├── scripts ├── iag │ ├── step5_run_hierarchical_filter.py │ ├── step3_crawl_from_urls.py │ ├── step2_call_search_engines.py │ ├── step4_gen_data4hierarchical_filter.py │ ├── step6_gen_test_searchlvlms.py │ └── step1_gen_search_queries.py ├── utils │ ├── __pycache__ │ │ ├── utils.cpython-39.pyc │ │ └── utils_engine.cpython-39.pyc │ ├── utils.py │ └── utils_engine.py ├── unset_env_variable.bash ├── init_env_variable.sh ├── iag.sh ├── eval │ └── eval_lvlms.py └── eval.sh ├── environment.yml └── README.md /scripts/iag/step5_run_hierarchical_filter.py: -------------------------------------------------------------------------------- 1 | # no python code 2 | # run hierarchical filter by line 52 in iag.sh -------------------------------------------------------------------------------- /scripts/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeverMoreLCH/SearchLVLMs/HEAD/scripts/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /scripts/utils/__pycache__/utils_engine.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeverMoreLCH/SearchLVLMs/HEAD/scripts/utils/__pycache__/utils_engine.cpython-39.pyc -------------------------------------------------------------------------------- /scripts/unset_env_variable.bash: -------------------------------------------------------------------------------- 1 | unset llama3_dir 2 | unset vlmevalkit_dir 3 | 4 | unset OPENAI_API_KEY 5 | unset OPENAI_ENDPOINT 6 | 7 | unset google_api_key 8 | unset google_text_cse_id 9 | unset google_image_cse_id 10 | unset bing_text_api_key 11 | unset bing_img_api_key 12 | unset bing_visual_api_key -------------------------------------------------------------------------------- /scripts/init_env_variable.sh: -------------------------------------------------------------------------------- 1 | export llama3_dir="" 2 | export vlmevalkit_dir="" 3 | 4 | export OPENAI_API_KEY="" 5 | export OPENAI_ENDPOINT="" 6 | 7 | export google_api_key="" 8 | export google_text_cse_id="" 9 | export google_image_cse_id="" 10 | export bing_text_api_key="" 11 | export bing_img_api_key="" 12 | export bing_visual_api_key="" -------------------------------------------------------------------------------- /scripts/iag/step3_crawl_from_urls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('scripts/utils') 4 | from utils import check_fetched_text, text_chunk, fetch_by_newspaper 5 | 6 | import json 7 | from tqdm import tqdm 8 | import time 9 | import argparse 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--test_data_path', type=str, default='./datasets/test/UDK-VQA/test_raw.jsonl') 15 | parser.add_argument('--save_step', type=int, default=500) 16 | 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def fetch_web_info_by_crawl(url): 22 | try: 23 | fetched_dict = fetch_by_newspaper(url) 24 | except: 25 | fetched_dict = {'fetched_title': '', 'fetched_text': ''} 26 | 27 | fetched_text = fetched_dict['fetched_text'] 28 | fetched_text_chunked_list = [] 29 | if check_fetched_text(fetched_text): 30 | fetched_text_chunked_list = text_chunk(fetched_text, 150) 31 | 32 | return fetched_text, fetched_text_chunked_list 33 | 34 | 35 | def main(): 36 | 37 | args = parse_args() 38 | 39 | test_data_path = args.test_data_path 40 | 41 | dataset_name = os.path.dirname(test_data_path).split('/')[-1] 42 | query2url_path = './intermediate_files/{}/query2url.json'.format(dataset_name) 43 | url2info_path = './intermediate_files/{}/url2info.json'.format(dataset_name) 44 | 45 | folder_path = os.path.dirname(query2url_path) 46 | os.makedirs(folder_path, exist_ok=True) 47 | 48 | if os.path.exists(query2url_path): 49 | with open(query2url_path, 'r') as f: 50 | query2url = json.load(f) 51 | else: 52 | query2url = {} 53 | 54 | url2info = {} 55 | for k, v in query2url.items(): 56 | for url in v: 57 | url2info[url] = {} 58 | url2info = dict(sorted(url2info.items())) 59 | 60 | candidate = [] 61 | for k, v in url2info.items(): 62 | candidate.append(k) 63 | 64 | total_sample_num = len(candidate) 65 | print(len(url2info.keys()), total_sample_num) 66 | 67 | last_url_type = '' 68 | valid_num = 0 69 | save_step = args.save_step 70 | __iter = tqdm(candidate) 71 | for idx, url in enumerate(__iter): 72 | now_url_type = url[:20] 73 | if now_url_type == last_url_type: 74 | time.sleep(2) 75 | 76 | infos = url2info[url] 77 | do_search = False 78 | if not "have_searched" in infos.keys(): 79 | do_search = True 80 | 81 | if do_search: 82 | fetched_text, fetched_text_chunked_list = fetch_web_info_by_crawl(url) 83 | url2info[url]['fetched_text'] = fetched_text 84 | url2info[url]['fetched_text_chunked_list'] = fetched_text_chunked_list 85 | url2info[url]['have_searched'] = True 86 | 87 | if len(url2info[url]['fetched_text_chunked_list']) > 0: 88 | valid_num = valid_num + 1 89 | __iter.set_description("valid num = {}/{}".format(valid_num, idx+1)) 90 | 91 | last_url_type = url[:20] 92 | if idx % save_step == 0: 93 | json.dump(url2info, open(url2info_path, 'w'), indent=4) 94 | 95 | json.dump(url2info, open(url2info_path, 'w'), indent=4) 96 | 97 | if __name__ == '__main__': 98 | main() -------------------------------------------------------------------------------- /scripts/iag/step2_call_search_engines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('scripts/utils') 4 | from utils_engine import search_from_web 5 | 6 | from datetime import datetime 7 | from tqdm import tqdm 8 | import json 9 | import argparse 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--test_data_path', type=str, default='./datasets/test/UDK-VQA/test_raw.jsonl') 16 | parser.add_argument('--save_step', type=int, default=500) 17 | 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def fetch_search_urls(query, freshness=''): 23 | 24 | engine_name = 'bing' 25 | exclude_list = ['www.msn.com', 'www.usatoday.com'] 26 | query_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 27 | 28 | urls = [] 29 | url_infos = {} 30 | try: 31 | res_items = search_from_web(query, engine=engine_name, search_type='text', exclude_list=exclude_list, freshness=freshness) 32 | urls = [x['url'] for x in res_items] 33 | url_infos = {x['url']: x for x in res_items} 34 | 35 | res_items = search_from_web(query, engine=engine_name, search_type='news', exclude_list=exclude_list, freshness=freshness) 36 | urls2 = [x['url'] for x in res_items] 37 | url_infos2 = {x['url']: x for x in res_items} 38 | 39 | for url in urls2: 40 | if not url in urls: 41 | urls.append(url) 42 | url_infos[url] = url_infos2[url] 43 | 44 | return query_time, urls, url_infos 45 | except: 46 | return query_time, urls, url_infos 47 | 48 | 49 | 50 | 51 | def main(): 52 | 53 | args = parse_args() 54 | 55 | test_data_path = args.test_data_path 56 | 57 | dataset_name = os.path.dirname(test_data_path).split('/')[-1] 58 | query2url_path = './intermediate_files/{}/query2url.json'.format(dataset_name) 59 | url2info_path = './intermediate_files/{}/url2info.json'.format(dataset_name) 60 | 61 | folder_path = os.path.dirname(query2url_path) 62 | os.makedirs(folder_path, exist_ok=True) 63 | 64 | if os.path.exists(query2url_path): 65 | with open(query2url_path, 'r') as f: 66 | query2url = json.load(f) 67 | else: 68 | query2url = {} 69 | print("query number = {}".format(len(query2url.keys()))) 70 | 71 | if os.path.exists(url2info_path): 72 | with open(url2info_path, 'r') as f: 73 | url2info = json.load(f) 74 | else: 75 | url2info = {} 76 | 77 | save_step = args.save_step 78 | __iter = tqdm(query2url.keys()) 79 | for idx, query in enumerate(__iter): 80 | urls = query2url[query] 81 | if len(urls) == 0: 82 | _, urls, url_infos = fetch_search_urls(query) 83 | query2url[query] = urls 84 | 85 | for url in urls: 86 | if not url in url2info.keys(): 87 | url2info[url] = url_infos[url] 88 | 89 | if idx % save_step == 0: 90 | json.dump(query2url, open(query2url_path, 'w'), indent=4) 91 | json.dump(url2info, open(url2info_path, 'w'), indent=4) 92 | 93 | json.dump(query2url, open(query2url_path, 'w'), indent=4) 94 | json.dump(url2info, open(url2info_path, 'w'), indent=4) 95 | 96 | if __name__ == '__main__': 97 | main() -------------------------------------------------------------------------------- /scripts/iag.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | device=cuda:0 4 | dataset_name=UDK-VQA 5 | 6 | project_dir="" 7 | llava_dir="" 8 | 9 | ner_model_path="" 10 | llama3_ckpt_dir="" 11 | llama3_tokenizer_path="" 12 | clip_dir="" 13 | ################################## 14 | # ↑↑ Need to modify ↑↑ # 15 | ################################## 16 | 17 | 18 | 19 | 20 | ################################## 21 | # ↓↓ No Modification Required ↓↓ # 22 | ################################## 23 | test_data_path='./datasets/test/'$dataset_name'/test_raw.jsonl' 24 | test_img_dir='./datasets/test/'$dataset_name'/images' 25 | 26 | save_step_for_step1=10 27 | save_step_for_step2=500 28 | save_step_for_step3=500 29 | top_num=10 30 | 31 | prompt_for_segment_selection="How helpful is this context in answering the question based on the image? Choose the best option.\n\nContext: {}\nQuestion: {}\nOptions:\nA. 1.0\nB. 0.8\nC. 0.6\nD. 0.4\nE. 0.2\nF. 0.0\n" 32 | prompt_for_question_answering="Given context: {}.\n\nQuestion: {}\nAnswers:\nA. {}\nB. {}\nC. {}\nD. {}\nE. No correct answers\n\nAnswer with the option's letter from the given choices directly based on the context and the image." 33 | prompt_for_question_answering_nocxt="Question: {}\nAnswers:\nA. {}\nB. {}\nC. {}\nD. {}\nE. No correct answers\n\nAnswer with the option's letter from the given choices directly based on the context and the image." 34 | 35 | llava_ckp_name=llava_lora_content_filter 36 | llava_ckp_dir=$project_dir'/checkpoints/'$llava_ckp_name 37 | question_filepath=$project_dir'/intermediate_files/'$dataset_name'/segment_level_items.jsonl' 38 | answer_filepath=$project_dir'/intermediate_files/'$dataset_name'/segment_score.json' 39 | 40 | /cpfs01/user/lichuanhao/miniconda3/envs/test_env/bin/python scripts/iag/step1_gen_search_queries.py \ 41 | --test_data_path $test_data_path \ 42 | --test_img_dir $test_img_dir \ 43 | --ner_model_path $ner_model_path \ 44 | --llama3_ckpt_dir $llama3_ckpt_dir \ 45 | --llama3_tokenizer_path $llama3_tokenizer_path \ 46 | --clip_dir $clip_dir \ 47 | --save_step $save_step_for_step1 \ 48 | --device $device 49 | 50 | /cpfs01/user/lichuanhao/miniconda3/envs/test_env/bin/python scripts/iag/step2_call_search_engines.py \ 51 | --test_data_path $test_data_path \ 52 | --save_step $save_step_for_step2 53 | 54 | /cpfs01/user/lichuanhao/miniconda3/envs/test_env/bin/python scripts/iag/step3_crawl_from_urls.py \ 55 | --test_data_path $test_data_path \ 56 | --save_step $save_step_for_step3 57 | 58 | /cpfs01/user/lichuanhao/miniconda3/envs/test_env/bin/python scripts/iag/step4_gen_data4hierarchical_filter.py \ 59 | --test_data_path $test_data_path \ 60 | --prompt "$prompt_for_segment_selection" 61 | 62 | # step5 63 | /cpfs01/user/lichuanhao/miniconda3/envs/llava/bin/python $llava_dir'llava/eval/model_vqa.py' \ 64 | --model-base lmsys/vicuna-7b-v1.5 \ 65 | --model-path $llava_ckp_dir \ 66 | --question-file $question_filepath \ 67 | --image-folder $test_img_dir \ 68 | --answers-file $answer_filepath 69 | 70 | /cpfs01/user/lichuanhao/miniconda3/envs/test_env/bin/python scripts/iag/step6_gen_test_searchlvlms.py \ 71 | --test_data_path $test_data_path \ 72 | --clip_dir $clip_dir \ 73 | --top_num $top_num \ 74 | --device $device \ 75 | --prompt "$prompt_for_question_answering" \ 76 | --prompt_nocxt "$prompt_for_question_answering_nocxt" 77 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: searchlvlms 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2024.11.26=h06a4308_0 8 | - ld_impl_linux-64=2.40=h12ee557_0 9 | - libffi=3.4.4=h6a678d5_1 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.15=h5eee18b_0 15 | - pip=24.2=py39h06a4308_0 16 | - python=3.9.20=he870216_1 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=75.1.0=py39h06a4308_0 19 | - sqlite=3.45.3=h5eee18b_0 20 | - tk=8.6.14=h39e8969_0 21 | - tzdata=2024b=h04d1e81_0 22 | - wheel=0.44.0=py39h06a4308_0 23 | - xz=5.4.6=h5eee18b_1 24 | - zlib=1.2.13=h5eee18b_1 25 | - pip: 26 | - annotated-types==0.7.0 27 | - anyio==4.7.0 28 | - beautifulsoup4==4.12.3 29 | - cachetools==5.5.0 30 | - certifi==2024.8.30 31 | - charset-normalizer==3.4.0 32 | - click==8.1.7 33 | - cssselect==1.2.0 34 | - distro==1.9.0 35 | - exceptiongroup==1.2.2 36 | - fairscale==0.4.13 37 | - feedfinder2==0.0.4 38 | - feedparser==6.0.11 39 | - filelock==3.16.1 40 | - fsspec==2024.2.0 41 | - ftfy==6.3.1 42 | - google-api-core==2.23.0 43 | - google-api-python-client==2.154.0 44 | - google-auth==2.36.0 45 | - google-auth-httplib2==0.2.0 46 | - googleapis-common-protos==1.66.0 47 | - h11==0.14.0 48 | - httpcore==1.0.7 49 | - httplib2==0.22.0 50 | - httpx==0.28.1 51 | - huggingface-hub==0.26.5 52 | - idna==3.10 53 | - jieba3k==0.35.1 54 | - jinja2==3.1.3 55 | - jiter==0.8.0 56 | - joblib==1.4.2 57 | - lxml==5.3.0 58 | - lxml-html-clean==0.4.1 59 | - markupsafe==2.1.5 60 | - mpmath==1.3.0 61 | - networkx==3.2.1 62 | - newspaper3k==0.2.8 63 | - nltk==3.9.1 64 | - numpy==1.26.3 65 | - nvidia-cublas-cu11==11.11.3.6 66 | - nvidia-cuda-cupti-cu11==11.8.87 67 | - nvidia-cuda-nvrtc-cu11==11.8.89 68 | - nvidia-cuda-runtime-cu11==11.8.89 69 | - nvidia-cudnn-cu11==8.7.0.84 70 | - nvidia-cufft-cu11==10.9.0.58 71 | - nvidia-curand-cu11==10.3.0.86 72 | - nvidia-cusolver-cu11==11.4.1.48 73 | - nvidia-cusparse-cu11==11.7.5.86 74 | - nvidia-nccl-cu11==2.19.3 75 | - nvidia-nvtx-cu11==11.8.86 76 | - openai==1.57.0 77 | - openai-clip==1.0.1 78 | - packaging==24.2 79 | - pillow==11.0.0 80 | - proto-plus==1.25.0 81 | - protobuf==5.29.1 82 | - pyasn1==0.6.1 83 | - pyasn1-modules==0.4.1 84 | - pydantic==2.10.3 85 | - pydantic-core==2.27.1 86 | - pyparsing==3.2.0 87 | - python-dateutil==2.9.0.post0 88 | - pyyaml==6.0.2 89 | - regex==2024.11.6 90 | - requests==2.32.3 91 | - requests-file==2.1.0 92 | - rsa==4.9 93 | - safetensors==0.4.5 94 | - scikit-learn==1.5.2 95 | - scipy==1.13.1 96 | - sgmllib3k==1.0.0 97 | - six==1.17.0 98 | - sniffio==1.3.1 99 | - soupsieve==2.6 100 | - sympy==1.13.1 101 | - threadpoolctl==3.5.0 102 | - tiktoken==0.8.0 103 | - tinysegmenter==0.3 104 | - tldextract==5.1.3 105 | - tokenizers==0.21.0 106 | - torch==2.2.0+cu118 107 | - torchaudio==2.2.0+cu118 108 | - torchvision==0.17.0+cu118 109 | - tqdm==4.67.1 110 | - transformers==4.47.0 111 | - triton==2.2.0 112 | - typing-extensions==4.12.2 113 | - uritemplate==4.1.1 114 | - urllib3==2.2.3 115 | - wcwidth==0.2.13 116 | prefix: /cpfs01/user/lichuanhao/miniconda3/envs/searchlvlms 117 | -------------------------------------------------------------------------------- /scripts/eval/eval_lvlms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.environ.get("vlmevalkit_dir")) 4 | from vlmeval.config import supported_VLM 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | import logging 9 | logging.getLogger("transformers").setLevel(logging.ERROR) 10 | 11 | 12 | import argparse 13 | import json 14 | from tqdm import tqdm 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--test_lvlm_list', type=str, default='llava-v1.5-7b-xtuner') 20 | parser.add_argument('--test_data_path', type=str, default='./datasets/test/UDK-VQA/test_raw.jsonl') 21 | parser.add_argument('--test_img_dir', type=str, default='./datasets/test/UDK-VQA/images') 22 | parser.add_argument('--prediction_path', type=str, default='./predictions/{}/searchlvlms_{}.json') 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def check_gt_ans(pred, answer_choice, answer_text): 29 | answers = [answer_choice, answer_choice[0], answer_text, answer_choice+' '+answer_text, answer_choice[0]+' '+answer_text] 30 | format_answers1 = ['The answer is {}'.format(x) for x in answers] 31 | format_answers2 = ['Answer: {}'.format(x) for x in answers] 32 | format_answers3 = ['{}'.format(x) for x in answers] 33 | 34 | answers.extend(format_answers1) 35 | answers.extend(format_answers2) 36 | answers.extend(format_answers3) 37 | answers = list(set([x.lower().strip().strip('.') for x in answers])) 38 | pred = pred.lower().strip().strip('.') 39 | 40 | if pred in answers: 41 | return 1 42 | 43 | answers.remove(answer_choice[0].lower().strip().strip('.')) 44 | for gt_ans in answers: 45 | min_len = min(len(gt_ans), len(pred)) 46 | trunc_pred = pred[: min_len] 47 | if gt_ans == trunc_pred: 48 | return 1 49 | 50 | format_answers1 = list(set([x.lower().strip().strip('.') for x in format_answers1])) 51 | for gt_ans in format_answers1: 52 | if gt_ans in pred: 53 | return 1 54 | return 0 55 | 56 | 57 | def main(): 58 | args = parse_args() 59 | 60 | test_data_path = args.test_data_path 61 | img_dir = args.test_img_dir 62 | 63 | test_lvlm_list = args.test_lvlm_list.split('@') 64 | candidate_list = [json.loads(q) for q in open(os.path.expanduser(test_data_path), "r")] 65 | dataset_name = os.path.dirname(test_data_path).split('/')[-1] 66 | 67 | print(len(candidate_list)) 68 | print('test_data_path = {}'.format(test_data_path)) 69 | print('prediction_path = {}'.format(args.prediction_path)) 70 | 71 | for lvlm in test_lvlm_list: 72 | model = supported_VLM[lvlm]() 73 | lvlm_predictions = {} 74 | 75 | p_path = args.prediction_path.format(lvlm) 76 | folder_path = os.path.dirname(p_path) 77 | os.makedirs(folder_path, exist_ok=True) 78 | 79 | correct_num = 0 80 | __iter = tqdm(candidate_list) 81 | for idx, sample in enumerate(__iter): 82 | __iter.set_description(lvlm) 83 | question_id = sample["question_id"] 84 | answer_choice = sample["category"] 85 | answer_text = sample["answer_text"] 86 | image_name = sample["image"] 87 | img_path = os.path.join(img_dir, image_name) 88 | 89 | pred = model.generate(img_path, sample['text']) 90 | score = check_gt_ans(pred, answer_choice, answer_text) 91 | lvlm_predictions[question_id] = {"pred": pred, "score": score} 92 | correct_num = correct_num + lvlm_predictions[question_id]["score"] 93 | 94 | print("{}: acc = {}".format(lvlm, correct_num * 100 / len(candidate_list))) 95 | json.dump(lvlm_predictions, open(p_path, 'w')) 96 | 97 | if __name__ == '__main__': 98 | main() -------------------------------------------------------------------------------- /scripts/iag/step4_gen_data4hierarchical_filter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | 5 | import argparse 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--test_data_path', type=str, default='./datasets/test/UDK-VQA/test_raw.jsonl') 11 | parser.add_argument('--prompt', type=str, default="How helpful is this context in answering the question based on the image? Choose the best option.\n\nContext: {}\nQuestion: {}\nOptions:\nA. 1.0\nB. 0.8\nC. 0.6\nD. 0.4\nE. 0.2\nF. 0.0\n") 12 | 13 | args = parser.parse_args() 14 | return args 15 | 16 | def get_id2infos(test_data_path): 17 | 18 | id2infos = {} 19 | test_list = [json.loads(q) for q in open(os.path.expanduser(test_data_path), "r")] 20 | __iter = tqdm(test_list) 21 | for idx, sample in enumerate(__iter): 22 | question = sample["text"].split("Question: ")[-1].split("\nAnswers:")[0] 23 | question_id = sample["question_id"] 24 | image_name = sample["image"] 25 | 26 | choices_str = sample["text"].split("\nAnswers:\n")[-1].split("\n\nAnswer")[0] 27 | ch1 = choices_str.split("A. ")[-1].split('\n')[0] 28 | ch2 = choices_str.split("B. ")[-1].split('\n')[0] 29 | ch3 = choices_str.split("C. ")[-1].split('\n')[0] 30 | ch4 = choices_str.split("D. ")[-1].split('\n')[0] 31 | 32 | id2infos[question_id] = { 33 | "question": question, 34 | "image": image_name, 35 | "choices": [ch1, ch2, ch3, ch4], 36 | "gt_choice": sample["category"], 37 | "gt_ans": sample["answer_text"] 38 | } 39 | 40 | return id2infos 41 | 42 | def main(): 43 | 44 | args = parse_args() 45 | 46 | prompt = args.prompt 47 | test_data_path = args.test_data_path 48 | 49 | dataset_name = os.path.dirname(test_data_path).split('/')[-1] 50 | id2allquery_path = './intermediate_files/{}/id2allquery.json'.format(dataset_name) 51 | query2url_path = './intermediate_files/{}/query2url.json'.format(dataset_name) 52 | url2info_path = './intermediate_files/{}/url2info.json'.format(dataset_name) 53 | 54 | folder_path = os.path.dirname(id2allquery_path) 55 | os.makedirs(folder_path, exist_ok=True) 56 | 57 | with open(id2allquery_path, 'r') as f: 58 | id2allquery = json.load(f) 59 | with open(query2url_path, 'r') as f: 60 | query2url = json.load(f) 61 | with open(url2info_path, 'r') as f: 62 | url2info = json.load(f) 63 | 64 | id2infos = get_id2infos(test_data_path) 65 | sample_num = 0 66 | res_list = [] 67 | __iter = tqdm(id2allquery.keys()) 68 | for qid in __iter: 69 | queries = id2allquery[qid] 70 | 71 | for query in queries: 72 | if not query in query2url.keys(): 73 | continue 74 | urls = query2url[query] 75 | for url in urls: 76 | if not url in url2info.keys(): 77 | continue 78 | infos = url2info[url] 79 | if "fetched_text_chunked_list" in infos.keys() and len(infos["fetched_text_chunked_list"]) > 0: 80 | for cxt_idx, cxt in enumerate(infos["fetched_text_chunked_list"]): 81 | question_id = "{}-@split@-{}-@split@-{}".format(qid, url, cxt_idx) 82 | res_list.append( 83 | { 84 | "question_id": question_id, 85 | "image": id2infos[qid]["image"], 86 | "text": prompt.format(cxt, id2infos[qid]["question"]).replace('\\n', '\n'), 87 | } 88 | ) 89 | __iter.set_description("{}, totoal sample num = {}".format(dataset_name, len(res_list))) 90 | 91 | save_path = './intermediate_files/{}/segment_level_items.jsonl'.format(dataset_name) 92 | with open(save_path, 'w') as f: 93 | for tmp in res_list: 94 | json.dump(tmp, f) 95 | f.write('\n') 96 | 97 | if __name__ == '__main__': 98 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SearchLVLMs 2 | **SearchLVLMs: A Plug-and-Play Framework for Augmenting Large Vision-Language Models by Searching Up-to-Date Internet Knowledge (NeurIPS 2024)** 3 | Chuanhao Li, Zhen Li, Chenchen Jing, Shuo Liu, Wenqi Shao, Yuwei Wu, Ping Luo, Yu Qiao, Kaipeng Zhang 4 | [[Homepage]](https://nevermorelch.github.io/SearchLVLMs.github.io/) [[Paper]](https://arxiv.org/pdf/2405.14554) 5 | 6 | ![Example Image](https://github.com/NeverMoreLCH/SearchLVLMs.github.io/blob/main/static/images/framework.png?raw=true) 7 | 8 |
9 | 10 | ## News 11 | - 2024.12.09: 🎉 The inference code and UDK-VQA dataset are released! 12 | - 2024.09.26: 🎉 SearchLVLMs is accepted by NeurIPS 2024! 13 | 14 |
15 | 16 | ## Install 17 | ``` 18 | conda env create -f environment.yml 19 | conda activate searchlvlms 20 | ``` 21 | 22 |
23 | 24 | ## Prerequisites 25 | #### Llama3 26 | Install [Llama3](https://github.com/meta-llama/llama3/) and download the [checkpoint](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct). 27 | 28 | #### VLMEvalKit 29 | Install [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and download the checkpoints of LVLMs for testing. 30 | 31 | #### LLaVA-1.5 32 | Install [LLaVA-1.5](https://github.com/haotian-liu/LLaVA) and download the [pretrained model](https://huggingface.co/liuhaotian/llava-v1.5-7b/tree/main) and the [projector weights](https://huggingface.co/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/tree/main). 33 | 34 | #### NER 35 | Download the NER model via [huggingface](https://huggingface.co/dslim/bert-large-NER). 36 | 37 | #### CLIP 38 | Download the CLIP model via [huggingface](https://huggingface.co/docs/transformers/model_doc/clip). 39 | 40 |
41 | 42 | ## UDK-VQA Dataset and Checkpoint 43 | For both UDK-VQA and the checkpoint of our filter, download them from: 44 | [[OneDrive](https://1drv.ms/f/c/da3b3d08ef25cb06/EoDXINrNkOZMoREylIMlBXIBchbhOGpUaSarJAzjFquieg?e=EhNjAj)] or 45 | [[Baidu NetDisk (password: DSPS)](https://pan.baidu.com/s/1XCJq9mSItZAd21fY0Xz1zQ)] 46 | 47 | Unzip the zip files and make sure the file structure looks like this: 48 | ``` 49 | SearchLVLMs 50 | ----checkpoints 51 | --------llava_lora_content_filter 52 | --------llava_lora_website_filter 53 | ----datasets 54 | --------test 55 | --------train 56 | ... 57 | ``` 58 | 59 |
60 | 61 | ## Configurations 62 | Configure the variables in `scripts/init_env_variable.sh`, `scripts/iag.sh` and `scripts/eval.sh`. 63 | - `scripts/init_env_variable.sh` 64 | ``` 65 | # used for generating queries for the question via llama3. 66 | llama3_dir="" 67 | 68 | # used for running eval.sh 69 | vlmevalkit_dir="" 70 | 71 | # used for calling gpt to generate queries for the question. 72 | OPENAI_API_KEY="" 73 | OPENAI_ENDPOINT="" 74 | 75 | # the keys of the google search engine are optional, as we mainly use the bing search engine. 76 | google_api_key="" 77 | google_text_cse_id="" 78 | google_image_cse_id="" 79 | 80 | # img_api is optional, as it's used for generating samples, which is not released yet. 81 | bing_text_api_key="" 82 | bing_img_api_key="" 83 | bing_visual_api_key="" 84 | ``` 85 | For the variables in `scripts/iag.sh` and `scripts/eval.sh`, you can easily understand them via their names. 86 | 87 |
88 | 89 | ## Evaluation 90 | You can run the following scripts to evaluate LVLMs (or LVLMs+SearchLVLMs) 91 | ``` 92 | cd SearchLVLMs 93 | 94 | # Active environment variable 95 | source scripts/init_env_variable.sh 96 | 97 | # Run SearchLVLMs to find the best context for each sample in the test set. 98 | sh scripts/iag.sh 99 | 100 | # Eval the accuracy of LVLMs (or LVLMs+SearchLVLMs) on the test set. 101 | sh scripts/eval.sh 102 | 103 | # Deactivate environment variable 104 | source scripts/unset_env_variable.sh 105 | ``` 106 | 107 |
108 | 109 | ## Citation 110 | If any part of our paper and code is helpful to your work, please generously cite with: 111 | ``` 112 | @inproceedings{li2024searchlvlms, 113 | title={SearchLVLMs: A Plug-and-Play Framework for Augmenting Large Vision-Language Models by Searching Up-to-Date Internet Knowledge}, 114 | author={Li, Chuanhao and Li, Zhen and Jing, Chenchen and Liu, Shuo and Shao, Wenqi and Wu, Yuwei and Luo, Ping and Qiao, Yu and Zhang, Kaipeng}, 115 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 116 | year={2024} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu=0 4 | 5 | # dataset_name in [UDK-VQA, UDK-VQA-20240905] 6 | dataset_name=UDK-VQA 7 | 8 | # for version in [raw, gt_segment] 9 | # version=raw 10 | # test_data_path='./datasets/test/'$dataset_name'/test_'$version'.jsonl' 11 | 12 | # for version in [searchlvlms] 13 | version=searchlvlms 14 | test_data_path='./intermediate_files/'$dataset_name'/test_'$version'.jsonl' 15 | ################################## 16 | # ↑↑ Need to modify ↑↑ # 17 | ################################## 18 | 19 | 20 | 21 | 22 | ################################## 23 | # ↓↓ No Modification Required ↓↓ # 24 | ################################## 25 | test_img_dir='./datasets/test/'$dataset_name'/images' 26 | prediction_path='./predictions/'$dataset_name'/{}_'$version'.json' 27 | 28 | test_lvlm_list_llava15=llava_v1.5_7b 29 | test_lvlm_list_llavanext=llava_next_mistral_7b 30 | test_lvlm_list_monkey=monkey 31 | test_lvlm_list_cogvlm_chat=cogvlm-chat 32 | test_lvlm_list_qwen_chat=qwen_chat 33 | test_lvlm_list_llava_v15_7b_xtuner=llava-v1.5-7b-xtuner 34 | test_lvlm_list_MiniCPM_V2=MiniCPM-V2 35 | test_lvlm_list_XComposer2=XComposer2 36 | test_lvlm_list_MMAlaya=MMAlaya 37 | test_lvlm_list_VisualGLM_6b=VisualGLM_6b 38 | test_lvlm_list_mPLUG_Owl2=mPLUG-Owl2 39 | test_lvlm_list_internvl=InternVL-Chat-V1-5 40 | 41 | 42 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/llava_next/bin/python scripts/eval/eval_lvlms.py \ 43 | --test_lvlm_list $test_lvlm_list_llavanext \ 44 | --test_data_path $test_data_path \ 45 | --test_img_dir $test_img_dir \ 46 | --prediction_path $prediction_path 47 | 48 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39/bin/python scripts/eval/eval_lvlms.py \ 49 | --test_lvlm_list $test_lvlm_list_internvl \ 50 | --test_data_path $test_data_path \ 51 | --test_img_dir $test_img_dir \ 52 | --prediction_path $prediction_path 53 | 54 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39/bin/python scripts/eval/eval_lvlms.py \ 55 | --test_lvlm_list $test_lvlm_list_monkey \ 56 | --test_data_path $test_data_path \ 57 | --test_img_dir $test_img_dir \ 58 | --prediction_path $prediction_path 59 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39/bin/python scripts/eval/eval_lvlms.py \ 60 | --test_lvlm_list $test_lvlm_list_cogvlm_chat \ 61 | --test_data_path $test_data_path \ 62 | --test_img_dir $test_img_dir \ 63 | --prediction_path $prediction_path 64 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39/bin/python scripts/eval/eval_lvlms.py \ 65 | --test_lvlm_list $test_lvlm_list_qwen_chat \ 66 | --test_data_path $test_data_path \ 67 | --test_img_dir $test_img_dir \ 68 | --prediction_path $prediction_path 69 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39/bin/python scripts/eval/eval_lvlms.py \ 70 | --test_lvlm_list $test_lvlm_list_llava_v15_7b_xtuner \ 71 | --test_data_path $test_data_path \ 72 | --test_img_dir $test_img_dir \ 73 | --prediction_path $prediction_path 74 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39/bin/python scripts/eval/eval_lvlms.py \ 75 | --test_lvlm_list $test_lvlm_list_MiniCPM_V2 \ 76 | --test_data_path $test_data_path \ 77 | --test_img_dir $test_img_dir \ 78 | --prediction_path $prediction_path 79 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39/bin/python scripts/eval/eval_lvlms.py \ 80 | --test_lvlm_list $test_lvlm_list_XComposer2 \ 81 | --test_data_path $test_data_path \ 82 | --test_img_dir $test_img_dir \ 83 | --prediction_path $prediction_path 84 | 85 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39_trans433/bin/python scripts/eval/eval_lvlms.py \ 86 | --test_lvlm_list $test_lvlm_list_MMAlaya \ 87 | --test_data_path $test_data_path \ 88 | --test_img_dir $test_img_dir \ 89 | --prediction_path $prediction_path 90 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/py39_trans433/bin/python scripts/eval/eval_lvlms.py \ 91 | --test_lvlm_list $test_lvlm_list_VisualGLM_6b \ 92 | --test_data_path $test_data_path \ 93 | --test_img_dir $test_img_dir \ 94 | --prediction_path $prediction_path 95 | 96 | CUDA_VISIBLE_DEVICES=$gpu /cpfs01/user/lichuanhao/miniconda3/envs/mplug_owl2/bin/python scripts/eval/eval_lvlms.py \ 97 | --test_lvlm_list $test_lvlm_list_mPLUG_Owl2 \ 98 | --test_data_path $test_data_path \ 99 | --test_img_dir $test_img_dir \ 100 | --prediction_path $prediction_path -------------------------------------------------------------------------------- /scripts/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import hashlib 4 | import requests 5 | from openai import OpenAI 6 | from newspaper import Article 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | self.file = open(fpath, 'w') 14 | 15 | def __del__(self): 16 | self.close() 17 | 18 | def __enter__(self): 19 | pass 20 | 21 | def __exit__(self, *args): 22 | self.close() 23 | 24 | def write(self, msg): 25 | self.console.write(msg) 26 | if self.file is not None: 27 | self.file.write(msg) 28 | 29 | def flush(self): 30 | self.console.flush() 31 | if self.file is not None: 32 | self.file.flush() 33 | os.fsync(self.file.fileno()) 34 | 35 | def close(self): 36 | self.console.close() 37 | if self.file is not None: 38 | self.file.close() 39 | 40 | 41 | def hash_url(url: str) -> str: 42 | sha256_hash = hashlib.sha256() 43 | sha256_hash.update(url.encode('utf-8')) 44 | hex_dig = sha256_hash.hexdigest() 45 | return hex_dig 46 | 47 | def get_hash_name_suffix(url): 48 | name = hash_url(url) 49 | suffix = name.split('.')[-1] 50 | if 'jpg' in suffix: 51 | suffix = 'jpg' 52 | elif 'jpeg' in suffix: 53 | suffix = 'jpeg' 54 | elif 'png' in suffix: 55 | suffix = 'png' 56 | elif 'gif' in suffix: 57 | suffix = 'gif' 58 | elif 'webp' in suffix: 59 | suffix = 'webp' 60 | elif 'svg' in suffix: 61 | suffix = 'svg' 62 | else: 63 | suffix = 'jpg' 64 | return name, suffix 65 | 66 | def text_chunk(text, token_num_per_chunk=800): 67 | text = text.replace('\n\n', ' ').replace('\n', ' ') 68 | text_list = text.split('.') 69 | 70 | ret = [] 71 | tmp_len = 0 72 | tmp_text_list = [] 73 | for para in text_list: 74 | if len(para) == 0: 75 | continue 76 | now_len = len(para.split(' ')) 77 | if now_len > token_num_per_chunk: 78 | continue 79 | 80 | if tmp_len + now_len <= token_num_per_chunk: 81 | tmp_str = para.strip() + '.' 82 | tmp_text_list.append(tmp_str) 83 | tmp_len = tmp_len + now_len 84 | else: 85 | ret.append(' '.join(tmp_text_list)) 86 | tmp_len = 0 87 | tmp_text_list = [] 88 | 89 | if tmp_len > 0: 90 | ret.append(' '.join(tmp_text_list)) 91 | 92 | return ret 93 | 94 | def check_fetched_text(text): 95 | good_flag = True 96 | if len(text) < 200: 97 | good_flag = False 98 | if len(text.split(' ')) <= 40: 99 | good_flag = False 100 | if 'Error' in text[:60]: 101 | good_flag = False 102 | if 'cookies' in text[:100]: 103 | good_flag = False 104 | if 'Cookies' in text[:100]: 105 | good_flag = False 106 | if text.startswith('Please upgrade your browser'): 107 | good_flag = False 108 | if 'Please make sure your browser supports JavaScript and cookies' in text: 109 | good_flag = False 110 | return good_flag 111 | 112 | def parse_ner_res(question, ner_res): 113 | ret_list = [] 114 | save_start_idx = 0 115 | save_end_idx = 0 116 | last_type = "" 117 | for info in ner_res: 118 | ent_type = info["entity"] 119 | start_idx = info["start"] 120 | end_idx = info["end"] 121 | 122 | if "-" in ent_type: 123 | prefix, suffix = ent_type.split("-") 124 | else: 125 | prefix = "O" 126 | suffix = "O" 127 | 128 | if prefix in ["B", "O"]: 129 | if not last_type == "": 130 | ret_list.append({"entity":question[save_start_idx:save_end_idx], "type":last_type}) 131 | save_start_idx = start_idx 132 | save_end_idx = end_idx 133 | if prefix == "I": 134 | save_end_idx = end_idx 135 | 136 | last_type = suffix 137 | 138 | if not last_type == "": 139 | ret_list.append({"entity":question[save_start_idx:save_end_idx], "type":last_type}) 140 | 141 | return ret_list 142 | 143 | 144 | def gpt_text_only(text, pair_num=1, temperature=0.1, model='gpt-3.5-turbo', OPENAI_API_KEY='', OPENAI_ENDPOINT=''): 145 | client = OpenAI( 146 | api_key=os.environ.get("OPENAI_API_KEY"), 147 | base_url=os.environ.get("OPENAI_ENDPOINT") 148 | ) 149 | 150 | if pair_num > 1: 151 | response = client.chat.completions.create( 152 | messages=[ 153 | { 154 | "role": "user", 155 | "content": text, 156 | } 157 | ], 158 | model=model, 159 | temperature = temperature, 160 | n = pair_num 161 | ) 162 | else: 163 | response = client.chat.completions.create( 164 | messages=[ 165 | { 166 | "role": "user", 167 | "content": text, 168 | } 169 | ], 170 | model=model, 171 | temperature = temperature 172 | ) 173 | res = [str(c.message.content) for c in response.choices] 174 | return res 175 | 176 | 177 | def fetch_by_newspaper(url): 178 | headers = { 179 | 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 180 | 'Accept-Language': 'en-us;q=0.5,en;q=0.3', 181 | 'Cache-Control': 'max-age=0', 182 | 'Connection': 'keep-alive', 183 | 'User-Agent': 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:22.0) Gecko/20100101 Firefox/22.0' 184 | } 185 | try: 186 | response = requests.get(url, headers=headers, timeout=(3, 6)) # headers=headers 187 | html_content = response.text 188 | article = Article(url='') 189 | article.set_html(html_content) 190 | article.parse() 191 | except Exception as err: 192 | return {'fetched_title': 'Error', 'fetched_text': 'Error Fetch! ' + str(err)} 193 | try: 194 | fetched_title = article.title 195 | except Exception as err: 196 | fetched_title = 'Error Title! ' + str(err) 197 | 198 | try: 199 | fetched_text = article.text 200 | except Exception as err: 201 | fetched_text = 'Error Fetch! ' + str(err) 202 | 203 | ret_dict = {'fetched_title': fetched_title, 'fetched_text': fetched_text} 204 | return ret_dict 205 | 206 | -------------------------------------------------------------------------------- /scripts/iag/step6_gen_test_searchlvlms.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import random 5 | import clip 6 | import torch 7 | import numpy as np 8 | from sklearn.cluster import KMeans 9 | 10 | import argparse 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--test_data_path', type=str, default='./datasets/test/UDK-VQA/test_raw.jsonl') 16 | parser.add_argument('--clip_dir', type=str, default='/cpfs01/user/lichuanhao/huggingface_cache/clip') 17 | 18 | parser.add_argument('--top_num', type=int, default=10) 19 | parser.add_argument('--device', type=str, default='gpu') 20 | 21 | parser.add_argument('--prompt', type=str, default="Given context: {}.\n\nQuestion: {}\nAnswers:\nA. {}\nB. {}\nC. {}\nD. {}\nE. No correct answers\n\nAnswer with the option's letter from the given choices directly based on the context and the image.") 22 | parser.add_argument('--prompt_nocxt', type=str, default="Question: {}\nAnswers:\nA. {}\nB. {}\nC. {}\nD. {}\nE. No correct answers\n\nAnswer with the option's letter from the given choices directly based on the context and the image.") 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def get_id2infos(test_data_path): 29 | 30 | id2infos = {} 31 | test_list = [json.loads(q) for q in open(os.path.expanduser(test_data_path), "r")] 32 | __iter = tqdm(test_list) 33 | for idx, sample in enumerate(__iter): 34 | question = sample["text"].split("Question: ")[-1].split("\nAnswers:")[0] 35 | question_id = sample["question_id"] 36 | image_name = sample["image"] 37 | 38 | choices_str = sample["text"].split("\nAnswers:\n")[-1].split("\n\nAnswer")[0] 39 | ch1 = choices_str.split("A. ")[-1].split('\n')[0] 40 | ch2 = choices_str.split("B. ")[-1].split('\n')[0] 41 | ch3 = choices_str.split("C. ")[-1].split('\n')[0] 42 | ch4 = choices_str.split("D. ")[-1].split('\n')[0] 43 | 44 | id2infos[question_id] = { 45 | "question": question, 46 | "image": image_name, 47 | "choices": [ch1, ch2, ch3, ch4], 48 | "gt_choice": sample["category"], 49 | "gt_ans": sample["answer_text"] 50 | } 51 | 52 | return id2infos 53 | 54 | 55 | def get_id2predictions(read_path): 56 | 57 | id2predictions = {} 58 | tmp_pred = [json.loads(q) for q in open(os.path.expanduser(read_path), "r")] 59 | 60 | __iter = tqdm(tmp_pred) 61 | for infos in __iter: 62 | question_id = infos["question_id"] 63 | real_qid = question_id.split("-@split@-")[0] 64 | pred = infos["text"] 65 | cxt = infos["prompt"].split("\n\nContext: ")[-1].split("\nQuestion: ")[0] 66 | 67 | if real_qid in id2predictions.keys(): 68 | id2predictions[real_qid].append({ 69 | "question_id": question_id, 70 | "context": cxt, 71 | "pred": pred 72 | }) 73 | else: 74 | id2predictions[real_qid]= [{ 75 | "question_id": question_id, 76 | "context": cxt, 77 | "pred": pred 78 | }] 79 | 80 | return id2predictions 81 | 82 | 83 | def _cluster(model, processor, device, texts, cluster_num): 84 | cluster_num = min(cluster_num, len(texts)) 85 | if cluster_num == 0: 86 | return [] 87 | 88 | ret_list = [] 89 | text_features = [] 90 | 91 | model = model.to(device) 92 | model = model.float() 93 | 94 | for text in texts: 95 | inputs = clip.tokenize(text, truncate=True).to(device) 96 | # print(f"Inputs device: {inputs.device}, Model device: {device}") 97 | with torch.no_grad(): 98 | features = model.encode_text(inputs) 99 | text_features.append(features.cpu().numpy()) 100 | 101 | text_features = np.concatenate(text_features, axis=0) 102 | 103 | n_clusters = cluster_num 104 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(text_features) 105 | 106 | clustered_texts = {} 107 | for i, label in enumerate(kmeans.labels_): 108 | if label not in clustered_texts: 109 | clustered_texts[label] = [] 110 | clustered_texts[label].append((texts[i], np.linalg.norm(text_features[i] - kmeans.cluster_centers_[label]))) 111 | 112 | for i, texts_distances in clustered_texts.items(): 113 | closest_text, closest_distance = min(texts_distances, key=lambda x: x[1]) 114 | ret_list.append(closest_text) 115 | 116 | ret_list.append(', '.join(ret_list)) 117 | 118 | return ret_list 119 | 120 | 121 | def main(): 122 | 123 | args = parse_args() 124 | 125 | test_data_path = args.test_data_path 126 | dataset_name = os.path.dirname(test_data_path).split('/')[-1] 127 | segment_score_path = './intermediate_files/{}/segment_score.json'.format(dataset_name) 128 | 129 | id2infos = get_id2infos(test_data_path) 130 | id2predictions = get_id2predictions(segment_score_path) 131 | 132 | device = torch.device(args.device) 133 | model, processor = clip.load("ViT-B/32", download_root=args.clip_dir, device=device) 134 | 135 | res = [] 136 | top_num = args.top_num 137 | __iter = tqdm(id2infos.keys()) 138 | for idx, qid in enumerate(__iter): 139 | 140 | question = id2infos[qid]["question"] 141 | choices = id2infos[qid]["choices"] 142 | if qid in id2predictions.keys(): 143 | context_list = id2predictions[qid] 144 | select_num = min(len(context_list), top_num*2) 145 | context_list = list([x["context"] for x in context_list])[:select_num] 146 | select_list = random.sample(context_list, select_num) 147 | res_context = _cluster(model, processor, device, select_list, top_num) 148 | text = args.prompt.format(res_context, question, choices[0], choices[1], choices[2], choices[3]).replace('\\n', '\n') 149 | else: 150 | text = args.prompt_nocxt.format(question, choices[0], choices[1], choices[2], choices[3]).replace('\\n', '\n') 151 | 152 | res.append({ 153 | "question_id": qid, 154 | "image": id2infos[qid]["image"], 155 | "text": text, 156 | "category": "{}".format(id2infos[qid]["gt_choice"]), 157 | "answer_text": id2infos[qid]["gt_ans"] 158 | }) 159 | 160 | save_path = './intermediate_files/{}/test_searchlvlms.jsonl'.format(dataset_name) 161 | with open(save_path, 'w') as f: 162 | for tmp in res: 163 | json.dump(tmp, f) 164 | f.write('\n') 165 | 166 | if __name__ == '__main__': 167 | main() -------------------------------------------------------------------------------- /scripts/iag/step1_gen_search_queries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.environ.get("llama3_dir")) 4 | # settings for llama3 5 | os.environ['MASTER_ADDR'] = 'localhost' 6 | os.environ['MASTER_PORT'] = '12339' 7 | os.environ["RANK"] = "0" 8 | os.environ["WORLD_SIZE"] = "1" 9 | 10 | 11 | sys.path.append('scripts/utils') 12 | from utils import parse_ner_res, gpt_text_only 13 | from utils_engine import search_from_web 14 | 15 | 16 | from typing import List, Optional 17 | from llama import Dialog, Llama 18 | from transformers import AutoTokenizer, AutoModelForTokenClassification 19 | from transformers import pipeline 20 | from sklearn.cluster import KMeans 21 | 22 | import torch 23 | import clip 24 | import json 25 | import time 26 | import numpy as np 27 | from tqdm import tqdm 28 | import argparse 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument('--test_data_path', type=str, default='./datasets/test/UDK-VQA/test_raw.jsonl') 35 | parser.add_argument('--test_img_dir', type=str, default='./datasets/test/UDK-VQA/images') 36 | 37 | parser.add_argument('--ner_model_path', type=str, default='/cpfs01/user/lichuanhao/huggingface_cache/hub/models--dslim--bert-large-NER/snapshots/13e784dccceca07aee7a7aab4ad487c605975423') 38 | parser.add_argument('--llama3_ckpt_dir', type=str, default='/cpfs01/user/lichuanhao/huggingface_cache/llama3/Meta-Llama-3-8B-Instruct/') 39 | parser.add_argument('--llama3_tokenizer_path', type=str, default='/cpfs01/user/lichuanhao/huggingface_cache/llama3/Meta-Llama-3-8B-Instruct/tokenizer.model') 40 | parser.add_argument('--clip_dir', type=str, default='/cpfs01/user/lichuanhao/huggingface_cache/clip') 41 | 42 | parser.add_argument('--save_step', type=int, default=10) 43 | parser.add_argument('--device', type=str, default='cuda:0') 44 | 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | 50 | def _cluster(model, processor, device, texts, cluster_num): 51 | cluster_num = min(cluster_num, len(texts)) 52 | if cluster_num == 0: 53 | return [] 54 | 55 | ret_list = [] 56 | text_features = [] 57 | 58 | model = model.to(device) 59 | model = model.float() 60 | 61 | for text in texts: 62 | inputs = clip.tokenize(text).to(device) 63 | # print(f"Inputs device: {inputs.device}, Model device: {device}") 64 | with torch.no_grad(): 65 | features = model.encode_text(inputs) 66 | text_features.append(features.cpu().numpy()) 67 | 68 | text_features = np.concatenate(text_features, axis=0) 69 | 70 | n_clusters = cluster_num 71 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(text_features) 72 | 73 | clustered_texts = {} 74 | for i, label in enumerate(kmeans.labels_): 75 | if label not in clustered_texts: 76 | clustered_texts[label] = [] 77 | clustered_texts[label].append((texts[i], np.linalg.norm(text_features[i] - kmeans.cluster_centers_[label]))) 78 | 79 | for i, texts_distances in clustered_texts.items(): 80 | closest_text, closest_distance = min(texts_distances, key=lambda x: x[1]) 81 | ret_list.append(closest_text) 82 | 83 | ret_list.append(', '.join(ret_list)) 84 | 85 | return ret_list 86 | 87 | def gen_query_by_gpt(question): 88 | temperature = 0.3 89 | pair_num = 1 90 | prompt = "'Question: {}\n\nDo not try to answer the question, just print the most informative no more than three entities in the question. Put them on one line and separate them with comm." 91 | gen_lvlm = "gpt-3.5-turbo" 92 | 93 | text = prompt.format(question) 94 | res_list = gpt_text_only(text, temperature=temperature, pair_num=pair_num, model=gen_lvlm) 95 | # print('res_list = {}'.format(res_list)) 96 | 97 | return res_list[0].split(', ') 98 | 99 | def gen_query_by_llama3(question, generator): 100 | 101 | llama3_prompt = "Question: {}\n\nDo not try to answer the question, just print the most informative no more than three entities in the question. Put them on one line and separate them with comm." 102 | text = llama3_prompt.format(question) 103 | dialogs: List[Dialog] = [ 104 | [{"role": "user", "content": text}], 105 | ] 106 | results = generator.chat_completion( 107 | dialogs, 108 | max_gen_len=512, 109 | temperature=0.6, 110 | top_p=0.9, 111 | ) 112 | 113 | res = [x.strip() for x in results[0]["generation"]["content"].split(',')] 114 | if "" in res: 115 | res.remove("") 116 | return res 117 | 118 | 119 | def extract_q_query(question, nlp, generator): 120 | 121 | queries = set() 122 | now_time = 0 123 | while(True): 124 | now_time = now_time + 1 125 | try: 126 | ner_set = set([x['entity'].lower() for x in parse_ner_res(question, nlp(question))]) 127 | gpt_set = set([x.lower() for x in gen_query_by_gpt(question)]) 128 | lma_set = set([x.lower() for x in gen_query_by_llama3(question, generator)]) 129 | 130 | queries = queries | ner_set 131 | queries = queries | gpt_set 132 | queries = queries | lma_set 133 | break 134 | 135 | except: 136 | if now_time >= 3: 137 | break 138 | time.sleep(1) 139 | return queries 140 | 141 | def extract_v_query(img_path, nlp): 142 | 143 | now_time = 0 144 | queries = set() 145 | while True: 146 | now_time = now_time + 1 147 | try: 148 | res = search_from_web(img_path, engine='bing', search_type='visual', nlp=nlp) 149 | queries = set([res['search_str']]) 150 | break 151 | except: 152 | if now_time >= 3: 153 | break 154 | time.sleep(1) 155 | 156 | return queries 157 | 158 | 159 | 160 | 161 | def main(): 162 | 163 | args = parse_args() 164 | 165 | test_data_path = args.test_data_path 166 | test_img_dir = args.test_img_dir 167 | test_data = [json.loads(q) for q in open(os.path.expanduser(test_data_path), "r")] 168 | 169 | # load NER, llama3 170 | ner_model_path = args.ner_model_path 171 | llama3_ckpt_dir = args.llama3_ckpt_dir 172 | llama3_tokenizer_path = args.llama3_tokenizer_path 173 | 174 | ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_path) 175 | ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_path) 176 | nlp = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer) 177 | llama3_generator = Llama.build( 178 | ckpt_dir=llama3_ckpt_dir, 179 | tokenizer_path=llama3_tokenizer_path, 180 | max_seq_len=512, 181 | max_batch_size=6, 182 | ) 183 | 184 | device = torch.device(args.device) 185 | clip_model, clip_processor = clip.load("ViT-B/32", download_root=args.clip_dir, device=device) 186 | 187 | dataset_name = os.path.dirname(test_data_path).split('/')[-1] 188 | id2query_path = './intermediate_files/{}/id2query.json'.format(dataset_name) 189 | id2clusterquery_path = './intermediate_files/{}/id2clusterquery.json'.format(dataset_name) 190 | id2allquery_path = './intermediate_files/{}/id2allquery.json'.format(dataset_name) 191 | query2url_path = './intermediate_files/{}/query2url.json'.format(dataset_name) 192 | 193 | folder_path = os.path.dirname(id2query_path) 194 | os.makedirs(folder_path, exist_ok=True) 195 | 196 | if os.path.exists(id2clusterquery_path): 197 | with open(id2clusterquery_path, 'r') as f: 198 | id2clusterquery = json.load(f) 199 | else: 200 | id2clusterquery = {} 201 | 202 | if os.path.exists(id2allquery_path): 203 | with open(id2allquery_path, 'r') as f: 204 | id2allquery = json.load(f) 205 | else: 206 | id2allquery = {} 207 | 208 | if os.path.exists(id2query_path): 209 | with open(id2query_path, 'r') as f: 210 | id2query = json.load(f) 211 | else: 212 | id2query = {} 213 | 214 | if os.path.exists(query2url_path): 215 | with open(query2url_path, 'r') as f: 216 | query2url = json.load(f) 217 | else: 218 | query2url = {} 219 | 220 | save_step = args.save_step 221 | __iter = tqdm(test_data) 222 | for idx, sample in enumerate(__iter): 223 | __iter.set_description('processed query num = {}'.format(len(query2url.keys()))) 224 | 225 | question_id = sample["question_id"] 226 | question = sample["text"].split('Question: ')[-1].split('\nAnswers:')[0] 227 | image_name = sample["image"] 228 | img_path = os.path.join(test_img_dir, image_name) 229 | 230 | if not question_id in id2query.keys(): 231 | q_queries = extract_q_query(question, nlp, llama3_generator) 232 | else: 233 | q_queries = set(id2query[question_id]) 234 | if len(q_queries) == 0: 235 | q_queries = extract_q_query(question, nlp, llama3_generator) 236 | 237 | if not image_name in id2query.keys(): 238 | v_queries = extract_v_query(img_path, nlp) 239 | else: 240 | v_queries = set(id2query[image_name]) 241 | 242 | all_queries = q_queries | v_queries 243 | tmp_queries = [] 244 | for x in all_queries: 245 | if len(x) > 5: 246 | tmp_queries.append(x) 247 | clustered_queries = set(_cluster(clip_model, clip_processor, device, tmp_queries, 3)) 248 | all_queries = all_queries | clustered_queries 249 | 250 | id2query[question_id] = list(q_queries) 251 | id2query[image_name] = list(v_queries) 252 | id2clusterquery[question_id] = list(clustered_queries) 253 | id2allquery[question_id] = list(all_queries) 254 | 255 | for query in clustered_queries: 256 | if not query in query2url.keys(): 257 | query2url[query] = [] 258 | 259 | if idx % save_step == 0: 260 | json.dump(id2query, open(id2query_path, 'w'), indent=4) 261 | json.dump(id2allquery, open(id2allquery_path, 'w'), indent=4) 262 | json.dump(id2clusterquery, open(id2clusterquery_path, 'w'), indent=4) 263 | json.dump(query2url, open(query2url_path, 'w'), indent=4) 264 | 265 | json.dump(id2query, open(id2query_path, 'w'), indent=4) 266 | json.dump(id2allquery, open(id2allquery_path, 'w'), indent=4) 267 | json.dump(id2clusterquery, open(id2clusterquery_path, 'w'), indent=4) 268 | json.dump(query2url, open(query2url_path, 'w'), indent=4) 269 | 270 | 271 | if __name__ == '__main__': 272 | main() -------------------------------------------------------------------------------- /scripts/utils/utils_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | from googleapiclient.discovery import build 3 | import requests 4 | from utils import get_hash_name_suffix, parse_ner_res 5 | import urllib.parse 6 | import json 7 | import string 8 | 9 | 10 | def longest_common_substring_3(s1, s2, s3): 11 | len1, len2, len3 = len(s1), len(s2), len(s3) 12 | dp = [[[0] * (len3+1) for _ in range(len2+1)] for __ in range(len1+1)] 13 | 14 | max_len = 0 15 | end_index_s1 = -1 16 | 17 | for i in range(1, len1+1): 18 | for j in range(1, len2+1): 19 | for k in range(1, len3+1): 20 | if s1[i-1] == s2[j-1] == s3[k-1]: 21 | dp[i][j][k] = dp[i-1][j-1][k-1] + 1 22 | if dp[i][j][k] > max_len: 23 | max_len = dp[i][j][k] 24 | end_index_s1 = i - 1 25 | else: 26 | dp[i][j][k] = 0 27 | 28 | if max_len > 0: 29 | return s1[end_index_s1 - max_len + 1:end_index_s1 + 1].strip() 30 | else: 31 | return "" 32 | 33 | def longest_common_substring_2(s1, s2): 34 | len1, len2 = len(s1), len(s2) 35 | dp = [[0] * (len2 + 1) for _ in range(len1 + 1)] 36 | 37 | max_len = 0 38 | end_index_s1 = -1 39 | 40 | for i in range(1, len1 + 1): 41 | for j in range(1, len2 + 1): 42 | if s1[i - 1] == s2[j - 1]: 43 | dp[i][j] = dp[i - 1][j - 1] + 1 44 | if dp[i][j] > max_len: 45 | max_len = dp[i][j] 46 | end_index_s1 = i - 1 47 | else: 48 | dp[i][j] = 0 49 | 50 | if max_len > 0: 51 | return s1[end_index_s1 - max_len + 1:end_index_s1 + 1].strip() 52 | else: 53 | return "" 54 | 55 | 56 | def bing_visual_search(img_path, SUBSCRIPTION_KEY, site='', nlp=None): 57 | BASE_URI = 'https://api.bing.microsoft.com/v7.0/images/visualsearch' 58 | 59 | HEADERS = {'Ocp-Apim-Subscription-Key': SUBSCRIPTION_KEY, "Accept-Language": 'en'} 60 | file = {'image' : ('myfile', open(img_path, 'rb'))} 61 | params = {"mkt": "en-US"} 62 | 63 | try: 64 | try: 65 | response = requests.post(BASE_URI, headers=HEADERS, params=params, files=file) 66 | except Exception as err: 67 | return str(err) 68 | response.raise_for_status() 69 | response_json = response.json() 70 | 71 | brq_set = set() 72 | entity_set = set() 73 | related_list = [] 74 | url_title_set = set() 75 | for tag in response_json['tags']: 76 | for act in tag["actions"]: 77 | if "actionType" in act.keys() and act["actionType"] == "BestRepresentativeQuery" and "displayName" in act.keys(): 78 | brq_set.add(act["displayName"]) 79 | if "actionType" in act.keys() and act["actionType"] == "Entity" and "displayName" in act.keys(): 80 | entity_set.add(act["displayName"]) 81 | if "actionType" in act.keys() and act["actionType"] == "RelatedSearches": 82 | if "data" in act.keys() and "value" in act["data"].keys(): 83 | data = act["data"]["value"] 84 | for result in data: 85 | if not result['text'] in related_list: 86 | related_list.append(result['text']) 87 | if "actionType" in act.keys() and act["actionType"] in ["PagesIncluding", "VisualSearch"]: 88 | if "data" in act.keys() and "value" in act["data"].keys(): 89 | data = act["data"]["value"] 90 | for result in data: 91 | title = result["name"] if "name" in result.keys() else "" 92 | url_title_set.add(title) 93 | 94 | brq_list = list(brq_set) 95 | entity_list = list(entity_set) 96 | url_title_list = list(url_title_set) 97 | 98 | search_str = '' 99 | if len(brq_list) > 0: 100 | search_str = ', '.join(brq_list) 101 | elif len(entity_list) > 0: 102 | search_str = ', '.join(entity_list) 103 | elif len(related_list) >= 3: 104 | str1, str2, str3 = related_list[0], related_list[1], related_list[2] 105 | search_str = longest_common_substring_3(str1, str2, str3) 106 | 107 | if search_str == '' and len(related_list) == 2: 108 | str1, str2 = related_list[0], related_list[1] 109 | search_str = longest_common_substring_2(str1, str2) 110 | if search_str == '' and len(related_list) == 1: 111 | search_str = related_list[0] 112 | 113 | if search_str == '' and len(url_title_list) >= 3: 114 | str1, str2, str3 = url_title_list[0], url_title_list[1], url_title_list[2] 115 | search_str = longest_common_substring_3(str1, str2, str3) 116 | if search_str == '' and len(url_title_list) == 2: 117 | str1, str2 = url_title_list[0], url_title_list[1] 118 | search_str = longest_common_substring_2(str1, str2) 119 | if search_str == '' and len(url_title_list) == 1: 120 | title = url_title_list[0] 121 | parse_ent_list = [x['entity'] for x in parse_ner_res(title, nlp(title))] 122 | search_str = ', '.join(parse_ent_list) 123 | 124 | ret_dict = {'search_str': search_str, 'brq_list':brq_list, 'entity_list':entity_list, 'related_list':related_list, 'url_title_list':url_title_list} 125 | return ret_dict 126 | 127 | except Exception as err: 128 | return str(err) 129 | 130 | def bing_text_search(search_term, subscription_key, site='', freshness=''): 131 | 132 | endpoint = "https://api.bing.microsoft.com/v7.0/search" 133 | query = search_term 134 | if len(site) > 0: 135 | query = query + ' ' + site 136 | 137 | headers = {"Ocp-Apim-Subscription-Key": subscription_key} 138 | if freshness == '': 139 | params = {"q": query, "count": 10, "mkt": "en-US"} 140 | else: 141 | params = {"q": query, "count": 10, "mkt": "en-US", 'freshness': freshness} 142 | 143 | try: 144 | try: 145 | response = requests.get(endpoint, headers=headers, params=params) 146 | except Exception as err: 147 | return str(err) 148 | response.raise_for_status() 149 | search_results = response.json() 150 | 151 | ret = [] 152 | for result in search_results.get("webPages", {}).get("value", []): 153 | title = result["name"] if "name" in result.keys() else "" 154 | snippet = result["snippet"] if "snippet" in result.keys() else "" 155 | url = result["url"] if "url" in result.keys() else "" 156 | datePublished = result["datePublished"] if "datePublished" in result.keys() else "0000-00-00T00:00:00.0000000Z" 157 | primaryImageOfPage = result["primaryImageOfPage"] if "primaryImageOfPage" in result.keys() else "" 158 | item_type = "bing_text_search" 159 | 160 | ret.append({"title": title, "url": url, "snippet": snippet, "datePublished": datePublished, "primaryImageOfPage": primaryImageOfPage, "type": item_type}) 161 | return ret 162 | except Exception as err: 163 | return str(err) 164 | 165 | def bing_news_search(search_term, subscription_key, site='', exclude='', freshness=''): 166 | 167 | endpoint = 'https://api.bing.microsoft.com/v7.0/news/search' 168 | query = search_term 169 | if len(site) > 0: 170 | query = query + ' ' + site 171 | if len(exclude) > 0: 172 | query = query + ' ' + exclude 173 | 174 | headers = {"Ocp-Apim-Subscription-Key": subscription_key} 175 | if freshness == '': 176 | params = {"q": query, "count": 10, "mkt": "en-US"} 177 | else: 178 | params = {"q": query, "count": 10, "mkt": "en-US", 'freshness': freshness} 179 | 180 | try: 181 | try: 182 | response = requests.get(endpoint, headers=headers, params=params) 183 | except Exception as err: 184 | return str(err) 185 | 186 | response.raise_for_status() 187 | search_results = response.json() 188 | 189 | ret = [] 190 | for result in search_results.get("value", []): 191 | title = result["name"] if "name" in result.keys() else "" 192 | snippet = result["description"] if "description" in result.keys() else "" 193 | url = result["url"] if "url" in result.keys() else "" 194 | datePublished = result["datePublished"] if "datePublished" in result.keys() else "0000-00-00T00:00:00.0000000Z" 195 | primaryImageOfPage = result["primaryImageOfPage"] if "primaryImageOfPage" in result.keys() else "" 196 | item_type = "bing_news_search" 197 | 198 | ret.append({"title": title, "url": url, "snippet": snippet, "datePublished": datePublished, "primaryImageOfPage": primaryImageOfPage, "type": item_type}) 199 | return ret 200 | except Exception as err: 201 | return str(err) 202 | 203 | def bing_image_search(search_term, subscription_key, site=''): 204 | endpoint = "https://api.bing.microsoft.com/v7.0/images/search" 205 | 206 | query = search_term 207 | if len(site) > 0: 208 | query = query + ' ' + site 209 | 210 | headers = {"Ocp-Apim-Subscription-Key": subscription_key} 211 | params = {"q": query, "count": 10, "imageType": "photo"} 212 | 213 | try: 214 | try: 215 | response = requests.get(endpoint, headers=headers, params=params) 216 | except Exception as err: 217 | return str(err) 218 | response.raise_for_status() 219 | search_results = response.json() 220 | 221 | ret = [] 222 | for item in search_results.get("value", []): 223 | web_url = item['hostPageUrl'] if 'hostPageUrl' in item.keys() else '' 224 | web_title = item['title'] if 'title' in item.keys() else '' 225 | web_snippet = item['snippet'] if 'snippet' in item.keys() else '' 226 | 227 | try: 228 | img_url = item['contentUrl'] 229 | except: 230 | try: 231 | img_url = item['thumbnailUrl'] 232 | except: 233 | img_url = 'Error' 234 | img_name, parse_suffix = get_hash_name_suffix(img_url) 235 | try: 236 | img_suffix = item['encodingFormat'] 237 | except: 238 | img_suffix = parse_suffix 239 | img_name = 'bing_' + img_name 240 | 241 | tmp_dict = {'img_url': img_url, 'img_name': img_name, 'img_suffix': img_suffix, 242 | 'web_url': web_url, 'web_title': web_title, 'web_snippet': web_snippet} 243 | ret.append(tmp_dict) 244 | return ret 245 | 246 | except Exception as err: 247 | return str(err) 248 | 249 | def google_search(search_term, api_key, cse_id, site=''): 250 | try: 251 | service = build("customsearch", "v1", developerKey=api_key) 252 | results = service.cse().list(q=search_term, cx=cse_id).execute() 253 | if isinstance(results, str): 254 | return results 255 | return results['items'] 256 | except Exception as err: 257 | return str(err) 258 | 259 | def google_text_search(search_term, api_key, cse_id, site=''): 260 | try: 261 | query = search_term 262 | if len(site) > 0: 263 | query = query + ' ' + site 264 | 265 | try: 266 | results = google_search(query, api_key, cse_id, site) 267 | except Exception as err: 268 | return str(err) 269 | if isinstance(results, str): 270 | return results 271 | if isinstance(results, dict) and 'error' in results.keys(): 272 | return results['error']['message'] 273 | 274 | ret = [] 275 | for result in results: 276 | title = result["title"] if "title" in result.keys() else "" 277 | snippet = result["snippet"] if "snippet" in result.keys() else "" 278 | url = result["link"] if "link" in result.keys() else "" 279 | datePublished = result["datePublished"] if "datePublished" in result.keys() else "0000-00-00T00:00:00.0000000Z" 280 | primaryImageOfPage = result["primaryImageOfPage"] if "primaryImageOfPage" in result.keys() else "" 281 | item_type = "google_text_search" 282 | 283 | ret.append({"title": title, "url": url, "snippet": snippet, "datePublished": datePublished, "primaryImageOfPage": primaryImageOfPage, "type": item_type}) 284 | return ret 285 | except Exception as err: 286 | return str(results) + str(err) 287 | 288 | def google_image_search(search_term, api_key, cse_id, site=''): 289 | try: 290 | query = search_term 291 | if len(site) > 0: 292 | query = query + ' ' + site 293 | encoded_query = urllib.parse.quote(query) 294 | search_url = 'https://www.googleapis.com/customsearch/v1?q={}&cx={}&searchType=image&key={}'.format(encoded_query, cse_id, api_key) 295 | results_raw = requests.get(search_url) 296 | results = results_raw.json() 297 | except Exception as err: 298 | return str(results) + str(err) 299 | if isinstance(results, str): 300 | return results 301 | if isinstance(results, dict) and 'error' in results.keys(): 302 | return results['error']['message'] 303 | 304 | ret = [] 305 | for item in results["items"]: 306 | if "image" in item.keys() and "contextLink" in item["image"].keys(): 307 | web_url = item["image"]['contextLink'] 308 | else: 309 | web_url = ' ' 310 | web_title = item['title'] if 'snititleppet' in item.keys() else '' 311 | web_snippet = item['snippet'] if 'snippet' in item.keys() else '' 312 | 313 | try: 314 | img_url = item['link'] 315 | except: 316 | try: 317 | img_url = item['image']['thumbnailLink'] 318 | except: 319 | img_url = 'Error' 320 | img_name, parse_suffix = get_hash_name_suffix(img_url) 321 | img_suffix = item['fileFormat'].split('image/')[-1] if 'fileFormat' in item.keys() else '' 322 | img_suffix = parse_suffix if img_suffix == '' else img_suffix 323 | 324 | img_name = 'google_' + img_name 325 | tmp_dict = {'img_url': img_url, 'img_name': img_name, 'img_suffix': img_suffix, 326 | 'web_url': web_url, 'web_title': web_title, 'web_snippet': web_snippet} 327 | ret.append(tmp_dict) 328 | return ret 329 | 330 | def search_from_web(search_term, engine='google', search_type='text', site_list=[], exclude_list=[], freshness='', nlp=None): 331 | google_api_key = os.environ.get("google_api_key") 332 | google_text_cse_id = os.environ.get("google_text_cse_id") 333 | google_image_cse_id = os.environ.get("google_image_cse_id") 334 | 335 | bing_text_api_key = os.environ.get("bing_text_api_key") 336 | bing_img_api_key = os.environ.get("bing_img_api_key") 337 | bing_visual_api_key = os.environ.get("bing_visual_api_key") 338 | 339 | site_list = ['site:{}'.format(x) for x in site_list] 340 | site_str = ' OR '.join(site_list) 341 | 342 | exclude_list = ['-site:{}'.format(x) for x in exclude_list] 343 | exclude_str = ' '.join(exclude_list) 344 | 345 | if search_type == 'text': 346 | if engine == 'google': 347 | return google_text_search(search_term, google_api_key, google_text_cse_id, site_str) 348 | elif engine == 'bing': 349 | return bing_text_search(search_term, bing_text_api_key, site_str, freshness=freshness) 350 | 351 | elif search_type == 'image': 352 | if engine == 'google': 353 | return google_image_search(search_term, google_api_key, google_image_cse_id, site_str) 354 | elif engine == 'bing': 355 | return bing_image_search(search_term, bing_img_api_key, site_str) 356 | 357 | elif search_type == 'news': 358 | return bing_news_search(search_term, bing_text_api_key, site_str, exclude_str, freshness=freshness) 359 | 360 | elif search_type == 'visual': 361 | return bing_visual_search(search_term, bing_visual_api_key, site_str, nlp=nlp) 362 | 363 | --------------------------------------------------------------------------------