├── Results
├── main.png
├── analysis.png
└── error_case.jpg
├── Comparison
├── compare.jpg
├── splits.png
└── statistic.png
├── Examples
├── 000000000933.jpg
├── 000000006568.jpg
├── 000000015740.jpg
├── 000000057139.jpg
├── 000000070986.jpg
├── 000000100633.jpg
├── 000000121362.jpg
├── 000000142379.jpg
├── examples_1-4.png
├── examples_5-8.png
└── examples.jsonl
├── Code
├── eval
│ ├── README.md
│ ├── calculate_result_rule.py
│ ├── calculate_xyz.py
│ └── calculate_prf1.py
├── requirement
│ ├── requirement_gpt4.txt
│ ├── requirement_gemini.txt
│ ├── requirement_spacellava.txt
│ ├── requirement_idefics.txt
│ ├── requirement_blip.txt
│ ├── requirement_mplug.txt
│ └── requirement_llava.txt
├── finetune
│ ├── llava_lora_train.sh
│ ├── spacellava_lora_train.sh
│ ├── mPLUG_Owl_train_it.sh
│ ├── idefics.py
│ ├── blip-vqa-base.py
│ ├── instructblip-lora.py
│ └── blip2-lora.py
├── experiment
│ ├── blip-vqa-base.py
│ ├── blip-vqa-base_finetuned.py
│ ├── blip2-opt-2.7b.py
│ ├── blip2-lora.py
│ ├── instructblip-flan-t5-xl.py
│ ├── instructblip-lora.py
│ ├── spatial_test_mplug.py
│ ├── spatial_test_mplug_lora.py
│ ├── idefics_new.py
│ ├── idefics_lora.py
│ ├── spacellava_test.py
│ ├── spatial_test_llava.py
│ ├── spatial_test_llava_lora.py
│ └── spacellava_lora_test.py
└── close_models
│ ├── gemini_text_only.py
│ ├── gemini_0_shot.py
│ ├── gpt4_text_only.py
│ ├── gpt4_zero_shot.py
│ ├── gemini_1_shot_random.py
│ ├── gemini_2_shot_random.py
│ ├── gpt4_1_shot.py
│ ├── gemini_3_shot_random.py
│ └── gpt4_2_shot.py
├── Dataset
├── relevant_images
│ └── README.md
├── tool
│ └── select_relevant_images.py
├── types
│ └── types.txt
└── README.md
├── Metadata
├── metadata_hf.json
└── metadata.json
└── README.md
/Results/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Results/main.png
--------------------------------------------------------------------------------
/Comparison/compare.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Comparison/compare.jpg
--------------------------------------------------------------------------------
/Comparison/splits.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Comparison/splits.png
--------------------------------------------------------------------------------
/Results/analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Results/analysis.png
--------------------------------------------------------------------------------
/Results/error_case.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Results/error_case.jpg
--------------------------------------------------------------------------------
/Comparison/statistic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Comparison/statistic.png
--------------------------------------------------------------------------------
/Examples/000000000933.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000000933.jpg
--------------------------------------------------------------------------------
/Examples/000000006568.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000006568.jpg
--------------------------------------------------------------------------------
/Examples/000000015740.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000015740.jpg
--------------------------------------------------------------------------------
/Examples/000000057139.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000057139.jpg
--------------------------------------------------------------------------------
/Examples/000000070986.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000070986.jpg
--------------------------------------------------------------------------------
/Examples/000000100633.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000100633.jpg
--------------------------------------------------------------------------------
/Examples/000000121362.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000121362.jpg
--------------------------------------------------------------------------------
/Examples/000000142379.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/000000142379.jpg
--------------------------------------------------------------------------------
/Examples/examples_1-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/examples_1-4.png
--------------------------------------------------------------------------------
/Examples/examples_5-8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziyan-xiaoyu/SpatialMQA/HEAD/Examples/examples_5-8.png
--------------------------------------------------------------------------------
/Code/eval/README.md:
--------------------------------------------------------------------------------
1 | **calculate_prf1.py** can calculate overall accuracy, overall P, R, F1 indicators.
2 |
3 | **calculate_xyz.py** can calculate accuracy based on relationship categories.
4 |
5 | **calculate_result_rule.py** can calculate accuracy based on rules.
6 |
--------------------------------------------------------------------------------
/Dataset/relevant_images/README.md:
--------------------------------------------------------------------------------
1 | ### Download images
2 | We use a subset of COCO-2017's images. The following script download COCO-2017's test sets images then put them into a single fodler `Dataset/COCO2017/`.
3 |
4 | ```bash
5 | cd Dataset/
6 | wget http://images.cocodataset.org/zips/test2017.zip
7 | unzip test2017.zip
8 | mv test2017 COCO2017 && rm -r test2017
9 | ```
10 | Copy only relevant images to `relevant_images/`.
11 | ```bash
12 | mkdir relevant_images
13 | cd tool
14 | python select_revlevant_images.py
15 | ```
16 | Alternatively, you could also browse individual images online directly using the key "image" in single json data.
17 |
(Through COCO's open source link, 'http://images.cocodataset.org/test2017/' + 'image_name'. For example: http://images.cocodataset.org/test2017/000000195921.jpg.)
18 |
--------------------------------------------------------------------------------
/Code/requirement/requirement_gpt4.txt:
--------------------------------------------------------------------------------
1 | annotated-types==0.6.0
2 | anyio==4.3.0
3 | certifi==2024.2.2
4 | charset-normalizer==3.3.2
5 | colorama==0.4.6
6 | distro==1.9.0
7 | exceptiongroup==1.2.1
8 | h11==0.14.0
9 | httpcore==1.0.5
10 | httpx==0.27.0
11 | idna==3.7
12 | openai==1.30.1
13 | Pillow==7.2.0
14 | pip==24.0
15 | pydantic==2.7.1
16 | pydantic_core==2.18.2
17 | requests==2.31.0
18 | setuptools==69.5.1
19 | sniffio==1.3.1
20 | tk==0.1.0
21 | tqdm==4.66.4
22 | typing_extensions==4.11.0
23 | urllib3==2.2.1
24 | wheel==0.43.0
25 | Could not fetch URL https://mirrors.aliyun.com/pypi/simple/pip/: There was a problem confirming the ssl certificate: HTTPSConnectionPool(host='mirrors.aliyun.com', port=443): Max retries exceeded with url: /pypi/simple/pip/ (Caused by SSLError(SSLZeroReturnError(6, 'TLS/SSL connection has been closed (EOF) (_ssl.c:1135)'))) - skipping
26 |
--------------------------------------------------------------------------------
/Code/requirement/requirement_gemini.txt:
--------------------------------------------------------------------------------
1 | annotated-types==0.6.0
2 | beautifulsoup4==4.12.3
3 | cachetools==5.3.3
4 | certifi==2024.2.2
5 | charset-normalizer==3.3.2
6 | colorama==0.4.6
7 | google==3.0.0
8 | google-ai-generativelanguage==0.6.3
9 | google-api-core==2.19.0
10 | google-api-python-client==2.129.0
11 | google-auth==2.29.0
12 | google-auth-httplib2==0.2.0
13 | google-generativeai==0.5.3
14 | googleapis-common-protos==1.63.0
15 | grpcio==1.63.0
16 | grpcio-status==1.62.2
17 | httplib2==0.22.0
18 | idna==3.7
19 | joblib==1.4.2
20 | numpy==1.26.4
21 | pillow==10.3.0
22 | pip==24.0
23 | proto-plus==1.23.0
24 | protobuf==4.25.3
25 | pyasn1==0.6.0
26 | pyasn1_modules==0.4.0
27 | pydantic==2.7.1
28 | pydantic_core==2.18.2
29 | pyparsing==3.1.2
30 | requests==2.31.0
31 | rsa==4.9
32 | scikit-learn==1.5.0
33 | scipy==1.13.1
34 | setuptools==69.5.1
35 | soupsieve==2.5
36 | threadpoolctl==3.5.0
37 | tqdm==4.66.4
38 | typing_extensions==4.11.0
39 | uritemplate==4.1.1
40 | urllib3==2.2.1
41 | wheel==0.43.0
42 |
--------------------------------------------------------------------------------
/Code/requirement/requirement_spacellava.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | _openmp_mutex=5.1=1_gnu
6 | brotli=1.1.0=pypi_0
7 | ca-certificates=2024.7.2=h06a4308_0
8 | certifi=2024.7.4=pypi_0
9 | charset-normalizer=3.3.2=pypi_0
10 | gevent=24.2.1=pypi_0
11 | geventhttpclient=2.0.12=pypi_0
12 | greenlet=3.0.3=pypi_0
13 | idna=3.7=pypi_0
14 | ld_impl_linux-64=2.38=h1181459_1
15 | libffi=3.4.4=h6a678d5_1
16 | libgcc-ng=11.2.0=h1234567_1
17 | libgomp=11.2.0=h1234567_1
18 | libstdcxx-ng=11.2.0=h1234567_1
19 | ncurses=6.4=h6a678d5_0
20 | numpy=1.24.4=pypi_0
21 | openssl=3.0.14=h5eee18b_0
22 | pip=24.0=py38h06a4308_0
23 | python=3.8.19=h955ad1f_0
24 | python-rapidjson=1.20=pypi_0
25 | readline=8.2=h5eee18b_0
26 | requests=2.32.3=pypi_0
27 | setuptools=72.1.0=py38h06a4308_0
28 | six=1.16.0=pypi_0
29 | sqlite=3.45.3=h5eee18b_0
30 | tk=8.6.14=h39e8969_0
31 | tritonclient=2.48.0=pypi_0
32 | urllib3=2.2.2=pypi_0
33 | wheel=0.43.0=py38h06a4308_0
34 | xz=5.4.6=h5eee18b_1
35 | zlib=1.2.13=h5eee18b_1
36 | zope-event=5.0=pypi_0
37 | zope-interface=7.0.1=pypi_0
38 |
--------------------------------------------------------------------------------
/Dataset/tool/select_relevant_images.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import shutil
4 |
5 |
6 | def copy_images(jsonl_file, source_dir, destination_dir):
7 | if not os.path.exists(destination_dir):
8 | os.makedirs(destination_dir)
9 | count = 0
10 | with open(jsonl_file, 'r', encoding='gbk', errors='ignore') as f:
11 | for line in f:
12 | count += 1
13 | data = json.loads(line)
14 | image_filename = data['image']
15 |
16 | source_path = os.path.join(source_dir, image_filename)
17 |
18 | if os.path.exists(source_path):
19 | destination_path = os.path.join(destination_dir, image_filename)
20 | shutil.copyfile(source_path, destination_path)
21 | else:
22 | print(f"Source file {image_filename} not found in {source_dir}")
23 |
24 | print(jsonl_file+': '+str(count))
25 | print("Copy successful!")
26 |
27 |
28 | jsonl_file = os.listdir('../dataset/')
29 | for i in range(len(jsonl_file)):
30 | jsonl_file[i] = '../dataset/'+jsonl_file[i]
31 | source_dir = '../COCO2017'
32 | destination_dir = '../relevant_images'
33 |
34 | for file in jsonl_file:
35 | copy_images(file, source_dir, destination_dir)
36 |
--------------------------------------------------------------------------------
/Code/finetune/llava_lora_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed --include localhost:6 llava/train/train_mem.py \
4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
5 | --deepspeed ./scripts/zero3.json \
6 | --model_name_or_path /projects/Models/LLaVA-main/llava-v1.5-7b \
7 | --version v1 \
8 | --data_path /projects/SpatialMQA/datasets/llava_train/train_3780.json \
9 | --image_folder /projects \
10 | --vision_tower openai/clip-vit-large-patch14-336 \
11 | --mm_projector_type mlp2x_gelu \
12 | --mm_vision_select_layer -2 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --image_aspect_ratio pad \
16 | --group_by_modality_length True \
17 | --bf16 True \
18 | --output_dir /projects/SpatialMQA/finetune_models/models_arg/llava_lora_20240601 \
19 | --num_train_epochs 10 \
20 | --per_device_train_batch_size 8 \
21 | --per_device_eval_batch_size 4 \
22 | --gradient_accumulation_steps 2 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 100 \
26 | --save_total_limit 1 \
27 | --learning_rate 2e-4 \
28 | --weight_decay 0. \
29 | --warmup_ratio 0.02 \
30 | --lr_scheduler_type "cosine" \
31 | --logging_steps 2 \
32 | --tf32 True \
33 | --model_max_length 2048 \
34 | --gradient_checkpointing True \
35 | --dataloader_num_workers 0 \
36 | --lazy_preprocess True
37 |
--------------------------------------------------------------------------------
/Code/requirement/requirement_idefics.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.29.2
2 | aiohttp==3.9.3
3 | aiosignal==1.3.1
4 | async-timeout==4.0.3
5 | attrs==23.2.0
6 | bitsandbytes==0.43.0
7 | certifi==2024.2.2
8 | charset-normalizer==3.3.2
9 | datasets==2.18.0
10 | dill==0.3.8
11 | filelock==3.13.4
12 | frozenlist==1.4.1
13 | fsspec==2024.2.0
14 | huggingface-hub==0.22.2
15 | idna==3.6
16 | Jinja2==3.1.3
17 | MarkupSafe==2.1.5
18 | mpmath==1.3.0
19 | multidict==6.0.5
20 | multiprocess==0.70.16
21 | networkx==3.3
22 | numpy==1.26.4
23 | nvidia-cublas-cu12==12.1.3.1
24 | nvidia-cuda-cupti-cu12==12.1.105
25 | nvidia-cuda-nvrtc-cu12==12.1.105
26 | nvidia-cuda-runtime-cu12==12.1.105
27 | nvidia-cudnn-cu12==8.9.2.26
28 | nvidia-cufft-cu12==11.0.2.54
29 | nvidia-curand-cu12==10.3.2.106
30 | nvidia-cusolver-cu12==11.4.5.107
31 | nvidia-cusparse-cu12==12.1.0.106
32 | nvidia-nccl-cu12==2.19.3
33 | nvidia-nvjitlink-cu12==12.4.127
34 | nvidia-nvtx-cu12==12.1.105
35 | packaging==24.0
36 | pandas==2.2.1
37 | peft==0.10.0
38 | pillow==10.3.0
39 | pip==23.3.1
40 | psutil==5.9.8
41 | pyarrow==15.0.2
42 | pyarrow-hotfix==0.6
43 | python-dateutil==2.9.0.post0
44 | pytz==2024.1
45 | PyYAML==6.0.1
46 | regex==2023.12.25
47 | requests==2.31.0
48 | safetensors==0.4.2
49 | setuptools==68.2.2
50 | six==1.16.0
51 | sympy==1.12
52 | tokenizers==0.15.2
53 | torch==2.2.2
54 | torchvision==0.17.2
55 | tqdm==4.66.2
56 | transformers==4.39.3
57 | triton==2.2.0
58 | typing_extensions==4.11.0
59 | tzdata==2024.1
60 | urllib3==2.2.1
61 | wheel==0.41.2
62 | xxhash==3.4.1
63 | yarl==1.9.4
64 |
--------------------------------------------------------------------------------
/Code/finetune/spacellava_lora_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed --include localhost:0,1 llava/train/train_mem.py \
4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
5 | --deepspeed ./scripts/zero3.json \
6 | --model_name_or_path /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/spacellava_hf \
7 | --version v1 \
8 | --data_path /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/datasets/train_3000.json \
9 | --image_folder /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/datasets/images \
10 | --vision_tower openai/clip-vit-large-patch14-336 \
11 | --mm_projector_type mlp2x_gelu \
12 | --mm_vision_select_layer -2 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --image_aspect_ratio pad \
16 | --group_by_modality_length True \
17 | --bf16 True \
18 | --output_dir /home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/saved_model/spacellava_lora_20240816_3000 \
19 | --num_train_epochs 2 \
20 | --per_device_train_batch_size 8 \
21 | --per_device_eval_batch_size 4 \
22 | --gradient_accumulation_steps 1 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 60 \
26 | --save_total_limit 1 \
27 | --learning_rate 2e-4 \
28 | --weight_decay 0. \
29 | --warmup_ratio 0.02 \
30 | --lr_scheduler_type "cosine" \
31 | --logging_steps 1 \
32 | --tf32 True \
33 | --model_max_length 2048 \
34 | --gradient_checkpointing True \
35 | --dataloader_num_workers 0 \
36 | --lazy_preprocess True
37 | # --report_to wandb
38 |
--------------------------------------------------------------------------------
/Dataset/types/types.txt:
--------------------------------------------------------------------------------
1 | person
2 | letter
3 | natural scenery
4 | hand
5 | cat
6 | cup
7 | sports ball
8 | dog
9 | car
10 | lamp
11 | cell phone
12 | fork
13 | clock
14 | bottle
15 | shadow
16 | building
17 | window
18 | laptop
19 | baseball bat
20 | knife
21 | tv
22 | bicycle
23 | number
24 | backpack
25 | stop sign
26 | paper
27 | bus
28 | book
29 | door
30 | elephant
31 | frisbee
32 | banana
33 | trash can
34 | wine glass
35 | towel
36 | mouse
37 | box
38 | remote
39 | spoon
40 | horse
41 | sink
42 | motorcycle
43 | teddy bear
44 | vase
45 | fire hydrant
46 | jar
47 | tree and wood
48 | potted plant
49 | chair
50 | shelf
51 | suitcase
52 | bear
53 | bird
54 | flower and grass
55 | pot
56 | traffic light
57 | refrigerator
58 | bowl
59 | fence
60 | skateboard
61 | mural
62 | mirror
63 | socket
64 | toy
65 | cow
66 | basket
67 | frame
68 | cake
69 | surfboard
70 | pillow
71 | clothes
72 | bench
73 | scissors
74 | orange
75 | handbag
76 | bread
77 | plate
78 | giraffe
79 | sheep
80 | shower
81 | camera
82 | boat
83 | pen
84 | dining table
85 | toothbrush
86 | billboard
87 | sculpture
88 | signboard
89 | umbrella
90 | kettle
91 | microwave
92 | truck
93 | train
94 | keyboard
95 | apple
96 | glass
97 | cabinet
98 | drink
99 | tennis racket
100 | hair drier
101 | speaker
102 | carrot
103 | bathtub
104 | fan
105 | rack
106 | faucet
107 | airplane
108 | picture
109 | snowboard
110 | sanitizer
111 | switch
112 | glasses
113 | ketchup
114 | scale
115 | bed
116 | zebra
117 | couch
118 | seat
119 | pizza
120 | toilet
121 | wall
122 | fireplace
123 | display
124 | pole
125 | screen
126 | dessert
127 | others_sub
128 | others_obj
129 |
--------------------------------------------------------------------------------
/Examples/examples.jsonl:
--------------------------------------------------------------------------------
1 | {"image": "000000000933.jpg", "question": "Where is the fork located relative to the pizza?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "right of"}
2 | {"image": "000000006568.jpg", "question": "Where is the cat located relative to the car in the image?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "on/above"}
3 | {"image": "000000057139.jpg", "question": "For the white letters on the red warning sign, where is the letter P located relative to the letter Y?", "options": ["on/above", "below", "left of", "right of"], "answer": "on/above"}
4 | {"image": "000000100633.jpg", "question": "If you are the cyclist in the image, where is the dog located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"}
5 | {"image": "000000121362.jpg", "question": "If you are the player in the image, where is the audience located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"}
6 | {"image": "000000142379.jpg", "question": "If you are the giraffe in the image, where is the tree located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"}
7 | {"image": "000000015740.jpg", "question": "If you are the woman in the image, from your perspective, where is the mouse located relative to the keyboard?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "left of"}
8 | {"image": "000000070986.jpg", "question": "If you are the driver of the bus in the image, from your perspective, where is the red car located relative to the bus?", "options": ["in front of", "behind", "left of", "right of"], "answer": "left of"}
9 |
--------------------------------------------------------------------------------
/Code/requirement/requirement_blip.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.0.0
2 | accelerate==0.25.0
3 | aiohttp==3.9.1
4 | aiosignal==1.3.1
5 | async-timeout==4.0.3
6 | attrs==23.2.0
7 | bitsandbytes==0.41.3.post2
8 | certifi==2023.11.17
9 | charset-normalizer==3.3.2
10 | click==8.1.7
11 | datasets==2.14.6
12 | dill==0.3.7
13 | filelock==3.13.1
14 | frozenlist==1.4.1
15 | fsspec==2023.10.0
16 | huggingface-hub==0.20.3
17 | idna==3.6
18 | Jinja2==3.1.2
19 | joblib==1.3.2
20 | MarkupSafe==2.1.3
21 | mpmath==1.3.0
22 | multidict==6.0.4
23 | multiprocess==0.70.15
24 | munch==4.0.0
25 | networkx==3.2.1
26 | nltk==3.8.1
27 | numpy==1.26.3
28 | nvidia-cublas-cu12==12.1.3.1
29 | nvidia-cuda-cupti-cu12==12.1.105
30 | nvidia-cuda-nvrtc-cu12==12.1.105
31 | nvidia-cuda-runtime-cu12==12.1.105
32 | nvidia-cudnn-cu12==8.9.2.26
33 | nvidia-cufft-cu12==11.0.2.54
34 | nvidia-curand-cu12==10.3.2.106
35 | nvidia-cusolver-cu12==11.4.5.107
36 | nvidia-cusparse-cu12==12.1.0.106
37 | nvidia-nccl-cu12==2.18.1
38 | nvidia-nvjitlink-cu12==12.3.101
39 | nvidia-nvtx-cu12==12.1.105
40 | packaging==23.2
41 | pandas==2.1.4
42 | peft==0.10.0
43 | Pillow==10.0.1
44 | pip==22.0.4
45 | protobuf==4.25.2
46 | psutil==5.9.7
47 | pyarrow==14.0.2
48 | python-dateutil==2.8.2
49 | pytz==2023.3.post1
50 | PyYAML==6.0.1
51 | regex==2023.12.25
52 | requests==2.31.0
53 | rouge-score==0.1.2
54 | ruamel.yaml==0.18.6
55 | ruamel.yaml.clib==0.2.8
56 | safetensors==0.4.1
57 | scipy==1.11.4
58 | sconf==0.2.5
59 | sentencepiece==0.1.99
60 | setuptools==58.1.0
61 | six==1.16.0
62 | sympy==1.12
63 | tokenizers==0.15.0
64 | torch==2.1.2
65 | torchsummary==1.5.1
66 | torchvision==0.16.0
67 | tqdm==4.66.1
68 | transformers==4.35.2
69 | triton==2.1.0
70 | typing_extensions==4.9.0
71 | tzdata==2023.4
72 | urllib3==2.1.0
73 | wheel==0.37.1
74 | xformers==0.0.23.post1
75 | xxhash==3.4.1
76 | yarl==1.9.4
77 |
--------------------------------------------------------------------------------
/Code/finetune/mPLUG_Owl_train_it.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DIR=`pwd`
3 | DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
4 |
5 | if [ $MASTER_ADDR ];then
6 | echo $MASTER_ADDR
7 | echo $MASTER_PORT
8 | echo $WORLD_SIZE
9 | echo $RANK
10 | else
11 | MASTER_ADDR=127.0.0.1
12 | MASTER_PORT=2$(($RANDOM % 10))$(($RANDOM % 10))15
13 | WORLD_SIZE=1
14 | RANK=0
15 | fi
16 |
17 | DISTRIBUTED_ARGS="--nproc_per_node 2 \
18 | --nnodes ${WORLD_SIZE} \
19 | --node_rank ${RANK} \
20 | --master_addr ${MASTER_ADDR} \
21 | --master_port ${MASTER_PORT}"
22 |
23 | EXP_NAME=sft_v0.1
24 | SAVE_NAME=mplug_lora_20240530
25 |
26 | SAVE_PATH="/projects/SpatialMQA/finetune_models/models_arg/${SAVE_NAME}/"
27 |
28 | max_length=2048
29 | micro_batch_size=4
30 | global_batch_size=8
31 | gradient_accumulation_steps=5
32 |
33 | # train_iters = total_data * train_epochs // global_batch_size
34 | # 3780 * 10 / 8 = 4725
35 | train_epochs=10
36 | train_iters=4725
37 |
38 | lr_warmup_iters=50
39 |
40 | eval_iter=50
41 | eval_interval=50
42 | save_interval=500
43 |
44 | mkdir -p ${SAVE_PATH}
45 |
46 | options=" \
47 | --pretrained-ckpt MAGAer13/mplug-owl-llama-7b-pt \
48 | --seq-length ${max_length} \
49 | --micro-batch-size ${micro_batch_size} \
50 | --num-training-steps ${train_iters} \
51 | --train-epochs ${train_epochs} \
52 | --num-warmup-steps ${lr_warmup_iters} \
53 | --gradient-accumulation-steps ${gradient_accumulation_steps} \
54 | --lr 5e-5 \
55 | --min-lr 1e-6 \
56 | --eval-iters ${eval_iter} \
57 | --save-interval ${save_interval} \
58 | --save-path ${SAVE_PATH} \
59 | --clip-grad 1.0 \
60 | --weight-decay 0.0001 \
61 | --adam-beta1 0.9 \
62 | --adam-beta2 0.999 \
63 | --num-workers 32 \
64 | --use-lora \
65 | --gradient-checkpointing \
66 | --bf16"
67 |
68 | multimodal_options=" \
69 | --mm-config configs/v0.yaml
70 | "
71 |
72 | HF_ENDPOINT=https://hf-mirror.com CUDA_VISIBLE_DEVICES=5,6 python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pipeline/train.py $@ ${options} ${multimodal_options} 2>&1 | tee /projects/SpatialMQA/finetune_models/finetune_code/log/20240530_mplug_train.log
--------------------------------------------------------------------------------
/Dataset/README.md:
--------------------------------------------------------------------------------
1 | # Data of SpatialMQA
2 |
3 | ### Download images
4 | We use a subset of COCO-2017's images. The following script download COCO-2017's test sets images then put them into a single fodler `Dataset/COCO2017/`.
5 |
6 | ```bash
7 | cd Dataset/
8 | wget http://images.cocodataset.org/zips/test2017.zip
9 | unzip test2017.zip
10 | mv test2017 COCO2017 && rm -r test2017
11 | ```
12 | Copy only relevant images to `relevant_images/`.
13 | ```bash
14 | mkdir relevant_images
15 | cd tool
16 | python select_revlevant_images.py
17 | ```
18 |
19 | Alternatively, you could also browse individual images online directly using the key "image" in single json data.
20 | (Through COCO's open source link, 'http://images.cocodataset.org/test2017/' + 'image_name'. For example: http://images.cocodataset.org/test2017/000000195921.jpg.)
21 |
22 | ### Splits
23 | As reported in the folloeing table, SpatialMQA contains 5,392 samples, divided into training, validation, and test sets according to a 7:1:2 ratio.
24 |
All the splited data sets are in the directory [`dataset/`](https://github.com/ziyan-xiaoyu/SpatialMQA/blob/Dataset).
25 |
26 | ### Format of the data
27 | Each `jsonl` file is of the following format:
28 | ```json
29 | {"image": "000000000933.jpg", "question": "Where is the fork located relative to the pizza?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "right of"}
30 | {"image": "000000100633.jpg", "question": "If you are the cyclist in the image, where is the dog located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"}
31 | {"image": "000000070986.jpg", "question": "If you are the driver of the bus in the image, from your perspective, where is the red car located relative to the bus?", "options": ["in front of", "behind", "left of", "right of"], "answer": "left of"}
32 | {"..."}
33 | ```
34 | Each line is an individual data point.
35 | `image` denotes name of the image in COCO. `question` is the question with manual annotation, `options` is reasonable combinations of six spatial relationships:(on/above, below, in front of, behind, left of, right of. `answer` is the annotation based on the objective world.
36 |
37 |
--------------------------------------------------------------------------------
/Code/experiment/blip-vqa-base.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from PIL import Image
3 | from transformers import BlipProcessor, BlipForQuestionAnswering
4 | import torch
5 | import json
6 |
7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/'
8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
9 | DEVICE_INDEX = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
10 | processor = BlipProcessor.from_pretrained(f"Salesforce/blip-vqa-base")
11 | model = BlipForQuestionAnswering.from_pretrained(f"Salesforce/blip-vqa-base").to(DEVICE_INDEX)
12 |
13 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
14 | count = 0
15 | right_count = 0
16 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip.jsonl', 'w+', encoding="utf-8") as fout:
17 | for line in f:
18 | data = json.loads(line)
19 | question = data['question']
20 | id = data['id']
21 | options = data['options']
22 | image_name = data['image']
23 | image_filepath = image_dir + image_name
24 | image = Image.open(image_filepath).convert('RGB')
25 | question = f'{question} {",".join(options[:-1])} or {options[-1]}'
26 | print(question)
27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX)
28 | predictions = model.generate(**inputs)
29 | output = (processor.decode(predictions[0], skip_special_tokens=True))
30 | count += 1
31 | if len(output) == 0:
32 | output = '--'
33 | if output.lower() in data['answer']:
34 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
35 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
36 | right_count += 1
37 | elif data['answer'] in output.lower():
38 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
39 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
40 | right_count += 1
41 | else:
42 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
43 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
44 | print(f'{output.lower()}')
45 | print(f"{data['answer']}")
46 | print(f'right_count: {right_count}')
47 | print(f'count: {count}')
48 | print(f'accuracy: {right_count/count}')
49 |
50 | accuracy = right_count/count
51 | print(f'accuracy: {accuracy}')
52 |
53 |
--------------------------------------------------------------------------------
/Code/requirement/requirement_mplug.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.28.0
3 | aiofiles==23.2.1
4 | altair==5.3.0
5 | annotated-types==0.6.0
6 | anyio==4.3.0
7 | asttokens==2.4.1
8 | attrs==23.2.0
9 | bitsandbytes==0.43.0
10 | blinker==1.7.0
11 | Brotli==1.0.9
12 | cchardet==2.1.7
13 | certifi==2024.2.2
14 | chardet==5.2.0
15 | charset-normalizer==2.0.4
16 | click==8.1.7
17 | colorama==0.4.6
18 | contourpy==1.2.0
19 | cycler==0.12.1
20 | decord==0.6.0
21 | einops==0.7.0
22 | exceptiongroup==1.2.0
23 | executing==2.0.1
24 | fastapi==0.110.0
25 | ffmpy==0.3.2
26 | filelock==3.13.3
27 | Flask==3.0.2
28 | fonttools==4.50.0
29 | fsspec==2024.3.1
30 | gradio==4.24.0
31 | gradio_client==0.14.0
32 | grpcio==1.62.1
33 | h11==0.14.0
34 | h5py==3.10.0
35 | httpcore==1.0.5
36 | httpx==0.27.0
37 | huggingface-hub==0.22.2
38 | icecream==2.1.3
39 | idna==3.4
40 | importlib_resources==6.4.0
41 | itsdangerous==2.1.2
42 | Jinja2==3.1.3
43 | jsonschema==4.21.1
44 | jsonschema-specifications==2023.12.1
45 | kiwisolver==1.4.5
46 | Markdown==3.6
47 | markdown-it-py==3.0.0
48 | markdown2==2.4.13
49 | MarkupSafe==2.1.5
50 | matplotlib==3.8.3
51 | mdurl==0.1.2
52 | mkl-fft==1.3.8
53 | mkl-random==1.2.4
54 | mkl-service==2.4.0
55 | munch==4.0.0
56 | numpy==1.26.4
57 | opencv-python==4.9.0.80
58 | opencv-python-headless==4.9.0.80
59 | orjson==3.10.0
60 | packaging==24.0
61 | pandas==2.2.1
62 | peft==0.10.0
63 | pillow==10.2.0
64 | pip==23.3.1
65 | protobuf==5.26.1
66 | psutil==5.9.8
67 | pydantic==2.6.4
68 | pydantic_core==2.16.3
69 | pydub==0.25.1
70 | Pygments==2.17.2
71 | pyparsing==3.1.2
72 | PySocks==1.7.1
73 | python-dateutil==2.9.0.post0
74 | python-multipart==0.0.9
75 | pytz==2024.1
76 | PyYAML==6.0.1
77 | referencing==0.34.0
78 | regex==2023.12.25
79 | requests==2.31.0
80 | rich==13.7.1
81 | rpds-py==0.18.0
82 | ruamel.yaml==0.18.6
83 | ruamel.yaml.clib==0.2.8
84 | ruff==0.3.4
85 | safetensors==0.4.2
86 | sconf==0.2.5
87 | semantic-version==2.10.0
88 | sentencepiece==0.2.0
89 | setuptools==68.2.2
90 | shellingham==1.5.4
91 | six==1.16.0
92 | sniffio==1.3.1
93 | starlette==0.36.3
94 | tensorboard==2.16.2
95 | tensorboard-data-server==0.7.2
96 | tensorboardX==2.6.2.2
97 | tokenizers==0.13.3
98 | tomlkit==0.12.0
99 | toolz==0.12.1
100 | torch==1.13.1
101 | torchaudio==0.13.1
102 | torchvision==0.14.1
103 | tqdm==4.66.2
104 | transformers==4.28.1
105 | typer==0.12.0
106 | typer-cli==0.12.0
107 | typer-slim==0.12.0
108 | typing_extensions==4.9.0
109 | tzdata==2024.1
110 | urllib3==2.1.0
111 | uvicorn==0.29.0
112 | websockets==11.0.3
113 | Werkzeug==3.0.1
114 | wheel==0.41.2
115 |
--------------------------------------------------------------------------------
/Code/experiment/blip-vqa-base_finetuned.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from PIL import Image
3 | from transformers import BlipProcessor, BlipForQuestionAnswering
4 | import torch
5 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/blip/'
6 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
7 | DEVICE_INDEX = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
8 | processor = BlipProcessor.from_pretrained(f"Salesforce/blip-vqa-base")
9 |
10 | for i in range(8):
11 | model = BlipForQuestionAnswering.from_pretrained(f"/projects/SpatialMQA/finetune_models/models_arg/blip_finetune_20240602/epoch-{i}").to(DEVICE_INDEX)
12 |
13 | import json
14 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
15 | count = 0
16 | right_count = 0
17 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip_finetuned_{i}.jsonl', 'w', encoding="utf-8") as fout:
18 | for line in f:
19 | data = json.loads(line)
20 | question = data['question']
21 | id = data['id']
22 | options = data['options']
23 | image_name = data['image']
24 | image_filepath = image_dir + image_name
25 | image = Image.open(image_filepath).convert('RGB')
26 | question = f'{question} {",".join(options[:-1])} or {options[-1]}'
27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX)
28 | predictions = model.generate(**inputs)
29 | output = (processor.decode(predictions[0], skip_special_tokens=True))
30 | count += 1
31 | if len(output) == 0:
32 | output = '--'
33 | if output == 'on / above':
34 | output = 'on/above'
35 | if output.lower() in data['answer']:
36 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
37 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
38 | right_count += 1
39 | elif data['answer'] in output.lower():
40 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
41 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
42 | right_count += 1
43 | else:
44 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
45 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
46 | print(f'{output.lower()}')
47 | print(f"{data['answer']}")
48 | print(f'right_count: {right_count}')
49 | print(f'count: {count}')
50 | print(f'accuracy: {right_count/count}')
51 |
52 | accuracy = right_count/count
53 | print(f'accuracy: {accuracy}')
54 |
55 |
--------------------------------------------------------------------------------
/Code/experiment/blip2-opt-2.7b.py:
--------------------------------------------------------------------------------
1 |
2 | import requests
3 | from PIL import Image
4 | from transformers import AutoProcessor, Blip2ForConditionalGeneration
5 | import torch
6 |
7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/'
8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
9 | DEVICE_INDEX = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
10 | model_dir = '/projects/Models/'
11 | processor = AutoProcessor.from_pretrained(f"{model_dir}blip2-opt-2.7b")
12 | model = Blip2ForConditionalGeneration.from_pretrained(f"{model_dir}blip2-opt-2.7b").to(DEVICE_INDEX)
13 | import json
14 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
15 | count = 0
16 | right_count = 0
17 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip2_20240531.jsonl', 'w+', encoding="utf-8") as fout:
18 | for line in f:
19 | data = json.loads(line)
20 | question = data['question']
21 | id = data['id']
22 | options = data['options']
23 | image_name = data['image']
24 | image_filepath = image_dir + image_name
25 | image = Image.open(image_filepath)
26 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX)
28 | predictions = model.generate(**inputs)
29 | output = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip().rstrip('.')
30 | count += 1
31 | if len(output) == 0:
32 | output = '--'
33 | if output.lower() in data['answer']:
34 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
35 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
36 | right_count += 1
37 | elif data['answer'] in output.lower():
38 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
39 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
40 | right_count += 1
41 | else:
42 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
43 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
44 | print(f'{output.lower()}')
45 | print(f"{data['answer']}")
46 | print(f'right_count: {right_count}')
47 | print(f'count: {count}')
48 | print(f'accuracy: {right_count/count}')
49 |
50 | accuracy = right_count/count
51 | print(f'accuracy: {accuracy}')
52 |
--------------------------------------------------------------------------------
/Code/requirement/requirement_llava.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | aiofiles==23.2.1
3 | altair==5.2.0
4 | annotated-types==0.6.0
5 | anyio==4.3.0
6 | attrs==23.2.0
7 | bitsandbytes==0.43.0
8 | certifi==2024.2.2
9 | charset-normalizer==3.3.2
10 | click==8.1.7
11 | colorama==0.4.6
12 | contourpy==1.2.0
13 | cycler==0.12.1
14 | deepspeed==0.14.2
15 | docker-pycreds==0.4.0
16 | einops==0.6.1
17 | einops-exts==0.0.4
18 | exceptiongroup==1.2.0
19 | fastapi==0.110.0
20 | ffmpy==0.3.2
21 | filelock==3.13.1
22 | flash-attn==2.5.8
23 | fonttools==4.49.0
24 | fsspec==2024.2.0
25 | gitdb==4.0.11
26 | GitPython==3.1.43
27 | gradio==4.16.0
28 | gradio_client==0.8.1
29 | h11==0.14.0
30 | hjson==3.1.0
31 | httpcore==0.17.3
32 | httpx==0.24.0
33 | huggingface-hub==0.21.4
34 | idna==3.6
35 | importlib_resources==6.3.0
36 | Jinja2==3.1.3
37 | joblib==1.3.2
38 | jsonschema==4.21.1
39 | jsonschema-specifications==2023.12.1
40 | kiwisolver==1.4.5
41 | llava==1.2.2.post1
42 | markdown-it-py==3.0.0
43 | markdown2==2.4.13
44 | MarkupSafe==2.1.5
45 | matplotlib==3.8.3
46 | mdurl==0.1.2
47 | mpmath==1.3.0
48 | networkx==3.2.1
49 | ninja==1.11.1.1
50 | numpy==1.26.4
51 | nvidia-cublas-cu12==12.1.3.1
52 | nvidia-cuda-cupti-cu12==12.1.105
53 | nvidia-cuda-nvrtc-cu12==12.1.105
54 | nvidia-cuda-runtime-cu12==12.1.105
55 | nvidia-cudnn-cu12==8.9.2.26
56 | nvidia-cufft-cu12==11.0.2.54
57 | nvidia-curand-cu12==10.3.2.106
58 | nvidia-cusolver-cu12==11.4.5.107
59 | nvidia-cusparse-cu12==12.1.0.106
60 | nvidia-nccl-cu12==2.18.1
61 | nvidia-nvjitlink-cu12==12.4.99
62 | nvidia-nvtx-cu12==12.1.105
63 | orjson==3.9.15
64 | packaging==24.0
65 | pandas==2.2.1
66 | peft==0.9.0
67 | pillow==10.2.0
68 | pip==24.0
69 | platformdirs==4.2.2
70 | protobuf==4.25.3
71 | psutil==5.9.8
72 | py-cpuinfo==9.0.0
73 | pydantic==2.6.4
74 | pydantic_core==2.16.3
75 | pydub==0.25.1
76 | Pygments==2.17.2
77 | pynvml==11.5.0
78 | pyparsing==3.1.2
79 | python-dateutil==2.9.0.post0
80 | python-multipart==0.0.9
81 | pytz==2024.1
82 | PyYAML==6.0.1
83 | referencing==0.33.0
84 | regex==2023.12.25
85 | requests==2.31.0
86 | rich==13.7.1
87 | rpds-py==0.18.0
88 | ruff==0.3.2
89 | safetensors==0.4.2
90 | scikit-learn==1.2.2
91 | scipy==1.12.0
92 | semantic-version==2.10.0
93 | sentencepiece==0.1.99
94 | sentry-sdk==2.2.1
95 | setproctitle==1.3.3
96 | setuptools==68.2.2
97 | shellingham==1.5.4
98 | shortuuid==1.0.13
99 | six==1.16.0
100 | smmap==5.0.1
101 | sniffio==1.3.1
102 | starlette==0.36.3
103 | svgwrite==1.4.3
104 | sympy==1.12
105 | threadpoolctl==3.3.0
106 | timm==0.6.13
107 | tokenizers==0.15.1
108 | tomlkit==0.12.0
109 | toolz==0.12.1
110 | torch==2.1.2
111 | torchvision==0.16.2
112 | tqdm==4.66.2
113 | transformers==4.37.2
114 | triton==2.1.0
115 | typer==0.9.0
116 | typing_extensions==4.10.0
117 | tzdata==2024.1
118 | urllib3==2.2.1
119 | uvicorn==0.28.0
120 | wavedrom==2.0.3.post3
121 | websockets==11.0.3
122 | wheel==0.41.2
123 |
--------------------------------------------------------------------------------
/Code/eval/calculate_result_rule.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 |
5 | def load_jsonl(file_path):
6 | data = []
7 | with open(file_path, 'r', encoding='utf-8') as f:
8 | for line in f:
9 | data.append(json.loads(line.strip()))
10 | return data
11 |
12 |
13 | def main():
14 | rule_1_file = 'rule_1_test_452.jsonl'
15 | rule_2_file = 'rule_2_test_590.jsonl'
16 | rule_3_file = 'rule_3_test_34.jsonl'
17 |
18 | rule_1_datas = load_jsonl(rule_1_file)
19 | rule_2_datas = load_jsonl(rule_2_file)
20 | rule_3_datas = load_jsonl(rule_3_file)
21 |
22 | test_file_list = os.listdir('model_exp_results/move_to_localhost/')
23 | # test_file_list = os.listdir('model_exp_results/m10/')
24 | for test_file in test_file_list:
25 | correct_count_1 = 0
26 | all_count_1 = 0
27 | correct_count_2 = 0
28 | all_count_2 = 0
29 | correct_count_3 = 0
30 | all_count_3 = 0
31 |
32 | print(f'-------------------------------{test_file}----------------------------------')
33 | test_file_path = f'model_exp_results/move_to_localhost/{test_file}'
34 | # test_file_path = f'model_exp_results/m10/{test_file}'
35 | predictions = load_jsonl(test_file_path)
36 | for predict in predictions:
37 | flag = 0
38 | for rule_1_data in rule_1_datas:
39 | if rule_1_data['id'] == predict['id']:
40 | if predict['result'] == 1:
41 | correct_count_1 += 1
42 | all_count_1 += 1
43 | flag = 1
44 | break
45 |
46 | if flag == 0:
47 | for rule_2_data in rule_2_datas:
48 | if rule_2_data['id'] == predict['id']:
49 | if predict['result'] == 1:
50 | correct_count_2 += 1
51 | all_count_2 += 1
52 | flag = 1
53 | break
54 |
55 | if flag == 0:
56 | for rule_3_data in rule_3_datas:
57 | if rule_3_data['id'] == predict['id']:
58 | if predict['result'] == 1:
59 | correct_count_3 += 1
60 | all_count_3 += 1
61 | flag = 1
62 | break
63 |
64 | if flag == 0:
65 | print("数据有问题")
66 |
67 | print(f"Accuracy for 'rule1': {(correct_count_1 / all_count_1):.4f}")
68 | print(f"Accuracy for 'rule2': {(correct_count_2 / all_count_2):.4f}")
69 | print(f"Accuracy for 'rule3': {(correct_count_3 / all_count_3):.4f}")
70 | print(f"Accuracy for 'overall': {((correct_count_1+correct_count_2+correct_count_3) / len(predictions)):.4f}")
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
75 |
--------------------------------------------------------------------------------
/Code/experiment/blip2-lora.py:
--------------------------------------------------------------------------------
1 |
2 | import requests
3 | from PIL import Image
4 | from transformers import AutoProcessor, Blip2ForConditionalGeneration
5 | from peft import PeftModel, PeftConfig
6 | import torch
7 |
8 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/'
9 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
10 | DEVICE_INDEX = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
11 | model_dir = '/projects/Models/'
12 |
13 | processor = AutoProcessor.from_pretrained(f"{model_dir}blip2-opt-2.7b")
14 | model = Blip2ForConditionalGeneration.from_pretrained(f"{model_dir}blip2-opt-2.7b", device_map="cuda:7", load_in_8bit=True)
15 | lora_point = '/projects/SpatialMQA/finetune_models/models_arg/blip2_lora_20240531/epoch-9'
16 |
17 | model = PeftModel.from_pretrained(model, lora_point).to(DEVICE_INDEX)
18 | import json
19 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
20 | count = 0
21 | right_count = 0
22 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}blip2_lora_20240531_ep9.jsonl', 'w+', encoding="utf-8") as fout:
23 | for line in f:
24 | data = json.loads(line)
25 | question = data['question']
26 | id = data['id']
27 | options = data['options']
28 | image_name = data['image']
29 | image_filepath = image_dir + image_name
30 | image = Image.open(image_filepath)
31 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \nOutput:'
32 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX)
33 | predictions = model.generate(**inputs)
34 | print("predictions: ", predictions)
35 | output = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip().rstrip('.')
36 | count += 1
37 | if len(output) == 0:
38 | output = '--'
39 | if output.lower() in data['answer']:
40 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
41 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
42 | right_count += 1
43 | elif data['answer'] in output.lower():
44 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
45 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
46 | right_count += 1
47 | else:
48 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
49 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
50 | print(f'{output.lower()}')
51 | print(f"{data['answer']}")
52 | print(f'right_count: {right_count}')
53 | print(f'count: {count}')
54 | print(f'accuracy: {right_count/count}')
55 |
56 | accuracy = right_count/count
57 | print(f'accuracy: {accuracy}')
58 |
--------------------------------------------------------------------------------
/Code/experiment/instructblip-flan-t5-xl.py:
--------------------------------------------------------------------------------
1 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
2 | import torch
3 | from PIL import Image
4 | import requests
5 | import json
6 |
7 |
8 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/'
9 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
10 | DEVICE_INDEX = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
11 | model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl").to(DEVICE_INDEX)
12 | processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
13 |
14 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
15 | count = 0
16 | right_count = 0
17 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}instrutblip_20240601.jsonl', 'w+', encoding="utf-8") as fout:
18 | for line in f:
19 | data = json.loads(line)
20 | question = data['question']
21 | id = data['id']
22 | options = data['options']
23 | image_name = data['image']
24 | image_filepath = image_dir + image_name
25 | image = Image.open(image_filepath)
26 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
27 | inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE_INDEX)
28 | outputs = model.generate(
29 | **inputs,
30 | do_sample=False,
31 | num_beams=5,
32 | max_length=256,
33 | min_length=1,
34 | top_p=0.9,
35 | repetition_penalty=1.5,
36 | length_penalty=1.0,
37 | temperature=1,
38 | )
39 | generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip().rstrip('.')
40 | print(generated_text)
41 | output = generated_text
42 | count += 1
43 | if len(output) == 0:
44 | output = '--'
45 | if output.lower() in data['answer']:
46 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
47 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
48 | right_count += 1
49 | elif data['answer'] in output.lower():
50 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
51 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
52 | right_count += 1
53 | else:
54 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
55 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
56 | print(f'{output.lower()}')
57 | print(f"{data['answer']}")
58 | print(f'right_count: {right_count}')
59 | print(f'count: {count}')
60 | print(f'accuracy: {right_count/count}')
61 |
62 | accuracy = right_count/count
63 | print(f'accuracy: {accuracy}')
64 |
--------------------------------------------------------------------------------
/Code/experiment/instructblip-lora.py:
--------------------------------------------------------------------------------
1 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
2 | import torch
3 | from peft import PeftModel, PeftConfig
4 | from PIL import Image
5 | import requests
6 | import re
7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/'
8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
9 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
10 | DEVICE_INDEX = 6
11 | model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl").to(DEVICE_INDEX)
12 | processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
13 |
14 | device = "cuda:6" if torch.cuda.is_available() else "cpu"
15 | lora_point = '/projects/SpatialMQA/finetune_models/models_arg/instructblip_lora_20240530'
16 | model = PeftModel.from_pretrained(model, lora_point).to(device)
17 |
18 |
19 | import json
20 |
21 | count = 0
22 | right_count = 0
23 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}instrutblip_lora_20240601.jsonl', 'w+', encoding="utf-8") as fout:
24 | for line in f:
25 | data = json.loads(line)
26 | question = data['question']
27 | id = data['id']
28 | options = data['options']
29 | image_name = data['image']
30 | image_filepath = image_dir + image_name
31 | image = Image.open(image_filepath)
32 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
33 | inputs = processor(images=image, text=question, return_tensors="pt").to(device)
34 | outputs = model.generate(
35 | **inputs,
36 | do_sample=False,
37 | num_beams=5,
38 | max_length=8,
39 | min_length=1,
40 | top_p=0.9,
41 | repetition_penalty=1.5,
42 | length_penalty=1.0,
43 | temperature=1,
44 | )
45 | generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
46 | print(generated_text)
47 | output = generated_text
48 |
49 | count += 1
50 | if len(output) == 0:
51 | output = '--'
52 | if output.lower() in data['answer']:
53 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
54 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
55 | right_count += 1
56 | elif data['answer'] in output.lower():
57 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
58 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
59 | right_count += 1
60 | else:
61 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
62 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
63 | print(f'{output.lower()}')
64 | print(f"{data['answer']}")
65 | print(f'right_count: {right_count}')
66 | print(f'count: {count}')
67 | print(f'accuracy: {right_count/count}')
68 |
69 | accuracy = right_count/count
70 | print(f'accuracy: {accuracy}')
71 |
--------------------------------------------------------------------------------
/Code/experiment/spatial_test_mplug.py:
--------------------------------------------------------------------------------
1 | from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
2 | from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
3 | from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
4 | import torch
5 | from PIL import Image
6 | import json
7 |
8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
9 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/prompt test/'
10 | pretrained_ckpt = 'MAGAer13/mplug-owl-llama-7b'
11 | device="cuda:7"
12 | model = MplugOwlForConditionalGeneration.from_pretrained(
13 | pretrained_ckpt,
14 | torch_dtype=torch.bfloat16,
15 | ).to(device)
16 | image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
17 | tokenizer = MplugOwlTokenizer.from_pretrained('MAGAer13/mplug-owl-llama-7b')
18 | processor = MplugOwlProcessor(image_processor, tokenizer)
19 |
20 | generate_kwargs = {
21 | 'do_sample': True,
22 | 'top_k': 1,
23 | 'max_length': 512,
24 | 'temperature':0.1
25 | }
26 |
27 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
28 |
29 | count = 0
30 | right_count = 0
31 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}mplug_1.jsonl', 'w+', encoding="utf-8") as fout:
32 | for line in f:
33 | data = json.loads(line)
34 | question = data['question']
35 | id = data['id']
36 | options = data['options']
37 | image_name = data['image']
38 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
39 | if not count:
40 | print(f'question:{question}')
41 | image_filepath = image_dir + image_name
42 | images = [Image.open(image_filepath)]
43 | prompts = [
44 | f'''The following is a conversation between a curious human and AI assistant.
45 | Human:
46 | Human: {question}
47 | AI: '''
48 | ]
49 | inputs = processor(text=prompts, images=images, return_tensors='pt')
50 | inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
51 | inputs = {k: v.to(model.device) for k, v in inputs.items()}
52 | with torch.no_grad():
53 | res = model.generate(**inputs, **generate_kwargs)
54 | sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
55 | print(sentence)
56 | count += 1
57 | output = sentence.strip().rstrip('.')
58 | if len(output) == 0:
59 | output = '--'
60 | if output.lower() in data['answer']:
61 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
62 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
63 | right_count += 1
64 | elif data['answer'] in output.lower():
65 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
66 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
67 | right_count += 1
68 | else:
69 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
70 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
71 | print(f'{output.lower()}')
72 | print(f"{data['answer']}")
73 | print(f'right_count: {right_count}')
74 | print(f'count: {count}')
75 | print(f'accuracy: {right_count/count}')
76 |
77 | accuracy = right_count/count
78 | print(f'accuracy: {accuracy}')
79 |
--------------------------------------------------------------------------------
/Metadata/metadata_hf.json:
--------------------------------------------------------------------------------
1 | {"@context":{"@language":"en","@vocab":"https://schema.org/","arrayShape":"cr:arrayShape","citeAs":"cr:citeAs","column":"cr:column","conformsTo":"dct:conformsTo","cr":"http://mlcommons.org/croissant/","data":{"@id":"cr:data","@type":"@json"},"dataBiases":"cr:dataBiases","dataCollection":"cr:dataCollection","dataType":{"@id":"cr:dataType","@type":"@vocab"},"dct":"http://purl.org/dc/terms/","extract":"cr:extract","field":"cr:field","fileProperty":"cr:fileProperty","fileObject":"cr:fileObject","fileSet":"cr:fileSet","format":"cr:format","includes":"cr:includes","isArray":"cr:isArray","isLiveDataset":"cr:isLiveDataset","jsonPath":"cr:jsonPath","key":"cr:key","md5":"cr:md5","parentField":"cr:parentField","path":"cr:path","personalSensitiveInformation":"cr:personalSensitiveInformation","recordSet":"cr:recordSet","references":"cr:references","regex":"cr:regex","repeated":"cr:repeated","replace":"cr:replace","sc":"https://schema.org/","separator":"cr:separator","source":"cr:source","subField":"cr:subField","transform":"cr:transform"},"@type":"sc:Dataset","distribution":[{"@type":"cr:FileObject","@id":"repo","name":"repo","description":"The Hugging Face git repository.","contentUrl":"https://huggingface.co/datasets/liuziyan/SpatialMQA/tree/refs%2Fconvert%2Fparquet","encodingFormat":"git+https","sha256":"https://github.com/mlcommons/croissant/issues/80"},{"@type":"cr:FileSet","@id":"parquet-files-for-config-default","containedIn":{"@id":"repo"},"encodingFormat":"application/x-parquet","includes":"default/*/*.parquet"}],"recordSet":[{"@type":"cr:RecordSet","dataType":"cr:Split","key":{"@id":"default_splits/split_name"},"@id":"default_splits","name":"default_splits","description":"Splits for the default config.","field":[{"@type":"cr:Field","@id":"default_splits/split_name","dataType":"sc:Text"}],"data":[{"default_splits/split_name":"train"},{"default_splits/split_name":"validation"},{"default_splits/split_name":"test"}]},{"@type":"cr:RecordSet","@id":"default","description":"liuziyan/SpatialMQA - 'default' subset\n\nAdditional information:\n- 3 splits: train, validation, test","field":[{"@type":"cr:Field","@id":"default/split","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"fileProperty":"fullpath"},"transform":{"regex":"default/(?:partial-)?(train|validation|test)/.+parquet$"}},"references":{"field":{"@id":"default_splits/split_name"}}},{"@type":"cr:Field","@id":"default/image","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"image"}}},{"@type":"cr:Field","@id":"default/question","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"question"}}},{"@type":"cr:Field","@id":"default/options","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"options"}},"isArray":true,"arrayShape":"-1"},{"@type":"cr:Field","@id":"default/answer","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":"answer"}}},{"@type":"cr:Field","@id":"default/_answer","dataType":"sc:Text","source":{"fileSet":{"@id":"parquet-files-for-config-default"},"extract":{"column":",answer"}}}]}],"conformsTo":"http://mlcommons.org/croissant/1.1","name":"SpatialMQA","description":"Welcome to explore our work titled \"Can Multimodal Large Language Models Understand Spatial Relations\".arXiv link: https://arxiv.org/abs/2505.19015.For more information about the paper and the SpatialMQA dataset, please visit our GitHub repository at https://github.com/ziyan-xiaoyu/SpatialMQA.\n\n\n\t\n\t\t\n\t\tlicense: cc-by-4.0\n\t\n\n","alternateName":["liuziyan/SpatialMQA"],"creator":{"@type":"Person","name":"刘子言","url":"https://huggingface.co/liuziyan"},"keywords":["1K - 10K","json","Image","Text","Datasets","pandas","Croissant","Polars","arxiv:2505.19015","🇺🇸 Region: US"],"url":"https://huggingface.co/datasets/liuziyan/SpatialMQA"}
2 |
--------------------------------------------------------------------------------
/Code/experiment/spatial_test_mplug_lora.py:
--------------------------------------------------------------------------------
1 | from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
2 | from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
3 | from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
4 | from peft import PeftModel, PeftConfig
5 | import torch
6 | from PIL import Image
7 | import json
8 |
9 |
10 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
11 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/m10/'
12 | pretrained_ckpt = 'MAGAer13/mplug-owl-llama-7b'
13 | device="cuda:6"
14 | peft_model_id = '/projects/SpatialMQA/finetune_models/models_arg/mplug_lora_20240530'
15 | model = MplugOwlForConditionalGeneration.from_pretrained(
16 | pretrained_ckpt,
17 | torch_dtype=torch.bfloat16,
18 | ).to(device)
19 | model = PeftModel.from_pretrained(model, peft_model_id).to(device)
20 | image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
21 | tokenizer = MplugOwlTokenizer.from_pretrained('MAGAer13/mplug-owl-llama-7b')
22 | processor = MplugOwlProcessor(image_processor, tokenizer)
23 |
24 | generate_kwargs = {
25 | 'do_sample': True,
26 | 'top_k': 0,
27 | 'max_length': 512,
28 | 'temperature': 0.1
29 | }
30 |
31 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
32 | count = 0
33 | right_count = 0
34 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}mplug_lora_0.jsonl', 'w+', encoding="utf-8") as fout:
35 | for line in f:
36 | data = json.loads(line)
37 | question = data['question']
38 | id = data['id']
39 | options = data['options']
40 | image_name = data['image']
41 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
42 | if not count:
43 | print(f'question:{question}')
44 | image_filepath = image_dir + image_name
45 | images = [Image.open(image_filepath)]
46 | prompts = [
47 | f'''The following is a conversation between a curious human and AI assistant.
48 | Human:
49 | Human: {question}
50 | AI: '''
51 | ]
52 | inputs = processor(text=prompts, images=images, return_tensors='pt')
53 | inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
54 | inputs = {k: v.to(model.device) for k, v in inputs.items()}
55 | with torch.no_grad():
56 | res = model.generate(**inputs, **generate_kwargs)
57 | sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
58 | print(sentence)
59 | count += 1
60 | output = sentence.strip().rstrip('.')
61 | if len(output) == 0:
62 | output = '--'
63 | if output.lower() in data['answer']:
64 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
65 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
66 | right_count += 1
67 | elif data['answer'] in output.lower():
68 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
69 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
70 | right_count += 1
71 | else:
72 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
73 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
74 | print(f'{output.lower()}')
75 | print(f"{data['answer']}")
76 | print(f'right_count: {right_count}')
77 | print(f'count: {count}')
78 | print(f'accuracy: {right_count/count}')
79 |
80 | accuracy = right_count/count
81 | print(f'accuracy: {accuracy}')
82 |
--------------------------------------------------------------------------------
/Code/experiment/idefics_new.py:
--------------------------------------------------------------------------------
1 | # this is a demo of inference of IDEFICS-9B which needs about 20GB of GPU memory
2 | import json
3 | import re
4 | from PIL import Image
5 | import torch
6 | from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig
7 | from peft import PeftModel, PeftConfig
8 |
9 | device = "cuda:5"
10 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/idefics-10/'
11 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
12 |
13 | model_path = '/projects/Models/idefics/'
14 |
15 | bnb_config = BitsAndBytesConfig(
16 | load_in_4bit=True,
17 | bnb_4bit_use_double_quant=True,
18 | bnb_4bit_quant_type="nf4",
19 | bnb_4bit_compute_dtype=torch.float16,
20 | llm_int8_skip_modules=["lm_head", "embed_tokens"],
21 | )
22 |
23 | config = PeftConfig.from_pretrained(model_path)
24 | processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
25 | model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path,quantization_config=bnb_config,device_map="cuda:5")
26 |
27 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
28 | count = 0
29 | right_count = 0
30 |
31 | def check_inference(model, processor, prompts, max_new_tokens=50):
32 | tokenizer = processor.tokenizer
33 | bad_words = ["", ""]
34 | if len(bad_words) > 0:
35 | bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
36 |
37 | eos_token = ""
38 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
39 |
40 | inputs = processor(prompts, return_tensors="pt").to(device)
41 | generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True,
42 | )
43 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
44 | return generated_text
45 |
46 |
47 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}idefics_test.jsonl', 'w+', encoding="utf-8") as fout:
48 | for line in f:
49 | data = json.loads(line)
50 | question = data['question']
51 | id = data['id']
52 | options = data['options']
53 | image_name = data['image']
54 | image_filepath = image_dir + image_name
55 | image = Image.open(image_filepath)
56 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
57 | prompts = [
58 | image,
59 | question,
60 | ]
61 |
62 | generated_text = check_inference(model, processor, prompts, max_new_tokens=5)
63 | print(f'generated_text:\n{generated_text}')
64 | output = generated_text.lower().split('answer:')[-1].split('\n')[0].strip().rstrip('.')
65 | count += 1
66 | if len(output) == 0:
67 | output = '--'
68 | if output in data['answer']:
69 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
70 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
71 | right_count += 1
72 | elif data['answer'] in output:
73 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
74 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
75 | right_count += 1
76 | else:
77 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
78 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
79 | print(f'{output.lower()}')
80 | print(f"{data['answer']}")
81 | print(f'right_count: {right_count}')
82 | print(f'count: {count}')
83 | print(f'accuracy: {right_count/count}')
84 |
85 | accuracy = right_count/count
86 | print(f'accuracy: {accuracy}')
--------------------------------------------------------------------------------
/Code/experiment/idefics_lora.py:
--------------------------------------------------------------------------------
1 | # this is a demo of inference of IDEFICS-9B which needs about 20GB of GPU memory
2 |
3 | import torch
4 | from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig
5 |
6 | device = "cuda:7" if torch.cuda.is_available() else "cpu"
7 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/idefics-10/'
8 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
9 |
10 | lora_point = '/projects/SpatialMQA/finetune_models/models_arg/idefics_lora_20240521/checkpoint-35'
11 |
12 | from peft import PeftModel, PeftConfig
13 |
14 | bnb_config = BitsAndBytesConfig(
15 | load_in_4bit=True,
16 | bnb_4bit_use_double_quant=True,
17 | bnb_4bit_quant_type="nf4",
18 | bnb_4bit_compute_dtype=torch.float16,
19 | llm_int8_skip_modules=["lm_head", "embed_tokens"],
20 | )
21 |
22 | config = PeftConfig.from_pretrained(lora_point)
23 | processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
24 | model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path,quantization_config=bnb_config,device_map="cuda:7")
25 | model = PeftModel.from_pretrained(model, lora_point).to(device)
26 |
27 |
28 | import json
29 | import re
30 | from PIL import Image
31 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
32 | count = 0
33 | right_count = 0
34 |
35 | def check_inference(model, processor, prompts, max_new_tokens=50):
36 | tokenizer = processor.tokenizer
37 | bad_words = ["", ""]
38 | if len(bad_words) > 0:
39 | bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
40 |
41 | eos_token = ""
42 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
43 |
44 | inputs = processor(prompts, return_tensors="pt").to(device)
45 | generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True,
46 | do_sample=True, temperature=0.7)
47 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
48 | return generated_text
49 |
50 |
51 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}idefics_lora_7.jsonl', 'w+', encoding="utf-8") as fout:
52 | for line in f:
53 | data = json.loads(line)
54 | question = data['question']
55 | id = data['id']
56 | options = data['options']
57 | image_name = data['image']
58 | image_filepath = image_dir + image_name
59 | image = Image.open(image_filepath)
60 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \nOutput:'
61 | prompts = [
62 | image,
63 | question,
64 | ]
65 |
66 | generated_text = check_inference(model, processor, prompts, max_new_tokens=5)
67 | print(f'generated_text:\n{generated_text}')
68 | output = generated_text.lower().split('answer:')[-1].split('\n')[0].strip().rstrip('.')
69 | count += 1
70 | if len(output) == 0:
71 | output = '--'
72 | if output.lower() in data['answer']:
73 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
74 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
75 | right_count += 1
76 | elif data['answer'] in output:
77 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
78 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
79 | right_count += 1
80 | else:
81 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
82 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
83 | print(f'{output.lower()}')
84 | print(f"{data['answer']}")
85 | print(f'right_count: {right_count}')
86 | print(f'count: {count}')
87 | print(f'accuracy: {right_count/count}')
88 |
89 | accuracy = right_count/count
90 | print(f'accuracy: {accuracy}')
91 |
--------------------------------------------------------------------------------
/Code/close_models/gemini_text_only.py:
--------------------------------------------------------------------------------
1 | import google.generativeai as genai
2 | from io import BytesIO
3 | import requests
4 |
5 | genai.configure(api_key="your key",transport="rest")
6 |
7 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480}
8 | safety_settings = [
9 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
10 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
11 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
12 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}
13 | ]
14 |
15 | FILE_PATH = 'datasets/test_en_select_500_sort.jsonl'
16 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
17 | RESULT_FILE_PATH = 'model_results/gemini_0_shot.jsonl'
18 |
19 |
20 | def fetch_image_content(image_url):
21 | response = requests.get(image_url)
22 | if response.status_code == 200:
23 | return BytesIO(response.content)
24 | else:
25 | return None
26 |
27 |
28 | import PIL.Image as Image
29 |
30 | # few-shot & zero-shot:
31 | # model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings)
32 |
33 | # text-only:
34 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings)
35 |
36 | import json
37 |
38 | count = 0
39 | right_count = 0
40 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'w+', encoding="utf-8") as fout:
41 | for line in f:
42 | data = json.loads(line)
43 | id = data['id']
44 |
45 | # text-only:
46 | question = f'You are currently a senior expert in spatial relation reasoning. ' \
47 | f'\nGiven a Question and Options, your task is to use your common sense to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
48 | f'\nInput: Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
49 |
50 | image_name = data['image']
51 | image_filepath = f'{IMAGE_DIR}/{image_name}'
52 |
53 | image_content = fetch_image_content(image_filepath)
54 | if image_content is not None:
55 | image = Image.open(image_content)
56 | try:
57 | response = model.generate_content(
58 | # [question, image] # few-shot & zero-shot
59 | [question] # text-only
60 | )
61 | except Exception as e:
62 | print(e)
63 | try:
64 | response = model.generate_content(
65 | # [question, image] # few-shot & zero-shot
66 | [question] # text-only
67 | )
68 | except Exception as e:
69 | try:
70 | response = model.generate_content(
71 | # [question, image] # few-shot & zero-shot
72 | [question] # text-only
73 | )
74 | except Exception as e:
75 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": 0}
76 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
77 | count += 1
78 | continue
79 | try:
80 | output = response.text.strip().rstrip('.').lower()
81 | except Exception as e:
82 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0}
83 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
84 | count += 1
85 | continue
86 | count += 1
87 | if output in data['answer']:
88 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0}
89 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
90 | right_count += 1
91 | else:
92 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0}
93 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
94 | print(f'{output.lower()}')
95 | print(f"{data['answer']}")
96 | print(f'right_count: {right_count}')
97 | print(f'count: {count}')
98 |
99 | accuracy = right_count / count
100 | print(accuracy)
101 |
--------------------------------------------------------------------------------
/Code/close_models/gemini_0_shot.py:
--------------------------------------------------------------------------------
1 | from time import sleep
2 |
3 | import google.generativeai as genai
4 | from io import BytesIO
5 | import requests
6 |
7 | genai.configure(api_key="your key",transport="rest")
8 |
9 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480}
10 | safety_settings = [
11 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
12 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
13 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
14 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}
15 | ]
16 |
17 | FILE_PATH = 'datasets/test_en_select_500_sort.jsonl'
18 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
19 | RESULT_FILE_PATH = 'model_results/gemini_0_shot.jsonl'
20 |
21 |
22 | def fetch_image_content(image_url):
23 | response = requests.get(image_url)
24 | if response.status_code == 200:
25 | return BytesIO(response.content)
26 | else:
27 | return None
28 |
29 |
30 | import PIL.Image as Image
31 |
32 | # few-shot & zero-shot:
33 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings)
34 |
35 | # text-only:
36 | # model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings)
37 |
38 | import json
39 |
40 | count = 0
41 | right_count = 0
42 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout:
43 | for line in f:
44 | sleep(1)
45 | data = json.loads(line)
46 | id = data['id']
47 |
48 | # 1 - shot:
49 | question = f'You are currently a senior expert in spatial relation reasoning. ' \
50 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
51 | f'\nInput: Image:, Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
52 |
53 | image_name = data['image']
54 | image_filepath = f'{IMAGE_DIR}/{image_name}'
55 |
56 | image_content = fetch_image_content(image_filepath)
57 | if image_content is not None:
58 | image = Image.open(image_content)
59 | try:
60 | response = model.generate_content(
61 | [question, image] # few-shot & zero-shot
62 | # [question] # text-only
63 | )
64 | except Exception as e:
65 | print(e)
66 | try:
67 | response = model.generate_content(
68 | [question, image] # few-shot & zero-shot
69 | # [question] # text-only
70 | )
71 | except Exception as e:
72 | try:
73 | response = model.generate_content(
74 | [question, image] # few-shot & zero-shot
75 | # [question] # text-only
76 | )
77 | except Exception as e:
78 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": 0}
79 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
80 | count += 1
81 | continue
82 | try:
83 | output = response.text.strip().rstrip('.').lower()
84 | except Exception as e:
85 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0}
86 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
87 | count += 1
88 | continue
89 | count += 1
90 | if output in data['answer']:
91 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0}
92 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
93 | right_count += 1
94 | else:
95 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": 0}
96 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
97 | print(f'{output.lower()}')
98 | print(f"{data['answer']}")
99 | print(f'right_count: {right_count}')
100 | print(f'count: {count}')
101 |
102 | accuracy = right_count / count
103 | print(accuracy)
104 |
--------------------------------------------------------------------------------
/Metadata/metadata.json:
--------------------------------------------------------------------------------
1 | {
2 | "@context": {
3 | "@language": "en",
4 | "@vocab": "https://schema.org/",
5 | "citeAs": "cr:citeAs",
6 | "column": "cr:column",
7 | "conformsTo": "dct:conformsTo",
8 | "cr": "http://mlcommons.org/croissant/",
9 | "rai": "http://mlcommons.org/croissant/RAI/",
10 | "data": {
11 | "@id": "cr:data",
12 | "@type": "@json"
13 | },
14 | "dataType": {
15 | "@id": "cr:dataType",
16 | "@type": "@vocab"
17 | },
18 | "dct": "http://purl.org/dc/terms/",
19 | "examples": {
20 | "@id": "cr:examples",
21 | "@type": "@json"
22 | },
23 | "extract": "cr:extract",
24 | "field": "cr:field",
25 | "fileProperty": "cr:fileProperty",
26 | "fileObject": "cr:fileObject",
27 | "fileSet": "cr:fileSet",
28 | "format": "cr:format",
29 | "includes": "cr:includes",
30 | "isLiveDataset": "cr:isLiveDataset",
31 | "jsonPath": "cr:jsonPath",
32 | "key": "cr:key",
33 | "md5": "cr:md5",
34 | "parentField": "cr:parentField",
35 | "path": "cr:path",
36 | "recordSet": "cr:recordSet",
37 | "references": "cr:references",
38 | "regex": "cr:regex",
39 | "repeated": "cr:repeated",
40 | "replace": "cr:replace",
41 | "sc": "https://schema.org/",
42 | "separator": "cr:separator",
43 | "source": "cr:source",
44 | "subField": "cr:subField",
45 | "transform": "cr:transform"
46 | },
47 | "@type": "sc:Dataset",
48 | "name": "SpatialMQA",
49 | "description": "SpatialMQA is a manually annotated dataset designed for multimodal spatial relation reasoning in a multiple-choice question & answer format. The dataset includes 5,392 samples collected from COCO2017, covering 128 subject and object types, without bounding boxes.",
50 | "conformsTo": "http://mlcommons.org/croissant/1.0",
51 | "license": "CC-BY 4.0",
52 | "url": " https://anonymous.4open.science/r/SpatialMQA", // Anonymous link, will be modified to real name link after the paper is reviewed
53 | "version": "1.0.0",
54 | "distribution": [
55 | {
56 | "@type": "cr:FileObject",
57 | "@id": "github-repository",
58 | "name": "github-repository",
59 | "description": "SpatialMQA repository on GitHub.",
60 | "contentUrl": " https://anonymous.4open.science/r/SpatialMQA",
61 | "encodingFormat": "git+https",
62 | "sha256": "main"
63 | },
64 | {
65 | "@type": "cr:FileSet",
66 | "@id": "jsonl-files",
67 | "name": "jsonl-files",
68 | "description": "JSONL files are hosted on the GitHub repository.",
69 | "containedIn": {
70 | "@id": "github-repository"
71 | },
72 | "encodingFormat": "application/jsonlines",
73 | "includes": "Dataset/dataset/*.jsonl"
74 | }
75 | ],
76 | "recordSet": [
77 | {
78 | "@type": "cr:RecordSet",
79 | "@id": "jsonl",
80 | "name": "jsonl",
81 | "field": [
82 | {
83 | "@type": "cr:Field",
84 | "@id": "jsonl/image",
85 | "name": "image",
86 | "dataType": "sc:Text",
87 | "source": {
88 | "fileSet": {
89 | "@id": "jsonl-files"
90 | },
91 | "extract": {
92 | "column": "image"
93 | }
94 | }
95 | },
96 | {
97 | "@type": "cr:Field",
98 | "@id": "jsonl/question",
99 | "name": "question",
100 | "description": "The expected question of the promt.",
101 | "dataType": "sc:Text",
102 | "source": {
103 | "fileSet": {
104 | "@id": "jsonl-files"
105 | },
106 | "extract": {
107 | "column": "question"
108 | }
109 | }
110 | },
111 | {
112 | "@type": "cr:Field",
113 | "@id": "jsonl/options",
114 | "name": "options",
115 | "description": "The expected options of the promt.",
116 | "dataType": "sc:Text",
117 | "source": {
118 | "fileSet": {
119 | "@id": "jsonl-files"
120 | },
121 | "extract": {
122 | "column": "options"
123 | }
124 | }
125 | },
126 | {
127 | "@type": "cr:Field",
128 | "@id": "jsonl/answer",
129 | "name": "answer",
130 | "description": "The expected options of the promt.",
131 | "dataType": "sc:Text",
132 | "source": {
133 | "fileSet": {
134 | "@id": "jsonl-files"
135 | },
136 | "extract": {
137 | "column": "answer"
138 | }
139 | }
140 | },
141 | {
142 | "@type": "cr:Field",
143 | "@id": "jsonl/task",
144 | "name": "task",
145 | "description": "The machine learning task appearing as the name of the file.",
146 | "dataType": "sc:Text",
147 | "source": {
148 | "fileSet": {
149 | "@id": "jsonl-files"
150 | },
151 | "extract": {
152 | "fileProperty": "filename"
153 | },
154 | "transform": {
155 | "regex": "^(.*)\\.jsonl$"
156 | }
157 | }
158 | }
159 | ]
160 | }
161 | ]
162 | }
163 |
--------------------------------------------------------------------------------
/Code/finetune/idefics.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/image_captioning.ipynb
2 |
3 | # This example demonstrates normal finetuning (w/o peft) - for the sake of keeping the memory
4 | # requirements small it freezes the original pre-trained text and image layers to keep the memory
5 | # requirements to just 40GB. If you have multiple GPUs then you can remove the unfreeze part to
6 | # finetune the whole model. Alternatively use the PEFT solution as shown in
7 | # IDEFICS_finetuning_demo.ipynb notebook which requires only 20GB to finetune the whole model.
8 |
9 | import torch
10 | from datasets import load_dataset
11 | from peft import LoraConfig, get_peft_model
12 | from PIL import Image
13 | from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
14 | import torchvision.transforms as transforms
15 |
16 | device = "cuda:7" if torch.cuda.is_available() else "cpu"
17 |
18 | checkpoint = "HuggingFaceM4/idefics-9b"
19 |
20 | # Here we skip some special modules that can't be quantized properly
21 | bnb_config = BitsAndBytesConfig(
22 | load_in_4bit=True,
23 | bnb_4bit_use_double_quant=True,
24 | bnb_4bit_quant_type="nf4",
25 | bnb_4bit_compute_dtype=torch.float16,
26 | llm_int8_skip_modules=["lm_head", "embed_tokens"],
27 | )
28 |
29 | processor = AutoProcessor.from_pretrained(checkpoint)
30 | # Simply take-off the quantization_config arg if you want to load the original model
31 | model = IdeficsForVisionText2Text.from_pretrained(checkpoint, quantization_config=bnb_config, device_map="cuda:0")
32 |
33 |
34 | def check_inference(model, processor, prompts, max_new_tokens=50):
35 | tokenizer = processor.tokenizer
36 | bad_words = ["", ""]
37 | if len(bad_words) > 0:
38 | bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
39 |
40 | eos_token = ""
41 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
42 |
43 | inputs = processor(prompts, return_tensors="pt").to(device)
44 | generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids,
45 | max_new_tokens=max_new_tokens, early_stopping=True)
46 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
47 | print(generated_text)
48 |
49 |
50 | url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg"
51 | prompts = [
52 | # "Instruction: provide an answer to the question. Use the image to answer.\n",
53 | url,
54 | "Question: What's on the picture? Answer:",
55 | ]
56 | check_inference(model, processor, prompts)
57 |
58 | train_ds = load_dataset("json", data_files='/projects/SpatialMQA/datasets/idefics_train/train_3780.jsonl')['train']
59 | eval_ds = load_dataset("json", data_files='/projects/SpatialMQA/datasets/idefics_train/dev_536.jsonl')['train']
60 |
61 |
62 | def convert_to_rgb(image):
63 | # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
64 | # for transparent images. The call to `alpha_composite` handles this case
65 | image = Image.open(image)
66 | if image.mode == "RGB":
67 | return image
68 |
69 | image_rgba = image.convert("RGBA")
70 | background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
71 | alpha_composite = Image.alpha_composite(background, image_rgba)
72 | alpha_composite = alpha_composite.convert("RGB")
73 | return alpha_composite
74 |
75 |
76 | def ds_transforms(example_batch):
77 | image_size = processor.image_processor.image_size
78 | image_mean = processor.image_processor.image_mean
79 | image_std = processor.image_processor.image_std
80 |
81 | image_transform = transforms.Compose([
82 | convert_to_rgb,
83 | transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0),
84 | interpolation=transforms.InterpolationMode.BICUBIC),
85 | transforms.ToTensor(),
86 | transforms.Normalize(mean=image_mean, std=image_std),
87 | ])
88 |
89 | prompts = []
90 | for i in range(len(example_batch)):
91 | prompts.append(
92 | [
93 | example_batch["image"][i],
94 | example_batch["text"][i],
95 | ],
96 | )
97 |
98 | inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)
99 |
100 | inputs["labels"] = inputs["input_ids"]
101 |
102 | return inputs
103 |
104 |
105 | train_ds.set_transform(ds_transforms)
106 | eval_ds.set_transform(ds_transforms)
107 |
108 | model_name = checkpoint.split("/")[1]
109 | config = LoraConfig(
110 | r=16,
111 | lora_alpha=32,
112 | target_modules=["q_proj", "k_proj", "v_proj"],
113 | lora_dropout=0.05,
114 | bias="none",
115 | )
116 | model = get_peft_model(model, config)
117 |
118 | num_train_epochs = 10
119 |
120 | training_args = TrainingArguments(
121 | output_dir=f"/projects/SpatialMQA/finetune_models/models_arg/idefics_lora_20240521",
122 | num_train_epochs=5,
123 | learning_rate=2e-4,
124 | bf16=True,
125 | fp16=False,
126 | per_device_train_batch_size=8,
127 | per_device_eval_batch_size=8,
128 | gradient_accumulation_steps=8,
129 | dataloader_pin_memory=False,
130 | save_total_limit=1,
131 | evaluation_strategy="epoch",
132 | save_strategy="epoch",
133 | logging_steps=100,
134 | remove_unused_columns=False,
135 | push_to_hub=False,
136 | label_names=["labels"],
137 | load_best_model_at_end=True,
138 | report_to=None,
139 | optim="paged_adamw_8bit",
140 | )
141 |
142 | trainer = Trainer(
143 | model=model,
144 | args=training_args,
145 | train_dataset=train_ds,
146 | eval_dataset=eval_ds,
147 | )
148 |
149 | trainer.train()
150 |
151 | model.save_pretrained("/projects/SpatialMQA/finetune_models/models_arg/idefics_lora_20240521")
152 |
153 | check_inference(model, processor, prompts, max_new_tokens=100)
154 |
--------------------------------------------------------------------------------
/Code/close_models/gpt4_text_only.py:
--------------------------------------------------------------------------------
1 | # import openai
2 | from openai import OpenAI
3 | import json
4 | import requests
5 | from io import BytesIO
6 | import PIL.Image as Image
7 | import base64
8 |
9 | client = OpenAI(
10 | api_key='your key',
11 | base_url='https://gpt.mnxcc.com/v1'
12 | )
13 | client.api_base = "https://api.foureast.cn/v1"
14 |
15 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
16 | count = 0
17 | right_count = 0
18 |
19 |
20 | def fetch_image_content(image_url):
21 | response = requests.get(image_url)
22 | if response.status_code == 200:
23 | return BytesIO(response.content)
24 | else:
25 | return None
26 |
27 |
28 | def encode_image(image):
29 | if image is None:
30 | return None
31 |
32 | buffered = BytesIO()
33 | try:
34 | image.save(buffered, format="JPEG")
35 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
36 | return 'data:image/jpeg;base64,' + img_str
37 | except Exception as e:
38 | print(f"encoding error: {e}")
39 | return None
40 |
41 |
42 | def call_gpt4(prompt: str, question: str, image):
43 | try:
44 | response = client.chat.completions.create(
45 | model="gpt-4o",
46 | messages=[
47 | {
48 | "role": "user",
49 |
50 | # text only:
51 | "content": [{"type": "text", "text": prompt}] \
52 | + [{"type": "text", "text": question}]
53 | }
54 | ],
55 | max_tokens=500,
56 | # temperature = 0.3,
57 | )
58 | # print(response.choices[0].message.content.strip())
59 | return response.choices[0].message.content.strip()
60 |
61 | except Exception as e:
62 | print(f"Error during answering: {e}")
63 | return None
64 |
65 |
66 | def process_jsonl(input_file, output_file):
67 | with open(input_file, 'r', encoding='utf-8') as file:
68 | with open(output_file, 'w', encoding='utf-8') as out_file:
69 | for line in file:
70 | data = json.loads(line)
71 | id = data['id']
72 |
73 | # text - only
74 | prompt = f'You are currently a senior expert in spatial relation reasoning. ' \
75 | f'\nGiven a Question and Options, your task is to use your common sense to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.'
76 | question = f'\nInput: Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
77 |
78 | image_id = data['image']
79 | image_url = f'{IMAGE_DIR}/{image_id}'
80 | image_content = fetch_image_content(image_url)
81 |
82 | if image_content is not None:
83 | image = Image.open(image_content)
84 | image_encoded = encode_image(image)
85 |
86 | global count
87 | global right_count
88 | try:
89 | model_answer = call_gpt4(prompt, question, image_encoded)
90 | except Exception as e:
91 | print(e)
92 | try:
93 | model_answer = call_gpt4(prompt, question, image_encoded)
94 | except Exception as e:
95 | try:
96 | model_answer = call_gpt4(prompt, question, image_encoded)
97 | except Exception as e:
98 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'],
99 | "rule": data['rule'], "example": 0}
100 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
101 | count += 1
102 | continue
103 |
104 | try:
105 | # output = model_answer.text.strip().rstrip('.').lower()
106 | output = model_answer.strip().rstrip('.').lower()
107 | except Exception as e:
108 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'],
109 | "rule": data['rule'], "example": 0}
110 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
111 | count += 1
112 | continue
113 | count += 1
114 | if output in data['answer']:
115 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'],
116 | "rule": data['rule'], "example": 0}
117 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
118 | right_count += 1
119 | else:
120 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'],
121 | "rule": data['rule'], "example": 0}
122 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
123 | print(f'{output.lower()}')
124 | print(f"{data['answer']}")
125 | print(f'right_count: {right_count}')
126 | print(f'count: {count}')
127 | # print(f'accuracy: {right_count / count}')
128 |
129 | accuracy = right_count / count
130 | print(accuracy)
131 |
132 |
133 | input_file_path = "datasets/test_en_500.jsonl"
134 | output_file_path = "model_results/gpt4_text_only.jsonl"
135 |
136 | process_jsonl(input_file_path, output_file_path)
137 |
--------------------------------------------------------------------------------
/Code/close_models/gpt4_zero_shot.py:
--------------------------------------------------------------------------------
1 | # import openai
2 | from openai import OpenAI
3 | import json
4 | import requests
5 | from io import BytesIO
6 | import PIL.Image as Image
7 | import base64
8 |
9 |
10 | client = OpenAI(
11 | api_key='your key',
12 | base_url='https://gpt.mnxcc.com/v1'
13 | )
14 | client.api_base = "https://api.foureast.cn/v1"
15 |
16 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
17 | count = 0
18 | right_count = 0
19 |
20 |
21 | def fetch_image_content(image_url):
22 | response = requests.get(image_url)
23 | if response.status_code == 200:
24 | return BytesIO(response.content)
25 | else:
26 | return None
27 |
28 |
29 | def encode_image(image):
30 | if image is None:
31 | return None
32 |
33 | buffered = BytesIO()
34 | try:
35 | image.save(buffered, format="JPEG")
36 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
37 | return 'data:image/jpeg;base64,' + img_str
38 | except Exception as e:
39 | print(f"encoding error: {e}")
40 | return None
41 |
42 |
43 | def call_gpt4(prompt: str, image, detail='auto'):
44 | try:
45 | response = client.chat.completions.create(
46 | model="gpt-4o",
47 | messages=[
48 | {
49 | "role": "user",
50 | # zero-shot + few-shot
51 | "content": [{"type": "text", "text": prompt}] \
52 | + [{"type": "image_url", "image_url": { "url": image, "detail": detail}}]
53 | }
54 | ],
55 | max_tokens=500,
56 | temperature=0.5,
57 | )
58 | # print(response.choices[0].message.content.strip())
59 | return response.choices[0].message.content.strip()
60 |
61 | except Exception as e:
62 | print(f"Error during answering: {e}")
63 | return None
64 |
65 |
66 | def process_jsonl(input_file, output_file):
67 | with open(input_file, 'r', encoding='utf-8') as file:
68 | with open(output_file, 'w', encoding='utf-8') as out_file:
69 | for line in file:
70 | data = json.loads(line)
71 | id = data['id']
72 |
73 | # zero - shot
74 | prompt = f'You are currently a senior expert in spatial relation reasoning. ' \
75 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
76 | f'\nInput: Image:, Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
77 |
78 | image_id = data['image']
79 | image_url = f'{IMAGE_DIR}/{image_id}'
80 | image_content = fetch_image_content(image_url)
81 |
82 | if image_content is not None:
83 | image = Image.open(image_content)
84 | image_encoded = encode_image(image)
85 |
86 | global count
87 | global right_count
88 | try:
89 | model_answer = call_gpt4(prompt, image_encoded)
90 | except Exception as e:
91 | print(e)
92 | try:
93 | model_answer = call_gpt4(prompt, image_encoded)
94 | except Exception as e:
95 | try:
96 | model_answer = call_gpt4(prompt, image_encoded)
97 | except Exception as e:
98 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'],
99 | "rule": data['rule'], "example": 0}
100 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
101 | count += 1
102 | continue
103 |
104 | try:
105 | # output = model_answer.text.strip().rstrip('.').lower()
106 | output = model_answer.strip().rstrip('.').lower()
107 | except Exception as e:
108 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'],
109 | "rule": data['rule'], "example": 0}
110 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
111 | count += 1
112 | continue
113 | count += 1
114 | if output in data['answer']:
115 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'],
116 | "rule": data['rule'], "example": 0}
117 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
118 | right_count += 1
119 | else:
120 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'],
121 | "rule": data['rule'], "example": 0}
122 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
123 | print(f'{output.lower()}')
124 | print(f"{data['answer']}")
125 | print(f'right_count: {right_count}')
126 | print(f'count: {count}')
127 | # print(f'accuracy: {right_count / count}')
128 |
129 | accuracy = right_count / count
130 | print(accuracy)
131 |
132 |
133 | input_file_path = "datasets/test_en_500.jsonl"
134 | output_file_path = "model_results/gpt4_zero_shot.jsonl"
135 |
136 |
137 | process_jsonl(input_file_path, output_file_path)
138 |
--------------------------------------------------------------------------------
/Code/eval/calculate_xyz.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | all_correct = 0
5 | position = {'y': ["on/above", "below"], 'z': ["in front of", "behind"],
6 | 'x': ["left of", "right of"]}
7 |
8 |
9 | def load_jsonl(file_path):
10 | data = []
11 | with open(file_path, 'r', encoding='utf-8') as f:
12 | for line in f:
13 | data.append(json.loads(line.strip()))
14 | return data
15 |
16 |
17 | def change_output(predict):
18 | if "front" in predict['output']:
19 | predict['output'] = "in front of"
20 | elif "behind" in predict['output']:
21 | predict['output'] = "behind"
22 | elif "left" in predict['output']:
23 | predict['output'] = "left of"
24 | elif "right" in predict['output']:
25 | predict['output'] = "right of"
26 | elif "below" in predict['output']:
27 | predict['output'] = "below"
28 | elif "on" in predict['output'] or "above" in predict['output']:
29 | predict['output'] = "on/above"
30 | else:
31 | predict['output'] = "--"
32 | return predict['output']
33 |
34 |
35 | def calculate_accuracy(predictions, option):
36 | correct_count = 0
37 | total_count = 0
38 | option1 = position[option][0]
39 | option2 = position[option][1]
40 | for predict in predictions:
41 | if predict['answer'] == option1 or predict['answer'] == option2:
42 | total_count += 1
43 | if predict['result'] == 1:
44 | correct_count += 1
45 | global all_correct
46 | all_correct += correct_count
47 | if total_count == 0:
48 | return 0.0
49 | print(option + ": " + str(correct_count) + "; " + str(total_count))
50 | return correct_count / total_count
51 |
52 |
53 | def main():
54 | # test_file_list = os.listdir('model_exp_results/m10/')
55 | test_file_list = os.listdir('model_exp_results/move_to_localhost/')
56 | for test_file in test_file_list:
57 | print(f'-------------------------------{test_file}----------------------------------')
58 | test_file_path = f'model_exp_results/move_to_localhost/{test_file}'
59 | # test_file_path = f'model_exp_results/m10/{test_file}'
60 |
61 | # 重新处理一下on/above的输出结果:
62 | predictions = load_jsonl(test_file_path)
63 | change_count = 0
64 | # in_front_of_count = 0
65 | for predict in predictions:
66 | if 501 > predict['id'] > 0:
67 | if "on" in predict['output'] or "above" in predict['output']:
68 | predict['result'] = 1
69 | change_count += 1
70 | # elif 1756 > predict['id'] > 945: # 处理in front of(id:958~1751)
71 | # if "front" in predict['output'] or "front" in predict['answer']:
72 | # in_front_of_count += 1
73 | # for predict in predictions:
74 | # if 512 > predict['id']: # 处理on/above(id:1~500)
75 | # if "on" in predict['output'] or "above" in predict['output']:
76 | # predict['result'] = 1
77 | # predict['output'] = "on/above"
78 | # change_count += 1
79 | # else:
80 | # predict['output'] = change_output(predict)
81 | # elif 958 > predict['id'] > 500: # 处理below(id:512~945)
82 | # if "below" in predict['output']:
83 | # predict['result'] = 1
84 | # predict['output'] = "below"
85 | # change_count += 1
86 | # else:
87 | # predict['output'] = change_output(predict)
88 | # elif 1756 > predict['id'] > 945: # 处理in front of(id:958~1751)
89 | # if "front" in predict['output']:
90 | # predict['result'] = 1
91 | # predict['output'] = "in front of"
92 | # change_count += 1
93 | # else:
94 | # predict['output'] = change_output(predict)
95 | # elif 2513 > predict['id'] > 1751: # 处理behind(id:1756~2509)
96 | # if "behind" in predict['output']:
97 | # predict['result'] = 1
98 | # predict['output'] = "behind"
99 | # change_count += 1
100 | # else:
101 | # predict['output'] = change_output(predict)
102 | # elif 3910 > predict['id'] > 2509: # 处理left of(id:2513~3906)
103 | # if "left" in predict['output']:
104 | # predict['result'] = 1
105 | # predict['output'] = "left of"
106 | # change_count += 1
107 | # else:
108 | # predict['output'] = change_output(predict)
109 | # elif predict['id'] > 3906: # 处理right of(id:3910~5391)
110 | # if "right" in predict['output']:
111 | # predict['result'] = 1
112 | # predict['output'] = "right of"
113 | # change_count += 1
114 | # else:
115 | # predict['output'] = change_output(predict)
116 |
117 | # with open(test_file_path, 'w', encoding='utf-8') as f:
118 | # for predict in predictions:
119 | # f.write(json.dumps(predict) + '\n')
120 | # print(f"修改此文件中的on/above的结果数据共:{change_count}条")
121 | # print(f"此文件中的in front of的结果数据共:{in_front_of_count}条")
122 |
123 | predictions = load_jsonl(test_file_path)
124 | print(len(predictions))
125 | # answers = load_jsonl('../SpatialMQA(En)/all_en_test_1076_sort.jsonl')
126 |
127 | accuracies = {}
128 | for option in ['x', 'y', 'z']:
129 | accuracies[option] = calculate_accuracy(predictions, option)
130 |
131 | global all_correct
132 | for option, accuracy in accuracies.items():
133 | print(f"Accuracy for '{option}': {accuracy:.4f}")
134 | print(f"Accuracy for 'overall': {(all_correct / len(predictions)):.4f}")
135 | all_correct = 0
136 |
137 |
138 | if __name__ == "__main__":
139 | main()
140 |
--------------------------------------------------------------------------------
/Code/finetune/blip-vqa-base.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import requests
4 | from transformers import BlipProcessor, BlipForQuestionAnswering
5 | from datasets import load_dataset
6 | import torch
7 | from PIL import Image
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 | import pickle
11 |
12 | model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
13 | processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
14 |
15 | device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
16 | model.to(device)
17 |
18 | torch.cuda.empty_cache()
19 |
20 |
21 | class VQADataset(torch.utils.data.Dataset):
22 | """VQA (v2) dataset."""
23 |
24 | def __init__(self, dataset, processor):
25 | self.dataset = dataset
26 | self.processor = processor
27 |
28 | def __len__(self):
29 | return len(self.dataset)
30 |
31 | def __getitem__(self, idx):
32 | # get image + text
33 | question = self.dataset[idx]['question']
34 | answer = str(self.dataset[idx]['answer'])
35 | image_id = self.dataset[idx]['image']
36 | image_path = f"/projects/SpatialMQA/COCO2017/test2017/{image_id}"
37 | image = Image.open(image_path).convert("RGB")
38 | text = question
39 |
40 | encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
41 | labels = self.processor.tokenizer.encode(
42 | answer, max_length=8, pad_to_max_length=True, return_tensors='pt'
43 | )
44 | encoding["labels"] = labels
45 | # remove batch dimension
46 | for k, v in encoding.items(): encoding[k] = v.squeeze()
47 | return encoding
48 |
49 |
50 | training_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip_train/train_3780.jsonl",
51 | split="train[:100%]")
52 | valid_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip_train/dev_536.jsonl",
53 | split="train[:100%]")
54 | print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset)))
55 |
56 | train_dataset = VQADataset(dataset=training_dataset,
57 | processor=processor)
58 | valid_dataset = VQADataset(dataset=valid_dataset,
59 | processor=processor)
60 |
61 | batch_size = 8
62 | cal_num = 2
63 | torch.manual_seed(42)
64 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
65 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
66 |
67 | optimizer = torch.optim.AdamW(model.parameters(), lr=6e-7)
68 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)
69 |
70 | num_epochs = 30
71 | patience = 5
72 | min_eval_loss = float("inf")
73 | early_stopping_hook = 0
74 | tracking_information = []
75 | scaler = torch.cuda.amp.GradScaler()
76 |
77 | for epoch in range(num_epochs):
78 | epoch_loss = 0
79 | model.train()
80 | cal_loss = 0
81 | for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
82 | input_ids = batch.pop('input_ids').to(device)
83 | pixel_values = batch.pop('pixel_values').to(device)
84 | attention_masked = batch.pop('attention_mask').to(device)
85 | labels = batch.pop('labels').to(device)
86 |
87 | with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
88 | outputs = model(input_ids=input_ids,
89 | pixel_values=pixel_values,
90 | attention_mask=attention_masked,
91 | labels=labels)
92 | optimizer.zero_grad()
93 | loss = outputs.loss
94 | cal_loss += loss
95 |
96 | epoch_loss += loss.item()
97 |
98 | if (idx + 1) % cal_num == 0 or idx == len(train_dataloader) - 1:
99 | if (idx + 1) % cal_num == 0:
100 | cal_loss = cal_loss / cal_num
101 | else:
102 | cal_loss = cal_loss / ((idx + 1) % cal_num)
103 | scaler.scale(cal_loss).backward()
104 | scaler.step(optimizer)
105 | scaler.update()
106 | cal_loss = 0
107 |
108 | model.eval()
109 | eval_loss = 0
110 | for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader):
111 | input_ids = batch.pop('input_ids').to(device)
112 | pixel_values = batch.pop('pixel_values').to(device)
113 | attention_masked = batch.pop('attention_mask').to(device)
114 | labels = batch.pop('labels').to(device)
115 |
116 | with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
117 | outputs = model(input_ids=input_ids,
118 | pixel_values=pixel_values,
119 | attention_mask=attention_masked,
120 | labels=labels)
121 |
122 | loss = outputs.loss
123 | eval_loss += loss.item()
124 |
125 | tracking_information.append(
126 | (epoch_loss / len(train_dataloader), eval_loss / len(valid_dataloader), optimizer.param_groups[0]["lr"]))
127 | print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch + 1, epoch_loss / len(train_dataloader),
128 | eval_loss / len(valid_dataloader),
129 | optimizer.param_groups[0]["lr"]))
130 | scheduler.step()
131 | if eval_loss < min_eval_loss:
132 | model.save_pretrained(f"/projects/SpatialMQA/finetune_models/models_arg/blip_finetune_20240602/epoch-{epoch}",
133 | from_pt=True)
134 | print(f"Saved model to /projects/SpatialMQA/finetune_models/models_arg/blip_finetune_20240602/epoch-{epoch}")
135 | min_eval_loss = eval_loss
136 | early_stopping_hook = 0
137 | else:
138 | early_stopping_hook += 1
139 | if early_stopping_hook > patience:
140 | break
141 |
142 | pickle.dump(tracking_information, open("tracking_information.pkl", "wb"))
143 | print("The finetuning process has done!")
144 |
--------------------------------------------------------------------------------
/Code/experiment/spacellava_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
3 |
4 | import argparse
5 | import torch
6 | import json
7 | from llava.constants import (
8 | IMAGE_TOKEN_INDEX,
9 | DEFAULT_IMAGE_TOKEN,
10 | DEFAULT_IM_START_TOKEN,
11 | DEFAULT_IM_END_TOKEN,
12 | IMAGE_PLACEHOLDER,
13 | )
14 | from llava.conversation import conv_templates, SeparatorStyle
15 | from llava.model.builder import load_pretrained_model
16 | from llava.utils import disable_torch_init
17 | from llava.mm_utils import (
18 | process_images,
19 | tokenizer_image_token,
20 | get_model_name_from_path,
21 | )
22 |
23 | from PIL import Image
24 |
25 | import requests
26 | from PIL import Image
27 | from io import BytesIO
28 | import re
29 |
30 | RESULT_FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/results/spacellava_20240812_noPrompts_3.jsonl'
31 | FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/all_en_test_1076_sort.jsonl'
32 | IMAGE_DIR = '/home/zouyinan/SpatialVLM/SpaceLLaVa/pic/'
33 |
34 |
35 | def image_parser(args):
36 | out = args.image_file.split(args.sep)
37 | return out
38 |
39 |
40 | def load_image(image_file):
41 | if image_file.startswith("http") or image_file.startswith("https"):
42 | response = requests.get(image_file)
43 | image = Image.open(BytesIO(response.content)).convert("RGB")
44 | else:
45 | image = Image.open(image_file).convert("RGB")
46 | return image
47 |
48 |
49 | def load_images(image_files):
50 | out = []
51 | for image_file in image_files:
52 | image = load_image(image_file)
53 | out.append(image)
54 | return out
55 |
56 |
57 | # Model
58 | disable_torch_init()
59 | model_path='/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/spacellava_hf'
60 | args = type('Args', (), {
61 | "model_path": model_path,
62 | "model_base": None,
63 | "model_name": get_model_name_from_path(model_path),
64 | # "query": prompt,
65 | "conv_mode": None,
66 | # "image_file": image_file,
67 | "sep": ",",
68 | "temperature": 0.9,
69 | "top_p": None,
70 | "num_beams": 1,
71 | "max_new_tokens": 512
72 | })()
73 |
74 | model_name = get_model_name_from_path(model_path)
75 | tokenizer, model, image_processor, context_len = load_pretrained_model(
76 | args.model_path, None, model_name
77 | )
78 |
79 | def eval_model(args,question,image_file):
80 |
81 |
82 | qs = question
83 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
84 | if IMAGE_PLACEHOLDER in qs:
85 | if model.config.mm_use_im_start_end:
86 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
87 | else:
88 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
89 | else:
90 | if model.config.mm_use_im_start_end:
91 | qs = image_token_se + "\n" + qs
92 | else:
93 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
94 |
95 | if "llama-2" in model_name.lower():
96 | conv_mode = "llava_llama_2"
97 | elif "mistral" in model_name.lower():
98 | conv_mode = "mistral_instruct"
99 | elif "v1.6-34b" in model_name.lower():
100 | conv_mode = "chatml_direct"
101 | elif "v1" in model_name.lower():
102 | conv_mode = "llava_v1"
103 | elif "mpt" in model_name.lower():
104 | conv_mode = "mpt"
105 | else:
106 | conv_mode = "llava_v0"
107 |
108 | if args.conv_mode is not None and conv_mode != args.conv_mode:
109 | print(
110 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
111 | conv_mode, args.conv_mode, args.conv_mode
112 | )
113 | )
114 | else:
115 | args.conv_mode = conv_mode
116 |
117 | conv = conv_templates[args.conv_mode].copy()
118 | conv.append_message(conv.roles[0], qs)
119 | conv.append_message(conv.roles[1], None)
120 | prompt = conv.get_prompt()
121 | image_files = [image_file]
122 | images = load_images(image_files)
123 | image_sizes = [x.size for x in images]
124 | images_tensor = process_images(
125 | images,
126 | image_processor,
127 | model.config
128 | ).to(model.device, dtype=torch.float16)
129 |
130 | input_ids = (
131 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
132 | .unsqueeze(0)
133 | .cuda()
134 | )
135 |
136 | with torch.inference_mode():
137 | output_ids = model.generate(
138 | input_ids,
139 | images=images_tensor,
140 | image_sizes=image_sizes,
141 | do_sample=True if args.temperature > 0 else False,
142 | temperature=args.temperature,
143 | top_p=args.top_p,
144 | num_beams=args.num_beams,
145 | max_new_tokens=args.max_new_tokens,
146 | use_cache=True,
147 | )
148 |
149 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150 | return outputs
151 | count = 0
152 | right_count = 0
153 |
154 | with open(f'{FILE_PATH}', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}', 'w+', encoding="utf-8") as fout:
155 | for line in f:
156 | data = json.loads(line)
157 | question = data['question']
158 | id = data['id']
159 | options = data['options']
160 | image_name = data['image']
161 | image_filepath = IMAGE_DIR + image_name
162 | question = f'Question: {question} \nOptions: {"; ".join(options)} \nAnswer:'
163 | output = eval_model(args,question,image_filepath)
164 | count += 1
165 | if len(output) == 0:
166 | output = '--'
167 | if output.lower() in data['answer']:
168 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
169 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
170 | right_count += 1
171 | elif data['answer'] in output.lower():
172 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
173 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
174 | right_count += 1
175 | else:
176 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
177 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
178 | print(f'{output.lower()}')
179 | print(f"{data['answer']}")
180 | print(f'right_count: {right_count}')
181 | print(f'count: {count}')
182 | print(f'accuracy: {right_count/count}')
183 | # print(f'accuracy: {right_count/count}')
184 |
185 | accuracy = right_count/count
186 | print(f'accuracy: {accuracy}')
--------------------------------------------------------------------------------
/Code/finetune/instructblip-lora.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import requests
4 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, T5ForConditionalGeneration
5 | from datasets import load_dataset
6 | import torch
7 | from PIL import Image
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 | import pickle
11 |
12 | model_dir = '/projects/Models/'
13 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
14 |
15 | from peft import LoraConfig, get_peft_model
16 |
17 | # Let's define the LoraConfig
18 | config = LoraConfig(
19 | r=16,
20 | lora_alpha=32,
21 | lora_dropout=0.05,
22 | bias="none",
23 | target_modules=["q", "k"]
24 | )
25 |
26 | processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
27 |
28 | device = "cuda:6" if torch.cuda.is_available() else "cpu"
29 |
30 | model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl")
31 |
32 | model = get_peft_model(model, config).to(device)
33 |
34 | torch.cuda.empty_cache()
35 | torch.manual_seed(42)
36 |
37 |
38 | class VQADataset(torch.utils.data.Dataset):
39 | """VQA (v2) dataset."""
40 |
41 | def __init__(self, dataset, processor):
42 | self.dataset = dataset
43 | self.processor = processor
44 |
45 | def __len__(self):
46 | return len(self.dataset)
47 |
48 | def __getitem__(self, idx):
49 | # get image + text
50 | question = self.dataset[idx]['question']
51 | answer = str(self.dataset[idx]['answer'])
52 | image_id = self.dataset[idx]['image']
53 | image_path = f"{image_dir}{image_id}"
54 | image = Image.open(image_path).convert("RGB")
55 |
56 | encoding = self.processor(images=image, text=question, return_tensors="pt")
57 | labels = self.processor.tokenizer.encode(
58 | answer, max_length=8, pad_to_max_length=True, return_tensors='pt'
59 | )
60 |
61 | encoding["labels"] = labels
62 | # remove batch dimension
63 | for k, v in encoding.items(): encoding[k] = v.squeeze()
64 | print(encoding['labels'], answer)
65 | return encoding
66 |
67 |
68 | training_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/instructblip_train/train_3780.jsonl",
69 | split="train[:100%]")
70 | valid_dataset = load_dataset("json", data_files="/projects/SpatialMQA/datasets/instructblip_train/dev_536.jsonl",
71 | split="train[:100%]")
72 | print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset)))
73 |
74 | train_dataset = VQADataset(dataset=training_dataset,
75 | processor=processor)
76 | valid_dataset = VQADataset(dataset=valid_dataset,
77 | processor=processor)
78 |
79 | batch_size = 8
80 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
81 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
82 |
83 | optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5)
84 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)
85 |
86 | num_epochs = 30
87 | patience = 5
88 | min_eval_loss = float("inf")
89 | early_stopping_hook = 0
90 | tracking_information = []
91 | scaler = torch.cuda.amp.GradScaler()
92 |
93 | for epoch in range(num_epochs):
94 | epoch_loss = 0
95 | model.train()
96 | for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
97 | input_ids = batch.pop('input_ids').to(device)
98 | qformer_input_ids = batch.pop('qformer_input_ids').to(device)
99 | qformer_attention_mask = batch.pop('qformer_attention_mask').to(device)
100 | pixel_values = batch.pop('pixel_values').to(device)
101 | attention_mask = batch.pop('attention_mask').to(device)
102 | labels = batch.pop('labels').to(device)
103 | with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
104 | outputs = model(input_ids=input_ids,
105 | qformer_input_ids=qformer_input_ids,
106 | qformer_attention_mask=qformer_attention_mask,
107 | pixel_values=pixel_values,
108 | attention_mask=attention_mask,
109 | labels=labels
110 | )
111 | loss = outputs.loss
112 | print('loss:', loss)
113 | epoch_loss += loss.item()
114 | optimizer.zero_grad()
115 |
116 | scaler.scale(loss).backward()
117 | scaler.step(optimizer)
118 | scaler.update()
119 |
120 | model.eval()
121 | eval_loss = 0
122 | for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader):
123 | input_ids = batch.pop('input_ids').to(device)
124 | qformer_input_ids = batch.pop('qformer_input_ids').to(device)
125 | qformer_attention_mask = batch.pop('qformer_attention_mask').to(device)
126 | pixel_values = batch.pop('pixel_values').to(device)
127 | attention_mask = batch.pop('attention_mask').to(device)
128 | labels = batch.pop('labels').to(device)
129 | outputs = model(input_ids=input_ids,
130 | qformer_input_ids=qformer_input_ids,
131 | qformer_attention_mask=qformer_attention_mask,
132 | pixel_values=pixel_values,
133 | attention_mask=attention_mask,
134 | labels=labels)
135 | loss = outputs.loss
136 | eval_loss += loss.item()
137 |
138 | tracking_information.append(
139 | (epoch_loss / len(train_dataloader), eval_loss / len(valid_dataloader), optimizer.param_groups[0]["lr"]))
140 | print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch + 1, epoch_loss / len(train_dataloader),
141 | eval_loss / len(valid_dataloader),
142 | optimizer.param_groups[0]["lr"]))
143 | scheduler.step()
144 | if eval_loss < min_eval_loss:
145 | model.save_pretrained("/projects/SpatialMQA/finetune_models/models_arg/instructblip_lora_20240530",
146 | from_pt=True)
147 | print("Saved model to /projects/SpatialMQA/finetune_models/models_arg/instructblip_lora_20240530")
148 | min_eval_loss = eval_loss
149 | early_stopping_hook = 0
150 | else:
151 | early_stopping_hook += 1
152 | if early_stopping_hook > patience:
153 | break
154 |
155 | pickle.dump(tracking_information, open("tracking_information.pkl", "wb"))
156 | print("The finetuning process has done!")
157 |
--------------------------------------------------------------------------------
/Code/experiment/spatial_test_llava.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "6"
4 |
5 | import argparse
6 | import torch
7 | import json
8 | from llava.constants import (
9 | IMAGE_TOKEN_INDEX,
10 | DEFAULT_IMAGE_TOKEN,
11 | DEFAULT_IM_START_TOKEN,
12 | DEFAULT_IM_END_TOKEN,
13 | IMAGE_PLACEHOLDER,
14 | )
15 | from llava.conversation import conv_templates, SeparatorStyle
16 | from llava.model.builder import load_pretrained_model
17 | from llava.utils import disable_torch_init
18 | from llava.mm_utils import (
19 | process_images,
20 | tokenizer_image_token,
21 | get_model_name_from_path,
22 | )
23 |
24 | from PIL import Image
25 |
26 | import requests
27 | from PIL import Image
28 | from io import BytesIO
29 | import re
30 |
31 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/10/'
32 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
33 | IMAGE_DIR = '/projects/SpatialMQA/COCO2017/test2017/'
34 |
35 |
36 | def image_parser(args):
37 | out = args.image_file.split(args.sep)
38 | return out
39 |
40 |
41 | def load_image(image_file):
42 | if image_file.startswith("http") or image_file.startswith("https"):
43 | response = requests.get(image_file)
44 | image = Image.open(BytesIO(response.content)).convert("RGB")
45 | else:
46 | image = Image.open(image_file).convert("RGB")
47 | return image
48 |
49 |
50 | def load_images(image_files):
51 | out = []
52 | for image_file in image_files:
53 | image = load_image(image_file)
54 | out.append(image)
55 | return out
56 |
57 |
58 | # Model
59 | disable_torch_init()
60 | model_path = '/projects/Models/LLaVA-main/llava-v1.5-7b'
61 | args = type('Args', (), {
62 | "model_path": model_path,
63 | "model_base": None,
64 | "model_name": get_model_name_from_path(model_path),
65 | "conv_mode": None,
66 | "sep": ",",
67 | "temperature": 0.4,
68 | "top_p": None,
69 | "num_beams": 1,
70 | "max_new_tokens": 512
71 | })()
72 |
73 | model_name = get_model_name_from_path(model_path)
74 | tokenizer, model, image_processor, context_len = load_pretrained_model(
75 | args.model_path, None, model_name
76 | )
77 |
78 |
79 | def eval_model(args, question, image_file):
80 | qs = question
81 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
82 | if IMAGE_PLACEHOLDER in qs:
83 | if model.config.mm_use_im_start_end:
84 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
85 | else:
86 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
87 | else:
88 | if model.config.mm_use_im_start_end:
89 | qs = image_token_se + "\n" + qs
90 | else:
91 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
92 |
93 | if "llama-2" in model_name.lower():
94 | conv_mode = "llava_llama_2"
95 | elif "mistral" in model_name.lower():
96 | conv_mode = "mistral_instruct"
97 | elif "v1.6-34b" in model_name.lower():
98 | conv_mode = "chatml_direct"
99 | elif "v1" in model_name.lower():
100 | conv_mode = "llava_v1"
101 | elif "mpt" in model_name.lower():
102 | conv_mode = "mpt"
103 | else:
104 | conv_mode = "llava_v0"
105 |
106 | if args.conv_mode is not None and conv_mode != args.conv_mode:
107 | print(
108 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
109 | conv_mode, args.conv_mode, args.conv_mode
110 | )
111 | )
112 | else:
113 | args.conv_mode = conv_mode
114 |
115 | conv = conv_templates[args.conv_mode].copy()
116 | conv.append_message(conv.roles[0], qs)
117 | conv.append_message(conv.roles[1], None)
118 | prompt = conv.get_prompt()
119 | image_files = [image_file]
120 | images = load_images(image_files)
121 | image_sizes = [x.size for x in images]
122 | images_tensor = process_images(
123 | images,
124 | image_processor,
125 | model.config
126 | ).to(model.device, dtype=torch.float16)
127 |
128 | input_ids = (
129 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
130 | .unsqueeze(0)
131 | .cuda()
132 | )
133 |
134 | with torch.inference_mode():
135 | output_ids = model.generate(
136 | input_ids,
137 | images=images_tensor,
138 | image_sizes=image_sizes,
139 | do_sample=True if args.temperature > 0 else False,
140 | temperature=args.temperature,
141 | top_p=args.top_p,
142 | num_beams=args.num_beams,
143 | max_new_tokens=args.max_new_tokens,
144 | use_cache=True,
145 | )
146 |
147 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
148 | return outputs
149 |
150 |
151 | count = 0
152 | right_count = 0
153 |
154 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f, open(
155 | f'{RESULT_FILE_PATH}llava_4_5.jsonl', 'w+', encoding="utf-8") as fout:
156 | for line in f:
157 | data = json.loads(line)
158 | question = data['question']
159 | id = data['id']
160 | options = data['options']
161 | image_name = data['image']
162 | image_filepath = IMAGE_DIR + image_name
163 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
164 | output = eval_model(args, question, image_filepath)
165 | count += 1
166 | if len(output) == 0:
167 | output = '--'
168 | if output.lower() in data['answer']:
169 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']}
170 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
171 | right_count += 1
172 | elif data['answer'] in output.lower():
173 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']}
174 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
175 | right_count += 1
176 | else:
177 | result_json = {'id': id, 'result': 0, "output": output.lower(), "answer": data['answer']}
178 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
179 | print(f'{output.lower()}')
180 | print(f"{data['answer']}")
181 | print(f'right_count: {right_count}')
182 | print(f'count: {count}')
183 | print(f'accuracy: {right_count / count}')
184 |
185 | accuracy = right_count / count
186 | print(f'accuracy: {accuracy}')
187 |
--------------------------------------------------------------------------------
/Code/experiment/spatial_test_llava_lora.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "7"
4 |
5 | import argparse
6 | import torch
7 | import json
8 | from llava.constants import (
9 | IMAGE_TOKEN_INDEX,
10 | DEFAULT_IMAGE_TOKEN,
11 | DEFAULT_IM_START_TOKEN,
12 | DEFAULT_IM_END_TOKEN,
13 | IMAGE_PLACEHOLDER,
14 | )
15 | from llava.conversation import conv_templates, SeparatorStyle
16 | from llava.model.builder import load_pretrained_model
17 | from llava.utils import disable_torch_init
18 | from llava.mm_utils import (
19 | process_images,
20 | tokenizer_image_token,
21 | get_model_name_from_path,
22 | )
23 |
24 | from PIL import Image
25 |
26 | import requests
27 | from PIL import Image
28 | from io import BytesIO
29 | import re
30 |
31 | RESULT_FILE_PATH = '/projects/SpatialMQA/output/10/'
32 | FILE_PATH = '/projects/SpatialMQA/datasets/spatial/'
33 | IMAGE_DIR = '/projects/SpatialMQA/COCO2017/test2017/'
34 |
35 |
36 | def image_parser(args):
37 | out = args.image_file.split(args.sep)
38 | return out
39 |
40 |
41 | def load_image(image_file):
42 | if image_file.startswith("http") or image_file.startswith("https"):
43 | response = requests.get(image_file)
44 | image = Image.open(BytesIO(response.content)).convert("RGB")
45 | else:
46 | image = Image.open(image_file).convert("RGB")
47 | return image
48 |
49 |
50 | def load_images(image_files):
51 | out = []
52 | for image_file in image_files:
53 | image = load_image(image_file)
54 | out.append(image)
55 | return out
56 |
57 |
58 | # Model
59 | disable_torch_init()
60 | model_path = '/projects/Models/LLaVA-main/llava-v1.5-7b'
61 | args = type('Args', (), {
62 | "model_path": model_path,
63 | "model_base": None,
64 | "model_name": get_model_name_from_path(model_path),
65 | "conv_mode": None,
66 | "sep": ",",
67 | "temperature": 0.4,
68 | "top_p": None,
69 | "num_beams": 1,
70 | "max_new_tokens": 512
71 | })()
72 |
73 | model_name = get_model_name_from_path(model_path)
74 | tokenizer, model, image_processor, context_len = load_pretrained_model(
75 | args.model_path, None, model_name
76 | )
77 | peft_model_id = "/projects/SpatialMQA/finetune_models/models_arg/llava_lora_20240522"
78 | model.load_adapter(peft_model_id)
79 |
80 |
81 | def eval_model(args, question, image_file):
82 | qs = question
83 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
84 | if IMAGE_PLACEHOLDER in qs:
85 | if model.config.mm_use_im_start_end:
86 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
87 | else:
88 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
89 | else:
90 | if model.config.mm_use_im_start_end:
91 | qs = image_token_se + "\n" + qs
92 | else:
93 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
94 |
95 | if "llama-2" in model_name.lower():
96 | conv_mode = "llava_llama_2"
97 | elif "mistral" in model_name.lower():
98 | conv_mode = "mistral_instruct"
99 | elif "v1.6-34b" in model_name.lower():
100 | conv_mode = "chatml_direct"
101 | elif "v1" in model_name.lower():
102 | conv_mode = "llava_v1"
103 | elif "mpt" in model_name.lower():
104 | conv_mode = "mpt"
105 | else:
106 | conv_mode = "llava_v0"
107 |
108 | if args.conv_mode is not None and conv_mode != args.conv_mode:
109 | print(
110 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
111 | conv_mode, args.conv_mode, args.conv_mode
112 | )
113 | )
114 | else:
115 | args.conv_mode = conv_mode
116 |
117 | conv = conv_templates[args.conv_mode].copy()
118 | conv.append_message(conv.roles[0], qs)
119 | conv.append_message(conv.roles[1], None)
120 | prompt = conv.get_prompt()
121 | image_files = [image_file]
122 | images = load_images(image_files)
123 | image_sizes = [x.size for x in images]
124 | images_tensor = process_images(
125 | images,
126 | image_processor,
127 | model.config
128 | ).to(model.device, dtype=torch.float16)
129 |
130 | input_ids = (
131 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
132 | .unsqueeze(0)
133 | .cuda()
134 | )
135 |
136 | with torch.inference_mode():
137 | output_ids = model.generate(
138 | input_ids,
139 | images=images_tensor,
140 | image_sizes=image_sizes,
141 | do_sample=True if args.temperature > 0 else False,
142 | temperature=args.temperature,
143 | top_p=args.top_p,
144 | num_beams=args.num_beams,
145 | max_new_tokens=args.max_new_tokens,
146 | use_cache=True,
147 | )
148 |
149 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150 | return outputs
151 |
152 |
153 | count = 0
154 | right_count = 0
155 |
156 | with open(f'{FILE_PATH}all_en_test_1076_sort.jsonl', 'r', encoding="utf-8") as f, open(
157 | f'{RESULT_FILE_PATH}llava_lora_4_5.jsonl', 'w+', encoding="utf-8") as fout:
158 | for line in f:
159 | data = json.loads(line)
160 | question = data['question']
161 | id = data['id']
162 | options = data['options']
163 | image_name = data['image']
164 | image_filepath = IMAGE_DIR + image_name
165 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: Image: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
166 | output = eval_model(args, question, image_filepath)
167 | count += 1
168 | if len(output) == 0:
169 | output = '--'
170 | if output.lower() in data['answer']:
171 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']}
172 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
173 | right_count += 1
174 | elif data['answer'] in output.lower():
175 | result_json = {'id': id, 'result': 1, "output": output.lower(), "answer": data['answer']}
176 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
177 | right_count += 1
178 | else:
179 | result_json = {'id': id, 'result': 0, "output": output.lower(), "answer": data['answer']}
180 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
181 | print(f'{output.lower()}')
182 | print(f"{data['answer']}")
183 | print(f'right_count: {right_count}')
184 | print(f'count: {count}')
185 | print(f'accuracy: {right_count / count}')
186 |
187 | accuracy = right_count / count
188 | print(f'accuracy: {accuracy}')
189 |
--------------------------------------------------------------------------------
/Code/eval/calculate_prf1.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
4 | import warnings
5 |
6 | warnings.filterwarnings("ignore")
7 |
8 | all_correct = 0
9 |
10 |
11 | def load_jsonl(file_path):
12 | data = []
13 | with open(file_path, 'r', encoding='utf-8') as f:
14 | for line in f:
15 | tmp = json.loads(line.strip())
16 | if tmp['result'] == 1:
17 | tmp['output'] = tmp['answer']
18 | data.append(tmp)
19 | return data
20 |
21 |
22 | def main():
23 | file_list = os.listdir("./model_exp_results/calculate_PRF1/")
24 | # file_list = os.listdir("./human_results_en/")
25 |
26 | # for test_file_path_item in file_list:
27 | # test_file_path = "./model_exp_results/" + test_file_path_item
28 | #
29 | # data = load_jsonl(test_file_path)
30 | # predictions = []
31 | # for item in data:
32 | # predictions.append(item['result'])
33 | # answers = []
34 | # for item in data:
35 | # answers.append(item['answer'])
36 | #
37 | # accuracy = accuracy_score(answers, predictions)
38 | #
39 | # # 计算精确率
40 | # precision = precision_score(answers, predictions, average='weighted')
41 | #
42 | # # 计算召回率
43 | # recall = recall_score(answers, predictions, average='weighted')
44 | #
45 | # # 计算F1值
46 | # f1 = f1_score(answers, predictions, average='weighted')
47 | # print("------------------------------------" + test_file_path_item + "-----------------------------------")
48 | # accuracy = "{:,.2f}".format(accuracy)
49 | # precision = "{:,.2f}".format(precision)
50 | # recall = "{:,.2f}".format(recall)
51 | # f1 = "{:,.2f}".format(f1)
52 | #
53 | # print(f"accuracy: {accuracy}")
54 | # print(f"Precision: {precision}")
55 | # print(f"Recall: {recall}")
56 | # print(f"F1 Score: {f1}")
57 |
58 | for test_file_path_item in file_list:
59 | test_file_path = "./model_exp_results/calculate_PRF1/" + test_file_path_item
60 | # test_file_path = "./human_results_en/" + test_file_path_item
61 |
62 | print("------------------------------------" + test_file_path_item + "-----------------------------------")
63 |
64 | data_list = load_jsonl(test_file_path)
65 | print(len(data_list))
66 | on_pre_num, below_pre_num, left_pre_num, right_pre_num, front_pre_num, behind_pre_num, \
67 | on_label_num, below_label_num, left_label_num, right_label_num, front_label_num, behind_label_num, \
68 | on_true_num, below_true_num, left_true_num, right_true_num, front_true_num, behind_true_num = \
69 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
70 |
71 | for data in data_list:
72 | if data['output'] in "on/above":
73 | on_pre_num += 1
74 | elif data['output'] in "below":
75 | below_pre_num += 1
76 | elif data['output'] in "left of":
77 | left_pre_num += 1
78 | elif data['output'] in "right of":
79 | right_pre_num += 1
80 | elif data['output'] in "in front of":
81 | front_pre_num += 1
82 | elif data['output'] in "behind":
83 | behind_pre_num += 1
84 |
85 | if data['answer'] == "on/above":
86 | on_label_num += 1
87 | elif data['answer'] == "below":
88 | below_label_num += 1
89 | elif data['answer'] == "left of":
90 | left_label_num += 1
91 | elif data['answer'] == "right of":
92 | right_label_num += 1
93 | elif data['answer'] == "in front of":
94 | front_label_num += 1
95 | elif data['answer'] == "behind":
96 | behind_label_num += 1
97 |
98 | if data['answer'] == data['output']:
99 | if data['answer'] == "on/above":
100 | on_true_num += 1
101 | elif data['answer'] == "below":
102 | below_true_num += 1
103 | elif data['answer'] == "left of":
104 | left_true_num += 1
105 | elif data['answer'] == "right of":
106 | right_true_num += 1
107 | elif data['answer'] == "in front of":
108 | front_true_num += 1
109 | elif data['answer'] == "behind":
110 | behind_true_num += 1
111 | if on_pre_num == 0:
112 | on_pre_num = 1
113 | if below_pre_num == 0:
114 | below_pre_num = 1
115 | if left_pre_num == 0:
116 | left_pre_num = 1
117 | if right_pre_num == 0:
118 | right_pre_num = 1
119 | if front_pre_num == 0:
120 | front_pre_num = 1
121 | if behind_pre_num == 0:
122 | behind_pre_num = 1
123 |
124 | p = (on_true_num / on_pre_num + below_true_num / below_pre_num + left_true_num / left_pre_num
125 | + right_true_num / right_pre_num + front_true_num / front_pre_num
126 | + behind_true_num / behind_pre_num) / 6
127 | print(on_true_num, on_pre_num, below_true_num, below_pre_num, left_true_num, left_pre_num,
128 | right_true_num, right_pre_num, front_true_num, front_pre_num,
129 | behind_true_num, behind_pre_num)
130 | r = (on_true_num / on_label_num + below_true_num / below_label_num + left_true_num / left_label_num
131 | + right_true_num / right_label_num + front_true_num / front_label_num +
132 | behind_true_num / behind_label_num) / 6
133 | f1 = 2 * p * r / (p + r)
134 | acc = (on_true_num + below_true_num + left_true_num + right_true_num + front_true_num +
135 | behind_true_num) / (on_label_num + below_label_num + left_label_num + right_label_num
136 | + front_label_num + behind_label_num)
137 |
138 | # p = ((on_true_num + below_true_num) / (on_pre_num + below_pre_num) + (left_true_num + right_true_num)
139 | # / (left_pre_num + right_pre_num) + (front_true_num + behind_true_num) / (front_pre_num +
140 | # behind_pre_num)) / 3
141 | # r = ((on_true_num + below_true_num) / (on_label_num + below_label_num) + (left_true_num + right_true_num)
142 | # / (left_label_num + right_label_num) + (front_true_num + behind_true_num) / (front_label_num +
143 | # behind_label_num)) / 3
144 | # f1 = 2 * p * r / (p + r)
145 | # acc = (on_true_num + below_true_num + left_true_num + right_true_num + front_true_num +
146 | # behind_true_num) / (on_label_num + below_label_num + left_label_num + right_label_num
147 | # + front_label_num + behind_label_num)
148 |
149 | print(f"accuracy: {acc}")
150 | print(f"Precision: {p}")
151 | print(f"Recall: {r}")
152 | print(f"F1 Score: {f1}")
153 |
154 |
155 | if __name__ == "__main__":
156 | main()
157 |
--------------------------------------------------------------------------------
/Code/finetune/blip2-lora.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import requests
4 | from transformers import BlipProcessor, BlipForQuestionAnswering
5 | from transformers import Blip2Processor, AutoProcessor, Blip2ForConditionalGeneration, \
6 | AutoModelForVisualQuestionAnswering
7 | from datasets import load_dataset
8 | import torch
9 | import torch.nn as nn
10 | from PIL import Image
11 | from torch.utils.data import Dataset, DataLoader
12 | from tqdm import tqdm
13 | import pickle
14 | from peft import LoraConfig, get_peft_model
15 |
16 | model_dir = '/projects/Models/'
17 | processor = Blip2Processor.from_pretrained(f"{model_dir}blip2-opt-2.7b")
18 | image_dir = '/projects/SpatialMQA/COCO2017/test2017/'
19 |
20 | train_ds = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip2_train/train_3780.jsonl",
21 | split="train[:100%]")
22 | eval_ds = load_dataset("json", data_files="/projects/SpatialMQA/datasets/blip2_train/dev_536.jsonl",
23 | split="train[:100%]")
24 | print("Training sets: {} - Validating set: {}".format(len(train_ds), len(eval_ds)))
25 |
26 | device = "cuda:7" if torch.cuda.is_available() else "cpu"
27 |
28 | model = Blip2ForConditionalGeneration.from_pretrained(f"{model_dir}blip2-opt-2.7b", device_map="cuda:7",
29 | load_in_8bit=True)
30 |
31 | # Let's define the LoraConfig
32 | config = LoraConfig(
33 | r=16,
34 | lora_alpha=32,
35 | lora_dropout=0.05,
36 | bias="none",
37 | target_modules=["q_proj", "k_proj"]
38 | )
39 |
40 | model = get_peft_model(model, config).to(device)
41 |
42 | torch.cuda.empty_cache()
43 | torch.manual_seed(42)
44 |
45 |
46 | class ImageCaptioningDataset(Dataset):
47 | def __init__(self, dataset, processor):
48 | self.dataset = dataset
49 | self.processor = processor
50 |
51 | def __len__(self):
52 | return len(self.dataset)
53 |
54 | def __getitem__(self, idx):
55 | # get image + text
56 | question = self.dataset[idx]['question']
57 | answer = str(self.dataset[idx]['answer'])
58 | image_id = self.dataset[idx]['image']
59 | image_path = f"/projects/SpatialMQA/COCO2017/test2017/{image_id}"
60 | image = Image.open(image_path).convert("RGB")
61 |
62 | encoding = self.processor(images=image, text=question, return_tensors="pt")
63 |
64 | labels = self.processor.tokenizer.tokenize(answer, return_tensors='pt')
65 | labels = torch.tensor(self.processor.tokenizer.convert_tokens_to_ids(labels)).unsqueeze(0)
66 | encoding["labels"] = torch.cat((labels, torch.tensor([50118]).unsqueeze(0)), dim=1)
67 |
68 | # remove batch dimension
69 | for k, v in encoding.items(): encoding[k] = v.squeeze()
70 | return encoding
71 |
72 |
73 | train_dataset = ImageCaptioningDataset(dataset=train_ds, processor=processor)
74 | valid_dataset = ImageCaptioningDataset(dataset=eval_ds, processor=processor)
75 |
76 | batch_size = 8
77 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
78 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
79 |
80 | optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5)
81 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)
82 |
83 | cal_num = 2
84 | num_epochs = 30
85 | patience = 5
86 | min_eval_loss = float("inf")
87 | early_stopping_hook = 0
88 | tracking_information = []
89 | scaler = torch.cuda.amp.GradScaler()
90 | criterion = nn.CrossEntropyLoss(ignore_index=1)
91 |
92 | for epoch in range(num_epochs):
93 | epoch_loss = 0
94 | cal_loss = 0
95 | model.train()
96 | for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
97 | input_ids = batch.pop('input_ids').to(device)
98 | pixel_values = batch.pop('pixel_values').to(device)
99 | attention_mask = batch.pop('attention_mask').to(device)
100 | labels = batch.pop('labels').to(device)
101 |
102 | with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
103 | outputs = model(input_ids=input_ids,
104 | pixel_values=pixel_values,
105 | attention_mask=attention_mask).logits
106 |
107 | loss = criterion(outputs.view(-1, outputs.shape[-1])[:labels.shape[1], :].contiguous(),
108 | labels.view(-1).contiguous())
109 | epoch_loss += loss.item()
110 | optimizer.zero_grad()
111 |
112 | cal_loss += loss
113 | if (idx + 1) % cal_num == 0 or idx == len(train_dataloader) - 1:
114 | if (idx + 1) % cal_num == 0:
115 | cal_loss = cal_loss / cal_num
116 | else:
117 | cal_loss = cal_loss / ((idx + 1) % cal_num)
118 | print('loss:', cal_loss)
119 | scaler.scale(cal_loss).backward()
120 | scaler.step(optimizer)
121 | scaler.update()
122 | cal_loss = 0
123 |
124 | model.eval()
125 | eval_loss = 0
126 | for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader):
127 | input_ids = batch.pop('input_ids').to(device)
128 | pixel_values = batch.pop('pixel_values').to(device)
129 | attention_mask = batch.pop('attention_mask').to(device)
130 | labels = batch.pop('labels').to(device)
131 | # print('labels:',labels)
132 | # outputs = model(input_ids=input_ids,
133 | # pixel_values=pixel_values,
134 | # attention_mask=attention_mask,
135 | # labels=labels)
136 | # loss = outputs.loss
137 |
138 | # labels_mask = batch.pop('labels_mask').to(device)
139 | outputs = model(input_ids=input_ids,
140 | pixel_values=pixel_values,
141 | attention_mask=attention_mask).logits
142 | # loss = criterion(outputs.view(-1, outputs.shape[-1])[:8, :].contiguous() * labels_mask.view(-1, 1).contiguous(), labels.view(-1).contiguous())
143 | loss = criterion(outputs.view(-1, outputs.shape[-1])[:labels.shape[1], :].contiguous(),
144 | labels.view(-1).contiguous())
145 |
146 | eval_loss += loss.item()
147 | # break
148 |
149 | # break
150 |
151 | tracking_information.append(
152 | (epoch_loss / len(train_dataloader), eval_loss / len(valid_dataloader), optimizer.param_groups[0]["lr"]))
153 | print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch + 1, epoch_loss / len(train_dataloader),
154 | eval_loss / len(valid_dataloader),
155 | optimizer.param_groups[0]["lr"]))
156 | scheduler.step()
157 | if eval_loss < min_eval_loss:
158 | model.save_pretrained(f"/projects/SpatialMQA/finetune_models/models_arg/blip2_lora_20240531/epoch-{epoch}",
159 | from_pt=True)
160 | print(f"Saved model to /projects/SpatialMQA/finetune_models/models_arg/blip2_lora_20240531")
161 | min_eval_loss = eval_loss
162 | early_stopping_hook = 0
163 | else:
164 | early_stopping_hook += 1
165 | if early_stopping_hook > patience:
166 | break
167 |
168 | pickle.dump(tracking_information, open("tracking_information.pkl", "wb"))
169 | print("The finetuning process has done!")
170 |
--------------------------------------------------------------------------------
/Code/experiment/spacellava_lora_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4"
3 |
4 | import argparse
5 | import torch
6 | import json
7 | from llava.constants import (
8 | IMAGE_TOKEN_INDEX,
9 | DEFAULT_IMAGE_TOKEN,
10 | DEFAULT_IM_START_TOKEN,
11 | DEFAULT_IM_END_TOKEN,
12 | IMAGE_PLACEHOLDER,
13 | )
14 | from llava.conversation import conv_templates, SeparatorStyle
15 | from llava.model.builder import load_pretrained_model
16 | from llava.utils import disable_torch_init
17 | from llava.mm_utils import (
18 | process_images,
19 | tokenizer_image_token,
20 | get_model_name_from_path,
21 | )
22 |
23 | from PIL import Image
24 |
25 | import requests
26 | from PIL import Image
27 | from io import BytesIO
28 | import re
29 | RESULT_FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/results/20240816/1000_valid.jsonl'
30 | # FILE_PATH = '/projects/SpatialMethod/Dataset/En_new/'
31 | FILE_PATH = '/home/zouyinan/SpatialVLM/SpaceLLaVa/valid.jsonl'
32 | IMAGE_DIR = '/home/zouyinan/SpatialVLM/SpaceLLaVa/pic/'
33 | def image_parser(args):
34 | out = args.image_file.split(args.sep)
35 | return out
36 |
37 |
38 | def load_image(image_file):
39 | if image_file.startswith("http") or image_file.startswith("https"):
40 | response = requests.get(image_file)
41 | image = Image.open(BytesIO(response.content)).convert("RGB")
42 | else:
43 | image = Image.open(image_file).convert("RGB")
44 | return image
45 |
46 |
47 | def load_images(image_files):
48 | out = []
49 | for image_file in image_files:
50 | image = load_image(image_file)
51 | out.append(image)
52 | return out
53 |
54 |
55 | # Model
56 | disable_torch_init()
57 | model_path='/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/spacellava_hf'
58 | args = type('Args', (), {
59 | "model_path": model_path,
60 | "model_base": None,
61 | "model_name": get_model_name_from_path(model_path),
62 | # "query": prompt,
63 | "conv_mode": None,
64 | # "image_file": image_file,
65 | "sep": ",",
66 | "temperature": 0.1,
67 | "top_p": None,
68 | "num_beams": 1,
69 | "max_new_tokens": 512
70 | })()
71 |
72 | model_name = get_model_name_from_path(model_path)
73 | tokenizer, model, image_processor, context_len = load_pretrained_model(
74 | args.model_path, None, model_name
75 | )
76 | # peft_model_id = "/projects/SpatialMQA/finetune_models/models_arg/llava_lora_20240522"
77 | # peft_model_id = '/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/saved_model/spacellava_lora_20240812_3/checkpoint-460'
78 | peft_model_id = "/home/zouyinan/SpatialVLM/SpaceLLaVa/spacellava/saved_model/spacellava_lora_20240816_1000"
79 | model.load_adapter(peft_model_id)
80 |
81 | def eval_model(args,question,image_file):
82 | qs = question
83 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
84 | if IMAGE_PLACEHOLDER in qs:
85 | if model.config.mm_use_im_start_end:
86 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
87 | else:
88 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
89 | else:
90 | if model.config.mm_use_im_start_end:
91 | qs = image_token_se + "\n" + qs
92 | else:
93 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
94 |
95 | if "llama-2" in model_name.lower():
96 | conv_mode = "llava_llama_2"
97 | elif "mistral" in model_name.lower():
98 | conv_mode = "mistral_instruct"
99 | elif "v1.6-34b" in model_name.lower():
100 | conv_mode = "chatml_direct"
101 | elif "v1" in model_name.lower():
102 | conv_mode = "llava_v1"
103 | elif "mpt" in model_name.lower():
104 | conv_mode = "mpt"
105 | else:
106 | conv_mode = "llava_v0"
107 |
108 | if args.conv_mode is not None and conv_mode != args.conv_mode:
109 | print(
110 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
111 | conv_mode, args.conv_mode, args.conv_mode
112 | )
113 | )
114 | else:
115 | args.conv_mode = conv_mode
116 |
117 | conv = conv_templates[args.conv_mode].copy()
118 | conv.append_message(conv.roles[0], qs)
119 | conv.append_message(conv.roles[1], None)
120 | prompt = conv.get_prompt()
121 | image_files = [image_file]
122 | images = load_images(image_files)
123 | image_sizes = [x.size for x in images]
124 | images_tensor = process_images(
125 | images,
126 | image_processor,
127 | model.config
128 | ).to(model.device, dtype=torch.float16)
129 |
130 | input_ids = (
131 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
132 | .unsqueeze(0)
133 | .cuda()
134 | )
135 |
136 | with torch.inference_mode():
137 | output_ids = model.generate(
138 | input_ids,
139 | images=images_tensor,
140 | image_sizes=image_sizes,
141 | do_sample=True if args.temperature > 0 else False,
142 | temperature=args.temperature,
143 | top_p=args.top_p,
144 | num_beams=args.num_beams,
145 | max_new_tokens=args.max_new_tokens,
146 | use_cache=True,
147 | )
148 |
149 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150 | return outputs
151 | count = 0
152 | right_count = 0
153 |
154 | with open(f'{FILE_PATH}', 'r', encoding="utf-8") as f,open(f'{RESULT_FILE_PATH}', 'w+', encoding="utf-8") as fout:
155 | for line in f:
156 | data = json.loads(line)
157 | question = data['question']
158 | id = data['id']
159 | options = data['options']
160 | image_name = data['image']
161 | image_filepath = IMAGE_DIR + image_name
162 | # question = f'Question: {question} According to the question, please choose one best option from the following options separated by semicolon. Note that you only need to answer one word from the options without explaining the reason.\nOptions: {"; ".join(options)} \nAnswer:'
163 | question = f'You are currently a senior expert in spatial relation reasoning. \n Given an Image, a Question, and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason. \n Input: , Question: {question}, Options: {"; ".join(options)}. \n Output:'
164 | output = eval_model(args,question,image_filepath)
165 |
166 | count += 1
167 | if len(output) == 0:
168 | output = '--'
169 | if output.lower() in data['answer']:
170 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
171 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
172 | right_count += 1
173 | elif data['answer'] in output.lower():
174 | result_json = {'id':id,'result':1,"output":output.lower(),"answer":data['answer']}
175 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
176 | right_count += 1
177 | else:
178 | result_json = {'id':id,'result':0,"output":output.lower(),"answer":data['answer']}
179 | fout.write(json.dumps(result_json,ensure_ascii=False)+'\n')
180 | print(f'{output.lower()}')
181 | print(f"{data['answer']}")
182 | print(f'right_count: {right_count}')
183 | print(f'count: {count}')
184 | print(f'accuracy: {right_count/count}')
185 | # print(f'accuracy: {right_count/count}')
186 |
187 | accuracy = right_count/count
188 | print(f'accuracy: {accuracy}')
--------------------------------------------------------------------------------
/Code/close_models/gemini_1_shot_random.py:
--------------------------------------------------------------------------------
1 | from time import sleep
2 | import PIL.Image as Image
3 | import google.generativeai as genai
4 | from io import BytesIO
5 | import requests
6 | import random
7 | import json
8 |
9 | genai.configure(api_key="your key",transport="rest")
10 |
11 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480}
12 | safety_settings = [
13 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
14 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
15 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
16 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}
17 | ]
18 |
19 | FILE_PATH = './datasets/test_en_select_500_sort1.jsonl'
20 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
21 | RESULT_FILE_PATH = './model_results/gemini_1_shot_random.jsonl'
22 |
23 |
24 | example_list_rule1 = [
25 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."},
26 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."},
27 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."}
28 | ]
29 |
30 | example_list_rule2 = [
31 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."},
32 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."},
33 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."}
34 | ]
35 |
36 | example_list_rule3 = [
37 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."},
38 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."},
39 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."}
40 | ]
41 |
42 |
43 | example_list = [
44 | random.choice(example_list_rule1),
45 | random.choice(example_list_rule2),
46 | random.choice(example_list_rule3)
47 | ]
48 |
49 |
50 | def fetch_image_content(image_url):
51 | response = requests.get(image_url)
52 | if response.status_code == 200:
53 | return BytesIO(response.content)
54 | else:
55 | return None
56 |
57 |
58 | def choose_1_example():
59 | example_1 = random.choice(example_list_rule1 + example_list_rule1 + example_list_rule1)
60 | return example_1
61 |
62 |
63 | # few-shot & zero-shot:
64 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings)
65 |
66 | # text-only:
67 | # model = genai.GenerativeModel('gemini-pro', generation_config=generation_config, safety_settings=safety_settings)
68 |
69 |
70 | count = 0
71 | right_count = 0
72 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout:
73 | for line in f:
74 | sleep(1)
75 | data = json.loads(line)
76 | id = data['id']
77 |
78 | example_1 = choose_1_example()
79 |
80 | # 1 - shot:
81 | # question = f'You are currently a senior expert in spatial relation reasoning. ' \
82 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
83 | # f'\nGiven the following 1 examples to learn the spatial relation reasoning task:' \
84 | # f'\n{example_1} ' \
85 | # f'Input: Image:, Question: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
86 | prompt1 = f'You are currently a senior expert in spatial relation reasoning. ' \
87 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
88 | f'\nGiven the following 1 examples to learn the spatial relation reasoning task:' \
89 | f'\nInput: Image: '
90 | prompt2 = f'\n{example_1["text"]} ' \
91 | f'\nInput: Image: '
92 | prompt3 = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
93 |
94 | image_eg1_id = example_1["image"]
95 | image_eg1_url = f'{IMAGE_DIR}/{image_eg1_id}'
96 | image_eg1_content = fetch_image_content(image_eg1_url)
97 |
98 | image_id = data['image']
99 | image_url = f'{IMAGE_DIR}/{image_id}'
100 | image_content = fetch_image_content(image_url)
101 |
102 | if (image_eg1_content is not None) and (image_content is not None):
103 | image_eg1 = Image.open(image_eg1_content)
104 | image = Image.open(image_content)
105 | try:
106 | response = model.generate_content(
107 | [prompt1, image_eg1, prompt2, image, prompt3]
108 | )
109 | except Exception as e:
110 | print(e)
111 | try:
112 | response = model.generate_content(
113 | [prompt1, image_eg1, prompt2, image, prompt3]
114 | )
115 | except Exception as e:
116 | try:
117 | response = model.generate_content(
118 | [prompt1, image_eg1, prompt2, image, prompt3]
119 | )
120 | except Exception as e:
121 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1}
122 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
123 | count += 1
124 | continue
125 | try:
126 | output = response.text.strip().rstrip('.').lower()
127 | except Exception as e:
128 | print(e)
129 | output = '--'
130 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1}
131 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
132 | count += 1
133 | continue
134 | count += 1
135 | if output in data['answer'] or data['answer'] in output:
136 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1}
137 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
138 | right_count += 1
139 | else:
140 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 3 + 1}
141 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
142 | print(f'{output.lower()}')
143 | print(f"{data['answer']}")
144 | print(f'right_count: {right_count}')
145 | print(f'count: {count}')
146 | # print(f'accuracy: {right_count / count}')
147 |
148 | accuracy = right_count / count
149 | print(accuracy)
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
🔭 Can Multimodal Large Language Models Understand Spatial Relations
4 | SpatialMQA: A new benchmark dataset for spatial reasoning of MLLMs.
5 |
6 |
7 |
8 | ·
9 | huggingface
10 | ·
11 | github
12 | ·
13 | license
14 |
15 |
16 |
17 |
18 |
19 |
20 | ## Contents
21 |
22 | - [SpatialMQA](#Contents)
23 | - [Overview](#1-Overview)
24 | - [Examples](#Examples)
25 | - [Detail Information](#Detail-Information)
26 | - [Access SpatialMQA](#2-Access-SpatialMQA)
27 | - [Download Images](#Download-Images)
28 | - [Data Split](#Data-Split)
29 | - [Data Format](#Data-Format)
30 | - [Experiment & Evaluation](#3-Experiment-and-Evaluation)
31 | - [Experiment](#Experiment)
32 | - [Evaluation](#Evaluation)
33 | - [License](#4-License)
34 |
35 |
36 |
37 |
38 | ## 1 Overview
39 | **SpatialMQA** is a **manually annotated** dataset designed for **multimodal spatial relation reasoning** in a **m**ultiple-choice **q**uestion & **a**nswer format.
40 | The dataset includes 5,392 samples collected from COCO2017, covering 128 subject and object types, without bounding boxes. To address the limitations of existing datasets, we clearly define annotation guidelines for SpatialMQA, including standardizing the objective world as the coordinate system and avoiding questions that can be answered solely by the question itself.
41 |
42 | ### Examples
43 | The following figures list some classic examples in our dataset. You can click out [`Examples:1~4/`](Examples/examples_1-4.png) and [`Examples:5~8/`](Examples/examples_5-8.png) to view the details.
44 |
45 | ### Detail Information
46 | The following table [`Splits/`](Comparison/splits.png) lists the detailed information statistics of the splited dataset.
47 |
48 | You can find our dataset through the following path **_(Dataset/dataset)_** for more details.
49 |
50 | _Due to the fact that only redirecting to the specified file is valid in anonymous links, redirecting to the specified directory is invalid. Therefore, we use bold and italicized font to indicate the markings of all specified directories, making it easier for reviewers to search. Thank you!_
51 |
52 |
53 | ## 2 Access SpatialMQA
54 | Our dataset has been officially released on the Hugging Face. It is available at https://huggingface.co/datasets/liuziyan/SpatialMQA.
55 | Alternatively, you can download it from GitHub by following the steps below:
56 | ### Download Images
57 | We use a subset of COCO-2017's images. The following script download COCO-2017's test sets images then put them into a single fodler `Dataset/COCO2017/`.
58 |
59 | ```bash
60 | cd Dataset/
61 | wget http://images.cocodataset.org/zips/test2017.zip
62 | unzip test2017.zip
63 | mv test2017 COCO2017 && rm -r test2017
64 | ```
65 | Copy only relevant images to `relevant_images/`.
66 | ```bash
67 | mkdir relevant_images
68 | cd tool
69 | python select_revlevant_images.py
70 | ```
71 | Alternatively, you could also browse individual images online directly using the key "image" in single json data.
72 |
(Through COCO's open source link, 'http://images.cocodataset.org/test2017/' + 'image_name'. For example: http://images.cocodataset.org/test2017/000000195921.jpg.)
73 |
74 | ### Data Split
75 | As reported in the folloeing table, SpatialMQA contains 5,392 samples, divided into training, validation, and test sets according to a 7:1:2 ratio.
76 |
All the splited data sets are in the directory **_(Dataset/dataset)_**.
77 |
In addition, we have selected the part of the data set that contains invisible subjects or objects and placed it in the file [`invisible/`](Dataset/invisible/invisible.jsonl) to facilitate readers' research.
78 |
79 | ### Data Format
80 | Each `jsonl` file is of the following format:
81 | ```json
82 | {"image": "000000000933.jpg", "question": "Where is the fork located relative to the pizza?", "options": ["on/above", "below", "in front of", "behind", "left of", "right of"], "answer": "right of"}
83 | {"image": "000000100633.jpg", "question": "If you are the cyclist in the image, where is the dog located relative to you?", "options": ["in front of", "behind", "left of", "right of"], "answer": "behind"}
84 | {"image": "000000070986.jpg", "question": "If you are the driver of the bus in the image, from your perspective, where is the red car located relative to the bus?", "options": ["in front of", "behind", "left of", "right of"], "answer": "left of"}
85 | {"..."}
86 | ```
87 | Each line is an individual data point.
88 | `image` denotes name of the image in COCO. `question` is the question with manual annotation, `options` is reasonable combinations of six spatial relationships:(on/above, below, in front of, behind, left of, right of. `answer` is the annotation based on the objective world.
89 |
90 | Our dataset is expanded based on the categories included in the COCO dataset. There are 113 subject types and one additional type for subjects with five or fewer samples in our dataset, and 84 object types and one additional type for objects with five or fewer samples. Due to the overlap between subject and object types, we have a total of 128 distinct subject and object types. You can see all of them in the file [`S & O types/`](Dataset/types/types.txt).
91 |
92 |
93 | ## 3 Experiment and Evaluation
94 | ### Experiment
95 | We have disclosed the inference code for the model in the directory **_(Code/experiment)_**, as well as the fine-tuning code in the directory **_(Code/finetune)_**.
96 |
97 | - For all 7 open-sourse MLLMs, you can directly execute Python files in the directory **_(Code/experiment)_** to perform inference on models before and after fine-tuning:
98 | ```
99 | nohup python blip-vqa-base.py > log/blip_exp.log 2>1& &
100 | nohup python blip-vqa-base_finetuned.py > log/blip_finetuned_exp.log 2>1& &
101 | nohup python blip2-opt-2.7b.py > log/blip2_exp.log 2>1& &
102 | nohup python blip2-lora.py > log/blip2_lora_exp.log 2>1& &
103 | nohup python instructblip-flan-t5-xl.py > log/instructblip_exp.log 2>1& &
104 | nohup python instructblip-lora.py > log/instructblip_lora_exp.log 2>1& &
105 | nohup python idefics_new.py > log/idefics_exp.log 2>1& &
106 | nohup python idefics_lora.py > log/idefics_lora_exp.log 2>1& &
107 | nohup python spatial_test_llava.py > log/llava_exp.log 2>1& &
108 | nohup python spatial_test_llava_lora.py > log/llava_lora_exp.log 2>1& &
109 | nohup python spatial_test_mplug.py > log/mplug_exp.log 2>1& &
110 | nohup python spatial_test_mplug_lora.py > log/mplug_lora_exp.log 2>1& &
111 | nohup python spacellava_test.py > log/spacellava_exp.log 2>1& &
112 | nohup python spacellava_lora_test.py > log/spacellava_lora_exp.log 2>1& &
113 | ```
114 | Due to the large amount of open source model code, you need to download it yourself through channels or call it directly from platforms such as [huggingface](https://huggingface.co).
115 |
116 | - For blip, blip2, instructblip and idefics, you can directly execute Python files in the directory **_(Code/finetune)_** to perform fine-tuning:
117 | ```
118 | nohup python blip-vqa-base.py > log/blip_train.log 2>1& &
119 | nohup python blip2-lora.py > log/blip2_train.log 2>1& &
120 | nohup python instructblip-lora.py > log/instructblip_train.log 2>1& &
121 | nohup python idefics.py > log/idefics_train.log 2>1& &
122 | ```
123 | - For llava, spacellava and mplug-owl, you need to execute bash files in the directory **_(Code/finetune)_** to perform fine-tuning:
124 | ```
125 | nohup bash llava_lora_train.sh > log/llava_train.log 2>1& &
126 | nohup bash spacellava_lora_train.sh > log/spacellava_train.log 2>1& &
127 | nohup bash mPLUG_Owl_train_it.sh > log/mplug_train.log 2>1& &
128 | ```
129 | - For gemini-1.5-flash and gpt-4o, you can directly execute our Python file in the directory **_(Code/close_models)_** to perform inferencing of the zero-shot, few-shot and text-only, provided that you prepare a key:
130 | ```
131 | python gemini_text_only.py
132 | python gemini_zero_shot.py
133 | python gemini_1_shot.py
134 | python gemini_2_shot.py
135 | python gemini_3_shot.py
136 | python gpt4_text_only.py
137 | python gpt4_zero_shot.py
138 | python gpt4_1_shot.py
139 | python gpt4_2_shot.py
140 | python gpt4_3_shot.py
141 | ```
142 | Gemini needs to apply on the [official website](https://aistudio.google.com/app/apikey), and GPT4 needs to be purchased on the [official website](https://openai.com/).
143 |
144 | ### Evaluation
145 | You can process the results of model inference through the code we provide to calculate overall accuracy, overall P, R, F1 indicators, accuracy based on relationship categories, and accuracy based on rules. We integrate the calculation process into the Python files in the directory **_(Code/eval)_**:
146 | ```
147 | python calculate_prf1.py
148 | python calculate_xyz.py
149 | python calculate_result_rule.py
150 | ```
151 |
152 | ### Requirements
153 | The environment configuration required for debugging code is placed in directory **_(Code/requirement)_**
154 |
155 | The requirements of models blip, blip2 and instructblip, are all in the file [`requirement_blip.txt/`](Code/requirement/requirement_blip.txt)
156 |
157 | ## 4 License
158 | This project is licensed under the [Apache-2.0 License](LICENSE).
159 |
--------------------------------------------------------------------------------
/Code/close_models/gemini_2_shot_random.py:
--------------------------------------------------------------------------------
1 | from time import sleep
2 |
3 | import google.generativeai as genai
4 | from io import BytesIO
5 | import requests
6 | import random
7 |
8 | genai.configure(api_key="your key", transport="rest")
9 |
10 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480}
11 | safety_settings = [
12 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
13 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
14 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
15 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}
16 | ]
17 |
18 | FILE_PATH = 'test_en_select_500_sort_cp.jsonl'
19 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
20 | RESULT_FILE_PATH = 'results/gemini_2_shot_random_1.jsonl'
21 |
22 |
23 | example_list_rule1 = [
24 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."},
25 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."},
26 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."}
27 | ]
28 |
29 | example_list_rule2 = [
30 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."},
31 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."},
32 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."}
33 | ]
34 |
35 | example_list_rule3 = [
36 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."},
37 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."},
38 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."}
39 | ]
40 |
41 |
42 | def pop_exist(data, datalist):
43 | data_list = datalist.copy()
44 | data_list.remove(data)
45 | return random.choice(data_list)
46 |
47 |
48 | E11 = random.choice(example_list_rule1)
49 | E12 = pop_exist(E11, example_list_rule1)
50 | E21 = random.choice(example_list_rule1)
51 | E22 = random.choice(example_list_rule2)
52 | E31 = random.choice(example_list_rule1)
53 | E32 = random.choice(example_list_rule3)
54 | E41 = random.choice(example_list_rule2)
55 | E42 = pop_exist(E41, example_list_rule2)
56 | E51 = random.choice(example_list_rule2)
57 | E52 = random.choice(example_list_rule3)
58 | E61 = random.choice(example_list_rule3)
59 | E62 = pop_exist(E61, example_list_rule3)
60 |
61 |
62 | example_list = [
63 | [E11, E12],
64 | [E21, E22],
65 | [E31, E32],
66 | [E41, E42],
67 | [E51, E52],
68 | [E61, E62]
69 | ]
70 |
71 |
72 | def fetch_image_content(image_url):
73 | response = requests.get(image_url)
74 | if response.status_code == 200:
75 | return BytesIO(response.content)
76 | else:
77 | return None
78 |
79 |
80 | import PIL.Image as Image
81 |
82 | # few-shot & zero-shot:
83 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings)
84 |
85 | # text-only:
86 | # model = genai.GenerativeModel('gemini-pro', generation_config=generation_config, safety_settings=safety_settings)
87 |
88 | import json
89 |
90 | count = 0
91 | right_count = 0
92 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout:
93 | for line in f:
94 | sleep(1)
95 | data = json.loads(line)
96 | id = data['id']
97 |
98 | example_2 = example_list[count % 6]
99 |
100 | # 2 - shot
101 | prompt = ['', '', '', '']
102 | prompt[0] = f'You are currently a senior expert in spatial relation reasoning. ' \
103 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
104 | f'\nGiven the following 2 examples to learn the spatial relation reasoning task:' \
105 | f'\nExample1: Input: Image: '
106 | prompt[1] = f'\n{example_2[0]["text"]} ' \
107 | f'\nExample2: Input: Image: '
108 | prompt[2] = f'\n{example_2[1]["text"]} ' \
109 | f'\nInput: Image: '
110 |
111 | # prompt = f'Given the following 3 examples to learn the spatial relation reasoning task:' \
112 | # f'\n{example_list[count % 10][0]["text"]} ' \
113 | # f'\nYou are currently a senior expert in spatial relation reasoning. ' \
114 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
115 | # f'\nInput: Image: '
116 | prompt[3] = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
117 |
118 | image_list = []
119 | for item in example_2:
120 | image_id = item['image']
121 | image_url = f'{IMAGE_DIR}/{image_id}'
122 | image_content = fetch_image_content(image_url)
123 | if image_content is not None:
124 | image = Image.open(image_content)
125 | image_list.append(image)
126 |
127 | image_id = data['image']
128 | image_url = f'{IMAGE_DIR}/{image_id}'
129 | image_content = fetch_image_content(image_url)
130 |
131 | if image_content is not None:
132 | image = Image.open(image_content)
133 | image_list.append(image)
134 | try:
135 | response = model.generate_content(
136 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3]] # few-shot & zero-shot
137 | # [question] # text-only
138 | )
139 | except Exception as e:
140 | print(e)
141 | try:
142 | response = model.generate_content(
143 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3]] # few-shot & zero-shot
144 | # [question] # text-only
145 | )
146 | except Exception as e:
147 | try:
148 | response = model.generate_content(
149 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3]] # few-shot & zero-shot
150 | # [question] # text-only
151 | )
152 | except Exception as e:
153 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
154 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
155 | count += 1
156 | continue
157 | try:
158 | output = response.text.strip().rstrip('.').lower()
159 | except Exception as e:
160 | print(e)
161 | output = '--'
162 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
163 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
164 | count += 1
165 | continue
166 | count += 1
167 | if output in data['answer'] or data['answer'] in output:
168 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
169 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
170 | right_count += 1
171 | else:
172 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
173 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
174 | print(f'{output.lower()}')
175 | print(f"{data['answer']}")
176 | print(f'right_count: {right_count}')
177 | print(f'count: {count}')
178 |
179 | accuracy = right_count / count
180 | print(accuracy)
181 |
--------------------------------------------------------------------------------
/Code/close_models/gpt4_1_shot.py:
--------------------------------------------------------------------------------
1 | # import openai
2 | from openai import OpenAI
3 | import json
4 | import requests
5 | from io import BytesIO
6 | import PIL.Image as Image
7 | import base64
8 | import random
9 |
10 | client = OpenAI(
11 | api_key='your key',
12 | base_url='https://api.mnxcc.com/v1'
13 | )
14 |
15 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
16 |
17 | example_list_rule1 = [
18 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."},
19 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."},
20 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."}
21 | ]
22 |
23 | example_list_rule2 = [
24 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."},
25 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."},
26 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."}
27 | ]
28 |
29 | example_list_rule3 = [
30 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."},
31 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."},
32 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."}
33 | ]
34 |
35 |
36 | def choose_1_example():
37 | example_1 = random.choice(example_list_rule1 + example_list_rule1 + example_list_rule1)
38 | return example_1
39 |
40 |
41 | # count = 0
42 | # right_count = 0
43 |
44 |
45 | def fetch_image_content(image_url):
46 | response = requests.get(image_url)
47 | if response.status_code == 200:
48 | return BytesIO(response.content)
49 | else:
50 | return None
51 |
52 |
53 | def encode_image(image):
54 | if image is None:
55 | return None
56 |
57 | buffered = BytesIO()
58 | try:
59 | image.save(buffered, format="JPEG")
60 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
61 | return 'data:image/jpeg;base64,' + img_str
62 | except Exception as e:
63 | print(f"error: {e}")
64 | return None
65 |
66 |
67 | def call_gpt4(prompt1: str, prompt2: str, prompt3: str, image_eg1, image, detail='auto'):
68 | try:
69 | response = client.chat.completions.create(
70 | # model="gpt-4-vision-preview",gpt-4-turbo
71 | model="gpt-4o",
72 | messages=[
73 | {
74 | "role": "user",
75 | # zero-shot + few-shot
76 | "content": [{"type": "text", "text": prompt1}] \
77 | + [{"type": "image_url", "image_url": { "url": image_eg1, "detail": detail}}] \
78 | + [{"type": "text", "text": prompt2}] \
79 | + [{"type": "image_url", "image_url": { "url": image, "detail": detail}}] \
80 | + [{"type": "text", "text": prompt3}]
81 | # text only:
82 | # "content": [{"type": "text", "text": question}]
83 | }
84 | ],
85 | max_tokens=500,
86 | temperature=0.5,
87 | )
88 | # print(response.choices[0].message.content.strip())
89 | return response.choices[0].message.content.strip()
90 |
91 | except Exception as e:
92 | print(f"Error during answering: {e}")
93 | return None
94 |
95 |
96 | def process_jsonl(input_file, output_file):
97 | count = 0
98 | right_count = 0
99 |
100 | with open(input_file, 'r', encoding='utf-8') as file:
101 | with open(output_file, 'w', encoding='utf-8') as out_file:
102 | for line in file:
103 | data = json.loads(line)
104 | id = data['id']
105 |
106 | example_1 = choose_1_example()
107 |
108 | # 1 - shot
109 | prompt1 = f'You are currently a senior expert in spatial relation reasoning. ' \
110 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
111 | f'\nGiven the following 1 examples to learn the spatial relation reasoning task:' \
112 | f'\nInput: Image: '
113 | prompt2 = f'\n{example_1["text"]} ' \
114 | f'\nInput: Image: '
115 | prompt3 = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
116 |
117 | image_eg1_id = example_1["image"]
118 | image_eg1_url = f'{IMAGE_DIR}/{image_eg1_id}'
119 | image_eg1_content = fetch_image_content(image_eg1_url)
120 |
121 | image_id = data['image']
122 | image_url = f'{IMAGE_DIR}/{image_id}'
123 | image_content = fetch_image_content(image_url)
124 |
125 | if (image_eg1_content is not None) and (image_content is not None):
126 | image_eg1 = Image.open(image_eg1_content)
127 | image_eg1_encoded = encode_image(image_eg1)
128 |
129 | image = Image.open(image_content)
130 | image_encoded = encode_image(image)
131 |
132 | try:
133 | model_answer = call_gpt4(prompt1, prompt2, prompt3, image_eg1_encoded, image_encoded)
134 | except Exception as e:
135 | print(e)
136 | try:
137 | model_answer = call_gpt4(prompt1, prompt2, prompt3, image_eg1_encoded, image_encoded)
138 | except Exception as e:
139 | try:
140 | model_answer = call_gpt4(prompt1, prompt2, prompt3, image_eg1_encoded, image_encoded)
141 | except Exception as e:
142 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'],
143 | "rule": data['rule'], "example": 0}
144 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
145 | count += 1
146 | continue
147 |
148 | try:
149 | # output = model_answer.text.strip().rstrip('.').lower()
150 | output = model_answer.strip().rstrip('.').lower()
151 | except Exception as e:
152 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'],
153 | "rule": data['rule'], "example": 0}
154 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
155 | count += 1
156 | continue
157 | count += 1
158 | if output in data['answer'] or data['answer'] in output:
159 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'],
160 | "rule": data['rule'], "example": 0}
161 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
162 | right_count += 1
163 | else:
164 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'],
165 | "rule": data['rule'], "example": 0}
166 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
167 | print(f'{output.lower()}')
168 | print(f"{data['answer']}")
169 | print(f'right_count: {right_count}')
170 | print(f'count: {count}')
171 | # print(f'accuracy: {right_count / count}')
172 |
173 | accuracy = right_count / count
174 | print(accuracy)
175 |
176 |
177 | input_file_path = "test_en_select_500_sort_cp.jsonl"
178 | output_file_path = "results/gpt4_1_shot_random_new_cp.jsonl"
179 |
180 | process_jsonl(input_file_path, output_file_path)
181 |
--------------------------------------------------------------------------------
/Code/close_models/gemini_3_shot_random.py:
--------------------------------------------------------------------------------
1 | from time import sleep
2 |
3 | import google.generativeai as genai
4 | from io import BytesIO
5 | import requests
6 | import random
7 |
8 | genai.configure(api_key="your key", transport="rest")
9 |
10 | generation_config = {"temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 480}
11 | safety_settings = [
12 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
13 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
14 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
15 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}
16 | ]
17 |
18 | FILE_PATH = 'test_en_select_500_sort_cp_2.jsonl'
19 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
20 | RESULT_FILE_PATH = 'results/gemini_3_shot_random_cp.jsonl'
21 |
22 |
23 | example_list_rule1 = [
24 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."},
25 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."},
26 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."}
27 | ]
28 |
29 | example_list_rule2 = [
30 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."},
31 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."},
32 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."}
33 | ]
34 |
35 | example_list_rule3 = [
36 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."},
37 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."},
38 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."}
39 | ]
40 |
41 |
42 | def pop_exist(data, datalist):
43 | data_list = datalist.copy()
44 | data_list.remove(data)
45 | return random.choice(data_list)
46 |
47 |
48 | E11 = example_list_rule1[0]
49 | E12 = example_list_rule1[1]
50 | E13 = example_list_rule1[2]
51 | E21 = random.choice(example_list_rule1)
52 | E22 = pop_exist(E21, example_list_rule1)
53 | E23 = random.choice(example_list_rule2)
54 | E31 = random.choice(example_list_rule1)
55 | E32 = pop_exist(E31, example_list_rule1)
56 | E33 = random.choice(example_list_rule3)
57 | E41 = random.choice(example_list_rule1)
58 | E42 = random.choice(example_list_rule2)
59 | E43 = pop_exist(E42, example_list_rule2)
60 | E51 = random.choice(example_list_rule1)
61 | E52 = random.choice(example_list_rule2)
62 | E53 = random.choice(example_list_rule3)
63 | E61 = random.choice(example_list_rule1)
64 | E62 = random.choice(example_list_rule3)
65 | E63 = pop_exist(E62, example_list_rule3)
66 | E71 = example_list_rule2[0]
67 | E72 = example_list_rule2[1]
68 | E73 = example_list_rule2[2]
69 | E81 = random.choice(example_list_rule2)
70 | E82 = pop_exist(E81, example_list_rule2)
71 | E83 = random.choice(example_list_rule3)
72 | E91 = random.choice(example_list_rule2)
73 | E92 = random.choice(example_list_rule3)
74 | E93 = pop_exist(E92, example_list_rule3)
75 | E101 = example_list_rule3[0]
76 | E102 = example_list_rule3[1]
77 | E103 = example_list_rule3[2]
78 |
79 |
80 | example_list = [
81 | [E11, E12, E13],
82 | [E21, E22, E23],
83 | [E31, E32, E33],
84 | [E41, E42, E43],
85 | [E51, E52, E53],
86 | [E61, E62, E63],
87 | [E71, E72, E73],
88 | [E81, E82, E83],
89 | [E91, E92, E93],
90 | [E101, E102, E103]
91 | ]
92 |
93 |
94 | def fetch_image_content(image_url):
95 | response = requests.get(image_url)
96 | if response.status_code == 200:
97 | return BytesIO(response.content)
98 | else:
99 | return None
100 |
101 |
102 | import PIL.Image as Image
103 |
104 | # few-shot & zero-shot:
105 | model = genai.GenerativeModel('gemini-1.5-flash', generation_config=generation_config, safety_settings=safety_settings)
106 |
107 | # text-only:
108 | # model = genai.GenerativeModel('gemini-pro', generation_config=generation_config, safety_settings=safety_settings)
109 |
110 | import json
111 |
112 | count = 0
113 | right_count = 0
114 | with open(FILE_PATH, 'r', encoding="utf-8") as f, open(RESULT_FILE_PATH, 'a', encoding="utf-8") as fout:
115 | for line in f:
116 | sleep(1)
117 | data = json.loads(line)
118 | id = data['id']
119 |
120 | example_3 = example_list[count % 10]
121 |
122 | # 3 - shot
123 | prompt = ['', '', '', '', '']
124 | prompt[0] = f'You are currently a senior expert in spatial relation reasoning. ' \
125 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
126 | f'\nGiven the following 3 examples to learn the spatial relation reasoning task:' \
127 | f'\nExample1: Input: Image: '
128 | prompt[1] = f'\n{example_3[0]["text"]} ' \
129 | f'\nExample2: Input: Image: '
130 | prompt[2] = f'\n{example_3[1]["text"]} ' \
131 | f'\nExample3: Input: Image: '
132 | prompt[3] = f'\n{example_3[2]["text"]} ' \
133 | f'\nInput: Image: '
134 |
135 | # prompt = f'Given the following 3 examples to learn the spatial relation reasoning task:' \
136 | # f'\n{example_list[count % 10][0]["text"]} ' \
137 | # f'\nYou are currently a senior expert in spatial relation reasoning. ' \
138 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
139 | # f'\nInput: Image: '
140 | prompt[4] = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
141 |
142 | image_list = []
143 | for item in example_3:
144 | image_id = item['image']
145 | image_url = f'{IMAGE_DIR}/{image_id}'
146 | image_content = fetch_image_content(image_url)
147 | if image_content is not None:
148 | image = Image.open(image_content)
149 | image_list.append(image)
150 |
151 | image_id = data['image']
152 | image_url = f'{IMAGE_DIR}/{image_id}'
153 | image_content = fetch_image_content(image_url)
154 |
155 | if image_content is not None:
156 | image = Image.open(image_content)
157 | image_list.append(image)
158 | try:
159 | response = model.generate_content(
160 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3], image_list[3], prompt[4]] # few-shot & zero-shot
161 | # [question] # text-only
162 | )
163 | except Exception as e:
164 | print(e)
165 | try:
166 | response = model.generate_content(
167 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3], image_list[3], prompt[4]] # few-shot & zero-shot
168 | # [question] # text-only
169 | )
170 | except Exception as e:
171 | try:
172 | response = model.generate_content(
173 | [prompt[0], image_list[0], prompt[1], image_list[1], prompt[2], image_list[2], prompt[3], image_list[3], prompt[4]] # few-shot & zero-shot
174 | # [question] # text-only
175 | )
176 | except Exception as e:
177 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10}
178 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
179 | count += 1
180 | continue
181 | try:
182 | output = response.text.strip().rstrip('.').lower()
183 | except Exception as e:
184 | print(e)
185 | output = '--'
186 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10}
187 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
188 | count += 1
189 | continue
190 | count += 1
191 | if output in data['answer'] or data['answer'] in output:
192 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10}
193 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
194 | right_count += 1
195 | else:
196 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 10 + 10}
197 | fout.write(json.dumps(result_json, ensure_ascii=False) + '\n')
198 | print(f'{output.lower()}')
199 | print(f"{data['answer']}")
200 | print(f'right_count: {right_count}')
201 | print(f'count: {count}')
202 |
203 | accuracy = right_count / count
204 | print(accuracy)
205 |
--------------------------------------------------------------------------------
/Code/close_models/gpt4_2_shot.py:
--------------------------------------------------------------------------------
1 | # import openai
2 | from openai import OpenAI
3 | import json
4 | import requests
5 | from io import BytesIO
6 | import PIL.Image as Image
7 | import base64
8 | import random
9 |
10 | client = OpenAI(
11 | api_key='your key',
12 | base_url='https://api.mnxcc.com/v1'
13 | )
14 |
15 | IMAGE_DIR = 'http://images.cocodataset.org/test2017'
16 |
17 | example_list_rule1 = [
18 | {"id": 1, "image": "000000358641.jpg", "text": "Question: For the clock in the picture, which side of the 1 scale does the hour hand point to?, Options: left of; right of. \nOutput: right of."},
19 | {"id": 2, "image": "000000209618.jpg", "text": "Question: Where is the white plate located relative to the glass?, Options: in front of; behind; left of; right of. \nOutput: in front of."},
20 | {"id": 3, "image": "000000010682.jpg", "text": "Question: For the letters on the warning sign, where is the letter W located relative to the letter O?, Options: on/above; below; left of; right of. \nOutput: below."}
21 | ]
22 |
23 | example_list_rule2 = [
24 | {"id": 1, "image": "000000057664.jpg", "text": "Question: If you are the person skiing in the picture, where is your shadow located relative to you?, Options: in front of; behind; left of; right of. \nOutput: right of."},
25 | {"id": 2, "image": "000000073924.jpg", "text": "Question: If you were a player playing on the court, where would the tennis ball be located relative to you?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: on/above."},
26 | {"id": 3, "image": "000000022707.jpg", "text": "Question: If you were the little girl in the picture, where would the window be located relative to you?, Options: in front of; behind; left of; right of. \nOutput: behind."}
27 | ]
28 |
29 | example_list_rule3 = [
30 | {"id": 1, "image": "000000139664.jpg", "text": "Question: If you are the driver of the bus in the picture, from your perspective, where is the stroller located relative to the bus?, Options: in front of; behind; left of; right of. \nOutput: left of."},
31 | {"id": 2, "image": "000000221101.jpg", "text": "If you are sitting in front of the computer in the picture, where is the scissors located relative to the laptop from your perspective?, Options: on/above; below; in front of; behind; left of; right of. \nOutput: right of."},
32 | {"id": 3, "image": "000000164692.jpg", "text": "Question: If you are the driver of the white car in the picture, from your perspective, where is the motorcycle located relative to the car?, Options: in front of; behind; left of; right of. \nOutput: behind."}
33 | ]
34 |
35 |
36 | def pop_exist(data, datalist):
37 | data_list = datalist.copy()
38 | data_list.remove(data)
39 | return random.choice(data_list)
40 |
41 |
42 | E11 = random.choice(example_list_rule1)
43 | E12 = pop_exist(E11, example_list_rule1)
44 | E21 = random.choice(example_list_rule1)
45 | E22 = random.choice(example_list_rule2)
46 | E31 = random.choice(example_list_rule1)
47 | E32 = random.choice(example_list_rule3)
48 | E41 = random.choice(example_list_rule2)
49 | E42 = pop_exist(E41, example_list_rule2)
50 | E51 = random.choice(example_list_rule2)
51 | E52 = random.choice(example_list_rule3)
52 | E61 = random.choice(example_list_rule3)
53 | E62 = pop_exist(E61, example_list_rule3)
54 |
55 |
56 | example_list = [
57 | [E11, E12],
58 | [E21, E22],
59 | [E31, E32],
60 | [E41, E42],
61 | [E51, E52],
62 | [E61, E62]
63 | ]
64 |
65 |
66 | # count = 0
67 | # right_count = 0
68 |
69 |
70 | def fetch_image_content(image_url):
71 | response = requests.get(image_url)
72 | if response.status_code == 200:
73 | return BytesIO(response.content)
74 | else:
75 | return None
76 |
77 |
78 | def encode_image(image):
79 | if image is None:
80 | return None
81 |
82 | buffered = BytesIO()
83 | try:
84 | image.save(buffered, format="JPEG")
85 | img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
86 | return 'data:image/jpeg;base64,' + img_str
87 | except Exception as e:
88 | print(f"error: {e}")
89 | return None
90 |
91 |
92 | def call_gpt4(prompt, image, detail='auto'):
93 | try:
94 | response = client.chat.completions.create(
95 | model="gpt-4o",
96 | messages=[
97 | {
98 | "role": "user",
99 | # zero-shot + few-shot
100 | "content": [{"type": "text", "text": prompt[0]}] \
101 | + [{"type": "image_url", "image_url": { "url": image[0], "detail": detail}}] \
102 | + [{"type": "text", "text": prompt[1]}] \
103 | + [{"type": "image_url", "image_url": { "url": image[1], "detail": detail}}] \
104 | + [{"type": "text", "text": prompt[2]}] \
105 | + [{"type": "image_url", "image_url": { "url": image[2], "detail": detail}}] \
106 | + [{"type": "text", "text": prompt[3]}]
107 | # text only:
108 | # "content": [{"type": "text", "text": question}]
109 | }
110 | ],
111 | max_tokens=500,
112 | temperature=0.5,
113 | )
114 | # print(response.choices[0].message.content.strip())
115 | return response.choices[0].message.content.strip()
116 |
117 | except Exception as e:
118 | print(f"Error during answering: {e}")
119 | return None
120 |
121 |
122 | def process_jsonl(input_file, output_file):
123 | count = 0
124 | right_count = 0
125 |
126 | with open(input_file, 'r', encoding='utf-8') as file:
127 | with open(output_file, 'w', encoding='utf-8') as out_file:
128 | for line in file:
129 | data = json.loads(line)
130 | id = data['id']
131 |
132 | example_2 = example_list[count % 6]
133 |
134 | # 2 - shot
135 | prompt = ['', '', '', '']
136 | prompt[0] = f'You are currently a senior expert in spatial relation reasoning. ' \
137 | f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
138 | f'\nGiven the following 2 examples to learn the spatial relation reasoning task:' \
139 | f'\nExample1: Input: Image: '
140 | prompt[1] = f'\n{example_2[0]["text"]} ' \
141 | f'\nExample2: Input: Image: '
142 | prompt[2] = f'\n{example_2[1]["text"]} ' \
143 | f'\nInput: Image: '
144 | # prompt = f'Given the following 3 examples to learn the spatial relation reasoning task:' \
145 | # f'\n{example_list[count % 10][0]["text"]} ' \
146 | # f'\nYou are currently a senior expert in spatial relation reasoning. ' \
147 | # f'\nGiven an Image, a Question and Options, your task is to answer the correct spatial relation. Note that you only need to choose one option from the all options without explaining any reason.' \
148 | # f'\nInput: Image: '
149 | prompt[3] = f'\nQuestion: {data["question"]}, Options: {"; ".join(data["options"])}. \nOutput: '
150 |
151 | image_encoded_list = []
152 | for item in example_2:
153 | image_id = item['image']
154 | image_url = f'{IMAGE_DIR}/{image_id}'
155 | image_content = fetch_image_content(image_url)
156 | if image_content is not None:
157 | image = Image.open(image_content)
158 | image_encoded = encode_image(image)
159 | image_encoded_list.append(image_encoded)
160 |
161 | image_id = data['image']
162 | image_url = f'{IMAGE_DIR}/{image_id}'
163 | image_content = fetch_image_content(image_url)
164 |
165 | if image_content is not None:
166 | image = Image.open(image_content)
167 | image_encoded = encode_image(image)
168 | image_encoded_list.append(image_encoded)
169 |
170 | try:
171 | model_answer = call_gpt4(prompt, image_encoded_list)
172 | except Exception as e:
173 | print(e)
174 | try:
175 | model_answer = call_gpt4(prompt, image_encoded_list)
176 | except Exception as e:
177 | try:
178 | model_answer = call_gpt4(prompt, image_encoded_list)
179 | except Exception as e:
180 | result_json = {"id": id, "result": 0, "output": "--", "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
181 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
182 | count += 1
183 | continue
184 |
185 | try:
186 | # output = model_answer.text.strip().rstrip('.').lower()
187 | output = model_answer.strip().rstrip('.').lower()
188 | except Exception as e:
189 | print(e)
190 | output = '--'
191 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
192 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
193 | count += 1
194 | continue
195 | count += 1
196 | if output in data['answer'] or data['answer'] in output:
197 | result_json = {"id": id, "result": 1, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
198 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
199 | right_count += 1
200 | else:
201 | result_json = {"id": id, "result": 0, "output": output.lower(), "answer": data['answer'], "rule": data['rule'], "example": count % 6 + 4}
202 | out_file.write(json.dumps(result_json, ensure_ascii=False) + '\n')
203 | print(f'{output.lower()}')
204 | print(f"{data['answer']}")
205 | print(f'right_count: {right_count}')
206 | print(f'count: {count}')
207 | # print(f'accuracy: {right_count / count}')
208 |
209 | accuracy = right_count / count
210 | print(accuracy)
211 |
212 |
213 | input_file_path = "test_en_select_500_sort_cp.jsonl"
214 | output_file_path = "results/gpt4_2_shot_random_new_1.jsonl"
215 |
216 | process_jsonl(input_file_path, output_file_path)
217 |
--------------------------------------------------------------------------------