├── IMG ├── ask_test_2.5.gif ├── logo.png ├── main_result.jpg └── method4.jpg ├── README.md ├── dataset ├── DynVQA_en │ ├── DynVQA_en.202406.jsonl │ ├── DynVQA_en.202408.jsonl │ ├── DynVQA_en.202410.jsonl │ ├── DynVQA_en.202412.jsonl │ └── DynVQA_en.202502.jsonl ├── DynVQA_zh │ ├── DynVQA_zh.202406.jsonl │ ├── DynVQA_zh.202408.jsonl │ ├── DynVQA_zh.202410.jsonl │ ├── DynVQA_zh.202412.jsonl │ └── DynVQA_zh.202502.jsonl └── training_data │ └── training_data_infoseek_en.json └── src ├── .DS_Store ├── Omnisearch_gpt ├── agent.py ├── conversation_manager.py ├── evaluate.py ├── llm_config.py ├── main.py ├── prompt.py └── search_api.py └── Omnisearch_qwen ├── Omnisearch_qwen.py └── search_api.py /IMG/ask_test_2.5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-NLP/OmniSearch/b01e78edb694fa0e8b05b808d2722d046f3f204b/IMG/ask_test_2.5.gif -------------------------------------------------------------------------------- /IMG/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-NLP/OmniSearch/b01e78edb694fa0e8b05b808d2722d046f3f204b/IMG/logo.png -------------------------------------------------------------------------------- /IMG/main_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-NLP/OmniSearch/b01e78edb694fa0e8b05b808d2722d046f3f204b/IMG/main_result.jpg -------------------------------------------------------------------------------- /IMG/method4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-NLP/OmniSearch/b01e78edb694fa0e8b05b808d2722d046f3f204b/IMG/method4.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # A Self-Adaptive Planning Agent For Multimodal RAG 6 | 7 | ![](https://img.shields.io/badge/version-1.0.0-blue)[![Pytorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?e&logo=PyTorch&logoColor=white)](https://pytorch.org/)[![arxiv badge](https://img.shields.io/badge/arxiv-2411.02937-red)](https://arxiv.org/abs/2411.02937) 8 | 9 | Repo for [*Benchmarking Multimodal Retrieval Augmented Generation with Dynamic VQA Dataset and Self-adaptive Planning Agent*](https://arxiv.org/abs/2411.02937) 10 | 11 | You can visit the Omnisearch homepage by clicking [*here!*](https://alibaba-nlp.github.io/OmniSearch/) 12 | 13 | 🌏 The **Chinese Web Demo** is avaiable at [ModelScope](https://modelscope.cn/studios/iic/OmniSearch/summary?header=default&fullWidth=false) now! 14 | 15 | 16 | 17 | - We propose OmniSearch, a self-adaptive retrieval agent that plans each retrieval action in real-time according to question solution stage and current retrieval content. As far as we known, **OmniSearch is the first planning agent for multimodal RAG.** 18 | - We reveal that existing VQA-based mRAG benchmarks fail to reflect the feature that real-world questions require dynamic knowledge retrieval, and propose novel **Dyn-VQA dataset, which contains three types of dynamic questions.** 19 | - We **benchmark various mRAG methods** with leading MLLMs on Dyn-VQA, demonstrating their flaw in providing sufficient and relevant knowledge for dynamic questions. 20 | 21 | 22 | 23 |
24 | 25 |
26 | 27 | 28 | 29 | ## 💡 Perfomance 30 | 31 | The performance of various MLLMs with different mRAG strategies are shown below: 32 | 33 |
34 | 35 |
36 | 37 | More analysis experiments can be found in the paper. 38 | 39 | # 📚 Dyn-VQA Dataset 40 | 41 | The json item of Dyn-VQA dataset is organized in the following format: 42 | ```json 43 | { 44 | "image_url": "https://www.pcarmarket.com/static/media/uploads/galleries/photos/uploads/galleries/22387-pasewark-1986-porsche-944/.thumbnails/IMG_7102.JPG.jpg/IMG_7102.JPG-tiny-2048x0-0.5x0.jpg", 45 | "question": "What is the model of car from this brand?", 46 | "question_id": 'qid', 47 | "answer": ["保时捷 944", "Porsche 944."] 48 | } 49 | ``` 50 | 51 | 🔥 The Dyn-VQA **will be updated regularly.** Laset version: 202412. 52 | 53 | # 🛠 Dependencies 54 | 55 | ```bash 56 | pip install -r requirement.txt 57 | ``` 58 | 59 | #### Details 60 | 61 | - Python = 3.11.9 62 | - [PyTorch](http://pytorch.org/) (>= 2.0.0) 63 | - pillow = 10.4.0 64 | - requests = 2.32.3 65 | - google-search-results = 2.4.2 66 | - serpapi = 0.1.5 67 | 68 | # 💻 Running OmniSearch 69 | 70 | - GPT-4V-based OmniSearch 71 | 72 | We have release the code of GPT-4V-based OmniSearch for English questions. 73 | 74 | Before running, please replace with your own OPENAI key and Google_search key. OPENAI key is at 11-th line of main.py 75 | 76 | ```python 77 | GPT_API_KEY = "your_actual_key_here" 78 | headers = { 79 | "Authorization": f"Bearer {GPT_API_KEY}" 80 | } 81 | ``` 82 | 83 | Google_search key is at 10-th line of search_api.py 84 | 85 | ```python 86 | API_KEY = "your api-key" 87 | ``` 88 | 89 | The result is saved to the path: 90 | 91 | ```python 92 | output_path = os.path.join(meta_save_path, dataset_name, "output_from_gpt4v.jsonl") 93 | ``` 94 | 95 | Run the `main.py` file: 96 | 97 | ```bash 98 | python main.py --test_dataset 'path/to/dataset.jsonl' --dataset_name NAME --meta_save_path 'path/to/results' 99 | ``` 100 | 101 | - Qwen-VL-based OmniSearch 102 | 103 | We have made the [training data](https://github.com/Alibaba-NLP/OmniSearch/tree/main/dataset/training_data) for Qwen-VL-based OmniSearch publicly available. This data, along with the [CogVLM dataset](https://modelscope.cn/datasets/ZhipuAI/CogVLM-SFT-311K), was used to jointly train the [Qwen-VL-Chat](https://www.modelscope.cn/models/Qwen/Qwen-VL-Chat) using the [SWIFT framework](https://github.com/modelscope/ms-swift). The training script can be executed as follows: 104 | 105 | ``` 106 | swift sft --model_type qwen-vl-chat --dataset /Data/Path/to/Training_data_1 /Data/Path/to/Training_data_2 --model_id_or_path /Model/Path/to/Qwen-VL-Chat/ --output_dir /Output/Model/Path --max_length 8192 --evaluation_strategy 'no' 107 | ``` 108 | 109 | You can download the model from [OmniSearch-Qwen-VL-Chat-en on Hugging Face](https://huggingface.co/Alibaba-NLP/OmniSearch-Qwen-VL-Chat-en/tree/main). 110 | 111 | Run the test script. Run the `Omnisearch_qwen.py` file: 112 | 113 | ```bash 114 | python Omnisearch_qwen.py --test_dataset '/path/to/dataset.jsonl' --dataset_name NAME --meta_save_path '/path/to/results' --model_path '/local/path/to/OmniSearch-Qwen-Chat-VL-weight' 115 | ``` 116 | 117 | 118 | 119 | # 🔍 Evaluation 120 | 121 | The evaluation script for token F1-Recall of the output answers can be used as follows: 122 | 123 | ```bash 124 | python evaluate.py --evaluate_file_path [path to output jsonl file] --lang [language of the 125 | QA dateset: en/zh] 126 | ``` 127 | 128 | ## 🔥 TODO 129 | 130 | - Release code for Qwen-VL-Chat based OmniSearch 131 | - Release the corresponding model weight 132 | - Create a benchmark for Dyn-VQA 133 | 134 | ## 📄 Acknowledge 135 | 136 | - The repo is contributed by Xinyu Wang, Shuo Guo, Zhen Zhang and Yangning Li. 137 | - This work was inspired by ReACT, SelfAsk, FleshLLMs. Sincere thanks for their efforts. 138 | 139 | ## 📝 Citation 140 | 141 | ```bigquery 142 | @article{li2024benchmarkingmultimodalretrievalaugmented, 143 | title={Benchmarking Multimodal Retrieval Augmented Generation with Dynamic VQA Dataset and Self-adaptive Planning Agent}, 144 | author={Yangning Li and Yinghui Li and Xinyu Wang and Yong Jiang and Zhen Zhang and Xinran Zheng and Hui Wang and Hai-Tao Zheng and Pengjun Xie and Philip S. Yu and Fei Huang and Jingren Zhou}, 145 | year={2024}, 146 | eprint={2411.02937}, 147 | archivePrefix={arXiv}, 148 | primaryClass={cs.CL}, 149 | url={https://arxiv.org/abs/2411.02937}, 150 | } 151 | ``` 152 | 153 | 154 | When citing our work, please kindly consider citing the original papers. The relevant citation information is listed here. 155 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-NLP/OmniSearch/b01e78edb694fa0e8b05b808d2722d046f3f204b/src/.DS_Store -------------------------------------------------------------------------------- /src/Omnisearch_gpt/agent.py: -------------------------------------------------------------------------------- 1 | from llm_config import call_gpt 2 | 3 | class QAAgent: 4 | def __init__(self, model, headers): 5 | self.model = model 6 | self.headers = headers 7 | 8 | def ask_gpt(self, messages, idx): 9 | success, idx, message, answer = call_gpt( 10 | self.model, messages, idx, self.headers 11 | ) 12 | 13 | return success, idx, message, answer 14 | 15 | -------------------------------------------------------------------------------- /src/Omnisearch_gpt/conversation_manager.py: -------------------------------------------------------------------------------- 1 | from search_api import fine_search 2 | import prompt as prompt 3 | 4 | class ConversationManager: 5 | def __init__(self, qa_agent, dataset_name, save_path): 6 | self.qa_agent = qa_agent 7 | self.dataset_name = dataset_name 8 | self.save_path = save_path 9 | self.conversation_num = 0 10 | self.total_image_quota = 9 11 | 12 | def manage_conversation(self, input_question, image_url, idx): 13 | self.conversation_num = 0 14 | messages = [ 15 | { 16 | "role": "user", 17 | "content": [ 18 | {"type": "text", "text": prompt.sys_prompt_1.format(input_question)}, 19 | {"type": "image_url", "image_url": {"url": image_url, "detail": "high"}} 20 | ] 21 | } 22 | ] 23 | current_message = messages 24 | 25 | success, idx, message, answer = self.qa_agent.ask_gpt(messages, idx) 26 | print("first response:", answer) 27 | 28 | while self.conversation_num < 5: 29 | if "Final Answer" in answer: 30 | tmp_d = {"role": "assistant"} 31 | tmp_d.update(message) 32 | current_message.append(tmp_d) 33 | print(answer) 34 | print("-------") 35 | print(answer.split("Final Answer: ")[-1]) 36 | return answer.split("Final Answer: ")[-1], current_message 37 | 38 | if any(phrase in answer for phrase in ["Image Retrieval with Input Image", "Text Retrieval", "Image Retrieval with Text Query"]): 39 | tmp_d = {"role": "assistant"} 40 | tmp_d.update(message) 41 | current_message.append(tmp_d) 42 | sub_question = answer.split('\n')[-1].split('\n')[0] 43 | search_images, search_text = self.handle_retrieval(answer, image_url, idx) 44 | 45 | contents = self.prepare_contents(search_images,messages,sub_question,idx, search_text, image_url) 46 | current_message.append({"role": "user", "content": contents}) 47 | 48 | success, idx, message, answer = self.qa_agent.ask_gpt(current_message, idx) 49 | print("conversation step:", self.conversation_num, answer) 50 | if not success: 51 | print("Request failed.") 52 | break 53 | print(self.conversation_num) 54 | self.conversation_num += 1 55 | print(answer) 56 | print(self.conversation_num) 57 | #print(current_message) 58 | print("OVER!") 59 | return answer, current_message 60 | 61 | def handle_retrieval(self, answer, image_url, idx): 62 | if 'Image Retrieval with Input Image' in answer: 63 | return fine_search(image_url, 'img_search_img', self.save_path, self.dataset_name, idx, self.conversation_num) 64 | elif 'Text Retrieval' in answer: 65 | query = self.extract_query(answer, 'Text Retrieval') 66 | return fine_search(query, 'text_search_text', self.save_path, self.dataset_name, idx, self.conversation_num) 67 | elif 'Image Retrieval with Text Query' in answer: 68 | query = self.extract_query(answer, 'Image Retrieval with Text Query') 69 | return fine_search(query, 'text_search_img', self.save_path, self.dataset_name, idx, self.conversation_num) 70 | 71 | def extract_query(self, answer, retrieval_type): 72 | return answer.split(retrieval_type)[-1].replace(':', '').replace('"', '').replace('>', '') 73 | 74 | def prepare_contents(self, search_images,messages,sub_question,idx,search_text, image_url): 75 | if len(search_images) > 0: 76 | #断言失败的时候显示(search_text) 77 | #assert len(search_images) == len(search_text), (search_text) 78 | contents = [{"type": "text", "text": "Contents of retrieved images: "}] 79 | use_imgs_num = min(5, self.total_image_quota) 80 | self.total_image_quota -= use_imgs_num 81 | for img, txt in zip(search_images[:use_imgs_num], search_text[:use_imgs_num]): 82 | contents.extend([ 83 | { 84 | "type": "image_url", 85 | "image_url": { 86 | "url": img[0], 87 | "detail": "high" 88 | } 89 | }, 90 | { 91 | "type": "text", 92 | "text": "Description: "+txt 93 | } 94 | ]) 95 | else: 96 | contents = [ 97 | { 98 | "type": "text", 99 | "text": "Below are related documents retrieved, which may be helpful for answering questions later on:" 100 | } 101 | ] 102 | for txt in search_text: 103 | contents.append({ 104 | "type": "text", 105 | "text": txt 106 | }) 107 | 108 | contents.append({ 109 | "type": "text", 110 | "text": "\nInput Image:" 111 | }) 112 | contents.append({ 113 | "type": "image_url", 114 | "image_url": { 115 | "url": image_url, 116 | "detail": "high" 117 | } 118 | }) 119 | contents.append({ 120 | "type": "text", 121 | "text": sub_question + " Answer:" 122 | }) 123 | sub_messages = [ 124 | { 125 | "role": "user", 126 | "content": contents 127 | } 128 | ] 129 | 130 | success=True 131 | answer=self.qa_agent.ask_gpt(sub_messages,idx) 132 | contents = [{"type": "text", "text": "Contents of retrieved documents: "}] 133 | if success: 134 | contents.extend([{"type": "text", "text": answer}]) 135 | else: 136 | for txt in search_text: 137 | contents.extend([ 138 | { 139 | "type": "text", 140 | "text": txt 141 | } 142 | ]) 143 | return contents 144 | 145 | -------------------------------------------------------------------------------- /src/Omnisearch_gpt/evaluate.py: -------------------------------------------------------------------------------- 1 | from nltk.tokenize import word_tokenize 2 | import re 3 | import sys 4 | import jieba 5 | from collections import Counter 6 | from tqdm import tqdm 7 | import json, argparse 8 | 9 | contractions = { 10 | 'aint': "ain't", 11 | 'arent': "aren't", 12 | 'cant': "can't", 13 | 'couldve': "could've", 14 | 'couldnt': "couldn't", 15 | "couldn'tve": "couldn't've", 16 | "couldnt've": "couldn't've", 17 | 'didnt': "didn't", 18 | 'doesnt': "doesn't", 19 | 'dont': "don't", 20 | 'hadnt': "hadn't", 21 | "hadnt've": "hadn't've", 22 | "hadn'tve": "hadn't've", 23 | 'hasnt': "hasn't", 24 | 'havent': "haven't", 25 | 'hed': "he'd", 26 | "hed've": "he'd've", 27 | "he'dve": "he'd've", 28 | 'hes': "he's", 29 | 'howd': "how'd", 30 | 'howll': "how'll", 31 | 'hows': "how's", 32 | "Id've": "I'd've", 33 | "I'dve": "I'd've", 34 | 'Im': "I'm", 35 | 'Ive': "I've", 36 | 'isnt': "isn't", 37 | 'itd': "it'd", 38 | "itd've": "it'd've", 39 | "it'dve": "it'd've", 40 | 'itll': "it'll", 41 | "let's": "let's", 42 | 'maam': "ma'am", 43 | 'mightnt': "mightn't", 44 | "mightnt've": "mightn't've", 45 | "mightn'tve": "mightn't've", 46 | 'mightve': "might've", 47 | 'mustnt': "mustn't", 48 | 'mustve': "must've", 49 | 'neednt': "needn't", 50 | 'notve': "not've", 51 | 'oclock': "o'clock", 52 | 'oughtnt': "oughtn't", 53 | "ow's'at": "'ow's'at", 54 | "'ows'at": "'ow's'at", 55 | "'ow'sat": "'ow's'at", 56 | 'shant': "shan't", 57 | "shed've": "she'd've", 58 | "she'dve": "she'd've", 59 | "she's": "she's", 60 | 'shouldve': "should've", 61 | 'shouldnt': "shouldn't", 62 | "shouldnt've": "shouldn't've", 63 | "shouldn'tve": "shouldn't've", 64 | "somebody'd": 'somebodyd', 65 | "somebodyd've": "somebody'd've", 66 | "somebody'dve": "somebody'd've", 67 | 'somebodyll': "somebody'll", 68 | 'somebodys': "somebody's", 69 | 'someoned': "someone'd", 70 | "someoned've": "someone'd've", 71 | "someone'dve": "someone'd've", 72 | 'someonell': "someone'll", 73 | 'someones': "someone's", 74 | 'somethingd': "something'd", 75 | "somethingd've": "something'd've", 76 | "something'dve": "something'd've", 77 | 'somethingll': "something'll", 78 | 'thats': "that's", 79 | 'thered': "there'd", 80 | "thered've": "there'd've", 81 | "there'dve": "there'd've", 82 | 'therere': "there're", 83 | 'theres': "there's", 84 | 'theyd': "they'd", 85 | "theyd've": "they'd've", 86 | "they'dve": "they'd've", 87 | 'theyll': "they'll", 88 | 'theyre': "they're", 89 | 'theyve': "they've", 90 | 'twas': "'twas", 91 | 'wasnt': "wasn't", 92 | "wed've": "we'd've", 93 | "we'dve": "we'd've", 94 | 'weve': "we've", 95 | 'werent': "weren't", 96 | 'whatll': "what'll", 97 | 'whatre': "what're", 98 | 'whats': "what's", 99 | 'whatve': "what've", 100 | 'whens': "when's", 101 | 'whered': "where'd", 102 | 'wheres': "where's", 103 | 'whereve': "where've", 104 | 'whod': "who'd", 105 | "whod've": "who'd've", 106 | "who'dve": "who'd've", 107 | 'wholl': "who'll", 108 | 'whos': "who's", 109 | 'whove': "who've", 110 | 'whyll': "why'll", 111 | 'whyre': "why're", 112 | 'whys': "why's", 113 | 'wont': "won't", 114 | 'wouldve': "would've", 115 | 'wouldnt': "wouldn't", 116 | "wouldnt've": "wouldn't've", 117 | "wouldn'tve": "wouldn't've", 118 | 'yall': "y'all", 119 | "yall'll": "y'all'll", 120 | "y'allll": "y'all'll", 121 | "yall'd've": "y'all'd've", 122 | "y'alld've": "y'all'd've", 123 | "y'all'dve": "y'all'd've", 124 | 'youd': "you'd", 125 | "youd've": "you'd've", 126 | "you'dve": "you'd've", 127 | 'youll': "you'll", 128 | 'youre': "you're", 129 | 'youve': "you've", 130 | } 131 | 132 | manualMap = { 133 | 'none': '0', 134 | 'zero': '0', 135 | 'one': '1', 136 | 'two': '2', 137 | 'three': '3', 138 | 'four': '4', 139 | 'five': '5', 140 | 'six': '6', 141 | 'seven': '7', 142 | 'eight': '8', 143 | 'nine': '9', 144 | 'ten': '10', 145 | } 146 | 147 | articles = ['a', 'an', 'the'] 148 | 149 | periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') 150 | 151 | commaStrip = re.compile('(\d)(,)(\d)') 152 | 153 | punct = [ 154 | ';', 155 | r'/', 156 | '[', 157 | ']', 158 | '"', 159 | '{', 160 | '}', 161 | '(', 162 | ')', 163 | '=', 164 | '+', 165 | '\\', 166 | '_', 167 | '-', 168 | '>', 169 | '<', 170 | '@', 171 | '`', 172 | ',', 173 | '?', 174 | '!', 175 | ] 176 | 177 | parser = argparse.ArgumentParser(description="") 178 | parser.add_argument('--evaluate_file_path', default='') 179 | parser.add_argument('--lang', default='en') 180 | args = parser.parse_args() 181 | 182 | def process_string(s): 183 | s = str(s) 184 | words = [] 185 | for word in ' '.join(jieba.cut(s)).split(): 186 | if word not in ',、。 ,.《》': 187 | words.append(word) 188 | return words 189 | 190 | def process_string_en(s): 191 | s = str(s).lower() 192 | words = [] 193 | for word in word_tokenize(s): 194 | if word not in ',.?!:;\'"': 195 | words.append(word) 196 | return words 197 | 198 | def compute_acc_single(gold_toks, pred_toks): 199 | common = Counter(gold_toks) & Counter(pred_toks) 200 | num_same = sum(common.values()) 201 | if len(gold_toks) == 0 or len(pred_toks) == 0: 202 | return float(gold_toks == pred_toks) 203 | if num_same == 0: 204 | return 0 205 | return num_same / len(gold_toks) 206 | 207 | def compute_acc(a_golds, a_pred, lang): 208 | if lang == 'zh': 209 | if a_pred == '': 210 | return 0 211 | golds_toks = [process_string(a_gold) for a_gold in a_golds] 212 | pred_toks = process_string(a_pred) 213 | elif lang == 'en': 214 | if a_pred == '': 215 | return 0 216 | golds_toks = [process_string_en(a_gold) for a_gold in a_golds] 217 | pred_toks = process_string_en(a_pred) 218 | 219 | return max( 220 | compute_acc_single(gold_toks, pred_toks) for gold_toks in golds_toks) 221 | 222 | def processPunctuation(inText): 223 | outText = inText 224 | for p in punct: 225 | if (p + ' ' in inText or ' ' + p 226 | in inText) or (re.search(commaStrip, inText) != None): 227 | outText = outText.replace(p, '') 228 | else: 229 | outText = outText.replace(p, ' ') 230 | outText = periodStrip.sub('', outText, re.UNICODE) 231 | return outText 232 | 233 | def processDigitArticle(inText): 234 | outText = [] 235 | tempText = inText.lower().split() 236 | for word in tempText: 237 | word = manualMap.setdefault(word, word) 238 | if word not in articles: 239 | outText.append(word) 240 | else: 241 | pass 242 | for wordId, word in enumerate(outText): 243 | if word in contractions: 244 | outText[wordId] = contractions[word] 245 | outText = ' '.join(outText) 246 | return outText 247 | 248 | evaluate_file_path = args.evaluate_file_path 249 | 250 | acc_list = [] 251 | f = open(evaluate_file_path, 'r') 252 | for idd, line in enumerate(f.readlines()): 253 | data = json.loads(line) 254 | resAns = data['prediction'] 255 | resAns = resAns.replace('\n', ' ') 256 | resAns = resAns.replace('\t', ' ') 257 | resAns = resAns.strip() 258 | resAns = processPunctuation(resAns) 259 | resAns = processDigitArticle(resAns) 260 | 261 | gtAnswers = data['answer'] 262 | avgGTAcc = compute_acc(a_golds=gtAnswers, a_pred=resAns, lang=args.lang) 263 | acc_list.append(avgGTAcc) 264 | 265 | print('Token F1-Recall: ', round(100 * float(sum(acc_list)) / len(acc_list), 2)) 266 | 267 | -------------------------------------------------------------------------------- /src/Omnisearch_gpt/llm_config.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import time 3 | 4 | workers=4 5 | skip_lines = 0 6 | 7 | LIMIT = 1000000000000 8 | retry_attempt = 10 9 | 10 | 11 | def call_gpt(model, messages, idx, headers ): 12 | # 准备请求数据,包括模型、对话信息等参数 13 | data = { 14 | "model": model, 15 | "messages": messages, 16 | "n": 1, # 回答数量 17 | "max_tokens": 4096 18 | } 19 | 20 | answer = None 21 | while answer is None: 22 | try: 23 | 24 | r = requests.post( 25 | #'https://api.chatanywhere.tech/v1/chat/completions', 26 | 'https://api.openai.com/v1/chat/completions', 27 | json=data, 28 | headers=headers 29 | ) 30 | resp = r.json() 31 | 32 | if r.status_code != 200: 33 | print('请求失败,重试中!') 34 | print(resp) 35 | continue 36 | 37 | if 'choices' in resp and resp['choices'][0].get('finish_reason') in ['content_filter', 'ResponsibleAIPolicyViolation']: 38 | print('内容不符合策略要求,返回空结果') 39 | return (False, idx, "", "", 0, 0, 0) 40 | message = resp['choices'][0]['message'] 41 | answer = message['content'] 42 | 43 | return (True, idx, message, answer) 44 | 45 | except Exception as e: 46 | print(e) 47 | print('发生异常,重试中!') 48 | time.sleep(1) # 等待一段时间再重试 49 | continue 50 | 51 | -------------------------------------------------------------------------------- /src/Omnisearch_gpt/main.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from concurrent.futures import ThreadPoolExecutor 3 | import os 4 | import json 5 | import asyncio 6 | import argparse 7 | from agent import QAAgent 8 | from conversation_manager import ConversationManager 9 | 10 | model = "gpt-4o" 11 | GPT_API_KEY = "" 12 | headers = { 13 | "Authorization": f"Bearer {GPT_API_KEY}" 14 | } 15 | 16 | meta_save_path = './vfreshqa_datasets_v2/'#保存地址 17 | 18 | 19 | write_lock = threading.Lock() 20 | 21 | def safe_write(file_path, data): 22 | with write_lock: 23 | with open(file_path, "a", encoding="utf-8") as f: 24 | f.write(json.dumps(data, ensure_ascii=False) + "\n") 25 | 26 | def process_item(item, conversation_manager, meta_save_path, dataset_name): 27 | input_question = item['question'] 28 | idx = item['question_id'] 29 | image_url = item['image_url'] 30 | 31 | answer, current_message = conversation_manager.manage_conversation( 32 | input_question=input_question, image_url=image_url, idx=idx 33 | ) 34 | 35 | # 将结果保存到 item 中 36 | item['prediction'] = answer 37 | # 保存结果 38 | output_path = os.path.join(meta_save_path, dataset_name, "output_from_gpt4v.jsonl") 39 | safe_write(output_path, item) 40 | 41 | def main(test_dataset,dataset_name, meta_save_path): 42 | 43 | num_threads = 1 44 | 45 | qa_agent = QAAgent(model=model, headers=headers) 46 | 47 | with open(test_dataset, "r", encoding="utf-8") as f: 48 | datas = [json.loads(line) for line in f.readlines()] 49 | 50 | output_path = os.path.join(meta_save_path, dataset_name, "output_from_gpt4v.jsonl") 51 | if os.path.exists(output_path): 52 | with open(output_path, "r") as fin: 53 | done_id = [json.loads(data)['question_id'] for data in fin.readlines()] 54 | datas = [data for data in datas if data['question_id'] not in done_id] 55 | 56 | save_path = os.path.join(meta_save_path, dataset_name, "search_images_gpt4v") 57 | os.makedirs(save_path, exist_ok=True) 58 | 59 | conversation_manager = ConversationManager(qa_agent=qa_agent, dataset_name=dataset_name, save_path=save_path) 60 | for item in datas: 61 | process_item(item, conversation_manager, meta_save_path, dataset_name) 62 | 63 | 64 | 65 | if __name__ == "__main__": 66 | # 设置命令行参数 67 | parser = argparse.ArgumentParser(description="运行指定的数据集") 68 | parser.add_argument("--test_dataset", type=str, required=True, help="数据集路径") 69 | parser.add_argument("--dataset_name", type=str, required=True, help="数据集名称") 70 | parser.add_argument("--meta_save_path", type=str, required=True, help="存储路径") 71 | 72 | 73 | args = parser.parse_args() 74 | 75 | # 调用 main 函数并传递解析后的参数 76 | main(args.test_dataset,args.dataset_name, args.meta_save_path) 77 | -------------------------------------------------------------------------------- /src/Omnisearch_gpt/prompt.py: -------------------------------------------------------------------------------- 1 | sys_prompt_1 = '''You are a helpful multimodal question answering assistant. Decompose the original question into sub-questions and solve them step by step. You can use "Final Answer" to output a sentence in the answer, use "Search" to state what additional context or information is needed to provide a precise answer to the "Sub-Question". In the "Search" step, You can use "Image Retrieval with Input Image" to seek images similar to the original ones and determine their titles, "Text Retrieval" with a specific query to fetch pertinent documents and summarize their content, "Image Retrieval with Text Query" to fetch images related to the entered keywords. 2 | Use the following format strictly: 3 | 4 | Analyse questions and answer of the sub-questions, then think about what is next sub-question. 5 | 6 | Sub-Question needs to be solved in one step, without references. 7 | 8 | One of three retrieval methods: Image Retrieval with Input Image. Text Retrieval: xxx. Image Retrieval with Text Query: xxx. 9 | 10 | ... (this Thought/Sub-Question/Search can be repeated zero or more times) 11 | 12 | 13 | Integrate retrieved information and reason to a final answer 14 | 15 | Final Answer: the final answer to the original input question 16 | 17 | Extra notes: 18 | 1. Do not use you own knowledge to analyse input image or answer questions 19 | 2. After you give each action, please wait for me to provide you with the answer to the sub-question, and then think about the next thought carefully. 20 | 3. The answers to the questions can be found on the internet and are not private 21 | 22 | Input Question:{} 23 | ''' -------------------------------------------------------------------------------- /src/Omnisearch_gpt/search_api.py: -------------------------------------------------------------------------------- 1 | from mimetypes import guess_type 2 | import base64 3 | import json 4 | import os 5 | import time 6 | from io import BytesIO 7 | 8 | import requests 9 | from PIL import Image 10 | from serpapi import GoogleSearch 11 | 12 | API_KEY = "" 13 | retry_attempt = 3 14 | 15 | 16 | def local_image_to_data_url(image_path): 17 | mime_type, _ = guess_type(image_path) 18 | if mime_type is None: 19 | mime_type = "application/octet-stream" 20 | with open(image_path, "rb") as f: 21 | b64 = base64.b64encode(f.read()).decode("utf-8") 22 | return f"data:{mime_type};base64,{b64}" 23 | 24 | 25 | def search_text_by_text(text): 26 | params = { 27 | "engine": "google", 28 | "q": text, 29 | "api_key": API_KEY, 30 | "num": 5, 31 | } 32 | for i in range(retry_attempt): 33 | try: 34 | search = GoogleSearch(params) 35 | results = search.get_dict() 36 | return results.get("organic_results", []) 37 | except Exception as e: 38 | print(f"Attempt {i+1} failed: {e}") 39 | if i < retry_attempt - 1: 40 | time.sleep(2) 41 | else: 42 | print("All retries failed.") 43 | return [] 44 | 45 | 46 | def search_image_by_text(text): 47 | params = { 48 | "engine": "google_images", 49 | "q": text, 50 | "api_key": API_KEY, 51 | } 52 | for i in range(retry_attempt): 53 | try: 54 | search = GoogleSearch(params) 55 | results = search.get_dict() 56 | images = results.get("images_results", []) 57 | return images[0] if images else {} 58 | except Exception as e: 59 | print(f"Attempt {i+1} failed: {e}") 60 | if i < retry_attempt - 1: 61 | time.sleep(2) 62 | else: 63 | print("All retries failed.") 64 | return {} 65 | 66 | 67 | def search_image_by_image_url(input_url): 68 | params = { 69 | "engine": "google_reverse_image", 70 | "image_url": input_url, 71 | "hl": "zh-CN", 72 | "gl": "CN", 73 | "api_key": API_KEY, 74 | } 75 | for i in range(retry_attempt): 76 | try: 77 | search = GoogleSearch(params) 78 | return search.get_dict() 79 | except Exception as e: 80 | print(f"Attempt {i+1} failed: {e}") 81 | if "SSLError" in str(e): 82 | print("SSL error encountered.") 83 | elif "ConnectionError" in str(e): 84 | print("Network connection error.") 85 | if i < retry_attempt - 1: 86 | time.sleep(2) 87 | else: 88 | print("All retries failed. Returning empty result.") 89 | return {} 90 | 91 | 92 | def parse_image_search_result_by_text(result, save_path, idx, conversation_num): 93 | images, texts = [], [] 94 | url = result.get("thumbnail") 95 | try: 96 | resp = requests.get(url) 97 | resp.raise_for_status() 98 | img = Image.open(BytesIO(resp.content)) 99 | fname = f"{idx}_{conversation_num}_{result.get('position','0')}.png" 100 | out = os.path.join(save_path, fname) 101 | img.save(out, format="PNG") 102 | images.append((url, out)) 103 | texts.append(result.get("title", "")) 104 | except Exception as e: 105 | print(f"Failed to save thumbnail {url}: {e}") 106 | return images, texts 107 | 108 | 109 | def parse_image_search_result_by_image(results, save_path, idx, conversation_num): 110 | images, texts = [], [] 111 | kg = results.get("knowledge_graph", {}) 112 | if "header_images" in kg: 113 | for item in kg["header_images"]: 114 | url = item.get("source") 115 | try: 116 | resp = requests.get(url) 117 | resp.raise_for_status() 118 | img = Image.open(BytesIO(resp.content)) 119 | fname = f"{idx}_{conversation_num}_header.png" 120 | out = os.path.join(save_path, fname) 121 | img.save(out, format="PNG") 122 | text = f"{kg.get('title','')}: {kg.get('description','')}" 123 | images.append((url, out)) 124 | texts.append(text) 125 | except Exception as e: 126 | print(f"Failed to save header image {url}: {e}") 127 | elif "image_results" in results: 128 | for item in results["image_results"]: 129 | snippet = item.get("snippet") 130 | if snippet: 131 | texts.append(snippet) 132 | else: 133 | print("No 'knowledge_graph' or 'image_results' in response.") 134 | return images, texts 135 | 136 | 137 | def fine_search(query, search_type, save_path, dataset_name, idx, conversation_num): 138 | if search_type == "text_search_text": 139 | results = search_text_by_text(query) 140 | texts = [item.get("title","") + item.get("snippet","") for item in results] 141 | return [], texts 142 | 143 | if search_type == "img_search_img": 144 | cache = os.path.join(save_path, dataset_name, f"image_search_res_{idx}.json") 145 | if os.path.exists(cache): 146 | with open(cache) as f: 147 | saved = json.load(f) 148 | imgs, txts = parse_image_search_result_by_image(saved, save_path, idx, conversation_num) 149 | if not txts: 150 | saved = search_image_by_image_url(query) 151 | imgs, txts = parse_image_search_result_by_image(saved, save_path, idx, conversation_num) 152 | return imgs, txts 153 | saved = search_image_by_image_url(query) 154 | return parse_image_search_result_by_image(saved, save_path, idx, conversation_num) 155 | 156 | if search_type == "text_search_img": 157 | result = search_image_by_text(query) 158 | return parse_image_search_result_by_text(result, save_path, idx, conversation_num) 159 | 160 | return [], [] 161 | -------------------------------------------------------------------------------- /src/Omnisearch_qwen/Omnisearch_qwen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import threading 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from swift.llm import ( 8 | ModelType, 9 | get_default_template_type, 10 | get_model_tokenizer, 11 | get_template, 12 | inference, 13 | ) 14 | from swift.utils import seed_everything 15 | from search_api import fine_search 16 | 17 | import argparse 18 | 19 | 20 | # 使用 argparse 解析参数 21 | parser = argparse.ArgumentParser(description="Multimodal Question Answering Agent") 22 | parser.add_argument( 23 | "--test_dataset", 24 | type=str, 25 | required=True, 26 | help="Path to the test dataset file (e.g., .jsonl)", 27 | ) 28 | parser.add_argument( 29 | "--dataset_name", 30 | type=str, 31 | required=True, 32 | help="Name of the dataset (used for saving outputs)", 33 | ) 34 | parser.add_argument( 35 | "--meta_save_path", 36 | type=str, 37 | required=True, 38 | help="Root directory to save intermediate and final outputs", 39 | ) 40 | parser.add_argument( 41 | "--model_path", 42 | type=str, 43 | required=True, 44 | help="Path to the model checkpoint (e.g., /home/user/model)", 45 | ) 46 | args = parser.parse_args() 47 | 48 | 49 | model_type = ModelType.qwen_vl_chat 50 | template_type = get_default_template_type(model_type) 51 | print(f"template_type: {template_type}") 52 | 53 | model, tokenizer = get_model_tokenizer( 54 | model_type, 55 | model_kwargs={"device_map": "auto"}, 56 | model_id_or_path=args.model_path, 57 | ) 58 | 59 | model.config.seq_length = 8192 60 | model.generation_config.max_new_tokens = 4096 61 | 62 | template = get_template(template_type, tokenizer) 63 | print(f"template: {template}") 64 | seed_everything(42) 65 | 66 | call_image_num = 0 67 | call_image_num_succ = 0 68 | 69 | SYS_PROMPT = '''You are a helpful multimodal question answering assistant. Decompose the original question into sub-questions and solve them step by step. You can use "Final Answer" to output a sentence in the answer, use "Search" to state what additional context or information is needed to provide a precise answer to the "Sub-Question". In the "Search" step, You can use "Image Retrieval with Input Image" to seek images similar to the original ones and determine their titles, "Text Retrieval" with a specific query to fetch pertinent documents and summarize their content, "Image Retrieval with Text Query" to fetch images related to the entered keywords. 70 | Use the following format strictly: 71 | 72 | Analyse questions and answer of the sub-questions, then think about what is next sub-question. 73 | 74 | Sub-Question needs to be solved in one step, without references. 75 | 76 | One of four retrieval methods: Image Retrieval with Input Image. Text Retrieval: xxx. Image Retrieval with Text Query: xxx. No Retrieval 77 | 78 | ... (this Thought/Sub-Question/Search can be repeated zero or more times) 79 | 80 | 81 | Integrate retrieved information and reason to a final answer 82 | 83 | Final Answer: the final answer to the original input question 84 | 85 | Extra notes: 86 | 1. Do not use you own knowledge to analyse input image or answer questions 87 | 2. After you give each action, please wait for me to provide you with the answer to the sub-question, and then think about the next thought carefully. 88 | 3. The answers to the questions can be found on the internet and are not private 89 | 90 | Input Question:{}''' 91 | 92 | def vqa_agent_v3( 93 | input_question: str, 94 | image_url: str, 95 | idx: int, 96 | search_image_save_path: str, 97 | args, 98 | ): 99 | global call_image_num, call_image_num_succ 100 | 101 | query = SYS_PROMPT.format(input_question) + f"\n{image_url}" 102 | response, history = inference(model, template, query) 103 | print("first response:\n", response) 104 | 105 | conversation_num, max_turns = 0, 5 106 | total_image_quota = 9 107 | 108 | while conversation_num < max_turns: 109 | if "Final Answer" in response: 110 | final_answer = response.split("Final Answer:")[-1].strip() 111 | return final_answer, history 112 | 113 | need_img_ret = "Image Retrieval with Input Image" in response 114 | need_txt_ret = "Text Retrieval" in response 115 | need_txt_img_ret = "Image Retrieval with Text Query" in response 116 | 117 | if need_img_ret or need_txt_ret or need_txt_img_ret: 118 | if need_img_ret: 119 | call_image_num += 1 120 | search_images, search_text = fine_search( 121 | image_url, 122 | "img_search_img", 123 | search_image_save_path, 124 | args.dataset_name, # 使用 args 参数 125 | idx, 126 | conversation_num, 127 | ) 128 | if search_images: 129 | call_image_num_succ += 1 130 | 131 | elif need_txt_ret: 132 | query_txt = ( 133 | response.split("Text Retrieval")[-1] 134 | .replace(":", "") 135 | .replace('"', "") 136 | .replace(">", "") 137 | ) 138 | search_images, search_text = fine_search( 139 | query_txt, 140 | "text_search_text", 141 | search_image_save_path, 142 | args.dataset_name, # 使用 args 参数 143 | idx, 144 | conversation_num, 145 | ) 146 | 147 | else: # need_txt_img_ret 148 | query_txt = ( 149 | response.split("Image Retrieval with Text Query")[-1] 150 | .replace(":", "") 151 | .replace('"', "") 152 | .replace(">", "") 153 | ) 154 | search_images, search_text = fine_search( 155 | query_txt, 156 | "text_search_img", 157 | search_image_save_path, 158 | args.dataset_name, # 使用 args 参数 159 | idx, 160 | conversation_num, 161 | ) 162 | 163 | contents = [] 164 | if search_images: 165 | assert len(search_images) == len(search_text) 166 | use_n = min(5, total_image_quota) 167 | total_image_quota -= use_n 168 | contents.append("Contents of retrieved images:") 169 | for img, txt in zip(search_images[:use_n], search_text[:use_n]): 170 | contents.extend([f"{img[0]}", f"Description: {txt}"]) 171 | elif search_text: 172 | contents.append("Contents of retrieved documents:") 173 | contents.extend(search_text) 174 | else: 175 | contents.append("No relevant information found.") 176 | 177 | try: 178 | response, history = inference( 179 | model, template, "\n".join(contents), history 180 | ) 181 | except Exception as e: 182 | print("Inference error:", e) 183 | return response, history 184 | 185 | conversation_num += 1 186 | 187 | return response, history 188 | 189 | def safe_write(file_path: str, data: dict): 190 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 191 | with threading.Lock(): 192 | with open(file_path, "a", encoding="utf-8") as f: 193 | f.write(json.dumps(data, ensure_ascii=False) + "\n") 194 | 195 | def process_item(item: dict, save_dir: str, meta_save_path: str, ds_name: str, args): 196 | answer, conv = vqa_agent_v3( 197 | item["question"], 198 | item["image_url"], 199 | item["question_id"], 200 | save_dir, 201 | args # 传递 args 参数 202 | ) 203 | item["prediction"] = answer 204 | item["conversation"] = conv 205 | 206 | out_path = os.path.join(meta_save_path, ds_name, "output.jsonl") 207 | safe_write(out_path, item) 208 | 209 | def main(): 210 | test_dataset = args.test_dataset 211 | ds_name = args.dataset_name 212 | meta_save_path = args.meta_save_path 213 | 214 | # 读取测试数据集 215 | with open(test_dataset, "r") as f: 216 | data = [json.loads(line) for line in f] 217 | 218 | # 检查已处理的样本 219 | output_file = os.path.join(meta_save_path, ds_name, "output.jsonl") 220 | if os.path.exists(output_file): 221 | with open(output_file, "r") as fin: 222 | done = {json.loads(line)["question_id"] for line in fin} 223 | data = [d for d in data if d["question_id"] not in done] 224 | 225 | search_img_save = os.path.join(meta_save_path, ds_name, "search_images") 226 | os.makedirs(search_img_save, exist_ok=True) 227 | 228 | with ThreadPoolExecutor(max_workers=1) as executor: 229 | futures = [] 230 | for item in data: 231 | futures.append( 232 | executor.submit( 233 | process_item, 234 | item, 235 | search_img_save, 236 | meta_save_path, 237 | ds_name, 238 | args 239 | ) 240 | ) 241 | for f in futures: 242 | f.result() 243 | 244 | print(f"Image search calls: {call_image_num} | Success: {call_image_num_succ}") 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /src/Omnisearch_qwen/search_api.py: -------------------------------------------------------------------------------- 1 | from mimetypes import guess_type 2 | import base64 3 | import json 4 | import os 5 | import time 6 | from io import BytesIO 7 | 8 | import requests 9 | from PIL import Image 10 | from serpapi import GoogleSearch 11 | 12 | API_KEY = "" 13 | retry_attempt = 3 14 | 15 | 16 | def local_image_to_data_url(image_path): 17 | mime_type, _ = guess_type(image_path) 18 | if mime_type is None: 19 | mime_type = "application/octet-stream" 20 | with open(image_path, "rb") as f: 21 | b64 = base64.b64encode(f.read()).decode("utf-8") 22 | return f"data:{mime_type};base64,{b64}" 23 | 24 | 25 | def search_text_by_text(text): 26 | params = { 27 | "engine": "google", 28 | "q": text, 29 | "api_key": API_KEY, 30 | "num": 5, 31 | } 32 | for i in range(retry_attempt): 33 | try: 34 | search = GoogleSearch(params) 35 | results = search.get_dict() 36 | return results.get("organic_results", []) 37 | except Exception as e: 38 | print(f"Attempt {i+1} failed: {e}") 39 | if i < retry_attempt - 1: 40 | time.sleep(2) 41 | else: 42 | print("All retries failed.") 43 | return [] 44 | 45 | 46 | def search_image_by_text(text): 47 | params = { 48 | "engine": "google_images", 49 | "q": text, 50 | "api_key": API_KEY, 51 | } 52 | for i in range(retry_attempt): 53 | try: 54 | search = GoogleSearch(params) 55 | results = search.get_dict() 56 | images = results.get("images_results", []) 57 | return images[0] if images else {} 58 | except Exception as e: 59 | print(f"Attempt {i+1} failed: {e}") 60 | if i < retry_attempt - 1: 61 | time.sleep(2) 62 | else: 63 | print("All retries failed.") 64 | return {} 65 | 66 | 67 | def search_image_by_image_url(input_url): 68 | params = { 69 | "engine": "google_reverse_image", 70 | "image_url": input_url, 71 | "hl": "zh-CN", 72 | "gl": "CN", 73 | "api_key": API_KEY, 74 | } 75 | for i in range(retry_attempt): 76 | try: 77 | search = GoogleSearch(params) 78 | return search.get_dict() 79 | except Exception as e: 80 | print(f"Attempt {i+1} failed: {e}") 81 | if "SSLError" in str(e): 82 | print("SSL error encountered.") 83 | elif "ConnectionError" in str(e): 84 | print("Network connection error.") 85 | if i < retry_attempt - 1: 86 | time.sleep(2) 87 | else: 88 | print("All retries failed. Returning empty result.") 89 | return {} 90 | 91 | 92 | def parse_image_search_result_by_text(result, save_path, idx, conversation_num): 93 | images, texts = [], [] 94 | url = result.get("thumbnail") 95 | try: 96 | resp = requests.get(url) 97 | resp.raise_for_status() 98 | img = Image.open(BytesIO(resp.content)) 99 | fname = f"{idx}_{conversation_num}_{result.get('position','0')}.png" 100 | out = os.path.join(save_path, fname) 101 | img.save(out, format="PNG") 102 | images.append((url, out)) 103 | texts.append(result.get("title", "")) 104 | except Exception as e: 105 | print(f"Failed to save thumbnail {url}: {e}") 106 | return images, texts 107 | 108 | 109 | def parse_image_search_result_by_image(results, save_path, idx, conversation_num): 110 | images, texts = [], [] 111 | kg = results.get("knowledge_graph", {}) 112 | if "header_images" in kg: 113 | for item in kg["header_images"]: 114 | url = item.get("source") 115 | try: 116 | resp = requests.get(url) 117 | resp.raise_for_status() 118 | img = Image.open(BytesIO(resp.content)) 119 | fname = f"{idx}_{conversation_num}_header.png" 120 | out = os.path.join(save_path, fname) 121 | img.save(out, format="PNG") 122 | text = f"{kg.get('title','')}: {kg.get('description','')}" 123 | images.append((url, out)) 124 | texts.append(text) 125 | except Exception as e: 126 | print(f"Failed to save header image {url}: {e}") 127 | elif "image_results" in results: 128 | for item in results["image_results"]: 129 | snippet = item.get("snippet") 130 | if snippet: 131 | texts.append(snippet) 132 | else: 133 | print("No 'knowledge_graph' or 'image_results' in response.") 134 | return images, texts 135 | 136 | 137 | def fine_search(query, search_type, save_path, idx, conversation_num): 138 | if search_type == "text_search_text": 139 | results = search_text_by_text(query) 140 | texts = [item.get("title","") + item.get("snippet","") for item in results] 141 | return [], texts 142 | 143 | if search_type == "img_search_img": 144 | cache = os.path.join(save_path, f"image_search_res_{idx}.json") 145 | if os.path.exists(cache): 146 | with open(cache) as f: 147 | saved = json.load(f) 148 | imgs, txts = parse_image_search_result_by_image(saved, save_path, idx, conversation_num) 149 | if not txts: 150 | saved = search_image_by_image_url(query) 151 | imgs, txts = parse_image_search_result_by_image(saved, save_path, idx, conversation_num) 152 | return imgs, txts 153 | saved = search_image_by_image_url(query) 154 | return parse_image_search_result_by_image(saved, save_path, idx, conversation_num) 155 | 156 | if search_type == "text_search_img": 157 | result = search_image_by_text(query) 158 | return parse_image_search_result_by_text(result, save_path, idx, conversation_num) 159 | 160 | return [], [] 161 | --------------------------------------------------------------------------------