├── LICENSE ├── README.md ├── Visual-ARFT ├── README.md ├── assets │ ├── dataset.png │ ├── framework.png │ ├── tesear.png │ └── test.txt ├── demo │ ├── .env │ ├── README.md │ ├── cases │ │ ├── 23.jpg │ │ ├── annotation2.jpg │ │ ├── coding_case.json │ │ ├── crop_13_ori.png │ │ ├── crop_8_ori.png │ │ ├── dark_1_proc.png │ │ ├── none_2_proc.png │ │ ├── rotation180_1_proc.png │ │ └── rotation90+dark_0_proc.png │ ├── coding_demo.ipynb │ ├── search_demo.ipynb │ └── web_search.py ├── evaluation_coding │ ├── README.md │ ├── evaluation.ipynb │ ├── evaluation_mat_coding_ngpu_7b_df.py │ ├── evaluation_mat_coding_visual_arft.py │ └── evaluation_results │ │ ├── 3b_mat_coding_grpo_step1400.json │ │ ├── 3b_mat_coding_grpo_step1400_v2.json │ │ ├── 7b_mat_coding_grpo_step1400.json │ │ └── 7b_mat_coding_grpo_step1400_v2.json ├── evaluation_search │ ├── .env │ ├── MAT-Search-Benchmark │ │ ├── evaluation.ipynb │ │ ├── evaluation_mat_search_ngpu_7b_df.py │ │ └── evaluation_mat_search_ngpu_7b_visual_arft.py │ ├── README.md │ ├── evaluation_results │ │ ├── 2wiki │ │ │ ├── 3b_step200_2wiki_web_n4.json │ │ │ └── 7b_step200_2wiki_web_n4.json │ │ ├── MAT-Search │ │ │ ├── 3b_step200_mat_search_web_n4.json │ │ │ └── 7b_step200_mat_search_web_n4.json │ │ ├── bamboogle │ │ │ ├── 3b_step200_bamboogle_web_n4.json │ │ │ └── 7b_step200_bamboogle_web_n4.json │ │ ├── hotpotqa │ │ │ ├── 3b_step200_hotpotqa_web_n4.json │ │ │ └── 7b_step200_hotpotqa_web_n4.json │ │ └── musique │ │ │ ├── 3b_step200_musique_web_n4.json │ │ │ └── 7b_step200_musique_web_n4.json │ ├── tools │ │ └── web_search.py │ └── traditional_benchmark │ │ ├── evaluation.ipynb │ │ ├── evaluation_2wikimultihopqa_ngpu_7b_df.py │ │ ├── evaluation_2wikimultihopqa_ngpu_7b_visual_arft.py │ │ ├── evaluation_bamboogle_ngpu_7b_df.py │ │ ├── evaluation_bamboogle_ngpu_7b_visual_arft.py │ │ ├── evaluation_hotpotqa_ngpu_7b_df.py │ │ ├── evaluation_hotpotqa_ngpu_7b_visual_arft.py │ │ ├── evaluation_musique_ngpu_7b_df.py │ │ └── evaluation_musique_ngpu_7b_visual_arft.py ├── setup.sh └── src │ ├── scripts │ ├── run_grpo_agent_code_3b_1_2k_new2_gpu8.sh │ ├── run_grpo_agent_code_7b_1_2k_new2_gpu8.sh │ ├── run_grpo_agent_search_3b_gpu8.sh │ └── run_grpo_agent_search_7b_gpu8.sh │ └── visual_arft │ ├── .gitignore │ ├── LICENSE │ ├── Makefile │ ├── configs │ ├── ddp.yaml │ ├── qwen2vl_sft_config.yaml │ ├── zero2.yaml │ └── zero3.yaml │ ├── local_scripts │ ├── create_vision_cot_data.py │ ├── lmms_eval_qwen2vl.sh │ ├── prepare_hf_data.py │ ├── train_aria_moe.sh │ ├── train_qwen2_vl.sh │ ├── zero1_no_optimizer.json │ ├── zero2.json │ ├── zero3.json │ ├── zero3.yaml │ └── zero3_offload.json │ ├── run_grpo.sh │ ├── setup.cfg │ ├── setup.py │ ├── src │ └── open_r1 │ │ ├── __init__.py │ │ ├── evaluate.py │ │ ├── generate.py │ │ ├── grpo.py │ │ ├── grpo_agent_code.py │ │ ├── grpo_agent_search.py │ │ ├── sft.py │ │ └── trainer │ │ ├── __init__.py │ │ ├── grpo_trainer.py │ │ ├── grpo_trainer_aid.py │ │ ├── grpo_trainer_mp.py │ │ ├── vllm_grpo_trainer.py │ │ └── vllm_grpo_trainer_modified.py │ └── temp_image.png ├── assets ├── case_cls.png ├── case_lisa.png ├── framework.png ├── pokeymon.jpg ├── radar.png └── teaser.png ├── classification ├── Qwen2_VL_classification_infere.py └── val_data │ ├── fgvc_aircraft.pth │ ├── fgvc_aircraft.txt │ ├── oxford_flowers.pth │ ├── oxford_flowers.txt │ ├── pets.pth │ ├── pets.txt │ ├── stanford_cars.pth │ └── stanford_cars.txt ├── coco_evaluation ├── Qwen2_VL_coco_infere.py ├── coco_evaluation.py ├── evaluation.ipynb ├── exist_map_coco_Qwen2_vl_2B_baseline.json └── exist_map_coco_Qwen2_vl_7B_baseline.json ├── dataset ├── README.md └── build_dataset.ipynb ├── demo ├── README.md └── lisa_demo.ipynb ├── lisa_evaluation ├── Qwen2_VL_lisa_infere.py ├── Qwen2_VL_lisa_infere.sh ├── README.md ├── box2mask.py ├── evaluation.ipynb ├── gen_box_ann.py ├── gen_sft.py ├── mask_iou.py └── merge_eval.py ├── lvis_evaluation ├── Qwen2_VL_lvis_infere.py ├── exist_map_lvis_Qwen2_vl_2B_baseline.json └── exist_map_lvis_Qwen2_vl_7B_baseline.json ├── setup.sh └── src ├── scripts ├── 2B_aircraft_4_shot.sh ├── 2B_base65cate_6k.sh ├── 2B_lisa_grounding.sh ├── 7B_base65cate_6k.sh └── example.sh └── virft ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── configs ├── ddp.yaml ├── zero2.yaml └── zero3.yaml ├── local_scripts ├── create_vision_cot_data.py ├── lmms_eval_qwen2vl.sh ├── prepare_hf_data.py ├── train_aria_moe.sh ├── train_qwen2_vl.sh ├── zero2.json ├── zero3.json ├── zero3.yaml └── zero3_offload.json ├── setup.cfg ├── setup.py ├── slurm ├── evaluate.slurm ├── generate.slurm └── sft.slurm └── src └── open_r1 ├── __init__.py ├── evaluate.py ├── generate.py ├── grpo.py ├── grpo_classification.py ├── grpo_lisa.py ├── sft.py └── trainer ├── __init__.py ├── grpo_trainer.py ├── grpo_trainer_aid.py ├── grpo_trainer_mp.py ├── history ├── grpo_trainer_v1.py └── vllm_grpo_trainer_v1.py ├── vllm_grpo_trainer.py └── vllm_grpo_trainer_modified.py /Visual-ARFT/assets/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/assets/dataset.png -------------------------------------------------------------------------------- /Visual-ARFT/assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/assets/framework.png -------------------------------------------------------------------------------- /Visual-ARFT/assets/tesear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/assets/tesear.png -------------------------------------------------------------------------------- /Visual-ARFT/assets/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Visual-ARFT/demo/.env: -------------------------------------------------------------------------------- 1 | SERPER_API = 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' -------------------------------------------------------------------------------- /Visual-ARFT/demo/README.md: -------------------------------------------------------------------------------- 1 | # Instruction for demos 2 | 3 | ## Prepare Datasets and Models 4 | You can download MAT Benchmakr on 🤗Datasets. 5 | 6 | You can download our model: 🤗Visual-ARFT-Search and 🤗Visual-ARFT-Coding 7 | 8 | ## Inference 9 | Replace model's and benchmark's path in our demo. Run `coding_demo.ipynb` or `search_demo.ipynb` step by step. 10 | 11 | > 🔔 If you want to try **Visual-ARFT-Search**, typing your SerperAPI in '.env' to support the web search function. (Registration includes 2,500 free usage credits.) 12 | -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/23.jpg -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/annotation2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/annotation2.jpg -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/coding_case.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": 12, 4 | "question": "What is the serving size of the beverage?", 5 | "answers": [ 6 | "1 Can (473ml)" 7 | ], 8 | "ori_image_path": "rotation180_1_ori.png", 9 | "processed_image_path": "rotation180_1_proc.png", 10 | "type": [ 11 | "rotation180" 12 | ], 13 | "split": "simple" 14 | }, 15 | { 16 | "id": 22, 17 | "question": "从图中提取: 发票号码, 发票代码", 18 | "answers": [ 19 | "{'发票号码': '16734841', '发票代码': '112001970102'}" 20 | ], 21 | "ori_image_path": "dark_1_ori.png", 22 | "processed_image_path": "dark_1_proc.png", 23 | "type": [ 24 | "dark" 25 | ], 26 | "split": "simple" 27 | }, 28 | { 29 | "id": 63, 30 | "question": "What is the title of the movie featured in the image?", 31 | "answers": [ 32 | "Uncommon Valor" 33 | ], 34 | "ori_image_path": "none_2_ori.png", 35 | "processed_image_path": "none_2_proc.png", 36 | "type": [ 37 | "none" 38 | ], 39 | "split": "simple" 40 | }, 41 | { 42 | "id": 71, 43 | "question": "从图中提取: 发票号码, 发票代码, 金额", 44 | "answers": [ 45 | "{'发票号码': '00134721', '发票代码': '151002065029', '金额': '#'}" 46 | ], 47 | "ori_image_path": "rotation90+dark_0_ori.png", 48 | "processed_image_path": "rotation90+dark_0_proc.png", 49 | "type": [ 50 | "rotation90", 51 | "dark" 52 | ], 53 | "split": "hard" 54 | }, 55 | { 56 | "id": 119, 57 | "question": "Recognize the text within the [366, 723, 506, 851] of the image. The coordinates have been normalized ranging from 0 to 1000 by the image width and height. ", 58 | "answers": [ 59 | "SINCE 1994" 60 | ], 61 | "ori_image_path": "crop_8_ori.png", 62 | "processed_image_path": "crop_8_proc.png", 63 | "type": [ 64 | "crop" 65 | ], 66 | "split": "hard" 67 | }, 68 | { 69 | "id": 124, 70 | "question": "Recognize the text within the [489, 155, 595, 259] of the image. The coordinates have been normalized ranging from 0 to 1000 by the image width and height. ", 71 | "answers": [ 72 | "Spaghetti\nGOEMON" 73 | ], 74 | "ori_image_path": "crop_13_ori.png", 75 | "processed_image_path": "crop_13_proc.png", 76 | "type": [ 77 | "crop" 78 | ], 79 | "split": "hard" 80 | } 81 | ] -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/crop_13_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/crop_13_ori.png -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/crop_8_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/crop_8_ori.png -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/dark_1_proc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/dark_1_proc.png -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/none_2_proc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/none_2_proc.png -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/rotation180_1_proc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/rotation180_1_proc.png -------------------------------------------------------------------------------- /Visual-ARFT/demo/cases/rotation90+dark_0_proc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/rotation90+dark_0_proc.png -------------------------------------------------------------------------------- /Visual-ARFT/demo/web_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from duckduckgo_search import DDGS 3 | import json 4 | import requests 5 | from typing import List, Literal, Optional, Union 6 | from firecrawl import FirecrawlApp 7 | from dotenv import load_dotenv 8 | load_dotenv() 9 | 10 | # 使用免费的 duck duck go 进行网页检索,使用 firecrawl 将网页转化为markdown格式 11 | def web_search_DDG(query: Optional[str], search_num: int = 2, search_mode: str = 'fast') -> Optional[List[str]]: 12 | assert search_mode == 'fast' or search_mode == 'pro' 13 | if search_mode == 'fast': 14 | assert type(query)==str 15 | results = DDGS().text(query, max_results=search_num) 16 | return results 17 | elif search_mode == 'pro': 18 | assert type(query)==str 19 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API")) 20 | results = DDGS().text(query, max_results=search_num) 21 | for result in results: 22 | web_url = result['href'] 23 | # firecrawl_app returns markdown and metadata 24 | web_content = firecrawl_app.scrape_url(web_url) 25 | web_content_markdown = web_content['markdown'] 26 | web_content_metadata = web_content['metadata'] 27 | result['web_content_markdown'] = web_content_markdown 28 | result['web_content_metadata'] = web_content_metadata 29 | return results 30 | 31 | def web_search_SERPER_API(query: Optional[str], search_num=2, search_mode: str = 'fast') -> Optional[List[str]]: 32 | assert search_mode == 'fast' or search_mode == 'pro' 33 | if search_mode == 'fast': 34 | assert type(query)==str 35 | url = "https://google.serper.dev/search" 36 | payload = json.dumps({"q": query, "num": search_num}) 37 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'} 38 | response = requests.request("POST", url, headers=headers, data=payload) 39 | response = json.loads(response.text) 40 | results = [] 41 | for item in response['organic']: 42 | results.append( 43 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']} 44 | ) 45 | return results 46 | elif search_mode == 'pro': 47 | assert type(query)==str 48 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API")) 49 | url = "https://google.serper.dev/search" 50 | payload = json.dumps({"q": query, "num": search_num}) 51 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'} 52 | response = requests.request("POST", url, headers=headers, data=payload) 53 | response = json.loads(response.text) 54 | results = [] 55 | for item in response['organic']: 56 | results.append( 57 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']} 58 | ) 59 | for result in results: 60 | web_url = result['href'] 61 | # firecrawl_app returns markdown and metadata 62 | web_content = firecrawl_app.scrape_url(web_url) 63 | web_content_markdown = web_content['markdown'] 64 | web_content_metadata = web_content['metadata'] 65 | result['web_content_markdown'] = web_content_markdown 66 | result['web_content_metadata'] = web_content_metadata 67 | return results -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_coding/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation on MAT-Coding 2 | In this folder, we provide two Python files for evaluation on **MAT-Coding**. 3 | ``evaluation_mat_coding_ngpu_7b_df.py`` is used to evaluate the results of direct inference, while `evaluation_mat_coding_visual_arft.py` is used to evaluate the inference results of the **Visual-Agentic-Coding** model. 4 | 5 | ## Download Dataset and Models 6 | You can download MAT Benchmakr on 🤗Datasets. Use `/MAT/MAT-Benchmark/MAT-Coding.json`. 7 | 8 | You can download our model: 🤗Visual-ARFT-Coding 9 | 10 | ## Inference on MAT-Coding 11 | 12 | To run `evaluation_mat_coding_visual_arft.py`, you need to replace the paths to the model and dataset: 13 | 14 | - Line 175: Replace **model_name** with the actual model path (**Visual-ARFT-Coding**). 15 | - Line 184: Replace **json_path** with the actual dataset path (MAT-Coding.json). 16 | - Line 201: Replace **input_image_path** with the actual image path. 17 | - Line 342: Set the results save path. 18 | 19 | > 🔔The intermediate image processing result will be saved as `cache.jpg` in the current directory, as specified in Line 206. 20 | 21 | To run `evaluation_mat_coding_ngpu_7b_df.py`, you need to replace the paths to the model and dataset: 22 | 23 | - Line 20: Replace **model_name** with the actual model path (**Original Qwen2.5-VL without Visual-ARFT**). 24 | - Line 30: Replace **json_path** with the actual dataset path (MAT-Coding.json). 25 | - Line 42: Replace **input_image_path** with the actual image path. 26 | - Line 103: Set the results save path. 27 | 28 | > ⏳ The inference time for **Visual-ARFT-Coding-7B** is approximately 1.5 hours, while the **3B model** takes around 50 minutes. The inference time for **Original Qwen2.5-VL without Visual-ARFT** is much faster. 29 | 30 | 31 | ## Evaluation 32 | After obtaining the inference results, run the `evaluation.ipynb` step by step. The `.ipynb` file will provide the final evaluation scores. 33 | 34 | We have saved the inference results of the `Visual-ARFT-Coding-7B` model in the `evaluation_results` folder. You can directly use the results inside to test the evaluation scores. 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_coding/evaluation_mat_coding_ngpu_7b_df.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import torch 4 | import string 5 | import numpy as np 6 | from tqdm import tqdm 7 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 8 | from qwen_vl_utils import process_vision_info 9 | import matplotlib.pyplot as plt 10 | import matplotlib.image as mpimg 11 | from PIL import Image 12 | 13 | 14 | # 定义颜色的ANSI代码 15 | RED = '\033[91m' 16 | GREEN = '\033[92m' 17 | YELLOW = '\033[93m' 18 | RESET = '\033[0m' # 重置颜色 19 | 20 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct" 21 | 22 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 23 | model_path, 24 | torch_dtype=torch.bfloat16, 25 | attn_implementation="flash_attention_2", 26 | device_map='cuda:0', 27 | ) 28 | processor = AutoProcessor.from_pretrained(model_path) 29 | 30 | with open('/share_data/MAT/MAT-Benchmark/MAT-Coding.json', 'r') as file: 31 | wikimultihopqa = json.load(file) 32 | print(len(wikimultihopqa)) 33 | 34 | 35 | combine_results = [] 36 | for item in tqdm(wikimultihopqa[:]): 37 | print("########################################") 38 | if item['type'][0] == 'crop': 39 | input_image_path = item['ori_image_path'] 40 | else: 41 | input_image_path = item['processed_image_path'] 42 | input_image_path = '/share_data/MAT/MAT-Benchmark/MAT-Coding-image/' + input_image_path 43 | query = item['question'] 44 | data_type = item['type'] 45 | item_id = item['id'] 46 | answer = item['answers'] 47 | 48 | # ### Direct Inference 49 | input_text = query + '\n' + "Answer the question directly. The answer should be very brief." 50 | 51 | ### CoT 52 | # input_text = SYSTEM_PROMPT + '\n' + query + "\n" + "You must output your thinking processs in . The answer between should be very brief." 53 | 54 | print(RED+input_text+RESET) 55 | print(GREEN+str(answer)+RESET) 56 | 57 | try: 58 | ###### 进行一轮推理 ###### 59 | messages = [ 60 | { "role": "user", "content": [{"type": "image","image": input_image_path}, {"type": "text", "text": input_text}]} 61 | ] 62 | # Preparation for inference 63 | text = processor.apply_chat_template( 64 | messages, tokenize=False, add_generation_prompt=True 65 | ) 66 | image_inputs, video_inputs = process_vision_info(messages) 67 | inputs = processor( 68 | text=[text], 69 | images=image_inputs, 70 | videos=video_inputs, 71 | padding=True, 72 | return_tensors="pt", 73 | ) 74 | inputs = inputs.to(model.device) 75 | 76 | # Inference: Generation of the output 77 | generated_ids = model.generate(**inputs, max_new_tokens=2048) 78 | generated_ids_trimmed = [ 79 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 80 | ] 81 | output_text = processor.batch_decode( 82 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 83 | ) 84 | result = output_text[0] 85 | 86 | pred_answer = result 87 | print(YELLOW+result+RESET) 88 | 89 | # match = re.search(r"(.*?)", result, re.DOTALL) 90 | # if match: 91 | # pred_answer = match.group(1).strip() 92 | # print(YELLOW+result+RESET) 93 | # print(YELLOW+pred_answer+RESET) 94 | 95 | except Exception as e: 96 | print("ERROR OCCURES") 97 | print({e}) 98 | combine_results.append( 99 | {'pred_answer': pred_answer, 'gt': answer, 'query': query} 100 | ) 101 | 102 | print(len(combine_results)) 103 | with open(f"./7B_df.json", "w", encoding="utf-8") as f: 104 | json.dump(combine_results, f, ensure_ascii=False, indent=4) -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/.env: -------------------------------------------------------------------------------- 1 | SERPER_API = 'xxxxxxxxxxxxxxxxxxxxxxxxx' -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/MAT-Search-Benchmark/evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import re\n", 10 | "import string\n", 11 | "def normalize(s):\n", 12 | " def remove_articles(text):\n", 13 | " return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\n", 14 | "\n", 15 | " def white_space_fix(text):\n", 16 | " return \" \".join(text.split())\n", 17 | "\n", 18 | " def remove_punc(text):\n", 19 | " exclude = set(string.punctuation)\n", 20 | " return \"\".join(ch for ch in text if ch not in exclude)\n", 21 | "\n", 22 | " def lower(text):\n", 23 | " return text.lower()\n", 24 | "\n", 25 | " return white_space_fix(remove_articles(remove_punc(lower(s))))\n", 26 | "\n", 27 | "def compute_f1(prediction, ground_truth):\n", 28 | " if prediction is None:\n", 29 | " return 0.0\n", 30 | " prediction_tokens = normalize(prediction).split()\n", 31 | " ground_truth_tokens = normalize(ground_truth).split()\n", 32 | "\n", 33 | " common = set(prediction_tokens) & set(ground_truth_tokens)\n", 34 | " num_same = len(common)\n", 35 | "\n", 36 | " if num_same == 0:\n", 37 | " return 0.0\n", 38 | "\n", 39 | " precision = num_same / len(prediction_tokens)\n", 40 | " recall = num_same / len(ground_truth_tokens)\n", 41 | " f1 = 2 * precision * recall / (precision + recall)\n", 42 | " return f1\n", 43 | "\n", 44 | "def exact_match_score(prediction, ground_truth):\n", 45 | " if prediction is None:\n", 46 | " return 0.0\n", 47 | " return int(normalize(prediction) == normalize(ground_truth))\n", 48 | "\n", 49 | "def evaluate(predictions):\n", 50 | " total = len(predictions)\n", 51 | " f1_total = []\n", 52 | " em_total = []\n", 53 | "\n", 54 | " for item in predictions:\n", 55 | " # if item['pred_answer_ori'] == None:\n", 56 | " pred = item['pred_answer']\n", 57 | " # else:\n", 58 | " # pred = item['pred_answer_ori']\n", 59 | " gts = item['gt']\n", 60 | "\n", 61 | " # 若gt是str,统一转换为列表处理\n", 62 | " if isinstance(gts, str):\n", 63 | " gts = [gts]\n", 64 | "\n", 65 | " f1 = max([compute_f1(pred, gt) for gt in gts])\n", 66 | " em = max([exact_match_score(pred, gt) for gt in gts])\n", 67 | " if em == 1:\n", 68 | " f1 = 1\n", 69 | "\n", 70 | " f1_total.append(f1)\n", 71 | " em_total.append(em)\n", 72 | "\n", 73 | " return {\n", 74 | " \"avg_f1\": sum(f1_total) / total if total > 0 else 0,\n", 75 | " \"avg_em\": sum(em_total) / total if total > 0 else 0,\n", 76 | " \"simple_f1\": sum(f1_total[:75]) / 75 if total > 0 else 0,\n", 77 | " \"simple_em\": sum(em_total[:75]) / 75 if total > 0 else 0,\n", 78 | " \"hard_f1\": sum(f1_total[75:]) / 75 if total > 0 else 0,\n", 79 | " \"hard_em\": sum(em_total[75:]) / 75 if total > 0 else 0,\n", 80 | " }" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# 12576 125 7405 2417\n", 90 | "\n", 91 | "import json\n", 92 | "with open('7B_df.json', 'r') as f:\n", 93 | " combine_results = json.load(f)\n", 94 | "print(len(combine_results))\n", 95 | "\n", 96 | "count_none = 0\n", 97 | "for item in combine_results:\n", 98 | " if item['pred_answer'] == None:\n", 99 | " count_none += 1\n", 100 | "print(count_none)\n", 101 | "results = evaluate(combine_results)\n", 102 | "print(*[f\"{results[k]*100:.2f}\" for k in ['simple_f1', 'simple_em', 'hard_f1', 'hard_em', 'avg_f1', 'avg_em']])" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": "r1-v", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.11.11" 130 | }, 131 | "orig_nbformat": 4 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 2 135 | } 136 | -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/MAT-Search-Benchmark/evaluation_mat_search_ngpu_7b_df.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | import json 5 | import torch 6 | from tqdm import tqdm 7 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 8 | from qwen_vl_utils import process_vision_info 9 | 10 | # 定义颜色的ANSI代码 11 | RED = '\033[91m' 12 | GREEN = '\033[92m' 13 | YELLOW = '\033[93m' 14 | RESET = '\033[0m' # 重置颜色 15 | 16 | 17 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct" 18 | 19 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 20 | model_path, 21 | torch_dtype=torch.bfloat16, 22 | attn_implementation="flash_attention_2", 23 | device_map='cuda:0', 24 | ) 25 | processor = AutoProcessor.from_pretrained(model_path) 26 | 27 | with open('/share_data/MAT/MAT-Benchmark/MAT-Coding.json', 'r') as file: 28 | wikimultihopqa = json.load(file) 29 | print(len(wikimultihopqa)) 30 | 31 | combine_results = [] 32 | for i in tqdm(range(len(wikimultihopqa))): 33 | pred_answer = None 34 | query = wikimultihopqa[i]['question'] 35 | answer = wikimultihopqa[i]['answer'] 36 | image_path = wikimultihopqa[i]['image_path'] 37 | input_image_path = '/MAT/MAT-Benchmark/MAT-Coding-image/' + image_path 38 | item_id = wikimultihopqa[i]['id'] 39 | 40 | # ### Direct Inference 41 | input_text = query + '\n' + "Answer the question directly. The answer should be very brief." 42 | 43 | ### CoT 44 | # input_text = SYSTEM_PROMPT + '\n' + query + "\n" + "You must output your thinking processs in . The answer between should be very brief." 45 | 46 | print(RED+input_text+RESET) 47 | print(GREEN+str(answer)+RESET) 48 | 49 | try: 50 | ###### 进行一轮推理 ###### 51 | messages = [ 52 | { "role": "user", "content": [{"type": "image","image": input_image_path}, {"type": "text", "text": input_text}]} 53 | ] 54 | # Preparation for inference 55 | text = processor.apply_chat_template( 56 | messages, tokenize=False, add_generation_prompt=True 57 | ) 58 | image_inputs, video_inputs = process_vision_info(messages) 59 | inputs = processor( 60 | text=[text], 61 | images=image_inputs, 62 | videos=video_inputs, 63 | padding=True, 64 | return_tensors="pt", 65 | ) 66 | inputs = inputs.to(model.device) 67 | 68 | # Inference: Generation of the output 69 | generated_ids = model.generate(**inputs, max_new_tokens=2048) 70 | generated_ids_trimmed = [ 71 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 72 | ] 73 | output_text = processor.batch_decode( 74 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 75 | ) 76 | result = output_text[0] 77 | 78 | pred_answer = result 79 | print(YELLOW+result+RESET) 80 | 81 | # match = re.search(r"(.*?)", result, re.DOTALL) 82 | # if match: 83 | # pred_answer = match.group(1).strip() 84 | # print(YELLOW+result+RESET) 85 | # print(YELLOW+pred_answer+RESET) 86 | 87 | except Exception as e: 88 | print("ERROR OCCURES") 89 | print({e}) 90 | combine_results.append( 91 | {'pred_answer': pred_answer, 'gt': answer, 'query': query} 92 | ) 93 | 94 | print(len(combine_results)) 95 | with open(f"./7B_df.json", "w", encoding="utf-8") as f: 96 | json.dump(combine_results, f, ensure_ascii=False, indent=4) -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation on MAT-Search 2 | In this folder, we provide scripts for evaluation on **MAT-Search** and other four multihopQA Benchmark (2wikimultihopQA, HotpotQA, MuSiQue, Bamboogle). 3 | ``MAT-Search-Benchmark`` folder is used to evaluate the results on **MAT-Search**, while `traditional_benchmark` folder is used to evaluate the inference results of the other four multihopQA Benchmark. 4 | 5 | ## Download Dataset and Models 6 | You can download MAT Benchmakr on 🤗Datasets. Use `/MAT/MAT-Benchmark/MAT-Search.json`. 7 | 8 | You can download our model: 🤗Visual-ARFT-Search 9 | 10 | You can download other four multihopQA benchmark from 🤗Dataset. 11 | 12 | ## Web Search API 13 | We use SerperAPI for web search. You can start by registering an account to receive 2,500 free queries, and then add your API key to the `.env` file. 14 | 15 | ## Inference on MAT-Search 16 | 17 | Firstly, `cd MAT-Search-Benchmark`. In this folder, there are `evaluation_mat_search_ngpu_7b_df.py` and `evaluation_mat_search_ngpu_7b_visual_arft.py`. The first Python file is used to evaluate the results of direct inference, while the second Python file is used to evaluate the results of the **Visual-ARFT-Search** model. 18 | 19 | To run `evaluation_mat_coding_visual_arft.py`, you need to replace the paths to the model and dataset: 20 | 21 | - Line 59: Replace **model_name** with the actual model path (**Visual-ARFT-Search**). 22 | - Line 76: Replace **json_path** with the actual dataset path (MAT-Search.json). 23 | - Line 97: Replace **input_image_path** with the actual image path. 24 | - Line 192: Set the results save path. 25 | 26 | To run `evaluation_mat_search_ngpu_7b_df.py`, you need to replace the paths to the model and dataset: 27 | 28 | - Line 17: Replace **model_name** with the actual model path (**Original Qwen2.5-VL without Visual-ARFT**). 29 | - Line 27: Replace **json_path** with the actual dataset path (MAT-Search.json). 30 | - Line 37: Replace **input_image_path** with the actual image path. 31 | - Line 95: Set the results save path. 32 | 33 | > ⏳ The code `evaluation_mat_coding_visual_arft.py` supports multi-GPU execution. The inference time for **Visual-ARFT-Coding-3/7B** is around ten minutes with 8GPUs. 34 | 35 | ## Inference on other MultihopQA Benchmark 36 | 37 | Firstly, `cd traditional_benchmark`. Then, replace the paths as instructed above. 38 | 39 | **2Wiki** contains over 12k questions and **HotpotQA** has over 7k, so the inference process may take a relatively long time. However, all scripts in the `traditional_benchmark` folder support multi-GPU inference. 40 | 41 | ## Evaluation 42 | After obtaining the inference results, run the `evaluation.ipynb` in each folder step by step. The `.ipynb` file will provide the final evaluation scores. 43 | 44 | We have also saved the inference results of the `Visual-ARFT-Search-7B` model in the `evaluation_results` folder. You can directly use the results inside to test the evaluation scores. 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/tools/web_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from duckduckgo_search import DDGS 3 | import json 4 | import requests 5 | from typing import List, Literal, Optional, Union 6 | from firecrawl import FirecrawlApp 7 | from dotenv import load_dotenv 8 | load_dotenv() 9 | 10 | # 使用免费的 duck duck go 进行网页检索,使用 firecrawl 将网页转化为markdown格式 11 | def web_search_DDG(query: Optional[str], search_num: int = 2, search_mode: str = 'fast') -> Optional[List[str]]: 12 | assert search_mode == 'fast' or search_mode == 'pro' 13 | if search_mode == 'fast': 14 | assert type(query)==str 15 | results = DDGS().text(query, max_results=search_num) 16 | return results 17 | elif search_mode == 'pro': 18 | assert type(query)==str 19 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API")) 20 | results = DDGS().text(query, max_results=search_num) 21 | for result in results: 22 | web_url = result['href'] 23 | # firecrawl_app returns markdown and metadata 24 | web_content = firecrawl_app.scrape_url(web_url) 25 | web_content_markdown = web_content['markdown'] 26 | web_content_metadata = web_content['metadata'] 27 | result['web_content_markdown'] = web_content_markdown 28 | result['web_content_metadata'] = web_content_metadata 29 | return results 30 | 31 | def web_search_SERPER_API(query: Optional[str], search_num=2, search_mode: str = 'fast') -> Optional[List[str]]: 32 | assert search_mode == 'fast' or search_mode == 'pro' 33 | if search_mode == 'fast': 34 | assert type(query)==str 35 | url = "https://google.serper.dev/search" 36 | payload = json.dumps({"q": query, "num": search_num}) 37 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'} 38 | response = requests.request("POST", url, headers=headers, data=payload) 39 | response = json.loads(response.text) 40 | results = [] 41 | for item in response['organic']: 42 | results.append( 43 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']} 44 | ) 45 | return results 46 | elif search_mode == 'pro': 47 | assert type(query)==str 48 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API")) 49 | url = "https://google.serper.dev/search" 50 | payload = json.dumps({"q": query, "num": search_num}) 51 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'} 52 | response = requests.request("POST", url, headers=headers, data=payload) 53 | response = json.loads(response.text) 54 | results = [] 55 | for item in response['organic']: 56 | results.append( 57 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']} 58 | ) 59 | for result in results: 60 | web_url = result['href'] 61 | # firecrawl_app returns markdown and metadata 62 | web_content = firecrawl_app.scrape_url(web_url) 63 | web_content_markdown = web_content['markdown'] 64 | web_content_metadata = web_content['metadata'] 65 | result['web_content_markdown'] = web_content_markdown 66 | result['web_content_metadata'] = web_content_metadata 67 | return results -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/traditional_benchmark/evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import re\n", 10 | "import string\n", 11 | "def normalize(s):\n", 12 | " def remove_articles(text):\n", 13 | " return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\n", 14 | "\n", 15 | " def white_space_fix(text):\n", 16 | " return \" \".join(text.split())\n", 17 | "\n", 18 | " def remove_punc(text):\n", 19 | " exclude = set(string.punctuation)\n", 20 | " return \"\".join(ch for ch in text if ch not in exclude)\n", 21 | "\n", 22 | " def lower(text):\n", 23 | " return text.lower()\n", 24 | "\n", 25 | " return white_space_fix(remove_articles(remove_punc(lower(s))))\n", 26 | "\n", 27 | "def compute_f1(prediction, ground_truth):\n", 28 | " if prediction is None:\n", 29 | " return 0.0\n", 30 | " prediction_tokens = normalize(prediction).split()\n", 31 | " ground_truth_tokens = normalize(ground_truth).split()\n", 32 | "\n", 33 | " common = set(prediction_tokens) & set(ground_truth_tokens)\n", 34 | " num_same = len(common)\n", 35 | "\n", 36 | " if num_same == 0:\n", 37 | " return 0.0\n", 38 | "\n", 39 | " precision = num_same / len(prediction_tokens)\n", 40 | " recall = num_same / len(ground_truth_tokens)\n", 41 | " f1 = 2 * precision * recall / (precision + recall)\n", 42 | " return f1\n", 43 | "\n", 44 | "def exact_match_score(prediction, ground_truth):\n", 45 | " if prediction is None:\n", 46 | " return 0.0\n", 47 | " return int(normalize(prediction) == normalize(ground_truth))\n", 48 | "\n", 49 | "def evaluate(predictions):\n", 50 | " total = len(predictions)\n", 51 | " f1_total = 0\n", 52 | " em_total = 0\n", 53 | "\n", 54 | " for item in predictions:\n", 55 | " # if item['pred_answer_ori'] == None:\n", 56 | " pred = item['pred_answer']\n", 57 | " # else:\n", 58 | " # pred = item['pred_answer_ori']\n", 59 | " gts = item['gt']\n", 60 | "\n", 61 | " # 若gt是str,统一转换为列表处理\n", 62 | " if isinstance(gts, str):\n", 63 | " gts = [gts]\n", 64 | "\n", 65 | " f1 = max([compute_f1(pred, gt) for gt in gts])\n", 66 | " em = max([exact_match_score(pred, gt) for gt in gts])\n", 67 | " if em == 1:\n", 68 | " f1 = 1\n", 69 | "\n", 70 | " f1_total += f1\n", 71 | " em_total += em\n", 72 | "\n", 73 | " return {\n", 74 | " \"avg_f1\": f1_total / total if total > 0 else 0,\n", 75 | " \"avg_em\": em_total / total if total > 0 else 0\n", 76 | " }" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "import json\n", 86 | "\n", 87 | "# results list:\n", 88 | "json_file_paths = [\n", 89 | " \"7b_step200_musique_web_n4_0.json\",\n", 90 | " \"7b_step200_musique_web_n4_1.json\",\n", 91 | " \"7b_step200_musique_web_n4_2.json\",\n", 92 | " \"7b_step200_musique_web_n4_3.json\",\n", 93 | " \"7b_step200_musique_web_n4_4.json\",\n", 94 | " \"7b_step200_musique_web_n4_5.json\",\n", 95 | " \"7b_step200_musique_web_n4_6.json\",\n", 96 | " \"7b_step200_musique_web_n4_7.json\",\n", 97 | "]\n", 98 | "\n", 99 | "# combined to one file\n", 100 | "merged_data = []\n", 101 | "for path in json_file_paths:\n", 102 | " with open(path, \"r\", encoding=\"utf-8\") as f:\n", 103 | " data = json.load(f)\n", 104 | " merged_data.extend(data)\n", 105 | "\n", 106 | "print(len(merged_data))\n", 107 | "# save\n", 108 | "with open(\"7b_step200_musique_web_n4.json\", \"w\", encoding=\"utf-8\") as f:\n", 109 | " json.dump(merged_data, f, ensure_ascii=False, indent=2)\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "import json\n", 119 | "with open('./evaluation_results/musique/7b_step200_musique_web_n4.json', 'r') as f:\n", 120 | " combine_results = json.load(f)\n", 121 | "print(len(combine_results))\n", 122 | "\n", 123 | "count_none = 0\n", 124 | "for item in combine_results:\n", 125 | " if item['pred_answer'] == None:\n", 126 | " count_none += 1\n", 127 | "print(count_none)\n", 128 | "results = evaluate(combine_results)\n", 129 | "print(\"Average F1:\", results['avg_f1'])\n", 130 | "print(\"Average EM:\", results['avg_em'])" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "r1-v", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.11.11" 158 | }, 159 | "orig_nbformat": 4 160 | }, 161 | "nbformat": 4, 162 | "nbformat_minor": 2 163 | } 164 | -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_2wikimultihopqa_ngpu_7b_df.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from tqdm import tqdm 4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 5 | 6 | # 定义颜色的ANSI代码 7 | RED = '\033[91m' 8 | GREEN = '\033[92m' 9 | YELLOW = '\033[93m' 10 | RESET = '\033[0m' # 重置颜色 11 | 12 | import logging 13 | logging.basicConfig() 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.INFO) 16 | 17 | import functools 18 | import itertools 19 | import multiprocessing as mp 20 | from argparse import ArgumentParser 21 | from multiprocessing import Pool 22 | 23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct" 24 | 25 | def run(rank, world_size): 26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto 27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 28 | model_path, 29 | torch_dtype=torch.bfloat16, 30 | attn_implementation="flash_attention_2", 31 | device_map='cpu', 32 | ) 33 | processor = AutoProcessor.from_pretrained(model_path) 34 | 35 | model = model.to(torch.device(rank)) 36 | model = model.eval() 37 | 38 | # 加载数据集数据 39 | wikimultihopqa = [] 40 | with open('/2wikimultihopqa/dev.jsonl', 'r', encoding='utf-8') as f: 41 | for line in f: 42 | if line.strip(): # 跳过空行 43 | wikimultihopqa.append(json.loads(line)) 44 | print(len(wikimultihopqa)) 45 | 46 | # wikimultihopqa = wikimultihopqa[:4] 47 | 48 | print(wikimultihopqa[0]['question']) 49 | print(wikimultihopqa[0]['golden_answers']) 50 | print("Rank:" + str(rank)) 51 | print("World Size:" + str(world_size)) 52 | import math 53 | split_length = math.ceil(len(wikimultihopqa)/world_size) 54 | print("Split Chunk Length:" + str(split_length)) 55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)] 56 | print(len(split_wikimultihopqa)) 57 | wikimultihopqa = split_wikimultihopqa 58 | 59 | combine_results = [] 60 | for i in tqdm(range(len(wikimultihopqa))): 61 | pred_answer = None 62 | query = wikimultihopqa[i]['question'] 63 | answer = wikimultihopqa[i]['golden_answers'] 64 | item_id = wikimultihopqa[i]['id'] 65 | # + 'Answer the question directly.' 66 | input_text = query + '\n' + 'Only Return the answer.' 67 | # print("################################################################") 68 | # print(query) 69 | # print(answer) 70 | try: 71 | messages = [ 72 | { "role": "user", 73 | "content": [{"type": "text", "text": input_text}]} 74 | ] 75 | # Preparation for inference 76 | text = processor.apply_chat_template( 77 | messages, tokenize=False, add_generation_prompt=True 78 | ) 79 | # image_inputs, video_inputs = process_vision_info(messages) 80 | inputs = processor( 81 | text=[text], 82 | # images=image_inputs, 83 | # videos=video_inputs, 84 | padding=True, 85 | return_tensors="pt", 86 | ) 87 | inputs = inputs.to(model.device) 88 | 89 | # Inference: Generation of the output 90 | generated_ids = model.generate(**inputs, max_new_tokens=2048) 91 | generated_ids_trimmed = [ 92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 93 | ] 94 | output_text = processor.batch_decode( 95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 96 | ) 97 | result = output_text[0] 98 | # print(result) 99 | pred_answer = result 100 | 101 | except Exception as e: 102 | logger.info("ERROR OCCURES") 103 | logger.info({e}) 104 | if pred_answer != None: 105 | combine_results.append( 106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query} 107 | ) 108 | else: 109 | combine_results.append( 110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query} 111 | ) 112 | with open(f"7b_ori_2wiki_direct_infere_{rank}.json", "w", encoding="utf-8") as f: 113 | json.dump(combine_results, f, ensure_ascii=False, indent=4) 114 | return combine_results 115 | 116 | def main(): 117 | multiprocess = torch.cuda.device_count() >= 2 118 | mp.set_start_method('spawn') 119 | if multiprocess: 120 | logger.info('started generation') 121 | n_gpus = torch.cuda.device_count() 122 | world_size = n_gpus 123 | with Pool(world_size) as pool: 124 | func = functools.partial(run, world_size=world_size) 125 | result_lists = pool.map(func, range(world_size)) 126 | 127 | global_results = [] 128 | for i in range(world_size): 129 | global_results = global_results + result_lists[i] 130 | 131 | logger.info("Done") 132 | logger.info('finished running') 133 | else: 134 | logger.info("Not enough GPUs") 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_bamboogle_ngpu_7b_df.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from tqdm import tqdm 4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 5 | 6 | # 定义颜色的ANSI代码 7 | RED = '\033[91m' 8 | GREEN = '\033[92m' 9 | YELLOW = '\033[93m' 10 | RESET = '\033[0m' # 重置颜色 11 | 12 | import logging 13 | logging.basicConfig() 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.INFO) 16 | 17 | import functools 18 | import itertools 19 | import multiprocessing as mp 20 | from argparse import ArgumentParser 21 | from multiprocessing import Pool 22 | 23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct" 24 | 25 | def run(rank, world_size): 26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto 27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 28 | model_path, 29 | torch_dtype=torch.bfloat16, 30 | attn_implementation="flash_attention_2", 31 | device_map='cpu', 32 | ) 33 | processor = AutoProcessor.from_pretrained(model_path) 34 | 35 | model = model.to(torch.device(rank)) 36 | model = model.eval() 37 | 38 | # 加载数据集数据 39 | wikimultihopqa = [] 40 | with open('/bamboogle/test.jsonl', 'r', encoding='utf-8') as f: 41 | for line in f: 42 | if line.strip(): # 跳过空行 43 | wikimultihopqa.append(json.loads(line)) 44 | print(len(wikimultihopqa)) 45 | 46 | # wikimultihopqa = wikimultihopqa[:4] 47 | 48 | print(wikimultihopqa[0]['question']) 49 | print(wikimultihopqa[0]['golden_answers']) 50 | print("Rank:" + str(rank)) 51 | print("World Size:" + str(world_size)) 52 | import math 53 | split_length = math.ceil(len(wikimultihopqa)/world_size) 54 | print("Split Chunk Length:" + str(split_length)) 55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)] 56 | print(len(split_wikimultihopqa)) 57 | wikimultihopqa = split_wikimultihopqa 58 | 59 | combine_results = [] 60 | for i in tqdm(range(len(wikimultihopqa))): 61 | pred_answer = None 62 | query = wikimultihopqa[i]['question'] 63 | answer = wikimultihopqa[i]['golden_answers'] 64 | item_id = wikimultihopqa[i]['id'] 65 | # + 'Answer the question directly.' 66 | input_text = query + '\n' + 'Only Return the answer.' 67 | # print("################################################################") 68 | # print(query) 69 | # print(answer) 70 | try: 71 | messages = [ 72 | { "role": "user", 73 | "content": [{"type": "text", "text": input_text}]} 74 | ] 75 | # Preparation for inference 76 | text = processor.apply_chat_template( 77 | messages, tokenize=False, add_generation_prompt=True 78 | ) 79 | # image_inputs, video_inputs = process_vision_info(messages) 80 | inputs = processor( 81 | text=[text], 82 | # images=image_inputs, 83 | # videos=video_inputs, 84 | padding=True, 85 | return_tensors="pt", 86 | ) 87 | inputs = inputs.to(model.device) 88 | 89 | # Inference: Generation of the output 90 | generated_ids = model.generate(**inputs, max_new_tokens=2048) 91 | generated_ids_trimmed = [ 92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 93 | ] 94 | output_text = processor.batch_decode( 95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 96 | ) 97 | result = output_text[0] 98 | # print(result) 99 | pred_answer = result 100 | 101 | except Exception as e: 102 | logger.info("ERROR OCCURES") 103 | logger.info({e}) 104 | if pred_answer != None: 105 | combine_results.append( 106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query} 107 | ) 108 | else: 109 | combine_results.append( 110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query} 111 | ) 112 | with open(f"7b_ori_bamboogle_direct_infere_{rank}.json", "w", encoding="utf-8") as f: 113 | json.dump(combine_results, f, ensure_ascii=False, indent=4) 114 | return combine_results 115 | 116 | def main(): 117 | multiprocess = torch.cuda.device_count() >= 2 118 | mp.set_start_method('spawn') 119 | if multiprocess: 120 | logger.info('started generation') 121 | n_gpus = torch.cuda.device_count() 122 | world_size = n_gpus 123 | with Pool(world_size) as pool: 124 | func = functools.partial(run, world_size=world_size) 125 | result_lists = pool.map(func, range(world_size)) 126 | 127 | global_results = [] 128 | for i in range(world_size): 129 | global_results = global_results + result_lists[i] 130 | 131 | logger.info("Done") 132 | logger.info('finished running') 133 | else: 134 | logger.info("Not enough GPUs") 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_hotpotqa_ngpu_7b_df.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from tqdm import tqdm 4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 5 | 6 | # 定义颜色的ANSI代码 7 | RED = '\033[91m' 8 | GREEN = '\033[92m' 9 | YELLOW = '\033[93m' 10 | RESET = '\033[0m' # 重置颜色 11 | 12 | import logging 13 | logging.basicConfig() 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.INFO) 16 | 17 | import functools 18 | import itertools 19 | import multiprocessing as mp 20 | from argparse import ArgumentParser 21 | from multiprocessing import Pool 22 | 23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct" 24 | 25 | def run(rank, world_size): 26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto 27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 28 | model_path, 29 | torch_dtype=torch.bfloat16, 30 | attn_implementation="flash_attention_2", 31 | device_map='cpu', 32 | ) 33 | processor = AutoProcessor.from_pretrained(model_path) 34 | 35 | model = model.to(torch.device(rank)) 36 | model = model.eval() 37 | 38 | # 加载数据集数据 39 | wikimultihopqa = [] 40 | with open('/hotpotqa/dev.jsonl', 'r', encoding='utf-8') as f: 41 | for line in f: 42 | if line.strip(): # 跳过空行 43 | wikimultihopqa.append(json.loads(line)) 44 | print(len(wikimultihopqa)) 45 | 46 | # wikimultihopqa = wikimultihopqa[:4] 47 | 48 | print(wikimultihopqa[0]['question']) 49 | print(wikimultihopqa[0]['golden_answers']) 50 | print("Rank:" + str(rank)) 51 | print("World Size:" + str(world_size)) 52 | import math 53 | split_length = math.ceil(len(wikimultihopqa)/world_size) 54 | print("Split Chunk Length:" + str(split_length)) 55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)] 56 | print(len(split_wikimultihopqa)) 57 | wikimultihopqa = split_wikimultihopqa 58 | 59 | combine_results = [] 60 | for i in tqdm(range(len(wikimultihopqa))): 61 | pred_answer = None 62 | query = wikimultihopqa[i]['question'] 63 | answer = wikimultihopqa[i]['golden_answers'] 64 | item_id = wikimultihopqa[i]['id'] 65 | # + 'Answer the question directly.' 66 | input_text = query + '\n' + 'Only Return the answer.' 67 | # print("################################################################") 68 | # print(query) 69 | # print(answer) 70 | try: 71 | messages = [ 72 | { "role": "user", 73 | "content": [{"type": "text", "text": input_text}]} 74 | ] 75 | # Preparation for inference 76 | text = processor.apply_chat_template( 77 | messages, tokenize=False, add_generation_prompt=True 78 | ) 79 | # image_inputs, video_inputs = process_vision_info(messages) 80 | inputs = processor( 81 | text=[text], 82 | # images=image_inputs, 83 | # videos=video_inputs, 84 | padding=True, 85 | return_tensors="pt", 86 | ) 87 | inputs = inputs.to(model.device) 88 | 89 | # Inference: Generation of the output 90 | generated_ids = model.generate(**inputs, max_new_tokens=2048) 91 | generated_ids_trimmed = [ 92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 93 | ] 94 | output_text = processor.batch_decode( 95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 96 | ) 97 | result = output_text[0] 98 | # print(result) 99 | pred_answer = result 100 | 101 | except Exception as e: 102 | logger.info("ERROR OCCURES") 103 | logger.info({e}) 104 | if pred_answer != None: 105 | combine_results.append( 106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query} 107 | ) 108 | else: 109 | combine_results.append( 110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query} 111 | ) 112 | with open(f"7b_ori_hotpotqa_direct_infere_{rank}.json", "w", encoding="utf-8") as f: 113 | json.dump(combine_results, f, ensure_ascii=False, indent=4) 114 | return combine_results 115 | 116 | def main(): 117 | multiprocess = torch.cuda.device_count() >= 2 118 | mp.set_start_method('spawn') 119 | if multiprocess: 120 | logger.info('started generation') 121 | n_gpus = torch.cuda.device_count() 122 | world_size = n_gpus 123 | with Pool(world_size) as pool: 124 | func = functools.partial(run, world_size=world_size) 125 | result_lists = pool.map(func, range(world_size)) 126 | 127 | global_results = [] 128 | for i in range(world_size): 129 | global_results = global_results + result_lists[i] 130 | 131 | logger.info("Done") 132 | logger.info('finished running') 133 | else: 134 | logger.info("Not enough GPUs") 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_musique_ngpu_7b_df.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from tqdm import tqdm 4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 5 | 6 | # 定义颜色的ANSI代码 7 | RED = '\033[91m' 8 | GREEN = '\033[92m' 9 | YELLOW = '\033[93m' 10 | RESET = '\033[0m' # 重置颜色 11 | 12 | import logging 13 | logging.basicConfig() 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.INFO) 16 | 17 | import functools 18 | import itertools 19 | import multiprocessing as mp 20 | from argparse import ArgumentParser 21 | from multiprocessing import Pool 22 | 23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct" 24 | 25 | def run(rank, world_size): 26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto 27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 28 | model_path, 29 | torch_dtype=torch.bfloat16, 30 | attn_implementation="flash_attention_2", 31 | device_map='cpu', 32 | ) 33 | processor = AutoProcessor.from_pretrained(model_path) 34 | 35 | model = model.to(torch.device(rank)) 36 | model = model.eval() 37 | 38 | # 加载数据集数据 39 | wikimultihopqa = [] 40 | with open('/musique/dev.jsonl', 'r', encoding='utf-8') as f: 41 | for line in f: 42 | if line.strip(): # 跳过空行 43 | wikimultihopqa.append(json.loads(line)) 44 | print(len(wikimultihopqa)) 45 | 46 | # wikimultihopqa = wikimultihopqa[:4] 47 | 48 | print(wikimultihopqa[0]['question']) 49 | print(wikimultihopqa[0]['golden_answers']) 50 | print("Rank:" + str(rank)) 51 | print("World Size:" + str(world_size)) 52 | import math 53 | split_length = math.ceil(len(wikimultihopqa)/world_size) 54 | print("Split Chunk Length:" + str(split_length)) 55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)] 56 | print(len(split_wikimultihopqa)) 57 | wikimultihopqa = split_wikimultihopqa 58 | 59 | combine_results = [] 60 | for i in tqdm(range(len(wikimultihopqa))): 61 | pred_answer = None 62 | query = wikimultihopqa[i]['question'] 63 | answer = wikimultihopqa[i]['golden_answers'] 64 | item_id = wikimultihopqa[i]['id'] 65 | # + 'Answer the question directly.' 66 | input_text = query + '\n' + 'Only Return the answer.' 67 | # print("################################################################") 68 | # print(query) 69 | # print(answer) 70 | try: 71 | messages = [ 72 | { "role": "user", 73 | "content": [{"type": "text", "text": input_text}]} 74 | ] 75 | # Preparation for inference 76 | text = processor.apply_chat_template( 77 | messages, tokenize=False, add_generation_prompt=True 78 | ) 79 | # image_inputs, video_inputs = process_vision_info(messages) 80 | inputs = processor( 81 | text=[text], 82 | # images=image_inputs, 83 | # videos=video_inputs, 84 | padding=True, 85 | return_tensors="pt", 86 | ) 87 | inputs = inputs.to(model.device) 88 | 89 | # Inference: Generation of the output 90 | generated_ids = model.generate(**inputs, max_new_tokens=2048) 91 | generated_ids_trimmed = [ 92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 93 | ] 94 | output_text = processor.batch_decode( 95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 96 | ) 97 | result = output_text[0] 98 | # print(result) 99 | pred_answer = result 100 | 101 | except Exception as e: 102 | logger.info("ERROR OCCURES") 103 | logger.info({e}) 104 | if pred_answer != None: 105 | combine_results.append( 106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query} 107 | ) 108 | else: 109 | combine_results.append( 110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query} 111 | ) 112 | with open(f"7b_ori_musique_direct_infere_{rank}.json", "w", encoding="utf-8") as f: 113 | json.dump(combine_results, f, ensure_ascii=False, indent=4) 114 | return combine_results 115 | 116 | def main(): 117 | multiprocess = torch.cuda.device_count() >= 2 118 | mp.set_start_method('spawn') 119 | if multiprocess: 120 | logger.info('started generation') 121 | n_gpus = torch.cuda.device_count() 122 | world_size = n_gpus 123 | with Pool(world_size) as pool: 124 | func = functools.partial(run, world_size=world_size) 125 | result_lists = pool.map(func, range(world_size)) 126 | 127 | global_results = [] 128 | for i in range(world_size): 129 | global_results = global_results + result_lists[i] 130 | 131 | logger.info("Done") 132 | logger.info('finished running') 133 | else: 134 | logger.info("Not enough GPUs") 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /Visual-ARFT/setup.sh: -------------------------------------------------------------------------------- 1 | cd src/visual_arft 2 | pip install -e ".[dev]" 3 | 4 | # Addtional modules 5 | pip install wandb==0.18.3 6 | pip install tensorboardx 7 | pip install qwen_vl_utils torchvision 8 | pip install flash-attn --no-build-isolation 9 | 10 | # vLLM support 11 | pip install vllm==0.7.2 12 | 13 | # fix transformers version 14 | pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef 15 | -------------------------------------------------------------------------------- /Visual-ARFT/src/scripts/run_grpo_agent_code_3b_1_2k_new2_gpu8.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL 2 | export LOG_PATH="./log_qwen25vl_3b_grpo_agent_code_1_2k_new2_gpu8.txt" 3 | 4 | torchrun --nproc_per_node="8" \ 5 | --nnodes="1" \ 6 | --node_rank="0" \ 7 | --master_addr="127.0.0.1" \ 8 | --master_port="12345" \ 9 | /src/visual_arft/src/open_r1/grpo_agent_code.py \ 10 | --output_dir /share_models/Qwen2.5-VL-3B-Instruct_GRPO_agent_code_1_2k_new2_gpu8 \ 11 | --model_name_or_path /shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-3B-Instruct \ 12 | --dataset_name /train_data/rft_agent_code_1_2k.json \ 13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \ 14 | --max_prompt_length 2048 \ 15 | --per_device_train_batch_size 1 \ 16 | --gradient_accumulation_steps 2 \ 17 | --logging_steps 1 \ 18 | --bf16 true \ 19 | --report_to wandb \ 20 | --gradient_checkpointing true \ 21 | --attn_implementation flash_attention_2 \ 22 | --max_pixels 401408 \ 23 | --num_train_epochs 10 \ 24 | --run_name Qwen25-VL-3B-GRPO-Agent-code-1_2k-new2-gpu8 \ 25 | --save_steps 100 \ 26 | --save_only_model true \ 27 | --num_generations 8 -------------------------------------------------------------------------------- /Visual-ARFT/src/scripts/run_grpo_agent_code_7b_1_2k_new2_gpu8.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL 2 | export LOG_PATH="./log_qwen25vl_7b_grpo_agent_code_1_2k_new2_gpu8.txt" 3 | 4 | torchrun --nproc_per_node="8" \ 5 | --nnodes="1" \ 6 | --node_rank="0" \ 7 | --master_addr="127.0.0.1" \ 8 | --master_port="12345" \ 9 | /src/visual_arft/src/open_r1/grpo_agent_code.py \ 10 | --output_dir /share_models/Qwen2.5-VL-7B-Instruct_GRPO_agent_code_1_2k_new2_gpu8 \ 11 | --model_name_or_path /share_model/Qwen2.5-VL-7B-Instruct \ 12 | --dataset_name /train_data/rft_agent_code_1_2k.json \ 13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \ 14 | --max_prompt_length 2048 \ 15 | --per_device_train_batch_size 1 \ 16 | --gradient_accumulation_steps 2 \ 17 | --logging_steps 1 \ 18 | --bf16 true \ 19 | --report_to wandb \ 20 | --gradient_checkpointing true \ 21 | --attn_implementation flash_attention_2 \ 22 | --max_pixels 401408 \ 23 | --num_train_epochs 10 \ 24 | --run_name Qwen25-VL-7B-GRPO-Agent-code-1_2k-new2-gpu8 \ 25 | --save_steps 100 \ 26 | --save_only_model true \ 27 | --num_generations 8 -------------------------------------------------------------------------------- /Visual-ARFT/src/scripts/run_grpo_agent_search_3b_gpu8.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL 2 | export LOG_PATH="./log_qwen25vl_3b_grpo_agent_search_data20_63_gpu8.txt" 3 | 4 | torchrun --nproc_per_node="8" \ 5 | --nnodes="1" \ 6 | --node_rank="0" \ 7 | --master_addr="127.0.0.1" \ 8 | --master_port="12345" \ 9 | /src/visual_arft/src/open_r1/grpo_agent_search.py \ 10 | --output_dir /share_models/Qwen2.5-VL-3B-Instruct_GRPO_agent_search_data20_63_gpu8 \ 11 | --model_name_or_path /shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-3B-Instruct \ 12 | --dataset_name /train_data/rft_agent_20.json \ 13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \ 14 | --max_prompt_length 2048 \ 15 | --per_device_train_batch_size 1 \ 16 | --gradient_accumulation_steps 2 \ 17 | --logging_steps 1 \ 18 | --bf16 true \ 19 | --report_to wandb \ 20 | --gradient_checkpointing true \ 21 | --attn_implementation flash_attention_2 \ 22 | --max_pixels 401408 \ 23 | --num_train_epochs 400 \ 24 | --run_name Qwen25-VL-3B-GRPO-Agent-Search-data20-63-gpu8 \ 25 | --save_steps 100 \ 26 | --save_only_model true \ 27 | --num_generations 8 -------------------------------------------------------------------------------- /Visual-ARFT/src/scripts/run_grpo_agent_search_7b_gpu8.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL 2 | export LOG_PATH="./log_qwen25vl_7b_grpo_agent_search_data20_63_gpu8.txt" 3 | 4 | torchrun --nproc_per_node="8" \ 5 | --nnodes="1" \ 6 | --node_rank="0" \ 7 | --master_addr="127.0.0.1" \ 8 | --master_port="12345" \ 9 | /src/visual_arft/src/open_r1/grpo_agent_search.py \ 10 | --output_dir /share_models/Qwen2.5-VL-7B-Instruct_GRPO_agent_search_data20_63_gpu8 \ 11 | --model_name_or_path /share_model/Qwen2.5-VL-7B-Instruct \ 12 | --dataset_name /train_data/rft_agent_20.json \ 13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \ 14 | --max_prompt_length 2048 \ 15 | --per_device_train_batch_size 1 \ 16 | --gradient_accumulation_steps 2 \ 17 | --logging_steps 1 \ 18 | --bf16 true \ 19 | --report_to wandb \ 20 | --gradient_checkpointing true \ 21 | --attn_implementation flash_attention_2 \ 22 | --max_pixels 401408 \ 23 | --num_train_epochs 400 \ 24 | --run_name Qwen25-VL-7B-GRPO-Agent-Search-data20-63-gpu8 \ 25 | --save_steps 100 \ 26 | --save_only_model true \ 27 | --num_generations 8 -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # Temp folders 174 | data/ 175 | wandb/ 176 | scripts/ 177 | checkpoints/ 178 | .vscode/ -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: style quality 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := src 7 | 8 | style: 9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py 10 | isort $(check_dirs) setup.py 11 | 12 | quality: 13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py 14 | isort --check-only $(check_dirs) setup.py 15 | flake8 --max-line-length 119 $(check_dirs) setup.py 16 | 17 | 18 | # Evaluation 19 | 20 | evaluate: 21 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/configs/ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/configs/qwen2vl_sft_config.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: Qwen/Qwen2-VL-2B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | 6 | # Data training arguments 7 | dataset_name: MMInstruction/Clevr_CoGenT_TrainA_R1 8 | dataset_configs: 9 | - all 10 | preprocessing_num_workers: 8 11 | 12 | # SFT trainer config 13 | bf16: true 14 | do_eval: true 15 | eval_strategy: "no" 16 | gradient_accumulation_steps: 4 17 | gradient_checkpointing: true 18 | gradient_checkpointing_kwargs: 19 | use_reentrant: false 20 | hub_model_id: Qwen2-VL-2B-Instruct-SFT 21 | hub_strategy: every_save 22 | learning_rate: 2.0e-05 23 | log_level: info 24 | logging_steps: 5 25 | logging_strategy: steps 26 | lr_scheduler_type: cosine 27 | packing: true 28 | max_seq_length: 4096 29 | max_steps: -1 30 | num_train_epochs: 1 31 | output_dir: data/Qwen2-VL-2B-Instruct-SFT 32 | overwrite_output_dir: true 33 | per_device_eval_batch_size: 4 34 | per_device_train_batch_size: 4 35 | push_to_hub: true 36 | report_to: 37 | - wandb 38 | save_strategy: "no" 39 | seed: 42 40 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/configs/zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/configs/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/create_vision_cot_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import concurrent.futures 4 | import io 5 | import json 6 | import os 7 | import random 8 | import re 9 | import time 10 | from concurrent.futures import ThreadPoolExecutor 11 | from functools import partial 12 | from io import BytesIO 13 | from typing import Dict, List 14 | 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import pandas as pd 18 | from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk 19 | from tqdm import tqdm 20 | 21 | import bytedtos 22 | import seaborn as sns 23 | import yaml 24 | from openai import AzureOpenAI 25 | from PIL import Image 26 | from pillow_avif import AvifImagePlugin 27 | 28 | 29 | PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions. 30 | 31 | Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A". 32 | 33 | Please strictly do not include "Answer:" in the question part to avoid confusion and leakage. 34 | 35 | Input Format: 36 | Original Question: {original_question} 37 | Original Answer: {original_answer} 38 | 39 | Output Format: 40 | Question: [rewrite the question if necessary] 41 | Answer: [answer with reasoning steps, including calculations where applicable] 42 | step-by-step reasoning process 43 | easy to verify answer 44 | """ 45 | 46 | 47 | def get_image_data_url(image_input): 48 | if isinstance(image_input, str) and image_input.startswith("data:"): 49 | return image_input 50 | 51 | if isinstance(image_input, str) and image_input.startswith("http"): 52 | image_input = load_image(image_input) 53 | 54 | if isinstance(image_input, str): 55 | image_input = Image.open(image_input) 56 | 57 | if not isinstance(image_input, Image.Image): 58 | raise ValueError("Unsupported image input type") 59 | 60 | if image_input.mode != "RGB": 61 | image_input = image_input.convert("RGB") 62 | 63 | buffer = BytesIO() 64 | image_input.save(buffer, format="JPEG") 65 | img_bytes = buffer.getvalue() 66 | base64_data = base64.b64encode(img_bytes).decode("utf-8") 67 | return f"data:image/jpeg;base64,{base64_data}" 68 | 69 | 70 | def gpt4o_query(image, prompt, max_retries=5, initial_delay=3): 71 | if image is None: 72 | return None 73 | 74 | data_url_list = [get_image_data_url(image)] 75 | client = AzureOpenAI( 76 | azure_endpoint="YOUR_AZURE_ENDPOINT", 77 | api_version="2023-07-01-preview", 78 | api_key="YOUR_API_KEY", 79 | ) 80 | 81 | for attempt in range(max_retries): 82 | try: 83 | messages = [ 84 | { 85 | "role": "system", 86 | "content": "You are an expert to analyze the image and provide useful information for users.", 87 | }, 88 | { 89 | "role": "user", 90 | "content": [ 91 | {"type": "text", "text": prompt}, 92 | ], 93 | }, 94 | ] 95 | 96 | for data_url in data_url_list: 97 | messages[1]["content"].insert( 98 | 0, {"type": "image_url", "image_url": {"url": data_url}} 99 | ) 100 | 101 | response = client.chat.completions.create( 102 | model="gpt-4o-2024-08-06", 103 | messages=messages, 104 | temperature=0.2, 105 | max_tokens=8192, 106 | ) 107 | return response.choices[0].message.content 108 | 109 | except Exception as e: 110 | if attempt == max_retries - 1: 111 | raise Exception( 112 | f"Failed after {max_retries} attempts. Last error: {str(e)}" 113 | ) 114 | delay = initial_delay * (2**attempt) + random.uniform( 115 | 0, 0.1 * initial_delay * (2**attempt) 116 | ) 117 | time.sleep(delay) 118 | 119 | 120 | def process_single_item(example): 121 | try: 122 | image_path = example["image_path"] 123 | formatted_prompt = PROMPT_FORMAT.format( 124 | original_question=example["question"], original_answer=example["answer"] 125 | ) 126 | 127 | response = gpt4o_query(image_path, formatted_prompt) 128 | example["gpt4o_response"] = response 129 | return example 130 | except Exception as e: 131 | print(f"Error processing item: {str(e)}") 132 | example["gpt4o_response"] = None 133 | return example 134 | 135 | 136 | def main(): 137 | dataset_path = "path/to/your/dataset" 138 | full_dataset = load_from_disk(dataset_path) 139 | 140 | processed_dataset = full_dataset.map( 141 | function=partial(process_single_item), 142 | num_proc=256, 143 | desc="Processing dataset with GPT-4o", 144 | keep_in_memory=True, 145 | ) 146 | 147 | output_path = f"{dataset_path}_processed" 148 | processed_dataset.save_to_disk(output_path) 149 | print(f"Processed dataset saved to: {output_path}") 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/lmms_eval_qwen2vl.sh: -------------------------------------------------------------------------------- 1 | export HF_HOME="" 2 | export HF_TOKEN="" 3 | export HF_HUB_ENABLE_HF_TRANSFER="1" 4 | 5 | export API_TYPE="" 6 | export AZURE_ENDPOINT="" 7 | export AZURE_API_KEY="" 8 | export API_VERSION="" 9 | export MODEL_VERSION="" 10 | export NAVIT_ATTENTION_IMPLEMENTATION="eager" 11 | 12 | # Prompt for installation with 3-second timeout 13 | read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true 14 | if [ "$install_deps" = "YES" ]; then 15 | # Prepare the environment 16 | pip3 install --upgrade pip 17 | pip3 install -U setuptools 18 | 19 | cd 20 | if [ ! -d "maas_engine" ]; then 21 | git clone 22 | else 23 | echo "maas_engine directory already exists, skipping clone" 24 | fi 25 | cd maas_engine 26 | git pull 27 | git checkout 28 | pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]" 29 | 30 | current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2) 31 | if [ "$current_version" != "4.46.2" ]; then 32 | echo "Installing transformers 4.46.2 (current version: $current_version)" 33 | pip3 install transformers==4.46.2 34 | else 35 | echo "transformers 4.46.2 is already installed" 36 | fi 37 | 38 | cd 39 | rm -rf 40 | pip3 install -e . 41 | pip3 install -U pydantic 42 | pip3 install Levenshtein 43 | pip3 install nltk 44 | python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)" 45 | fi 46 | 47 | TASKS=mmmu_val,mathvista_testmini,mmmu_pro 48 | MODEL_BASENAME=qwen2_vl 49 | 50 | model_checkpoint="" 51 | echo "MODEL_BASENAME: ${MODEL_BASENAME}" 52 | cd 53 | 54 | python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \ 55 | --model qwen2_vl \ 56 | --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \ 57 | --tasks ${TASKS} \ 58 | --batch_size 1 \ 59 | --log_samples \ 60 | --log_samples_suffix ${MODEL_BASENAME} \ 61 | --output_path ./logs -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/prepare_hf_data.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | import random 5 | from typing import List, Dict 6 | import numpy as np 7 | from concurrent.futures import ThreadPoolExecutor 8 | from tqdm import tqdm 9 | import datasets 10 | 11 | import io 12 | from datasets import load_dataset, load_from_disk, concatenate_datasets 13 | from PIL import Image 14 | from tqdm import tqdm 15 | from functools import partial 16 | from pillow_avif import AvifImagePlugin 17 | from datasets import Dataset 18 | import json 19 | import yaml 20 | import os 21 | import re 22 | import time 23 | import random 24 | import base64 25 | from openai import AzureOpenAI 26 | import concurrent.futures 27 | from typing import List, Dict 28 | import argparse 29 | import time 30 | 31 | 32 | def extract_problem_solution(gpt4o_response): 33 | # Split the response into parts 34 | parts = gpt4o_response.split("") 35 | 36 | # Extract the problem (first part before any tags) 37 | problem = parts[0].strip() 38 | # Remove "Question:" prefix if it exists 39 | problem = re.sub(r"^Question:\s*", "", problem) 40 | # Remove "Answer:" at the end of the problem 41 | problem = re.sub(r"\s*Answer:\s*$", "", problem).strip() 42 | 43 | # Combine all the reasoning steps into a single block 44 | think_parts = [p.split("")[0].strip() for p in parts[1:] if "" in p] 45 | solution = f"{' '.join(think_parts)}" 46 | 47 | # Add the final answer if it exists, removing "Answer:" prefix 48 | if "" in gpt4o_response: 49 | final_answer = ( 50 | gpt4o_response.split("")[-1].split("")[0].strip() 51 | ) 52 | final_answer = re.sub(r"^Answer:\s*", "", final_answer) 53 | solution += f"\n\n{final_answer}" 54 | 55 | return problem, solution 56 | 57 | 58 | def load_image_from_path(image_path): 59 | try: 60 | img = Image.open(image_path) 61 | return img 62 | except Exception as e: 63 | print(f"Error loading image {image_path}: {str(e)}") 64 | return None 65 | 66 | 67 | def process_raw_data(raw_data): 68 | # Parse the raw data if it's a string 69 | if isinstance(raw_data, str): 70 | data = json.loads(raw_data) 71 | else: 72 | data = raw_data 73 | 74 | # Extract problem and solution 75 | try: 76 | problem, solution = extract_problem_solution(data["gpt4o_response"]) 77 | image = load_image_from_path(data["image_path"]) 78 | 79 | return { 80 | "image": image, 81 | "problem": problem, 82 | "solution": solution, 83 | "original_question": data["question"], 84 | "original_answer": data["answer"], 85 | } 86 | except Exception as e: 87 | print(f"Error processing data {data}: {str(e)}") 88 | return { 89 | "image": None, 90 | "problem": None, 91 | "solution": None, 92 | "original_question": None, 93 | "original_answer": None, 94 | } 95 | 96 | 97 | raw_data_list = [ 98 | "/path/to/reasoning_data_with_response_90k_verified", 99 | ] 100 | 101 | raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list]) 102 | 103 | processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42) 104 | 105 | hf_dict = { 106 | "image": [], 107 | "problem": [], 108 | "solution": [], 109 | "original_question": [], 110 | "original_answer": [], 111 | } 112 | 113 | for item in tqdm(processed_data): 114 | hf_dict["image"].append(item["image"]) 115 | hf_dict["problem"].append(item["problem"]) 116 | hf_dict["solution"].append(item["solution"]) 117 | hf_dict["original_question"].append(item["original_question"]) 118 | hf_dict["original_answer"].append(item["original_answer"]) 119 | 120 | 121 | features = datasets.Features( 122 | { 123 | "image": datasets.Image(), 124 | "problem": datasets.Value("string"), 125 | "solution": datasets.Value("string"), 126 | "original_question": datasets.Value("string"), 127 | "original_answer": datasets.Value("string"), 128 | } 129 | ) 130 | 131 | 132 | def has_empty_tags(text): 133 | # Pattern to match empty tags like 134 | pattern = r"<[^>]+>]+>" 135 | return bool(re.search(pattern, text)) 136 | 137 | 138 | def has_answer_pattern(text): 139 | if "Answer:" in text: 140 | return True 141 | return False 142 | 143 | 144 | def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement 145 | # Assuming the image is in a format that can be checked for dimensions 146 | # You might need to adjust this depending on how the image is stored in your dataset 147 | try: 148 | image = example["image"] # or however your image is accessed 149 | if isinstance(image, dict) and "height" in image and "width" in image: 150 | return image["height"] >= 28 and image["width"] >= 28 151 | # If image is a PIL Image or similar 152 | return image.height >= 28 and image.width >= 28 153 | except: 154 | return False 155 | 156 | 157 | ds = datasets.Dataset.from_dict(hf_dict, features=features) 158 | ds = ds.filter( 159 | lambda x: not has_empty_tags(x["solution"]) 160 | and not has_answer_pattern(x["problem"]) 161 | and has_valid_image_size(x) 162 | and x["image"] is not None, 163 | num_proc=128, 164 | ) 165 | # Push to Hugging Face Hub 166 | ds.push_to_hub("path/to/your/dataset") 167 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/train_aria_moe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_BLOCKING_WAIT=0 4 | export TOKENIZERS_PARALLELISM=false 5 | export OMP_NUM_THREADS=8 6 | export NCCL_IB_DISABLE=0 7 | export NCCL_IB_GID_INDEX=3 8 | export NCCL_SOCKET_IFNAME=eth0 9 | export NCCL_DEBUG=INFO 10 | 11 | # CONFIG Huggingface 12 | # export HF_TOKEN="" 13 | export HF_TOKEN="" 14 | export HF_HOME="$HOME/.cache/huggingface" 15 | export HF_HUB_ENABLE_HF_TRANSFER="1" 16 | 17 | export NCCL_DEBUG=INFO 18 | 19 | GPUS="0,1,2,3,4,5,6,7" 20 | 21 | # 取 worker0 第一个 port 22 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' ')) 23 | port=${ports[0]} 24 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')" 25 | 26 | echo "total workers: ${ARNOLD_WORKER_NUM}" 27 | echo "cur worker id: ${ARNOLD_ID}" 28 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}" 29 | echo "master ip: ${METIS_WORKER_0_HOST}" 30 | echo "master port: ${port}" 31 | echo "master port in cmd: ${port_in_cmd}" 32 | 33 | # export WANDB_BASE_URL=https://api.wandb.ai 34 | # export WANDB_API_KEY="" 35 | # wandb login $WANDB_API_KEY 36 | 37 | export WANDB_BASE_URL=https://api.wandb.ai 38 | export WANDB_PROJECT=vision-reasoning 39 | export WANDB_API_KEY="" 40 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S) 41 | wandb login $WANDB_API_KEY 42 | 43 | cd /home/tiger/multimodal-open-r1 44 | # pip3 install vllm==0.6.6.post1 45 | pip3 install -e ".[dev]" 46 | pip3 install wandb==0.18.3 47 | 48 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \ 49 | --nnodes="${ARNOLD_WORKER_NUM}" \ 50 | --node_rank="${ARNOLD_ID}" \ 51 | --master_addr="${METIS_WORKER_0_HOST}" \ 52 | --master_port="${port_in_cmd}" \ 53 | src/open_r1/grpo.py \ 54 | --deepspeed scripts/zero3.json \ 55 | --output_dir Aria-GRPO-mini_cot_80k \ 56 | --model_name_or_path rhymes-ai/Aria \ 57 | --dataset_name luodian/mini_cot_80k \ 58 | --max_prompt_length 8192 \ 59 | --per_device_train_batch_size 1 \ 60 | --gradient_accumulation_steps 1 \ 61 | --logging_steps 1 \ 62 | --bf16 \ 63 | --report_to wandb \ 64 | --gradient_checkpointing true \ 65 | --attn_implementation eager \ 66 | --save_total_limit 8 \ 67 | --num_train_epochs 1 \ 68 | --run_name $WANDB_RUN_NAME 69 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/train_qwen2_vl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_BLOCKING_WAIT=0 4 | export TOKENIZERS_PARALLELISM=false 5 | export OMP_NUM_THREADS=8 6 | export NCCL_IB_DISABLE=0 7 | export NCCL_IB_GID_INDEX=3 8 | export NCCL_SOCKET_IFNAME=eth0 9 | export NCCL_DEBUG=INFO 10 | 11 | GPUS="0,1,2,3,4,5,6,7" 12 | 13 | # 取 worker0 第一个 port 14 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' ')) 15 | port=${ports[0]} 16 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')" 17 | 18 | echo "total workers: ${ARNOLD_WORKER_NUM}" 19 | echo "cur worker id: ${ARNOLD_ID}" 20 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}" 21 | echo "master ip: ${METIS_WORKER_0_HOST}" 22 | echo "master port: ${port}" 23 | echo "master port in cmd: ${port_in_cmd}" 24 | 25 | # export WANDB_BASE_URL=https://api.wandb.ai 26 | # export WANDB_API_KEY="" 27 | # wandb login $WANDB_API_KEY 28 | 29 | export WANDB_BASE_URL=https://api.wandb.ai 30 | export WANDB_PROJECT=vision-reasoning 31 | export WANDB_API_KEY="" 32 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S) 33 | wandb login $WANDB_API_KEY 34 | 35 | cd /home/tiger/multimodal-open-r1 36 | # pip3 install vllm==0.6.6.post1 37 | pip3 install -e ".[dev]" 38 | pip3 install wandb==0.18.3 39 | 40 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \ 41 | --nnodes="${ARNOLD_WORKER_NUM}" \ 42 | --node_rank="${ARNOLD_ID}" \ 43 | --master_addr="${METIS_WORKER_0_HOST}" \ 44 | --master_port="${port_in_cmd}" \ 45 | src/open_r1/grpo.py \ 46 | --deepspeed scripts/zero3.json \ 47 | --output_dir checkpoints/${WANDB_RUN_NAME} \ 48 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 49 | --dataset_name luodian/${DATASET_NAME} \ 50 | --max_prompt_length 8192 \ 51 | --per_device_train_batch_size 1 \ 52 | --gradient_accumulation_steps 1 \ 53 | --logging_steps 1 \ 54 | --bf16 \ 55 | --report_to wandb \ 56 | --gradient_checkpointing true \ 57 | --attn_implementation flash_attention_2 \ 58 | --max_pixels 2359296 \ 59 | --save_total_limit 8 \ 60 | --num_train_epochs 1 \ 61 | --run_name $WANDB_RUN_NAME 62 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/zero1_no_optimizer.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "allgather_partitions": true, 5 | "allgather_bucket_size": 1e9, 6 | "overlap_comm": false, 7 | "reduce_scatter": true, 8 | "reduce_bucket_size": 1e9, 9 | "contiguous_gradients": true 10 | }, 11 | "fp16": { 12 | "enabled": "auto", 13 | "auto_cast": true, 14 | "loss_scale": 0, 15 | "initial_scale_power": 32, 16 | "loss_scale_window": 1000, 17 | "hysteresis": 2, 18 | "min_loss_scale": 1 19 | }, 20 | "bf16": { 21 | "enabled": "auto" 22 | }, 23 | "gradient_accumulation_steps": "auto", 24 | "gradient_clipping": "auto", 25 | "steps_per_print": 1, 26 | "train_batch_size": "auto", 27 | "train_micro_batch_size_per_gpu": "auto", 28 | "wall_clock_breakdown": true 29 | } -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": false 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": false 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/local_scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/run_grpo.sh: -------------------------------------------------------------------------------- 1 | cd src/r1-v 2 | 3 | export DEBUG_MODE="true" 4 | export LOG_PATH="./debug_log_2b.txt" 5 | 6 | 7 | 8 | torchrun --nproc_per_node="8" \ 9 | --nnodes="1" \ 10 | --node_rank="0" \ 11 | --master_addr="127.0.0.1" \ 12 | --master_port="12345" \ 13 | src/open_r1/grpo.py \ 14 | --output_dir \ 15 | --model_name_or_path \ 16 | --dataset_name \ 17 | --max_prompt_length 1024 \ 18 | --per_device_train_batch_size 1 \ 19 | --gradient_accumulation_steps 2 \ 20 | --logging_steps 1 \ 21 | --bf16 \ 22 | --report_to wandb \ 23 | --gradient_checkpointing false \ 24 | --attn_implementation flash_attention_2 \ 25 | --max_pixels 401408 \ 26 | --num_train_epochs 2 \ 27 | --run_name Qwen2-VL-2B-GRPO-CLEVR-70k \ 28 | --save_steps 100 \ 29 | --save_only_model true -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = open_r1 7 | known_third_party = 8 | transformers 9 | datasets 10 | fugashi 11 | git 12 | h5py 13 | matplotlib 14 | nltk 15 | numpy 16 | packaging 17 | pandas 18 | psutil 19 | pytest 20 | rouge_score 21 | sacrebleu 22 | seqeval 23 | sklearn 24 | streamlit 25 | torch 26 | tqdm 27 | 28 | line_length = 119 29 | lines_after_imports = 2 30 | multi_line_output = 3 31 | use_parentheses = True 32 | 33 | [flake8] 34 | ignore = E203, E501, E741, W503, W605 35 | max-line-length = 119 36 | per-file-ignores = 37 | # imported but unused 38 | __init__.py: F401 39 | 40 | [tool:pytest] 41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py 16 | 17 | 18 | import re 19 | import shutil 20 | from pathlib import Path 21 | 22 | from setuptools import find_packages, setup 23 | 24 | 25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info" 27 | if stale_egg_info.exists(): 28 | print( 29 | ( 30 | "Warning: {} exists.\n\n" 31 | "If you recently updated open_r1, this is expected,\n" 32 | "but it may prevent open_r1 from installing in editable mode.\n\n" 33 | "This directory is automatically generated by Python's packaging tools.\n" 34 | "I will remove it now.\n\n" 35 | "See https://github.com/pypa/pip/issues/5466 for details.\n" 36 | ).format(stale_egg_info) 37 | ) 38 | shutil.rmtree(stale_egg_info) 39 | 40 | 41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any. 42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version 43 | _deps = [ 44 | "accelerate>=1.2.1", 45 | "bitsandbytes>=0.42.0", 46 | "black>=24.4.2", 47 | "datasets>=3.2.0", 48 | "deepspeed==0.15.4", 49 | "distilabel[vllm,ray,openai]>=1.5.2", 50 | "einops>=0.8.0", 51 | "flake8>=6.0.0", 52 | "hf_transfer>=0.1.4", 53 | "huggingface-hub[cli]>=0.19.2,<1.0", 54 | "isort>=5.12.0", 55 | "liger_kernel==0.5.2", 56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]", 57 | "math-verify", # Used for math verification in grpo 58 | "packaging>=23.0", 59 | "parameterized>=0.9.0", 60 | "pytest", 61 | "safetensors>=0.3.3", 62 | "sentencepiece>=0.1.99", 63 | "torch>=2.5.1", 64 | # "transformers @ git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef", 65 | "trl==0.14.0", 66 | "vllm==0.6.6.post1", 67 | "wandb>=0.19.1", 68 | "pillow", 69 | ] 70 | 71 | # this is a lookup table with items like: 72 | # 73 | # tokenizers: "tokenizers==0.9.4" 74 | # packaging: "packaging" 75 | # 76 | # some of the values are versioned whereas others aren't. 77 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} 78 | 79 | 80 | def deps_list(*pkgs): 81 | return [deps[pkg] for pkg in pkgs] 82 | 83 | 84 | extras = {} 85 | extras["tests"] = deps_list("pytest", "parameterized") 86 | extras["torch"] = deps_list("torch") 87 | extras["quality"] = deps_list("black", "isort", "flake8") 88 | extras["eval"] = deps_list("lighteval", "math-verify") 89 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] 90 | 91 | # core dependencies shared across the whole project - keep this to a bare minimum :) 92 | install_requires = [ 93 | deps["accelerate"], 94 | deps["bitsandbytes"], 95 | deps["einops"], 96 | deps["datasets"], 97 | deps["deepspeed"], 98 | deps["hf_transfer"], 99 | deps["huggingface-hub"], 100 | deps["liger_kernel"], 101 | deps["packaging"], # utilities from PyPA to e.g., compare versions 102 | deps["safetensors"], 103 | deps["sentencepiece"], 104 | # deps["transformers"], 105 | deps["trl"], 106 | ] 107 | 108 | setup( 109 | name="r1-v", 110 | version="0.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 111 | author="The r1-v team and the Hugging Face team (past and future)", 112 | description="R1-V", 113 | license="Apache", 114 | url="https://github.com/Deep-Agent/R1-V", 115 | package_dir={"": "src"}, 116 | packages=find_packages("src"), 117 | zip_safe=False, 118 | extras_require=extras, 119 | python_requires=">=3.10.9", 120 | install_requires=install_requires, 121 | classifiers=[ 122 | "Development Status :: 3 - Alpha", 123 | "Intended Audience :: Developers", 124 | "Intended Audience :: Education", 125 | "Intended Audience :: Science/Research", 126 | "License :: OSI Approved :: Apache Software License", 127 | "Operating System :: OS Independent", 128 | "Programming Language :: Python :: 3", 129 | "Programming Language :: Python :: 3.10", 130 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 131 | ], 132 | ) 133 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/src/open_r1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/src/visual_arft/src/open_r1/__init__.py -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/src/open_r1/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom evaluation tasks for LightEval.""" 16 | 17 | from lighteval.metrics.dynamic_metrics import ( 18 | ExprExtractionConfig, 19 | LatexExtractionConfig, 20 | multilingual_extractive_match_metric, 21 | ) 22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig 23 | from lighteval.tasks.requests import Doc 24 | from lighteval.utils.language import Language 25 | 26 | 27 | metric = multilingual_extractive_match_metric( 28 | language=Language.ENGLISH, 29 | fallback_mode="first_match", 30 | precision=5, 31 | gold_extraction_target=(LatexExtractionConfig(),), 32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), 33 | aggregation_function=max, 34 | ) 35 | 36 | 37 | def prompt_fn(line, task_name: str = None): 38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically""" 39 | return Doc( 40 | task_name=task_name, 41 | query=line["problem"], 42 | choices=[line["solution"]], 43 | gold_index=0, 44 | ) 45 | 46 | 47 | # Define tasks 48 | aime24 = LightevalTaskConfig( 49 | name="aime24", 50 | suite=["custom"], 51 | prompt_function=prompt_fn, 52 | hf_repo="HuggingFaceH4/aime_2024", 53 | hf_subset="default", 54 | hf_avail_splits=["train"], 55 | evaluation_splits=["train"], 56 | few_shots_split=None, 57 | few_shots_select=None, 58 | generation_size=32768, 59 | metric=[metric], 60 | version=1, 61 | ) 62 | math_500 = LightevalTaskConfig( 63 | name="math_500", 64 | suite=["custom"], 65 | prompt_function=prompt_fn, 66 | hf_repo="HuggingFaceH4/MATH-500", 67 | hf_subset="default", 68 | hf_avail_splits=["test"], 69 | evaluation_splits=["test"], 70 | few_shots_split=None, 71 | few_shots_select=None, 72 | generation_size=32768, 73 | metric=[metric], 74 | version=1, 75 | ) 76 | 77 | # Add tasks to the table 78 | TASKS_TABLE = [] 79 | TASKS_TABLE.append(aime24) 80 | TASKS_TABLE.append(math_500) 81 | 82 | # MODULE LOGIC 83 | if __name__ == "__main__": 84 | print([t["name"] for t in TASKS_TABLE]) 85 | print(len(TASKS_TABLE)) 86 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/src/open_r1/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from distilabel.llms import OpenAILLM 18 | from distilabel.pipeline import Pipeline 19 | from distilabel.steps.tasks import TextGeneration 20 | 21 | 22 | def build_distilabel_pipeline( 23 | model: str, 24 | base_url: str = "http://localhost:8000/v1", 25 | prompt_column: Optional[str] = None, 26 | temperature: Optional[float] = None, 27 | top_p: Optional[float] = None, 28 | max_new_tokens: int = 8192, 29 | num_generations: int = 1, 30 | ) -> Pipeline: 31 | generation_kwargs = {"max_new_tokens": max_new_tokens} 32 | 33 | if temperature is not None: 34 | generation_kwargs["temperature"] = temperature 35 | 36 | if top_p is not None: 37 | generation_kwargs["top_p"] = top_p 38 | 39 | with Pipeline().ray() as pipeline: 40 | TextGeneration( 41 | llm=OpenAILLM( 42 | base_url=base_url, 43 | api_key="something", 44 | model=model, 45 | # thinking can take some time... 46 | timeout=10 * 60, 47 | generation_kwargs=generation_kwargs, 48 | ), 49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, 50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion 51 | num_generations=num_generations, 52 | ) 53 | 54 | return pipeline 55 | 56 | 57 | if __name__ == "__main__": 58 | import argparse 59 | 60 | from datasets import load_dataset 61 | 62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") 63 | parser.add_argument( 64 | "--hf-dataset", 65 | type=str, 66 | required=True, 67 | help="HuggingFace dataset to load", 68 | ) 69 | parser.add_argument( 70 | "--hf-dataset-config", 71 | type=str, 72 | required=False, 73 | help="Dataset config to use", 74 | ) 75 | parser.add_argument( 76 | "--hf-dataset-split", 77 | type=str, 78 | default="train", 79 | help="Dataset split to use", 80 | ) 81 | parser.add_argument("--prompt-column", type=str, default="prompt") 82 | parser.add_argument( 83 | "--model", 84 | type=str, 85 | required=True, 86 | help="Model name to use for generation", 87 | ) 88 | parser.add_argument( 89 | "--vllm-server-url", 90 | type=str, 91 | default="http://localhost:8000/v1", 92 | help="URL of the vLLM server", 93 | ) 94 | parser.add_argument( 95 | "--temperature", 96 | type=float, 97 | help="Temperature for generation", 98 | ) 99 | parser.add_argument( 100 | "--top-p", 101 | type=float, 102 | help="Top-p value for generation", 103 | ) 104 | parser.add_argument( 105 | "--max-new-tokens", 106 | type=int, 107 | default=8192, 108 | help="Maximum number of new tokens to generate", 109 | ) 110 | parser.add_argument( 111 | "--num-generations", 112 | type=int, 113 | default=1, 114 | help="Number of generations per problem", 115 | ) 116 | parser.add_argument( 117 | "--hf-output-dataset", 118 | type=str, 119 | required=False, 120 | help="HuggingFace repo to push results to", 121 | ) 122 | parser.add_argument( 123 | "--private", 124 | action="store_true", 125 | help="Whether to make the output dataset private when pushing to HF Hub", 126 | ) 127 | 128 | args = parser.parse_args() 129 | 130 | print("\nRunning with arguments:") 131 | for arg, value in vars(args).items(): 132 | print(f" {arg}: {value}") 133 | print() 134 | 135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") 136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) 137 | print("Dataset loaded!") 138 | 139 | pipeline = build_distilabel_pipeline( 140 | model=args.model, 141 | base_url=args.vllm_server_url, 142 | prompt_column=args.prompt_column, 143 | temperature=args.temperature, 144 | top_p=args.top_p, 145 | max_new_tokens=args.max_new_tokens, 146 | num_generations=args.num_generations, 147 | ) 148 | 149 | print("Running generation pipeline...") 150 | distiset = pipeline.run(dataset=dataset, use_cache=False) 151 | print("Generation pipeline finished!") 152 | 153 | if args.hf_output_dataset: 154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") 155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private) 156 | print("Dataset pushed!") 157 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/src/open_r1/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .grpo_trainer import Qwen2VLGRPOTrainer 2 | from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer 3 | from .vllm_grpo_trainer_modified import Qwen2VLGRPOVLLMTrainerModified 4 | from .grpo_trainer_aid import Qwen2VLGRPOTrainer_AID 5 | from .grpo_trainer_visual_rft import Qwen2VLGRPOTrainer_Visual_RFT 6 | from .grpo_trainer_mp import Qwen2VLGRPOTrainer_MP 7 | 8 | __all__ = [ 9 | "Qwen2VLGRPOTrainer", 10 | "Qwen2VLGRPOVLLMTrainer", 11 | "Qwen2VLGRPOVLLMTrainerModified", 12 | "Qwen2VLGRPOTrainer_AID", 13 | "Qwen2VLGRPOTrainer_Visual_RFT", 14 | "Qwen2VLGRPOTrainer_MP", 15 | ] 16 | -------------------------------------------------------------------------------- /Visual-ARFT/src/visual_arft/temp_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/src/visual_arft/temp_image.png -------------------------------------------------------------------------------- /assets/case_cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/case_cls.png -------------------------------------------------------------------------------- /assets/case_lisa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/case_lisa.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/framework.png -------------------------------------------------------------------------------- /assets/pokeymon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/pokeymon.jpg -------------------------------------------------------------------------------- /assets/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/radar.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/teaser.png -------------------------------------------------------------------------------- /classification/val_data/fgvc_aircraft.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/fgvc_aircraft.pth -------------------------------------------------------------------------------- /classification/val_data/fgvc_aircraft.txt: -------------------------------------------------------------------------------- 1 | 707-320 2 | 727-200 3 | 737-200 4 | 737-300 5 | 737-400 6 | 737-500 7 | 737-600 8 | 737-700 9 | 737-800 10 | 737-900 11 | 747-100 12 | 747-200 13 | 747-300 14 | 747-400 15 | 757-200 16 | 757-300 17 | 767-200 18 | 767-300 19 | 767-400 20 | 777-200 21 | 777-300 22 | A300B4 23 | A310 24 | A318 25 | A319 26 | A320 27 | A321 28 | A330-200 29 | A330-300 30 | A340-200 31 | A340-300 32 | A340-500 33 | A340-600 34 | A380 35 | ATR-42 36 | ATR-72 37 | An-12 38 | BAE 146-200 39 | BAE 146-300 40 | BAE-125 41 | Beechcraft 1900 42 | Boeing 717 43 | C-130 44 | C-47 45 | CRJ-200 46 | CRJ-700 47 | CRJ-900 48 | Cessna 172 49 | Cessna 208 50 | Cessna 525 51 | Cessna 560 52 | Challenger 600 53 | DC-10 54 | DC-3 55 | DC-6 56 | DC-8 57 | DC-9-30 58 | DH-82 59 | DHC-1 60 | DHC-6 61 | DHC-8-100 62 | DHC-8-300 63 | DR-400 64 | Dornier 328 65 | E-170 66 | E-190 67 | E-195 68 | EMB-120 69 | ERJ 135 70 | ERJ 145 71 | Embraer Legacy 600 72 | Eurofighter Typhoon 73 | F-16A/B 74 | F/A-18 75 | Falcon 2000 76 | Falcon 900 77 | Fokker 100 78 | Fokker 50 79 | Fokker 70 80 | Global Express 81 | Gulfstream IV 82 | Gulfstream V 83 | Hawk T1 84 | Il-76 85 | L-1011 86 | MD-11 87 | MD-80 88 | MD-87 89 | MD-90 90 | Metroliner 91 | Model B200 92 | PA-28 93 | SR-20 94 | Saab 2000 95 | Saab 340 96 | Spitfire 97 | Tornado 98 | Tu-134 99 | Tu-154 100 | Yak-42 -------------------------------------------------------------------------------- /classification/val_data/oxford_flowers.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/oxford_flowers.pth -------------------------------------------------------------------------------- /classification/val_data/oxford_flowers.txt: -------------------------------------------------------------------------------- 1 | pink primrose 2 | hard-leaved pocket orchid 3 | canterbury bells 4 | sweet pea 5 | english marigold 6 | tiger lily 7 | moon orchid 8 | bird of paradise 9 | monkshood 10 | globe thistle 11 | snapdragon 12 | colts foot 13 | king protea 14 | spear thistle 15 | yellow iris 16 | globe-flower 17 | purple coneflower 18 | peruvian lily 19 | balloon flower 20 | giant white arum lily 21 | fire lily 22 | pincushion flower 23 | fritillary 24 | red ginger 25 | grape hyacinth 26 | corn poppy 27 | prince of wales feathers 28 | stemless gentian 29 | artichoke 30 | sweet william 31 | carnation 32 | garden phlox 33 | love in the mist 34 | mexican aster 35 | alpine sea holly 36 | ruby-lipped cattleya 37 | cape flower 38 | great masterwort 39 | siam tulip 40 | lenten rose 41 | barbeton daisy 42 | daffodil 43 | sword lily 44 | poinsettia 45 | bolero deep blue 46 | wallflower 47 | marigold 48 | buttercup 49 | oxeye daisy 50 | common dandelion 51 | petunia 52 | wild pansy 53 | primula 54 | sunflower 55 | pelargonium 56 | bishop of llandaff 57 | gaura 58 | geranium 59 | orange dahlia 60 | pink-yellow dahlia 61 | cautleya spicata 62 | japanese anemone 63 | black-eyed susan 64 | silverbush 65 | californian poppy 66 | osteospermum 67 | spring crocus 68 | bearded iris 69 | windflower 70 | tree poppy 71 | gazania 72 | azalea 73 | water lily 74 | rose 75 | thorn apple 76 | morning glory 77 | passion flower 78 | lotus 79 | toad lily 80 | anthurium 81 | frangipani 82 | clematis 83 | hibiscus 84 | columbine 85 | desert-rose 86 | tree mallow 87 | magnolia 88 | cyclamen 89 | watercress 90 | canna lily 91 | hippeastrum 92 | bee balm 93 | ball moss 94 | foxglove 95 | bougainvillea 96 | camellia 97 | mallow 98 | mexican petunia 99 | bromelia 100 | blanket flower 101 | trumpet creeper 102 | blackberry lily -------------------------------------------------------------------------------- /classification/val_data/pets.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/pets.pth -------------------------------------------------------------------------------- /classification/val_data/pets.txt: -------------------------------------------------------------------------------- 1 | abyssinian 2 | american_bulldog 3 | american_pit_bull_terrier 4 | basset_hound 5 | beagle 6 | bengal 7 | birman 8 | bombay 9 | boxer 10 | british_shorthair 11 | chihuahua 12 | egyptian_mau 13 | english_cocker_spaniel 14 | english_setter 15 | german_shorthaired 16 | great_pyrenees 17 | havanese 18 | japanese_chin 19 | keeshond 20 | leonberger 21 | maine_coon 22 | miniature_pinscher 23 | newfoundland 24 | persian 25 | pomeranian 26 | pug 27 | ragdoll 28 | russian_blue 29 | saint_bernard 30 | samoyed 31 | scottish_terrier 32 | shiba_inu 33 | siamese 34 | sphynx 35 | staffordshire_bull_terrier 36 | wheaten_terrier 37 | yorkshire_terrier -------------------------------------------------------------------------------- /classification/val_data/stanford_cars.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/stanford_cars.pth -------------------------------------------------------------------------------- /classification/val_data/stanford_cars.txt: -------------------------------------------------------------------------------- 1 | 2000 AM General Hummer SUV 2 | 2012 Acura RL Sedan 3 | 2012 Acura TL Sedan 4 | 2008 Acura TL Type-S 5 | 2012 Acura TSX Sedan 6 | 2001 Acura Integra Type R 7 | 2012 Acura ZDX Hatchback 8 | 2012 Aston Martin V8 Vantage Convertible 9 | 2012 Aston Martin V8 Vantage Coupe 10 | 2012 Aston Martin Virage Convertible 11 | 2012 Aston Martin Virage Coupe 12 | 2008 Audi RS 4 Convertible 13 | 2012 Audi A5 Coupe 14 | 2012 Audi TTS Coupe 15 | 2012 Audi R8 Coupe 16 | 1994 Audi V8 Sedan 17 | 1994 Audi 100 Sedan 18 | 1994 Audi 100 Wagon 19 | 2011 Audi TT Hatchback 20 | 2011 Audi S6 Sedan 21 | 2012 Audi S5 Convertible 22 | 2012 Audi S5 Coupe 23 | 2012 Audi S4 Sedan 24 | 2007 Audi S4 Sedan 25 | 2012 Audi TT RS Coupe 26 | 2012 BMW ActiveHybrid 5 Sedan 27 | 2012 BMW 1 Series Convertible 28 | 2012 BMW 1 Series Coupe 29 | 2012 BMW 3 Series Sedan 30 | 2012 BMW 3 Series Wagon 31 | 2007 BMW 6 Series Convertible 32 | 2007 BMW X5 SUV 33 | 2012 BMW X6 SUV 34 | 2012 BMW M3 Coupe 35 | 2010 BMW M5 Sedan 36 | 2010 BMW M6 Convertible 37 | 2012 BMW X3 SUV 38 | 2012 BMW Z4 Convertible 39 | 2012 Bentley Continental Supersports Conv. Convertible 40 | 2009 Bentley Arnage Sedan 41 | 2011 Bentley Mulsanne Sedan 42 | 2012 Bentley Continental GT Coupe 43 | 2007 Bentley Continental GT Coupe 44 | 2007 Bentley Continental Flying Spur Sedan 45 | 2009 Bugatti Veyron 16.4 Convertible 46 | 2009 Bugatti Veyron 16.4 Coupe 47 | 2012 Buick Regal GS 48 | 2007 Buick Rainier SUV 49 | 2012 Buick Verano Sedan 50 | 2012 Buick Enclave SUV 51 | 2012 Cadillac CTS-V Sedan 52 | 2012 Cadillac SRX SUV 53 | 2007 Cadillac Escalade EXT Crew Cab 54 | 2012 Chevrolet Silverado 1500 Hybrid Crew Cab 55 | 2012 Chevrolet Corvette Convertible 56 | 2012 Chevrolet Corvette ZR1 57 | 2007 Chevrolet Corvette Ron Fellows Edition Z06 58 | 2012 Chevrolet Traverse SUV 59 | 2012 Chevrolet Camaro Convertible 60 | 2010 Chevrolet HHR SS 61 | 2007 Chevrolet Impala Sedan 62 | 2012 Chevrolet Tahoe Hybrid SUV 63 | 2012 Chevrolet Sonic Sedan 64 | 2007 Chevrolet Express Cargo Van 65 | 2012 Chevrolet Avalanche Crew Cab 66 | 2010 Chevrolet Cobalt SS 67 | 2010 Chevrolet Malibu Hybrid Sedan 68 | 2009 Chevrolet TrailBlazer SS 69 | 2012 Chevrolet Silverado 2500HD Regular Cab 70 | 2007 Chevrolet Silverado 1500 Classic Extended Cab 71 | 2007 Chevrolet Express Van 72 | 2007 Chevrolet Monte Carlo Coupe 73 | 2007 Chevrolet Malibu Sedan 74 | 2012 Chevrolet Silverado 1500 Extended Cab 75 | 2012 Chevrolet Silverado 1500 Regular Cab 76 | 2009 Chrysler Aspen SUV 77 | 2010 Chrysler Sebring Convertible 78 | 2012 Chrysler Town and Country Minivan 79 | 2010 Chrysler 300 SRT-8 80 | 2008 Chrysler Crossfire Convertible 81 | 2008 Chrysler PT Cruiser Convertible 82 | 2002 Daewoo Nubira Wagon 83 | 2012 Dodge Caliber Wagon 84 | 2007 Dodge Caliber Wagon 85 | 1997 Dodge Caravan Minivan 86 | 2010 Dodge Ram Pickup 3500 Crew Cab 87 | 2009 Dodge Ram Pickup 3500 Quad Cab 88 | 2009 Dodge Sprinter Cargo Van 89 | 2012 Dodge Journey SUV 90 | 2010 Dodge Dakota Crew Cab 91 | 2007 Dodge Dakota Club Cab 92 | 2008 Dodge Magnum Wagon 93 | 2011 Dodge Challenger SRT8 94 | 2012 Dodge Durango SUV 95 | 2007 Dodge Durango SUV 96 | 2012 Dodge Charger Sedan 97 | 2009 Dodge Charger SRT-8 98 | 1998 Eagle Talon Hatchback 99 | 2012 FIAT 500 Abarth 100 | 2012 FIAT 500 Convertible 101 | 2012 Ferrari FF Coupe 102 | 2012 Ferrari California Convertible 103 | 2012 Ferrari 458 Italia Convertible 104 | 2012 Ferrari 458 Italia Coupe 105 | 2012 Fisker Karma Sedan 106 | 2012 Ford F-450 Super Duty Crew Cab 107 | 2007 Ford Mustang Convertible 108 | 2007 Ford Freestar Minivan 109 | 2009 Ford Expedition EL SUV 110 | 2012 Ford Edge SUV 111 | 2011 Ford Ranger SuperCab 112 | 2006 Ford GT Coupe 113 | 2012 Ford F-150 Regular Cab 114 | 2007 Ford F-150 Regular Cab 115 | 2007 Ford Focus Sedan 116 | 2012 Ford E-Series Wagon Van 117 | 2012 Ford Fiesta Sedan 118 | 2012 GMC Terrain SUV 119 | 2012 GMC Savana Van 120 | 2012 GMC Yukon Hybrid SUV 121 | 2012 GMC Acadia SUV 122 | 2012 GMC Canyon Extended Cab 123 | 1993 Geo Metro Convertible 124 | 2010 HUMMER H3T Crew Cab 125 | 2009 HUMMER H2 SUT Crew Cab 126 | 2012 Honda Odyssey Minivan 127 | 2007 Honda Odyssey Minivan 128 | 2012 Honda Accord Coupe 129 | 2012 Honda Accord Sedan 130 | 2012 Hyundai Veloster Hatchback 131 | 2012 Hyundai Santa Fe SUV 132 | 2012 Hyundai Tucson SUV 133 | 2012 Hyundai Veracruz SUV 134 | 2012 Hyundai Sonata Hybrid Sedan 135 | 2007 Hyundai Elantra Sedan 136 | 2012 Hyundai Accent Sedan 137 | 2012 Hyundai Genesis Sedan 138 | 2012 Hyundai Sonata Sedan 139 | 2012 Hyundai Elantra Touring Hatchback 140 | 2012 Hyundai Azera Sedan 141 | 2012 Infiniti G Coupe IPL 142 | 2011 Infiniti QX56 SUV 143 | 2008 Isuzu Ascender SUV 144 | 2012 Jaguar XK XKR 145 | 2012 Jeep Patriot SUV 146 | 2012 Jeep Wrangler SUV 147 | 2012 Jeep Liberty SUV 148 | 2012 Jeep Grand Cherokee SUV 149 | 2012 Jeep Compass SUV 150 | 2008 Lamborghini Reventon Coupe 151 | 2012 Lamborghini Aventador Coupe 152 | 2012 Lamborghini Gallardo LP 570-4 Superleggera 153 | 2001 Lamborghini Diablo Coupe 154 | 2012 Land Rover Range Rover SUV 155 | 2012 Land Rover LR2 SUV 156 | 2011 Lincoln Town Car Sedan 157 | 2012 MINI Cooper Roadster Convertible 158 | 2012 Maybach Landaulet Convertible 159 | 2011 Mazda Tribute SUV 160 | 2012 McLaren MP4-12C Coupe 161 | 1993 Mercedes-Benz 300-Class Convertible 162 | 2012 Mercedes-Benz C-Class Sedan 163 | 2009 Mercedes-Benz SL-Class Coupe 164 | 2012 Mercedes-Benz E-Class Sedan 165 | 2012 Mercedes-Benz S-Class Sedan 166 | 2012 Mercedes-Benz Sprinter Van 167 | 2012 Mitsubishi Lancer Sedan 168 | 2012 Nissan Leaf Hatchback 169 | 2012 Nissan NV Passenger Van 170 | 2012 Nissan Juke Hatchback 171 | 1998 Nissan 240SX Coupe 172 | 1999 Plymouth Neon Coupe 173 | 2012 Porsche Panamera Sedan 174 | 2012 Ram C/V Cargo Van Minivan 175 | 2012 Rolls-Royce Phantom Drophead Coupe Convertible 176 | 2012 Rolls-Royce Ghost Sedan 177 | 2012 Rolls-Royce Phantom Sedan 178 | 2012 Scion xD Hatchback 179 | 2009 Spyker C8 Convertible 180 | 2009 Spyker C8 Coupe 181 | 2007 Suzuki Aerio Sedan 182 | 2012 Suzuki Kizashi Sedan 183 | 2012 Suzuki SX4 Hatchback 184 | 2012 Suzuki SX4 Sedan 185 | 2012 Tesla Model S Sedan 186 | 2012 Toyota Sequoia SUV 187 | 2012 Toyota Camry Sedan 188 | 2012 Toyota Corolla Sedan 189 | 2012 Toyota 4Runner SUV 190 | 2012 Volkswagen Golf Hatchback 191 | 1991 Volkswagen Golf Hatchback 192 | 2012 Volkswagen Beetle Hatchback 193 | 2012 Volvo C30 Hatchback 194 | 1993 Volvo 240 Sedan 195 | 2007 Volvo XC90 SUV 196 | 2012 smart fortwo Convertible -------------------------------------------------------------------------------- /coco_evaluation/coco_evaluation.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import copy 3 | import io 4 | import itertools 5 | import json 6 | import logging 7 | import os 8 | import warnings 9 | 10 | import numpy as np 11 | from pycocotools.cocoeval import COCOeval 12 | from tabulate import tabulate 13 | 14 | import cv2 15 | import torch 16 | from pycocotools.coco import COCO 17 | 18 | logger = logging.getLogger("coco") 19 | 20 | 21 | def xyxy2xywh(bbox): 22 | """ 23 | change bbox to coco format 24 | :param bbox: [x1, y1, x2, y2] 25 | :return: [x, y, w, h] 26 | """ 27 | return [ 28 | bbox[0], 29 | bbox[1], 30 | bbox[2] - bbox[0], 31 | bbox[3] - bbox[1], 32 | ] 33 | 34 | 35 | class CocoDetectionEvaluator: 36 | def __init__(self, ann_path): 37 | 38 | self.coco_api = COCO(ann_path) 39 | self.cat_ids = sorted(self.coco_api.getCatIds()) 40 | self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} 41 | self.cats = self.coco_api.loadCats(self.cat_ids) 42 | self.class_names = [cat["name"] for cat in self.cats] 43 | self.img_ids = sorted(self.coco_api.imgs.keys()) 44 | img_info = self.coco_api.loadImgs(self.img_ids) 45 | 46 | # assert hasattr(dataset, "coco_api") 47 | # self.class_names = dataset.class_names 48 | # self.coco_api = dataset.coco_api 49 | # self.cat_ids = dataset.cat_ids 50 | self.metric_names = ["mAP", "AP_50", "AP_75", "AP_small", "AP_m", "AP_l"] 51 | 52 | def results2json(self, results): 53 | """ 54 | results: {image_id: {label: [bboxes...] } } 55 | :return coco json format: {image_id: 56 | category_id: 57 | bbox: 58 | score: } 59 | """ 60 | json_results = [] 61 | for image_id, dets in results.items(): 62 | for label, bboxes in dets.items(): 63 | category_id = self.cat_ids[label] 64 | for bbox in bboxes: 65 | score = float(bbox[4]) 66 | detection = dict( 67 | image_id=int(image_id), 68 | category_id=int(category_id), 69 | bbox=xyxy2xywh(bbox), 70 | score=score, 71 | ) 72 | json_results.append(detection) 73 | return json_results 74 | 75 | def evaluate(self, results, save_dir, rank=-1): 76 | ### Original 77 | # results_json = self.results2json(results) 78 | ### lzy modified 79 | with open(results, 'r') as json_file: 80 | results_json = json.load(json_file) 81 | if len(results_json) == 0: 82 | warnings.warn( 83 | "Detection result is empty! Please check whether " 84 | "training set is too small (need to increase val_interval " 85 | "in config and train more epochs). Or check annotation " 86 | "correctness." 87 | ) 88 | empty_eval_results = {} 89 | for key in self.metric_names: 90 | empty_eval_results[key] = 0 91 | return empty_eval_results 92 | json_path = os.path.join(save_dir, "results{}.json".format(rank)) 93 | json.dump(results_json, open(json_path, "w")) 94 | coco_dets = self.coco_api.loadRes(json_path) 95 | coco_eval = COCOeval( 96 | copy.deepcopy(self.coco_api), copy.deepcopy(coco_dets), "bbox" 97 | ) 98 | coco_eval.evaluate() 99 | coco_eval.accumulate() 100 | 101 | # use logger to log coco eval results 102 | redirect_string = io.StringIO() 103 | with contextlib.redirect_stdout(redirect_string): 104 | coco_eval.summarize() 105 | logger.info("\n" + redirect_string.getvalue()) 106 | 107 | # print per class AP 108 | headers = ["class", "AP50", "mAP"] 109 | colums = 6 110 | per_class_ap50s = [] 111 | per_class_maps = [] 112 | precisions = coco_eval.eval["precision"] 113 | # dimension of precisions: [TxRxKxAxM] 114 | # precision has dims (iou, recall, cls, area range, max dets) 115 | assert len(self.class_names) == precisions.shape[2] 116 | 117 | ### lzy modified 118 | per_class_results = [] 119 | for idx, name in enumerate(self.class_names): 120 | # area range index 0: all area ranges 121 | # max dets index -1: typically 100 per image 122 | precision_50 = precisions[0, :, idx, 0, -1] 123 | precision_50 = precision_50[precision_50 > -1] 124 | ap50 = np.mean(precision_50) if precision_50.size else float("nan") 125 | per_class_ap50s.append(float(ap50 * 100)) 126 | 127 | precision = precisions[:, :, idx, 0, -1] 128 | precision = precision[precision > -1] 129 | ap = np.mean(precision) if precision.size else float("nan") 130 | per_class_maps.append(float(ap * 100)) 131 | per_class_results.append({name:float(ap * 100)}) 132 | 133 | num_cols = min(colums, len(self.class_names) * len(headers)) 134 | flatten_results = [] 135 | for name, ap50, mAP in zip(self.class_names, per_class_ap50s, per_class_maps): 136 | flatten_results += [name, ap50, mAP] 137 | 138 | row_pair = itertools.zip_longest( 139 | *[flatten_results[i::num_cols] for i in range(num_cols)] 140 | ) 141 | table_headers = headers * (num_cols // len(headers)) 142 | table = tabulate( 143 | row_pair, 144 | tablefmt="pipe", 145 | floatfmt=".1f", 146 | headers=table_headers, 147 | numalign="left", 148 | ) 149 | logger.info("\n" + table) 150 | 151 | aps = coco_eval.stats[:6] 152 | eval_results = {} 153 | for k, v in zip(self.metric_names, aps): 154 | eval_results[k] = v 155 | return eval_results, per_class_results 156 | -------------------------------------------------------------------------------- /coco_evaluation/evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d8fcd1a9-d7c5-4049-bfff-27dbc49b029d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from coco_evaluation import CocoDetectionEvaluator" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "aa90e100-072b-4f55-999c-a8e03d56fe87", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "evaluator = CocoDetectionEvaluator('./data/coco/annotations/instances_val2017.json')" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "585276a5-cfb2-4e2b-a9de-30a46cbf30c5", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "results, per_class_results = evaluator.evaluate('./prediction_Qwen2_vl_2B_GRPO_coco.json', './results')" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "92d76ac4-44aa-43b5-8123-d871453e6750", 37 | "metadata": { 38 | "scrolled": true 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "### mAP and AP for all categories\n", 43 | "results, per_class_results" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "014151c7-5a90-4d46-9782-7195cb7147ed", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "### mAP and AP for selected categories\n", 54 | "selected_cate = ['bus', 'train', 'fire hydrant', 'stop sign', 'cat', 'dog', 'bed', 'toilet']\n", 55 | "# selected_cate = ['mouse', 'fork', 'hot dog', 'cat', 'airplane', 'suitcase', 'parking meter', 'sandwich', 'train', 'hair drier', 'toilet', 'toaster', 'snowboard', 'frisbee', 'bear']\n", 56 | "results, per_class_results\n", 57 | "AP_sum = 0\n", 58 | "for item in per_class_results:\n", 59 | " for key, value in item.items():\n", 60 | " if key in selected_cate:\n", 61 | " print(f\"Key: {key}, Value: {value}\")\n", 62 | " AP_sum += value\n", 63 | "print(\"mAP for selected categories: \", (AP_sum)/(len(selected_cate)))" 64 | ] 65 | } 66 | ], 67 | "metadata": { 68 | "kernelspec": { 69 | "display_name": "Python 3 (ipykernel)", 70 | "language": "python", 71 | "name": "python3" 72 | }, 73 | "language_info": { 74 | "codemirror_mode": { 75 | "name": "ipython", 76 | "version": 3 77 | }, 78 | "file_extension": ".py", 79 | "mimetype": "text/x-python", 80 | "name": "python", 81 | "nbconvert_exporter": "python", 82 | "pygments_lexer": "ipython3", 83 | "version": "3.10.13" 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 5 88 | } 89 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataset/build_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "2f84d496-a970-452d-99a1-2f67718dff9b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import time\n", 11 | "from datasets import DatasetDict, Dataset\n", 12 | "from PIL import Image\n", 13 | "import json" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "5a8ba1f2-d465-4b78-a48a-744f591a14ab", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "\"\"\"\n", 24 | "turn your json to DatasetDict\n", 25 | "\"\"\"\n", 26 | "def json_to_dataset(json_file_path):\n", 27 | " # read json file\n", 28 | " with open(json_file_path, 'r') as f:\n", 29 | " data = json.load(f)\n", 30 | "\n", 31 | " image_paths = [item['image_path'] for item in data]\n", 32 | " problems = [item['problem'] for item in data]\n", 33 | " solutions = [item['solution'] for item in data]\n", 34 | "\n", 35 | " images = [Image.open(image_path).convert('RGBA') for image_path in image_paths]\n", 36 | "\n", 37 | " dataset_dict = {\n", 38 | " 'image': images,\n", 39 | " 'problem': problems,\n", 40 | " 'solution': solutions\n", 41 | " }\n", 42 | "\n", 43 | " dataset = Dataset.from_dict(dataset_dict)\n", 44 | " dataset_dict = DatasetDict({\n", 45 | " 'train': dataset\n", 46 | " })\n", 47 | " return dataset_dict\n", 48 | "\n", 49 | "\n", 50 | "time1 = time.asctime()\n", 51 | "print(time1)\n", 52 | "### Your dataset in JSON file format consists of three parts: image, problem and solution\n", 53 | "dataset_dict = json_to_dataset('your_dataset_json_file.json')\n", 54 | "time2 = time.asctime()\n", 55 | "print(time2)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "0e2a20b4-131f-49fa-baee-4c2675479de3", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "\"\"\"\n", 66 | "save to your local disk\n", 67 | "\"\"\"\n", 68 | "def save_dataset(dataset_dict, save_path):\n", 69 | " # save DatasetDict to your disk\n", 70 | " dataset_dict.save_to_disk(save_path)\n", 71 | "\n", 72 | "save_path = './share_data/your_local_dataset'\n", 73 | "save_dataset(dataset_dict, save_path)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "22aa1938-3b68-4df3-896d-74c8b7c854c3", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "\"\"\"\n", 84 | "read from your local disk\n", 85 | "\"\"\"\n", 86 | "def load_dataset(save_path):\n", 87 | " # load DatasetDict\n", 88 | " return DatasetDict.load_from_disk(save_path)\n", 89 | "\n", 90 | "test_dataset_dict = load_dataset('./share_data/your_local_dataset')" 91 | ] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "Python 3 (ipykernel)", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.10.13" 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 5 115 | } 116 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | ## Inference on LISA Dataset 2 | We've uploaded the model trained with 239 samples from the LISA dataset(🤗Huggingface). You can use the following code for inference to test the model's **reasoning grounding** capability (or use `lisa_demo.ipynb`). 3 | ```python 4 | import torch 5 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 6 | from qwen_vl_utils import process_vision_info 7 | import json 8 | import os 9 | from PIL import Image 10 | import logging 11 | from tqdm import tqdm 12 | import re 13 | import math 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | torch.manual_seed(1234) 17 | img2description = dict() 18 | 19 | SYSTEM_PROMPT = ( 20 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " 21 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " 22 | "process and answer are enclosed within and tags, respectively, i.e., " 23 | " reasoning process here answer here " 24 | ) 25 | 26 | model = Qwen2VLForConditionalGeneration.from_pretrained( 27 | "Zery/Qwen2-VL-7B_visual_rft_lisa_IoU_reward", device_map="auto" 28 | ).eval() 29 | 30 | processor = AutoProcessor.from_pretrained("Zery/Qwen2-VL-7B_visual_rft_lisa_IoU_reward") 31 | 32 | def prepare_inputs(img_path, instruction): 33 | messages = [ 34 | {"role": "system", "content": SYSTEM_PROMPT}, 35 | { 36 | "role": "user", 37 | "content": [ 38 | {"type": "image", "image": img_path}, 39 | {"type": "text", "text": f"Output the bounding box in the image corresponding to the instruction: {instruction}. Output the thinking process in and your grouding box. Following \" thinking process \n(x1,y1),(x2,y2))\" format."} 40 | ] 41 | } 42 | ] 43 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 44 | image_inputs, _ = process_vision_info(messages) 45 | inputs = processor( 46 | text=[text], 47 | images=image_inputs, 48 | padding=True, 49 | return_tensors="pt", 50 | ) 51 | return inputs.to("cuda") 52 | 53 | image_path = "assets/pokeymon.jpg" 54 | inputs = prepare_inputs(image_path, "the pokeymon that can perform Thunderbolt. Output thinking process as detail as possibile") 55 | 56 | with torch.no_grad(): 57 | generated_ids = model.generate(**inputs, max_new_tokens=128) 58 | response = processor.batch_decode( 59 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False 60 | )[0] 61 | print(response) 62 | ``` 63 | -------------------------------------------------------------------------------- /lisa_evaluation/Qwen2_VL_lisa_infere.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 3 | from qwen_vl_utils import process_vision_info 4 | import json 5 | import os 6 | from PIL import Image 7 | import logging 8 | from tqdm import tqdm 9 | import re 10 | # from process_utils import pred_2_point, extract_bbox 11 | import math 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | torch.manual_seed(1234) 15 | img2description = dict() 16 | 17 | SYSTEM_PROMPT = ( 18 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " 19 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " 20 | "process and answer are enclosed within and tags, respectively, i.e., " 21 | " reasoning process here answer here " 22 | ) 23 | 24 | # Load Qwen2-VL-2B model and processor 25 | model = Qwen2VLForConditionalGeneration.from_pretrained( 26 | "/path/to/your/checkpoint-498", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" 27 | ).eval() 28 | 29 | processor = AutoProcessor.from_pretrained("/path/to/your//checkpoint-498") 30 | 31 | logging.info("Model and processor loaded successfully") 32 | 33 | def process_image(image_path): 34 | if not os.path.exists(image_path): 35 | raise FileNotFoundError(f"Image not found: {image_path}") 36 | return Image.open(image_path) 37 | 38 | def prepare_inputs(img_path, instruction): 39 | messages = [ 40 | {"role": "system", "content": SYSTEM_PROMPT}, 41 | { 42 | "role": "user", 43 | "content": [ 44 | {"type": "image", "image": img_path}, 45 | {"type": "text", "text": f"Output the bounding box in the image corresponding to the instruction: {instruction}. Output the thinking process in and your grouding box. Following \" thinking process \n(x1,y1),(x2,y2))\" format."} 46 | ] 47 | } 48 | ] 49 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 50 | image_inputs, _ = process_vision_info(messages) 51 | inputs = processor( 52 | text=[text], 53 | images=image_inputs, 54 | padding=True, 55 | return_tensors="pt", 56 | ) 57 | return inputs.to("cuda") 58 | 59 | def extract_bbox(response): 60 | try: 61 | match = re.search(r"\[(\d+),(\d+),(\d+),(\d+)\]", response) 62 | if match: 63 | return [int(match.group(i)) for i in range(1, 5)] 64 | else: 65 | raise ValueError("Invalid response format") 66 | except Exception as e: 67 | logging.error(f"Error extracting bbox: {e}") 68 | return None 69 | 70 | def compute_iou(boxA, boxB): 71 | """ 72 | 计算 IoU (Intersection over Union) 73 | boxA, boxB 格式: [x1, y1, x2, y2] 74 | """ 75 | xA = max(boxA[0], boxB[0]) 76 | yA = max(boxA[1], boxB[1]) 77 | xB = min(boxA[2], boxB[2]) 78 | yB = min(boxA[3], boxB[3]) 79 | 80 | interArea = max(0, xB - xA) * max(0, yB - yA) 81 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) 82 | boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) 83 | 84 | iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6) 85 | return iou 86 | 87 | def evaluate_model(tasks): 88 | results = [] 89 | box_res =[] 90 | for task in tasks: 91 | logging.info(f"Processing task: {task}") 92 | ious = [] 93 | screenspot_data = json.load(open(f"path/to/lisa_{task}.json", 'r')) 94 | data_per_gpu = math.ceil(len(screenspot_data) / int(os.environ['SPLIT_NUM'])) 95 | start_idx = int(os.environ['SPLIT']) * data_per_gpu 96 | end_idx = min(start_idx + data_per_gpu, len(screenspot_data)) 97 | screenspot_data = screenspot_data[start_idx:end_idx] 98 | 99 | for item in tqdm(screenspot_data): 100 | img_path = item['image_path'] 101 | try: 102 | image = process_image(img_path) 103 | w, h = image.size 104 | instruction = item["instruction"][0] 105 | bbox = item["boxes"][0] 106 | bbox = [ 107 | bbox[0] / image.size[0], 108 | bbox[1] / image.size[1], 109 | bbox[2] / image.size[0], 110 | bbox[3] / image.size[1], 111 | ] 112 | inputs = prepare_inputs(img_path, instruction) 113 | 114 | with torch.no_grad(): 115 | generated_ids = model.generate(**inputs, max_new_tokens=128) 116 | response = processor.batch_decode( 117 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False 118 | )[0] 119 | print(response) 120 | pattern = r"\(\s*(\d+)\s*,\s*(\d+)\s*\)\s*,\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)" 121 | matches = re.findall(pattern, response) 122 | x1, y1, x2, y2 = map(int, matches[0]) 123 | pred_bbox = [int(x1) / 1000, int(y1) / 1000, int(x2) / 1000, int(y2) / 1000] 124 | 125 | iou = compute_iou(pred_bbox, bbox) 126 | box_res.append( 127 | { 128 | "image_pth": item['image_path'], 129 | "pred_bbox": pred_bbox, 130 | "thinking_process": response 131 | } 132 | ) 133 | ious.append(iou) 134 | except Exception as e: 135 | ious.append(0) 136 | json.dump(box_res, open(f"tmp/resbox_{os.environ['SPLIT']}_r1_w_think_7b.json", 'w'), indent=4) 137 | json.dump(ious, open(f"tmp/res_{os.environ['SPLIT']}.json", 'w'), indent=4) 138 | return 139 | 140 | if __name__ == "__main__": 141 | import argparse 142 | 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument('--task', type=str, required=True) 145 | args = parser.parse_args() 146 | 147 | if args.task == "all": 148 | tasks = ["val", "test"] 149 | else: 150 | tasks = [args.task] 151 | 152 | results = evaluate_model(tasks) 153 | -------------------------------------------------------------------------------- /lisa_evaluation/Qwen2_VL_lisa_infere.sh: -------------------------------------------------------------------------------- 1 | TASKS=("test" "val") 2 | # Adjust to your gpu num 3 | GPU_IDS=(0 1 2 3 4 5 6 7) 4 | SPLIT_NUM=8 5 | 6 | for task in "${TASKS[@]}"; do 7 | echo "Starting inference for task: $task" 8 | 9 | # 遍历 GPU 和 SPLIT 10 | for i in "${!GPU_IDS[@]}"; do 11 | GPU_ID=${GPU_IDS[$i]} 12 | SPLIT=$i 13 | echo "Launching task=$task on GPU=$GPU_ID with SPLIT=$SPLIT" 14 | SPLIT=$SPLIT SPLIT_NUM=$SPLIT_NUM python Qwen2_VL_lisa_infere.py \ 15 | --task $task & 16 | sleep 1 17 | done 18 | wait 19 | echo "Merging results for task: $task" 20 | SPLIT_NUM=$SPLIT_NUM python merge_eval.py >> res.txt 21 | done 22 | 23 | echo "All tasks completed!" 24 | -------------------------------------------------------------------------------- /lisa_evaluation/README.md: -------------------------------------------------------------------------------- 1 | ## ViRFT for reasoning grounding 2 | 3 | ## training 4 | 1. Download [LISA dataset](https://github.com/dvlab-research/LISA) 5 | 2. use `gen_box_ann.py` to generate box from mask. 6 | 3. use `gen_sft.py` to generate SFT/Visual-RFT training annotations. 7 | 4. use `src/scripts/2B_lisa_grounding.sh` to train the model, with annotation path changed to step.3 generated annotations. 8 | 9 | After training model, replace model path in `Qwen2_VL_lisa_infere.py` with your own ckpt. 10 | 11 | ```python 12 | # Load Qwen2-VL-2B model and processor 13 | model = Qwen2VLForConditionalGeneration.from_pretrained( 14 | "/path/to/your/checkpoint-498", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" 15 | ).eval() 16 | 17 | processor = AutoProcessor.from_pretrained("/path/to/your/checkpoint-498") 18 | ``` 19 | 20 | to compute gIoU, follow the process bellow. 21 | 1. Use `box2mask.py` to extract mask from [SAM](https://github.com/facebookresearch/segment-anything) 22 | 2. Use `mask_iou` to comput mask IoU. 23 | 24 | ```shell 25 | cd lisa_evaluation 26 | bash Qwen2_VL_lisa_infere.sh 27 | ``` 28 | -------------------------------------------------------------------------------- /lisa_evaluation/box2mask.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import numpy as np 5 | import cv2 6 | from pycocotools import mask as maskUtils 7 | from segment_anything import sam_model_registry, SamPredictor 8 | 9 | def load_json(json_path): 10 | with open(json_path, 'r') as f: 11 | data = json.load(f) 12 | return data 13 | 14 | def load_image(image_path): 15 | image = cv2.imread(image_path) 16 | if image is None: 17 | raise FileNotFoundError(f"Image not found at {image_path}") 18 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 19 | return image 20 | 21 | def scale_bbox(bbox, width, height): 22 | x1, y1, x2, y2 = bbox 23 | return [int(x1 * width), int(y1 * height), int(x2 * width), int(y2 * height)] 24 | 25 | def mask_to_rle(mask): 26 | rle = maskUtils.encode(np.asfortranarray(mask.astype(np.uint8))) 27 | rle["counts"] = rle["counts"].decode("utf-8") 28 | return rle 29 | 30 | def save_rle_to_json(rle_data_list, save_path): 31 | with open(save_path, 'w') as f: 32 | json.dump(rle_data_list, f) 33 | print(f"Saved all RLE data to {save_path}") 34 | 35 | def main(json_path, sam_checkpoint, output_json_path, model_type="vit_h"): 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | 38 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 39 | sam.to(device=device) 40 | predictor = SamPredictor(sam) 41 | 42 | data = load_json(json_path) 43 | 44 | all_rle_data = [] 45 | 46 | for idx, item in enumerate(data): 47 | image_pth = item["image_pth"] 48 | pred_bbox = item["pred_bbox"] 49 | 50 | image = load_image(image_pth) 51 | height, width, _ = image.shape 52 | 53 | abs_bbox = scale_bbox(pred_bbox, width, height) 54 | 55 | predictor.set_image(image) 56 | 57 | transformed_bbox = predictor.transform.apply_boxes_torch( 58 | torch.tensor([abs_bbox], device=device), image.shape[:2] 59 | ) 60 | masks, scores, _ = predictor.predict_torch( 61 | point_coords=None, 62 | point_labels=None, 63 | boxes=transformed_bbox, 64 | multimask_output=False, 65 | ) 66 | 67 | mask = masks[0][0].cpu().numpy() 68 | 69 | rle = mask_to_rle(mask) 70 | 71 | rle_data = { 72 | "image_pth": image_pth, 73 | "pred_bbox": pred_bbox, 74 | "rle_mask": rle 75 | } 76 | all_rle_data.append(rle_data) 77 | 78 | save_rle_to_json(all_rle_data, output_json_path) 79 | 80 | if __name__ == "__main__": 81 | json_path = f"/path/to/resbox" 82 | sam_checkpoint = "sam_vit_h_4b8939.pth" 83 | output_json_path = f"./all_masks_rle.json" 84 | 85 | main(json_path, sam_checkpoint, output_json_path) 86 | -------------------------------------------------------------------------------- /lisa_evaluation/evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d8fcd1a9-d7c5-4049-bfff-27dbc49b029d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from coco_evaluation import CocoDetectionEvaluator" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "aa90e100-072b-4f55-999c-a8e03d56fe87", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "evaluator = CocoDetectionEvaluator('./data/coco/annotations/instances_val2017.json')" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "585276a5-cfb2-4e2b-a9de-30a46cbf30c5", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "results, per_class_results = evaluator.evaluate('./prediction_Qwen2_vl_2B_GRPO_coco.json', './results')" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "92d76ac4-44aa-43b5-8123-d871453e6750", 37 | "metadata": { 38 | "scrolled": true 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "### mAP and AP for all categories\n", 43 | "results, per_class_results" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "014151c7-5a90-4d46-9782-7195cb7147ed", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "### mAP and AP for selected categories\n", 54 | "selected_cate = ['bus', 'train', 'fire hydrant', 'stop sign', 'cat', 'dog', 'bed', 'toilet']\n", 55 | "# selected_cate = ['mouse', 'fork', 'hot dog', 'cat', 'airplane', 'suitcase', 'parking meter', 'sandwich', 'train', 'hair drier', 'toilet', 'toaster', 'snowboard', 'frisbee', 'bear']\n", 56 | "results, per_class_results\n", 57 | "AP_sum = 0\n", 58 | "for item in per_class_results:\n", 59 | " for key, value in item.items():\n", 60 | " if key in selected_cate:\n", 61 | " print(f\"Key: {key}, Value: {value}\")\n", 62 | " AP_sum += value\n", 63 | "print(\"mAP for selected categories: \", (AP_sum)/(len(selected_cate)))" 64 | ] 65 | } 66 | ], 67 | "metadata": { 68 | "kernelspec": { 69 | "display_name": "Python 3 (ipykernel)", 70 | "language": "python", 71 | "name": "python3" 72 | }, 73 | "language_info": { 74 | "codemirror_mode": { 75 | "name": "ipython", 76 | "version": 3 77 | }, 78 | "file_extension": ".py", 79 | "mimetype": "text/x-python", 80 | "name": "python", 81 | "nbconvert_exporter": "python", 82 | "pygments_lexer": "ipython3", 83 | "version": "3.10.13" 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 5 88 | } 89 | -------------------------------------------------------------------------------- /lisa_evaluation/gen_box_ann.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image, ImageDraw 4 | 5 | res = [] 6 | base_path = "/path/to/your/LISA-main/data/train" 7 | 8 | for pth in os.listdir(base_path): 9 | if pth.endswith(".json"): 10 | json_path = os.path.join(base_path, pth) 11 | 12 | with open(json_path, 'r') as f: 13 | item = json.load(f) 14 | 15 | instruct = item["text"] 16 | shapes = item["shapes"] 17 | 18 | boxes = [] 19 | for shape in shapes[:1]: 20 | points = shape["points"] 21 | x_coords = [p[0] for p in points] 22 | y_coords = [p[1] for p in points] 23 | 24 | x_min, x_max = min(x_coords), max(x_coords) 25 | y_min, y_max = min(y_coords), max(y_coords) 26 | boxes.append((x_min, y_min, x_max, y_max)) 27 | 28 | img_path = json_path.replace(".json", ".jpg") 29 | if os.path.exists(img_path): 30 | res.append({ 31 | "image_path": img_path, 32 | "instruction": instruct, 33 | "boxes": boxes 34 | }) 35 | 36 | json.dump(res, open("lisa_train.json", 'w'), indent=4) 37 | -------------------------------------------------------------------------------- /lisa_evaluation/gen_sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | res = [] 5 | index = 0 6 | for i, item in enumerate(json.load(open("path/to/your/lisa_train.json", 'r'))): 7 | for instruct in item['instruction']: 8 | w, h= Image.open(item['image_path']).size 9 | res.append({ 10 | "id": f"lisa_{index}", 11 | "conversations": [ 12 | { 13 | "from": "user", 14 | "value": f"{item['image_path']}\n Output the bounding box in the image corresponding to the instruction: {instruct}" 15 | }, 16 | { 17 | "from": "assistant", 18 | "value": f"({int(item['boxes'][0][0] / w * 1000)},{int(item['boxes'][0][1] / h * 1000)}),({int(item['boxes'][0][2] / w * 1000)},{int(item['boxes'][0][3] / h * 1000)})" 19 | } 20 | ] 21 | }) 22 | index += 1 23 | json.dump(res, open("lisa_train_sft.json", 'w'), indent=4) 24 | -------------------------------------------------------------------------------- /lisa_evaluation/mask_iou.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from pycocotools import mask as maskUtils 5 | from PIL import Image, ImageDraw 6 | from tqdm import tqdm 7 | 8 | def polygon_to_mask(polygon, size): 9 | mask = Image.new('L', (size[1], size[0]), 0) 10 | ImageDraw.Draw(mask).polygon([tuple(point) for point in polygon], outline=1, fill=1) 11 | return np.array(mask, dtype=np.uint8) 12 | 13 | def rle_to_mask(rle): 14 | rle_decoded = maskUtils.decode(rle) 15 | return rle_decoded.astype(np.uint8) 16 | 17 | def compute_iou(pred_mask, gt_mask): 18 | intersection = np.logical_and(pred_mask, gt_mask).sum() 19 | union = np.logical_or(pred_mask, gt_mask).sum() 20 | iou = intersection / (union + 1e-6) 21 | return intersection, union, iou 22 | 23 | res = [] 24 | base_path = "/path/to/LISA-main/data/test" 25 | 26 | test_res = {} 27 | 28 | with open(f"all_masks_rle.json", 'r') as f: 29 | for item in json.load(f): 30 | test_res[item['image_pth']] = item['rle_mask'] 31 | 32 | total_intersection = 0 33 | total_union = 0 34 | ious = [] 35 | 36 | for pth in tqdm(os.listdir(base_path)): 37 | if pth.endswith(".json"): 38 | json_path = os.path.join(base_path, pth) 39 | 40 | with open(json_path, 'r') as f: 41 | item = json.load(f) 42 | 43 | try: 44 | gt_polygon = item["shapes"][0]['points'] 45 | except: 46 | print("no res") 47 | continue 48 | 49 | image_pth = f"/path/to/LISA-main/data/test/{pth.replace('.json', '.jpg')}" 50 | pred_rle = test_res.get(image_pth, None) 51 | 52 | 53 | with Image.open(image_pth) as img: 54 | img_width, img_height = img.size 55 | real_size = [img_height, img_width] 56 | 57 | if pred_rle is not None: 58 | size = pred_rle['size'] 59 | 60 | if size != real_size: 61 | print("shape mismatch") 62 | pred_mask = np.zeros((size[0], size[1]), dtype=np.uint8) 63 | continue 64 | pred_mask = rle_to_mask(pred_rle) 65 | else: 66 | print(f"Warning: No prediction found for {image_pth}. Setting IoU to 0.") 67 | pred_mask = np.zeros((size[0], size[1]), dtype=np.uint8) 68 | 69 | gt_mask = polygon_to_mask(gt_polygon, real_size) 70 | intersection, union, iou = compute_iou(pred_mask, gt_mask) 71 | 72 | total_intersection += intersection 73 | total_union += union 74 | 75 | ious.append(iou) 76 | 77 | gIoU = np.mean(ious) if ious else 0 78 | cIoU = total_intersection / (total_union + 1e-6) 79 | 80 | print(f"gIoU (Global IoU): {gIoU:.4f}") 81 | print(f"cIoU (Cumulative IoU): {cIoU:.4f}") 82 | -------------------------------------------------------------------------------- /lisa_evaluation/merge_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | merged = [] 4 | for i in range(int(os.environ['SPLIT_NUM'])): 5 | data = json.load(open(f"tmp/res_{i}.json", 'r')) 6 | merged += data 7 | print(f"mIoU: {sum(merged) / len(merged)}") 8 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | cd src/virft 2 | pip install -e ".[dev]" 3 | 4 | # Addtional modules 5 | pip install wandb==0.18.3 6 | pip install tensorboardx 7 | pip install qwen_vl_utils torchvision 8 | pip install flash-attn --no-build-isolation 9 | 10 | # vLLM support 11 | pip install vllm==0.7.2 12 | 13 | # fix transformers version 14 | pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef 15 | -------------------------------------------------------------------------------- /src/scripts/2B_aircraft_4_shot.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" 2 | export LOG_PATH="./debug_log_2b_GRPO_aircraft_4_shot.txt" 3 | 4 | export DATA_PATH=./share_data/ViRFT_CLS_fgvc_aircraft_4_shot 5 | export CKPT_PATH=./share_models/Qwen2-VL-2B-Instruct 6 | export SAVE_PATH=./share_models/Qwen2-VL-2B-Instruct_GRPO_aircraft_4_shot 7 | 8 | 9 | torchrun --nproc_per_node="8" \ 10 | --nnodes="1" \ 11 | --node_rank="0" \ 12 | --master_addr="127.0.0.1" \ 13 | --master_port="12345" \ 14 | /src/open_r1/grpo_classification.py \ 15 | --output_dir ${SAVE_PATH} \ 16 | --model_name_or_path ${CKPT_PATH} \ 17 | --dataset_name ${DATA_PATH} \ 18 | --deepspeed /local_scripts/zero3.json \ 19 | --max_prompt_length 1024 \ 20 | --per_device_train_batch_size 1 \ 21 | --gradient_accumulation_steps 2 \ 22 | --logging_steps 1 \ 23 | --bf16 \ 24 | --report_to wandb \ 25 | --gradient_checkpointing false \ 26 | --attn_implementation flash_attention_2 \ 27 | --max_pixels 401408 \ 28 | --num_train_epochs 1 \ 29 | --run_name Qwen2-VL-2B_GRPO_aircraft100_4shot \ 30 | --save_steps 100 \ 31 | --save_only_model true \ 32 | --num_generations 8 33 | -------------------------------------------------------------------------------- /src/scripts/2B_base65cate_6k.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=./share_data/base65cate_6k_think 2 | export CKPT_PATH=./share_models/Qwen2-VL-2B-Instruct 3 | export SAVE_PATH=./share_models/Qwen2-VL-2B-Instruct_GRPO_coco_base65cate_6k 4 | 5 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL 6 | export LOG_PATH="./debug_log_2b_GRPO_coco_base65cate_6k.txt" 7 | 8 | torchrun --nproc_per_node="8" \ 9 | --nnodes="1" \ 10 | --node_rank="0" \ 11 | --master_addr="127.0.0.1" \ 12 | --master_port="12345" \ 13 | src/open_r1/grpo.py \ 14 | --output_dir ${SAVE_PATH} \ 15 | --model_name_or_path ${CKPT_PATH} \ 16 | --dataset_name ${DATA_PATH} \ 17 | --deepspeed ./local_scripts/zero3.json \ 18 | --max_prompt_length 1024 \ 19 | --per_device_train_batch_size 1 \ 20 | --gradient_accumulation_steps 2 \ 21 | --logging_steps 1 \ 22 | --bf16 \ 23 | --report_to wandb \ 24 | --gradient_checkpointing false \ 25 | --attn_implementation flash_attention_2 \ 26 | --max_pixels 401408 \ 27 | --num_train_epochs 2 \ 28 | --run_name Qwen2-VL-2B_GRPO_coco_base65cate_6k \ 29 | --save_steps 100 \ 30 | --save_only_model true \ 31 | --num_generations 8 ' 32 | -------------------------------------------------------------------------------- /src/scripts/2B_lisa_grounding.sh: -------------------------------------------------------------------------------- 1 | # cd src/open-r1-multimodal 2 | 3 | export DEBUG_MODE="true" 4 | export LOG_PATH="./debug_log_2b.txt" 5 | torchrun --nproc_per_node="8" \ 6 | --nnodes="1" \ 7 | --node_rank="0" \ 8 | --master_addr="127.0.0.1" \ 9 | --master_port="12346" \ 10 | src/open_r1/grpo_gui_grounding_lisa.py \ 11 | --output_dir "out/lisa_train_GIoU" \ 12 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 13 | --dataset_name NOT_USED \ 14 | --deepspeed local_scripts/zero3.json \ 15 | --max_prompt_length 1024 \ 16 | --per_device_train_batch_size 1 \ 17 | --gradient_accumulation_steps 2 \ 18 | --logging_steps 1 \ 19 | --bf16 \ 20 | --gradient_checkpointing true \ 21 | --attn_implementation flash_attention_2 \ 22 | --max_pixels 401408 \ 23 | --num_train_epochs 6 \ 24 | --run_name Qwen2-VL-2B-GRPO-groud_lisa_train \ 25 | --save_steps 50 \ 26 | --save_only_model true \ 27 | --num_generations 8 # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance 28 | -------------------------------------------------------------------------------- /src/scripts/7B_base65cate_6k.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=./share_data/base65cate_6k_think 2 | export CKPT_PATH=./share_models/Qwen2-VL-7B-Instruct 3 | export SAVE_PATH=./share_models/Qwen2-VL-7B-Instruct_GRPO_coco_base65cate_6k 4 | 5 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL 6 | export LOG_PATH="./debug_log_7b_GRPO_coco_base65cate_6k.txt" 7 | 8 | torchrun --nproc_per_node="8" \ 9 | --nnodes="1" \ 10 | --node_rank="0" \ 11 | --master_addr="127.0.0.1" \ 12 | --master_port="12345" \ 13 | src/open_r1/grpo.py \ 14 | --output_dir ${SAVE_PATH} \ 15 | --model_name_or_path ${CKPT_PATH} \ 16 | --dataset_name ${DATA_PATH} \ 17 | --deepspeed ./local_scripts/zero3.json \ 18 | --max_prompt_length 1024 \ 19 | --per_device_train_batch_size 1 \ 20 | --gradient_accumulation_steps 2 \ 21 | --logging_steps 1 \ 22 | --bf16 \ 23 | --report_to wandb \ 24 | --gradient_checkpointing False \ 25 | --attn_implementation flash_attention_2 \ 26 | --max_pixels 401408 \ 27 | --num_train_epochs 2 \ 28 | --run_name Qwen2-VL-7B_GRPO_coco_base65cate_6k \ 29 | --save_steps 100 \ 30 | --save_only_model true \ 31 | --num_generations 4' 32 | -------------------------------------------------------------------------------- /src/scripts/example.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" 2 | export LOG_PATH="./debug_log_2b_GRPO_coco_base65cate_6k.txt" 3 | 4 | export DATA_PATH=./share_data/base65cate_6k_think 5 | export CKPT_PATH=./share_models/Qwen2-VL-2B-Instruct 6 | export SAVE_PATH=./share_models/Qwen2-VL-2B-Instruct_GRPO_coco_base65cate_6k 7 | 8 | 9 | torchrun --nproc_per_node="8" \ 10 | --nnodes="1" \ 11 | --node_rank="0" \ 12 | --master_addr="127.0.0.1" \ 13 | --master_port="12345" \ 14 | src/open_r1/grpo.py \ 15 | --output_dir ${SAVE_PATH} \ 16 | --model_name_or_path ${CKPT_PATH} \ 17 | --dataset_name ${DATA_PATH} \ 18 | --deepspeed local_scripts/zero3.json \ 19 | --max_prompt_length 1024 \ 20 | --per_device_train_batch_size 1 \ 21 | --gradient_accumulation_steps 2 \ 22 | --logging_steps 1 \ 23 | --bf16 \ 24 | --report_to wandb \ 25 | --gradient_checkpointing false \ 26 | --attn_implementation flash_attention_2 \ 27 | --max_pixels 401408 \ 28 | --num_train_epochs 2 \ 29 | --run_name Qwen2-VL-2B_GRPO_coco_base65cate_6k \ 30 | --save_steps 100 \ 31 | --save_only_model true \ 32 | --num_generations 8 ' 33 | -------------------------------------------------------------------------------- /src/virft/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # Temp folders 174 | data/ 175 | wandb/ 176 | scripts/ 177 | checkpoints/ 178 | .vscode/ -------------------------------------------------------------------------------- /src/virft/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: style quality 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := src 7 | 8 | style: 9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py 10 | isort $(check_dirs) setup.py 11 | 12 | quality: 13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py 14 | isort --check-only $(check_dirs) setup.py 15 | flake8 --max-line-length 119 $(check_dirs) setup.py 16 | 17 | 18 | # Evaluation 19 | 20 | evaluate: 21 | -------------------------------------------------------------------------------- /src/virft/README.md: -------------------------------------------------------------------------------- 1 | # Visual-RFT 2 | -------------------------------------------------------------------------------- /src/virft/configs/ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /src/virft/configs/zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false -------------------------------------------------------------------------------- /src/virft/configs/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /src/virft/local_scripts/create_vision_cot_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import concurrent.futures 4 | import io 5 | import json 6 | import os 7 | import random 8 | import re 9 | import time 10 | from concurrent.futures import ThreadPoolExecutor 11 | from functools import partial 12 | from io import BytesIO 13 | from typing import Dict, List 14 | 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import pandas as pd 18 | from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk 19 | from tqdm import tqdm 20 | 21 | import bytedtos 22 | import seaborn as sns 23 | import yaml 24 | from openai import AzureOpenAI 25 | from PIL import Image 26 | from pillow_avif import AvifImagePlugin 27 | 28 | 29 | PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions. 30 | 31 | Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A". 32 | 33 | Please strictly do not include "Answer:" in the question part to avoid confusion and leakage. 34 | 35 | Input Format: 36 | Original Question: {original_question} 37 | Original Answer: {original_answer} 38 | 39 | Output Format: 40 | Question: [rewrite the question if necessary] 41 | Answer: [answer with reasoning steps, including calculations where applicable] 42 | step-by-step reasoning process 43 | easy to verify answer 44 | """ 45 | 46 | 47 | def get_image_data_url(image_input): 48 | if isinstance(image_input, str) and image_input.startswith("data:"): 49 | return image_input 50 | 51 | if isinstance(image_input, str) and image_input.startswith("http"): 52 | image_input = load_image(image_input) 53 | 54 | if isinstance(image_input, str): 55 | image_input = Image.open(image_input) 56 | 57 | if not isinstance(image_input, Image.Image): 58 | raise ValueError("Unsupported image input type") 59 | 60 | if image_input.mode != "RGB": 61 | image_input = image_input.convert("RGB") 62 | 63 | buffer = BytesIO() 64 | image_input.save(buffer, format="JPEG") 65 | img_bytes = buffer.getvalue() 66 | base64_data = base64.b64encode(img_bytes).decode("utf-8") 67 | return f"data:image/jpeg;base64,{base64_data}" 68 | 69 | 70 | def gpt4o_query(image, prompt, max_retries=5, initial_delay=3): 71 | if image is None: 72 | return None 73 | 74 | data_url_list = [get_image_data_url(image)] 75 | client = AzureOpenAI( 76 | azure_endpoint="YOUR_AZURE_ENDPOINT", 77 | api_version="2023-07-01-preview", 78 | api_key="YOUR_API_KEY", 79 | ) 80 | 81 | for attempt in range(max_retries): 82 | try: 83 | messages = [ 84 | { 85 | "role": "system", 86 | "content": "You are an expert to analyze the image and provide useful information for users.", 87 | }, 88 | { 89 | "role": "user", 90 | "content": [ 91 | {"type": "text", "text": prompt}, 92 | ], 93 | }, 94 | ] 95 | 96 | for data_url in data_url_list: 97 | messages[1]["content"].insert( 98 | 0, {"type": "image_url", "image_url": {"url": data_url}} 99 | ) 100 | 101 | response = client.chat.completions.create( 102 | model="gpt-4o-2024-08-06", 103 | messages=messages, 104 | temperature=0.2, 105 | max_tokens=8192, 106 | ) 107 | return response.choices[0].message.content 108 | 109 | except Exception as e: 110 | if attempt == max_retries - 1: 111 | raise Exception( 112 | f"Failed after {max_retries} attempts. Last error: {str(e)}" 113 | ) 114 | delay = initial_delay * (2**attempt) + random.uniform( 115 | 0, 0.1 * initial_delay * (2**attempt) 116 | ) 117 | time.sleep(delay) 118 | 119 | 120 | def process_single_item(example): 121 | try: 122 | image_path = example["image_path"] 123 | formatted_prompt = PROMPT_FORMAT.format( 124 | original_question=example["question"], original_answer=example["answer"] 125 | ) 126 | 127 | response = gpt4o_query(image_path, formatted_prompt) 128 | example["gpt4o_response"] = response 129 | return example 130 | except Exception as e: 131 | print(f"Error processing item: {str(e)}") 132 | example["gpt4o_response"] = None 133 | return example 134 | 135 | 136 | def main(): 137 | dataset_path = "path/to/your/dataset" 138 | full_dataset = load_from_disk(dataset_path) 139 | 140 | processed_dataset = full_dataset.map( 141 | function=partial(process_single_item), 142 | num_proc=256, 143 | desc="Processing dataset with GPT-4o", 144 | keep_in_memory=True, 145 | ) 146 | 147 | output_path = f"{dataset_path}_processed" 148 | processed_dataset.save_to_disk(output_path) 149 | print(f"Processed dataset saved to: {output_path}") 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /src/virft/local_scripts/lmms_eval_qwen2vl.sh: -------------------------------------------------------------------------------- 1 | export HF_HOME="" 2 | export HF_TOKEN="" 3 | export HF_HUB_ENABLE_HF_TRANSFER="1" 4 | 5 | export API_TYPE="" 6 | export AZURE_ENDPOINT="" 7 | export AZURE_API_KEY="" 8 | export API_VERSION="" 9 | export MODEL_VERSION="" 10 | export NAVIT_ATTENTION_IMPLEMENTATION="eager" 11 | 12 | # Prompt for installation with 3-second timeout 13 | read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true 14 | if [ "$install_deps" = "YES" ]; then 15 | # Prepare the environment 16 | pip3 install --upgrade pip 17 | pip3 install -U setuptools 18 | 19 | cd 20 | if [ ! -d "maas_engine" ]; then 21 | git clone 22 | else 23 | echo "maas_engine directory already exists, skipping clone" 24 | fi 25 | cd maas_engine 26 | git pull 27 | git checkout 28 | pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]" 29 | 30 | current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2) 31 | if [ "$current_version" != "4.46.2" ]; then 32 | echo "Installing transformers 4.46.2 (current version: $current_version)" 33 | pip3 install transformers==4.46.2 34 | else 35 | echo "transformers 4.46.2 is already installed" 36 | fi 37 | 38 | cd 39 | rm -rf 40 | pip3 install -e . 41 | pip3 install -U pydantic 42 | pip3 install Levenshtein 43 | pip3 install nltk 44 | python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)" 45 | fi 46 | 47 | TASKS=mmmu_val,mathvista_testmini,mmmu_pro 48 | MODEL_BASENAME=qwen2_vl 49 | 50 | model_checkpoint="" 51 | echo "MODEL_BASENAME: ${MODEL_BASENAME}" 52 | cd 53 | 54 | python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \ 55 | --model qwen2_vl \ 56 | --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \ 57 | --tasks ${TASKS} \ 58 | --batch_size 1 \ 59 | --log_samples \ 60 | --log_samples_suffix ${MODEL_BASENAME} \ 61 | --output_path ./logs -------------------------------------------------------------------------------- /src/virft/local_scripts/prepare_hf_data.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | import random 5 | from typing import List, Dict 6 | import numpy as np 7 | from concurrent.futures import ThreadPoolExecutor 8 | from tqdm import tqdm 9 | import datasets 10 | 11 | import io 12 | from datasets import load_dataset, load_from_disk, concatenate_datasets 13 | from PIL import Image 14 | from tqdm import tqdm 15 | from functools import partial 16 | from pillow_avif import AvifImagePlugin 17 | from datasets import Dataset 18 | import json 19 | import yaml 20 | import os 21 | import re 22 | import time 23 | import random 24 | import base64 25 | from openai import AzureOpenAI 26 | import concurrent.futures 27 | from typing import List, Dict 28 | import argparse 29 | import time 30 | 31 | 32 | def extract_problem_solution(gpt4o_response): 33 | # Split the response into parts 34 | parts = gpt4o_response.split("") 35 | 36 | # Extract the problem (first part before any tags) 37 | problem = parts[0].strip() 38 | # Remove "Question:" prefix if it exists 39 | problem = re.sub(r"^Question:\s*", "", problem) 40 | # Remove "Answer:" at the end of the problem 41 | problem = re.sub(r"\s*Answer:\s*$", "", problem).strip() 42 | 43 | # Combine all the reasoning steps into a single block 44 | think_parts = [p.split("")[0].strip() for p in parts[1:] if "" in p] 45 | solution = f"{' '.join(think_parts)}" 46 | 47 | # Add the final answer if it exists, removing "Answer:" prefix 48 | if "" in gpt4o_response: 49 | final_answer = ( 50 | gpt4o_response.split("")[-1].split("")[0].strip() 51 | ) 52 | final_answer = re.sub(r"^Answer:\s*", "", final_answer) 53 | solution += f"\n\n{final_answer}" 54 | 55 | return problem, solution 56 | 57 | 58 | def load_image_from_path(image_path): 59 | try: 60 | img = Image.open(image_path) 61 | return img 62 | except Exception as e: 63 | print(f"Error loading image {image_path}: {str(e)}") 64 | return None 65 | 66 | 67 | def process_raw_data(raw_data): 68 | # Parse the raw data if it's a string 69 | if isinstance(raw_data, str): 70 | data = json.loads(raw_data) 71 | else: 72 | data = raw_data 73 | 74 | # Extract problem and solution 75 | try: 76 | problem, solution = extract_problem_solution(data["gpt4o_response"]) 77 | image = load_image_from_path(data["image_path"]) 78 | 79 | return { 80 | "image": image, 81 | "problem": problem, 82 | "solution": solution, 83 | "original_question": data["question"], 84 | "original_answer": data["answer"], 85 | } 86 | except Exception as e: 87 | print(f"Error processing data {data}: {str(e)}") 88 | return { 89 | "image": None, 90 | "problem": None, 91 | "solution": None, 92 | "original_question": None, 93 | "original_answer": None, 94 | } 95 | 96 | 97 | raw_data_list = [ 98 | "/path/to/reasoning_data_with_response_90k_verified", 99 | ] 100 | 101 | raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list]) 102 | 103 | processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42) 104 | 105 | hf_dict = { 106 | "image": [], 107 | "problem": [], 108 | "solution": [], 109 | "original_question": [], 110 | "original_answer": [], 111 | } 112 | 113 | for item in tqdm(processed_data): 114 | hf_dict["image"].append(item["image"]) 115 | hf_dict["problem"].append(item["problem"]) 116 | hf_dict["solution"].append(item["solution"]) 117 | hf_dict["original_question"].append(item["original_question"]) 118 | hf_dict["original_answer"].append(item["original_answer"]) 119 | 120 | 121 | features = datasets.Features( 122 | { 123 | "image": datasets.Image(), 124 | "problem": datasets.Value("string"), 125 | "solution": datasets.Value("string"), 126 | "original_question": datasets.Value("string"), 127 | "original_answer": datasets.Value("string"), 128 | } 129 | ) 130 | 131 | 132 | def has_empty_tags(text): 133 | # Pattern to match empty tags like 134 | pattern = r"<[^>]+>]+>" 135 | return bool(re.search(pattern, text)) 136 | 137 | 138 | def has_answer_pattern(text): 139 | if "Answer:" in text: 140 | return True 141 | return False 142 | 143 | 144 | def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement 145 | # Assuming the image is in a format that can be checked for dimensions 146 | # You might need to adjust this depending on how the image is stored in your dataset 147 | try: 148 | image = example["image"] # or however your image is accessed 149 | if isinstance(image, dict) and "height" in image and "width" in image: 150 | return image["height"] >= 28 and image["width"] >= 28 151 | # If image is a PIL Image or similar 152 | return image.height >= 28 and image.width >= 28 153 | except: 154 | return False 155 | 156 | 157 | ds = datasets.Dataset.from_dict(hf_dict, features=features) 158 | ds = ds.filter( 159 | lambda x: not has_empty_tags(x["solution"]) 160 | and not has_answer_pattern(x["problem"]) 161 | and has_valid_image_size(x) 162 | and x["image"] is not None, 163 | num_proc=128, 164 | ) 165 | # Push to Hugging Face Hub 166 | ds.push_to_hub("path/to/your/dataset") 167 | -------------------------------------------------------------------------------- /src/virft/local_scripts/train_aria_moe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_BLOCKING_WAIT=0 4 | export TOKENIZERS_PARALLELISM=false 5 | export OMP_NUM_THREADS=8 6 | export NCCL_IB_DISABLE=0 7 | export NCCL_IB_GID_INDEX=3 8 | export NCCL_SOCKET_IFNAME=eth0 9 | export NCCL_DEBUG=INFO 10 | 11 | # CONFIG Huggingface 12 | # export HF_TOKEN="" 13 | export HF_TOKEN="" 14 | export HF_HOME="$HOME/.cache/huggingface" 15 | export HF_HUB_ENABLE_HF_TRANSFER="1" 16 | 17 | export NCCL_DEBUG=INFO 18 | 19 | GPUS="0,1,2,3,4,5,6,7" 20 | 21 | # 取 worker0 第一个 port 22 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' ')) 23 | port=${ports[0]} 24 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')" 25 | 26 | echo "total workers: ${ARNOLD_WORKER_NUM}" 27 | echo "cur worker id: ${ARNOLD_ID}" 28 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}" 29 | echo "master ip: ${METIS_WORKER_0_HOST}" 30 | echo "master port: ${port}" 31 | echo "master port in cmd: ${port_in_cmd}" 32 | 33 | # export WANDB_BASE_URL=https://api.wandb.ai 34 | # export WANDB_API_KEY="" 35 | # wandb login $WANDB_API_KEY 36 | 37 | export WANDB_BASE_URL=https://api.wandb.ai 38 | export WANDB_PROJECT=vision-reasoning 39 | export WANDB_API_KEY="" 40 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S) 41 | wandb login $WANDB_API_KEY 42 | 43 | cd /home/tiger/multimodal-open-r1 44 | # pip3 install vllm==0.6.6.post1 45 | pip3 install -e ".[dev]" 46 | pip3 install wandb==0.18.3 47 | 48 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \ 49 | --nnodes="${ARNOLD_WORKER_NUM}" \ 50 | --node_rank="${ARNOLD_ID}" \ 51 | --master_addr="${METIS_WORKER_0_HOST}" \ 52 | --master_port="${port_in_cmd}" \ 53 | src/open_r1/grpo.py \ 54 | --deepspeed scripts/zero3.json \ 55 | --output_dir Aria-GRPO-mini_cot_80k \ 56 | --model_name_or_path rhymes-ai/Aria \ 57 | --dataset_name luodian/mini_cot_80k \ 58 | --max_prompt_length 8192 \ 59 | --per_device_train_batch_size 1 \ 60 | --gradient_accumulation_steps 1 \ 61 | --logging_steps 1 \ 62 | --bf16 \ 63 | --report_to wandb \ 64 | --gradient_checkpointing true \ 65 | --attn_implementation eager \ 66 | --save_total_limit 8 \ 67 | --num_train_epochs 1 \ 68 | --run_name $WANDB_RUN_NAME 69 | -------------------------------------------------------------------------------- /src/virft/local_scripts/train_qwen2_vl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_BLOCKING_WAIT=0 4 | export TOKENIZERS_PARALLELISM=false 5 | export OMP_NUM_THREADS=8 6 | export NCCL_IB_DISABLE=0 7 | export NCCL_IB_GID_INDEX=3 8 | export NCCL_SOCKET_IFNAME=eth0 9 | export NCCL_DEBUG=INFO 10 | 11 | GPUS="0,1,2,3,4,5,6,7" 12 | 13 | # 取 worker0 第一个 port 14 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' ')) 15 | port=${ports[0]} 16 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')" 17 | 18 | echo "total workers: ${ARNOLD_WORKER_NUM}" 19 | echo "cur worker id: ${ARNOLD_ID}" 20 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}" 21 | echo "master ip: ${METIS_WORKER_0_HOST}" 22 | echo "master port: ${port}" 23 | echo "master port in cmd: ${port_in_cmd}" 24 | 25 | # export WANDB_BASE_URL=https://api.wandb.ai 26 | # export WANDB_API_KEY="" 27 | # wandb login $WANDB_API_KEY 28 | 29 | export WANDB_BASE_URL=https://api.wandb.ai 30 | export WANDB_PROJECT=vision-reasoning 31 | export WANDB_API_KEY="" 32 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S) 33 | wandb login $WANDB_API_KEY 34 | 35 | cd /home/tiger/multimodal-open-r1 36 | # pip3 install vllm==0.6.6.post1 37 | pip3 install -e ".[dev]" 38 | pip3 install wandb==0.18.3 39 | 40 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \ 41 | --nnodes="${ARNOLD_WORKER_NUM}" \ 42 | --node_rank="${ARNOLD_ID}" \ 43 | --master_addr="${METIS_WORKER_0_HOST}" \ 44 | --master_port="${port_in_cmd}" \ 45 | src/open_r1/grpo.py \ 46 | --deepspeed scripts/zero3.json \ 47 | --output_dir checkpoints/${WANDB_RUN_NAME} \ 48 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 49 | --dataset_name luodian/${DATASET_NAME} \ 50 | --max_prompt_length 8192 \ 51 | --per_device_train_batch_size 1 \ 52 | --gradient_accumulation_steps 1 \ 53 | --logging_steps 1 \ 54 | --bf16 \ 55 | --report_to wandb \ 56 | --gradient_checkpointing true \ 57 | --attn_implementation flash_attention_2 \ 58 | --max_pixels 2359296 \ 59 | --save_total_limit 8 \ 60 | --num_train_epochs 1 \ 61 | --run_name $WANDB_RUN_NAME 62 | -------------------------------------------------------------------------------- /src/virft/local_scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /src/virft/local_scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /src/virft/local_scripts/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /src/virft/local_scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /src/virft/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = open_r1 7 | known_third_party = 8 | transformers 9 | datasets 10 | fugashi 11 | git 12 | h5py 13 | matplotlib 14 | nltk 15 | numpy 16 | packaging 17 | pandas 18 | psutil 19 | pytest 20 | rouge_score 21 | sacrebleu 22 | seqeval 23 | sklearn 24 | streamlit 25 | torch 26 | tqdm 27 | 28 | line_length = 119 29 | lines_after_imports = 2 30 | multi_line_output = 3 31 | use_parentheses = True 32 | 33 | [flake8] 34 | ignore = E203, E501, E741, W503, W605 35 | max-line-length = 119 36 | per-file-ignores = 37 | # imported but unused 38 | __init__.py: F401 39 | 40 | [tool:pytest] 41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS -------------------------------------------------------------------------------- /src/virft/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py 16 | 17 | 18 | import re 19 | import shutil 20 | from pathlib import Path 21 | 22 | from setuptools import find_packages, setup 23 | 24 | 25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info" 27 | if stale_egg_info.exists(): 28 | print( 29 | ( 30 | "Warning: {} exists.\n\n" 31 | "If you recently updated open_r1, this is expected,\n" 32 | "but it may prevent open_r1 from installing in editable mode.\n\n" 33 | "This directory is automatically generated by Python's packaging tools.\n" 34 | "I will remove it now.\n\n" 35 | "See https://github.com/pypa/pip/issues/5466 for details.\n" 36 | ).format(stale_egg_info) 37 | ) 38 | shutil.rmtree(stale_egg_info) 39 | 40 | 41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any. 42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version 43 | _deps = [ 44 | "accelerate>=1.2.1", 45 | "bitsandbytes>=0.43.0", 46 | "black>=24.4.2", 47 | "datasets>=3.2.0", 48 | "deepspeed==0.15.4", 49 | "distilabel[vllm,ray,openai]>=1.5.2", 50 | "einops>=0.8.0", 51 | "flake8>=6.0.0", 52 | "hf_transfer>=0.1.4", 53 | "huggingface-hub[cli]>=0.19.2,<1.0", 54 | "isort>=5.12.0", 55 | "liger_kernel==0.5.2", 56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]", 57 | "math-verify", # Used for math verification in grpo 58 | "packaging>=23.0", 59 | "parameterized>=0.9.0", 60 | "pytest", 61 | "safetensors>=0.3.3", 62 | "sentencepiece>=0.1.99", 63 | "torch>=2.5.1", 64 | "transformers @ git+https://github.com/huggingface/transformers.git@main", 65 | "trl @ git+https://github.com/huggingface/trl.git@main", 66 | "vllm==0.6.6.post1", 67 | "wandb>=0.19.1", 68 | "pillow", 69 | ] 70 | 71 | # this is a lookup table with items like: 72 | # 73 | # tokenizers: "tokenizers==0.9.4" 74 | # packaging: "packaging" 75 | # 76 | # some of the values are versioned whereas others aren't. 77 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} 78 | 79 | 80 | def deps_list(*pkgs): 81 | return [deps[pkg] for pkg in pkgs] 82 | 83 | 84 | extras = {} 85 | extras["tests"] = deps_list("pytest", "parameterized") 86 | extras["torch"] = deps_list("torch") 87 | extras["quality"] = deps_list("black", "isort", "flake8") 88 | extras["eval"] = deps_list("lighteval", "math-verify") 89 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] 90 | 91 | # core dependencies shared across the whole project - keep this to a bare minimum :) 92 | install_requires = [ 93 | deps["accelerate"], 94 | deps["bitsandbytes"], 95 | deps["einops"], 96 | deps["datasets"], 97 | deps["deepspeed"], 98 | deps["hf_transfer"], 99 | deps["huggingface-hub"], 100 | deps["liger_kernel"], 101 | deps["packaging"], # utilities from PyPA to e.g., compare versions 102 | deps["safetensors"], 103 | deps["sentencepiece"], 104 | deps["transformers"], 105 | deps["trl"], 106 | ] 107 | 108 | setup( 109 | name="open-r1", 110 | version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 111 | author="The Hugging Face team (past and future)", 112 | author_email="lewis@huggingface.co", 113 | description="Open R1", 114 | long_description="This is the Visual-RFT project.", 115 | long_description_content_type="text/markdown", 116 | keywords="llm inference-time compute reasoning", 117 | license="Apache", 118 | url="https://github.com/huggingface/open-r1", 119 | package_dir={"": "src"}, 120 | packages=find_packages("src"), 121 | zip_safe=False, 122 | extras_require=extras, 123 | python_requires=">=3.10.9", 124 | install_requires=install_requires, 125 | classifiers=[ 126 | "Development Status :: 3 - Alpha", 127 | "Intended Audience :: Developers", 128 | "Intended Audience :: Education", 129 | "Intended Audience :: Science/Research", 130 | "License :: OSI Approved :: Apache Software License", 131 | "Operating System :: OS Independent", 132 | "Programming Language :: Python :: 3", 133 | "Programming Language :: Python :: 3.10", 134 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 135 | ], 136 | ) 137 | -------------------------------------------------------------------------------- /src/virft/slurm/evaluate.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=open-r1-evaluate 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --exclusive 6 | #SBATCH --gres=gpu:8 7 | #SBATCH --partition=hopper-prod 8 | #SBATCH --time=01:59:00 9 | #SBATCH --output=./logs/evaluate/%x-%j.out 10 | #SBATCH --err=./logs/evaluate/%x-%j.err 11 | 12 | # Usage: sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B aime24 13 | 14 | set -x -e 15 | 16 | source ~/.bashrc 17 | conda activate openr1 18 | module load cuda/12.1 19 | echo "START TIME: $(date)" 20 | echo "PYTHON ENV: $(which python)" 21 | 22 | 23 | NUM_GPUS=8 24 | MODEL=$1 25 | TASK=$2 26 | MODEL_ARGS="pretrained=$MODEL,dtype=float16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" 27 | OUTPUT_DIR=data/evals/$MODEL 28 | 29 | 30 | # force crashing on nccl issues like hanging broadcast 31 | export NCCL_ASYNC_ERROR_HANDLING=1 32 | # export NCCL_DEBUG=INFO 33 | # export NCCL_DEBUG_SUBSYS=COLL 34 | # export NCCL_SOCKET_NTHREADS=1 35 | # export NCCL_NSOCKS_PERTHREAD=1 36 | # export CUDA_LAUNCH_BLOCKING=1 37 | 38 | # Specific configuration optimized for the Hugging Face Compute Cluster 39 | # Be ye warned this may not work on other clusters! 40 | module load cuda/12.1 41 | 42 | lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ 43 | --custom-tasks src/open_r1/evaluate.py \ 44 | --use-chat-template \ 45 | --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \ 46 | --output-dir $OUTPUT_DIR 47 | 48 | 49 | echo "END TIME: $(date)" 50 | -------------------------------------------------------------------------------- /src/virft/slurm/sft.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=open-r1-sft 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --exclusive 6 | #SBATCH --gres=gpu:8 7 | #SBATCH --partition=hopper-prod 8 | #SBATCH --output=./logs/%x-%j.out 9 | #SBATCH --err=./logs/%x-%j.err 10 | 11 | set -x -e 12 | 13 | source ~/.bashrc 14 | conda activate openr1 15 | module load cuda/12.1 16 | echo "START TIME: $(date)" 17 | echo "PYTHON ENV: $(which python)" 18 | 19 | MODEL_PATH=$1 20 | DATASET_PATH=$2 21 | ACCELERATOR=$3 22 | 23 | # Training setup 24 | NUM_NODES=$SLURM_NNODES 25 | GPUS_PER_NODE=8 26 | WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE)) 27 | 28 | # so processes know who to talk to 29 | MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 30 | MASTER_PORT=6000 31 | 32 | export CMD=" \ 33 | src/open_r1/sft.py \ 34 | --model_name_or_path $MODEL_PATH \ 35 | --dataset_name $DATASET_PATH \ 36 | --use_liger_kernel true \ 37 | --learning_rate 2.0e-5 \ 38 | --num_train_epochs 1 \ 39 | --packing \ 40 | --max_seq_length 4096 \ 41 | --per_device_train_batch_size 4 \ 42 | --per_device_eval_batch_size 4 \ 43 | --gradient_accumulation_steps 4 \ 44 | --gradient_checkpointing \ 45 | --bf16 \ 46 | --logging_steps 5 \ 47 | --eval_strategy steps \ 48 | --eval_steps 100 \ 49 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill 50 | " 51 | 52 | export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ 53 | --config_file configs/$ACCELERATOR.yaml \ 54 | --gradient_accumulation_steps 4 \ 55 | --num_machines $NUM_NODES \ 56 | --num_processes $WORLD_SIZE \ 57 | --main_process_ip $MASTER_ADDR \ 58 | --main_process_port $MASTER_PORT \ 59 | --machine_rank \$SLURM_PROCID \ 60 | --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \ 61 | --max_restarts 1 \ 62 | --role \$(hostname -s): \ 63 | --tee 3 \ 64 | " 65 | 66 | # force crashing on nccl issues like hanging broadcast 67 | export NCCL_ASYNC_ERROR_HANDLING=1 68 | # export NCCL_DEBUG=INFO 69 | # export NCCL_DEBUG_SUBSYS=COLL 70 | # export NCCL_SOCKET_NTHREADS=1 71 | # export NCCL_NSOCKS_PERTHREAD=1 72 | # export CUDA_LAUNCH_BLOCKING=1 73 | 74 | # Specific configuration optimized for the Hugging Face Compute Cluster 75 | # Be ye warned this may not work on other clusters! 76 | module load cuda/12.1 77 | 78 | # srun error handling: 79 | # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks 80 | # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code 81 | SRUN_ARGS=" \ 82 | --wait=60 \ 83 | --kill-on-bad-exit=1 \ 84 | " 85 | 86 | clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1 87 | 88 | echo "END TIME: $(date)" 89 | -------------------------------------------------------------------------------- /src/virft/src/open_r1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/src/virft/src/open_r1/__init__.py -------------------------------------------------------------------------------- /src/virft/src/open_r1/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom evaluation tasks for LightEval.""" 16 | 17 | from lighteval.metrics.dynamic_metrics import ( 18 | ExprExtractionConfig, 19 | LatexExtractionConfig, 20 | multilingual_extractive_match_metric, 21 | ) 22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig 23 | from lighteval.tasks.requests import Doc 24 | from lighteval.utils.language import Language 25 | 26 | 27 | metric = multilingual_extractive_match_metric( 28 | language=Language.ENGLISH, 29 | fallback_mode="first_match", 30 | precision=5, 31 | gold_extraction_target=(LatexExtractionConfig(),), 32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), 33 | aggregation_function=max, 34 | ) 35 | 36 | 37 | def prompt_fn(line, task_name: str = None): 38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically""" 39 | return Doc( 40 | task_name=task_name, 41 | query=line["problem"], 42 | choices=[line["solution"]], 43 | gold_index=0, 44 | ) 45 | 46 | 47 | # Define tasks 48 | aime24 = LightevalTaskConfig( 49 | name="aime24", 50 | suite=["custom"], 51 | prompt_function=prompt_fn, 52 | hf_repo="HuggingFaceH4/aime_2024", 53 | hf_subset="default", 54 | hf_avail_splits=["train"], 55 | evaluation_splits=["train"], 56 | few_shots_split=None, 57 | few_shots_select=None, 58 | generation_size=32768, 59 | metric=[metric], 60 | version=1, 61 | ) 62 | math_500 = LightevalTaskConfig( 63 | name="math_500", 64 | suite=["custom"], 65 | prompt_function=prompt_fn, 66 | hf_repo="HuggingFaceH4/MATH-500", 67 | hf_subset="default", 68 | hf_avail_splits=["test"], 69 | evaluation_splits=["test"], 70 | few_shots_split=None, 71 | few_shots_select=None, 72 | generation_size=32768, 73 | metric=[metric], 74 | version=1, 75 | ) 76 | 77 | # Add tasks to the table 78 | TASKS_TABLE = [] 79 | TASKS_TABLE.append(aime24) 80 | TASKS_TABLE.append(math_500) 81 | 82 | # MODULE LOGIC 83 | if __name__ == "__main__": 84 | print([t["name"] for t in TASKS_TABLE]) 85 | print(len(TASKS_TABLE)) 86 | -------------------------------------------------------------------------------- /src/virft/src/open_r1/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from distilabel.llms import OpenAILLM 18 | from distilabel.pipeline import Pipeline 19 | from distilabel.steps.tasks import TextGeneration 20 | 21 | 22 | def build_distilabel_pipeline( 23 | model: str, 24 | base_url: str = "http://localhost:8000/v1", 25 | prompt_column: Optional[str] = None, 26 | temperature: Optional[float] = None, 27 | top_p: Optional[float] = None, 28 | max_new_tokens: int = 8192, 29 | num_generations: int = 1, 30 | ) -> Pipeline: 31 | generation_kwargs = {"max_new_tokens": max_new_tokens} 32 | 33 | if temperature is not None: 34 | generation_kwargs["temperature"] = temperature 35 | 36 | if top_p is not None: 37 | generation_kwargs["top_p"] = top_p 38 | 39 | with Pipeline().ray() as pipeline: 40 | TextGeneration( 41 | llm=OpenAILLM( 42 | base_url=base_url, 43 | api_key="something", 44 | model=model, 45 | # thinking can take some time... 46 | timeout=10 * 60, 47 | generation_kwargs=generation_kwargs, 48 | ), 49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, 50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion 51 | num_generations=num_generations, 52 | ) 53 | 54 | return pipeline 55 | 56 | 57 | if __name__ == "__main__": 58 | import argparse 59 | 60 | from datasets import load_dataset 61 | 62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") 63 | parser.add_argument( 64 | "--hf-dataset", 65 | type=str, 66 | required=True, 67 | help="HuggingFace dataset to load", 68 | ) 69 | parser.add_argument( 70 | "--hf-dataset-config", 71 | type=str, 72 | required=False, 73 | help="Dataset config to use", 74 | ) 75 | parser.add_argument( 76 | "--hf-dataset-split", 77 | type=str, 78 | default="train", 79 | help="Dataset split to use", 80 | ) 81 | parser.add_argument("--prompt-column", type=str, default="prompt") 82 | parser.add_argument( 83 | "--model", 84 | type=str, 85 | required=True, 86 | help="Model name to use for generation", 87 | ) 88 | parser.add_argument( 89 | "--vllm-server-url", 90 | type=str, 91 | default="http://localhost:8000/v1", 92 | help="URL of the vLLM server", 93 | ) 94 | parser.add_argument( 95 | "--temperature", 96 | type=float, 97 | help="Temperature for generation", 98 | ) 99 | parser.add_argument( 100 | "--top-p", 101 | type=float, 102 | help="Top-p value for generation", 103 | ) 104 | parser.add_argument( 105 | "--max-new-tokens", 106 | type=int, 107 | default=8192, 108 | help="Maximum number of new tokens to generate", 109 | ) 110 | parser.add_argument( 111 | "--num-generations", 112 | type=int, 113 | default=1, 114 | help="Number of generations per problem", 115 | ) 116 | parser.add_argument( 117 | "--hf-output-dataset", 118 | type=str, 119 | required=False, 120 | help="HuggingFace repo to push results to", 121 | ) 122 | parser.add_argument( 123 | "--private", 124 | action="store_true", 125 | help="Whether to make the output dataset private when pushing to HF Hub", 126 | ) 127 | 128 | args = parser.parse_args() 129 | 130 | print("\nRunning with arguments:") 131 | for arg, value in vars(args).items(): 132 | print(f" {arg}: {value}") 133 | print() 134 | 135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") 136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) 137 | print("Dataset loaded!") 138 | 139 | pipeline = build_distilabel_pipeline( 140 | model=args.model, 141 | base_url=args.vllm_server_url, 142 | prompt_column=args.prompt_column, 143 | temperature=args.temperature, 144 | top_p=args.top_p, 145 | max_new_tokens=args.max_new_tokens, 146 | num_generations=args.num_generations, 147 | ) 148 | 149 | print("Running generation pipeline...") 150 | distiset = pipeline.run(dataset=dataset, use_cache=False) 151 | print("Generation pipeline finished!") 152 | 153 | if args.hf_output_dataset: 154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") 155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private) 156 | print("Dataset pushed!") 157 | -------------------------------------------------------------------------------- /src/virft/src/open_r1/sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Supervised fine-tuning script for decoder language models. 17 | 18 | Usage: 19 | 20 | # One 1 node of 8 x H100s 21 | accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \ 22 | --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ 23 | --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \ 24 | --learning_rate 2.0e-5 \ 25 | --num_train_epochs 1 \ 26 | --packing \ 27 | --max_seq_length 4096 \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --gradient_checkpointing \ 31 | --bf16 \ 32 | --logging_steps 5 \ 33 | --eval_strategy steps \ 34 | --eval_steps 100 \ 35 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill 36 | """ 37 | 38 | from datasets import load_dataset 39 | from transformers import AutoTokenizer 40 | 41 | from trl import ( 42 | ModelConfig, 43 | ScriptArguments, 44 | SFTConfig, 45 | SFTTrainer, 46 | TrlParser, 47 | get_kbit_device_map, 48 | get_peft_config, 49 | get_quantization_config, 50 | ) 51 | 52 | 53 | def main(script_args, training_args, model_args): 54 | ################ 55 | # Model init kwargs & Tokenizer 56 | ################ 57 | quantization_config = get_quantization_config(model_args) 58 | model_kwargs = dict( 59 | revision=model_args.model_revision, 60 | trust_remote_code=model_args.trust_remote_code, 61 | attn_implementation=model_args.attn_implementation, 62 | torch_dtype=model_args.torch_dtype, 63 | use_cache=False if training_args.gradient_checkpointing else True, 64 | device_map=get_kbit_device_map() if quantization_config is not None else None, 65 | quantization_config=quantization_config, 66 | ) 67 | training_args.model_init_kwargs = model_kwargs 68 | tokenizer = AutoTokenizer.from_pretrained( 69 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True 70 | ) 71 | tokenizer.pad_token = tokenizer.eos_token 72 | 73 | ################ 74 | # Dataset 75 | ################ 76 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 77 | 78 | ################ 79 | # Training 80 | ################ 81 | trainer = SFTTrainer( 82 | model=model_args.model_name_or_path, 83 | args=training_args, 84 | train_dataset=dataset[script_args.dataset_train_split], 85 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 86 | processing_class=tokenizer, 87 | peft_config=get_peft_config(model_args), 88 | ) 89 | 90 | trainer.train() 91 | 92 | # Save and push to hub 93 | trainer.save_model(training_args.output_dir) 94 | if training_args.push_to_hub: 95 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) 100 | script_args, training_args, model_args = parser.parse_args_and_config() 101 | main(script_args, training_args, model_args) 102 | -------------------------------------------------------------------------------- /src/virft/src/open_r1/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .grpo_trainer import Qwen2VLGRPOTrainer 2 | from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer 3 | from .grpo_trainer_mp import Qwen2VLGRPOTrainer_MP 4 | from .grpo_trainer_aid import Qwen2VLGRPOTrainer_AID 5 | 6 | __all__ = ["Qwen2VLGRPOTrainer", "Qwen2VLGRPOVLLMTrainer", "Qwen2VLGRPOTrainer_MP", "Qwen2VLGRPOTrainer_AID"] 7 | --------------------------------------------------------------------------------