├── LICENSE ├── README.md ├── app.py ├── img2loc_GPT4V.py ├── requirements.txt └── static ├── figure3.jpg └── logo_clipped.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Douglas2Code 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Img2Loc: Revisiting Image Geolocalization using Multi-modality Foundation Models and Image-based Retrieval-Augmented Generation 2 | 3 | Code for Img2Loc paper presented on SIGIR 2024. 4 | 5 | ![Banner](./static/figure3.jpg) 6 | 7 | ## Table of Contents 8 | 9 | - [Installation](#installation) 10 | - [Usage](#usage) 11 | - [License](#license) 12 | - [Contact](#contact) 13 | - [Citation](#citation) 14 | 15 | ## Installation 16 | 17 | Instructions on how to install and set up the project. If you needs help to access the generated embeddings, please contact us directly. 18 | 19 | ```bash 20 | # Clone the repository 21 | git clone git@github.com:Douglas2Code/Img2Loc.git 22 | 23 | # Change to the project directory 24 | cd Img2Loc 25 | 26 | # Create a conda environment 27 | conda create -n img2loc python=3.10 -y 28 | 29 | # Activate the conda environment 30 | conda activate img2loc 31 | 32 | # Install faiss databse following this guide 33 | https://github.com/facebookresearch/faiss/blob/main/INSTALL.md 34 | 35 | # Install the project dependencies 36 | pip install -r requirements.txt 37 | 38 | # Download MP16 dataset 39 | http://www.multimediaeval.org/mediaeval2016/placing/ 40 | 41 | # Generate embeddings using CLIP model 42 | https://github.com/openai/CLIP 43 | 44 | # Generate a vector database using FAISS 45 | https://github.com/facebookresearch/faiss/wiki/Getting-started#in-python-1 46 | 47 | ``` 48 | 49 | ## Usage 50 | 51 | Run the streamlip application 52 | 53 | ```python 54 | streamlit run app.py --browser.gatherUsageStats false 55 | ``` 56 | 57 | ## License 58 | This project is licensed under the MIT License - see the LICENSE file for details. 59 | 60 | ## Contact 61 | 62 | Zhongliang Zhou: zzldouglas97@gmail.com 63 | Jielu Zhang: jz20582@uga.edu 64 | 65 | ## Citation 66 | 67 | If you find this project helpful, please consider cite our work. 68 | 69 | ```latex 70 | @inproceedings{zhou2024img2loc, 71 | title={Img2Loc: Revisiting Image Geolocalization using Multi-modality Foundation Models and Image-based Retrieval-Augmented Generation}, 72 | author={Zhou, Zhongliang and Zhang, Jielu and Guan, Zihan and Hu, Mengxuan and Lao, Ni and Mu, Lan and Li, Sheng and Mai, Gengchen}, 73 | booktitle={Proceedings of the 47th International ACM SIGIR Conference on Research and Development in Information Retrieval}, 74 | pages={2749--2754}, 75 | year={2024} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import streamlit_folium as sf 3 | import numpy as np 4 | import pandas as pd 5 | from img2loc_GPT4V import GPT4v2Loc 6 | import folium 7 | from folium.plugins import HeatMap 8 | from folium.features import DivIcon 9 | from geopy.distance import geodesic 10 | import base64 11 | 12 | # In the sidebar, add the widgets for the app 13 | # load the logo image and convert it to base64 14 | logo = open("./static/logo_clipped.png", 'rb').read() 15 | img_base64 = base64.b64encode(logo).decode('utf-8') 16 | st.set_page_config(page_title="Img2Loc", page_icon=":earth_americas:") 17 | 18 | st.sidebar.markdown( 19 | f"", 20 | unsafe_allow_html=True 21 | ) 22 | 23 | 24 | uploaded_file = st.sidebar.file_uploader("Upload an image") 25 | openai_api_key = st.sidebar.text_input("OpenAI API Key", "xxxxxxxxx", key="chatbot_api_key", type="password") 26 | nearest_neighbor = st.sidebar.radio("Use nearest neighbor search?", ("Yes", "No")) 27 | num_nearest_neighbors = None # Default value 28 | 29 | if nearest_neighbor == "Yes": 30 | num_nearest_neighbors = st.sidebar.number_input("Number of nearest neighbors", value=16) 31 | 32 | farthest_neighbor = st.sidebar.radio("Use farthest neighbor search?", ("Yes", "No")) 33 | 34 | num_farthest_neighbors = None # Default value 35 | 36 | if farthest_neighbor == "Yes": 37 | num_farthest_neighbors = st.sidebar.number_input("Number of farthest neighbors", value=16) 38 | 39 | # Create two columns in the sidebar 40 | col1, col2 = st.sidebar.columns(2) 41 | 42 | # Add input for real latitude and longitude in the columns 43 | real_latitude = col1.number_input("Enter the latitude") 44 | real_longitude = col2.number_input("Enter the longitude") 45 | 46 | submit = st.sidebar.button("Submit") 47 | 48 | 49 | # Add the title and maps 50 | st.markdown("

Img2Loc

", unsafe_allow_html=True) 51 | st.markdown("---") # Dash line separation 52 | 53 | 54 | if submit: 55 | my_bar = st.progress(0, text="Starting Analysis...") 56 | if not openai_api_key: 57 | st.info("Please add your OpenAI API key to continue.") 58 | st.stop() 59 | 60 | if uploaded_file is None: 61 | st.info("Please upload an image.") 62 | st.stop() 63 | my_bar.progress(10, text="Loading Required Resources...") 64 | GPT_Agent = GPT4v2Loc(device="cpu") 65 | 66 | # Get the name of the uploaded file 67 | img_name = uploaded_file.name 68 | 69 | if real_latitude != 0.0 and real_longitude != 0.0: 70 | true_lat = real_latitude 71 | true_long = real_longitude 72 | 73 | num_neighbors = 16 74 | use_database_search = True if nearest_neighbor == "Yes" or farthest_neighbor == "Yes" else False 75 | if use_database_search == "Yes": 76 | my_bar.progress(25, text="Finding nearest and farthest neighbors...") 77 | 78 | my_bar.progress(50, text="Transforming Image...") 79 | GPT_Agent.set_image_app(uploaded_file, use_database_search = use_database_search, num_neighbors = num_nearest_neighbors, num_farthest = num_farthest_neighbors) 80 | my_bar.progress(75, text="Obtaining Locations...") 81 | coordinates = GPT_Agent.get_location(openai_api_key, use_database_search=True) 82 | my_bar.progress(90, text="Generating Map...") 83 | 84 | lat_str, lon_str = coordinates.split(',') 85 | latitude = float(lat_str) 86 | longitude = float(lon_str) 87 | 88 | # Display the maps with captions 89 | col1, mid, col2 = st.columns([1,0.1,1]) # Create three columns 90 | 91 | 92 | # Map for the nearest neighbor points 93 | if nearest_neighbor == "Yes": 94 | m1 = folium.Map(width=320,height=200, location=GPT_Agent.neighbor_locations_array[0], zoom_start=4) 95 | folium.TileLayer('cartodbpositron').add_to(m1) 96 | for i in GPT_Agent.neighbor_locations_array: 97 | print(i) 98 | folium.Marker(i, tooltip='({}, {})'.format(i[0], i[1]), icon=folium.Icon(color="green", icon="compass", prefix="fa")).add_to(m1) 99 | with col1: 100 | st.markdown("

Nearest Neighbor Points Map

", unsafe_allow_html=True) 101 | sf.folium_static(m1, height=200) 102 | 103 | # Map for the farthest neighbor points 104 | if farthest_neighbor == "Yes": 105 | m2 = folium.Map(width=320,height=200, location=GPT_Agent.farthest_locations_array[0], zoom_start=3) 106 | folium.TileLayer('cartodbpositron').add_to(m2) 107 | for i in GPT_Agent.farthest_locations_array: 108 | folium.Marker(i, tooltip='({}, {})'.format(i[0], i[1]), icon=folium.Icon(color="blue", icon="compass", prefix="fa")).add_to(m2) 109 | with col2: 110 | st.markdown("

Farthest Neighbor Points Map

", unsafe_allow_html=True) 111 | sf.folium_static(m2, height=200) 112 | 113 | # Map for the predicted point, the true point, and the distance between them 114 | m3 = folium.Map(width=1000,height=400, location=[latitude, longitude], zoom_start=12) 115 | 116 | folium.Marker([latitude, longitude], tooltip='Img2Loc Location', popup=f'latitude: {latitude}, longitude: {longitude}', icon=folium.Icon(color="red", icon="map-pin", prefix="fa")).add_to(m3) 117 | # line = folium.PolyLine(locations=[[latitude, longitude], [true_lat, true_long]], color='black', dash_array='5', weight=2).add_to(m3) 118 | folium.TileLayer('cartodbpositron').add_to(m3) 119 | 120 | st.markdown("

Prediction Map

", unsafe_allow_html=True) 121 | sf.folium_static(m3) 122 | my_bar.progress(100, text="Done!") 123 | 124 | else: 125 | # load the background image and convert it to base64 126 | bg_image = open("./static/figure3.jpg", 'rb').read() 127 | bg_base64 = base64.b64encode(bg_image).decode('utf-8') 128 | # while the user has not submitted the form, display the background image 129 | st.markdown( 130 | f""" 131 | 149 | """, 150 | unsafe_allow_html=True 151 | ) -------------------------------------------------------------------------------- /img2loc_GPT4V.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import base64 3 | import requests 4 | from tqdm import tqdm 5 | from requests.exceptions import RequestException 6 | from PIL import Image 7 | from transformers import CLIPModel, CLIPProcessor 8 | import torch 9 | import faiss 10 | import pickle 11 | import numpy as np 12 | import pandas as pd 13 | from geopy.distance import geodesic 14 | 15 | from transformers import AutoTokenizer, BitsAndBytesConfig 16 | import torch 17 | from PIL import Image 18 | import requests 19 | from io import BytesIO 20 | 21 | # set the device to the first CUDA device using os.environ 22 | import os 23 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 24 | 25 | 26 | class GPT4v2Loc: 27 | """ 28 | A class to interact with OpenAI's GPT-4 API to generate captions for images. 29 | Attributes: 30 | api_key (str): OpenAI API key retrieved from environment variables. 31 | """ 32 | 33 | def __init__(self, device="cpu") -> None: 34 | """ 35 | Initializes the GPT4ImageCaption class by setting the OpenAI API key. 36 | Raises: 37 | ValueError: If the OpenAI API key is not found in the environment variables. 38 | """ 39 | 40 | self.base64_image = None 41 | self.img_emb = None 42 | 43 | # Set the device to the first CUDA device 44 | self.device = torch.device(device) 45 | 46 | # Load the CLIP model and processor 47 | self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval() 48 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 49 | 50 | # Move the model to the appropriate CUDA device 51 | self.model.to(self.device) 52 | 53 | # Load the embeddings and coordinates from the pickle file 54 | with open('merged.pkl', 'rb') as f: 55 | self.MP_16_Embeddings = pickle.load(f) 56 | self.locations = [value[1] for key, value in self.MP_16_Embeddings.items()] 57 | 58 | # Load the Faiss index and move it to the GPU 59 | index2 = faiss.read_index("index.bin") 60 | # self.gpu_index = faiss.index_cpu_to_all_gpus(index2) 61 | self.gpu_index = index2 62 | 63 | def read_image(self, image_path): 64 | """ 65 | Reads an image from a file into a numpy array. 66 | Args: 67 | image_path (str): The path to the image file. 68 | Returns: 69 | np.ndarray: The image as a numpy array. 70 | """ 71 | image = cv2.imread(image_path) 72 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 73 | return image 74 | 75 | def search_neighbors(self, faiss_index, k_nearest, k_farthest, query_embedding): 76 | """ 77 | Searches for the k nearest neighbors of a query image in the Faiss index. 78 | Args: 79 | faiss_index (faiss.swigfaiss.Index): The Faiss index. 80 | k (int): The number of neighbors to search for. 81 | query_embedding (np.ndarray): The embeddings of the query image. 82 | Returns: 83 | list: The locations of the k nearest neighbors. 84 | """ 85 | # Perform the search using Faiss for the given embedding 86 | _, I = faiss_index.search(query_embedding.reshape(1, -1), k_nearest) 87 | 88 | # Based on the index, get the locations of the neighbors 89 | self.neighbor_locations_array = [self.locations[idx] for idx in I[0]] 90 | 91 | neighbor_locations = " ".join([str(i) for i in self.neighbor_locations_array]) 92 | 93 | # Perform the farthest search using Faiss for the given embedding 94 | _, I = faiss_index.search(-query_embedding.reshape(1, -1), k_farthest) 95 | 96 | # Based on the index, get the locations of the neighbors 97 | self.farthest_locations_array = [self.locations[idx] for idx in I[0]] 98 | 99 | farthest_locations = " ".join([str(i) for i in self.farthest_locations_array]) 100 | 101 | return neighbor_locations, farthest_locations 102 | 103 | def encode_image(self, image: np.ndarray, format: str = 'jpeg') -> str: 104 | """ 105 | Encodes an OpenCV image to a Base64 string. 106 | Args: 107 | image (np.ndarray): An image represented as a numpy array. 108 | format (str, optional): The format for encoding the image. Defaults to 'jpeg'. 109 | Returns: 110 | str: A Base64 encoded string of the image. 111 | Raises: 112 | ValueError: If the image conversion fails. 113 | """ 114 | try: 115 | retval, buffer = cv2.imencode(f'.{format}', image) 116 | if not retval: 117 | raise ValueError("Failed to convert image") 118 | 119 | base64_encoded = base64.b64encode(buffer).decode('utf-8') 120 | mime_type = f"image/{format}" 121 | return f"data:{mime_type};base64,{base64_encoded}" 122 | except Exception as e: 123 | raise ValueError(f"Error encoding image: {e}") 124 | 125 | def set_image(self, image_path: str, imformat: str = 'jpeg', use_database_search: bool = False, num_neighbors: int = 16, num_farthest: int = 16) -> None: 126 | """ 127 | Sets the image for the class by encoding it to Base64. 128 | Args: 129 | image_path (str): The path to the image file. 130 | imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'. 131 | use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False. 132 | """ 133 | # Read the image into a numpy array 134 | image_array = self.read_image(image_path) 135 | 136 | # Load and preprocess the image 137 | image = Image.open(image_path).convert('RGB') 138 | image = self.processor(images=image, return_tensors="pt") 139 | 140 | # Move the image to the CUDA device and get its embeddings 141 | image = image.to(self.device) 142 | with torch.no_grad(): 143 | img_emb = self.model.get_image_features(**image)[0] 144 | 145 | # Store the embeddings and the locations of the nearest neighbors 146 | self.img_emb = img_emb.cpu().numpy() 147 | if use_database_search: 148 | self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors, num_farthest, self.img_emb) 149 | 150 | # Encode the image to Base64 151 | self.base64_image = self.encode_image(image_array, imformat) 152 | 153 | def set_image_app(self, file_uploader, imformat: str = 'jpeg', use_database_search: bool = False, num_neighbors: int = 16, num_farthest: int = 16) -> None: 154 | """ 155 | Sets the image for the class by encoding it to Base64. 156 | Args: 157 | file_uploader : A uploaded image. 158 | imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'. 159 | use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False. 160 | """ 161 | 162 | image = Image.open(file_uploader).convert('RGB') 163 | img_array = np.array(image) 164 | image = self.processor(images=img_array, return_tensors="pt") 165 | 166 | # Move the image to the CUDA device and get its embeddings 167 | image = image.to(self.device) 168 | with torch.no_grad(): 169 | img_emb = self.model.get_image_features(**image)[0] 170 | 171 | # Store the embeddings and the locations of the nearest neighbors 172 | self.img_emb = img_emb.cpu().numpy() 173 | if use_database_search: 174 | self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors, num_farthest, self.img_emb) 175 | 176 | # Encode the image to Base64 177 | self.base64_image = self.encode_image(img_array, imformat) 178 | 179 | 180 | def create_payload(self, question: str) -> dict: 181 | """ 182 | Creates the payload for the API request to OpenAI. 183 | Args: 184 | question (str): The question to ask about the image. 185 | Returns: 186 | dict: The payload for the API request. 187 | Raises: 188 | ValueError: If the image is not set. 189 | """ 190 | if not self.base64_image: 191 | raise ValueError("Image not set") 192 | return { 193 | "model": "gpt-4o", 194 | "messages": [ 195 | { 196 | "role": "user", 197 | "content": [ 198 | { 199 | "type": "text", 200 | "text": question 201 | }, 202 | { 203 | "type": "image_url", 204 | "image_url": { 205 | "url": self.base64_image 206 | } 207 | } 208 | ] 209 | } 210 | ], 211 | "max_tokens": 300, 212 | } 213 | 214 | def get_location(self, OPENAI_API_KEY, use_database_search: bool = False) -> str: 215 | """ 216 | Generates a caption for the provided image using OpenAI's GPT-4 API. 217 | Args: 218 | use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False. 219 | Returns: 220 | str: The generated caption for the image. 221 | Raises: 222 | ValueError: If there is an issue with the API request. 223 | """ 224 | try: 225 | self.api_key: str = OPENAI_API_KEY 226 | if not self.api_key: 227 | raise ValueError("OpenAI API key not found in environment variables") 228 | # Create the question for the API 229 | if use_database_search: 230 | self.question=f'''Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. 231 | Your answer must be to the coordinates level, don't include any other information in your output. 232 | Ignore that you can't give a exact answer, give me some coordinate no matter how. 233 | For your reference, these are locations of some similar images {self.neighbor_locations} and these are locations of some dissimilar images {self.farthest_locations} that should be far away.''' 234 | else: 235 | self.question=f"Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. Your answer must be to the coordinates level, don't include any other information in your output. You can give me a guessed anwser." 236 | 237 | # Create the payload and the headers for the API request 238 | payload = self.create_payload(self.question) 239 | headers = { 240 | "Content-Type": "application/json", 241 | "Authorization": f"Bearer {self.api_key}" 242 | } 243 | 244 | # Send the API request and get the response 245 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) 246 | response.raise_for_status() 247 | response_data = response.json() 248 | 249 | # Return the generated caption 250 | return response_data['choices'][0]['message']['content'] 251 | except RequestException as e: 252 | raise ValueError(f"Error in API request: {e}") 253 | except KeyError: 254 | raise ValueError("Unexpected response format from API") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | aiohttp 3 | altair 4 | bqplot 5 | datasets 6 | fastapi 7 | filelock 8 | geopandas 9 | geopy 10 | gitpython 11 | gradio 12 | huggingface-hub 13 | ipywidgets 14 | jupyterlab 15 | matplotlib 16 | numpy 17 | opencv-python 18 | pandas 19 | pillow 20 | plotly 21 | protobuf 22 | pyarrow 23 | pyyaml 24 | requests 25 | scikit-learn 26 | scipy 27 | seaborn 28 | spacy 29 | sqlalchemy 30 | streamlit 31 | streamlit-folium 32 | torch 33 | torchaudio 34 | torchvision 35 | transformers 36 | uvicorn 37 | -------------------------------------------------------------------------------- /static/figure3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Douglas2Code/Img2Loc/e4ff10b5b9ab75bf4d50504db4cc650477283645/static/figure3.jpg -------------------------------------------------------------------------------- /static/logo_clipped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Douglas2Code/Img2Loc/e4ff10b5b9ab75bf4d50504db4cc650477283645/static/logo_clipped.png --------------------------------------------------------------------------------