├── IndexSearch.py ├── LICENSE ├── README.md ├── aggregate_llm_predictions.py ├── data ├── im2gps3k │ └── put_data_here.txt └── yfcc4k │ └── put_data_here.txt ├── index └── put_index_here.txt ├── llm_predict.py ├── llm_predict_hf.py ├── run_G3.py └── utils ├── G3.py ├── rff ├── functional.py └── layers.py └── utils.py /IndexSearch.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import torch 3 | import numpy as np 4 | import os 5 | import argparse 6 | import pandas as pd 7 | import ast 8 | import itertools 9 | from PIL import Image 10 | from geopy.distance import geodesic 11 | from transformers import CLIPImageProcessor, CLIPModel 12 | from utils.utils import MP16Dataset, im2gps3kDataset, yfcc4kDataset 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | from torch.utils.data import Dataset, DataLoader 16 | from datetime import datetime 17 | 18 | def build_index(args): 19 | if args.index == 'g3': 20 | model = torch.load('./checkpoints/g3.pth', map_location='cuda:0') 21 | model.requires_grad_(False) 22 | vision_processor = model.vision_processor 23 | dataset = MP16Dataset(vision_processor = model.vision_processor, text_processor = None) 24 | index_flat = faiss.IndexFlatIP(768*3) 25 | dataloader = DataLoader(dataset, batch_size=1024, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=3) 26 | model.eval() 27 | t= tqdm(dataloader) 28 | for i, (images, texts, longitude, latitude) in enumerate(t): 29 | images = images.to(args.device) 30 | vision_output = model.vision_model(images)[1] 31 | image_embeds = model.vision_projection(vision_output) 32 | image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) 33 | 34 | image_text_embeds = model.vision_projection_else_1(model.vision_projection(vision_output)) 35 | image_text_embeds = image_text_embeds / image_text_embeds.norm(p=2, dim=-1, keepdim=True) 36 | 37 | image_location_embeds = model.vision_projection_else_2(model.vision_projection(vision_output)) 38 | image_location_embeds = image_location_embeds / image_location_embeds.norm(p=2, dim=-1, keepdim=True) 39 | 40 | image_embeds = torch.cat([image_embeds, image_text_embeds, image_location_embeds], dim=1) 41 | index_flat.add(image_embeds.cpu().detach().numpy()) 42 | 43 | faiss.write_index(index_flat, f'./index/{args.index}.index') 44 | 45 | def search_index(args, index, topk): 46 | print('start searching...') 47 | if args.dataset == 'im2gps3k': 48 | if args.index == 'g3': 49 | model = torch.load('./checkpoints/g3.pth', map_location='cuda:0') 50 | model.requires_grad_(False) 51 | vision_processor = model.vision_processor 52 | dataset = im2gps3kDataset(vision_processor = vision_processor, text_processor = None) 53 | dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=5) 54 | test_images_embeds = np.empty((0, 768*3)) 55 | model.eval() 56 | print('generating embeds...') 57 | t = tqdm(dataloader) 58 | for i, (images, texts, longitude, latitude) in enumerate(t): 59 | images = images.to(args.device) 60 | vision_output = model.vision_model(images)[1] 61 | image_embeds = model.vision_projection(vision_output) 62 | image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) 63 | 64 | image_text_embeds = model.vision_projection_else_1(model.vision_projection(vision_output)) 65 | image_text_embeds = image_text_embeds / image_text_embeds.norm(p=2, dim=-1, keepdim=True) 66 | 67 | image_location_embeds = model.vision_projection_else_2(model.vision_projection(vision_output)) 68 | image_location_embeds = image_location_embeds / image_location_embeds.norm(p=2, dim=-1, keepdim=True) 69 | 70 | image_embeds = torch.cat([image_embeds, image_text_embeds, image_location_embeds], dim=1) 71 | test_images_embeds = np.concatenate([test_images_embeds, image_embeds.cpu().detach().numpy()], axis=0) 72 | print(test_images_embeds.shape) 73 | test_images_embeds = test_images_embeds.reshape(-1, 768*3) 74 | print('start searching NN...') 75 | D, I = index.search(test_images_embeds, topk) 76 | print(I) 77 | return D, I 78 | elif args.dataset == 'yfcc4k': 79 | if args.index == 'g3': 80 | model = torch.load('./checkpoints/g3.pth', map_location='cuda:0') 81 | model.requires_grad_(False) 82 | vision_processor = model.vision_processor 83 | dataset = yfcc4kDataset(vision_processor = vision_processor, text_processor = None) 84 | dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=5) 85 | test_images_embeds = np.empty((0, 768*3)) 86 | model.eval() 87 | print('generating embeds...') 88 | t = tqdm(dataloader) 89 | for i, (images, texts, longitude, latitude) in enumerate(t): 90 | images = images.to(args.device) 91 | vision_output = model.vision_model(images)[1] 92 | image_embeds = model.vision_projection(vision_output) 93 | image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) 94 | 95 | image_text_embeds = model.vision_projection_else_1(model.vision_projection(vision_output)) 96 | image_text_embeds = image_text_embeds / image_text_embeds.norm(p=2, dim=-1, keepdim=True) 97 | 98 | image_location_embeds = model.vision_projection_else_2(model.vision_projection(vision_output)) 99 | image_location_embeds = image_location_embeds / image_location_embeds.norm(p=2, dim=-1, keepdim=True) 100 | 101 | image_embeds = torch.cat([image_embeds, image_text_embeds, image_location_embeds], dim=1) 102 | test_images_embeds = np.concatenate([test_images_embeds, image_embeds.cpu().detach().numpy()], axis=0) 103 | print(test_images_embeds.shape) 104 | test_images_embeds = test_images_embeds.reshape(-1, 768*3) 105 | print('start searching NN...') 106 | D, I = index.search(test_images_embeds, topk) 107 | return D, I 108 | 109 | class GeoImageDataset(Dataset): 110 | def __init__(self, dataframe, img_folder, topn, vision_processor, database_df, I): 111 | self.dataframe = dataframe 112 | self.img_folder = img_folder 113 | self.topn = topn 114 | self.vision_processor = vision_processor 115 | self.database_df = database_df 116 | self.I = I 117 | 118 | def __len__(self): 119 | return len(self.dataframe) 120 | 121 | def __getitem__(self, idx): 122 | img_path = f'{self.img_folder}/{self.dataframe.loc[idx, "IMG_ID"]}' 123 | image = Image.open(img_path).convert('RGB') 124 | image = self.vision_processor(images=image, return_tensors='pt')['pixel_values'].reshape(3,224,224) 125 | 126 | gps_data = [] 127 | search_top1_latitude, search_top1_longitude = self.database_df.loc[self.I[idx][0], ['LAT', 'LON']].values 128 | rag_5, rag_10, rag_15, zs = [],[],[],[] 129 | for j in range(self.topn): 130 | gps_data.extend([ 131 | float(self.dataframe.loc[idx, f'5_rag_{j}_latitude']), 132 | float(self.dataframe.loc[idx, f'5_rag_{j}_longitude']), 133 | float(self.dataframe.loc[idx, f'10_rag_{j}_latitude']), 134 | float(self.dataframe.loc[idx, f'10_rag_{j}_longitude']), 135 | float(self.dataframe.loc[idx, f'15_rag_{j}_latitude']), 136 | float(self.dataframe.loc[idx, f'15_rag_{j}_longitude']), 137 | float(self.dataframe.loc[idx, f'zs_{j}_latitude']), 138 | float(self.dataframe.loc[idx, f'zs_{j}_longitude']), 139 | search_top1_latitude, 140 | search_top1_longitude 141 | ]) 142 | 143 | gps_data = np.array(gps_data).reshape(-1, 2) 144 | return image, gps_data, idx 145 | 146 | def evaluate(args, I): 147 | print('start evaluation') 148 | if args.database == 'mp16': 149 | database = args.database_df 150 | df = args.dataset_df 151 | df['NN_idx'] = I[:, 0] 152 | df['LAT_pred'] = df.apply(lambda x: database.loc[x['NN_idx'],'LAT'], axis=1) 153 | df['LON_pred'] = df.apply(lambda x: database.loc[x['NN_idx'],'LON'], axis=1) 154 | 155 | df_llm = pd.read_csv(f'./data/{args.dataset}/{args.dataset}_prediction.csv') 156 | model = torch.load('./checkpoints/g3.pth', map_location='cuda:0') 157 | topn = 5 # number of candidates 158 | 159 | dataset = GeoImageDataset(df_llm, f'./data/{args.dataset}/images', topn, vision_processor=model.vision_processor, database_df=database, I=I) 160 | data_loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True) 161 | 162 | for images, gps_batch, indices in tqdm(data_loader): 163 | images = images.to(args.device) 164 | image_embeds = model.vision_projection_else_2(model.vision_projection(model.vision_model(images)[1])) 165 | image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) # b, 768 166 | 167 | gps_batch = gps_batch.to(args.device) 168 | gps_input = gps_batch.clone().detach() 169 | b, c, _ = gps_input.shape 170 | gps_input = gps_input.reshape(b*c, 2) 171 | location_embeds = model.location_encoder(gps_input) 172 | location_embeds = model.location_projection_else(location_embeds.reshape(b*c, -1)) 173 | location_embeds = location_embeds / location_embeds.norm(p=2, dim=-1, keepdim=True) 174 | location_embeds = location_embeds.reshape(b, c, -1) # b, c, 768 175 | 176 | similarity = torch.matmul(image_embeds.unsqueeze(1), location_embeds.permute(0, 2, 1)) # b, 1, c 177 | similarity = similarity.squeeze(1).cpu().detach().numpy() 178 | max_idxs = np.argmax(similarity, axis=1) 179 | 180 | # update DataFrame 181 | for i, max_idx in enumerate(max_idxs): 182 | final_idx = indices[i] 183 | final_idx = final_idx.item() 184 | final_latitude, final_longitude = gps_batch[i][max_idx] 185 | final_latitude, final_longitude = final_latitude.item(), final_longitude.item() 186 | if final_latitude < -90 or final_latitude > 90: 187 | final_latitude = 0 188 | if final_longitude < -180 or final_longitude > 180: 189 | final_longitude = 0 190 | df.loc[final_idx, 'LAT_pred'] = final_latitude 191 | df.loc[final_idx, 'LON_pred'] = final_longitude 192 | 193 | df['geodesic'] = df.apply(lambda x: geodesic((x['LAT'], x['LON']), (x['LAT_pred'], x['LON_pred'])).km, axis=1) 194 | print(df.head()) 195 | df.to_csv(f'./data/{args.dataset}_{args.index}_results.csv', index=False) 196 | 197 | # 1, 25, 200, 750, 2500 km level 198 | print('2500km level: ', df[df['geodesic'] < 2500].shape[0] / df.shape[0]) 199 | print('750km level: ', df[df['geodesic'] < 750].shape[0] / df.shape[0]) 200 | print('200km level: ', df[df['geodesic'] < 200].shape[0] / df.shape[0]) 201 | print('25km level: ', df[df['geodesic'] < 25].shape[0] / df.shape[0]) 202 | print('1km level: ', df[df['geodesic'] < 1].shape[0] / df.shape[0]) 203 | 204 | if __name__ == '__main__': 205 | 206 | res = faiss.StandardGpuResources() 207 | 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--index', type=str, default='g3') 210 | parser.add_argument('--dataset', type=str, default='im2gps3k') 211 | parser.add_argument('--database', type=str, default='mp16') 212 | args = parser.parse_args() 213 | if args.dataset == 'im2gps3k': 214 | args.dataset_df = pd.read_csv('./data/im2gps3k/im2gps3k_places365.csv') 215 | elif args.dataset == 'yfcc4k': 216 | args.dataset_df = pd.read_csv('./data/yfcc4k/yfcc4k_places365.csv') 217 | 218 | if args.database == 'mp16': 219 | args.database_df = pd.read_csv('./data/MP16_Pro_filtered.csv') 220 | 221 | args.device = "cuda" if torch.cuda.is_available() else "cpu" 222 | 223 | if not os.path.exists(f'./index'): os.makedirs(f'./index') 224 | if not os.path.exists(f'./index/{args.index}.index'): 225 | build_index(args) 226 | else: 227 | # gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index) 228 | if not os.path.exists(f'./index/I_{args.index}_{args.dataset}.npy'): 229 | index = faiss.read_index(f'./index/{args.index}.index') 230 | print('read index success') 231 | D,I = search_index(args, index, 20) 232 | np.save(f'./index/D_{args.index}_{args.dataset}.npy', D) 233 | np.save(f'./index/I_{args.index}_{args.dataset}.npy', I) 234 | else: 235 | D = np.load(f'./index/D_{args.index}_{args.dataset}.npy') 236 | I = np.load(f'./index/I_{args.index}_{args.dataset}.npy') 237 | evaluate(args, I) 238 | 239 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the code repository for paper "G3: An Effective and Adaptive Framework for Worldwide Geolocalization Using Large Multi-Modality Models" 2 | 3 | # MP16-Pro 4 | 5 | You can download the images and metadata of MP16-Pro from huggingface: [Jia-py/MP16-Pro](https://huggingface.co/datasets/Jia-py/MP16-Pro/tree/main) 6 | 7 | # Data 8 | 9 | IM2GPS3K: [images](http://www.mediafire.com/file/7ht7sn78q27o9we/im2gps3ktest.zip) | [metadata](https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/meta/im2gps3k_places365.csv) 10 | 11 | YFCC4K: [images](http://www.mediafire.com/file/3og8y3o6c9de3ye/yfcc4k.zip) | [metadata](https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/yfcc25600_places365.csv) 12 | 13 | # Environment Setting 14 | 15 | ```bash 16 | # test on cuda12.0 17 | conda create -n g3 python=3.9 18 | pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121 19 | pip install transformers accelerate huggingface_hub pandas 20 | ``` 21 | 22 | # Running samples 23 | 24 | 1. Geo-alignment 25 | 26 | You can run `python run_G3.py` to train the model. 27 | 28 | 2. Geo-diversification 29 | 30 | First, you need to build the index file using `python IndexSearch.py`. 31 | 32 | Parameters in IndexSearch.py 33 | - index name --> which model you want to use for embedding 34 | - dataset --> im2gps3k or yfcc4k 35 | - database --> default mp16 36 | 37 | Then, you also need to construct index for negative samples by modifying images_embeds to -1 * images_embeds 38 | 39 | Then, you can run `llm_predict_hf.py` or `llm_predict.py` to generate llm predictions. 40 | 41 | After that, `running aggregate_llm_predictions.py` to aggregate the predictions. 42 | 43 | 3. Geo-verification 44 | 45 | `python IndexSearch.py --index=g3 --dataset=im2gps3k or yfcc4k` to verificate predictions and evaluate. 46 | 47 | # Citation 48 | 49 | ```bib 50 | @article{jia2024g3, 51 | title={G3: an effective and adaptive framework for worldwide geolocalization using large multi-modality models}, 52 | author={Jia, Pengyue and Liu, Yiding and Li, Xiaopeng and Zhao, Xiangyu and Wang, Yuhao and Du, Yantong and Han, Xiao and Wei, Xuetao and Wang, Shuaiqiang and Yin, Dawei}, 53 | journal={Advances in Neural Information Processing Systems}, 54 | volume={37}, 55 | pages={53198--53221}, 56 | year={2024} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /aggregate_llm_predictions.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | import ast 4 | from tqdm import tqdm 5 | 6 | df_raw = pd.read_csv('./data/im2gps3k/im2gps3k_places365.csv') 7 | zs_df = pd.read_csv('./data/im2gps3k/llm_predict_results_zs.csv') 8 | rag_5_df = pd.read_csv('./data/im2gps3k/5_llm_predict_results_rag.csv') 9 | rag_10_df = pd.read_csv('./data/im2gps3k/10_llm_predict_results_rag.csv') 10 | rag_15_df = pd.read_csv('./data/im2gps3k/15_llm_predict_results_rag.csv') 11 | 12 | pattern = r'[-+]?\d+\.\d+' 13 | 14 | for i in tqdm(range(zs_df.shape[0])): 15 | response = zs_df.loc[i, 'response'] 16 | response = ast.literal_eval(response) 17 | for idx, content in enumerate(response): 18 | try: 19 | match = re.findall(pattern, content) 20 | latitude = match[0] 21 | longitude = match[1] 22 | df_raw.loc[i, f'zs_{idx}_latitude'] = latitude 23 | df_raw.loc[i, f'zs_{idx}_longitude'] = longitude 24 | except: 25 | df_raw.loc[i, f'zs_{idx}_latitude'] = '0.0' 26 | df_raw.loc[i, f'zs_{idx}_longitude'] = '0.0' 27 | 28 | for i in tqdm(range(df_raw.shape[0])): 29 | response = rag_5_df.loc[i, 'rag_response'] 30 | response = ast.literal_eval(response) 31 | for idx, content in enumerate(response): 32 | try: 33 | match = re.findall(pattern, content) 34 | latitude = match[0] 35 | longitude = match[1] 36 | df_raw.loc[i, f'5_rag_{idx}_latitude'] = latitude 37 | df_raw.loc[i, f'5_rag_{idx}_longitude'] = longitude 38 | except: 39 | df_raw.loc[i, f'5_rag_{idx}_latitude'] = '0.0' 40 | df_raw.loc[i, f'5_rag_{idx}_longitude'] = '0.0' 41 | 42 | for i in tqdm(range(df_raw.shape[0])): 43 | response = rag_10_df.loc[i, 'rag_response'] 44 | response = ast.literal_eval(response) 45 | for idx, content in enumerate(response): 46 | try: 47 | match = re.findall(pattern, content) 48 | latitude = match[0] 49 | longitude = match[1] 50 | df_raw.loc[i, f'10_rag_{idx}_latitude'] = latitude 51 | df_raw.loc[i, f'10_rag_{idx}_longitude'] = longitude 52 | except: 53 | df_raw.loc[i, f'10_rag_{idx}_latitude'] = '0.0' 54 | df_raw.loc[i, f'10_rag_{idx}_longitude'] = '0.0' 55 | 56 | for i in tqdm(range(df_raw.shape[0])): 57 | response = rag_15_df.loc[i, 'rag_response'] 58 | response = ast.literal_eval(response) 59 | for idx, content in enumerate(response): 60 | try: 61 | match = re.findall(pattern, content) 62 | latitude = match[0] 63 | longitude = match[1] 64 | df_raw.loc[i, f'15_rag_{idx}_latitude'] = latitude 65 | df_raw.loc[i, f'15_rag_{idx}_longitude'] = longitude 66 | except: 67 | df_raw.loc[i, f'15_rag_{idx}_latitude'] = '0.0' 68 | df_raw.loc[i, f'15_rag_{idx}_longitude'] = '0.0' 69 | 70 | df_raw.to_csv('./data/im2gps3k/im2gps3k_prediction.csv', index=False) 71 | -------------------------------------------------------------------------------- /data/im2gps3k/put_data_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Applied-Machine-Learning-Lab/G3/0c6b0193af2f0e4904c105f6ee1dfa780395bfe5/data/im2gps3k/put_data_here.txt -------------------------------------------------------------------------------- /data/yfcc4k/put_data_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Applied-Machine-Learning-Lab/G3/0c6b0193af2f0e4904c105f6ee1dfa780395bfe5/data/yfcc4k/put_data_here.txt -------------------------------------------------------------------------------- /index/put_index_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Applied-Machine-Learning-Lab/G3/0c6b0193af2f0e4904c105f6ee1dfa780395bfe5/index/put_index_here.txt -------------------------------------------------------------------------------- /llm_predict.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import base64 3 | import os 4 | import re 5 | import pandas as pd 6 | import numpy as np 7 | import ast 8 | from pandarallel import pandarallel 9 | from tqdm import tqdm 10 | import argparse 11 | 12 | def encode_image(image_path): 13 | with open(image_path, "rb") as image_file: 14 | return base64.b64encode(image_file.read()).decode('utf-8') 15 | 16 | def get_response(image_path, base_url, api_key, model_name, detail="low", max_tokens=200, temperature=1.2, n=10): 17 | base64_image = encode_image(image_path) 18 | headers = { 19 | "Content-Type": "application/json", 20 | "Authorization": f"Bearer {api_key}" 21 | } 22 | 23 | payload = { 24 | "model": model_name, 25 | "messages": [ 26 | { 27 | "role": "user", 28 | "content": [ 29 | { 30 | "type": "text", 31 | "text": """Suppose you are an expert in geo-localization, you have the ability to give two number GPS coordination given an image. 32 | Please give me the location of the given image. 33 | Remember, you must have an answer, just output your best guess, don't answer me that you can't give a location. 34 | Your answer should be in the following JSON format without any other information: {"latitude": float,"longitude": float}. 35 | Your answer should be in the following JSON format without any other information: {"latitude": float,"longitude": float}.""" 36 | }, 37 | { 38 | "type": "image_url", 39 | "image_url": { 40 | "url": f"data:image/jpeg;base64,{base64_image}", 41 | "detail": detail 42 | } 43 | } 44 | ] 45 | } 46 | ], 47 | "max_tokens": max_tokens, 48 | "temperature": temperature, 49 | "n": n 50 | } 51 | 52 | response = requests.post(base_url, headers=headers, json=payload, timeout=(30,60)) 53 | ans = [] 54 | for choice in response.json()['choices']: 55 | try: 56 | ans.append(choice['message']['content']) 57 | except: 58 | ans.append('{"latitude": 0.0,"longitude": 0.0}') 59 | return ans 60 | 61 | def get_response_rag(image_path, base_url, api_key, model_name, candidates_gps, reverse_gps, detail="low", max_tokens=200, temperature=1.2, n=10): 62 | base64_image = encode_image(image_path) 63 | headers = { 64 | "Content-Type": "application/json", 65 | "Authorization": f"Bearer {api_key}" 66 | } 67 | 68 | payload = { 69 | "model": model_name, 70 | "messages": [ 71 | { 72 | "role": "user", 73 | "content": [ 74 | { 75 | "type": "text", 76 | "text": f"""Suppose you are an expert in geo-localization, Please analyze this image and give me a guess of the location. 77 | Your answer must be to the coordinates level in (latitude, longitude) format. 78 | For your reference, these are coordinates of some similar images: {candidates_gps}, and these are coordinates of some dissimilar images: {reverse_gps}. 79 | Remember, you must have an answer, just output your best guess, don't answer me that you can't give an location. 80 | Your answer should be in the following JSON format without any other information: {{"latitude": float,"longitude": float}}. 81 | Your answer should be in the following JSON format without any other information: {{"latitude": float,"longitude": float}}. 82 | """ 83 | }, 84 | { 85 | "type": "image_url", 86 | "image_url": { 87 | "url": f"data:image/jpeg;base64,{base64_image}", 88 | "detail": detail 89 | } 90 | } 91 | ] 92 | } 93 | ], 94 | "max_tokens": max_tokens, 95 | "temperature": temperature, 96 | "n": n 97 | } 98 | 99 | response = requests.post(base_url, headers=headers, json=payload, timeout=(30,60)) 100 | ans = [] 101 | for choice in response.json()['choices']: 102 | try: 103 | ans.append(choice['message']['content']) 104 | except: 105 | ans.append('{"latitude": 0.0,"longitude": 0.0}') 106 | return ans 107 | 108 | def process_row(row, base_url, api_key, model_name, root_path, image_path): 109 | image_path = os.path.join(root_path, image_path, row["IMG_ID"]) 110 | try: 111 | response = get_response(image_path, base_url, api_key, model_name) 112 | except Exception as e: 113 | response = "None" 114 | print(e) 115 | row['response'] = response 116 | return row 117 | 118 | def process_row_rag(row, base_url, api_key, model_name, root_path, image_path, rag_sample_num): 119 | image_path = os.path.join(root_path, image_path, row["IMG_ID"]) 120 | try: 121 | #candidates_gps = [eval(row[f'candidate_{i}_gps']) for i in range(rag_sample_num)] 122 | candidates_gps = [row[f'candidate_{i}_gps'] for i in range(rag_sample_num)] 123 | candidates_gps = str(candidates_gps) 124 | #reverse_gps = [eval(row[f'reverse_{i}_gps']) for i in range(rag_sample_num)] 125 | reverse_gps = [row[f'reverse_{i}_gps'] for i in range(rag_sample_num)] 126 | reverse_gps = str(reverse_gps) 127 | response = get_response_rag(image_path, base_url, api_key, model_name, candidates_gps, reverse_gps) 128 | except Exception as e: 129 | response = "None" 130 | print(e) 131 | row['rag_response'] = response 132 | return row 133 | 134 | def check_conditions(coord_str): 135 | if coord_str.startswith('[]') or coord_str.startswith('None'): 136 | return True 137 | try: 138 | coordinates = ast.literal_eval(coord_str) 139 | return float(coordinates[0]) == 0.0 140 | except: 141 | return False 142 | def run(args): 143 | api_key = args.api_key 144 | model_name = args.model_name 145 | base_url = args.base_url 146 | root_path = args.root_path 147 | text_path = args.text_path 148 | image_path = args.image_path 149 | result_path = args.result_path 150 | rag_path = args.rag_path 151 | process = args.process 152 | rag_sample_num = args.rag_sample_num 153 | searching_file_name = args.searching_file_name 154 | 155 | if process == 'predict': 156 | if os.path.exists(os.path.join(root_path, result_path)): 157 | df = pd.read_csv(os.path.join(root_path, result_path)) 158 | df_rerun = df[df['response'].isna()] 159 | print('Need Rerun:', df_rerun.shape[0]) 160 | df_rerun = df_rerun.parallel_apply(lambda row: process_row(row, base_url, api_key, model_name, root_path, image_path), axis=1) 161 | df.update(df_rerun) 162 | df.to_csv(os.path.join(root_path, result_path), index=False) 163 | else: 164 | df = pd.read_csv(os.path.join(root_path, text_path)) 165 | df = df.parallel_apply(lambda row: process_row(row, base_url, api_key, model_name, root_path, image_path), axis=1) 166 | df.to_csv(os.path.join(root_path, result_path), index=False) 167 | 168 | if process == 'extract': 169 | df = pd.read_csv(os.path.join(root_path, result_path)) 170 | pattern = r'[-+]?\d+\.\d+' 171 | df['coordinates'] = df['response'].apply(lambda x: re.findall(pattern, x)) 172 | df.to_csv(os.path.join(root_path, result_path), index=False) 173 | 174 | if process == 'rag': 175 | database_df = pd.read_csv('./data/MP16_Pro_filtered.csv') 176 | if not os.path.exists(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path)): 177 | df = pd.read_csv(os.path.join(root_path, text_path)) 178 | I = np.load('./index/{}.npy'.format(searching_file_name)) 179 | reverse_I = np.load('./index/{}_reverse.npy'.format(searching_file_name)) 180 | for i in tqdm(range(df.shape[0])): 181 | candidate_idx_lis = I[i] 182 | candidate_gps = database_df.loc[candidate_idx_lis, ['LAT', 'LON', 'city', 'state', 'country']].values 183 | for idx, (latitude, longitude, city, state, country) in enumerate(candidate_gps): 184 | df.loc[i, f'candidate_{idx}_gps'] = f'[{latitude}, {longitude}]' 185 | reverse_idx_lis = reverse_I[i] 186 | reverse_gps = database_df.loc[reverse_idx_lis, ['LAT', 'LON', 'city', 'state', 'country']].values 187 | for idx, (latitude, longitude, city, state, country) in enumerate(reverse_gps): 188 | df.loc[i, f'reverse_{idx}_gps'] = f'[{latitude}, {longitude}]' 189 | df.to_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path), index=False) 190 | df = df.parallel_apply(lambda row: process_row_rag(row, base_url, api_key, model_name, root_path, image_path, rag_sample_num), axis=1) 191 | df.to_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path), index=False) 192 | else: 193 | df = pd.read_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path)) 194 | # df_rerun = df[df['rag_coordinates'].apply(check_conditions)] 195 | df_rerun = df[df['rag_response'].isna()] 196 | print('Need Rerun:', df_rerun.shape[0]) 197 | df_rerun = df_rerun.parallel_apply(lambda row: process_row_rag(row, base_url, api_key, model_name, root_path, image_path, rag_sample_num), axis=1) 198 | df.update(df_rerun) 199 | df.to_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path), index=False) 200 | 201 | if process == 'rag_extract': 202 | df = pd.read_csv(os.path.join(root_path, rag_path)).fillna("None") 203 | pattern = r'[-+]?\d+\.\d+' 204 | df['rag_coordinates'] = df['rag_response'].apply(lambda x: re.findall(pattern, x)) 205 | df.to_csv(os.path.join(root_path, rag_path), index=False) 206 | 207 | if __name__ == '__main__': 208 | args = argparse.ArgumentParser() 209 | api_key = "sk-xxx" 210 | model_name = "gpt-xxx" # gpt-4-vision-preview, gpt-4-turbo-2024-04-09 211 | base_url = "https://xxx" 212 | 213 | root_path = "./data/im2gps3k" 214 | text_path = "im2gps3k_places365.csv" 215 | image_path = "images" 216 | result_path = "llm_predict_results_zs.csv" 217 | rag_path = "llm_predict_results_rag.csv" 218 | process = 'rag' # predict, extract, rag, rag_extract 219 | rag_sample_num = 15 220 | searching_file_name = 'I_g3_im2gps3k' 221 | 222 | pandarallel.initialize(progress_bar=True, nb_workers=16) 223 | args.add_argument('--api_key', type=str, default=api_key) 224 | args.add_argument('--model_name', type=str, default=model_name) 225 | args.add_argument('--base_url', type=str, default=base_url) 226 | args.add_argument('--root_path', type=str, default=root_path) 227 | args.add_argument('--text_path', type=str, default=text_path) 228 | args.add_argument('--image_path', type=str, default=image_path) 229 | args.add_argument('--result_path', type=str, default=result_path) 230 | args.add_argument('--rag_path', type=str, default=rag_path) 231 | args.add_argument('--process', type=str, default=process) 232 | args.add_argument('--rag_sample_num', type=int, default=rag_sample_num) 233 | args.add_argument('--searching_file_name', type=str, default=searching_file_name) 234 | args = args.parse_args() 235 | print(args) 236 | 237 | run(args) 238 | 239 | 240 | -------------------------------------------------------------------------------- /llm_predict_hf.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import base64 3 | import os 4 | import re 5 | import pandas as pd 6 | import numpy as np 7 | import ast 8 | from pandarallel import pandarallel 9 | from tqdm import tqdm 10 | import json 11 | import time 12 | import argparse 13 | from retrying import retry 14 | from PIL import Image 15 | import torch 16 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration 17 | import datetime 18 | 19 | def encode_image(image_path): 20 | with open(image_path, "rb") as image_file: 21 | return base64.b64encode(image_file.read()).decode('utf-8') 22 | 23 | def get_response(image_path, model, processor, max_tokens=200, temperature=0.7, n=10): 24 | image = Image.open(image_path) 25 | conversation = [ 26 | { 27 | 28 | "role": "user", 29 | "content": [ 30 | {"type": "text", "text": '''Suppose you are an expert in geo-localization, you have the ability to give two number GPS coordination given an image. 31 | Please give me the location of the given image. 32 | Remember, you must have an answer, just output your best guess, don't answer me that you can't give a location. 33 | Your answer should be in the following JSON format without any other information: {"latitude": float,"longitude": float}. 34 | Your answer should be in the following JSON format without any other information: {"latitude": float,"longitude": float}.'''}, 35 | {"type": "image"}, 36 | ], 37 | }, 38 | ] 39 | prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) 40 | inputs = processor(prompt, image, return_tensors="pt").to(model.device) 41 | output = model.generate(**inputs, max_new_tokens=max_tokens, temperature=temperature, num_return_sequences=n, do_sample=True, pad_token_id=processor.tokenizer.pad_token_id) 42 | ans = [] 43 | dialogue = processor.batch_decode(output, skip_special_tokens=True) 44 | for i in range(n): 45 | assistant_reply = dialogue[i].split("assistant")[-1].strip() 46 | ans.append(assistant_reply) 47 | return ans 48 | 49 | def get_response_rag(image_path, model, processor, candidates_gps, reverse_gps, max_tokens=200, temperature=0.7, n=10): 50 | image = Image.open(image_path) 51 | conversation = [ 52 | { 53 | 54 | "role": "user", 55 | "content": [ 56 | {"type": "text", "text": f"""Suppose you are an expert in geo-localization, Please analyze this image and give me a guess of the location. 57 | Your answer must be to the coordinates level in (latitude, longitude) format. 58 | For your reference, these are coordinates of some similar images: {candidates_gps}, and these are coordinates of some dissimilar images: {reverse_gps}. 59 | Remember, you must have an answer, just output your best guess, don't answer me that you can't give an location. 60 | Your answer should be in the following JSON format without any other information: {{"latitude": float,"longitude": float}}. 61 | Your answer should be in the following JSON format without any other information: {{"latitude": float,"longitude": float}}. 62 | """}, 63 | {"type": "image"}, 64 | ], 65 | }, 66 | ] 67 | 68 | prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) 69 | inputs = processor(prompt, image, return_tensors="pt").to(model.device) 70 | output = model.generate(**inputs, max_new_tokens=max_tokens, temperature=temperature, num_return_sequences=n, do_sample=True, pad_token_id=processor.tokenizer.pad_token_id) 71 | ans = [] 72 | dialogue = processor.batch_decode(output, skip_special_tokens=True) 73 | for i in range(n): 74 | assistant_reply = dialogue[i].split("assistant")[-1].strip() 75 | ans.append(assistant_reply) 76 | return ans 77 | 78 | def process_row(row, model, processor, root_path, image_path): 79 | image_path = os.path.join(root_path, image_path, row["IMG_ID"]) 80 | try: 81 | response = get_response(image_path, model, processor) 82 | except Exception as e: 83 | response = "None" 84 | print(e) 85 | row['response'] = response 86 | return row 87 | 88 | def process_row_rag(row, model, processor, root_path, image_path, rag_sample_num): 89 | image_path = os.path.join(root_path, image_path, row["IMG_ID"]) 90 | try: 91 | candidates_gps = [eval(row[f'candidate_{i}_gps']) for i in range(rag_sample_num)] 92 | candidates_gps = str(candidates_gps) 93 | reverse_gps = [eval(row[f'reverse_{i}_gps']) for i in range(rag_sample_num)] 94 | reverse_gps = str(reverse_gps) 95 | response = get_response_rag(image_path, model, processor, candidates_gps, reverse_gps) 96 | except Exception as e: 97 | response = "None" 98 | print(e) 99 | row['rag_response'] = response 100 | return row 101 | 102 | def run(args): 103 | root_path = args.root_path 104 | text_path = args.text_path 105 | image_path = args.image_path 106 | result_path = args.result_path 107 | rag_path = args.rag_path 108 | process = args.process 109 | rag_sample_num = args.rag_sample_num 110 | searching_file_name = args.searching_file_name 111 | model = args.model 112 | processor = args.processor 113 | tqdm.pandas(desc='') 114 | 115 | if process == 'predict': 116 | if os.path.exists(os.path.join(root_path, result_path)): 117 | df = pd.read_csv(os.path.join(root_path, result_path)) 118 | df_rerun = df[df['response'].isna()] 119 | print('Need Rerun:', df_rerun.shape[0]) 120 | df_rerun = df_rerun.progress_apply(lambda row: process_row(row, model, processor, root_path, image_path), axis=1) 121 | df.update(df_rerun) 122 | df.to_csv(os.path.join(root_path, result_path), index=False) 123 | else: 124 | df = pd.read_csv(os.path.join(root_path, text_path)) 125 | df = df.progress_apply(lambda row: process_row(row, model, processor, root_path, image_path), axis=1) 126 | df.to_csv(os.path.join(root_path, result_path), index=False) 127 | 128 | if process == 'extract': 129 | df = pd.read_csv(os.path.join(root_path, result_path)) 130 | pattern = r'[-+]?\d+\.\d+' 131 | df['coordinates'] = df['response'].apply(lambda x: re.findall(pattern, x)) 132 | df.to_csv(os.path.join(root_path, result_path), index=False) 133 | 134 | if process == 'rag': 135 | database_df = pd.read_csv('./data/MP16_Pro_filtered.csv') 136 | if not os.path.exists(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path)): 137 | df = pd.read_csv(os.path.join(root_path, text_path)) 138 | I = np.load('./index/{}.npy'.format(searching_file_name)) 139 | reverse_I = np.load('./index/{}_reverse.npy'.format(searching_file_name)) 140 | for i in tqdm(range(df.shape[0])): 141 | candidate_idx_lis = I[i] 142 | candidate_gps = database_df.loc[candidate_idx_lis, ['LAT', 'LON']].values 143 | for idx, (latitude, longitude) in enumerate(candidate_gps): 144 | df.loc[i, f'candidate_{idx}_gps'] = f'[{latitude}, {longitude}]' 145 | reverse_idx_lis = reverse_I[i] 146 | reverse_gps = database_df.loc[reverse_idx_lis, ['LAT', 'LON']].values 147 | for idx, (latitude, longitude) in enumerate(reverse_gps): 148 | df.loc[i, f'reverse_{idx}_gps'] = f'[{latitude}, {longitude}]' 149 | df.to_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path), index=False) 150 | df = df.progress_apply(lambda row: process_row_rag(row, model, processor, root_path, image_path, rag_sample_num), axis=1) 151 | df.to_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path), index=False) 152 | else: 153 | df = pd.read_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path)) 154 | # df_rerun = df[df['rag_coordinates'].apply(check_conditions)] 155 | df_rerun = df[df['rag_response'].isna()] 156 | print('Need Rerun:', df_rerun.shape[0]) 157 | df_rerun = df_rerun.progress_apply(lambda row: process_row_rag(row, model, processor, root_path, image_path, rag_sample_num), axis=1) 158 | df.update(df_rerun) 159 | df.to_csv(os.path.join(root_path, str(rag_sample_num) + '_' + rag_path), index=False) 160 | 161 | if process == 'rag_extract': 162 | df = pd.read_csv(os.path.join(root_path, rag_path)).fillna("None") 163 | pattern = r'[-+]?\d+\.\d+' 164 | df['rag_coordinates'] = df['rag_response'].apply(lambda x: re.findall(pattern, x)) 165 | df.to_csv(os.path.join(root_path, rag_path), index=False) 166 | 167 | 168 | if __name__ == '__main__': 169 | args = argparse.ArgumentParser() 170 | model_path = "./llava-next-8b-llama3" 171 | root_path = "./data/im2gps3k" 172 | text_path = "im2gps3k_places365.csv" 173 | image_path = "images" 174 | result_path = "llm_predict_results_zs_llava.csv" 175 | rag_path = "llm_predict_results_rag_llava.csv" 176 | process = 'predict' # predict, extract, rag, rag_extract, select, selected_extract 177 | rag_sample_num = 5 178 | searching_file_name = 'I_g3_im2gps3k' 179 | 180 | processor = LlavaNextProcessor.from_pretrained(model_path) 181 | model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto") 182 | 183 | # pandarallel.initialize(progress_bar=True, nb_workers=4) 184 | args.add_argument('--root_path', type=str, default=root_path) 185 | args.add_argument('--text_path', type=str, default=text_path) 186 | args.add_argument('--image_path', type=str, default=image_path) 187 | args.add_argument('--result_path', type=str, default=result_path) 188 | args.add_argument('--rag_path', type=str, default=rag_path) 189 | args.add_argument('--process', type=str, default=process) 190 | args.add_argument('--rag_sample_num', type=int, default=rag_sample_num) 191 | args.add_argument('--searching_file_name', type=str, default=searching_file_name) 192 | args = args.parse_args() 193 | print(args) 194 | args.model = model 195 | args.processor = processor 196 | 197 | run(args) -------------------------------------------------------------------------------- /run_G3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import time 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader 7 | from utils.utils import MP16Dataset 8 | from utils.G3 import G3 9 | from accelerate import Accelerator, DistributedDataParallelKwargs 10 | import warnings 11 | 12 | warnings.filterwarnings('ignore') 13 | 14 | def train_1epoch(dataloader, eval_dataloader, earlystopper, model, vision_processor, text_processor, optimizer, scheduler, device, accelerator=None): 15 | model.train() 16 | t = tqdm(dataloader, disable=not accelerator.is_local_main_process) 17 | for i, (images, texts, longitude, latitude) in enumerate(t): 18 | texts = text_processor(text=texts, padding='max_length', truncation=True, return_tensors='pt', max_length=77) 19 | images = images.to(device) 20 | texts = texts.to(device) 21 | longitude = longitude.to(device).float() 22 | latitude = latitude.to(device).float() 23 | optimizer.zero_grad() 24 | 25 | output = model(images, texts, longitude, latitude, return_loss=True) 26 | loss = output['loss'] 27 | 28 | # loss.backward() 29 | accelerator.backward(loss) 30 | optimizer.step() 31 | if i % 1 == 0: 32 | t.set_description('step {}, loss {}, lr {}'.format(i, loss.item(), scheduler.get_last_lr()[0])) 33 | scheduler.step() 34 | 35 | 36 | def main(): 37 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 38 | accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) 39 | 40 | # fine-tune 41 | device = "cuda" if torch.cuda.is_available() else "cpu" 42 | # device = 'cpu' 43 | model = G3(device).to(device) 44 | location_encoder_dict = torch.load('location_encoder.pth') # from geoclip 45 | model.location_encoder.load_state_dict(location_encoder_dict) 46 | 47 | dataset = MP16Dataset(vision_processor = model.vision_processor, text_processor = model.text_processor) 48 | dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=5) 49 | 50 | 51 | params = [] 52 | for name, param in model.named_parameters(): 53 | if param.requires_grad: 54 | print(name, param.size()) 55 | params.append(param) 56 | 57 | optimizer = torch.optim.AdamW([param for name,param in model.named_parameters() if param.requires_grad], lr=3e-5, weight_decay=1e-6) 58 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.87) 59 | 60 | model, optimizer, dataloader, scheduler = accelerator.prepare( 61 | model, optimizer, dataloader, scheduler 62 | ) 63 | 64 | eval_dataloader = None 65 | earlystopper = None 66 | for epoch in range(10): 67 | train_1epoch(dataloader, eval_dataloader, earlystopper, model, model.vision_processor, model.text_processor, optimizer, scheduler, device, accelerator) 68 | unwrapped_model = accelerator.unwrap_model(model) 69 | torch.save(unwrapped_model, 'checkpoints/g3_{}_.pth'.format(epoch)) 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /utils/G3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import pandas as pd 5 | import itertools 6 | from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPModel 7 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 8 | from .rff.layers import GaussianEncoding 9 | from pyproj import Proj, Transformer 10 | 11 | class LocationEncoderCapsule(nn.Module): 12 | def __init__(self, sigma): 13 | super(LocationEncoderCapsule, self).__init__() 14 | rff_encoding = GaussianEncoding(sigma=sigma, input_size=2, encoded_size=256) 15 | self.km = sigma 16 | self.capsule = nn.Sequential(rff_encoding, 17 | nn.Linear(512, 1024), 18 | nn.ReLU(), 19 | nn.Linear(1024, 1024), 20 | nn.ReLU(), 21 | nn.Linear(1024, 1024), 22 | nn.ReLU()) 23 | self.head = nn.Sequential(nn.Linear(1024, 512)) 24 | 25 | def forward(self, x): 26 | x = self.capsule(x) 27 | x = self.head(x) 28 | return x 29 | 30 | class CustomLocationEncoder(nn.Module): 31 | def __init__(self, sigma=[2**0, 2**4, 2**8]): 32 | super(CustomLocationEncoder, self).__init__() 33 | 34 | self.sigma = sigma 35 | self.n = len(self.sigma) 36 | 37 | for i, s in enumerate(self.sigma): 38 | self.add_module('LocEnc' + str(i), LocationEncoderCapsule(sigma=s)) 39 | 40 | proj_wgs84 = Proj('epsg:4326') 41 | proj_mercator = Proj('epsg:3857') 42 | self.transformer = Transformer.from_proj(proj_wgs84, proj_mercator, always_xy=True) 43 | 44 | def forward(self, input): 45 | lat = input[:, 0].float().detach().cpu().numpy() 46 | lon = input[:, 1].float().detach().cpu().numpy() 47 | projected_lon_lat = self.transformer.transform(lon, lat) 48 | location = [] 49 | for coord in zip(*projected_lon_lat): 50 | location.append([coord[1],coord[0]]) 51 | location = torch.Tensor(location).to('cuda') 52 | location = location / 20037508.3427892 53 | 54 | location_features = torch.zeros(location.shape[0], 512).to('cuda') 55 | 56 | for i in range(self.n): 57 | location_features += self._modules['LocEnc' + str(i)](location) 58 | 59 | return location_features 60 | 61 | 62 | class G3(torch.nn.Module): 63 | def __init__(self, device): 64 | super(G3, self).__init__() 65 | self.device = device 66 | 67 | clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") 68 | self.vision_model = clip_model.vision_model 69 | self.text_model = clip_model.text_model 70 | self.vision_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 71 | self.text_processor = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 72 | self.vision_projection = clip_model.visual_projection 73 | self.text_projection = clip_model.text_projection 74 | 75 | self.logit_scale1 = nn.Parameter(torch.tensor(3.99)) 76 | self.logit_scale2 = nn.Parameter(torch.tensor(3.99)) 77 | self.logit_scale3 = nn.Parameter(torch.tensor(3.99)) 78 | 79 | self.location_encoder = CustomLocationEncoder() # output batch_size, 3, 512 80 | #self.location_encoder = LocationEncoder(sigma=[2**0, 2**4, 2**8]) 81 | self.vision_projection_else_1 = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)) 82 | self.text_projection_else = nn.Sequential(nn.Linear(768,768), nn.ReLU(), nn.Linear(768, 768)) 83 | 84 | self.vision_projection_else_2 = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)) 85 | self.location_projection_else = nn.Sequential(nn.Linear(512,512), nn.ReLU(), nn.Linear(512, 768)) 86 | 87 | # freeze CLIP 88 | self.vision_model.requires_grad_(False) 89 | self.vision_projection.requires_grad_(False) 90 | self.text_model.requires_grad_(False) 91 | self.text_projection.requires_grad_(False) 92 | 93 | def forward(self, images, texts, longitude, latitude, return_loss=True): 94 | 95 | vision_output = self.vision_model(images)[1] 96 | text_output = self.text_model(**texts)[1] 97 | image_embeds = self.vision_projection(vision_output) 98 | text_embeds = self.text_projection(text_output) # batch_size, 512 99 | this_batch_locations = torch.stack((latitude, longitude), dim=1) 100 | location_embeds = self.location_encoder(this_batch_locations) 101 | 102 | # phase _1 103 | image_embeds_1 = self.vision_projection_else_1(image_embeds) 104 | text_embeds_1 = self.text_projection_else(text_embeds.reshape(text_embeds.shape[0], -1)) 105 | 106 | # normalized features 107 | image_embeds_1 = image_embeds_1 / image_embeds_1.norm(p=2, dim=-1, keepdim=True) 108 | text_embeds_1 = text_embeds_1 / text_embeds_1.norm(p=2, dim=-1, keepdim=True) 109 | 110 | # image with texts 111 | logit_scale = self.logit_scale1.exp() 112 | logits_per_texts_with_images = torch.matmul(text_embeds_1, image_embeds_1.t()) * logit_scale 113 | logits_per_images_with_texts = logits_per_texts_with_images.t() 114 | if return_loss: loss1 = self.clip_loss(logits_per_texts_with_images) 115 | 116 | loss_phase_1 = None 117 | if return_loss: 118 | loss_phase_1 = loss1 119 | 120 | # phase _2 121 | image_embeds_2 = self.vision_projection_else_2(image_embeds) 122 | location_embeds_2 = self.location_projection_else(location_embeds.reshape(location_embeds.shape[0], -1)) 123 | 124 | # normalized features 125 | image_embeds_2 = image_embeds_2 / image_embeds_2.norm(p=2, dim=-1, keepdim=True) 126 | location_embeds_2 = location_embeds_2 / location_embeds_2.norm(p=2, dim=-1, keepdim=True) 127 | 128 | # image with location 129 | logit_scale = self.logit_scale2.exp() 130 | logits_per_locations_with_images = torch.matmul(location_embeds_2, image_embeds_2.t()) * logit_scale 131 | logits_per_images_with_locations = logits_per_locations_with_images.t() 132 | loss_phase_2 = None 133 | if return_loss: loss_phase_2 = self.clip_loss(logits_per_locations_with_images) 134 | 135 | loss = loss_phase_1 + loss_phase_2 136 | 137 | return { 138 | 'logits_per_texts_with_images': logits_per_texts_with_images, 139 | 'logits_per_images_with_texts': logits_per_images_with_texts, 140 | 'logits_per_locations_with_images': logits_per_locations_with_images, 141 | 'logits_per_images_with_locations': logits_per_images_with_locations, 142 | 'logits_per_locations_with_texts': None, 143 | 'logits_per_texts_with_locations': None, 144 | 'loss': loss, 145 | 'vision_output': vision_output, 146 | 'text_output': text_output, 147 | 'image_embeds': image_embeds, 148 | 'text_embeds': text_embeds 149 | } 150 | 151 | 152 | def contrastive_loss(self, logits: torch.Tensor) -> torch.Tensor: 153 | return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) 154 | 155 | 156 | def clip_loss(self, similarity: torch.Tensor) -> torch.Tensor: 157 | caption_loss = self.contrastive_loss(similarity) 158 | image_loss = self.contrastive_loss(similarity.t()) 159 | return (caption_loss + image_loss) / 2.0 160 | -------------------------------------------------------------------------------- /utils/rff/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch import Tensor 5 | 6 | 7 | def sample_b(sigma: float, size: tuple) -> Tensor: 8 | r"""Matrix of size :attr:`size` sampled from from :math:`\mathcal{N}(0, \sigma^2)` 9 | 10 | Args: 11 | sigma (float): standard deviation 12 | size (tuple): size of the matrix sampled 13 | 14 | See :class:`~rff.layers.GaussianEncoding` for more details 15 | """ 16 | return torch.randn(size) * sigma 17 | 18 | 19 | @torch.jit.script 20 | def gaussian_encoding( 21 | v: Tensor, 22 | b: Tensor) -> Tensor: 23 | r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{B} \mathbf{v}} , \sin{2 \pi \mathbf{B} \mathbf{v}})` 24 | 25 | Args: 26 | v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})` 27 | b (Tensor): projection matrix of shape :math:`(\text{encoded_layer_size}, \text{input_size})` 28 | 29 | Returns: 30 | Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{encoded_layer_size})` 31 | 32 | See :class:`~rff.layers.GaussianEncoding` for more details. 33 | """ 34 | vp = 2 * np.pi * v @ b.T 35 | return torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1) 36 | 37 | 38 | @torch.jit.script 39 | def basic_encoding( 40 | v: Tensor) -> Tensor: 41 | r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{v}} , \sin{2 \pi \mathbf{v}})` 42 | 43 | Args: 44 | v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})` 45 | 46 | Returns: 47 | Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{input_size})` 48 | 49 | See :class:`~rff.layers.BasicEncoding` for more details. 50 | """ 51 | vp = 2 * np.pi * v 52 | return torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1) 53 | 54 | 55 | @torch.jit.script 56 | def positional_encoding( 57 | v: Tensor, 58 | sigma: float, 59 | m: int) -> Tensor: 60 | r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)` 61 | where :math:`j \in \{0, \dots, m-1\}` 62 | 63 | Args: 64 | v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})` 65 | sigma (float): constant chosen based upon the domain of :attr:`v` 66 | m (int): [description] 67 | 68 | Returns: 69 | Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})` 70 | 71 | See :class:`~rff.layers.PositionalEncoding` for more details. 72 | """ 73 | j = torch.arange(m, device=v.device) 74 | coeffs = 2 * np.pi * sigma ** (j / m) 75 | vp = coeffs * torch.unsqueeze(v, -1) 76 | vp_cat = torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1) 77 | return vp_cat.flatten(-2, -1) 78 | -------------------------------------------------------------------------------- /utils/rff/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from typing import Optional 4 | from torch import Tensor 5 | from . import functional 6 | 7 | class GaussianEncoding(nn.Module): 8 | """Layer for mapping coordinates using random Fourier features""" 9 | 10 | def __init__(self, sigma: Optional[float] = None, 11 | input_size: Optional[float] = None, 12 | encoded_size: Optional[float] = None, 13 | b: Optional[Tensor] = None): 14 | r""" 15 | Args: 16 | sigma (Optional[float]): standard deviation 17 | input_size (Optional[float]): the number of input dimensions 18 | encoded_size (Optional[float]): the number of dimensions the `b` matrix maps to 19 | b (Optional[Tensor], optional): Optionally specify a :attr:`b` matrix already sampled 20 | Raises: 21 | ValueError: 22 | If :attr:`b` is provided and one of :attr:`sigma`, :attr:`input_size`, 23 | or :attr:`encoded_size` is provided. If :attr:`b` is not provided and one of 24 | :attr:`sigma`, :attr:`input_size`, or :attr:`encoded_size` is not provided. 25 | """ 26 | super().__init__() 27 | if b is None: 28 | if sigma is None or input_size is None or encoded_size is None: 29 | raise ValueError( 30 | 'Arguments "sigma," "input_size," and "encoded_size" are required.') 31 | 32 | b = functional.sample_b(sigma, (encoded_size, input_size)) 33 | elif sigma is not None or input_size is not None or encoded_size is not None: 34 | raise ValueError('Only specify the "b" argument when using it.') 35 | self.b = nn.parameter.Parameter(b, requires_grad=False) 36 | 37 | def forward(self, v: Tensor) -> Tensor: 38 | r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{B} \mathbf{v}} , \sin{2 \pi \mathbf{B} \mathbf{v}})` 39 | 40 | Args: 41 | v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})` 42 | 43 | Returns: 44 | Tensor: Tensor mapping using random fourier features of shape :math:`(N, *, 2 \cdot \text{encoded_size})` 45 | """ 46 | return functional.gaussian_encoding(v, self.b) 47 | 48 | 49 | class BasicEncoding(nn.Module): 50 | """Layer for mapping coordinates using the basic encoding""" 51 | 52 | def forward(self, v: Tensor) -> Tensor: 53 | r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{v}} , \sin{2 \pi \mathbf{v}})` 54 | 55 | Args: 56 | v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})` 57 | 58 | Returns: 59 | Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{input_size})` 60 | """ 61 | return functional.basic_encoding(v) 62 | 63 | 64 | class PositionalEncoding(nn.Module): 65 | """Layer for mapping coordinates using the positional encoding""" 66 | 67 | def __init__(self, sigma: float, m: int): 68 | r""" 69 | Args: 70 | sigma (float): frequency constant 71 | m (int): number of frequencies to map to 72 | """ 73 | super().__init__() 74 | self.sigma = sigma 75 | self.m = m 76 | 77 | def forward(self, v: Tensor) -> Tensor: 78 | r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)` 79 | 80 | Args: 81 | v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})` 82 | 83 | Returns: 84 | Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})` 85 | """ 86 | return functional.positional_encoding(v, self.sigma, self.m) 87 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import numpy as np 5 | import tarfile 6 | import pickle 7 | from tqdm import tqdm 8 | from transformers import CLIPVisionModel, CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPModel 9 | from torchvision.datasets import VisionDataset 10 | from typing import Callable, Optional 11 | from torchvision.io import ImageReadMode, read_image 12 | from pathlib import Path 13 | import pandas as pd 14 | from torch.utils.data import DataLoader 15 | import torchvision.transforms as T 16 | from PIL import Image 17 | from io import BytesIO 18 | from PIL import ImageFile 19 | from torch.utils.data import get_worker_info 20 | ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow truncated images to be loaded 21 | 22 | class MP16Dataset(VisionDataset): 23 | 24 | def __init__(self, root_path='./data/', text_data_path='MP16_Pro_places365.csv', image_data_path='mp-16-images.tar', member_info_path='tar_index.pkl', vision_processor= None, text_processor=None): 25 | super().__init__(self) 26 | self.root_path = root_path 27 | self.text_data_path = text_data_path 28 | self.image_data_path = image_data_path 29 | self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path)) 30 | self.text_data['IMG_ID'] = self.text_data['IMG_ID'].apply(lambda x: x.replace('/', '_')) 31 | # self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images 32 | print('read text data success') 33 | worker = get_worker_info() 34 | worker = worker.id if worker else None 35 | self.tar_obj = {worker: tarfile.open(os.path.join(root_path, image_data_path))} 36 | # self.tar = tarfile.open(os.path.join(root_path, image_data_path)) 37 | 38 | if os.path.exists(os.path.join(self.root_path, member_info_path)): 39 | with open(os.path.join(self.root_path, member_info_path), 'rb') as f: 40 | self.tar_index = pickle.load(f) 41 | all_image_names = list(self.tar_index.keys()) 42 | print('load tar index success') 43 | else: 44 | print('no exist tar index success, need building...') 45 | self.tar_index = {} 46 | all_image_names = [] 47 | for member in tqdm(self.tar_obj[worker]): 48 | if member.name.endswith('.jpg') and member.size > 5120: 49 | self.tar_index[member.name.split('/')[2]] = member 50 | all_image_names.append(member.name.split('/')[2]) 51 | print('tar index buidling success') 52 | with open(os.path.join(self.root_path, member_info_path), 'wb') as f: 53 | pickle.dump(self.tar_index, f) 54 | all_image_names = set(all_image_names) 55 | 56 | self.text_data = self.text_data[self.text_data['country'].notnull()] 57 | self.text_data = self.text_data[self.text_data['IMG_ID'].isin(all_image_names)] 58 | print('data columns: ', self.text_data.shape[0]) 59 | 60 | # location from str to float 61 | self.text_data.loc[:,'LON'] = self.text_data['LON'].astype(float) 62 | self.text_data.loc[:,'LAT'] = self.text_data['LAT'].astype(float) 63 | print('location from str to float success') 64 | 65 | # image transform 66 | self.transform = T.Resize(size=(512, 512)) 67 | self.transform_totensor = T.ToTensor() 68 | 69 | self.vision_processor = vision_processor 70 | self.text_processor = text_processor 71 | 72 | # Define the contrast transforms here 73 | self.contrast_transforms = T.Compose([ 74 | T.RandomHorizontalFlip(), 75 | T.RandomResizedCrop(size=224), 76 | T.RandomApply([ 77 | T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1) 78 | ], p=0.8), 79 | T.RandomGrayscale(p=0.2), 80 | T.GaussianBlur(kernel_size=9), 81 | T.ToTensor() 82 | # T.Normalize((0.5,), (0.5,)) 83 | ]) 84 | 85 | # self.text_data.to_csv('/data/mp-16/MP16_Pro_filtered.csv', index=False) 86 | 87 | def caption_generation(self, row): 88 | pass 89 | 90 | def __getitem__(self, index): 91 | image_path = self.text_data.iloc[index]['IMG_ID'] 92 | text = '' 93 | neighbourhood, city, county, state, region, country, continent = self.text_data.iloc[index][['neighbourhood', 'city', 'county', 'state', 'region', 'country', 'continent']] 94 | # location_elements = [element for element in [neighbourhood, city, state, country] if element is not np.nan and str(element) != 'nan'] 95 | location_elements = [element for element in [city, state, country] if element is not np.nan and str(element) != 'nan'] 96 | text = 'A street view photo taken in '+', '.join(location_elements) 97 | 98 | longitude = self.text_data.iloc[index]['LON'] 99 | latitude = self.text_data.iloc[index]['LAT'] 100 | # read the image from self.tar 101 | worker = get_worker_info() 102 | worker = worker.id if worker else None 103 | if worker not in self.tar_obj: 104 | self.tar_obj[worker] = tarfile.open(os.path.join(self.root_path, self.image_data_path)) 105 | image = self.tar_obj[worker].extractfile(self.tar_index[image_path]) 106 | image = Image.open(image) 107 | 108 | if image.mode != 'RGB': 109 | image = image.convert('RGB') 110 | 111 | if self.vision_processor: 112 | image = self.vision_processor(images=image, return_tensors='pt')['pixel_values'].reshape(3,224,224) 113 | 114 | return image, text, longitude, latitude 115 | 116 | def __len__(self): 117 | return len(self.text_data) 118 | 119 | class im2gps3kDataset(VisionDataset): 120 | def __init__(self, root_path='./data/im2gps3k', text_data_path='im2gps3k_places365.csv', image_data_path='images/', vision_processor= None, text_processor=None): 121 | super().__init__(self) 122 | print('start loading im2gps...') 123 | self.root_path = root_path 124 | self.text_data_path = text_data_path 125 | self.image_data_path = image_data_path 126 | self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path)) 127 | # self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images 128 | print('read text data success') 129 | 130 | # location from str to float 131 | self.text_data.loc[:,'LAT'] = self.text_data['LAT'].astype(float) 132 | self.text_data.loc[:,'LON'] = self.text_data['LON'].astype(float) 133 | print('location from str to float success') 134 | 135 | self.vision_processor = vision_processor 136 | self.text_processor = text_processor 137 | 138 | self.tencrop = T.TenCrop(224) 139 | 140 | def __getitem__(self, index): 141 | image_path = self.text_data.iloc[index]['IMG_ID'] 142 | text = image_path 143 | 144 | longitude = self.text_data.iloc[index]['LON'] 145 | latitude = self.text_data.iloc[index]['LAT'] 146 | 147 | image = Image.open(os.path.join(self.root_path, self.image_data_path, image_path)) 148 | 149 | if image.mode != 'RGB': 150 | image = image.convert('RGB') 151 | 152 | # image = self.tencrop(image) # for tencrop 153 | 154 | if self.vision_processor: 155 | image = self.vision_processor(images=image, return_tensors='pt')['pixel_values'].reshape(-1,224,224) 156 | 157 | return image, text, longitude, latitude 158 | 159 | def __len__(self): 160 | return len(self.text_data) 161 | 162 | 163 | class yfcc4kDataset(VisionDataset): 164 | def __init__(self, root_path='./data/yfcc4k', text_data_path='yfcc4k_places365.csv', image_data_path='images/', vision_processor= None, text_processor=None): 165 | super().__init__(self) 166 | print('start loading yfcc4k...') 167 | self.root_path = root_path 168 | self.text_data_path = text_data_path 169 | self.image_data_path = image_data_path 170 | self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path)) 171 | # self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images 172 | print('read text data success') 173 | 174 | # location from str to float 175 | self.text_data.loc[:,'LAT'] = self.text_data['LAT'].astype(float) 176 | self.text_data.loc[:,'LON'] = self.text_data['LON'].astype(float) 177 | print('location from str to float success') 178 | 179 | self.vision_processor = vision_processor 180 | self.text_processor = text_processor 181 | 182 | def __getitem__(self, index): 183 | image_path = self.text_data.iloc[index]['IMG_ID'] 184 | text = image_path 185 | 186 | longitude = self.text_data.iloc[index]['LON'] 187 | latitude = self.text_data.iloc[index]['LAT'] 188 | 189 | image = Image.open(os.path.join(self.root_path, self.image_data_path, image_path)) 190 | 191 | if image.mode != 'RGB': 192 | image = image.convert('RGB') 193 | 194 | if self.vision_processor: 195 | image = self.vision_processor(images=image, return_tensors='pt')['pixel_values'].reshape(-1,224,224) 196 | 197 | return image, text, longitude, latitude 198 | 199 | def __len__(self): 200 | return len(self.text_data) 201 | --------------------------------------------------------------------------------