├── Results ├── main.png ├── analysis.png └── error_case.jpg ├── Comparison ├── compare.jpg ├── splits.png └── statistic.png ├── Examples ├── 000000000933.jpg ├── 000000006568.jpg ├── 000000015740.jpg ├── 000000057139.jpg ├── 000000070986.jpg ├── 000000100633.jpg ├── 000000121362.jpg ├── 000000142379.jpg ├── examples_1-4.png ├── examples_5-8.png └── examples.jsonl ├── Code ├── eval │ ├── README.md │ ├── calculate_result_rule.py │ ├── calculate_xyz.py │ └── calculate_prf1.py ├── requirement │ ├── requirement_gpt4.txt │ ├── requirement_gemini.txt │ ├── requirement_spacellava.txt │ ├── requirement_idefics.txt │ ├── requirement_blip.txt │ ├── requirement_mplug.txt │ └── requirement_llava.txt ├── finetune │ ├── llava_lora_train.sh │ ├── spacellava_lora_train.sh │ ├── mPLUG_Owl_train_it.sh │ ├── idefics.py │ ├── blip-vqa-base.py │ ├── instructblip-lora.py │ └── blip2-lora.py ├── experiment │ ├── blip-vqa-base.py │ ├── blip-vqa-base_finetuned.py │ ├── blip2-opt-2.7b.py │ ├── blip2-lora.py │ ├── instructblip-flan-t5-xl.py │ ├── instructblip-lora.py │ ├── spatial_test_mplug.py │ ├── spatial_test_mplug_lora.py │ ├── idefics_new.py │ ├── idefics_lora.py │ ├── spacellava_test.py │ ├── spatial_test_llava.py │ ├── spatial_test_llava_lora.py │ └── spacellava_lora_test.py └── close_models │ ├── gemini_text_only.py │ ├── gemini_0_shot.py │ ├── gpt4_text_only.py │ ├── gpt4_zero_shot.py │ ├── gemini_1_shot_random.py │ ├── gemini_2_shot_random.py │ ├── gpt4_1_shot.py │ ├── gemini_3_shot_random.py │ └── gpt4_2_shot.py ├── Dataset ├── relevant_images │ └── README.md ├── tool │ └── select_relevant_images.py ├── types │ └── types.txt └── README.md ├── Metadata ├── metadata_hf.json └── metadata.json └── README.md /Results/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Results/main.png -------------------------------------------------------------------------------- /Comparison/compare.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Comparison/compare.jpg -------------------------------------------------------------------------------- /Comparison/splits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Comparison/splits.png -------------------------------------------------------------------------------- /Results/analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Results/analysis.png -------------------------------------------------------------------------------- /Results/error_case.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Results/error_case.jpg -------------------------------------------------------------------------------- /Comparison/statistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Comparison/statistic.png -------------------------------------------------------------------------------- /Examples/000000000933.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000000933.jpg -------------------------------------------------------------------------------- /Examples/000000006568.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000006568.jpg -------------------------------------------------------------------------------- /Examples/000000015740.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000015740.jpg -------------------------------------------------------------------------------- /Examples/000000057139.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000057139.jpg -------------------------------------------------------------------------------- /Examples/000000070986.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000070986.jpg -------------------------------------------------------------------------------- /Examples/000000100633.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000100633.jpg -------------------------------------------------------------------------------- /Examples/000000121362.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000121362.jpg -------------------------------------------------------------------------------- /Examples/000000142379.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000142379.jpg -------------------------------------------------------------------------------- /Examples/examples_1-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/examples_1-4.png -------------------------------------------------------------------------------- /Examples/examples_5-8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/examples_5-8.png -------------------------------------------------------------------------------- /Code/eval/README.md: -------------------------------------------------------------------------------- 1 | **calculate_prf1.py** can calculate overall accuracy, overall P, R, F1 indicators. 2 | 3 | **calculate_xyz.py** can calculate accuracy based on relationship categories. 4 | 5 | **calculate_result_rule.py** can calculate accuracy based on rules. 6 | -------------------------------------------------------------------------------- /Dataset/relevant_images/README.md: -------------------------------------------------------------------------------- 1 | ### Download images 2 | We use a subset of COCO-2017's images. The following script download COCO-2017's test sets images then put them into a single fodler `Dataset/COCO2017/`. 3 | 4 | ```bash 5 | cd Dataset/ 6 | wget http://images.cocodataset.org/zips/test2017.zip 7 | unzip test2017.zip 8 | mv test2017 COCO2017 && rm -r test2017 9 | ``` 10 | Copy only relevant images to `relevant_images/`. 11 | ```bash 12 | mkdir relevant_images 13 | cd tool 14 | python select_revlevant_images.py 15 | ``` 16 | Alternatively, you could also browse individual images online directly using the key "image" in single json data. 17 |
(Through COCO's open source link, 'http://images.cocodataset.org/test2017/' + 'image_name'. For example: http://images.cocodataset.org/test2017/000000195921.jpg.) 18 | -------------------------------------------------------------------------------- /Code/requirement/requirement_gpt4.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.6.0 2 | anyio==4.3.0 3 | certifi==2024.2.2 4 | charset-normalizer==3.3.2 5 | colorama==0.4.6 6 | distro==1.9.0 7 | exceptiongroup==1.2.1 8 | h11==0.14.0 9 | httpcore==1.0.5 10 | httpx==0.27.0 11 | idna==3.7 12 | openai==1.30.1 13 | Pillow==7.2.0 14 | pip==24.0 15 | pydantic==2.7.1 16 | pydantic_core==2.18.2 17 | requests==2.31.0 18 | setuptools==69.5.1 19 | sniffio==1.3.1 20 | tk==0.1.0 21 | tqdm==4.66.4 22 | typing_extensions==4.11.0 23 | urllib3==2.2.1 24 | wheel==0.43.0 25 | Could not fetch URL https://mirrors.aliyun.com/pypi/simple/pip/: There was a problem confirming the ssl certificate: HTTPSConnectionPool(host='mirrors.aliyun.com', port=443): Max retries exceeded with url: /pypi/simple/pip/ (Caused by SSLError(SSLZeroReturnError(6, 'TLS/SSL connection has been closed (EOF) (_ssl.c:1135)'))) - skipping 26 | -------------------------------------------------------------------------------- /Code/requirement/requirement_gemini.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.6.0 2 | beautifulsoup4==4.12.3 3 | cachetools==5.3.3 4 | certifi==2024.2.2 5 | charset-normalizer==3.3.2 6 | colorama==0.4.6 7 | google==3.0.0 8 | google-ai-generativelanguage==0.6.3 9 | google-api-core==2.19.0 10 | google-api-python-client==2.129.0 11 | google-auth==2.29.0 12 | google-auth-httplib2==0.2.0 13 | google-generativeai==0.5.3 14 | googleapis-common-protos==1.63.0 15 | grpcio==1.63.0 16 | grpcio-status==1.62.2 17 | httplib2==0.22.0 18 | idna==3.7 19 | joblib==1.4.2 20 | numpy==1.26.4 21 | pillow==10.3.0 22 | pip==24.0 23 | proto-plus==1.23.0 24 | protobuf==4.25.3 25 | pyasn1==0.6.0 26 | pyasn1_modules==0.4.0 27 | pydantic==2.7.1 28 | pydantic_core==2.18.2 29 | pyparsing==3.1.2 30 | requests==2.31.0 31 | rsa==4.9 32 | scikit-learn==1.5.0 33 | scipy==1.13.1 34 | setuptools==69.5.1 35 | soupsieve==2.5 36 | threadpoolctl==3.5.0 37 | tqdm==4.66.4 38 | typing_extensions==4.11.0 39 | uritemplate==4.1.1 40 | urllib3==2.2.1 41 | wheel==0.43.0 42 | -------------------------------------------------------------------------------- /Code/requirement/requirement_spacellava.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | brotli=1.1.0=pypi_0 7 | ca-certificates=2024.7.2=h06a4308_0 8 | certifi=2024.7.4=pypi_0 9 | charset-normalizer=3.3.2=pypi_0 10 | gevent=24.2.1=pypi_0 11 | geventhttpclient=2.0.12=pypi_0 12 | greenlet=3.0.3=pypi_0 13 | idna=3.7=pypi_0 14 | ld_impl_linux-64=2.38=h1181459_1 15 | libffi=3.4.4=h6a678d5_1 16 | libgcc-ng=11.2.0=h1234567_1 17 | libgomp=11.2.0=h1234567_1 18 | libstdcxx-ng=11.2.0=h1234567_1 19 | ncurses=6.4=h6a678d5_0 20 | numpy=1.24.4=pypi_0 21 | openssl=3.0.14=h5eee18b_0 22 | pip=24.0=py38h06a4308_0 23 | python=3.8.19=h955ad1f_0 24 | python-rapidjson=1.20=pypi_0 25 | readline=8.2=h5eee18b_0 26 | requests=2.32.3=pypi_0 27 | setuptools=72.1.0=py38h06a4308_0 28 | six=1.16.0=pypi_0 29 | sqlite=3.45.3=h5eee18b_0 30 | tk=8.6.14=h39e8969_0 31 | tritonclient=2.48.0=pypi_0 32 | urllib3=2.2.2=pypi_0 33 | wheel=0.43.0=py38h06a4308_0 34 | xz=5.4.6=h5eee18b_1 35 | zlib=1.2.13=h5eee18b_1 36 | zope-event=5.0=pypi_0 37 | zope-interface=7.0.1=pypi_0 38 | -------------------------------------------------------------------------------- /Dataset/tool/select_relevant_images.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | 5 | 6 | def copy_images(jsonl_file, source_dir, destination_dir): 7 | if not os.path.exists(destination_dir): 8 | os.makedirs(destination_dir) 9 | count = 0 10 | with open(jsonl_file, 'r', encoding='gbk', errors='ignore') as f: 11 | for line in f: 12 | count += 1 13 | data = json.loads(line) 14 | image_filename = data['image'] 15 | 16 | source_path = os.path.join(source_dir, image_filename) 17 | 18 | if os.path.exists(source_path): 19 | destination_path = os.path.join(destination_dir, image_filename) 20 | shutil.copyfile(source_path, destination_path) 21 | else: 22 | print(f"Source file {image_filename} not found in {source_dir}") 23 | 24 | print(jsonl_file+': '+str(count)) 25 | print("Copy successful!") 26 | 27 | 28 | jsonl_file = os.listdir('../dataset/') 29 | for i in range(len(jsonl_file)): 30 | jsonl_file[i] = '../dataset/'+jsonl_file[i] 31 | source_dir = '../COCO2017' 32 | destination_dir = '../relevant_images' 33 | 34 | for file in jsonl_file: 35 | copy_images(file, source_dir, destination_dir) 36 | -------------------------------------------------------------------------------- /Code/finetune/llava_lora_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed --include localhost:6 llava/train/train_mem.py \ 4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path /projects/Models/LLaVA-main/llava-v1.5-7b \ 7 | --version v1 \ 8 | --data_path /projects/SpatialMQA/datasets/llava_train/train_3780.json \ 9 | --image_folder /projects \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length True \ 17 | --bf16 True \ 18 | --output_dir /projects/SpatialMQA/finetune_models/models_arg/llava_lora_20240601 \ 19 | --num_train_epochs 10 \ 20 | --per_device_train_batch_size 8 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 2 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 100 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-4 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.02 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 2 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 0 \ 36 | --lazy_preprocess True 37 | -------------------------------------------------------------------------------- /Code/requirement/requirement_idefics.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.29.2 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | bitsandbytes==0.43.0 7 | certifi==2024.2.2 8 | charset-normalizer==3.3.2 9 | datasets==2.18.0 10 | dill==0.3.8 11 | filelock==3.13.4 12 | frozenlist==1.4.1 13 | fsspec==2024.2.0 14 | huggingface-hub==0.22.2 15 | idna==3.6 16 | Jinja2==3.1.3 17 | MarkupSafe==2.1.5 18 | mpmath==1.3.0 19 | multidict==6.0.5 20 | multiprocess==0.70.16 21 | networkx==3.3 22 | numpy==1.26.4 23 | nvidia-cublas-cu12==12.1.3.1 24 | nvidia-cuda-cupti-cu12==12.1.105 25 | nvidia-cuda-nvrtc-cu12==12.1.105 26 | nvidia-cuda-runtime-cu12==12.1.105 27 | nvidia-cudnn-cu12==8.9.2.26 28 | nvidia-cufft-cu12==11.0.2.54 29 | nvidia-curand-cu12==10.3.2.106 30 | nvidia-cusolver-cu12==11.4.5.107 31 | nvidia-cusparse-cu12==12.1.0.106 32 | nvidia-nccl-cu12==2.19.3 33 | nvidia-nvjitlink-cu12==12.4.127 34 | nvidia-nvtx-cu12==12.1.105 35 | packaging==24.0 36 | pandas==2.2.1 37 | peft==0.10.0 38 | pillow==10.3.0 39 | pip==23.3.1 40 | psutil==5.9.8 41 | pyarrow==15.0.2 42 | pyarrow-hotfix==0.6 43 | python-dateutil==2.9.0.post0 44 | pytz==2024.1 45 | PyYAML==6.0.1 46 | regex==2023.12.25 47 | requests==2.31.0 48 | safetensors==0.4.2 49 | setuptools==68.2.2 50 | six==1.16.0 51 | sympy==1.12 52 | tokenizers==0.15.2 53 | torch==2.2.2 54 | torchvision==0.17.2 55 | tqdm==4.66.2 56 | transformers==4.39.3 57 | triton==2.2.0 58 | typing_extensions==4.11.0 59 | tzdata==2024.1 60 | urllib3==2.2.1 61 | wheel==0.41.2 62 | xxhash==3.4.1 63 | yarl==1.9.4 64 | -------------------------------------------------------------------------------- /Code/finetune/spacellava_lora_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed --include localhost:0,1 llava/train/train_mem.py \ 4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/spacellava_hf \ 7 | --version v1 \ 8 | --data_path /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/datasets/train_3000.json \ 9 | --image_folder /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/datasets/images \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length True \ 17 | --bf16 True \ 18 | --output_dir /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/saved_model/spacellava_lora_20240816_3000 \ 19 | --num_train_epochs 2 \ 20 | --per_device_train_batch_size 8 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 1 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 60 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-4 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.02 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 0 \ 36 | --lazy_preprocess True 37 | # --report_to wandb 38 | -------------------------------------------------------------------------------- /Dataset/types/types.txt: -------------------------------------------------------------------------------- 1 | person 2 | letter 3 | natural scenery 4 | hand 5 | cat 6 | cup 7 | sports ball 8 | dog 9 | car 10 | lamp 11 | cell phone 12 | fork 13 | clock 14 | bottle 15 | shadow 16 | building 17 | window 18 | laptop 19 | baseball bat 20 | knife 21 | tv 22 | bicycle 23 | number 24 | backpack 25 | stop sign 26 | paper 27 | bus 28 | book 29 | door 30 | elephant 31 | frisbee 32 | banana 33 | trash can 34 | wine glass 35 | towel 36 | mouse 37 | box 38 | remote 39 | spoon 40 | horse 41 | sink 42 | motorcycle 43 | teddy bear 44 | vase 45 | fire hydrant 46 | jar 47 | tree and wood 48 | potted plant 49 | chair 50 | shelf 51 | suitcase 52 | bear 53 | bird 54 | flower and grass 55 | pot 56 | traffic light 57 | refrigerator 58 | bowl 59 | fence 60 | skateboard 61 | mural 62 | mirror 63 | socket 64 | toy 65 | cow 66 | basket 67 | frame 68 | cake 69 | surfboard 70 | pillow 71 | clothes 72 | bench 73 | scissors 74 | orange 75 | handbag 76 | bread 77 | plate 78 | giraffe 79 | sheep 80 | shower 81 | camera 82 | boat 83 | pen 84 | dining table 85 | toothbrush 86 | billboard 87 | sculpture 88 | signboard 89 | umbrella 90 | kettle 91 | microwave 92 | truck 93 | train 94 | keyboard 95 | apple 96 | glass 97 | cabinet 98 | drink 99 | tennis racket 100 | hair drier 101 | speaker 102 | carrot 103 | bathtub 104 | fan 105 | rack 106 | faucet 107 | airplane 108 | picture 109 | snowboard 110 | sanitizer 111 | switch 112 | glasses 113 | ketchup 114 | scale 115 | bed 116 | zebra 117 | couch 118 | seat 119 | pizza 120 | toilet 121 | wall 122 | fireplace 123 | display 124 | pole 125 | screen 126 | dessert 127 | others_sub 128 | others_obj 129 | -------------------------------------------------------------------------------- /Examples/examples.jsonl: -------------------------------------------------------------------------------- 1 | {"image": "000000000933.jpg", "question": "Where is the fork located relative to the pizza?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "right of"} 2 | {"image": "000000006568.jpg", "question": "Where is the cat located relative to the car in the image?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "on/above"} 3 | {"image": "000000057139.jpg", "question": "For the white letters on the red warning sign, where is the letter P located relative to the letter Y?", "options": ["on/above", "below", "left of", "right of"], "answer": "on/above"} 4 | {"image": "000000100633.jpg", "question": "If you are the cyclist in the image, where is the dog located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"} 5 | {"image": "000000121362.jpg", "question": "If you are the player in the image, where is the audience located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"} 6 | {"image": "000000142379.jpg", "question": "If you are the giraffe in the image, where is the tree located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"} 7 | {"image": "000000015740.jpg", "question": "If you are the woman in the image, from your perspective, where is the mouse located relative to the keyboard?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "left of"} 8 | {"image": "000000070986.jpg", "question": "If you are the driver of the bus in the image, from your perspective, where is the red car located relative to the bus?", "options": ["in front of", "behind", "left of", "right of"], "answer": "left of"} 9 | -------------------------------------------------------------------------------- /Code/requirement/requirement_blip.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | accelerate==0.25.0 3 | aiohttp==3.9.1 4 | aiosignal==1.3.1 5 | async-timeout==4.0.3 6 | attrs==23.2.0 7 | bitsandbytes==0.41.3.post2 8 | certifi==2023.11.17 9 | charset-normalizer==3.3.2 10 | click==8.1.7 11 | datasets==2.14.6 12 | dill==0.3.7 13 | filelock==3.13.1 14 | frozenlist==1.4.1 15 | fsspec==2023.10.0 16 | huggingface-hub==0.20.3 17 | idna==3.6 18 | Jinja2==3.1.2 19 | joblib==1.3.2 20 | MarkupSafe==2.1.3 21 | mpmath==1.3.0 22 | multidict==6.0.4 23 | multiprocess==0.70.15 24 | munch==4.0.0 25 | networkx==3.2.1 26 | nltk==3.8.1 27 | numpy==1.26.3 28 | nvidia-cublas-cu12==12.1.3.1 29 | nvidia-cuda-cupti-cu12==12.1.105 30 | nvidia-cuda-nvrtc-cu12==12.1.105 31 | nvidia-cuda-runtime-cu12==12.1.105 32 | nvidia-cudnn-cu12==8.9.2.26 33 | nvidia-cufft-cu12==11.0.2.54 34 | nvidia-curand-cu12==10.3.2.106 35 | nvidia-cusolver-cu12==11.4.5.107 36 | nvidia-cusparse-cu12==12.1.0.106 37 | nvidia-nccl-cu12==2.18.1 38 | nvidia-nvjitlink-cu12==12.3.101 39 | nvidia-nvtx-cu12==12.1.105 40 | packaging==23.2 41 | pandas==2.1.4 42 | peft==0.10.0 43 | Pillow==10.0.1 44 | pip==22.0.4 45 | protobuf==4.25.2 46 | psutil==5.9.7 47 | pyarrow==14.0.2 48 | python-dateutil==2.8.2 49 | pytz==2023.3.post1 50 | PyYAML==6.0.1 51 | regex==2023.12.25 52 | requests==2.31.0 53 | rouge-score==0.1.2 54 | ruamel.yaml==0.18.6 55 | ruamel.yaml.clib==0.2.8 56 | safetensors==0.4.1 57 | scipy==1.11.4 58 | sconf==0.2.5 59 | sentencepiece==0.1.99 60 | setuptools==58.1.0 61 | six==1.16.0 62 | sympy==1.12 63 | tokenizers==0.15.0 64 | torch==2.1.2 65 | torchsummary==1.5.1 66 | torchvision==0.16.0 67 | tqdm==4.66.1 68 | transformers==4.35.2 69 | triton==2.1.0 70 | typing_extensions==4.9.0 71 | tzdata==2023.4 72 | urllib3==2.1.0 73 | wheel==0.37.1 74 | xformers==0.0.23.post1 75 | xxhash==3.4.1 76 | yarl==1.9.4 77 | -------------------------------------------------------------------------------- /Code/finetune/mPLUG_Owl_train_it.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR=`pwd` 3 | DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` 4 | 5 | if [ $MASTER_ADDR ];then 6 | echo $MASTER_ADDR 7 | echo $MASTER_PORT 8 | echo $WORLD_SIZE 9 | echo $RANK 10 | else 11 | MASTER_ADDR=127.0.0.1 12 | MASTER_PORT=2$(($RANDOM % 10))$(($RANDOM % 10))15 13 | WORLD_SIZE=1 14 | RANK=0 15 | fi 16 | 17 | DISTRIBUTED_ARGS="--nproc_per_node 2 \ 18 | --nnodes ${WORLD_SIZE} \ 19 | --node_rank ${RANK} \ 20 | --master_addr ${MASTER_ADDR} \ 21 | --master_port ${MASTER_PORT}" 22 | 23 | EXP_NAME=sft_v0.1 24 | SAVE_NAME=mplug_lora_20240530 25 | 26 | SAVE_PATH="/projects/SpatialMQA/finetune_models/models_arg/${SAVE_NAME}/" 27 | 28 | max_length=2048 29 | micro_batch_size=4 30 | global_batch_size=8 31 | gradient_accumulation_steps=5 32 | 33 | # train_iters = total_data * train_epochs // global_batch_size 34 | # 3780 * 10 / 8 = 4725 35 | train_epochs=10 36 | train_iters=4725 37 | 38 | lr_warmup_iters=50 39 | 40 | eval_iter=50 41 | eval_interval=50 42 | save_interval=500 43 | 44 | mkdir -p ${SAVE_PATH} 45 | 46 | options=" \ 47 | --pretrained-ckpt MAGAer13/mplug-owl-llama-7b-pt \ 48 | --seq-length ${max_length} \ 49 | --micro-batch-size ${micro_batch_size} \ 50 | --num-training-steps ${train_iters} \ 51 | --train-epochs ${train_epochs} \ 52 | --num-warmup-steps ${lr_warmup_iters} \ 53 | --gradient-accumulation-steps ${gradient_accumulation_steps} \ 54 | --lr 5e-5 \ 55 | --min-lr 1e-6 \ 56 | --eval-iters ${eval_iter} \ 57 | --save-interval ${save_interval} \ 58 | --save-path ${SAVE_PATH} \ 59 | --clip-grad 1.0 \ 60 | --weight-decay 0.0001 \ 61 | --adam-beta1 0.9 \ 62 | --adam-beta2 0.999 \ 63 | --num-workers 32 \ 64 | --use-lora \ 65 | --gradient-checkpointing \ 66 | --bf16" 67 | 68 | multimodal_options=" \ 69 | --mm-config configs/v0.yaml 70 | " 71 | 72 | HF_ENDPOINT=https://hf-mirror.com CUDA_VISIBLE_DEVICES=5,6 python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pipeline/train.py $@ ${options} ${multimodal_options} 2>&1 | tee /projects/SpatialMQA/finetune_models/finetune_code/log/20240530_mplug_train.log -------------------------------------------------------------------------------- /Dataset/README.md: -------------------------------------------------------------------------------- 1 | # Data of SpatialMQA 2 | 3 | ### Download images 4 | We use a subset of COCO-2017's images. The following script download COCO-2017's test sets images then put them into a single fodler `Dataset/COCO2017/`. 5 | 6 | ```bash 7 | cd Dataset/ 8 | wget http://images.cocodataset.org/zips/test2017.zip 9 | unzip test2017.zip 10 | mv test2017 COCO2017 && rm -r test2017 11 | ``` 12 | Copy only relevant images to `relevant_images/`. 13 | ```bash 14 | mkdir relevant_images 15 | cd tool 16 | python select_revlevant_images.py 17 | ``` 18 | 19 | Alternatively, you could also browse individual images online directly using the key "image" in single json data. 20 | (Through COCO's open source link, 'http://images.cocodataset.org/test2017/' + 'image_name'. For example: http://images.cocodataset.org/test2017/000000195921.jpg.) 21 | 22 | ### Splits 23 | As reported in the folloeing table, SpatialMQA contains 5,392 samples, divided into training, validation, and test sets according to a 7:1:2 ratio. 24 |
All the splited data sets are in the directory [`dataset/`](https://github.com/ziyan-xiaoyu/SpatialMQA/blob/Dataset). 25 | 26 | ### Format of the data 27 | Each `jsonl` file is of the following format: 28 | ```json 29 | {"image": "000000000933.jpg", "question": "Where is the fork located relative to the pizza?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "right of"} 30 | {"image": "000000100633.jpg", "question": "If you are the cyclist in the image, where is the dog located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"} 31 | {"image": "000000070986.jpg", "question": "If you are the driver of the bus in the image, from your perspective, where is the red car located relative to the bus?", "options": ["in front of", "behind", "left of", "right of"], "answer": "left of"} 32 | {"..."} 33 | ``` 34 | Each line is an individual data point. 35 | `image` denotes name of the image in COCO. `question` is the question with manual annotation, `options` is reasonable combinations of six spatial relationships:(on/above, below, in front of, behind, left of, right of. `answer` is the annotation based on the objective world. 36 | 37 | -------------------------------------------------------------------------------- /Code/experiment/blip-vqa-base.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | from transformers import BlipProcessor, BlipForQuestionAnswering 4 | import torch 5 | import json 6 | 7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/' 8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 9 | DEVICE_INDEX = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") 10 | processor = BlipProcessor.from_pretrained(f"Salesforce/blip-vqa-base") 11 | model = BlipForQuestionAnswering.from_pretrained(f"Salesforce/blip-vqa-base").to(DEVICE_INDEX) 12 | 13 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 14 | count = 0 15 | right_count = 0 16 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip.jsonl', 'w+', encoding="utf-8") as fout: 17 | for line in f: 18 | data = json.loads(line) 19 | question = data['question'] 20 | id = data['id'] 21 | options = data['options'] 22 | image_name = data['image'] 23 | image_filepath = image_dir + image_name 24 | image = Image.open(image_filepath).convert('RGB') 25 | question = f'{question} {",".join(options[:-1])} or {options[-1]}' 26 | print(question) 27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX) 28 | predictions = model.generate(**inputs) 29 | output = (processor.decode(predictions[0], skip_special_tokens=True)) 30 | count += 1 31 | if len(output) == 0: 32 | output = '--' 33 | if output.lower() in data['answer']: 34 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 35 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 36 | right_count += 1 37 | elif data['answer'] in output.lower(): 38 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 39 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 40 | right_count += 1 41 | else: 42 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 43 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 44 | print(f'{output.lower()}') 45 | print(f"{data['answer']}") 46 | print(f'right_count: {right_count}') 47 | print(f'count: {count}') 48 | print(f'accuracy: {right_count/count}') 49 | 50 | accuracy = right_count/count 51 | print(f'accuracy: {accuracy}') 52 | 53 | -------------------------------------------------------------------------------- /Code/requirement/requirement_mplug.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.28.0 3 | aiofiles==23.2.1 4 | altair==5.3.0 5 | annotated-types==0.6.0 6 | anyio==4.3.0 7 | asttokens==2.4.1 8 | attrs==23.2.0 9 | bitsandbytes==0.43.0 10 | blinker==1.7.0 11 | Brotli==1.0.9 12 | cchardet==2.1.7 13 | certifi==2024.2.2 14 | chardet==5.2.0 15 | charset-normalizer==2.0.4 16 | click==8.1.7 17 | colorama==0.4.6 18 | contourpy==1.2.0 19 | cycler==0.12.1 20 | decord==0.6.0 21 | einops==0.7.0 22 | exceptiongroup==1.2.0 23 | executing==2.0.1 24 | fastapi==0.110.0 25 | ffmpy==0.3.2 26 | filelock==3.13.3 27 | Flask==3.0.2 28 | fonttools==4.50.0 29 | fsspec==2024.3.1 30 | gradio==4.24.0 31 | gradio_client==0.14.0 32 | grpcio==1.62.1 33 | h11==0.14.0 34 | h5py==3.10.0 35 | httpcore==1.0.5 36 | httpx==0.27.0 37 | huggingface-hub==0.22.2 38 | icecream==2.1.3 39 | idna==3.4 40 | importlib_resources==6.4.0 41 | itsdangerous==2.1.2 42 | Jinja2==3.1.3 43 | jsonschema==4.21.1 44 | jsonschema-specifications==2023.12.1 45 | kiwisolver==1.4.5 46 | Markdown==3.6 47 | markdown-it-py==3.0.0 48 | markdown2==2.4.13 49 | MarkupSafe==2.1.5 50 | matplotlib==3.8.3 51 | mdurl==0.1.2 52 | mkl-fft==1.3.8 53 | mkl-random==1.2.4 54 | mkl-service==2.4.0 55 | munch==4.0.0 56 | numpy==1.26.4 57 | opencv-python==4.9.0.80 58 | opencv-python-headless==4.9.0.80 59 | orjson==3.10.0 60 | packaging==24.0 61 | pandas==2.2.1 62 | peft==0.10.0 63 | pillow==10.2.0 64 | pip==23.3.1 65 | protobuf==5.26.1 66 | psutil==5.9.8 67 | pydantic==2.6.4 68 | pydantic_core==2.16.3 69 | pydub==0.25.1 70 | Pygments==2.17.2 71 | pyparsing==3.1.2 72 | PySocks==1.7.1 73 | python-dateutil==2.9.0.post0 74 | python-multipart==0.0.9 75 | pytz==2024.1 76 | PyYAML==6.0.1 77 | referencing==0.34.0 78 | regex==2023.12.25 79 | requests==2.31.0 80 | rich==13.7.1 81 | rpds-py==0.18.0 82 | ruamel.yaml==0.18.6 83 | ruamel.yaml.clib==0.2.8 84 | ruff==0.3.4 85 | safetensors==0.4.2 86 | sconf==0.2.5 87 | semantic-version==2.10.0 88 | sentencepiece==0.2.0 89 | setuptools==68.2.2 90 | shellingham==1.5.4 91 | six==1.16.0 92 | sniffio==1.3.1 93 | starlette==0.36.3 94 | tensorboard==2.16.2 95 | tensorboard-data-server==0.7.2 96 | tensorboardX==2.6.2.2 97 | tokenizers==0.13.3 98 | tomlkit==0.12.0 99 | toolz==0.12.1 100 | torch==1.13.1 101 | torchaudio==0.13.1 102 | torchvision==0.14.1 103 | tqdm==4.66.2 104 | transformers==4.28.1 105 | typer==0.12.0 106 | typer-cli==0.12.0 107 | typer-slim==0.12.0 108 | typing_extensions==4.9.0 109 | tzdata==2024.1 110 | urllib3==2.1.0 111 | uvicorn==0.29.0 112 | websockets==11.0.3 113 | Werkzeug==3.0.1 114 | wheel==0.41.2 115 | -------------------------------------------------------------------------------- /Code/experiment/blip-vqa-base_finetuned.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | from transformers import BlipProcessor, BlipForQuestionAnswering 4 | import torch 5 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/blip/' 6 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 7 | DEVICE_INDEX = torch.device("cuda:6" if torch.cuda.is_available() else "cpu") 8 | processor = BlipProcessor.from_pretrained(f"Salesforce/blip-vqa-base") 9 | 10 | for i in range(8): 11 | model = BlipForQuestionAnswering.from_pretrained(f"/projects/SpatialMQA/finetune_models/models_arg/blip_finetune_20240602/epoch-{i}").to(DEVICE_INDEX) 12 | 13 | import json 14 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 15 | count = 0 16 | right_count = 0 17 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip_finetuned_{i}.jsonl', 'w', encoding="utf-8") as fout: 18 | for line in f: 19 | data = json.loads(line) 20 | question = data['question'] 21 | id = data['id'] 22 | options = data['options'] 23 | image_name = data['image'] 24 | image_filepath = image_dir + image_name 25 | image = Image.open(image_filepath).convert('RGB') 26 | question = f'{question} {",".join(options[:-1])} or {options[-1]}' 27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX) 28 | predictions = model.generate(**inputs) 29 | output = (processor.decode(predictions[0], skip_special_tokens=True)) 30 | count += 1 31 | if len(output) == 0: 32 | output = '--' 33 | if output == 'on / above': 34 | output = 'on/above' 35 | if output.lower() in data['answer']: 36 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 37 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 38 | right_count += 1 39 | elif data['answer'] in output.lower(): 40 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 41 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 42 | right_count += 1 43 | else: 44 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 45 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 46 | print(f'{output.lower()}') 47 | print(f"{data['answer']}") 48 | print(f'right_count: {right_count}') 49 | print(f'count: {count}') 50 | print(f'accuracy: {right_count/count}') 51 | 52 | accuracy = right_count/count 53 | print(f'accuracy: {accuracy}') 54 | 55 | -------------------------------------------------------------------------------- /Code/experiment/blip2-opt-2.7b.py: -------------------------------------------------------------------------------- 1 | 2 | import requests 3 | from PIL import Image 4 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 5 | import torch 6 | 7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/' 8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 9 | DEVICE_INDEX = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") 10 | model_dir = '/projects/Models/' 11 | processor = AutoProcessor.from_pretrained(f"{model_dir}blip2-opt-2.7b") 12 | model = Blip2ForConditionalGeneration.from_pretrained(f"{model_dir}blip2-opt-2.7b").to(DEVICE_INDEX) 13 | import json 14 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 15 | count = 0 16 | right_count = 0 17 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip2_20240531.jsonl', 'w+', encoding="utf-8") as fout: 18 | for line in f: 19 | data = json.loads(line) 20 | question = data['question'] 21 | id = data['id'] 22 | options = data['options'] 23 | image_name = data['image'] 24 | image_filepath = image_dir + image_name 25 | image = Image.open(image_filepath) 26 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX) 28 | predictions = model.generate(**inputs) 29 | output = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip().rstrip('.') 30 | count += 1 31 | if len(output) == 0: 32 | output = '--' 33 | if output.lower() in data['answer']: 34 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 35 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 36 | right_count += 1 37 | elif data['answer'] in output.lower(): 38 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 39 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 40 | right_count += 1 41 | else: 42 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 43 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 44 | print(f'{output.lower()}') 45 | print(f"{data['answer']}") 46 | print(f'right_count: {right_count}') 47 | print(f'count: {count}') 48 | print(f'accuracy: {right_count/count}') 49 | 50 | accuracy = right_count/count 51 | print(f'accuracy: {accuracy}') 52 | -------------------------------------------------------------------------------- /Code/requirement/requirement_llava.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | aiofiles==23.2.1 3 | altair==5.2.0 4 | annotated-types==0.6.0 5 | anyio==4.3.0 6 | attrs==23.2.0 7 | bitsandbytes==0.43.0 8 | certifi==2024.2.2 9 | charset-normalizer==3.3.2 10 | click==8.1.7 11 | colorama==0.4.6 12 | contourpy==1.2.0 13 | cycler==0.12.1 14 | deepspeed==0.14.2 15 | docker-pycreds==0.4.0 16 | einops==0.6.1 17 | einops-exts==0.0.4 18 | exceptiongroup==1.2.0 19 | fastapi==0.110.0 20 | ffmpy==0.3.2 21 | filelock==3.13.1 22 | flash-attn==2.5.8 23 | fonttools==4.49.0 24 | fsspec==2024.2.0 25 | gitdb==4.0.11 26 | GitPython==3.1.43 27 | gradio==4.16.0 28 | gradio_client==0.8.1 29 | h11==0.14.0 30 | hjson==3.1.0 31 | httpcore==0.17.3 32 | httpx==0.24.0 33 | huggingface-hub==0.21.4 34 | idna==3.6 35 | importlib_resources==6.3.0 36 | Jinja2==3.1.3 37 | joblib==1.3.2 38 | jsonschema==4.21.1 39 | jsonschema-specifications==2023.12.1 40 | kiwisolver==1.4.5 41 | llava==1.2.2.post1 42 | markdown-it-py==3.0.0 43 | markdown2==2.4.13 44 | MarkupSafe==2.1.5 45 | matplotlib==3.8.3 46 | mdurl==0.1.2 47 | mpmath==1.3.0 48 | networkx==3.2.1 49 | ninja==1.11.1.1 50 | numpy==1.26.4 51 | nvidia-cublas-cu12==12.1.3.1 52 | nvidia-cuda-cupti-cu12==12.1.105 53 | nvidia-cuda-nvrtc-cu12==12.1.105 54 | nvidia-cuda-runtime-cu12==12.1.105 55 | nvidia-cudnn-cu12==8.9.2.26 56 | nvidia-cufft-cu12==11.0.2.54 57 | nvidia-curand-cu12==10.3.2.106 58 | nvidia-cusolver-cu12==11.4.5.107 59 | nvidia-cusparse-cu12==12.1.0.106 60 | nvidia-nccl-cu12==2.18.1 61 | nvidia-nvjitlink-cu12==12.4.99 62 | nvidia-nvtx-cu12==12.1.105 63 | orjson==3.9.15 64 | packaging==24.0 65 | pandas==2.2.1 66 | peft==0.9.0 67 | pillow==10.2.0 68 | pip==24.0 69 | platformdirs==4.2.2 70 | protobuf==4.25.3 71 | psutil==5.9.8 72 | py-cpuinfo==9.0.0 73 | pydantic==2.6.4 74 | pydantic_core==2.16.3 75 | pydub==0.25.1 76 | Pygments==2.17.2 77 | pynvml==11.5.0 78 | pyparsing==3.1.2 79 | python-dateutil==2.9.0.post0 80 | python-multipart==0.0.9 81 | pytz==2024.1 82 | PyYAML==6.0.1 83 | referencing==0.33.0 84 | regex==2023.12.25 85 | requests==2.31.0 86 | rich==13.7.1 87 | rpds-py==0.18.0 88 | ruff==0.3.2 89 | safetensors==0.4.2 90 | scikit-learn==1.2.2 91 | scipy==1.12.0 92 | semantic-version==2.10.0 93 | sentencepiece==0.1.99 94 | sentry-sdk==2.2.1 95 | setproctitle==1.3.3 96 | setuptools==68.2.2 97 | shellingham==1.5.4 98 | shortuuid==1.0.13 99 | six==1.16.0 100 | smmap==5.0.1 101 | sniffio==1.3.1 102 | starlette==0.36.3 103 | svgwrite==1.4.3 104 | sympy==1.12 105 | threadpoolctl==3.3.0 106 | timm==0.6.13 107 | tokenizers==0.15.1 108 | tomlkit==0.12.0 109 | toolz==0.12.1 110 | torch==2.1.2 111 | torchvision==0.16.2 112 | tqdm==4.66.2 113 | transformers==4.37.2 114 | triton==2.1.0 115 | typer==0.9.0 116 | typing_extensions==4.10.0 117 | tzdata==2024.1 118 | urllib3==2.2.1 119 | uvicorn==0.28.0 120 | wavedrom==2.0.3.post3 121 | websockets==11.0.3 122 | wheel==0.41.2 123 | -------------------------------------------------------------------------------- /Code/eval/calculate_result_rule.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def load_jsonl(file_path): 6 | data = [] 7 | with open(file_path, 'r', encoding='utf-8') as f: 8 | for line in f: 9 | data.append(json.loads(line.strip())) 10 | return data 11 | 12 | 13 | def main(): 14 | rule_1_file = 'rule_1_test_452.jsonl' 15 | rule_2_file = 'rule_2_test_590.jsonl' 16 | rule_3_file = 'rule_3_test_34.jsonl' 17 | 18 | rule_1_datas = load_jsonl(rule_1_file) 19 | rule_2_datas = load_jsonl(rule_2_file) 20 | rule_3_datas = load_jsonl(rule_3_file) 21 | 22 | test_file_list = os.listdir('model_exp_results/move_to_localhost/') 23 | # test_file_list = os.listdir('model_exp_results/m10/') 24 | for test_file in test_file_list: 25 | correct_count_1 = 0 26 | all_count_1 = 0 27 | correct_count_2 = 0 28 | all_count_2 = 0 29 | correct_count_3 = 0 30 | all_count_3 = 0 31 | 32 | print(f'-------------------------------{test_file}----------------------------------') 33 | test_file_path = f'model_exp_results/move_to_localhost/{test_file}' 34 | # test_file_path = f'model_exp_results/m10/{test_file}' 35 | predictions = load_jsonl(test_file_path) 36 | for predict in predictions: 37 | flag = 0 38 | for rule_1_data in rule_1_datas: 39 | if rule_1_data['id'] == predict['id']: 40 | if predict['result'] == 1: 41 | correct_count_1 += 1 42 | all_count_1 += 1 43 | flag = 1 44 | break 45 | 46 | if flag == 0: 47 | for rule_2_data in rule_2_datas: 48 | if rule_2_data['id'] == predict['id']: 49 | if predict['result'] == 1: 50 | correct_count_2 += 1 51 | all_count_2 += 1 52 | flag = 1 53 | break 54 | 55 | if flag == 0: 56 | for rule_3_data in rule_3_datas: 57 | if rule_3_data['id'] == predict['id']: 58 | if predict['result'] == 1: 59 | correct_count_3 += 1 60 | all_count_3 += 1 61 | flag = 1 62 | break 63 | 64 | if flag == 0: 65 | print("数据有问题") 66 | 67 | print(f"Accuracy for 'rule1': {(correct_count_1 / all_count_1):.4f}") 68 | print(f"Accuracy for 'rule2': {(correct_count_2 / all_count_2):.4f}") 69 | print(f"Accuracy for 'rule3': {(correct_count_3 / all_count_3):.4f}") 70 | print(f"Accuracy for 'overall': {((correct_count_1+correct_count_2+correct_count_3) / len(predictions)):.4f}") 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /Code/experiment/blip2-lora.py: -------------------------------------------------------------------------------- 1 | 2 | import requests 3 | from PIL import Image 4 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 5 | from peft import PeftModel, PeftConfig 6 | import torch 7 | 8 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/' 9 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 10 | DEVICE_INDEX = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") 11 | model_dir = '/projects/Models/' 12 | 13 | processor = AutoProcessor.from_pretrained(f"{model_dir}blip2-opt-2.7b") 14 | model = Blip2ForConditionalGeneration.from_pretrained(f"{model_dir}blip2-opt-2.7b", device_map="cuda:7", load_in_8bit=True) 15 | lora_point = '/projects/SpatialMQA/finetune_models/models_arg/blip2_lora_20240531/epoch-9' 16 | 17 | model = PeftModel.from_pretrained(model, lora_point).to(DEVICE_INDEX) 18 | import json 19 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 20 | count = 0 21 | right_count = 0 22 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip2_lora_20240531_ep9.jsonl', 'w+', encoding="utf-8") as fout: 23 | for line in f: 24 | data = json.loads(line) 25 | question = data['question'] 26 | id = data['id'] 27 | options = data['options'] 28 | image_name = data['image'] 29 | image_filepath = image_dir + image_name 30 | image = Image.open(image_filepath) 31 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \nOutput:' 32 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX) 33 | predictions = model.generate(**inputs) 34 | print("predictions: ", predictions) 35 | output = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip().rstrip('.') 36 | count += 1 37 | if len(output) == 0: 38 | output = '--' 39 | if output.lower() in data['answer']: 40 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 41 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 42 | right_count += 1 43 | elif data['answer'] in output.lower(): 44 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 45 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 46 | right_count += 1 47 | else: 48 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 49 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 50 | print(f'{output.lower()}') 51 | print(f"{data['answer']}") 52 | print(f'right_count: {right_count}') 53 | print(f'count: {count}') 54 | print(f'accuracy: {right_count/count}') 55 | 56 | accuracy = right_count/count 57 | print(f'accuracy: {accuracy}') 58 | -------------------------------------------------------------------------------- /Code/experiment/instructblip-flan-t5-xl.py: -------------------------------------------------------------------------------- 1 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration 2 | import torch 3 | from PIL import Image 4 | import requests 5 | import json 6 | 7 | 8 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/' 9 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 10 | DEVICE_INDEX = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") 11 | model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl").to(DEVICE_INDEX) 12 | processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl") 13 | 14 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 15 | count = 0 16 | right_count = 0 17 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}instrutblip_20240601.jsonl', 'w+', encoding="utf-8") as fout: 18 | for line in f: 19 | data = json.loads(line) 20 | question = data['question'] 21 | id = data['id'] 22 | options = data['options'] 23 | image_name = data['image'] 24 | image_filepath = image_dir + image_name 25 | image = Image.open(image_filepath) 26 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX) 28 | outputs = model.generate( 29 | **inputs, 30 | do_sample=False, 31 | num_beams=5, 32 | max_length=256, 33 | min_length=1, 34 | top_p=0.9, 35 | repetition_penalty=1.5, 36 | length_penalty=1.0, 37 | temperature=1, 38 | ) 39 | generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip().rstrip('.') 40 | print(generated_text) 41 | output = generated_text 42 | count += 1 43 | if len(output) == 0: 44 | output = '--' 45 | if output.lower() in data['answer']: 46 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 47 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 48 | right_count += 1 49 | elif data['answer'] in output.lower(): 50 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 51 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 52 | right_count += 1 53 | else: 54 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 55 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 56 | print(f'{output.lower()}') 57 | print(f"{data['answer']}") 58 | print(f'right_count: {right_count}') 59 | print(f'count: {count}') 60 | print(f'accuracy: {right_count/count}') 61 | 62 | accuracy = right_count/count 63 | print(f'accuracy: {accuracy}') 64 | -------------------------------------------------------------------------------- /Code/experiment/instructblip-lora.py: -------------------------------------------------------------------------------- 1 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration 2 | import torch 3 | from peft import PeftModel, PeftConfig 4 | from PIL import Image 5 | import requests 6 | import re 7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/' 8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 9 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 10 | DEVICE_INDEX = 6 11 | model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl").to(DEVICE_INDEX) 12 | processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl") 13 | 14 | device = "cuda:6" if torch.cuda.is_available() else "cpu" 15 | lora_point = '/projects/SpatialMQA/finetune_models/models_arg/instructblip_lora_20240530' 16 | model = PeftModel.from_pretrained(model, lora_point).to(device) 17 | 18 | 19 | import json 20 | 21 | count = 0 22 | right_count = 0 23 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}instrutblip_lora_20240601.jsonl', 'w+', encoding="utf-8") as fout: 24 | for line in f: 25 | data = json.loads(line) 26 | question = data['question'] 27 | id = data['id'] 28 | options = data['options'] 29 | image_name = data['image'] 30 | image_filepath = image_dir + image_name 31 | image = Image.open(image_filepath) 32 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 33 | inputs = processor(images=image, text=question, return_tensors="pt").to(device) 34 | outputs = model.generate( 35 | **inputs, 36 | do_sample=False, 37 | num_beams=5, 38 | max_length=8, 39 | min_length=1, 40 | top_p=0.9, 41 | repetition_penalty=1.5, 42 | length_penalty=1.0, 43 | temperature=1, 44 | ) 45 | generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() 46 | print(generated_text) 47 | output = generated_text 48 | 49 | count += 1 50 | if len(output) == 0: 51 | output = '--' 52 | if output.lower() in data['answer']: 53 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 54 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 55 | right_count += 1 56 | elif data['answer'] in output.lower(): 57 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 58 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 59 | right_count += 1 60 | else: 61 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 62 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 63 | print(f'{output.lower()}') 64 | print(f"{data['answer']}") 65 | print(f'right_count: {right_count}') 66 | print(f'count: {count}') 67 | print(f'accuracy: {right_count/count}') 68 | 69 | accuracy = right_count/count 70 | print(f'accuracy: {accuracy}') 71 | -------------------------------------------------------------------------------- /Code/experiment/spatial_test_mplug.py: -------------------------------------------------------------------------------- 1 | from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration 2 | from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer 3 | from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor 4 | import torch 5 | from PIL import Image 6 | import json 7 | 8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 9 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/prompt test/' 10 | pretrained_ckpt = 'MAGAer13/mplug-owl-llama-7b' 11 | device="cuda:7" 12 | model = MplugOwlForConditionalGeneration.from_pretrained( 13 | pretrained_ckpt, 14 | torch_dtype=torch.bfloat16, 15 | ).to(device) 16 | image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) 17 | tokenizer = MplugOwlTokenizer.from_pretrained('MAGAer13/mplug-owl-llama-7b') 18 | processor = MplugOwlProcessor(image_processor, tokenizer) 19 | 20 | generate_kwargs = { 21 | 'do_sample': True, 22 | 'top_k': 1, 23 | 'max_length': 512, 24 | 'temperature':0.1 25 | } 26 | 27 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 28 | 29 | count = 0 30 | right_count = 0 31 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}mplug_1.jsonl', 'w+', encoding="utf-8") as fout: 32 | for line in f: 33 | data = json.loads(line) 34 | question = data['question'] 35 | id = data['id'] 36 | options = data['options'] 37 | image_name = data['image'] 38 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 39 | if not count: 40 | print(f'question:{question}') 41 | image_filepath = image_dir + image_name 42 | images = [Image.open(image_filepath)] 43 | prompts = [ 44 | f'''The following is a conversation between a curious human and AI assistant. 45 | Human: 46 | Human: {question} 47 | AI: ''' 48 | ] 49 | inputs = processor(text=prompts, images=images, return_tensors='pt') 50 | inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} 51 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 52 | with torch.no_grad(): 53 | res = model.generate(**inputs, **generate_kwargs) 54 | sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) 55 | print(sentence) 56 | count += 1 57 | output = sentence.strip().rstrip('.') 58 | if len(output) == 0: 59 | output = '--' 60 | if output.lower() in data['answer']: 61 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 62 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 63 | right_count += 1 64 | elif data['answer'] in output.lower(): 65 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 66 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 67 | right_count += 1 68 | else: 69 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 70 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 71 | print(f'{output.lower()}') 72 | print(f"{data['answer']}") 73 | print(f'right_count: {right_count}') 74 | print(f'count: {count}') 75 | print(f'accuracy: {right_count/count}') 76 | 77 | accuracy = right_count/count 78 | print(f'accuracy: {accuracy}') 79 | -------------------------------------------------------------------------------- /Metadata/metadata_hf.json: -------------------------------------------------------------------------------- 1 | {"@context":{"@language":"en","@vocab":"https://schema.org/","arrayShape":"cr:arrayShape","citeAs":"cr:citeAs","column":"cr:column","conformsTo":"dct:conformsTo","cr":"http://mlcommons.org/croissant/","data":{"@id":"cr:data","@type":"@json"},"dataBiases":"cr:dataBiases","dataCollection":"cr:dataCollection","dataType":{"@id":"cr:dataType","@type":"@vocab"},"dct":"http://purl.org/dc/terms/","extract":"cr:extract","field":"cr:field","fileProperty":"cr:fileProperty","fileObject":"cr:fileObject","fileSet":"cr:fileSet","format":"cr:format","includes":"cr:includes","isArray":"cr:isArray","isLiveDataset":"cr:isLiveDataset","jsonPath":"cr:jsonPath","key":"cr:key","md5":"cr:md5","parentField":"cr:parentField","path":"cr:path","personalSensitiveInformation":"cr:personalSensitiveInformation","recordSet":"cr:recordSet","references":"cr:references","regex":"cr:regex","repeated":"cr:repeated","replace":"cr:replace","sc":"https://schema.org/","separator":"cr:separator","source":"cr:source","subField":"cr:subField","transform":"cr:transform"},"@type":"sc:Dataset","distribution":[{"@type":"cr:FileObject","@id":"repo","name":"repo","description":"The Hugging Face git repository.","contentUrl":"https://huggingface.co/datasets/liuziyan/SpatialMQA/tree/refs%2Fconvert%2Fparquet","encodingFormat":"git+https","sha256":"https://github.com/mlcommons/croissant/issues/80"},{"@type":"cr:FileSet","@id":"parquet-files-for-config-default","containedIn":{"@id":"repo"},"encodingFormat":"application/x-parquet","includes":"default/*/*.parquet"}],"recordSet":[{"@type":"cr:RecordSet","dataType":"cr:Split","key":{"@id":"default_splits/split_name"},"@id":"default_splits","name":"default_splits","description":"Splits for the default config.","field":[{"@type":"cr:Field","@id":"default_splits/split_name","dataType":"sc:Text"}],"data":[{"default_splits/split_name":"train"},{"default_splits/split_name":"validation"},{"default_splits/split_name":"test"}]},{"@type":"cr:RecordSet","@id":"default","description":"liuziyan/SpatialMQA - 'default' subset\n\nAdditional information:\n- 3 splits: train, validation, test","field":[{"@type":"cr:Field","@id":"default/split","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"fileProperty":"fullpath"},"transform":{"regex":"default/(?:partial-)?(train|validation|test)/.+parquet$"}},"references":{"field":{"@id":"default_splits/split_name"}}},{"@type":"cr:Field","@id":"default/image","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"image"}}},{"@type":"cr:Field","@id":"default/question","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"question"}}},{"@type":"cr:Field","@id":"default/options","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"options"}},"isArray":true,"arrayShape":"-1"},{"@type":"cr:Field","@id":"default/answer","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"answer"}}},{"@type":"cr:Field","@id":"default/_answer","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":",answer"}}}]}],"conformsTo":"http://mlcommons.org/croissant/1.1","name":"SpatialMQA","description":"Welcome to explore our work titled \"Can Multimodal Large Language Models Understand Spatial Relations\".arXiv link: https://arxiv.org/abs/2505.19015.For more information about the paper and the SpatialMQA dataset, please visit our GitHub repository at https://github.com/ziyan-xiaoyu/SpatialMQA.\n\n\n\t\n\t\t\n\t\tlicense: cc-by-4.0\n\t\n\n","alternateName":["liuziyan/SpatialMQA"],"creator":{"@type":"Person","name":"刘子言","url":"https://huggingface.co/liuziyan"},"keywords":["1K - 10K","json","Image","Text","Datasets","pandas","Croissant","Polars","arxiv:2505.19015","🇺🇸 Region: US"],"url":"https://huggingface.co/datasets/liuziyan/SpatialMQA"} 2 | -------------------------------------------------------------------------------- /Code/experiment/spatial_test_mplug_lora.py: -------------------------------------------------------------------------------- 1 | from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration 2 | from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer 3 | from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor 4 | from peft import PeftModel, PeftConfig 5 | import torch 6 | from PIL import Image 7 | import json 8 | 9 | 10 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 11 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/m10/' 12 | pretrained_ckpt = 'MAGAer13/mplug-owl-llama-7b' 13 | device="cuda:6" 14 | peft_model_id = '/projects/SpatialMQA/finetune_models/models_arg/mplug_lora_20240530' 15 | model = MplugOwlForConditionalGeneration.from_pretrained( 16 | pretrained_ckpt, 17 | torch_dtype=torch.bfloat16, 18 | ).to(device) 19 | model = PeftModel.from_pretrained(model, peft_model_id).to(device) 20 | image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) 21 | tokenizer = MplugOwlTokenizer.from_pretrained('MAGAer13/mplug-owl-llama-7b') 22 | processor = MplugOwlProcessor(image_processor, tokenizer) 23 | 24 | generate_kwargs = { 25 | 'do_sample': True, 26 | 'top_k': 0, 27 | 'max_length': 512, 28 | 'temperature': 0.1 29 | } 30 | 31 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 32 | count = 0 33 | right_count = 0 34 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}mplug_lora_0.jsonl', 'w+', encoding="utf-8") as fout: 35 | for line in f: 36 | data = json.loads(line) 37 | question = data['question'] 38 | id = data['id'] 39 | options = data['options'] 40 | image_name = data['image'] 41 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 42 | if not count: 43 | print(f'question:{question}') 44 | image_filepath = image_dir + image_name 45 | images = [Image.open(image_filepath)] 46 | prompts = [ 47 | f'''The following is a conversation between a curious human and AI assistant. 48 | Human: 49 | Human: {question} 50 | AI: ''' 51 | ] 52 | inputs = processor(text=prompts, images=images, return_tensors='pt') 53 | inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} 54 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 55 | with torch.no_grad(): 56 | res = model.generate(**inputs, **generate_kwargs) 57 | sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) 58 | print(sentence) 59 | count += 1 60 | output = sentence.strip().rstrip('.') 61 | if len(output) == 0: 62 | output = '--' 63 | if output.lower() in data['answer']: 64 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 65 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 66 | right_count += 1 67 | elif data['answer'] in output.lower(): 68 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 69 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 70 | right_count += 1 71 | else: 72 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 73 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 74 | print(f'{output.lower()}') 75 | print(f"{data['answer']}") 76 | print(f'right_count: {right_count}') 77 | print(f'count: {count}') 78 | print(f'accuracy: {right_count/count}') 79 | 80 | accuracy = right_count/count 81 | print(f'accuracy: {accuracy}') 82 | -------------------------------------------------------------------------------- /Code/experiment/idefics_new.py: -------------------------------------------------------------------------------- 1 | # this is a demo of inference of IDEFICS-9B which needs about 20GB of GPU memory 2 | import json 3 | import re 4 | from PIL import Image 5 | import torch 6 | from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig 7 | from peft import PeftModel, PeftConfig 8 | 9 | device = "cuda:5" 10 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/idefics-10/' 11 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 12 | 13 | model_path = '/projects/Models/idefics/' 14 | 15 | bnb_config = BitsAndBytesConfig( 16 | load_in_4bit=True, 17 | bnb_4bit_use_double_quant=True, 18 | bnb_4bit_quant_type="nf4", 19 | bnb_4bit_compute_dtype=torch.float16, 20 | llm_int8_skip_modules=["lm_head", "embed_tokens"], 21 | ) 22 | 23 | config = PeftConfig.from_pretrained(model_path) 24 | processor = AutoProcessor.from_pretrained(config.base_model_name_or_path) 25 | model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path,quantization_config=bnb_config,device_map="cuda:5") 26 | 27 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 28 | count = 0 29 | right_count = 0 30 | 31 | def check_inference(model, processor, prompts, max_new_tokens=50): 32 | tokenizer = processor.tokenizer 33 | bad_words = ["", ""] 34 | if len(bad_words) > 0: 35 | bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids 36 | 37 | eos_token = "" 38 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) 39 | 40 | inputs = processor(prompts, return_tensors="pt").to(device) 41 | generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True, 42 | ) 43 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 44 | return generated_text 45 | 46 | 47 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}idefics_test.jsonl', 'w+', encoding="utf-8") as fout: 48 | for line in f: 49 | data = json.loads(line) 50 | question = data['question'] 51 | id = data['id'] 52 | options = data['options'] 53 | image_name = data['image'] 54 | image_filepath = image_dir + image_name 55 | image = Image.open(image_filepath) 56 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 57 | prompts = [ 58 | image, 59 | question, 60 | ] 61 | 62 | generated_text = check_inference(model, processor, prompts, max_new_tokens=5) 63 | print(f'generated_text:\n{generated_text}') 64 | output = generated_text.lower().split('answer:')[-1].split('\n')[0].strip().rstrip('.') 65 | count += 1 66 | if len(output) == 0: 67 | output = '--' 68 | if output in data['answer']: 69 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 70 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 71 | right_count += 1 72 | elif data['answer'] in output: 73 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 74 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 75 | right_count += 1 76 | else: 77 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 78 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 79 | print(f'{output.lower()}') 80 | print(f"{data['answer']}") 81 | print(f'right_count: {right_count}') 82 | print(f'count: {count}') 83 | print(f'accuracy: {right_count/count}') 84 | 85 | accuracy = right_count/count 86 | print(f'accuracy: {accuracy}') -------------------------------------------------------------------------------- /Code/experiment/idefics_lora.py: -------------------------------------------------------------------------------- 1 | # this is a demo of inference of IDEFICS-9B which needs about 20GB of GPU memory 2 | 3 | import torch 4 | from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig 5 | 6 | device = "cuda:7" if torch.cuda.is_available() else "cpu" 7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/idefics-10/' 8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 9 | 10 | lora_point = '/projects/SpatialMQA/finetune_models/models_arg/idefics_lora_20240521/checkpoint-35' 11 | 12 | from peft import PeftModel, PeftConfig 13 | 14 | bnb_config = BitsAndBytesConfig( 15 | load_in_4bit=True, 16 | bnb_4bit_use_double_quant=True, 17 | bnb_4bit_quant_type="nf4", 18 | bnb_4bit_compute_dtype=torch.float16, 19 | llm_int8_skip_modules=["lm_head", "embed_tokens"], 20 | ) 21 | 22 | config = PeftConfig.from_pretrained(lora_point) 23 | processor = AutoProcessor.from_pretrained(config.base_model_name_or_path) 24 | model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path,quantization_config=bnb_config,device_map="cuda:7") 25 | model = PeftModel.from_pretrained(model, lora_point).to(device) 26 | 27 | 28 | import json 29 | import re 30 | from PIL import Image 31 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 32 | count = 0 33 | right_count = 0 34 | 35 | def check_inference(model, processor, prompts, max_new_tokens=50): 36 | tokenizer = processor.tokenizer 37 | bad_words = ["", ""] 38 | if len(bad_words) > 0: 39 | bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids 40 | 41 | eos_token = "" 42 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) 43 | 44 | inputs = processor(prompts, return_tensors="pt").to(device) 45 | generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True, 46 | do_sample=True, temperature=0.7) 47 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 48 | return generated_text 49 | 50 | 51 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}idefics_lora_7.jsonl', 'w+', encoding="utf-8") as fout: 52 | for line in f: 53 | data = json.loads(line) 54 | question = data['question'] 55 | id = data['id'] 56 | options = data['options'] 57 | image_name = data['image'] 58 | image_filepath = image_dir + image_name 59 | image = Image.open(image_filepath) 60 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \nOutput:' 61 | prompts = [ 62 | image, 63 | question, 64 | ] 65 | 66 | generated_text = check_inference(model, processor, prompts, max_new_tokens=5) 67 | print(f'generated_text:\n{generated_text}') 68 | output = generated_text.lower().split('answer:')[-1].split('\n')[0].strip().rstrip('.') 69 | count += 1 70 | if len(output) == 0: 71 | output = '--' 72 | if output.lower() in data['answer']: 73 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 74 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 75 | right_count += 1 76 | elif data['answer'] in output: 77 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 78 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 79 | right_count += 1 80 | else: 81 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 82 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 83 | print(f'{output.lower()}') 84 | print(f"{data['answer']}") 85 | print(f'right_count: {right_count}') 86 | print(f'count: {count}') 87 | print(f'accuracy: {right_count/count}') 88 | 89 | accuracy = right_count/count 90 | print(f'accuracy: {accuracy}') 91 | -------------------------------------------------------------------------------- /Code/close_models/gemini_text_only.py: -------------------------------------------------------------------------------- 1 | import google.generativeai as genai 2 | from io import BytesIO 3 | import requests 4 | 5 | genai.configure(api_key="your key",transport="rest") 6 | 7 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480} 8 | safety_settings = [ 9 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, 10 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 11 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, 12 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"} 13 | ] 14 | 15 | FILE_PATH = 'datasets/test_en_select_500_sort.jsonl' 16 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 17 | RESULT_FILE_PATH = 'model_results/gemini_0_shot.jsonl' 18 | 19 | 20 | def fetch_image_content(image_url): 21 | response = requests.get(image_url) 22 | if response.status_code == 200: 23 | return BytesIO(response.content) 24 | else: 25 | return None 26 | 27 | 28 | import PIL.Image as Image 29 | 30 | # few-shot & zero-shot: 31 | # model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings) 32 | 33 | # text-only: 34 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings) 35 | 36 | import json 37 | 38 | count = 0 39 | right_count = 0 40 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'w+', encoding="utf-8") as fout: 41 | for line in f: 42 | data = json.loads(line) 43 | id = data['id'] 44 | 45 | # text-only: 46 | question = f'You are currently a senior expert in spatial relation reasoning. ' \ 47 | f'\nGiven a Question and Options, your task is to use your common sense to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 48 | f'\nInput: Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 49 | 50 | image_name = data['image'] 51 | image_filepath = f'{IMAGE_DIR}/{image_name}' 52 | 53 | image_content = fetch_image_content(image_filepath) 54 | if image_content is not None: 55 | image = Image.open(image_content) 56 | try: 57 | response = model.generate_content( 58 | # [question, image] # few-shot & zero-shot 59 | [question] # text-only 60 | ) 61 | except Exception as e: 62 | print(e) 63 | try: 64 | response = model.generate_content( 65 | # [question, image] # few-shot & zero-shot 66 | [question] # text-only 67 | ) 68 | except Exception as e: 69 | try: 70 | response = model.generate_content( 71 | # [question, image] # few-shot & zero-shot 72 | [question] # text-only 73 | ) 74 | except Exception as e: 75 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": 0} 76 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 77 | count += 1 78 | continue 79 | try: 80 | output = response.text.strip().rstrip('.').lower() 81 | except Exception as e: 82 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0} 83 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 84 | count += 1 85 | continue 86 | count += 1 87 | if output in data['answer']: 88 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0} 89 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 90 | right_count += 1 91 | else: 92 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0} 93 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 94 | print(f'{output.lower()}') 95 | print(f"{data['answer']}") 96 | print(f'right_count: {right_count}') 97 | print(f'count: {count}') 98 | 99 | accuracy = right_count / count 100 | print(accuracy) 101 | -------------------------------------------------------------------------------- /Code/close_models/gemini_0_shot.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | 3 | import google.generativeai as genai 4 | from io import BytesIO 5 | import requests 6 | 7 | genai.configure(api_key="your key",transport="rest") 8 | 9 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480} 10 | safety_settings = [ 11 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, 12 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 13 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, 14 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"} 15 | ] 16 | 17 | FILE_PATH = 'datasets/test_en_select_500_sort.jsonl' 18 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 19 | RESULT_FILE_PATH = 'model_results/gemini_0_shot.jsonl' 20 | 21 | 22 | def fetch_image_content(image_url): 23 | response = requests.get(image_url) 24 | if response.status_code == 200: 25 | return BytesIO(response.content) 26 | else: 27 | return None 28 | 29 | 30 | import PIL.Image as Image 31 | 32 | # few-shot & zero-shot: 33 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings) 34 | 35 | # text-only: 36 | # model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings) 37 | 38 | import json 39 | 40 | count = 0 41 | right_count = 0 42 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout: 43 | for line in f: 44 | sleep(1) 45 | data = json.loads(line) 46 | id = data['id'] 47 | 48 | # 1 - shot: 49 | question = f'You are currently a senior expert in spatial relation reasoning. ' \ 50 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 51 | f'\nInput: Image:, Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 52 | 53 | image_name = data['image'] 54 | image_filepath = f'{IMAGE_DIR}/{image_name}' 55 | 56 | image_content = fetch_image_content(image_filepath) 57 | if image_content is not None: 58 | image = Image.open(image_content) 59 | try: 60 | response = model.generate_content( 61 | [question, image] # few-shot & zero-shot 62 | # [question] # text-only 63 | ) 64 | except Exception as e: 65 | print(e) 66 | try: 67 | response = model.generate_content( 68 | [question, image] # few-shot & zero-shot 69 | # [question] # text-only 70 | ) 71 | except Exception as e: 72 | try: 73 | response = model.generate_content( 74 | [question, image] # few-shot & zero-shot 75 | # [question] # text-only 76 | ) 77 | except Exception as e: 78 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": 0} 79 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 80 | count += 1 81 | continue 82 | try: 83 | output = response.text.strip().rstrip('.').lower() 84 | except Exception as e: 85 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0} 86 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 87 | count += 1 88 | continue 89 | count += 1 90 | if output in data['answer']: 91 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0} 92 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 93 | right_count += 1 94 | else: 95 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0} 96 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 97 | print(f'{output.lower()}') 98 | print(f"{data['answer']}") 99 | print(f'right_count: {right_count}') 100 | print(f'count: {count}') 101 | 102 | accuracy = right_count / count 103 | print(accuracy) 104 | -------------------------------------------------------------------------------- /Metadata/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "@context": { 3 | "@language": "en", 4 | "@vocab": "https://schema.org/", 5 | "citeAs": "cr:citeAs", 6 | "column": "cr:column", 7 | "conformsTo": "dct:conformsTo", 8 | "cr": "http://mlcommons.org/croissant/", 9 | "rai": "http://mlcommons.org/croissant/RAI/", 10 | "data": { 11 | "@id": "cr:data", 12 | "@type": "@json" 13 | }, 14 | "dataType": { 15 | "@id": "cr:dataType", 16 | "@type": "@vocab" 17 | }, 18 | "dct": "http://purl.org/dc/terms/", 19 | "examples": { 20 | "@id": "cr:examples", 21 | "@type": "@json" 22 | }, 23 | "extract": "cr:extract", 24 | "field": "cr:field", 25 | "fileProperty": "cr:fileProperty", 26 | "fileObject": "cr:fileObject", 27 | "fileSet": "cr:fileSet", 28 | "format": "cr:format", 29 | "includes": "cr:includes", 30 | "isLiveDataset": "cr:isLiveDataset", 31 | "jsonPath": "cr:jsonPath", 32 | "key": "cr:key", 33 | "md5": "cr:md5", 34 | "parentField": "cr:parentField", 35 | "path": "cr:path", 36 | "recordSet": "cr:recordSet", 37 | "references": "cr:references", 38 | "regex": "cr:regex", 39 | "repeated": "cr:repeated", 40 | "replace": "cr:replace", 41 | "sc": "https://schema.org/", 42 | "separator": "cr:separator", 43 | "source": "cr:source", 44 | "subField": "cr:subField", 45 | "transform": "cr:transform" 46 | }, 47 | "@type": "sc:Dataset", 48 | "name": "SpatialMQA", 49 | "description": "SpatialMQA is a manually annotated dataset designed for multimodal spatial relation reasoning in a multiple-choice question & answer format. The dataset includes 5,392 samples collected from COCO2017, covering 128 subject and object types, without bounding boxes.", 50 | "conformsTo": "http://mlcommons.org/croissant/1.0", 51 | "license": "CC-BY 4.0", 52 | "url": " https://anonymous.4open.science/r/SpatialMQA", // Anonymous link, will be modified to real name link after the paper is reviewed 53 | "version": "1.0.0", 54 | "distribution": [ 55 | { 56 | "@type": "cr:FileObject", 57 | "@id": "github-repository", 58 | "name": "github-repository", 59 | "description": "SpatialMQA repository on GitHub.", 60 | "contentUrl": " https://anonymous.4open.science/r/SpatialMQA", 61 | "encodingFormat": "git+https", 62 | "sha256": "main" 63 | }, 64 | { 65 | "@type": "cr:FileSet", 66 | "@id": "jsonl-files", 67 | "name": "jsonl-files", 68 | "description": "JSONL files are hosted on the GitHub repository.", 69 | "containedIn": { 70 | "@id": "github-repository" 71 | }, 72 | "encodingFormat": "application/jsonlines", 73 | "includes": "Dataset/dataset/*.jsonl" 74 | } 75 | ], 76 | "recordSet": [ 77 | { 78 | "@type": "cr:RecordSet", 79 | "@id": "jsonl", 80 | "name": "jsonl", 81 | "field": [ 82 | { 83 | "@type": "cr:Field", 84 | "@id": "jsonl/image", 85 | "name": "image", 86 | "dataType": "sc:Text", 87 | "source": { 88 | "fileSet": { 89 | "@id": "jsonl-files" 90 | }, 91 | "extract": { 92 | "column": "image" 93 | } 94 | } 95 | }, 96 | { 97 | "@type": "cr:Field", 98 | "@id": "jsonl/question", 99 | "name": "question", 100 | "description": "The expected question of the promt.", 101 | "dataType": "sc:Text", 102 | "source": { 103 | "fileSet": { 104 | "@id": "jsonl-files" 105 | }, 106 | "extract": { 107 | "column": "question" 108 | } 109 | } 110 | }, 111 | { 112 | "@type": "cr:Field", 113 | "@id": "jsonl/options", 114 | "name": "options", 115 | "description": "The expected options of the promt.", 116 | "dataType": "sc:Text", 117 | "source": { 118 | "fileSet": { 119 | "@id": "jsonl-files" 120 | }, 121 | "extract": { 122 | "column": "options" 123 | } 124 | } 125 | }, 126 | { 127 | "@type": "cr:Field", 128 | "@id": "jsonl/answer", 129 | "name": "answer", 130 | "description": "The expected options of the promt.", 131 | "dataType": "sc:Text", 132 | "source": { 133 | "fileSet": { 134 | "@id": "jsonl-files" 135 | }, 136 | "extract": { 137 | "column": "answer" 138 | } 139 | } 140 | }, 141 | { 142 | "@type": "cr:Field", 143 | "@id": "jsonl/task", 144 | "name": "task", 145 | "description": "The machine learning task appearing as the name of the file.", 146 | "dataType": "sc:Text", 147 | "source": { 148 | "fileSet": { 149 | "@id": "jsonl-files" 150 | }, 151 | "extract": { 152 | "fileProperty": "filename" 153 | }, 154 | "transform": { 155 | "regex": "^(.*)\\.jsonl$" 156 | } 157 | } 158 | } 159 | ] 160 | } 161 | ] 162 | } 163 | -------------------------------------------------------------------------------- /Code/finetune/idefics.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/image_captioning.ipynb 2 | 3 | # This example demonstrates normal finetuning (w/o peft) - for the sake of keeping the memory 4 | # requirements small it freezes the original pre-trained text and image layers to keep the memory 5 | # requirements to just 40GB. If you have multiple GPUs then you can remove the unfreeze part to 6 | # finetune the whole model. Alternatively use the PEFT solution as shown in 7 | # IDEFICS_finetuning_demo.ipynb notebook which requires only 20GB to finetune the whole model. 8 | 9 | import torch 10 | from datasets import load_dataset 11 | from peft import LoraConfig, get_peft_model 12 | from PIL import Image 13 | from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig 14 | import torchvision.transforms as transforms 15 | 16 | device = "cuda:7" if torch.cuda.is_available() else "cpu" 17 | 18 | checkpoint = "HuggingFaceM4/idefics-9b" 19 | 20 | # Here we skip some special modules that can't be quantized properly 21 | bnb_config = BitsAndBytesConfig( 22 | load_in_4bit=True, 23 | bnb_4bit_use_double_quant=True, 24 | bnb_4bit_quant_type="nf4", 25 | bnb_4bit_compute_dtype=torch.float16, 26 | llm_int8_skip_modules=["lm_head", "embed_tokens"], 27 | ) 28 | 29 | processor = AutoProcessor.from_pretrained(checkpoint) 30 | # Simply take-off the quantization_config arg if you want to load the original model 31 | model = IdeficsForVisionText2Text.from_pretrained(checkpoint, quantization_config=bnb_config, device_map="cuda:0") 32 | 33 | 34 | def check_inference(model, processor, prompts, max_new_tokens=50): 35 | tokenizer = processor.tokenizer 36 | bad_words = ["", ""] 37 | if len(bad_words) > 0: 38 | bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids 39 | 40 | eos_token = "" 41 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) 42 | 43 | inputs = processor(prompts, return_tensors="pt").to(device) 44 | generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, 45 | max_new_tokens=max_new_tokens, early_stopping=True) 46 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 47 | print(generated_text) 48 | 49 | 50 | url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg" 51 | prompts = [ 52 | # "Instruction: provide an answer to the question. Use the image to answer.\n", 53 | url, 54 | "Question: What's on the picture? Answer:", 55 | ] 56 | check_inference(model, processor, prompts) 57 | 58 | train_ds = load_dataset("json", data_files='/projects/SpatialMQA/datasets/idefics_train/train_3780.jsonl')['train'] 59 | eval_ds = load_dataset("json", data_files='/projects/SpatialMQA/datasets/idefics_train/dev_536.jsonl')['train'] 60 | 61 | 62 | def convert_to_rgb(image): 63 | # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background 64 | # for transparent images. The call to `alpha_composite` handles this case 65 | image = Image.open(image) 66 | if image.mode == "RGB": 67 | return image 68 | 69 | image_rgba = image.convert("RGBA") 70 | background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) 71 | alpha_composite = Image.alpha_composite(background, image_rgba) 72 | alpha_composite = alpha_composite.convert("RGB") 73 | return alpha_composite 74 | 75 | 76 | def ds_transforms(example_batch): 77 | image_size = processor.image_processor.image_size 78 | image_mean = processor.image_processor.image_mean 79 | image_std = processor.image_processor.image_std 80 | 81 | image_transform = transforms.Compose([ 82 | convert_to_rgb, 83 | transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), 84 | interpolation=transforms.InterpolationMode.BICUBIC), 85 | transforms.ToTensor(), 86 | transforms.Normalize(mean=image_mean, std=image_std), 87 | ]) 88 | 89 | prompts = [] 90 | for i in range(len(example_batch)): 91 | prompts.append( 92 | [ 93 | example_batch["image"][i], 94 | example_batch["text"][i], 95 | ], 96 | ) 97 | 98 | inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device) 99 | 100 | inputs["labels"] = inputs["input_ids"] 101 | 102 | return inputs 103 | 104 | 105 | train_ds.set_transform(ds_transforms) 106 | eval_ds.set_transform(ds_transforms) 107 | 108 | model_name = checkpoint.split("/")[1] 109 | config = LoraConfig( 110 | r=16, 111 | lora_alpha=32, 112 | target_modules=["q_proj", "k_proj", "v_proj"], 113 | lora_dropout=0.05, 114 | bias="none", 115 | ) 116 | model = get_peft_model(model, config) 117 | 118 | num_train_epochs = 10 119 | 120 | training_args = TrainingArguments( 121 | output_dir=f"/projects/SpatialMQA/finetune_models/models_arg/idefics_lora_20240521", 122 | num_train_epochs=5, 123 | learning_rate=2e-4, 124 | bf16=True, 125 | fp16=False, 126 | per_device_train_batch_size=8, 127 | per_device_eval_batch_size=8, 128 | gradient_accumulation_steps=8, 129 | dataloader_pin_memory=False, 130 | save_total_limit=1, 131 | evaluation_strategy="epoch", 132 | save_strategy="epoch", 133 | logging_steps=100, 134 | remove_unused_columns=False, 135 | push_to_hub=False, 136 | label_names=["labels"], 137 | load_best_model_at_end=True, 138 | report_to=None, 139 | optim="paged_adamw_8bit", 140 | ) 141 | 142 | trainer = Trainer( 143 | model=model, 144 | args=training_args, 145 | train_dataset=train_ds, 146 | eval_dataset=eval_ds, 147 | ) 148 | 149 | trainer.train() 150 | 151 | model.save_pretrained("/projects/SpatialMQA/finetune_models/models_arg/idefics_lora_20240521") 152 | 153 | check_inference(model, processor, prompts, max_new_tokens=100) 154 | -------------------------------------------------------------------------------- /Code/close_models/gpt4_text_only.py: -------------------------------------------------------------------------------- 1 | # import openai 2 | from openai import OpenAI 3 | import json 4 | import requests 5 | from io import BytesIO 6 | import PIL.Image as Image 7 | import base64 8 | 9 | client = OpenAI( 10 | api_key='your key', 11 | base_url='https://gpt.mnxcc.com/v1' 12 | ) 13 | client.api_base = "https://api.foureast.cn/v1" 14 | 15 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 16 | count = 0 17 | right_count = 0 18 | 19 | 20 | def fetch_image_content(image_url): 21 | response = requests.get(image_url) 22 | if response.status_code == 200: 23 | return BytesIO(response.content) 24 | else: 25 | return None 26 | 27 | 28 | def encode_image(image): 29 | if image is None: 30 | return None 31 | 32 | buffered = BytesIO() 33 | try: 34 | image.save(buffered, format="JPEG") 35 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') 36 | return 'data:image/jpeg;base64,' + img_str 37 | except Exception as e: 38 | print(f"encoding error: {e}") 39 | return None 40 | 41 | 42 | def call_gpt4(prompt: str, question: str, image): 43 | try: 44 | response = client.chat.completions.create( 45 | model="gpt-4o", 46 | messages=[ 47 | { 48 | "role": "user", 49 | 50 | # text only: 51 | "content": [{"type": "text", "text": prompt}] \ 52 | + [{"type": "text", "text": question}] 53 | } 54 | ], 55 | max_tokens=500, 56 | # temperature = 0.3, 57 | ) 58 | # print(response.choices[0].message.content.strip()) 59 | return response.choices[0].message.content.strip() 60 | 61 | except Exception as e: 62 | print(f"Error during answering: {e}") 63 | return None 64 | 65 | 66 | def process_jsonl(input_file, output_file): 67 | with open(input_file, 'r', encoding='utf-8') as file: 68 | with open(output_file, 'w', encoding='utf-8') as out_file: 69 | for line in file: 70 | data = json.loads(line) 71 | id = data['id'] 72 | 73 | # text - only 74 | prompt = f'You are currently a senior expert in spatial relation reasoning. ' \ 75 | f'\nGiven a Question and Options, your task is to use your common sense to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' 76 | question = f'\nInput: Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 77 | 78 | image_id = data['image'] 79 | image_url = f'{IMAGE_DIR}/{image_id}' 80 | image_content = fetch_image_content(image_url) 81 | 82 | if image_content is not None: 83 | image = Image.open(image_content) 84 | image_encoded = encode_image(image) 85 | 86 | global count 87 | global right_count 88 | try: 89 | model_answer = call_gpt4(prompt, question, image_encoded) 90 | except Exception as e: 91 | print(e) 92 | try: 93 | model_answer = call_gpt4(prompt, question, image_encoded) 94 | except Exception as e: 95 | try: 96 | model_answer = call_gpt4(prompt, question, image_encoded) 97 | except Exception as e: 98 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], 99 | "rule": data['rule'], "example": 0} 100 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 101 | count += 1 102 | continue 103 | 104 | try: 105 | # output = model_answer.text.strip().rstrip('.').lower() 106 | output = model_answer.strip().rstrip('.').lower() 107 | except Exception as e: 108 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], 109 | "rule": data['rule'], "example": 0} 110 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 111 | count += 1 112 | continue 113 | count += 1 114 | if output in data['answer']: 115 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], 116 | "rule": data['rule'], "example": 0} 117 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 118 | right_count += 1 119 | else: 120 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], 121 | "rule": data['rule'], "example": 0} 122 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 123 | print(f'{output.lower()}') 124 | print(f"{data['answer']}") 125 | print(f'right_count: {right_count}') 126 | print(f'count: {count}') 127 | # print(f'accuracy: {right_count / count}') 128 | 129 | accuracy = right_count / count 130 | print(accuracy) 131 | 132 | 133 | input_file_path = "datasets/test_en_500.jsonl" 134 | output_file_path = "model_results/gpt4_text_only.jsonl" 135 | 136 | process_jsonl(input_file_path, output_file_path) 137 | -------------------------------------------------------------------------------- /Code/close_models/gpt4_zero_shot.py: -------------------------------------------------------------------------------- 1 | # import openai 2 | from openai import OpenAI 3 | import json 4 | import requests 5 | from io import BytesIO 6 | import PIL.Image as Image 7 | import base64 8 | 9 | 10 | client = OpenAI( 11 | api_key='your key', 12 | base_url='https://gpt.mnxcc.com/v1' 13 | ) 14 | client.api_base = "https://api.foureast.cn/v1" 15 | 16 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 17 | count = 0 18 | right_count = 0 19 | 20 | 21 | def fetch_image_content(image_url): 22 | response = requests.get(image_url) 23 | if response.status_code == 200: 24 | return BytesIO(response.content) 25 | else: 26 | return None 27 | 28 | 29 | def encode_image(image): 30 | if image is None: 31 | return None 32 | 33 | buffered = BytesIO() 34 | try: 35 | image.save(buffered, format="JPEG") 36 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') 37 | return 'data:image/jpeg;base64,' + img_str 38 | except Exception as e: 39 | print(f"encoding error: {e}") 40 | return None 41 | 42 | 43 | def call_gpt4(prompt: str, image, detail='auto'): 44 | try: 45 | response = client.chat.completions.create( 46 | model="gpt-4o", 47 | messages=[ 48 | { 49 | "role": "user", 50 | # zero-shot + few-shot 51 | "content": [{"type": "text", "text": prompt}] \ 52 | + [{"type": "image_url", "image_url": { "url": image, "detail": detail}}] 53 | } 54 | ], 55 | max_tokens=500, 56 | temperature=0.5, 57 | ) 58 | # print(response.choices[0].message.content.strip()) 59 | return response.choices[0].message.content.strip() 60 | 61 | except Exception as e: 62 | print(f"Error during answering: {e}") 63 | return None 64 | 65 | 66 | def process_jsonl(input_file, output_file): 67 | with open(input_file, 'r', encoding='utf-8') as file: 68 | with open(output_file, 'w', encoding='utf-8') as out_file: 69 | for line in file: 70 | data = json.loads(line) 71 | id = data['id'] 72 | 73 | # zero - shot 74 | prompt = f'You are currently a senior expert in spatial relation reasoning. ' \ 75 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 76 | f'\nInput: Image:, Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 77 | 78 | image_id = data['image'] 79 | image_url = f'{IMAGE_DIR}/{image_id}' 80 | image_content = fetch_image_content(image_url) 81 | 82 | if image_content is not None: 83 | image = Image.open(image_content) 84 | image_encoded = encode_image(image) 85 | 86 | global count 87 | global right_count 88 | try: 89 | model_answer = call_gpt4(prompt, image_encoded) 90 | except Exception as e: 91 | print(e) 92 | try: 93 | model_answer = call_gpt4(prompt, image_encoded) 94 | except Exception as e: 95 | try: 96 | model_answer = call_gpt4(prompt, image_encoded) 97 | except Exception as e: 98 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], 99 | "rule": data['rule'], "example": 0} 100 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 101 | count += 1 102 | continue 103 | 104 | try: 105 | # output = model_answer.text.strip().rstrip('.').lower() 106 | output = model_answer.strip().rstrip('.').lower() 107 | except Exception as e: 108 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], 109 | "rule": data['rule'], "example": 0} 110 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 111 | count += 1 112 | continue 113 | count += 1 114 | if output in data['answer']: 115 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], 116 | "rule": data['rule'], "example": 0} 117 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 118 | right_count += 1 119 | else: 120 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], 121 | "rule": data['rule'], "example": 0} 122 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 123 | print(f'{output.lower()}') 124 | print(f"{data['answer']}") 125 | print(f'right_count: {right_count}') 126 | print(f'count: {count}') 127 | # print(f'accuracy: {right_count / count}') 128 | 129 | accuracy = right_count / count 130 | print(accuracy) 131 | 132 | 133 | input_file_path = "datasets/test_en_500.jsonl" 134 | output_file_path = "model_results/gpt4_zero_shot.jsonl" 135 | 136 | 137 | process_jsonl(input_file_path, output_file_path) 138 | -------------------------------------------------------------------------------- /Code/eval/calculate_xyz.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | all_correct = 0 5 | position = {'y': ["on/above", "below"], 'z': ["in front of", "behind"], 6 | 'x': ["left of", "right of"]} 7 | 8 | 9 | def load_jsonl(file_path): 10 | data = [] 11 | with open(file_path, 'r', encoding='utf-8') as f: 12 | for line in f: 13 | data.append(json.loads(line.strip())) 14 | return data 15 | 16 | 17 | def change_output(predict): 18 | if "front" in predict['output']: 19 | predict['output'] = "in front of" 20 | elif "behind" in predict['output']: 21 | predict['output'] = "behind" 22 | elif "left" in predict['output']: 23 | predict['output'] = "left of" 24 | elif "right" in predict['output']: 25 | predict['output'] = "right of" 26 | elif "below" in predict['output']: 27 | predict['output'] = "below" 28 | elif "on" in predict['output'] or "above" in predict['output']: 29 | predict['output'] = "on/above" 30 | else: 31 | predict['output'] = "--" 32 | return predict['output'] 33 | 34 | 35 | def calculate_accuracy(predictions, option): 36 | correct_count = 0 37 | total_count = 0 38 | option1 = position[option][0] 39 | option2 = position[option][1] 40 | for predict in predictions: 41 | if predict['answer'] == option1 or predict['answer'] == option2: 42 | total_count += 1 43 | if predict['result'] == 1: 44 | correct_count += 1 45 | global all_correct 46 | all_correct += correct_count 47 | if total_count == 0: 48 | return 0.0 49 | print(option + ": " + str(correct_count) + "; " + str(total_count)) 50 | return correct_count / total_count 51 | 52 | 53 | def main(): 54 | # test_file_list = os.listdir('model_exp_results/m10/') 55 | test_file_list = os.listdir('model_exp_results/move_to_localhost/') 56 | for test_file in test_file_list: 57 | print(f'-------------------------------{test_file}----------------------------------') 58 | test_file_path = f'model_exp_results/move_to_localhost/{test_file}' 59 | # test_file_path = f'model_exp_results/m10/{test_file}' 60 | 61 | # 重新处理一下on/above的输出结果: 62 | predictions = load_jsonl(test_file_path) 63 | change_count = 0 64 | # in_front_of_count = 0 65 | for predict in predictions: 66 | if 501 > predict['id'] > 0: 67 | if "on" in predict['output'] or "above" in predict['output']: 68 | predict['result'] = 1 69 | change_count += 1 70 | # elif 1756 > predict['id'] > 945: # 处理in front of(id:958~1751) 71 | # if "front" in predict['output'] or "front" in predict['answer']: 72 | # in_front_of_count += 1 73 | # for predict in predictions: 74 | # if 512 > predict['id']: # 处理on/above(id:1~500) 75 | # if "on" in predict['output'] or "above" in predict['output']: 76 | # predict['result'] = 1 77 | # predict['output'] = "on/above" 78 | # change_count += 1 79 | # else: 80 | # predict['output'] = change_output(predict) 81 | # elif 958 > predict['id'] > 500: # 处理below(id:512~945) 82 | # if "below" in predict['output']: 83 | # predict['result'] = 1 84 | # predict['output'] = "below" 85 | # change_count += 1 86 | # else: 87 | # predict['output'] = change_output(predict) 88 | # elif 1756 > predict['id'] > 945: # 处理in front of(id:958~1751) 89 | # if "front" in predict['output']: 90 | # predict['result'] = 1 91 | # predict['output'] = "in front of" 92 | # change_count += 1 93 | # else: 94 | # predict['output'] = change_output(predict) 95 | # elif 2513 > predict['id'] > 1751: # 处理behind(id:1756~2509) 96 | # if "behind" in predict['output']: 97 | # predict['result'] = 1 98 | # predict['output'] = "behind" 99 | # change_count += 1 100 | # else: 101 | # predict['output'] = change_output(predict) 102 | # elif 3910 > predict['id'] > 2509: # 处理left of(id:2513~3906) 103 | # if "left" in predict['output']: 104 | # predict['result'] = 1 105 | # predict['output'] = "left of" 106 | # change_count += 1 107 | # else: 108 | # predict['output'] = change_output(predict) 109 | # elif predict['id'] > 3906: # 处理right of(id:3910~5391) 110 | # if "right" in predict['output']: 111 | # predict['result'] = 1 112 | # predict['output'] = "right of" 113 | # change_count += 1 114 | # else: 115 | # predict['output'] = change_output(predict) 116 | 117 | # with open(test_file_path, 'w', encoding='utf-8') as f: 118 | # for predict in predictions: 119 | # f.write(json.dumps(predict) + '\n') 120 | # print(f"修改此文件中的on/above的结果数据共:{change_count}条") 121 | # print(f"此文件中的in front of的结果数据共:{in_front_of_count}条") 122 | 123 | predictions = load_jsonl(test_file_path) 124 | print(len(predictions)) 125 | # answers = load_jsonl('../SpatialMQA(En)/all_en_test_1076_sort.jsonl') 126 | 127 | accuracies = {} 128 | for option in ['x', 'y', 'z']: 129 | accuracies[option] = calculate_accuracy(predictions, option) 130 | 131 | global all_correct 132 | for option, accuracy in accuracies.items(): 133 | print(f"Accuracy for '{option}': {accuracy:.4f}") 134 | print(f"Accuracy for 'overall': {(all_correct / len(predictions)):.4f}") 135 | all_correct = 0 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /Code/finetune/blip-vqa-base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from transformers import BlipProcessor, BlipForQuestionAnswering 5 | from datasets import load_dataset 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | import pickle 11 | 12 | model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") 13 | processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") 14 | 15 | device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") 16 | model.to(device) 17 | 18 | torch.cuda.empty_cache() 19 | 20 | 21 | class VQADataset(torch.utils.data.Dataset): 22 | """VQA (v2) dataset.""" 23 | 24 | def __init__(self, dataset, processor): 25 | self.dataset = dataset 26 | self.processor = processor 27 | 28 | def __len__(self): 29 | return len(self.dataset) 30 | 31 | def __getitem__(self, idx): 32 | # get image + text 33 | question = self.dataset[idx]['question'] 34 | answer = str(self.dataset[idx]['answer']) 35 | image_id = self.dataset[idx]['image'] 36 | image_path = f"/projects/SpatialMQA/COCO2017/test2017/{image_id}" 37 | image = Image.open(image_path).convert("RGB") 38 | text = question 39 | 40 | encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt") 41 | labels = self.processor.tokenizer.encode( 42 | answer, max_length=8, pad_to_max_length=True, return_tensors='pt' 43 | ) 44 | encoding["labels"] = labels 45 | # remove batch dimension 46 | for k, v in encoding.items(): encoding[k] = v.squeeze() 47 | return encoding 48 | 49 | 50 | training_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip_train/train_3780.jsonl", 51 | split="train[:100%]") 52 | valid_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip_train/dev_536.jsonl", 53 | split="train[:100%]") 54 | print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset))) 55 | 56 | train_dataset = VQADataset(dataset=training_dataset, 57 | processor=processor) 58 | valid_dataset = VQADataset(dataset=valid_dataset, 59 | processor=processor) 60 | 61 | batch_size = 8 62 | cal_num = 2 63 | torch.manual_seed(42) 64 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 65 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 66 | 67 | optimizer = torch.optim.AdamW(model.parameters(), lr=6e-7) 68 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False) 69 | 70 | num_epochs = 30 71 | patience = 5 72 | min_eval_loss = float("inf") 73 | early_stopping_hook = 0 74 | tracking_information = [] 75 | scaler = torch.cuda.amp.GradScaler() 76 | 77 | for epoch in range(num_epochs): 78 | epoch_loss = 0 79 | model.train() 80 | cal_loss = 0 81 | for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader): 82 | input_ids = batch.pop('input_ids').to(device) 83 | pixel_values = batch.pop('pixel_values').to(device) 84 | attention_masked = batch.pop('attention_mask').to(device) 85 | labels = batch.pop('labels').to(device) 86 | 87 | with torch.amp.autocast(device_type='cuda', dtype=torch.float16): 88 | outputs = model(input_ids=input_ids, 89 | pixel_values=pixel_values, 90 | attention_mask=attention_masked, 91 | labels=labels) 92 | optimizer.zero_grad() 93 | loss = outputs.loss 94 | cal_loss += loss 95 | 96 | epoch_loss += loss.item() 97 | 98 | if (idx + 1) % cal_num == 0 or idx == len(train_dataloader) - 1: 99 | if (idx + 1) % cal_num == 0: 100 | cal_loss = cal_loss / cal_num 101 | else: 102 | cal_loss = cal_loss / ((idx + 1) % cal_num) 103 | scaler.scale(cal_loss).backward() 104 | scaler.step(optimizer) 105 | scaler.update() 106 | cal_loss = 0 107 | 108 | model.eval() 109 | eval_loss = 0 110 | for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader): 111 | input_ids = batch.pop('input_ids').to(device) 112 | pixel_values = batch.pop('pixel_values').to(device) 113 | attention_masked = batch.pop('attention_mask').to(device) 114 | labels = batch.pop('labels').to(device) 115 | 116 | with torch.amp.autocast(device_type='cuda', dtype=torch.float16): 117 | outputs = model(input_ids=input_ids, 118 | pixel_values=pixel_values, 119 | attention_mask=attention_masked, 120 | labels=labels) 121 | 122 | loss = outputs.loss 123 | eval_loss += loss.item() 124 | 125 | tracking_information.append( 126 | (epoch_loss / len(train_dataloader), eval_loss / len(valid_dataloader), optimizer.param_groups[0]["lr"])) 127 | print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch + 1, epoch_loss / len(train_dataloader), 128 | eval_loss / len(valid_dataloader), 129 | optimizer.param_groups[0]["lr"])) 130 | scheduler.step() 131 | if eval_loss < min_eval_loss: 132 | model.save_pretrained(f"/projects/SpatialMQA/finetune_models/models_arg/blip_finetune_20240602/epoch-{epoch}", 133 | from_pt=True) 134 | print(f"Saved model to /projects/SpatialMQA/finetune_models/models_arg/blip_finetune_20240602/epoch-{epoch}") 135 | min_eval_loss = eval_loss 136 | early_stopping_hook = 0 137 | else: 138 | early_stopping_hook += 1 139 | if early_stopping_hook > patience: 140 | break 141 | 142 | pickle.dump(tracking_information, open("tracking_information.pkl", "wb")) 143 | print("The finetuning process has done!") 144 | -------------------------------------------------------------------------------- /Code/experiment/spacellava_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 3 | 4 | import argparse 5 | import torch 6 | import json 7 | from llava.constants import ( 8 | IMAGE_TOKEN_INDEX, 9 | DEFAULT_IMAGE_TOKEN, 10 | DEFAULT_IM_START_TOKEN, 11 | DEFAULT_IM_END_TOKEN, 12 | IMAGE_PLACEHOLDER, 13 | ) 14 | from llava.conversation import conv_templates, SeparatorStyle 15 | from llava.model.builder import load_pretrained_model 16 | from llava.utils import disable_torch_init 17 | from llava.mm_utils import ( 18 | process_images, 19 | tokenizer_image_token, 20 | get_model_name_from_path, 21 | ) 22 | 23 | from PIL import Image 24 | 25 | import requests 26 | from PIL import Image 27 | from io import BytesIO 28 | import re 29 | 30 | RESULT_FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/results/spacellava_20240812_noPrompts_3.jsonl' 31 | FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/all_en_test_1076_sort.jsonl' 32 | IMAGE_DIR = '/home/zouyinan/SpatialVLM/SpaceLLaVa/pic/' 33 | 34 | 35 | def image_parser(args): 36 | out = args.image_file.split(args.sep) 37 | return out 38 | 39 | 40 | def load_image(image_file): 41 | if image_file.startswith("http") or image_file.startswith("https"): 42 | response = requests.get(image_file) 43 | image = Image.open(BytesIO(response.content)).convert("RGB") 44 | else: 45 | image = Image.open(image_file).convert("RGB") 46 | return image 47 | 48 | 49 | def load_images(image_files): 50 | out = [] 51 | for image_file in image_files: 52 | image = load_image(image_file) 53 | out.append(image) 54 | return out 55 | 56 | 57 | # Model 58 | disable_torch_init() 59 | model_path='/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/spacellava_hf' 60 | args = type('Args', (), { 61 | "model_path": model_path, 62 | "model_base": None, 63 | "model_name": get_model_name_from_path(model_path), 64 | # "query": prompt, 65 | "conv_mode": None, 66 | # "image_file": image_file, 67 | "sep": ",", 68 | "temperature": 0.9, 69 | "top_p": None, 70 | "num_beams": 1, 71 | "max_new_tokens": 512 72 | })() 73 | 74 | model_name = get_model_name_from_path(model_path) 75 | tokenizer, model, image_processor, context_len = load_pretrained_model( 76 | args.model_path, None, model_name 77 | ) 78 | 79 | def eval_model(args,question,image_file): 80 | 81 | 82 | qs = question 83 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 84 | if IMAGE_PLACEHOLDER in qs: 85 | if model.config.mm_use_im_start_end: 86 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 87 | else: 88 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 89 | else: 90 | if model.config.mm_use_im_start_end: 91 | qs = image_token_se + "\n" + qs 92 | else: 93 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 94 | 95 | if "llama-2" in model_name.lower(): 96 | conv_mode = "llava_llama_2" 97 | elif "mistral" in model_name.lower(): 98 | conv_mode = "mistral_instruct" 99 | elif "v1.6-34b" in model_name.lower(): 100 | conv_mode = "chatml_direct" 101 | elif "v1" in model_name.lower(): 102 | conv_mode = "llava_v1" 103 | elif "mpt" in model_name.lower(): 104 | conv_mode = "mpt" 105 | else: 106 | conv_mode = "llava_v0" 107 | 108 | if args.conv_mode is not None and conv_mode != args.conv_mode: 109 | print( 110 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 111 | conv_mode, args.conv_mode, args.conv_mode 112 | ) 113 | ) 114 | else: 115 | args.conv_mode = conv_mode 116 | 117 | conv = conv_templates[args.conv_mode].copy() 118 | conv.append_message(conv.roles[0], qs) 119 | conv.append_message(conv.roles[1], None) 120 | prompt = conv.get_prompt() 121 | image_files = [image_file] 122 | images = load_images(image_files) 123 | image_sizes = [x.size for x in images] 124 | images_tensor = process_images( 125 | images, 126 | image_processor, 127 | model.config 128 | ).to(model.device, dtype=torch.float16) 129 | 130 | input_ids = ( 131 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 132 | .unsqueeze(0) 133 | .cuda() 134 | ) 135 | 136 | with torch.inference_mode(): 137 | output_ids = model.generate( 138 | input_ids, 139 | images=images_tensor, 140 | image_sizes=image_sizes, 141 | do_sample=True if args.temperature > 0 else False, 142 | temperature=args.temperature, 143 | top_p=args.top_p, 144 | num_beams=args.num_beams, 145 | max_new_tokens=args.max_new_tokens, 146 | use_cache=True, 147 | ) 148 | 149 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 150 | return outputs 151 | count = 0 152 | right_count = 0 153 | 154 | with open(f'{FILE_PATH}', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}', 'w+', encoding="utf-8") as fout: 155 | for line in f: 156 | data = json.loads(line) 157 | question = data['question'] 158 | id = data['id'] 159 | options = data['options'] 160 | image_name = data['image'] 161 | image_filepath = IMAGE_DIR + image_name 162 | question = f'Question: {question} \nOptions: {"; ".join(options)} \nAnswer:' 163 | output = eval_model(args,question,image_filepath) 164 | count += 1 165 | if len(output) == 0: 166 | output = '--' 167 | if output.lower() in data['answer']: 168 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 169 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 170 | right_count += 1 171 | elif data['answer'] in output.lower(): 172 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 173 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 174 | right_count += 1 175 | else: 176 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 177 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 178 | print(f'{output.lower()}') 179 | print(f"{data['answer']}") 180 | print(f'right_count: {right_count}') 181 | print(f'count: {count}') 182 | print(f'accuracy: {right_count/count}') 183 | # print(f'accuracy: {right_count/count}') 184 | 185 | accuracy = right_count/count 186 | print(f'accuracy: {accuracy}') -------------------------------------------------------------------------------- /Code/finetune/instructblip-lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, T5ForConditionalGeneration 5 | from datasets import load_dataset 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | import pickle 11 | 12 | model_dir = '/projects/Models/' 13 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 14 | 15 | from peft import LoraConfig, get_peft_model 16 | 17 | # Let's define the LoraConfig 18 | config = LoraConfig( 19 | r=16, 20 | lora_alpha=32, 21 | lora_dropout=0.05, 22 | bias="none", 23 | target_modules=["q", "k"] 24 | ) 25 | 26 | processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl") 27 | 28 | device = "cuda:6" if torch.cuda.is_available() else "cpu" 29 | 30 | model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl") 31 | 32 | model = get_peft_model(model, config).to(device) 33 | 34 | torch.cuda.empty_cache() 35 | torch.manual_seed(42) 36 | 37 | 38 | class VQADataset(torch.utils.data.Dataset): 39 | """VQA (v2) dataset.""" 40 | 41 | def __init__(self, dataset, processor): 42 | self.dataset = dataset 43 | self.processor = processor 44 | 45 | def __len__(self): 46 | return len(self.dataset) 47 | 48 | def __getitem__(self, idx): 49 | # get image + text 50 | question = self.dataset[idx]['question'] 51 | answer = str(self.dataset[idx]['answer']) 52 | image_id = self.dataset[idx]['image'] 53 | image_path = f"{image_dir}{image_id}" 54 | image = Image.open(image_path).convert("RGB") 55 | 56 | encoding = self.processor(images=image, text=question, return_tensors="pt") 57 | labels = self.processor.tokenizer.encode( 58 | answer, max_length=8, pad_to_max_length=True, return_tensors='pt' 59 | ) 60 | 61 | encoding["labels"] = labels 62 | # remove batch dimension 63 | for k, v in encoding.items(): encoding[k] = v.squeeze() 64 | print(encoding['labels'], answer) 65 | return encoding 66 | 67 | 68 | training_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/instructblip_train/train_3780.jsonl", 69 | split="train[:100%]") 70 | valid_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/instructblip_train/dev_536.jsonl", 71 | split="train[:100%]") 72 | print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset))) 73 | 74 | train_dataset = VQADataset(dataset=training_dataset, 75 | processor=processor) 76 | valid_dataset = VQADataset(dataset=valid_dataset, 77 | processor=processor) 78 | 79 | batch_size = 8 80 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 81 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 82 | 83 | optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5) 84 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False) 85 | 86 | num_epochs = 30 87 | patience = 5 88 | min_eval_loss = float("inf") 89 | early_stopping_hook = 0 90 | tracking_information = [] 91 | scaler = torch.cuda.amp.GradScaler() 92 | 93 | for epoch in range(num_epochs): 94 | epoch_loss = 0 95 | model.train() 96 | for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader): 97 | input_ids = batch.pop('input_ids').to(device) 98 | qformer_input_ids = batch.pop('qformer_input_ids').to(device) 99 | qformer_attention_mask = batch.pop('qformer_attention_mask').to(device) 100 | pixel_values = batch.pop('pixel_values').to(device) 101 | attention_mask = batch.pop('attention_mask').to(device) 102 | labels = batch.pop('labels').to(device) 103 | with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): 104 | outputs = model(input_ids=input_ids, 105 | qformer_input_ids=qformer_input_ids, 106 | qformer_attention_mask=qformer_attention_mask, 107 | pixel_values=pixel_values, 108 | attention_mask=attention_mask, 109 | labels=labels 110 | ) 111 | loss = outputs.loss 112 | print('loss:', loss) 113 | epoch_loss += loss.item() 114 | optimizer.zero_grad() 115 | 116 | scaler.scale(loss).backward() 117 | scaler.step(optimizer) 118 | scaler.update() 119 | 120 | model.eval() 121 | eval_loss = 0 122 | for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader): 123 | input_ids = batch.pop('input_ids').to(device) 124 | qformer_input_ids = batch.pop('qformer_input_ids').to(device) 125 | qformer_attention_mask = batch.pop('qformer_attention_mask').to(device) 126 | pixel_values = batch.pop('pixel_values').to(device) 127 | attention_mask = batch.pop('attention_mask').to(device) 128 | labels = batch.pop('labels').to(device) 129 | outputs = model(input_ids=input_ids, 130 | qformer_input_ids=qformer_input_ids, 131 | qformer_attention_mask=qformer_attention_mask, 132 | pixel_values=pixel_values, 133 | attention_mask=attention_mask, 134 | labels=labels) 135 | loss = outputs.loss 136 | eval_loss += loss.item() 137 | 138 | tracking_information.append( 139 | (epoch_loss / len(train_dataloader), eval_loss / len(valid_dataloader), optimizer.param_groups[0]["lr"])) 140 | print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch + 1, epoch_loss / len(train_dataloader), 141 | eval_loss / len(valid_dataloader), 142 | optimizer.param_groups[0]["lr"])) 143 | scheduler.step() 144 | if eval_loss < min_eval_loss: 145 | model.save_pretrained("/projects/SpatialMQA/finetune_models/models_arg/instructblip_lora_20240530", 146 | from_pt=True) 147 | print("Saved model to /projects/SpatialMQA/finetune_models/models_arg/instructblip_lora_20240530") 148 | min_eval_loss = eval_loss 149 | early_stopping_hook = 0 150 | else: 151 | early_stopping_hook += 1 152 | if early_stopping_hook > patience: 153 | break 154 | 155 | pickle.dump(tracking_information, open("tracking_information.pkl", "wb")) 156 | print("The finetuning process has done!") 157 | -------------------------------------------------------------------------------- /Code/experiment/spatial_test_llava.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "6" 4 | 5 | import argparse 6 | import torch 7 | import json 8 | from llava.constants import ( 9 | IMAGE_TOKEN_INDEX, 10 | DEFAULT_IMAGE_TOKEN, 11 | DEFAULT_IM_START_TOKEN, 12 | DEFAULT_IM_END_TOKEN, 13 | IMAGE_PLACEHOLDER, 14 | ) 15 | from llava.conversation import conv_templates, SeparatorStyle 16 | from llava.model.builder import load_pretrained_model 17 | from llava.utils import disable_torch_init 18 | from llava.mm_utils import ( 19 | process_images, 20 | tokenizer_image_token, 21 | get_model_name_from_path, 22 | ) 23 | 24 | from PIL import Image 25 | 26 | import requests 27 | from PIL import Image 28 | from io import BytesIO 29 | import re 30 | 31 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/10/' 32 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 33 | IMAGE_DIR = '/projects/SpatialMQA/COCO2017/test2017/' 34 | 35 | 36 | def image_parser(args): 37 | out = args.image_file.split(args.sep) 38 | return out 39 | 40 | 41 | def load_image(image_file): 42 | if image_file.startswith("http") or image_file.startswith("https"): 43 | response = requests.get(image_file) 44 | image = Image.open(BytesIO(response.content)).convert("RGB") 45 | else: 46 | image = Image.open(image_file).convert("RGB") 47 | return image 48 | 49 | 50 | def load_images(image_files): 51 | out = [] 52 | for image_file in image_files: 53 | image = load_image(image_file) 54 | out.append(image) 55 | return out 56 | 57 | 58 | # Model 59 | disable_torch_init() 60 | model_path = '/projects/Models/LLaVA-main/llava-v1.5-7b' 61 | args = type('Args', (), { 62 | "model_path": model_path, 63 | "model_base": None, 64 | "model_name": get_model_name_from_path(model_path), 65 | "conv_mode": None, 66 | "sep": ",", 67 | "temperature": 0.4, 68 | "top_p": None, 69 | "num_beams": 1, 70 | "max_new_tokens": 512 71 | })() 72 | 73 | model_name = get_model_name_from_path(model_path) 74 | tokenizer, model, image_processor, context_len = load_pretrained_model( 75 | args.model_path, None, model_name 76 | ) 77 | 78 | 79 | def eval_model(args, question, image_file): 80 | qs = question 81 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 82 | if IMAGE_PLACEHOLDER in qs: 83 | if model.config.mm_use_im_start_end: 84 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 85 | else: 86 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 87 | else: 88 | if model.config.mm_use_im_start_end: 89 | qs = image_token_se + "\n" + qs 90 | else: 91 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 92 | 93 | if "llama-2" in model_name.lower(): 94 | conv_mode = "llava_llama_2" 95 | elif "mistral" in model_name.lower(): 96 | conv_mode = "mistral_instruct" 97 | elif "v1.6-34b" in model_name.lower(): 98 | conv_mode = "chatml_direct" 99 | elif "v1" in model_name.lower(): 100 | conv_mode = "llava_v1" 101 | elif "mpt" in model_name.lower(): 102 | conv_mode = "mpt" 103 | else: 104 | conv_mode = "llava_v0" 105 | 106 | if args.conv_mode is not None and conv_mode != args.conv_mode: 107 | print( 108 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 109 | conv_mode, args.conv_mode, args.conv_mode 110 | ) 111 | ) 112 | else: 113 | args.conv_mode = conv_mode 114 | 115 | conv = conv_templates[args.conv_mode].copy() 116 | conv.append_message(conv.roles[0], qs) 117 | conv.append_message(conv.roles[1], None) 118 | prompt = conv.get_prompt() 119 | image_files = [image_file] 120 | images = load_images(image_files) 121 | image_sizes = [x.size for x in images] 122 | images_tensor = process_images( 123 | images, 124 | image_processor, 125 | model.config 126 | ).to(model.device, dtype=torch.float16) 127 | 128 | input_ids = ( 129 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 130 | .unsqueeze(0) 131 | .cuda() 132 | ) 133 | 134 | with torch.inference_mode(): 135 | output_ids = model.generate( 136 | input_ids, 137 | images=images_tensor, 138 | image_sizes=image_sizes, 139 | do_sample=True if args.temperature > 0 else False, 140 | temperature=args.temperature, 141 | top_p=args.top_p, 142 | num_beams=args.num_beams, 143 | max_new_tokens=args.max_new_tokens, 144 | use_cache=True, 145 | ) 146 | 147 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 148 | return outputs 149 | 150 | 151 | count = 0 152 | right_count = 0 153 | 154 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f, open( 155 | f'{RESULT_FILE_PATH}llava_4_5.jsonl', 'w+', encoding="utf-8") as fout: 156 | for line in f: 157 | data = json.loads(line) 158 | question = data['question'] 159 | id = data['id'] 160 | options = data['options'] 161 | image_name = data['image'] 162 | image_filepath = IMAGE_DIR + image_name 163 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 164 | output = eval_model(args, question, image_filepath) 165 | count += 1 166 | if len(output) == 0: 167 | output = '--' 168 | if output.lower() in data['answer']: 169 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']} 170 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 171 | right_count += 1 172 | elif data['answer'] in output.lower(): 173 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']} 174 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 175 | right_count += 1 176 | else: 177 | result_json = {'id': id, 'result': 0, "output": output.lower(), "answer": data['answer']} 178 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 179 | print(f'{output.lower()}') 180 | print(f"{data['answer']}") 181 | print(f'right_count: {right_count}') 182 | print(f'count: {count}') 183 | print(f'accuracy: {right_count / count}') 184 | 185 | accuracy = right_count / count 186 | print(f'accuracy: {accuracy}') 187 | -------------------------------------------------------------------------------- /Code/experiment/spatial_test_llava_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "7" 4 | 5 | import argparse 6 | import torch 7 | import json 8 | from llava.constants import ( 9 | IMAGE_TOKEN_INDEX, 10 | DEFAULT_IMAGE_TOKEN, 11 | DEFAULT_IM_START_TOKEN, 12 | DEFAULT_IM_END_TOKEN, 13 | IMAGE_PLACEHOLDER, 14 | ) 15 | from llava.conversation import conv_templates, SeparatorStyle 16 | from llava.model.builder import load_pretrained_model 17 | from llava.utils import disable_torch_init 18 | from llava.mm_utils import ( 19 | process_images, 20 | tokenizer_image_token, 21 | get_model_name_from_path, 22 | ) 23 | 24 | from PIL import Image 25 | 26 | import requests 27 | from PIL import Image 28 | from io import BytesIO 29 | import re 30 | 31 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/10/' 32 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/' 33 | IMAGE_DIR = '/projects/SpatialMQA/COCO2017/test2017/' 34 | 35 | 36 | def image_parser(args): 37 | out = args.image_file.split(args.sep) 38 | return out 39 | 40 | 41 | def load_image(image_file): 42 | if image_file.startswith("http") or image_file.startswith("https"): 43 | response = requests.get(image_file) 44 | image = Image.open(BytesIO(response.content)).convert("RGB") 45 | else: 46 | image = Image.open(image_file).convert("RGB") 47 | return image 48 | 49 | 50 | def load_images(image_files): 51 | out = [] 52 | for image_file in image_files: 53 | image = load_image(image_file) 54 | out.append(image) 55 | return out 56 | 57 | 58 | # Model 59 | disable_torch_init() 60 | model_path = '/projects/Models/LLaVA-main/llava-v1.5-7b' 61 | args = type('Args', (), { 62 | "model_path": model_path, 63 | "model_base": None, 64 | "model_name": get_model_name_from_path(model_path), 65 | "conv_mode": None, 66 | "sep": ",", 67 | "temperature": 0.4, 68 | "top_p": None, 69 | "num_beams": 1, 70 | "max_new_tokens": 512 71 | })() 72 | 73 | model_name = get_model_name_from_path(model_path) 74 | tokenizer, model, image_processor, context_len = load_pretrained_model( 75 | args.model_path, None, model_name 76 | ) 77 | peft_model_id = "/projects/SpatialMQA/finetune_models/models_arg/llava_lora_20240522" 78 | model.load_adapter(peft_model_id) 79 | 80 | 81 | def eval_model(args, question, image_file): 82 | qs = question 83 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 84 | if IMAGE_PLACEHOLDER in qs: 85 | if model.config.mm_use_im_start_end: 86 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 87 | else: 88 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 89 | else: 90 | if model.config.mm_use_im_start_end: 91 | qs = image_token_se + "\n" + qs 92 | else: 93 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 94 | 95 | if "llama-2" in model_name.lower(): 96 | conv_mode = "llava_llama_2" 97 | elif "mistral" in model_name.lower(): 98 | conv_mode = "mistral_instruct" 99 | elif "v1.6-34b" in model_name.lower(): 100 | conv_mode = "chatml_direct" 101 | elif "v1" in model_name.lower(): 102 | conv_mode = "llava_v1" 103 | elif "mpt" in model_name.lower(): 104 | conv_mode = "mpt" 105 | else: 106 | conv_mode = "llava_v0" 107 | 108 | if args.conv_mode is not None and conv_mode != args.conv_mode: 109 | print( 110 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 111 | conv_mode, args.conv_mode, args.conv_mode 112 | ) 113 | ) 114 | else: 115 | args.conv_mode = conv_mode 116 | 117 | conv = conv_templates[args.conv_mode].copy() 118 | conv.append_message(conv.roles[0], qs) 119 | conv.append_message(conv.roles[1], None) 120 | prompt = conv.get_prompt() 121 | image_files = [image_file] 122 | images = load_images(image_files) 123 | image_sizes = [x.size for x in images] 124 | images_tensor = process_images( 125 | images, 126 | image_processor, 127 | model.config 128 | ).to(model.device, dtype=torch.float16) 129 | 130 | input_ids = ( 131 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 132 | .unsqueeze(0) 133 | .cuda() 134 | ) 135 | 136 | with torch.inference_mode(): 137 | output_ids = model.generate( 138 | input_ids, 139 | images=images_tensor, 140 | image_sizes=image_sizes, 141 | do_sample=True if args.temperature > 0 else False, 142 | temperature=args.temperature, 143 | top_p=args.top_p, 144 | num_beams=args.num_beams, 145 | max_new_tokens=args.max_new_tokens, 146 | use_cache=True, 147 | ) 148 | 149 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 150 | return outputs 151 | 152 | 153 | count = 0 154 | right_count = 0 155 | 156 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f, open( 157 | f'{RESULT_FILE_PATH}llava_lora_4_5.jsonl', 'w+', encoding="utf-8") as fout: 158 | for line in f: 159 | data = json.loads(line) 160 | question = data['question'] 161 | id = data['id'] 162 | options = data['options'] 163 | image_name = data['image'] 164 | image_filepath = IMAGE_DIR + image_name 165 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 166 | output = eval_model(args, question, image_filepath) 167 | count += 1 168 | if len(output) == 0: 169 | output = '--' 170 | if output.lower() in data['answer']: 171 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']} 172 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 173 | right_count += 1 174 | elif data['answer'] in output.lower(): 175 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']} 176 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 177 | right_count += 1 178 | else: 179 | result_json = {'id': id, 'result': 0, "output": output.lower(), "answer": data['answer']} 180 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 181 | print(f'{output.lower()}') 182 | print(f"{data['answer']}") 183 | print(f'right_count: {right_count}') 184 | print(f'count: {count}') 185 | print(f'accuracy: {right_count / count}') 186 | 187 | accuracy = right_count / count 188 | print(f'accuracy: {accuracy}') 189 | -------------------------------------------------------------------------------- /Code/eval/calculate_prf1.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score 4 | import warnings 5 | 6 | warnings.filterwarnings("ignore") 7 | 8 | all_correct = 0 9 | 10 | 11 | def load_jsonl(file_path): 12 | data = [] 13 | with open(file_path, 'r', encoding='utf-8') as f: 14 | for line in f: 15 | tmp = json.loads(line.strip()) 16 | if tmp['result'] == 1: 17 | tmp['output'] = tmp['answer'] 18 | data.append(tmp) 19 | return data 20 | 21 | 22 | def main(): 23 | file_list = os.listdir("./model_exp_results/calculate_PRF1/") 24 | # file_list = os.listdir("./human_results_en/") 25 | 26 | # for test_file_path_item in file_list: 27 | # test_file_path = "./model_exp_results/" + test_file_path_item 28 | # 29 | # data = load_jsonl(test_file_path) 30 | # predictions = [] 31 | # for item in data: 32 | # predictions.append(item['result']) 33 | # answers = [] 34 | # for item in data: 35 | # answers.append(item['answer']) 36 | # 37 | # accuracy = accuracy_score(answers, predictions) 38 | # 39 | # # 计算精确率 40 | # precision = precision_score(answers, predictions, average='weighted') 41 | # 42 | # # 计算召回率 43 | # recall = recall_score(answers, predictions, average='weighted') 44 | # 45 | # # 计算F1值 46 | # f1 = f1_score(answers, predictions, average='weighted') 47 | # print("------------------------------------" + test_file_path_item + "-----------------------------------") 48 | # accuracy = "{:,.2f}".format(accuracy) 49 | # precision = "{:,.2f}".format(precision) 50 | # recall = "{:,.2f}".format(recall) 51 | # f1 = "{:,.2f}".format(f1) 52 | # 53 | # print(f"accuracy: {accuracy}") 54 | # print(f"Precision: {precision}") 55 | # print(f"Recall: {recall}") 56 | # print(f"F1 Score: {f1}") 57 | 58 | for test_file_path_item in file_list: 59 | test_file_path = "./model_exp_results/calculate_PRF1/" + test_file_path_item 60 | # test_file_path = "./human_results_en/" + test_file_path_item 61 | 62 | print("------------------------------------" + test_file_path_item + "-----------------------------------") 63 | 64 | data_list = load_jsonl(test_file_path) 65 | print(len(data_list)) 66 | on_pre_num, below_pre_num, left_pre_num, right_pre_num, front_pre_num, behind_pre_num, \ 67 | on_label_num, below_label_num, left_label_num, right_label_num, front_label_num, behind_label_num, \ 68 | on_true_num, below_true_num, left_true_num, right_true_num, front_true_num, behind_true_num = \ 69 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 70 | 71 | for data in data_list: 72 | if data['output'] in "on/above": 73 | on_pre_num += 1 74 | elif data['output'] in "below": 75 | below_pre_num += 1 76 | elif data['output'] in "left of": 77 | left_pre_num += 1 78 | elif data['output'] in "right of": 79 | right_pre_num += 1 80 | elif data['output'] in "in front of": 81 | front_pre_num += 1 82 | elif data['output'] in "behind": 83 | behind_pre_num += 1 84 | 85 | if data['answer'] == "on/above": 86 | on_label_num += 1 87 | elif data['answer'] == "below": 88 | below_label_num += 1 89 | elif data['answer'] == "left of": 90 | left_label_num += 1 91 | elif data['answer'] == "right of": 92 | right_label_num += 1 93 | elif data['answer'] == "in front of": 94 | front_label_num += 1 95 | elif data['answer'] == "behind": 96 | behind_label_num += 1 97 | 98 | if data['answer'] == data['output']: 99 | if data['answer'] == "on/above": 100 | on_true_num += 1 101 | elif data['answer'] == "below": 102 | below_true_num += 1 103 | elif data['answer'] == "left of": 104 | left_true_num += 1 105 | elif data['answer'] == "right of": 106 | right_true_num += 1 107 | elif data['answer'] == "in front of": 108 | front_true_num += 1 109 | elif data['answer'] == "behind": 110 | behind_true_num += 1 111 | if on_pre_num == 0: 112 | on_pre_num = 1 113 | if below_pre_num == 0: 114 | below_pre_num = 1 115 | if left_pre_num == 0: 116 | left_pre_num = 1 117 | if right_pre_num == 0: 118 | right_pre_num = 1 119 | if front_pre_num == 0: 120 | front_pre_num = 1 121 | if behind_pre_num == 0: 122 | behind_pre_num = 1 123 | 124 | p = (on_true_num / on_pre_num + below_true_num / below_pre_num + left_true_num / left_pre_num 125 | + right_true_num / right_pre_num + front_true_num / front_pre_num 126 | + behind_true_num / behind_pre_num) / 6 127 | print(on_true_num, on_pre_num, below_true_num, below_pre_num, left_true_num, left_pre_num, 128 | right_true_num, right_pre_num, front_true_num, front_pre_num, 129 | behind_true_num, behind_pre_num) 130 | r = (on_true_num / on_label_num + below_true_num / below_label_num + left_true_num / left_label_num 131 | + right_true_num / right_label_num + front_true_num / front_label_num + 132 | behind_true_num / behind_label_num) / 6 133 | f1 = 2 * p * r / (p + r) 134 | acc = (on_true_num + below_true_num + left_true_num + right_true_num + front_true_num + 135 | behind_true_num) / (on_label_num + below_label_num + left_label_num + right_label_num 136 | + front_label_num + behind_label_num) 137 | 138 | # p = ((on_true_num + below_true_num) / (on_pre_num + below_pre_num) + (left_true_num + right_true_num) 139 | # / (left_pre_num + right_pre_num) + (front_true_num + behind_true_num) / (front_pre_num + 140 | # behind_pre_num)) / 3 141 | # r = ((on_true_num + below_true_num) / (on_label_num + below_label_num) + (left_true_num + right_true_num) 142 | # / (left_label_num + right_label_num) + (front_true_num + behind_true_num) / (front_label_num + 143 | # behind_label_num)) / 3 144 | # f1 = 2 * p * r / (p + r) 145 | # acc = (on_true_num + below_true_num + left_true_num + right_true_num + front_true_num + 146 | # behind_true_num) / (on_label_num + below_label_num + left_label_num + right_label_num 147 | # + front_label_num + behind_label_num) 148 | 149 | print(f"accuracy: {acc}") 150 | print(f"Precision: {p}") 151 | print(f"Recall: {r}") 152 | print(f"F1 Score: {f1}") 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /Code/finetune/blip2-lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from transformers import BlipProcessor, BlipForQuestionAnswering 5 | from transformers import Blip2Processor, AutoProcessor, Blip2ForConditionalGeneration, \ 6 | AutoModelForVisualQuestionAnswering 7 | from datasets import load_dataset 8 | import torch 9 | import torch.nn as nn 10 | from PIL import Image 11 | from torch.utils.data import Dataset, DataLoader 12 | from tqdm import tqdm 13 | import pickle 14 | from peft import LoraConfig, get_peft_model 15 | 16 | model_dir = '/projects/Models/' 17 | processor = Blip2Processor.from_pretrained(f"{model_dir}blip2-opt-2.7b") 18 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/' 19 | 20 | train_ds = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip2_train/train_3780.jsonl", 21 | split="train[:100%]") 22 | eval_ds = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip2_train/dev_536.jsonl", 23 | split="train[:100%]") 24 | print("Training sets: {} - Validating set: {}".format(len(train_ds), len(eval_ds))) 25 | 26 | device = "cuda:7" if torch.cuda.is_available() else "cpu" 27 | 28 | model = Blip2ForConditionalGeneration.from_pretrained(f"{model_dir}blip2-opt-2.7b", device_map="cuda:7", 29 | load_in_8bit=True) 30 | 31 | # Let's define the LoraConfig 32 | config = LoraConfig( 33 | r=16, 34 | lora_alpha=32, 35 | lora_dropout=0.05, 36 | bias="none", 37 | target_modules=["q_proj", "k_proj"] 38 | ) 39 | 40 | model = get_peft_model(model, config).to(device) 41 | 42 | torch.cuda.empty_cache() 43 | torch.manual_seed(42) 44 | 45 | 46 | class ImageCaptioningDataset(Dataset): 47 | def __init__(self, dataset, processor): 48 | self.dataset = dataset 49 | self.processor = processor 50 | 51 | def __len__(self): 52 | return len(self.dataset) 53 | 54 | def __getitem__(self, idx): 55 | # get image + text 56 | question = self.dataset[idx]['question'] 57 | answer = str(self.dataset[idx]['answer']) 58 | image_id = self.dataset[idx]['image'] 59 | image_path = f"/projects/SpatialMQA/COCO2017/test2017/{image_id}" 60 | image = Image.open(image_path).convert("RGB") 61 | 62 | encoding = self.processor(images=image, text=question, return_tensors="pt") 63 | 64 | labels = self.processor.tokenizer.tokenize(answer, return_tensors='pt') 65 | labels = torch.tensor(self.processor.tokenizer.convert_tokens_to_ids(labels)).unsqueeze(0) 66 | encoding["labels"] = torch.cat((labels, torch.tensor([50118]).unsqueeze(0)), dim=1) 67 | 68 | # remove batch dimension 69 | for k, v in encoding.items(): encoding[k] = v.squeeze() 70 | return encoding 71 | 72 | 73 | train_dataset = ImageCaptioningDataset(dataset=train_ds, processor=processor) 74 | valid_dataset = ImageCaptioningDataset(dataset=eval_ds, processor=processor) 75 | 76 | batch_size = 8 77 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 78 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 79 | 80 | optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5) 81 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False) 82 | 83 | cal_num = 2 84 | num_epochs = 30 85 | patience = 5 86 | min_eval_loss = float("inf") 87 | early_stopping_hook = 0 88 | tracking_information = [] 89 | scaler = torch.cuda.amp.GradScaler() 90 | criterion = nn.CrossEntropyLoss(ignore_index=1) 91 | 92 | for epoch in range(num_epochs): 93 | epoch_loss = 0 94 | cal_loss = 0 95 | model.train() 96 | for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader): 97 | input_ids = batch.pop('input_ids').to(device) 98 | pixel_values = batch.pop('pixel_values').to(device) 99 | attention_mask = batch.pop('attention_mask').to(device) 100 | labels = batch.pop('labels').to(device) 101 | 102 | with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): 103 | outputs = model(input_ids=input_ids, 104 | pixel_values=pixel_values, 105 | attention_mask=attention_mask).logits 106 | 107 | loss = criterion(outputs.view(-1, outputs.shape[-1])[:labels.shape[1], :].contiguous(), 108 | labels.view(-1).contiguous()) 109 | epoch_loss += loss.item() 110 | optimizer.zero_grad() 111 | 112 | cal_loss += loss 113 | if (idx + 1) % cal_num == 0 or idx == len(train_dataloader) - 1: 114 | if (idx + 1) % cal_num == 0: 115 | cal_loss = cal_loss / cal_num 116 | else: 117 | cal_loss = cal_loss / ((idx + 1) % cal_num) 118 | print('loss:', cal_loss) 119 | scaler.scale(cal_loss).backward() 120 | scaler.step(optimizer) 121 | scaler.update() 122 | cal_loss = 0 123 | 124 | model.eval() 125 | eval_loss = 0 126 | for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader): 127 | input_ids = batch.pop('input_ids').to(device) 128 | pixel_values = batch.pop('pixel_values').to(device) 129 | attention_mask = batch.pop('attention_mask').to(device) 130 | labels = batch.pop('labels').to(device) 131 | # print('labels:',labels) 132 | # outputs = model(input_ids=input_ids, 133 | # pixel_values=pixel_values, 134 | # attention_mask=attention_mask, 135 | # labels=labels) 136 | # loss = outputs.loss 137 | 138 | # labels_mask = batch.pop('labels_mask').to(device) 139 | outputs = model(input_ids=input_ids, 140 | pixel_values=pixel_values, 141 | attention_mask=attention_mask).logits 142 | # loss = criterion(outputs.view(-1, outputs.shape[-1])[:8, :].contiguous() * labels_mask.view(-1, 1).contiguous(), labels.view(-1).contiguous()) 143 | loss = criterion(outputs.view(-1, outputs.shape[-1])[:labels.shape[1], :].contiguous(), 144 | labels.view(-1).contiguous()) 145 | 146 | eval_loss += loss.item() 147 | # break 148 | 149 | # break 150 | 151 | tracking_information.append( 152 | (epoch_loss / len(train_dataloader), eval_loss / len(valid_dataloader), optimizer.param_groups[0]["lr"])) 153 | print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch + 1, epoch_loss / len(train_dataloader), 154 | eval_loss / len(valid_dataloader), 155 | optimizer.param_groups[0]["lr"])) 156 | scheduler.step() 157 | if eval_loss < min_eval_loss: 158 | model.save_pretrained(f"/projects/SpatialMQA/finetune_models/models_arg/blip2_lora_20240531/epoch-{epoch}", 159 | from_pt=True) 160 | print(f"Saved model to /projects/SpatialMQA/finetune_models/models_arg/blip2_lora_20240531") 161 | min_eval_loss = eval_loss 162 | early_stopping_hook = 0 163 | else: 164 | early_stopping_hook += 1 165 | if early_stopping_hook > patience: 166 | break 167 | 168 | pickle.dump(tracking_information, open("tracking_information.pkl", "wb")) 169 | print("The finetuning process has done!") 170 | -------------------------------------------------------------------------------- /Code/experiment/spacellava_lora_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 3 | 4 | import argparse 5 | import torch 6 | import json 7 | from llava.constants import ( 8 | IMAGE_TOKEN_INDEX, 9 | DEFAULT_IMAGE_TOKEN, 10 | DEFAULT_IM_START_TOKEN, 11 | DEFAULT_IM_END_TOKEN, 12 | IMAGE_PLACEHOLDER, 13 | ) 14 | from llava.conversation import conv_templates, SeparatorStyle 15 | from llava.model.builder import load_pretrained_model 16 | from llava.utils import disable_torch_init 17 | from llava.mm_utils import ( 18 | process_images, 19 | tokenizer_image_token, 20 | get_model_name_from_path, 21 | ) 22 | 23 | from PIL import Image 24 | 25 | import requests 26 | from PIL import Image 27 | from io import BytesIO 28 | import re 29 | RESULT_FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/results/20240816/1000_valid.jsonl' 30 | # FILE_PATH = '/projects/SpatialMethod/Dataset/En_new/' 31 | FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/valid.jsonl' 32 | IMAGE_DIR = '/home/zouyinan/SpatialVLM/SpaceLLaVa/pic/' 33 | def image_parser(args): 34 | out = args.image_file.split(args.sep) 35 | return out 36 | 37 | 38 | def load_image(image_file): 39 | if image_file.startswith("http") or image_file.startswith("https"): 40 | response = requests.get(image_file) 41 | image = Image.open(BytesIO(response.content)).convert("RGB") 42 | else: 43 | image = Image.open(image_file).convert("RGB") 44 | return image 45 | 46 | 47 | def load_images(image_files): 48 | out = [] 49 | for image_file in image_files: 50 | image = load_image(image_file) 51 | out.append(image) 52 | return out 53 | 54 | 55 | # Model 56 | disable_torch_init() 57 | model_path='/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/spacellava_hf' 58 | args = type('Args', (), { 59 | "model_path": model_path, 60 | "model_base": None, 61 | "model_name": get_model_name_from_path(model_path), 62 | # "query": prompt, 63 | "conv_mode": None, 64 | # "image_file": image_file, 65 | "sep": ",", 66 | "temperature": 0.1, 67 | "top_p": None, 68 | "num_beams": 1, 69 | "max_new_tokens": 512 70 | })() 71 | 72 | model_name = get_model_name_from_path(model_path) 73 | tokenizer, model, image_processor, context_len = load_pretrained_model( 74 | args.model_path, None, model_name 75 | ) 76 | # peft_model_id = "/projects/SpatialMQA/finetune_models/models_arg/llava_lora_20240522" 77 | # peft_model_id = '/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/saved_model/spacellava_lora_20240812_3/checkpoint-460' 78 | peft_model_id = "/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/saved_model/spacellava_lora_20240816_1000" 79 | model.load_adapter(peft_model_id) 80 | 81 | def eval_model(args,question,image_file): 82 | qs = question 83 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 84 | if IMAGE_PLACEHOLDER in qs: 85 | if model.config.mm_use_im_start_end: 86 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 87 | else: 88 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 89 | else: 90 | if model.config.mm_use_im_start_end: 91 | qs = image_token_se + "\n" + qs 92 | else: 93 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 94 | 95 | if "llama-2" in model_name.lower(): 96 | conv_mode = "llava_llama_2" 97 | elif "mistral" in model_name.lower(): 98 | conv_mode = "mistral_instruct" 99 | elif "v1.6-34b" in model_name.lower(): 100 | conv_mode = "chatml_direct" 101 | elif "v1" in model_name.lower(): 102 | conv_mode = "llava_v1" 103 | elif "mpt" in model_name.lower(): 104 | conv_mode = "mpt" 105 | else: 106 | conv_mode = "llava_v0" 107 | 108 | if args.conv_mode is not None and conv_mode != args.conv_mode: 109 | print( 110 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 111 | conv_mode, args.conv_mode, args.conv_mode 112 | ) 113 | ) 114 | else: 115 | args.conv_mode = conv_mode 116 | 117 | conv = conv_templates[args.conv_mode].copy() 118 | conv.append_message(conv.roles[0], qs) 119 | conv.append_message(conv.roles[1], None) 120 | prompt = conv.get_prompt() 121 | image_files = [image_file] 122 | images = load_images(image_files) 123 | image_sizes = [x.size for x in images] 124 | images_tensor = process_images( 125 | images, 126 | image_processor, 127 | model.config 128 | ).to(model.device, dtype=torch.float16) 129 | 130 | input_ids = ( 131 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 132 | .unsqueeze(0) 133 | .cuda() 134 | ) 135 | 136 | with torch.inference_mode(): 137 | output_ids = model.generate( 138 | input_ids, 139 | images=images_tensor, 140 | image_sizes=image_sizes, 141 | do_sample=True if args.temperature > 0 else False, 142 | temperature=args.temperature, 143 | top_p=args.top_p, 144 | num_beams=args.num_beams, 145 | max_new_tokens=args.max_new_tokens, 146 | use_cache=True, 147 | ) 148 | 149 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 150 | return outputs 151 | count = 0 152 | right_count = 0 153 | 154 | with open(f'{FILE_PATH}', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}', 'w+', encoding="utf-8") as fout: 155 | for line in f: 156 | data = json.loads(line) 157 | question = data['question'] 158 | id = data['id'] 159 | options = data['options'] 160 | image_name = data['image'] 161 | image_filepath = IMAGE_DIR + image_name 162 | # question = f'Question: {question} According to the question, please choose one best option from the following options separated by semicolon. Note that you only need to answer one word from the options without explaining the reason.\nOptions: {"; ".join(options)} \nAnswer:' 163 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question, and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: , Question: {question}, Options: {"; ".join(options)}. \n Output:' 164 | output = eval_model(args,question,image_filepath) 165 | 166 | count += 1 167 | if len(output) == 0: 168 | output = '--' 169 | if output.lower() in data['answer']: 170 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 171 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 172 | right_count += 1 173 | elif data['answer'] in output.lower(): 174 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']} 175 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 176 | right_count += 1 177 | else: 178 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']} 179 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n') 180 | print(f'{output.lower()}') 181 | print(f"{data['answer']}") 182 | print(f'right_count: {right_count}') 183 | print(f'count: {count}') 184 | print(f'accuracy: {right_count/count}') 185 | # print(f'accuracy: {right_count/count}') 186 | 187 | accuracy = right_count/count 188 | print(f'accuracy: {accuracy}') -------------------------------------------------------------------------------- /Code/close_models/gemini_1_shot_random.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | import PIL.Image as Image 3 | import google.generativeai as genai 4 | from io import BytesIO 5 | import requests 6 | import random 7 | import json 8 | 9 | genai.configure(api_key="your key",transport="rest") 10 | 11 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480} 12 | safety_settings = [ 13 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, 14 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 15 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, 16 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"} 17 | ] 18 | 19 | FILE_PATH = './datasets/test_en_select_500_sort1.jsonl' 20 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 21 | RESULT_FILE_PATH = './model_results/gemini_1_shot_random.jsonl' 22 | 23 | 24 | example_list_rule1 = [ 25 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."}, 26 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."}, 27 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."} 28 | ] 29 | 30 | example_list_rule2 = [ 31 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."}, 32 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."}, 33 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."} 34 | ] 35 | 36 | example_list_rule3 = [ 37 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."}, 38 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."}, 39 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."} 40 | ] 41 | 42 | 43 | example_list = [ 44 | random.choice(example_list_rule1), 45 | random.choice(example_list_rule2), 46 | random.choice(example_list_rule3) 47 | ] 48 | 49 | 50 | def fetch_image_content(image_url): 51 | response = requests.get(image_url) 52 | if response.status_code == 200: 53 | return BytesIO(response.content) 54 | else: 55 | return None 56 | 57 | 58 | def choose_1_example(): 59 | example_1 = random.choice(example_list_rule1 + example_list_rule1 + example_list_rule1) 60 | return example_1 61 | 62 | 63 | # few-shot & zero-shot: 64 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings) 65 | 66 | # text-only: 67 | # model = genai.GenerativeModel('gemini-pro', generation_config=generation_config, safety_settings=safety_settings) 68 | 69 | 70 | count = 0 71 | right_count = 0 72 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout: 73 | for line in f: 74 | sleep(1) 75 | data = json.loads(line) 76 | id = data['id'] 77 | 78 | example_1 = choose_1_example() 79 | 80 | # 1 - shot: 81 | # question = f'You are currently a senior expert in spatial relation reasoning. ' \ 82 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 83 | # f'\nGiven the following 1 examples to learn the spatial relation reasoning task:' \ 84 | # f'\n{example_1} ' \ 85 | # f'Input: Image:, Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 86 | prompt1 = f'You are currently a senior expert in spatial relation reasoning. ' \ 87 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 88 | f'\nGiven the following 1 examples to learn the spatial relation reasoning task:' \ 89 | f'\nInput: Image: ' 90 | prompt2 = f'\n{example_1["text"]} ' \ 91 | f'\nInput: Image: ' 92 | prompt3 = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 93 | 94 | image_eg1_id = example_1["image"] 95 | image_eg1_url = f'{IMAGE_DIR}/{image_eg1_id}' 96 | image_eg1_content = fetch_image_content(image_eg1_url) 97 | 98 | image_id = data['image'] 99 | image_url = f'{IMAGE_DIR}/{image_id}' 100 | image_content = fetch_image_content(image_url) 101 | 102 | if (image_eg1_content is not None) and (image_content is not None): 103 | image_eg1 = Image.open(image_eg1_content) 104 | image = Image.open(image_content) 105 | try: 106 | response = model.generate_content( 107 | [prompt1, image_eg1, prompt2, image, prompt3] 108 | ) 109 | except Exception as e: 110 | print(e) 111 | try: 112 | response = model.generate_content( 113 | [prompt1, image_eg1, prompt2, image, prompt3] 114 | ) 115 | except Exception as e: 116 | try: 117 | response = model.generate_content( 118 | [prompt1, image_eg1, prompt2, image, prompt3] 119 | ) 120 | except Exception as e: 121 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1} 122 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 123 | count += 1 124 | continue 125 | try: 126 | output = response.text.strip().rstrip('.').lower() 127 | except Exception as e: 128 | print(e) 129 | output = '--' 130 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1} 131 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 132 | count += 1 133 | continue 134 | count += 1 135 | if output in data['answer'] or data['answer'] in output: 136 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1} 137 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 138 | right_count += 1 139 | else: 140 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1} 141 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 142 | print(f'{output.lower()}') 143 | print(f"{data['answer']}") 144 | print(f'right_count: {right_count}') 145 | print(f'count: {count}') 146 | # print(f'accuracy: {right_count / count}') 147 | 148 | accuracy = right_count / count 149 | print(accuracy) 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 |

🔭 Can Multimodal Large Language Models Understand Spatial Relations

4 |

SpatialMQA: A new benchmark dataset for spatial reasoning of MLLMs.

5 | 6 |

7 | 8 | · 9 | huggingface 10 | · 11 | github 12 | · 13 | license 14 | 15 | 16 |

17 |

18 | 19 | 20 | ## Contents 21 | 22 | - [SpatialMQA](#Contents) 23 | - [Overview](#1-Overview) 24 | - [Examples](#Examples) 25 | - [Detail Information](#Detail-Information) 26 | - [Access SpatialMQA](#2-Access-SpatialMQA) 27 | - [Download Images](#Download-Images) 28 | - [Data Split](#Data-Split) 29 | - [Data Format](#Data-Format) 30 | - [Experiment & Evaluation](#3-Experiment-and-Evaluation) 31 | - [Experiment](#Experiment) 32 | - [Evaluation](#Evaluation) 33 | - [License](#4-License) 34 | 35 | 36 | 37 | 38 | ## 1 Overview 39 | **SpatialMQA** is a **manually annotated** dataset designed for **multimodal spatial relation reasoning** in a **m**ultiple-choice **q**uestion & **a**nswer format. 40 | The dataset includes 5,392 samples collected from COCO2017, covering 128 subject and object types, without bounding boxes. To address the limitations of existing datasets, we clearly define annotation guidelines for SpatialMQA, including standardizing the objective world as the coordinate system and avoiding questions that can be answered solely by the question itself. 41 | 42 | ### Examples 43 | The following figures list some classic examples in our dataset. You can click out [`Examples:1~4/`](Examples/examples_1-4.png) and [`Examples:5~8/`](Examples/examples_5-8.png) to view the details. 44 | 45 | ### Detail Information 46 | The following table [`Splits/`](Comparison/splits.png) lists the detailed information statistics of the splited dataset. 47 |
48 | You can find our dataset through the following path **_(Dataset/dataset)_** for more details. 49 |
50 | _Due to the fact that only redirecting to the specified file is valid in anonymous links, redirecting to the specified directory is invalid. Therefore, we use bold and italicized font to indicate the markings of all specified directories, making it easier for reviewers to search. Thank you!_ 51 | 52 | 53 | ## 2 Access SpatialMQA 54 | Our dataset has been officially released on the Hugging Face. It is available at https://huggingface.co/datasets/liuziyan/SpatialMQA. 55 | Alternatively, you can download it from GitHub by following the steps below: 56 | ### Download Images 57 | We use a subset of COCO-2017's images. The following script download COCO-2017's test sets images then put them into a single fodler `Dataset/COCO2017/`. 58 | 59 | ```bash 60 | cd Dataset/ 61 | wget http://images.cocodataset.org/zips/test2017.zip 62 | unzip test2017.zip 63 | mv test2017 COCO2017 && rm -r test2017 64 | ``` 65 | Copy only relevant images to `relevant_images/`. 66 | ```bash 67 | mkdir relevant_images 68 | cd tool 69 | python select_revlevant_images.py 70 | ``` 71 | Alternatively, you could also browse individual images online directly using the key "image" in single json data. 72 |
(Through COCO's open source link, 'http://images.cocodataset.org/test2017/' + 'image_name'. For example: http://images.cocodataset.org/test2017/000000195921.jpg.) 73 | 74 | ### Data Split 75 | As reported in the folloeing table, SpatialMQA contains 5,392 samples, divided into training, validation, and test sets according to a 7:1:2 ratio. 76 |
All the splited data sets are in the directory **_(Dataset/dataset)_**. 77 |
In addition, we have selected the part of the data set that contains invisible subjects or objects and placed it in the file [`invisible/`](Dataset/invisible/invisible.jsonl) to facilitate readers' research. 78 | 79 | ### Data Format 80 | Each `jsonl` file is of the following format: 81 | ```json 82 | {"image": "000000000933.jpg", "question": "Where is the fork located relative to the pizza?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "right of"} 83 | {"image": "000000100633.jpg", "question": "If you are the cyclist in the image, where is the dog located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"} 84 | {"image": "000000070986.jpg", "question": "If you are the driver of the bus in the image, from your perspective, where is the red car located relative to the bus?", "options": ["in front of", "behind", "left of", "right of"], "answer": "left of"} 85 | {"..."} 86 | ``` 87 | Each line is an individual data point. 88 | `image` denotes name of the image in COCO. `question` is the question with manual annotation, `options` is reasonable combinations of six spatial relationships:(on/above, below, in front of, behind, left of, right of. `answer` is the annotation based on the objective world. 89 |
90 | Our dataset is expanded based on the categories included in the COCO dataset. There are 113 subject types and one additional type for subjects with five or fewer samples in our dataset, and 84 object types and one additional type for objects with five or fewer samples. Due to the overlap between subject and object types, we have a total of 128 distinct subject and object types. You can see all of them in the file [`S & O types/`](Dataset/types/types.txt). 91 | 92 | 93 | ## 3 Experiment and Evaluation 94 | ### Experiment 95 | We have disclosed the inference code for the model in the directory **_(Code/experiment)_**, as well as the fine-tuning code in the directory **_(Code/finetune)_**. 96 |
97 | - For all 7 open-sourse MLLMs, you can directly execute Python files in the directory **_(Code/experiment)_** to perform inference on models before and after fine-tuning: 98 | ``` 99 | nohup python blip-vqa-base.py > log/blip_exp.log 2>1& & 100 | nohup python blip-vqa-base_finetuned.py > log/blip_finetuned_exp.log 2>1& & 101 | nohup python blip2-opt-2.7b.py > log/blip2_exp.log 2>1& & 102 | nohup python blip2-lora.py > log/blip2_lora_exp.log 2>1& & 103 | nohup python instructblip-flan-t5-xl.py > log/instructblip_exp.log 2>1& & 104 | nohup python instructblip-lora.py > log/instructblip_lora_exp.log 2>1& & 105 | nohup python idefics_new.py > log/idefics_exp.log 2>1& & 106 | nohup python idefics_lora.py > log/idefics_lora_exp.log 2>1& & 107 | nohup python spatial_test_llava.py > log/llava_exp.log 2>1& & 108 | nohup python spatial_test_llava_lora.py > log/llava_lora_exp.log 2>1& & 109 | nohup python spatial_test_mplug.py > log/mplug_exp.log 2>1& & 110 | nohup python spatial_test_mplug_lora.py > log/mplug_lora_exp.log 2>1& & 111 | nohup python spacellava_test.py > log/spacellava_exp.log 2>1& & 112 | nohup python spacellava_lora_test.py > log/spacellava_lora_exp.log 2>1& & 113 | ``` 114 | Due to the large amount of open source model code, you need to download it yourself through channels or call it directly from platforms such as [huggingface](https://huggingface.co). 115 |
116 | - For blip, blip2, instructblip and idefics, you can directly execute Python files in the directory **_(Code/finetune)_** to perform fine-tuning: 117 | ``` 118 | nohup python blip-vqa-base.py > log/blip_train.log 2>1& & 119 | nohup python blip2-lora.py > log/blip2_train.log 2>1& & 120 | nohup python instructblip-lora.py > log/instructblip_train.log 2>1& & 121 | nohup python idefics.py > log/idefics_train.log 2>1& & 122 | ``` 123 | - For llava, spacellava and mplug-owl, you need to execute bash files in the directory **_(Code/finetune)_** to perform fine-tuning: 124 | ``` 125 | nohup bash llava_lora_train.sh > log/llava_train.log 2>1& & 126 | nohup bash spacellava_lora_train.sh > log/spacellava_train.log 2>1& & 127 | nohup bash mPLUG_Owl_train_it.sh > log/mplug_train.log 2>1& & 128 | ``` 129 | - For gemini-1.5-flash and gpt-4o, you can directly execute our Python file in the directory **_(Code/close_models)_** to perform inferencing of the zero-shot, few-shot and text-only, provided that you prepare a key: 130 | ``` 131 | python gemini_text_only.py 132 | python gemini_zero_shot.py 133 | python gemini_1_shot.py 134 | python gemini_2_shot.py 135 | python gemini_3_shot.py 136 | python gpt4_text_only.py 137 | python gpt4_zero_shot.py 138 | python gpt4_1_shot.py 139 | python gpt4_2_shot.py 140 | python gpt4_3_shot.py 141 | ``` 142 | Gemini needs to apply on the [official website](https://aistudio.google.com/app/apikey), and GPT4 needs to be purchased on the [official website](https://openai.com/). 143 | 144 | ### Evaluation 145 | You can process the results of model inference through the code we provide to calculate overall accuracy, overall P, R, F1 indicators, accuracy based on relationship categories, and accuracy based on rules. We integrate the calculation process into the Python files in the directory **_(Code/eval)_**: 146 | ``` 147 | python calculate_prf1.py 148 | python calculate_xyz.py 149 | python calculate_result_rule.py 150 | ``` 151 | 152 | ### Requirements 153 | The environment configuration required for debugging code is placed in directory **_(Code/requirement)_** 154 |
155 | The requirements of models blip, blip2 and instructblip, are all in the file [`requirement_blip.txt/`](Code/requirement/requirement_blip.txt) 156 | 157 | ## 4 License 158 | This project is licensed under the [Apache-2.0 License](LICENSE). 159 | -------------------------------------------------------------------------------- /Code/close_models/gemini_2_shot_random.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | 3 | import google.generativeai as genai 4 | from io import BytesIO 5 | import requests 6 | import random 7 | 8 | genai.configure(api_key="your key", transport="rest") 9 | 10 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480} 11 | safety_settings = [ 12 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, 13 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 14 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, 15 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"} 16 | ] 17 | 18 | FILE_PATH = 'test_en_select_500_sort_cp.jsonl' 19 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 20 | RESULT_FILE_PATH = 'results/gemini_2_shot_random_1.jsonl' 21 | 22 | 23 | example_list_rule1 = [ 24 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."}, 25 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."}, 26 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."} 27 | ] 28 | 29 | example_list_rule2 = [ 30 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."}, 31 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."}, 32 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."} 33 | ] 34 | 35 | example_list_rule3 = [ 36 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."}, 37 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."}, 38 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."} 39 | ] 40 | 41 | 42 | def pop_exist(data, datalist): 43 | data_list = datalist.copy() 44 | data_list.remove(data) 45 | return random.choice(data_list) 46 | 47 | 48 | E11 = random.choice(example_list_rule1) 49 | E12 = pop_exist(E11, example_list_rule1) 50 | E21 = random.choice(example_list_rule1) 51 | E22 = random.choice(example_list_rule2) 52 | E31 = random.choice(example_list_rule1) 53 | E32 = random.choice(example_list_rule3) 54 | E41 = random.choice(example_list_rule2) 55 | E42 = pop_exist(E41, example_list_rule2) 56 | E51 = random.choice(example_list_rule2) 57 | E52 = random.choice(example_list_rule3) 58 | E61 = random.choice(example_list_rule3) 59 | E62 = pop_exist(E61, example_list_rule3) 60 | 61 | 62 | example_list = [ 63 | [E11, E12], 64 | [E21, E22], 65 | [E31, E32], 66 | [E41, E42], 67 | [E51, E52], 68 | [E61, E62] 69 | ] 70 | 71 | 72 | def fetch_image_content(image_url): 73 | response = requests.get(image_url) 74 | if response.status_code == 200: 75 | return BytesIO(response.content) 76 | else: 77 | return None 78 | 79 | 80 | import PIL.Image as Image 81 | 82 | # few-shot & zero-shot: 83 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings) 84 | 85 | # text-only: 86 | # model = genai.GenerativeModel('gemini-pro', generation_config=generation_config, safety_settings=safety_settings) 87 | 88 | import json 89 | 90 | count = 0 91 | right_count = 0 92 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout: 93 | for line in f: 94 | sleep(1) 95 | data = json.loads(line) 96 | id = data['id'] 97 | 98 | example_2 = example_list[count % 6] 99 | 100 | # 2 - shot 101 | prompt = ['', '', '', ''] 102 | prompt[0] = f'You are currently a senior expert in spatial relation reasoning. ' \ 103 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 104 | f'\nGiven the following 2 examples to learn the spatial relation reasoning task:' \ 105 | f'\nExample1: Input: Image: ' 106 | prompt[1] = f'\n{example_2[0]["text"]} ' \ 107 | f'\nExample2: Input: Image: ' 108 | prompt[2] = f'\n{example_2[1]["text"]} ' \ 109 | f'\nInput: Image: ' 110 | 111 | # prompt = f'Given the following 3 examples to learn the spatial relation reasoning task:' \ 112 | # f'\n{example_list[count % 10][0]["text"]} ' \ 113 | # f'\nYou are currently a senior expert in spatial relation reasoning. ' \ 114 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 115 | # f'\nInput: Image: ' 116 | prompt[3] = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 117 | 118 | image_list = [] 119 | for item in example_2: 120 | image_id = item['image'] 121 | image_url = f'{IMAGE_DIR}/{image_id}' 122 | image_content = fetch_image_content(image_url) 123 | if image_content is not None: 124 | image = Image.open(image_content) 125 | image_list.append(image) 126 | 127 | image_id = data['image'] 128 | image_url = f'{IMAGE_DIR}/{image_id}' 129 | image_content = fetch_image_content(image_url) 130 | 131 | if image_content is not None: 132 | image = Image.open(image_content) 133 | image_list.append(image) 134 | try: 135 | response = model.generate_content( 136 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3]] # few-shot & zero-shot 137 | # [question] # text-only 138 | ) 139 | except Exception as e: 140 | print(e) 141 | try: 142 | response = model.generate_content( 143 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3]] # few-shot & zero-shot 144 | # [question] # text-only 145 | ) 146 | except Exception as e: 147 | try: 148 | response = model.generate_content( 149 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3]] # few-shot & zero-shot 150 | # [question] # text-only 151 | ) 152 | except Exception as e: 153 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 154 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 155 | count += 1 156 | continue 157 | try: 158 | output = response.text.strip().rstrip('.').lower() 159 | except Exception as e: 160 | print(e) 161 | output = '--' 162 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 163 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 164 | count += 1 165 | continue 166 | count += 1 167 | if output in data['answer'] or data['answer'] in output: 168 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 169 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 170 | right_count += 1 171 | else: 172 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 173 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 174 | print(f'{output.lower()}') 175 | print(f"{data['answer']}") 176 | print(f'right_count: {right_count}') 177 | print(f'count: {count}') 178 | 179 | accuracy = right_count / count 180 | print(accuracy) 181 | -------------------------------------------------------------------------------- /Code/close_models/gpt4_1_shot.py: -------------------------------------------------------------------------------- 1 | # import openai 2 | from openai import OpenAI 3 | import json 4 | import requests 5 | from io import BytesIO 6 | import PIL.Image as Image 7 | import base64 8 | import random 9 | 10 | client = OpenAI( 11 | api_key='your key', 12 | base_url='https://api.mnxcc.com/v1' 13 | ) 14 | 15 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 16 | 17 | example_list_rule1 = [ 18 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."}, 19 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."}, 20 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."} 21 | ] 22 | 23 | example_list_rule2 = [ 24 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."}, 25 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."}, 26 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."} 27 | ] 28 | 29 | example_list_rule3 = [ 30 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."}, 31 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."}, 32 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."} 33 | ] 34 | 35 | 36 | def choose_1_example(): 37 | example_1 = random.choice(example_list_rule1 + example_list_rule1 + example_list_rule1) 38 | return example_1 39 | 40 | 41 | # count = 0 42 | # right_count = 0 43 | 44 | 45 | def fetch_image_content(image_url): 46 | response = requests.get(image_url) 47 | if response.status_code == 200: 48 | return BytesIO(response.content) 49 | else: 50 | return None 51 | 52 | 53 | def encode_image(image): 54 | if image is None: 55 | return None 56 | 57 | buffered = BytesIO() 58 | try: 59 | image.save(buffered, format="JPEG") 60 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') 61 | return 'data:image/jpeg;base64,' + img_str 62 | except Exception as e: 63 | print(f"error: {e}") 64 | return None 65 | 66 | 67 | def call_gpt4(prompt1: str, prompt2: str, prompt3: str, image_eg1, image, detail='auto'): 68 | try: 69 | response = client.chat.completions.create( 70 | # model="gpt-4-vision-preview",gpt-4-turbo 71 | model="gpt-4o", 72 | messages=[ 73 | { 74 | "role": "user", 75 | # zero-shot + few-shot 76 | "content": [{"type": "text", "text": prompt1}] \ 77 | + [{"type": "image_url", "image_url": { "url": image_eg1, "detail": detail}}] \ 78 | + [{"type": "text", "text": prompt2}] \ 79 | + [{"type": "image_url", "image_url": { "url": image, "detail": detail}}] \ 80 | + [{"type": "text", "text": prompt3}] 81 | # text only: 82 | # "content": [{"type": "text", "text": question}] 83 | } 84 | ], 85 | max_tokens=500, 86 | temperature=0.5, 87 | ) 88 | # print(response.choices[0].message.content.strip()) 89 | return response.choices[0].message.content.strip() 90 | 91 | except Exception as e: 92 | print(f"Error during answering: {e}") 93 | return None 94 | 95 | 96 | def process_jsonl(input_file, output_file): 97 | count = 0 98 | right_count = 0 99 | 100 | with open(input_file, 'r', encoding='utf-8') as file: 101 | with open(output_file, 'w', encoding='utf-8') as out_file: 102 | for line in file: 103 | data = json.loads(line) 104 | id = data['id'] 105 | 106 | example_1 = choose_1_example() 107 | 108 | # 1 - shot 109 | prompt1 = f'You are currently a senior expert in spatial relation reasoning. ' \ 110 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 111 | f'\nGiven the following 1 examples to learn the spatial relation reasoning task:' \ 112 | f'\nInput: Image: ' 113 | prompt2 = f'\n{example_1["text"]} ' \ 114 | f'\nInput: Image: ' 115 | prompt3 = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 116 | 117 | image_eg1_id = example_1["image"] 118 | image_eg1_url = f'{IMAGE_DIR}/{image_eg1_id}' 119 | image_eg1_content = fetch_image_content(image_eg1_url) 120 | 121 | image_id = data['image'] 122 | image_url = f'{IMAGE_DIR}/{image_id}' 123 | image_content = fetch_image_content(image_url) 124 | 125 | if (image_eg1_content is not None) and (image_content is not None): 126 | image_eg1 = Image.open(image_eg1_content) 127 | image_eg1_encoded = encode_image(image_eg1) 128 | 129 | image = Image.open(image_content) 130 | image_encoded = encode_image(image) 131 | 132 | try: 133 | model_answer = call_gpt4(prompt1, prompt2, prompt3, image_eg1_encoded, image_encoded) 134 | except Exception as e: 135 | print(e) 136 | try: 137 | model_answer = call_gpt4(prompt1, prompt2, prompt3, image_eg1_encoded, image_encoded) 138 | except Exception as e: 139 | try: 140 | model_answer = call_gpt4(prompt1, prompt2, prompt3, image_eg1_encoded, image_encoded) 141 | except Exception as e: 142 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], 143 | "rule": data['rule'], "example": 0} 144 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 145 | count += 1 146 | continue 147 | 148 | try: 149 | # output = model_answer.text.strip().rstrip('.').lower() 150 | output = model_answer.strip().rstrip('.').lower() 151 | except Exception as e: 152 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], 153 | "rule": data['rule'], "example": 0} 154 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 155 | count += 1 156 | continue 157 | count += 1 158 | if output in data['answer'] or data['answer'] in output: 159 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], 160 | "rule": data['rule'], "example": 0} 161 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 162 | right_count += 1 163 | else: 164 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], 165 | "rule": data['rule'], "example": 0} 166 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 167 | print(f'{output.lower()}') 168 | print(f"{data['answer']}") 169 | print(f'right_count: {right_count}') 170 | print(f'count: {count}') 171 | # print(f'accuracy: {right_count / count}') 172 | 173 | accuracy = right_count / count 174 | print(accuracy) 175 | 176 | 177 | input_file_path = "test_en_select_500_sort_cp.jsonl" 178 | output_file_path = "results/gpt4_1_shot_random_new_cp.jsonl" 179 | 180 | process_jsonl(input_file_path, output_file_path) 181 | -------------------------------------------------------------------------------- /Code/close_models/gemini_3_shot_random.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | 3 | import google.generativeai as genai 4 | from io import BytesIO 5 | import requests 6 | import random 7 | 8 | genai.configure(api_key="your key", transport="rest") 9 | 10 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480} 11 | safety_settings = [ 12 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, 13 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 14 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, 15 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"} 16 | ] 17 | 18 | FILE_PATH = 'test_en_select_500_sort_cp_2.jsonl' 19 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 20 | RESULT_FILE_PATH = 'results/gemini_3_shot_random_cp.jsonl' 21 | 22 | 23 | example_list_rule1 = [ 24 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."}, 25 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."}, 26 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."} 27 | ] 28 | 29 | example_list_rule2 = [ 30 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."}, 31 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."}, 32 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."} 33 | ] 34 | 35 | example_list_rule3 = [ 36 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."}, 37 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."}, 38 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."} 39 | ] 40 | 41 | 42 | def pop_exist(data, datalist): 43 | data_list = datalist.copy() 44 | data_list.remove(data) 45 | return random.choice(data_list) 46 | 47 | 48 | E11 = example_list_rule1[0] 49 | E12 = example_list_rule1[1] 50 | E13 = example_list_rule1[2] 51 | E21 = random.choice(example_list_rule1) 52 | E22 = pop_exist(E21, example_list_rule1) 53 | E23 = random.choice(example_list_rule2) 54 | E31 = random.choice(example_list_rule1) 55 | E32 = pop_exist(E31, example_list_rule1) 56 | E33 = random.choice(example_list_rule3) 57 | E41 = random.choice(example_list_rule1) 58 | E42 = random.choice(example_list_rule2) 59 | E43 = pop_exist(E42, example_list_rule2) 60 | E51 = random.choice(example_list_rule1) 61 | E52 = random.choice(example_list_rule2) 62 | E53 = random.choice(example_list_rule3) 63 | E61 = random.choice(example_list_rule1) 64 | E62 = random.choice(example_list_rule3) 65 | E63 = pop_exist(E62, example_list_rule3) 66 | E71 = example_list_rule2[0] 67 | E72 = example_list_rule2[1] 68 | E73 = example_list_rule2[2] 69 | E81 = random.choice(example_list_rule2) 70 | E82 = pop_exist(E81, example_list_rule2) 71 | E83 = random.choice(example_list_rule3) 72 | E91 = random.choice(example_list_rule2) 73 | E92 = random.choice(example_list_rule3) 74 | E93 = pop_exist(E92, example_list_rule3) 75 | E101 = example_list_rule3[0] 76 | E102 = example_list_rule3[1] 77 | E103 = example_list_rule3[2] 78 | 79 | 80 | example_list = [ 81 | [E11, E12, E13], 82 | [E21, E22, E23], 83 | [E31, E32, E33], 84 | [E41, E42, E43], 85 | [E51, E52, E53], 86 | [E61, E62, E63], 87 | [E71, E72, E73], 88 | [E81, E82, E83], 89 | [E91, E92, E93], 90 | [E101, E102, E103] 91 | ] 92 | 93 | 94 | def fetch_image_content(image_url): 95 | response = requests.get(image_url) 96 | if response.status_code == 200: 97 | return BytesIO(response.content) 98 | else: 99 | return None 100 | 101 | 102 | import PIL.Image as Image 103 | 104 | # few-shot & zero-shot: 105 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings) 106 | 107 | # text-only: 108 | # model = genai.GenerativeModel('gemini-pro', generation_config=generation_config, safety_settings=safety_settings) 109 | 110 | import json 111 | 112 | count = 0 113 | right_count = 0 114 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout: 115 | for line in f: 116 | sleep(1) 117 | data = json.loads(line) 118 | id = data['id'] 119 | 120 | example_3 = example_list[count % 10] 121 | 122 | # 3 - shot 123 | prompt = ['', '', '', '', ''] 124 | prompt[0] = f'You are currently a senior expert in spatial relation reasoning. ' \ 125 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 126 | f'\nGiven the following 3 examples to learn the spatial relation reasoning task:' \ 127 | f'\nExample1: Input: Image: ' 128 | prompt[1] = f'\n{example_3[0]["text"]} ' \ 129 | f'\nExample2: Input: Image: ' 130 | prompt[2] = f'\n{example_3[1]["text"]} ' \ 131 | f'\nExample3: Input: Image: ' 132 | prompt[3] = f'\n{example_3[2]["text"]} ' \ 133 | f'\nInput: Image: ' 134 | 135 | # prompt = f'Given the following 3 examples to learn the spatial relation reasoning task:' \ 136 | # f'\n{example_list[count % 10][0]["text"]} ' \ 137 | # f'\nYou are currently a senior expert in spatial relation reasoning. ' \ 138 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 139 | # f'\nInput: Image: ' 140 | prompt[4] = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 141 | 142 | image_list = [] 143 | for item in example_3: 144 | image_id = item['image'] 145 | image_url = f'{IMAGE_DIR}/{image_id}' 146 | image_content = fetch_image_content(image_url) 147 | if image_content is not None: 148 | image = Image.open(image_content) 149 | image_list.append(image) 150 | 151 | image_id = data['image'] 152 | image_url = f'{IMAGE_DIR}/{image_id}' 153 | image_content = fetch_image_content(image_url) 154 | 155 | if image_content is not None: 156 | image = Image.open(image_content) 157 | image_list.append(image) 158 | try: 159 | response = model.generate_content( 160 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3], image_list[3], prompt[4]] # few-shot & zero-shot 161 | # [question] # text-only 162 | ) 163 | except Exception as e: 164 | print(e) 165 | try: 166 | response = model.generate_content( 167 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3], image_list[3], prompt[4]] # few-shot & zero-shot 168 | # [question] # text-only 169 | ) 170 | except Exception as e: 171 | try: 172 | response = model.generate_content( 173 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3], image_list[3], prompt[4]] # few-shot & zero-shot 174 | # [question] # text-only 175 | ) 176 | except Exception as e: 177 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10} 178 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 179 | count += 1 180 | continue 181 | try: 182 | output = response.text.strip().rstrip('.').lower() 183 | except Exception as e: 184 | print(e) 185 | output = '--' 186 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10} 187 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 188 | count += 1 189 | continue 190 | count += 1 191 | if output in data['answer'] or data['answer'] in output: 192 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10} 193 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 194 | right_count += 1 195 | else: 196 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10} 197 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n') 198 | print(f'{output.lower()}') 199 | print(f"{data['answer']}") 200 | print(f'right_count: {right_count}') 201 | print(f'count: {count}') 202 | 203 | accuracy = right_count / count 204 | print(accuracy) 205 | -------------------------------------------------------------------------------- /Code/close_models/gpt4_2_shot.py: -------------------------------------------------------------------------------- 1 | # import openai 2 | from openai import OpenAI 3 | import json 4 | import requests 5 | from io import BytesIO 6 | import PIL.Image as Image 7 | import base64 8 | import random 9 | 10 | client = OpenAI( 11 | api_key='your key', 12 | base_url='https://api.mnxcc.com/v1' 13 | ) 14 | 15 | IMAGE_DIR = 'http://images.cocodataset.org/test2017' 16 | 17 | example_list_rule1 = [ 18 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."}, 19 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."}, 20 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."} 21 | ] 22 | 23 | example_list_rule2 = [ 24 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."}, 25 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."}, 26 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."} 27 | ] 28 | 29 | example_list_rule3 = [ 30 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."}, 31 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."}, 32 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."} 33 | ] 34 | 35 | 36 | def pop_exist(data, datalist): 37 | data_list = datalist.copy() 38 | data_list.remove(data) 39 | return random.choice(data_list) 40 | 41 | 42 | E11 = random.choice(example_list_rule1) 43 | E12 = pop_exist(E11, example_list_rule1) 44 | E21 = random.choice(example_list_rule1) 45 | E22 = random.choice(example_list_rule2) 46 | E31 = random.choice(example_list_rule1) 47 | E32 = random.choice(example_list_rule3) 48 | E41 = random.choice(example_list_rule2) 49 | E42 = pop_exist(E41, example_list_rule2) 50 | E51 = random.choice(example_list_rule2) 51 | E52 = random.choice(example_list_rule3) 52 | E61 = random.choice(example_list_rule3) 53 | E62 = pop_exist(E61, example_list_rule3) 54 | 55 | 56 | example_list = [ 57 | [E11, E12], 58 | [E21, E22], 59 | [E31, E32], 60 | [E41, E42], 61 | [E51, E52], 62 | [E61, E62] 63 | ] 64 | 65 | 66 | # count = 0 67 | # right_count = 0 68 | 69 | 70 | def fetch_image_content(image_url): 71 | response = requests.get(image_url) 72 | if response.status_code == 200: 73 | return BytesIO(response.content) 74 | else: 75 | return None 76 | 77 | 78 | def encode_image(image): 79 | if image is None: 80 | return None 81 | 82 | buffered = BytesIO() 83 | try: 84 | image.save(buffered, format="JPEG") 85 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') 86 | return 'data:image/jpeg;base64,' + img_str 87 | except Exception as e: 88 | print(f"error: {e}") 89 | return None 90 | 91 | 92 | def call_gpt4(prompt, image, detail='auto'): 93 | try: 94 | response = client.chat.completions.create( 95 | model="gpt-4o", 96 | messages=[ 97 | { 98 | "role": "user", 99 | # zero-shot + few-shot 100 | "content": [{"type": "text", "text": prompt[0]}] \ 101 | + [{"type": "image_url", "image_url": { "url": image[0], "detail": detail}}] \ 102 | + [{"type": "text", "text": prompt[1]}] \ 103 | + [{"type": "image_url", "image_url": { "url": image[1], "detail": detail}}] \ 104 | + [{"type": "text", "text": prompt[2]}] \ 105 | + [{"type": "image_url", "image_url": { "url": image[2], "detail": detail}}] \ 106 | + [{"type": "text", "text": prompt[3]}] 107 | # text only: 108 | # "content": [{"type": "text", "text": question}] 109 | } 110 | ], 111 | max_tokens=500, 112 | temperature=0.5, 113 | ) 114 | # print(response.choices[0].message.content.strip()) 115 | return response.choices[0].message.content.strip() 116 | 117 | except Exception as e: 118 | print(f"Error during answering: {e}") 119 | return None 120 | 121 | 122 | def process_jsonl(input_file, output_file): 123 | count = 0 124 | right_count = 0 125 | 126 | with open(input_file, 'r', encoding='utf-8') as file: 127 | with open(output_file, 'w', encoding='utf-8') as out_file: 128 | for line in file: 129 | data = json.loads(line) 130 | id = data['id'] 131 | 132 | example_2 = example_list[count % 6] 133 | 134 | # 2 - shot 135 | prompt = ['', '', '', ''] 136 | prompt[0] = f'You are currently a senior expert in spatial relation reasoning. ' \ 137 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 138 | f'\nGiven the following 2 examples to learn the spatial relation reasoning task:' \ 139 | f'\nExample1: Input: Image: ' 140 | prompt[1] = f'\n{example_2[0]["text"]} ' \ 141 | f'\nExample2: Input: Image: ' 142 | prompt[2] = f'\n{example_2[1]["text"]} ' \ 143 | f'\nInput: Image: ' 144 | # prompt = f'Given the following 3 examples to learn the spatial relation reasoning task:' \ 145 | # f'\n{example_list[count % 10][0]["text"]} ' \ 146 | # f'\nYou are currently a senior expert in spatial relation reasoning. ' \ 147 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \ 148 | # f'\nInput: Image: ' 149 | prompt[3] = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: ' 150 | 151 | image_encoded_list = [] 152 | for item in example_2: 153 | image_id = item['image'] 154 | image_url = f'{IMAGE_DIR}/{image_id}' 155 | image_content = fetch_image_content(image_url) 156 | if image_content is not None: 157 | image = Image.open(image_content) 158 | image_encoded = encode_image(image) 159 | image_encoded_list.append(image_encoded) 160 | 161 | image_id = data['image'] 162 | image_url = f'{IMAGE_DIR}/{image_id}' 163 | image_content = fetch_image_content(image_url) 164 | 165 | if image_content is not None: 166 | image = Image.open(image_content) 167 | image_encoded = encode_image(image) 168 | image_encoded_list.append(image_encoded) 169 | 170 | try: 171 | model_answer = call_gpt4(prompt, image_encoded_list) 172 | except Exception as e: 173 | print(e) 174 | try: 175 | model_answer = call_gpt4(prompt, image_encoded_list) 176 | except Exception as e: 177 | try: 178 | model_answer = call_gpt4(prompt, image_encoded_list) 179 | except Exception as e: 180 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 181 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 182 | count += 1 183 | continue 184 | 185 | try: 186 | # output = model_answer.text.strip().rstrip('.').lower() 187 | output = model_answer.strip().rstrip('.').lower() 188 | except Exception as e: 189 | print(e) 190 | output = '--' 191 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 192 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 193 | count += 1 194 | continue 195 | count += 1 196 | if output in data['answer'] or data['answer'] in output: 197 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 198 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 199 | right_count += 1 200 | else: 201 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4} 202 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n') 203 | print(f'{output.lower()}') 204 | print(f"{data['answer']}") 205 | print(f'right_count: {right_count}') 206 | print(f'count: {count}') 207 | # print(f'accuracy: {right_count / count}') 208 | 209 | accuracy = right_count / count 210 | print(accuracy) 211 | 212 | 213 | input_file_path = "test_en_select_500_sort_cp.jsonl" 214 | output_file_path = "results/gpt4_2_shot_random_new_1.jsonl" 215 | 216 | process_jsonl(input_file_path, output_file_path) 217 | --------------------------------------------------------------------------------