; " for i in range(len(visuals))]
157 | image_tokens = "".join(image_tokens)
158 | contexts = image_tokens + contexts
159 | elif self.modality == "video":
160 | image = visuals
161 |
162 | if "max_new_tokens" not in gen_kwargs:
163 | gen_kwargs["max_new_tokens"] = 1024
164 | if "temperature" not in gen_kwargs:
165 | gen_kwargs["temperature"] = 0
166 | if "top_p" not in gen_kwargs:
167 | gen_kwargs["top_p"] = None
168 | if "num_beams" not in gen_kwargs:
169 | gen_kwargs["num_beams"] = 1
170 |
171 | try:
172 | with torch.autocast(device_type="cuda", dtype=torch.float16):
173 | response, his = self.model.chat(self.tokenizer, contexts, image, do_sample=False, num_beams=1, use_meta=True, max_new_tokens=gen_kwargs["max_new_tokens"])
174 | except Exception as e:
175 | eval_logger.error(f"Error : {e}")
176 | response = ""
177 |
178 | res.append(response)
179 | pbar.update(1)
180 | pbar.close()
181 | return res
182 |
183 | def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
184 | assert False, "Not implemented yet."
185 |
186 | def generate_until_multi_round(self, requests) -> List[str]:
187 | raise NotImplementedError("TODO: Implement multi-round generation")
188 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Euclid’s Gift: Enhancing Spatial Perception and Reasoning in Vision‑Language Models via Geometric Surrogate Tasks
2 | [](https://github.com/LiamLian0727/Euclids_Gift/issues)
3 | [](https://github.com/LiamLian0727/Euclids_Gift/fork)
4 | [](https://github.com/LiamLian0727/Euclids_Gift/stargazers)
5 | [](https://huggingface.co/collections/LiamLian0727/euclid-model)
6 | [](https://huggingface.co/datasets/LiamLian0727/Euclid30K)
7 | [](https://arxiv.org/abs/2509.24473)
8 | [](LICENSE)
9 |
10 | ## 📢 News
11 |
12 | - [10/24/2025] :zap: We trained Qwen3VL (4B, 8B, and 30B) using Euclid30K, and the results show that the models also achieve significant gains across various spatial intelligence tasks. The weights of the fine-tuned models are available [here](https://huggingface.co/collections/LiamLian0727/euclid-model).
13 |
14 |
15 |
16 | | Model | SuperClevr | Omni3D Bench | VSIBench* | MindCube |
17 | | :------------------ | :----------------: | :----------------: | :---------------: | :---------------: |
18 | | Qwen3VL-4B | 55.36 | 27.74 | 35.51 | 26.11 |
19 | | Qwen3VL-Euclid-4B | 61.24 **(+5.88)** | 31.74 **(+4.00)** | 42.26 **(+6.75)** | 32.98 **(+6.87)** |
20 | | Qwen3VL-8B | 48.30 | 34.01 | 33.25 | 34.16 |
21 | | Qwen3VL-Euclid-8B | 48.96 **(+0.66)** | 35.03 **(+1.02)** | 35.54 **(+2.29)** | 41.02 **(+6.86)** |
22 | | Qwen3VL-30B | 64.12 | 36.71 | 40.00 | 39.75 |
23 | | Qwen3VL-Euclid-30B | 70.18 **(+6.06)** | 38.90 **(+2.19)** | 45.80 **(+5.80)** | 40.68 **(+0.93)** |
24 |
25 |
26 |
27 | > Qwen3VL and Qwen3VL-Euclid are evaluated using the same prompting template defined in [test/eval_qwen.sh](test/eval_qwen.sh) to ensure a fair comparison.
28 |
29 |
30 | - [10/17/2025] Thanks to Synced (机器之心) for covering our work: [wechat article](https://mp.weixin.qq.com/s/OfCiijFuj1nITUyAF7Svfw) / [zhihu](https://zhuanlan.zhihu.com/p/1962478345846501995).
31 | - [09/30/2025] We release our paper in [arXiv](https://arxiv.org/abs/2509.24473) and Euclid30K dataset in [huggingface](https://huggingface.co/datasets/LiamLian0727/Euclid30K).
32 |
33 | ## Abstract
34 | Spatial intelligence spans abilities such as visualizing and transforming shapes, mental rotation, reasoning about relative positions and containment, and counting/estimation. These remain challenging for modern Multimodal Large Language Models (MLLMs). We propose solving Euclidean geometry problems as a surrogate task and construct Euclid30K, a dataset of roughly 30K 2D and 3D geometry questions. We then fine‑tune Qwen2.5‑VL and RoboBrain2.0 models with Group Relative Policy Optimization (GRPO), enabling the models to internalize and apply Euclidean principles for shape recognition, counting, relation extraction, and multi‑step deductive reasoning. Without task‑specific adaptations, our models achieve significant zero‑shot gains on four spatial‑reasoning benchmarks: Super‑CLEVR, Omni3DBench, VSI‑Bench, and MindCube. For example, on VSI‑Bench, average accuracy improves from 34.5% to 40.5% (+5.5 percentage points); RoboBrain2.0‑Euclid‑7B reaches 49.6%, surpassing the previous SOTA (Spatial‑MLLM).
35 |
36 | 
37 |
38 | 
39 |
40 | ## Quick Start
41 |
42 | ### 1) Environment Setup
43 | Training
44 | - Install [EasyR1](https://github.com/hiyouga/EasyR1) following the official documentation.
45 | - Install the required Python dependencies: `pip install -r requirements.txt`.
46 | - Download the Euclid30K dataset from Hugging Face: https://huggingface.co/datasets/LiamLian0727/Euclid30K
47 |
48 | Evaluation
49 | - Install [lmms‑eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) following its official documentation. You can either:
50 | - Use the [`lmms-eval/`](https://github.com/EvolvingLMMs-Lab/lmms-eval) copy included in this repository; or
51 | - Copy the four task folders provided under `test/lmms_eval/tasks/` into your existing lmms‑eval setup.
52 | - Download the benchmark datasets [Super‑CLEVR](https://huggingface.co/datasets/MMInstruction/SuperClevr_Val), [Omni3DBench](https://huggingface.co/datasets/dmarsili/Omni3D-Bench), [VSI‑Bench](https://huggingface.co/datasets/nyu-visionx/VSI-Bench), and [MindCube_lmms_eval](https://huggingface.co/datasets/LiamLian0727/MindCube_lmms_eval); then update the dataset paths in each corresponding YAML under `test/lmms_eval/tasks/`.
53 |
54 | ### 2) Training
55 |
56 | Below is an example command for training (e.g., 8 GPUs). For multi‑node multi‑GPU training, see the example script [train/dist_train.sh](train/dist_train.sh).
57 |
58 | ```bash
59 | python3 -m verl.trainer.main \
60 | config=examples/config.yaml \
61 | data.train_files=/mnt/datasets/Euclid30K/Euclid30K_train.parquet \
62 | data.val_files=/mnt/datasets/Euclid30K/Euclid30K_val.parquet \
63 | worker.actor.model.model_path=/mnt/models/Qwen2.5-VL-7B-Instruct \
64 | trainer.experiment_name=EXPERIMENT_NAME \
65 | worker.actor.micro_batch_size_per_device_for_update=1 \
66 | worker.actor.micro_batch_size_per_device_for_experience=8 \
67 | worker.actor.clip_ratio_low=0.2 \
68 | worker.actor.clip_ratio_high=0.28 \
69 | worker.reward.reward_function=/mnt/code/Euclids_Gift/train/euclid.py:compute_score \
70 | trainer.total_epochs=10 \
71 | trainer.n_gpus_per_node=8 \
72 | trainer.nnodes=2 \
73 | trainer.save_checkpoint_path=/mnt/models/Qwen2.5-VL-7B-Euclid
74 | ```
75 |
76 | ### 3) Evaluation
77 |
78 | 
79 |
80 | Use [`test/eval_qwen.sh`](test/eval_qwen.sh), [`test/eval_robo.sh`](test/eval_robo.sh), and [`test/eval_euclid.sh`](test/eval_euclid.sh) to evaluate the Qwen2.5‑VL series, the RoboBrain 2.0 series, and Euclid models trained on Euclid30K, respectively.
81 |
82 | Before running these scripts, set `model_path` in each script to the path of the model you want to evaluate.
83 |
84 | > Notably, as noted in VSIBench, **spatial reasoning ability is the primary bottleneck limiting MLLM performance on the VSI-Bench test**. Therefore, to better demonstrate how models perceive scenes and perform spatial reasoning, and to verify whether they genuinely acquire spatial intelligence from geometric knowledge, we deviate from the original VSI-Bench setup, which uses prompts such as "*Answer with the option's letter from the given choices directly*" or "*Please answer the question using a single word or phrase*" and constrains the maximum response length to 16 tokens. Instead, we follow the prompt configuration described in RoboBrain2.0 Sec. B, which encourages the model to first reason about the problem before providing an answer, and we set the maximum response length to 1024 tokens. This setup allows us to observe the model's intermediate reasoning process and assess whether it has internalized transferable spatial priors from Euclid30K training.
85 |
86 |
87 | ## Citation
88 | If you find this project or the dataset helpful, please cite:
89 | ```bibtex
90 | @misc{Euclids_Gift,
91 | title={Euclid’s Gift: Enhancing Spatial Perception and Reasoning in Vision-Language Models via Geometric Surrogate Tasks},
92 | author={Shijie Lian and Changti Wu and Laurence Tianruo Yang and Hang Yuan and Bin Yu and Lei Zhang and Kai Chen},
93 | year={2025},
94 | eprint={2509.24473},
95 | archivePrefix={arXiv},
96 | primaryClass={cs.CV},
97 | url={https://arxiv.org/abs/2509.24473}
98 | }
99 | ```
100 |
101 | ## Acknowledgements
102 |
103 | We thank the [VeRL](https://github.com/volcengine/verl) / [EasyR1](https://github.com/hiyouga/EasyR1) training framework, as well as the benchmark suites [Super‑CLEVR](https://huggingface.co/datasets/MMInstruction/SuperClevr_Val), [Omni3DBench](https://huggingface.co/datasets/dmarsili/Omni3D-Bench), [VSI‑Bench](https://huggingface.co/datasets/nyu-visionx/VSI-Bench), and [MindCube](https://huggingface.co/datasets/MLL-Lab/MindCube).
104 |
105 | ## ⭐ Stargazers
106 | [](https://github.com/LiamLian0727/Euclids_Gift/stargazers)
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/test/lmms_eval/api/model.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import hashlib
3 | import json
4 | import os
5 | from typing import List, Optional, Tuple, Type, TypeVar, Union
6 |
7 | from loguru import logger as eval_logger
8 | from sqlitedict import SqliteDict
9 | from tqdm import tqdm
10 |
11 | from lmms_eval import utils
12 | from lmms_eval.api.instance import Instance
13 |
14 | T = TypeVar("T", bound="lmms")
15 |
16 |
17 | class lmms(abc.ABC):
18 | def __init__(self) -> None:
19 | """Defines the interface that should be implemented by all lmms subclasses.
20 | lmmss are assumed to take image-text as input and yield strings as output
21 | (inputs/outputs should be tokenization-agnostic.)
22 | """
23 | # set rank and world size to a single process, by default.
24 | self._rank = 0
25 | self._world_size = 1
26 | self.cache_hook = CacheHook(None)
27 | self.task_dict = {}
28 |
29 | @abc.abstractmethod
30 | def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
31 | """Compute log-likelihood of generating a continuation from a context.
32 | Downstream tasks should attempt to use loglikelihood instead of other
33 | LMM calls whenever possible.
34 |
35 | :param requests: list[Instance]
36 | A list of Instance objects, with property `args` which returns a tuple (context, continuation).
37 | `context: str`
38 | Context string. Implementations of LMM must be able to handle an
39 | empty context string.
40 | `continuation: str`
41 | The continuation over which log likelihood will be calculated. If
42 | there is a word boundary, the space should be in the continuation.
43 | For example, context="hello" continuation=" world" is correct.
44 | 'visual_list: list[dict]'
45 | Visual input to the model. Can be None.
46 |
47 | :return: list[tuple[float, bool]]
48 | A list of pairs (logprob, isgreedy)
49 | `logprob: float`
50 | The log probability of `continuation`.
51 | `isgreedy`:
52 | Whether `continuation` would be generated by greedy sampling from `context`.
53 | """
54 | pass
55 |
56 | # TODO: Add an optional max length
57 | @abc.abstractmethod
58 | def generate_until(self, requests) -> List[str]:
59 | """Generate greedily until a stopping sequence
60 |
61 | :param requests: list[Instance]
62 | A list of Instance objects with property `args` which returns a tuple (context, until).
63 | context: str
64 | Context string
65 | generation_kwargs: dict
66 | Generation Kwargs
67 | 'visual_list: list[dict]'
68 | Visual input to the model. Can be None.
69 | :return: list[str]
70 | A list of strings continuation
71 | continuation: str
72 | The generated continuation.
73 | """
74 | pass
75 |
76 | @abc.abstractmethod
77 | def generate_until_multi_round(self, requests) -> List[str]:
78 | """Generate greedily until a stopping sequence
79 |
80 | :param requests: list[Instance]
81 | A list of Instance objects with property `args` which returns a tuple (context, until).
82 | context: str
83 | Context string
84 | generation_kwargs: dict
85 | Generation Kwargs
86 | 'visual_list: list[dict]'
87 | Visual input to the model. Can be None.
88 | :return: list[str]
89 | A list of strings continuation
90 | continuation: str
91 | The generated continuation.
92 | """
93 | pass
94 |
95 | @classmethod
96 | def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T:
97 | """
98 | Creates an instance of the LMM class using the given argument string and additional config.
99 |
100 | Parameters:
101 | - arg_string: A string containing arguments in the format key1=value1,key2=value2.
102 | - additional_config: Optional dictionary containing additional configuration parameters.
103 |
104 | Returns:
105 | - Instance of the LMM class.
106 | """
107 | additional_config = {} if additional_config is None else additional_config
108 | args = utils.simple_parse_args_string(arg_string)
109 | args2 = {k: v for k, v in additional_config.items() if v is not None}
110 | return cls(**args, **args2)
111 |
112 | @property
113 | def rank(self):
114 | # used in the case of parallelism. Hardcoded to
115 | # ensure no errors arise using API models which do
116 | # not support multi-device parallelism nor expect it.
117 | return self._rank
118 |
119 | @property
120 | def world_size(self):
121 | # used in the case of parallelism. Hardcoded to
122 | # ensure no errors arise using API models which do
123 | # not support multi-device parallelism nor expect it.
124 | return self._world_size
125 |
126 | def set_cache_hook(self, cache_hook) -> None:
127 | self.cache_hook = cache_hook
128 |
129 |
130 | ### SQLite-based caching of LMM responses
131 | def hash_args(attr, args):
132 | dat = json.dumps([attr] + list(args))
133 | return hashlib.sha256(dat.encode("utf-8")).hexdigest()
134 |
135 |
136 | class CacheHook:
137 | def __init__(self, cachinglm) -> None:
138 | if cachinglm is None:
139 | self.dbdict = None
140 | return
141 |
142 | self.dbdict = cachinglm.dbdict
143 |
144 | def add_partial(self, attr, req, res) -> None:
145 | if self.dbdict is None:
146 | return
147 | hsh = hash_args(attr, req)
148 | self.dbdict[hsh] = res
149 |
150 |
151 | class CachingLMM:
152 | def __init__(self, lm, cache_db) -> None:
153 | """LMM wrapper that returns cached results if they exist, and uses the underlying LMM if not.
154 |
155 | :param lm: LMM
156 | Underlying LMM
157 | :param cache_db: str
158 | Path to cache db
159 | """
160 | self.lm = lm
161 | self.cache_db = cache_db
162 | if os.path.dirname(cache_db):
163 | os.makedirs(os.path.dirname(cache_db), exist_ok=True)
164 | self.dbdict = SqliteDict(cache_db, autocommit=True)
165 |
166 | # add hook to lm
167 | lm.set_cache_hook(self.get_cache_hook())
168 |
169 | def __getattr__(self, attr):
170 | lm_attr = getattr(self.lm, attr)
171 | if not callable(lm_attr):
172 | return lm_attr
173 |
174 | def fn(requests):
175 | res = []
176 | remaining_reqs = []
177 | warned = False
178 | # figure out which ones are cached and which ones are new
179 | eval_logger.info(f"Loading '{attr}' responses from cache '{self.cache_db}' where possible...")
180 | for req in tqdm(requests):
181 | hsh = hash_args(attr, req.args)
182 | if attr in ["generate_until", "generate_until_multi_round"] and req.args[1].get("do_sample", False):
183 | # when we are doing non-greedy generation, don't use the cache
184 | # (else every "randomly sampled" generation would be identical for repeats > 1).
185 | if not warned:
186 | eval_logger.warning(f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests.")
187 | warned = True
188 | res.append(None)
189 | remaining_reqs.append(req)
190 | elif hsh in self.dbdict:
191 | ob = self.dbdict[hsh]
192 |
193 | assert ob is not None
194 |
195 | res.append(ob)
196 | else:
197 | res.append(None)
198 | remaining_reqs.append(req)
199 |
200 | # actually run the LMM on the requests that do not have cached results
201 | rem_res = getattr(self.lm, attr)(remaining_reqs)
202 |
203 | # stick the new ones back into the list and also cache any of the new ones
204 | resptr = 0
205 | for req, r in zip(remaining_reqs, rem_res):
206 | while res[resptr] is not None:
207 | resptr += 1
208 |
209 | res[resptr] = r
210 |
211 | # caching
212 | hsh = hash_args(attr, req.args)
213 | self.dbdict[hsh] = r
214 | self.dbdict.commit()
215 |
216 | return res
217 |
218 | return fn
219 |
220 | def get_cache_hook(self):
221 | return CacheHook(self)
222 |
--------------------------------------------------------------------------------
/test/lmms_eval/models/reka.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import os
4 | import time
5 | from copy import deepcopy
6 | from io import BytesIO
7 | from typing import List, Tuple
8 |
9 | import numpy as np
10 | import requests as url_requests
11 | from accelerate import Accelerator, DistributedType
12 | from PIL import Image
13 | from tqdm import tqdm
14 |
15 | from lmms_eval.api.instance import Instance
16 | from lmms_eval.api.model import lmms
17 | from lmms_eval.api.registry import register_model
18 |
19 | NUM_SECONDS_TO_SLEEP = 30
20 |
21 | from loguru import logger
22 |
23 | eval_logger = logger
24 |
25 | try:
26 | from decord import VideoReader, cpu
27 | from reka import ChatMessage
28 | from reka.client import Reka as RekaClient
29 | except Exception as e:
30 | eval_logger.warning(f"Error importing reka: {e}")
31 |
32 |
33 | @register_model("reka")
34 | class Reka(lmms):
35 | def __init__(
36 | self,
37 | model_version: str = "reka-edge",
38 | modality: str = "image",
39 | max_frames_num: int = 5,
40 | timeout: int = 120,
41 | continual_mode: bool = False,
42 | response_persistent_folder: str = None, # We will cache the Gemini API response in this path and use it for future requests
43 | **kwargs,
44 | ) -> None:
45 | super().__init__()
46 | self.model_version = model_version
47 | self.modality = modality
48 | self.max_frames_num = max_frames_num
49 | self.timeout = timeout
50 | self.continual_mode = continual_mode
51 | if self.continual_mode:
52 | if response_persistent_folder is None:
53 | raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.")
54 |
55 | os.makedirs(response_persistent_folder, exist_ok=True)
56 | self.response_persistent_folder = response_persistent_folder
57 | self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")
58 |
59 | if os.path.exists(self.response_persistent_file):
60 | with open(self.response_persistent_file, "r") as f:
61 | self.response_cache = json.load(f)
62 | self.cache_mode = "resume"
63 | else:
64 | self.response_cache = {}
65 | self.cache_mode = "start"
66 |
67 | self.reka = RekaClient(api_key=os.getenv("REKA_API_KEY", "YOUR_API_KEY"))
68 |
69 | accelerator = Accelerator()
70 | if accelerator.num_processes > 1:
71 | assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
72 | self.accelerator = accelerator
73 | if self.accelerator.is_local_main_process:
74 | eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
75 | self._rank = self.accelerator.local_process_index
76 | self._world_size = self.accelerator.num_processes
77 | else:
78 | self.accelerator = accelerator
79 | self._rank = self.accelerator.local_process_index
80 | self._world_size = self.accelerator.num_processes
81 |
82 | self.device = self.accelerator.device
83 |
84 | def encode_image(self, image):
85 | if type(image) == list:
86 | media_urls = []
87 | for img in image:
88 | output_buffer = BytesIO()
89 | img.save(output_buffer, format="PNG")
90 | byte_data = output_buffer.getvalue()
91 | base64_str = base64.b64encode(byte_data).decode("utf-8")
92 | media_urls.append(f"data:image/jpeg;base64,{base64_str}")
93 | return media_urls
94 | else:
95 | output_buffer = BytesIO()
96 | image.save(output_buffer, format="PNG")
97 | byte_data = output_buffer.getvalue()
98 | base64_str = base64.b64encode(byte_data).decode("utf-8")
99 |
100 | return f"data:image/jpeg;base64,{base64_str}"
101 |
102 | def encode_video(self, video_path):
103 | vr = VideoReader(video_path, ctx=cpu(0))
104 | total_frame_num = len(vr)
105 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_num, dtype=int)
106 | frame_idx = uniform_sampled_frames.tolist()
107 | frames = vr.get_batch(frame_idx).asnumpy()
108 |
109 | base64_frames = []
110 | for frame in frames:
111 | img = Image.fromarray(frame)
112 | output_buffer = BytesIO()
113 | img.save(output_buffer, format="PNG")
114 | byte_data = output_buffer.getvalue()
115 | base64_str = base64.b64encode(byte_data).decode("utf-8")
116 | base64_frames.append(f"data:image/jpeg;base64,{base64_str}")
117 |
118 | return base64_frames
119 |
120 | def generate_until(self, requests) -> List[str]:
121 | res = []
122 | pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
123 |
124 | for context, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
125 | if self.continual_mode is True and self.cache_mode == "resume":
126 | doc_uuid = f"{task}___{split}___{doc_id}"
127 | if doc_uuid in self.response_cache:
128 | response_text = self.response_cache[doc_uuid]
129 | if response_text:
130 | res.append(response_text)
131 | pbar.update(1)
132 | continue
133 |
134 | visual = doc_to_visual(self.task_dict[task][split][doc_id])
135 |
136 | message_content = []
137 |
138 | if self.modality == "image":
139 | media_urls = self.encode_image(visual)
140 | message_content.append({"type": "text", "text": context})
141 | for media_url in media_urls:
142 | message_content.append({"type": "image_url", "image_url": media_url})
143 | elif self.modality == "video":
144 | message_content.append({"type": "text", "text": context})
145 | assert len(visual) == 1, "Reka only supports one video per request"
146 | media_urls = self.encode_video(visual[0])
147 | assert len(media_urls) == self.max_frames_num, f"Reka only supports {self.max_frames_num} frames per request"
148 | for media_url in media_urls:
149 | message_content.append({"type": "image_url", "image_url": media_url})
150 |
151 | if "max_new_tokens" not in gen_kwargs:
152 | gen_kwargs["max_new_tokens"] = 1024
153 | if "temperature" not in gen_kwargs:
154 | gen_kwargs["temperature"] = 0
155 | if "top_p" not in gen_kwargs:
156 | gen_kwargs["top_p"] = None
157 | if "num_beams" not in gen_kwargs:
158 | gen_kwargs["num_beams"] = 1
159 |
160 | for attempt in range(5):
161 | try:
162 | response = self.reka.chat.create(
163 | messages=[
164 | ChatMessage(
165 | role="user",
166 | content=message_content,
167 | )
168 | ],
169 | model=self.model_version,
170 | )
171 | response_text = response.responses[0].message.content.strip()
172 | break # If successful, break out of the loop
173 |
174 | except Exception as e:
175 | eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}")
176 | if attempt < 5 - 1: # If we have retries left, sleep and then continue to next attempt
177 | time.sleep(NUM_SECONDS_TO_SLEEP)
178 | else: # If this was the last attempt, log and return empty
179 | eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}")
180 | response_text = ""
181 |
182 | res.append(response_text)
183 | pbar.update(1)
184 | if self.continual_mode is True: # Cache the response
185 | doc_uuid = f"{task}___{split}___{doc_id}"
186 | self.response_cache[doc_uuid] = response_text
187 | with open(self.response_persistent_file, "w") as f:
188 | json.dump(self.response_cache, f)
189 |
190 | pbar.close()
191 | return res
192 |
193 | def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
194 | # TODO
195 | assert False, "Reka not support loglikelihood"
196 |
197 | def generate_until_multi_round(self, requests) -> List[str]:
198 | raise NotImplementedError("TODO: Implement multi-round generation")
199 |
--------------------------------------------------------------------------------
/test/lmms_eval/models/batch_gpt4.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import base64
3 | import json
4 | import os
5 | import time
6 | from copy import deepcopy
7 | from io import BytesIO
8 |
9 | import numpy as np
10 | import requests as url_requests
11 |
12 | # Related third-party imports
13 | from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
14 | from accelerate.state import AcceleratorState
15 | from loguru import logger as eval_logger
16 | from openai import OpenAI
17 | from PIL import Image
18 | from tqdm import tqdm
19 |
20 | # Local application/library specific imports
21 | from lmms_eval.api.instance import Instance
22 | from lmms_eval.api.model import lmms
23 | from lmms_eval.api.registry import register_model
24 |
25 | # Conditional imports
26 | try:
27 | from decord import VideoReader, cpu
28 | except ImportError:
29 | eval_logger.warning("Decord is not installed. Video input will not be supported.")
30 |
31 | # Constants and global configurations
32 | API_TYPE = os.getenv("API_TYPE", "openai")
33 | NUM_SECONDS_TO_SLEEP = 5
34 |
35 | if API_TYPE == "openai":
36 | API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
37 | API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
38 | headers = {
39 | "Authorization": f"Bearer {API_KEY}",
40 | "Content-Type": "application/json",
41 | }
42 | elif API_TYPE == "azure":
43 | API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
44 | API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
45 | headers = {
46 | "api-key": API_KEY,
47 | "Content-Type": "application/json",
48 | }
49 | else:
50 | API_URL = "YOUR_API_URL"
51 | API_KEY = "YOUR_API_KEY"
52 |
53 |
54 | @register_model("batch_gpt4")
55 | class BatchGPT4(lmms):
56 | def __init__(
57 | self,
58 | model_version: str = "gpt-4o",
59 | api_key: str = API_KEY,
60 | api_url: str = API_URL,
61 | modality: str = "image",
62 | max_frames_num: int = 10,
63 | timeout: int = 120,
64 | **kwargs,
65 | ) -> None:
66 | super().__init__()
67 | # Manually set a image token for GPT4V so that we can search for it
68 | # and split the text and image
69 | # Here we just use the same token as llava for convenient
70 | self.model_version = model_version
71 | self.modality = modality
72 | self.max_frames_num = max_frames_num
73 | self.image_token = ""
74 | self.timeout = timeout
75 |
76 | self.api_key = api_key
77 | self.api_url = api_url
78 | self.client = OpenAI(api_key=api_key)
79 |
80 | accelerator = Accelerator()
81 | assert accelerator.state.local_process_index == 0, "BatchGPT4 does not support distributed inference."
82 | assert accelerator.state.num_processes == 1, "BatchGPT4 does not support distributed inference."
83 |
84 | # Function to encode the image
85 | def encode_image(self, image: Image):
86 | output_buffer = BytesIO()
87 | image.save(output_buffer, format="PNG")
88 | byte_data = output_buffer.getvalue()
89 | base64_str = base64.b64encode(byte_data).decode("utf-8")
90 | return base64_str
91 |
92 | # Function to encode the video
93 | def encode_video(self, video_path, for_get_frames_num):
94 | vr = VideoReader(video_path, ctx=cpu(0))
95 | total_frame_num = len(vr)
96 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int)
97 | frame_idx = uniform_sampled_frames.tolist()
98 | frames = vr.get_batch(frame_idx).asnumpy()
99 |
100 | base64_frames = []
101 | for frame in frames:
102 | img = Image.fromarray(frame)
103 | output_buffer = BytesIO()
104 | img.save(output_buffer, format="PNG")
105 | byte_data = output_buffer.getvalue()
106 | base64_str = base64.b64encode(byte_data).decode("utf-8")
107 | base64_frames.append(base64_str)
108 |
109 | return base64_frames
110 |
111 | def flatten(self, input):
112 | new_list = []
113 | for i in input:
114 | for j in i:
115 | new_list.append(j)
116 | return new_list
117 |
118 | def generate_until(self, requests):
119 | # Prepare the batch requests data
120 | requests_data = {}
121 | pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Batch Preparing")
122 | for idx, (contexts, gen_kwargs, doc_to_visual, doc_id, task, split) in enumerate([reg.args for reg in requests]):
123 | visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
124 | visuals = self.flatten(visuals)
125 | imgs = []
126 | for visual in visuals:
127 | if self.modality == "image":
128 | img = self.encode_image(visual)
129 | imgs.append(img)
130 | elif self.modality == "video":
131 | frames = self.encode_video(visual, self.max_frames_num)
132 | imgs.extend(frames)
133 |
134 | messages = []
135 | if self.image_token not in contexts:
136 | messages.append({"role": "user", "content": contexts})
137 | for img in imgs:
138 | messages.append({"role": "user", "content": f"data:image/jpeg;base64,{img}"})
139 | else:
140 | contexts_split = contexts.split(self.image_token)
141 | for idx, context in enumerate(contexts_split):
142 | if idx < len(imgs):
143 | messages.append({"role": "user", "content": context})
144 | messages.append({"role": "user", "content": f"data:image/jpeg;base64,{imgs[idx]}"})
145 | if len(contexts_split) > len(imgs):
146 | messages.append({"role": "user", "content": contexts_split[-1]})
147 |
148 | requests_data[f"request-{idx}"] = {"model": self.model_version, "messages": messages, "max_tokens": gen_kwargs.get("max_new_tokens", 1024)}
149 | pbar.update(1)
150 |
151 | file_path = os.getenv("HF_HOME", "~/.cache/huggingface") + f"/batchinput_{len(requests_data)}.jsonl"
152 | file_path = self.create_batch_input_file(requests_data, file_path)
153 | file_id = self.upload_input_file(file_path)
154 |
155 | batch_response = self.create_batch(file_id, metadata={"description": "Batch Processing for GPT-4"})
156 | batch_status = self.check_batch_status(batch_response.id)
157 | while True:
158 | batch_status = self.check_batch_status(batch_response.id)
159 | if batch_status.status == "completed":
160 | eval_logger.info("Batch processing completed.")
161 | batch_results = self.retrieve_batch_results(batch_status.output_file_id)
162 | res = [result["response"]["choices"][0]["message"]["content"] for result in json.loads(batch_results)]
163 | return res
164 | elif batch_status.status == "failed":
165 | eval_logger.info("Batch processing failed.")
166 | res = ["Batch failed"] * len(requests)
167 | return res
168 | else:
169 | eval_logger.info(f"Batch status: {batch_status.status}. Retrying in {NUM_SECONDS_TO_SLEEP} seconds.")
170 | time.sleep(NUM_SECONDS_TO_SLEEP)
171 |
172 | def loglikelihood(self, requests):
173 | # TODO
174 | assert False, "GPT4V not support"
175 |
176 | def create_batch_input_file(self, requests_data, file_path="batchinput.jsonl"):
177 | with open(file_path, "w") as file:
178 | for request_id, data in requests_data.items():
179 | json_record = json.dumps({"custom_id": request_id, "method": "POST", "url": "/v1/chat/completions", "body": data})
180 | file.write(json_record + "\n")
181 | return file_path
182 |
183 | def upload_input_file(self, file_path):
184 | with open(file_path, "rb") as file:
185 | response = self.client.files.create(file=file, purpose="batch")
186 | return response.id
187 |
188 | def create_batch(self, file_id, metadata=None):
189 | if metadata is None:
190 | metadata = {}
191 | response = self.client.batches.create(input_file_id=file_id, endpoint="/v1/chat/completions", completion_window="24h", metadata=metadata)
192 | return response
193 |
194 | def check_batch_status(self, batch_id):
195 | return self.client.batches.retrieve(batch_id)
196 |
197 | def retrieve_batch_results(self, file_id):
198 | return self.client.files.content(file_id)
199 |
200 | def cancel_batch(self, batch_id):
201 | return self.client.batches.cancel(batch_id)
202 |
203 | def list_batches(self, limit=10):
204 | return self.client.batches.list(limit=limit)
205 |
206 | def generate_until_multi_round(self, requests) -> List[str]:
207 | raise NotImplementedError("TODO: Implement multi-round generation for BatchGPT4")
208 |
--------------------------------------------------------------------------------