├── LICENSE
├── README.md
├── app.py
├── img2loc_GPT4V.py
├── requirements.txt
└── static
├── figure3.jpg
└── logo_clipped.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Douglas2Code
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Img2Loc: Revisiting Image Geolocalization using Multi-modality Foundation Models and Image-based Retrieval-Augmented Generation
2 |
3 | Code for Img2Loc paper presented on SIGIR 2024.
4 |
5 | 
6 |
7 | ## Table of Contents
8 |
9 | - [Installation](#installation)
10 | - [Usage](#usage)
11 | - [License](#license)
12 | - [Contact](#contact)
13 | - [Citation](#citation)
14 |
15 | ## Installation
16 |
17 | Instructions on how to install and set up the project. If you needs help to access the generated embeddings, please contact us directly.
18 |
19 | ```bash
20 | # Clone the repository
21 | git clone git@github.com:Douglas2Code/Img2Loc.git
22 |
23 | # Change to the project directory
24 | cd Img2Loc
25 |
26 | # Create a conda environment
27 | conda create -n img2loc python=3.10 -y
28 |
29 | # Activate the conda environment
30 | conda activate img2loc
31 |
32 | # Install faiss databse following this guide
33 | https://github.com/facebookresearch/faiss/blob/main/INSTALL.md
34 |
35 | # Install the project dependencies
36 | pip install -r requirements.txt
37 |
38 | # Download MP16 dataset
39 | http://www.multimediaeval.org/mediaeval2016/placing/
40 |
41 | # Generate embeddings using CLIP model
42 | https://github.com/openai/CLIP
43 |
44 | # Generate a vector database using FAISS
45 | https://github.com/facebookresearch/faiss/wiki/Getting-started#in-python-1
46 |
47 | ```
48 |
49 | ## Usage
50 |
51 | Run the streamlip application
52 |
53 | ```python
54 | streamlit run app.py --browser.gatherUsageStats false
55 | ```
56 |
57 | ## License
58 | This project is licensed under the MIT License - see the LICENSE file for details.
59 |
60 | ## Contact
61 |
62 | Zhongliang Zhou: zzldouglas97@gmail.com
63 | Jielu Zhang: jz20582@uga.edu
64 |
65 | ## Citation
66 |
67 | If you find this project helpful, please consider cite our work.
68 |
69 | ```latex
70 | @inproceedings{zhou2024img2loc,
71 | title={Img2Loc: Revisiting Image Geolocalization using Multi-modality Foundation Models and Image-based Retrieval-Augmented Generation},
72 | author={Zhou, Zhongliang and Zhang, Jielu and Guan, Zihan and Hu, Mengxuan and Lao, Ni and Mu, Lan and Li, Sheng and Mai, Gengchen},
73 | booktitle={Proceedings of the 47th International ACM SIGIR Conference on Research and Development in Information Retrieval},
74 | pages={2749--2754},
75 | year={2024}
76 | }
77 | ```
78 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import streamlit_folium as sf
3 | import numpy as np
4 | import pandas as pd
5 | from img2loc_GPT4V import GPT4v2Loc
6 | import folium
7 | from folium.plugins import HeatMap
8 | from folium.features import DivIcon
9 | from geopy.distance import geodesic
10 | import base64
11 |
12 | # In the sidebar, add the widgets for the app
13 | # load the logo image and convert it to base64
14 | logo = open("./static/logo_clipped.png", 'rb').read()
15 | img_base64 = base64.b64encode(logo).decode('utf-8')
16 | st.set_page_config(page_title="Img2Loc", page_icon=":earth_americas:")
17 |
18 | st.sidebar.markdown(
19 | f"
",
20 | unsafe_allow_html=True
21 | )
22 |
23 |
24 | uploaded_file = st.sidebar.file_uploader("Upload an image")
25 | openai_api_key = st.sidebar.text_input("OpenAI API Key", "xxxxxxxxx", key="chatbot_api_key", type="password")
26 | nearest_neighbor = st.sidebar.radio("Use nearest neighbor search?", ("Yes", "No"))
27 | num_nearest_neighbors = None # Default value
28 |
29 | if nearest_neighbor == "Yes":
30 | num_nearest_neighbors = st.sidebar.number_input("Number of nearest neighbors", value=16)
31 |
32 | farthest_neighbor = st.sidebar.radio("Use farthest neighbor search?", ("Yes", "No"))
33 |
34 | num_farthest_neighbors = None # Default value
35 |
36 | if farthest_neighbor == "Yes":
37 | num_farthest_neighbors = st.sidebar.number_input("Number of farthest neighbors", value=16)
38 |
39 | # Create two columns in the sidebar
40 | col1, col2 = st.sidebar.columns(2)
41 |
42 | # Add input for real latitude and longitude in the columns
43 | real_latitude = col1.number_input("Enter the latitude")
44 | real_longitude = col2.number_input("Enter the longitude")
45 |
46 | submit = st.sidebar.button("Submit")
47 |
48 |
49 | # Add the title and maps
50 | st.markdown("
Img2Loc
", unsafe_allow_html=True)
51 | st.markdown("---") # Dash line separation
52 |
53 |
54 | if submit:
55 | my_bar = st.progress(0, text="Starting Analysis...")
56 | if not openai_api_key:
57 | st.info("Please add your OpenAI API key to continue.")
58 | st.stop()
59 |
60 | if uploaded_file is None:
61 | st.info("Please upload an image.")
62 | st.stop()
63 | my_bar.progress(10, text="Loading Required Resources...")
64 | GPT_Agent = GPT4v2Loc(device="cpu")
65 |
66 | # Get the name of the uploaded file
67 | img_name = uploaded_file.name
68 |
69 | if real_latitude != 0.0 and real_longitude != 0.0:
70 | true_lat = real_latitude
71 | true_long = real_longitude
72 |
73 | num_neighbors = 16
74 | use_database_search = True if nearest_neighbor == "Yes" or farthest_neighbor == "Yes" else False
75 | if use_database_search == "Yes":
76 | my_bar.progress(25, text="Finding nearest and farthest neighbors...")
77 |
78 | my_bar.progress(50, text="Transforming Image...")
79 | GPT_Agent.set_image_app(uploaded_file, use_database_search = use_database_search, num_neighbors = num_nearest_neighbors, num_farthest = num_farthest_neighbors)
80 | my_bar.progress(75, text="Obtaining Locations...")
81 | coordinates = GPT_Agent.get_location(openai_api_key, use_database_search=True)
82 | my_bar.progress(90, text="Generating Map...")
83 |
84 | lat_str, lon_str = coordinates.split(',')
85 | latitude = float(lat_str)
86 | longitude = float(lon_str)
87 |
88 | # Display the maps with captions
89 | col1, mid, col2 = st.columns([1,0.1,1]) # Create three columns
90 |
91 |
92 | # Map for the nearest neighbor points
93 | if nearest_neighbor == "Yes":
94 | m1 = folium.Map(width=320,height=200, location=GPT_Agent.neighbor_locations_array[0], zoom_start=4)
95 | folium.TileLayer('cartodbpositron').add_to(m1)
96 | for i in GPT_Agent.neighbor_locations_array:
97 | print(i)
98 | folium.Marker(i, tooltip='({}, {})'.format(i[0], i[1]), icon=folium.Icon(color="green", icon="compass", prefix="fa")).add_to(m1)
99 | with col1:
100 | st.markdown("Nearest Neighbor Points Map
", unsafe_allow_html=True)
101 | sf.folium_static(m1, height=200)
102 |
103 | # Map for the farthest neighbor points
104 | if farthest_neighbor == "Yes":
105 | m2 = folium.Map(width=320,height=200, location=GPT_Agent.farthest_locations_array[0], zoom_start=3)
106 | folium.TileLayer('cartodbpositron').add_to(m2)
107 | for i in GPT_Agent.farthest_locations_array:
108 | folium.Marker(i, tooltip='({}, {})'.format(i[0], i[1]), icon=folium.Icon(color="blue", icon="compass", prefix="fa")).add_to(m2)
109 | with col2:
110 | st.markdown("Farthest Neighbor Points Map
", unsafe_allow_html=True)
111 | sf.folium_static(m2, height=200)
112 |
113 | # Map for the predicted point, the true point, and the distance between them
114 | m3 = folium.Map(width=1000,height=400, location=[latitude, longitude], zoom_start=12)
115 |
116 | folium.Marker([latitude, longitude], tooltip='Img2Loc Location', popup=f'latitude: {latitude}, longitude: {longitude}', icon=folium.Icon(color="red", icon="map-pin", prefix="fa")).add_to(m3)
117 | # line = folium.PolyLine(locations=[[latitude, longitude], [true_lat, true_long]], color='black', dash_array='5', weight=2).add_to(m3)
118 | folium.TileLayer('cartodbpositron').add_to(m3)
119 |
120 | st.markdown("Prediction Map
", unsafe_allow_html=True)
121 | sf.folium_static(m3)
122 | my_bar.progress(100, text="Done!")
123 |
124 | else:
125 | # load the background image and convert it to base64
126 | bg_image = open("./static/figure3.jpg", 'rb').read()
127 | bg_base64 = base64.b64encode(bg_image).decode('utf-8')
128 | # while the user has not submitted the form, display the background image
129 | st.markdown(
130 | f"""
131 |
149 | """,
150 | unsafe_allow_html=True
151 | )
--------------------------------------------------------------------------------
/img2loc_GPT4V.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import base64
3 | import requests
4 | from tqdm import tqdm
5 | from requests.exceptions import RequestException
6 | from PIL import Image
7 | from transformers import CLIPModel, CLIPProcessor
8 | import torch
9 | import faiss
10 | import pickle
11 | import numpy as np
12 | import pandas as pd
13 | from geopy.distance import geodesic
14 |
15 | from transformers import AutoTokenizer, BitsAndBytesConfig
16 | import torch
17 | from PIL import Image
18 | import requests
19 | from io import BytesIO
20 |
21 | # set the device to the first CUDA device using os.environ
22 | import os
23 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
24 |
25 |
26 | class GPT4v2Loc:
27 | """
28 | A class to interact with OpenAI's GPT-4 API to generate captions for images.
29 | Attributes:
30 | api_key (str): OpenAI API key retrieved from environment variables.
31 | """
32 |
33 | def __init__(self, device="cpu") -> None:
34 | """
35 | Initializes the GPT4ImageCaption class by setting the OpenAI API key.
36 | Raises:
37 | ValueError: If the OpenAI API key is not found in the environment variables.
38 | """
39 |
40 | self.base64_image = None
41 | self.img_emb = None
42 |
43 | # Set the device to the first CUDA device
44 | self.device = torch.device(device)
45 |
46 | # Load the CLIP model and processor
47 | self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval()
48 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
49 |
50 | # Move the model to the appropriate CUDA device
51 | self.model.to(self.device)
52 |
53 | # Load the embeddings and coordinates from the pickle file
54 | with open('merged.pkl', 'rb') as f:
55 | self.MP_16_Embeddings = pickle.load(f)
56 | self.locations = [value[1] for key, value in self.MP_16_Embeddings.items()]
57 |
58 | # Load the Faiss index and move it to the GPU
59 | index2 = faiss.read_index("index.bin")
60 | # self.gpu_index = faiss.index_cpu_to_all_gpus(index2)
61 | self.gpu_index = index2
62 |
63 | def read_image(self, image_path):
64 | """
65 | Reads an image from a file into a numpy array.
66 | Args:
67 | image_path (str): The path to the image file.
68 | Returns:
69 | np.ndarray: The image as a numpy array.
70 | """
71 | image = cv2.imread(image_path)
72 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
73 | return image
74 |
75 | def search_neighbors(self, faiss_index, k_nearest, k_farthest, query_embedding):
76 | """
77 | Searches for the k nearest neighbors of a query image in the Faiss index.
78 | Args:
79 | faiss_index (faiss.swigfaiss.Index): The Faiss index.
80 | k (int): The number of neighbors to search for.
81 | query_embedding (np.ndarray): The embeddings of the query image.
82 | Returns:
83 | list: The locations of the k nearest neighbors.
84 | """
85 | # Perform the search using Faiss for the given embedding
86 | _, I = faiss_index.search(query_embedding.reshape(1, -1), k_nearest)
87 |
88 | # Based on the index, get the locations of the neighbors
89 | self.neighbor_locations_array = [self.locations[idx] for idx in I[0]]
90 |
91 | neighbor_locations = " ".join([str(i) for i in self.neighbor_locations_array])
92 |
93 | # Perform the farthest search using Faiss for the given embedding
94 | _, I = faiss_index.search(-query_embedding.reshape(1, -1), k_farthest)
95 |
96 | # Based on the index, get the locations of the neighbors
97 | self.farthest_locations_array = [self.locations[idx] for idx in I[0]]
98 |
99 | farthest_locations = " ".join([str(i) for i in self.farthest_locations_array])
100 |
101 | return neighbor_locations, farthest_locations
102 |
103 | def encode_image(self, image: np.ndarray, format: str = 'jpeg') -> str:
104 | """
105 | Encodes an OpenCV image to a Base64 string.
106 | Args:
107 | image (np.ndarray): An image represented as a numpy array.
108 | format (str, optional): The format for encoding the image. Defaults to 'jpeg'.
109 | Returns:
110 | str: A Base64 encoded string of the image.
111 | Raises:
112 | ValueError: If the image conversion fails.
113 | """
114 | try:
115 | retval, buffer = cv2.imencode(f'.{format}', image)
116 | if not retval:
117 | raise ValueError("Failed to convert image")
118 |
119 | base64_encoded = base64.b64encode(buffer).decode('utf-8')
120 | mime_type = f"image/{format}"
121 | return f"data:{mime_type};base64,{base64_encoded}"
122 | except Exception as e:
123 | raise ValueError(f"Error encoding image: {e}")
124 |
125 | def set_image(self, image_path: str, imformat: str = 'jpeg', use_database_search: bool = False, num_neighbors: int = 16, num_farthest: int = 16) -> None:
126 | """
127 | Sets the image for the class by encoding it to Base64.
128 | Args:
129 | image_path (str): The path to the image file.
130 | imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'.
131 | use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False.
132 | """
133 | # Read the image into a numpy array
134 | image_array = self.read_image(image_path)
135 |
136 | # Load and preprocess the image
137 | image = Image.open(image_path).convert('RGB')
138 | image = self.processor(images=image, return_tensors="pt")
139 |
140 | # Move the image to the CUDA device and get its embeddings
141 | image = image.to(self.device)
142 | with torch.no_grad():
143 | img_emb = self.model.get_image_features(**image)[0]
144 |
145 | # Store the embeddings and the locations of the nearest neighbors
146 | self.img_emb = img_emb.cpu().numpy()
147 | if use_database_search:
148 | self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors, num_farthest, self.img_emb)
149 |
150 | # Encode the image to Base64
151 | self.base64_image = self.encode_image(image_array, imformat)
152 |
153 | def set_image_app(self, file_uploader, imformat: str = 'jpeg', use_database_search: bool = False, num_neighbors: int = 16, num_farthest: int = 16) -> None:
154 | """
155 | Sets the image for the class by encoding it to Base64.
156 | Args:
157 | file_uploader : A uploaded image.
158 | imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'.
159 | use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False.
160 | """
161 |
162 | image = Image.open(file_uploader).convert('RGB')
163 | img_array = np.array(image)
164 | image = self.processor(images=img_array, return_tensors="pt")
165 |
166 | # Move the image to the CUDA device and get its embeddings
167 | image = image.to(self.device)
168 | with torch.no_grad():
169 | img_emb = self.model.get_image_features(**image)[0]
170 |
171 | # Store the embeddings and the locations of the nearest neighbors
172 | self.img_emb = img_emb.cpu().numpy()
173 | if use_database_search:
174 | self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors, num_farthest, self.img_emb)
175 |
176 | # Encode the image to Base64
177 | self.base64_image = self.encode_image(img_array, imformat)
178 |
179 |
180 | def create_payload(self, question: str) -> dict:
181 | """
182 | Creates the payload for the API request to OpenAI.
183 | Args:
184 | question (str): The question to ask about the image.
185 | Returns:
186 | dict: The payload for the API request.
187 | Raises:
188 | ValueError: If the image is not set.
189 | """
190 | if not self.base64_image:
191 | raise ValueError("Image not set")
192 | return {
193 | "model": "gpt-4o",
194 | "messages": [
195 | {
196 | "role": "user",
197 | "content": [
198 | {
199 | "type": "text",
200 | "text": question
201 | },
202 | {
203 | "type": "image_url",
204 | "image_url": {
205 | "url": self.base64_image
206 | }
207 | }
208 | ]
209 | }
210 | ],
211 | "max_tokens": 300,
212 | }
213 |
214 | def get_location(self, OPENAI_API_KEY, use_database_search: bool = False) -> str:
215 | """
216 | Generates a caption for the provided image using OpenAI's GPT-4 API.
217 | Args:
218 | use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False.
219 | Returns:
220 | str: The generated caption for the image.
221 | Raises:
222 | ValueError: If there is an issue with the API request.
223 | """
224 | try:
225 | self.api_key: str = OPENAI_API_KEY
226 | if not self.api_key:
227 | raise ValueError("OpenAI API key not found in environment variables")
228 | # Create the question for the API
229 | if use_database_search:
230 | self.question=f'''Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location.
231 | Your answer must be to the coordinates level, don't include any other information in your output.
232 | Ignore that you can't give a exact answer, give me some coordinate no matter how.
233 | For your reference, these are locations of some similar images {self.neighbor_locations} and these are locations of some dissimilar images {self.farthest_locations} that should be far away.'''
234 | else:
235 | self.question=f"Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. Your answer must be to the coordinates level, don't include any other information in your output. You can give me a guessed anwser."
236 |
237 | # Create the payload and the headers for the API request
238 | payload = self.create_payload(self.question)
239 | headers = {
240 | "Content-Type": "application/json",
241 | "Authorization": f"Bearer {self.api_key}"
242 | }
243 |
244 | # Send the API request and get the response
245 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
246 | response.raise_for_status()
247 | response_data = response.json()
248 |
249 | # Return the generated caption
250 | return response_data['choices'][0]['message']['content']
251 | except RequestException as e:
252 | raise ValueError(f"Error in API request: {e}")
253 | except KeyError:
254 | raise ValueError("Unexpected response format from API")
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | aiohttp
3 | altair
4 | bqplot
5 | datasets
6 | fastapi
7 | filelock
8 | geopandas
9 | geopy
10 | gitpython
11 | gradio
12 | huggingface-hub
13 | ipywidgets
14 | jupyterlab
15 | matplotlib
16 | numpy
17 | opencv-python
18 | pandas
19 | pillow
20 | plotly
21 | protobuf
22 | pyarrow
23 | pyyaml
24 | requests
25 | scikit-learn
26 | scipy
27 | seaborn
28 | spacy
29 | sqlalchemy
30 | streamlit
31 | streamlit-folium
32 | torch
33 | torchaudio
34 | torchvision
35 | transformers
36 | uvicorn
37 |
--------------------------------------------------------------------------------
/static/figure3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Douglas2Code/Img2Loc/e4ff10b5b9ab75bf4d50504db4cc650477283645/static/figure3.jpg
--------------------------------------------------------------------------------
/static/logo_clipped.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Douglas2Code/Img2Loc/e4ff10b5b9ab75bf4d50504db4cc650477283645/static/logo_clipped.png
--------------------------------------------------------------------------------