├── .gitignore ├── LICENSE ├── README.md ├── dataset_builder_multi_label.py ├── find_best_model.py ├── geoguessr_dataset.py ├── get_images.py ├── main.py ├── notebook └── GeoGuessr_AI_Demo.ipynb ├── save_production_model.py └── utils └── tensor_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Machine Learning 132 | /models 133 | /images 134 | /geoguessr_dataset 135 | /runs 136 | /tensorboard -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alex 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeoGuessr AI 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Stelath/geoguessr-ai/blob/main/notebook/GeoGuessr_AI_Demo.ipynb) ![License](https://img.shields.io/github/license/Stelath/geoguessr-ai) 4 | 5 | This project was done personally as an opportunity to learn more about CNN & MLPs by creating a ML Model that could reliably guess a random location in one of five US cities given only a Google Street View image. The idea was inspired by the game GeoGuessr where the user is given a random Google Street View location and have to guess based on the Street View their location in the world. 6 | 7 | ## Update 8 | Stanford Grad Students have recently released a model called PIGEON which significantly outperforms ours, I suggest you check it out [here](https://huggingface.co/geolocal/StreetCLIP) 9 | 10 | ## Project Overview 11 | 12 | ### Creating a Dataset 13 | 14 | In order to train the model I first had to create a reasonably large dataset of Google Street View images to train it on. To do this I wrote a python script ([get_images.py](https://github.com/Stelath/geoguessr-ai/blob/main/get_images.py "get_images.py")) to download a large set of photographs from 5 cities in the US from Google Street View API. In order to download images from the Google Street View API latitude and longitude coordinates were needed, to solve this I utilized an address book from [Open Addresses](https://openaddresses.io/) to get the latitude and longitude data of random street addresses for each of the 5 cities. 15 | 16 | ### Formatting the Images 17 | 18 | Before training the model I decided to format the targets in order to create a way for the AI to better guess the location by turning the GPS coordinates into multi class targets, through formatting each number in the coordinate as a one hot array with numbers ranging one through ten. 19 | 20 | ### Training the Model 21 | 22 | After using a couple different model architectures I settled on a wideresnet with 50 layers which gave comparable results to a wideresnet with 100 layers but took far less GPU memory. The model was trained on a dataset of 50,000 images and had the best performance 20 epochs in and then began to slowly overfit. 23 | 24 | ### How Could it be Improved? 25 | 26 | Using a custom model that is better suited to guessing locations rather than just image classification. Adding far more layers so the model can pick up more complexity however this would require a larger GPU. Arguably the best performance improver - and something that would more accurately follow the GeoGuessr - idea would be to train a 3D CNN on an array of images from the location giving the model far more data to work with and make accurate predictions. 27 | 28 | ### Takeaway 29 | 30 | While the model is by no means perfect, it is suprisingly accurate given the limited input it recives. This interestingly enough reveals that while many would regard American cities as similar there are clearly significant differences in their landscapes so much so that the AI was able to take advantage of them to at least correctly predict the city it was in the majority of hte time. 31 | 32 | ## Instructions 33 | 34 | ### Run Pretrained Model 35 | 36 | In order to run a pretrained model you can download it from Google Drive: [![Open In Drive](https://img.shields.io/badge/Google%20Drive-5383ec?style=flat&logo=googledrive&logoColor=5383ec&label=%E2%80%8B)](https://drive.google.com/file/d/1VJpeLJp6jC8IUfKy6cAtZ9WZcX1TTutW/view?usp=sharing) or you can use the Google Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Stelath/geoguessr-ai/blob/main/notebook/GeoGuessr_AI_Demo.ipynb). 37 | 38 | ### Train a Model 39 | You can use the [get_images.py](https://github.com/Stelath/geoguessr-ai/blob/main/get_images.py "get_images.py") script to download a database of images through Google Cloud, they allow 28,500 free Google Street View API calls each month so keep that in mind (anything more and you will be charged $0.007 per image), you will also have to set up a Google Cloud account. You can also download the database of Google Street View images I created [here](https://www.kaggle.com/stelath/city-street-view-dataset). 40 | 41 | After you have a database of images running the [dataset_builder_multi_label.py](https://github.com/Stelath/geoguessr-ai/blob/main/dataset_builder_multi_label.py) script will preprocess all of the images, then running [main.py](https://github.com/Stelath/geoguessr-ai/blob/main/main.py) will begin training the model. 42 | 43 | Here is a set of commands that would be used to train a model on 25,000 images (keep in mind you will need a cities folder containing `.geojson` files from [Open Addresses](https://openaddresses.io/)): 44 | ``` 45 | python -m get_images --cities cities/ --output images/ --icount 25000 --key (YOUR GSV API KEY HERE) 46 | python -m dataset_builder_multi_label --file images/picture_coords.csv --images images/ --output geoguessr_dataset/ 47 | python -m main geoguessr_dataset/ -a wide_resnet50_2 -b 16 --lr 0.0001 -j 6 --checkpoint-step 1 48 | ``` 49 | -------------------------------------------------------------------------------- /dataset_builder_multi_label.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from PIL import Image 4 | from tqdm import tqdm 5 | from random import randint 6 | import numpy as np 7 | import os 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--file", help="The CSV file to read and extract GPS coordinates from", required=True, type=str) 12 | parser.add_argument("--images", help="The path to the images folder, (defaults to: images/)", default='images/', type=str) 13 | parser.add_argument("--output", help="The output folder", required=True, type=str) 14 | return parser.parse_args() 15 | 16 | args = get_args() 17 | 18 | targets_train = [] 19 | targets_val = [] 20 | 21 | def multi_label(num, decimals=4): 22 | num_original = str(abs(float(num))) 23 | 24 | if decimals > 0: 25 | num = str(round(float(num) * (10 ** decimals))) 26 | else: 27 | num = str(round(float(num))) 28 | 29 | if num[0] == '-': 30 | label = np.array([1]) 31 | num = num[1:] 32 | else: 33 | label = np.array([0]) 34 | 35 | num = num.ljust(len(num_original.split('.')[0]) + decimals, '0') 36 | num = num.zfill(decimals + 3) 37 | 38 | for digit in num: 39 | label = np.concatenate((label, np.eye(10)[int(digit)])) 40 | 41 | return label 42 | 43 | def get_data(coord, coord_index): 44 | lat, lon = coord[0], coord[1] 45 | 46 | img_path = os.path.join(args.images, f'street_view_{coord_index}.jpg') 47 | img = Image.open(img_path) 48 | 49 | lat_multi_label = multi_label(lat) 50 | lon_multi_label = multi_label(lon) 51 | 52 | target = np.concatenate((lat_multi_label, lon_multi_label)) 53 | 54 | return [img, target] 55 | 56 | def main(): 57 | with open(args.file, 'r') as f: 58 | coords_reader = csv.reader(f) 59 | coords = [] 60 | for row in coords_reader: 61 | coords.append(row) 62 | 63 | 64 | train_data_path = os.path.join(args.output, 'train') 65 | os.makedirs(train_data_path, exist_ok=True) 66 | val_data_path = os.path.join(args.output, 'val') 67 | os.makedirs(val_data_path, exist_ok=True) 68 | 69 | val_count = 0 70 | train_count = 0 71 | 72 | for coord_index, coord in enumerate(tqdm(coords)): 73 | if randint(0, 9) == 0: 74 | data = get_data(coord, coord_index) 75 | val_data_path = os.path.join(args.output, f'val/street_view_{val_count}.jpg') 76 | data[0].save(val_data_path) 77 | targets_val.append(data[1]) 78 | val_count += 1 79 | else: 80 | data = get_data(coord, coord_index) 81 | train_data_path = os.path.join(args.output, f'train/street_view_{train_count}.jpg') 82 | data[0].save(train_data_path) 83 | targets_train.append(data[1]) 84 | train_count += 1 85 | 86 | np.save(os.path.join(args.output, f'train/targets.npy'), np.array(targets_train)) 87 | np.save(os.path.join(args.output, f'val/targets.npy'), np.array(targets_val)) 88 | 89 | print('Train Files:', train_count) 90 | print('Val Files:', val_count) 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /find_best_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | import numpy as np 5 | from datetime import datetime 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.utils.data 10 | import torchvision.models as models 11 | from geoguessr_dataset import GeoGuessrDataset 12 | 13 | model_names = sorted(name for name in models.__dict__ 14 | if name.islower() and not name.startswith("__") 15 | and callable(models.__dict__[name])) 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch GeoGuessr AI Best Model Locator') 18 | parser.add_argument('data', metavar='DIR', 19 | help='path to dataset') 20 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 21 | choices=model_names, 22 | help='model architecture: ' + 23 | ' | '.join(model_names) + 24 | ' (default: resnet50)') 25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 26 | help='number of data loading workers (default: 4)') 27 | parser.add_argument('-b', '--batch-size', default=64, type=int, 28 | metavar='N', 29 | help='batch size (default: 64), this is the total ' 30 | 'batch size of the GPU') 31 | parser.add_argument('--models-dir', default='models', type=str) 32 | 33 | start_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 34 | args = parser.parse_args() 35 | all_loss = [] 36 | 37 | def fwd_pass(model, data, targets, loss_function, optimizer, train=False): 38 | data = data.cuda() 39 | targets = targets.cuda() 40 | 41 | if train: 42 | model.zero_grad() 43 | 44 | outputs = model(data) 45 | matches = [(torch.where(i >= 0.5, 1, 0) == j).all() for i, j in zip(outputs, targets)] 46 | acc = matches.count(True) / len(matches) 47 | loss = loss_function(outputs, targets) 48 | 49 | if train: 50 | loss.backward() 51 | optimizer.step() 52 | 53 | return acc, loss 54 | 55 | def test(val_loader, model, loss_function): 56 | model.eval() 57 | acc = [] 58 | loss = [] 59 | 60 | for idx, sample in enumerate(tqdm(val_loader)): 61 | if idx < 256: 62 | data, target = sample 63 | batch_acc, batch_loss = fwd_pass(model, data, target, loss_function, None) 64 | acc.append(batch_acc) 65 | loss.append(batch_loss.cpu().detach().numpy()) 66 | 67 | acc = np.mean(acc) 68 | loss = np.mean(loss) 69 | 70 | val_acc = np.mean(acc) 71 | val_loss = np.mean(loss) 72 | return val_acc, val_loss 73 | 74 | def main(): 75 | torch.device("cuda") 76 | valdir = os.path.join(args.data, 'val') 77 | val_dataset = GeoGuessrDataset(valdir) 78 | 79 | val_loader = torch.utils.data.DataLoader( 80 | val_dataset, batch_size=args.batch_size, shuffle=False, 81 | num_workers=args.workers, pin_memory=True) 82 | 83 | print("=> creating model '{}'".format(args.arch)) 84 | model = models.__dict__[args.arch](pretrained=False, progress=True, num_classes=142) 85 | model = nn.Sequential( 86 | model, 87 | nn.Sigmoid() 88 | ) 89 | model.cuda() 90 | 91 | loss_function = nn.BCELoss() 92 | 93 | for model_path in tqdm(os.listdir(args.models_dir)): 94 | model_path = os.path.join(args.models_dir, model_path) 95 | print("=> loading model '{}'".format(model_path)) 96 | checkpoint = torch.load(model_path) 97 | model.load_state_dict(checkpoint['model_state_dict']) 98 | print("=> loaded model '{}' (epoch {})".format(model_path, checkpoint['epoch'])) 99 | 100 | val_acc, val_loss = test(val_loader, model, loss_function) 101 | all_loss.append(val_loss) 102 | print("=> val_acc: {:.4f}, val_loss: {:.4f}".format(val_acc, val_loss)) 103 | 104 | min_value = min(all_loss) 105 | min_index = all_loss.index(min_value) 106 | print("=> best model: {}, loss: {}".format(os.listdir(args.models_dir)[min_index], min_value)) 107 | 108 | if __name__ == '__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /geoguessr_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import os, os.path 6 | from PIL import Image 7 | 8 | class GeoGuessrDataset(Dataset): 9 | def __init__(self, data_dir): 10 | self.data_dir = data_dir 11 | self.targets = np.load(os.path.join(data_dir, 'targets.npy'), allow_pickle=True) 12 | 13 | def __len__(self): 14 | return len(os.listdir(self.data_dir)) - 1 15 | 16 | def __getitem__(self, idx): 17 | data_path = os.path.join(self.data_dir, f'street_view_{idx}.jpg') 18 | 19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 20 | std=[0.229, 0.224, 0.225]) 21 | transform = transforms.Compose([ 22 | transforms.Resize(256), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | 27 | img = pil_loader(data_path) 28 | data = transform(img) 29 | 30 | target = torch.tensor(self.targets[idx], dtype=torch.float) 31 | 32 | return data, target 33 | 34 | def pil_loader(path: str) -> Image.Image: 35 | with open(path, "rb") as f: 36 | img = Image.open(f) 37 | return img.convert("RGB") -------------------------------------------------------------------------------- /get_images.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from tqdm import tqdm 3 | import os 4 | import json 5 | from random import randint 6 | import argparse 7 | from csv import writer 8 | # Consider using https://osmnx.readthedocs.io/en/stable/osmnx.html#osmnx.utils_geo.sample_points for street gps coords 9 | # Stack Overflow on how to get the coordinates: https://stackoverflow.com/questions/68367074/how-to-generate-random-lat-long-points-within-geographical-boundaries 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--cities", help="The folder full of addresses per city to read and extract GPS coordinates from", required=True, type=str) 14 | parser.add_argument("--output", help="The output folder where the images will be stored, (defaults to: images/)", default='images/', type=str) 15 | parser.add_argument("--icount", help="The amount of images to pull (defaults to 25,000)", default=25000, type=int) 16 | parser.add_argument("--key", help="Your Google Street View API Key", type=str, required=True) 17 | return parser.parse_args() 18 | 19 | args = get_args() 20 | url = 'https://maps.googleapis.com/maps/api/streetview' 21 | cities = [] 22 | 23 | def load_cities(): 24 | for city in os.listdir(args.cities): 25 | with open(os.path.join(args.cities, city)) as f: 26 | coordinates = [] 27 | print(f'Loading {city} addresses...') 28 | for line in tqdm(f): 29 | data = json.loads(line) 30 | coordinates.append(data['geometry']['coordinates']) 31 | cities.append(coordinates) 32 | 33 | def main(): 34 | # Open and create all the necessary files & folders 35 | os.makedirs(args.output, exist_ok=True) 36 | 37 | load_cities() 38 | 39 | coord_output_file = open(os.path.join(args.output, 'picture_coords.csv'), 'w', newline='') 40 | csv_writer = writer(coord_output_file) 41 | 42 | for i in tqdm(range(args.icount)): 43 | cities_count = [] 44 | cities_count = [0] * len(cities) 45 | city_index = randint(0, len(cities) - 1) 46 | city = cities[city_index] 47 | cities_count[city_index] += 1 48 | addressLoc = city[randint(0, len(city) - 1)] 49 | city.remove(addressLoc) # Remove the address from the list so we don't get the same one twice 50 | # Set the parameters for the API call to Google Street View 51 | params = { 52 | 'key': args.key, 53 | 'size': '640x640', 54 | 'location': str(addressLoc[1]) + ',' + str(addressLoc[0]), 55 | 'heading': str((randint(0, 3) * 90) + randint(-15, 15)), 56 | 'pitch': '20', 57 | 'fov': '90' 58 | } 59 | 60 | response = requests.get(url, params) 61 | 62 | # Save the image to the output folder 63 | with open(os.path.join(args.output, f'street_view_{i}.jpg'), "wb") as file: 64 | file.write(response.content) 65 | 66 | # Save the coordinates to the output file 67 | csv_writer.writerow([addressLoc[1], addressLoc[0]]) 68 | 69 | coord_output_file.close() 70 | 71 | for i in range(len(cities_count)): 72 | city_count = cities_count[i] 73 | city_name = os.listdir(args.cities)[i] 74 | print(f'{city_count} images pulled from {city_name}') 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from calendar import EPOCH 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | import numpy as np 6 | from datetime import datetime 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data 11 | import torchvision.models as models 12 | from torch.utils.tensorboard import SummaryWriter 13 | from utils.tensor_utils import round_tensor 14 | from geoguessr_dataset import GeoGuessrDataset 15 | 16 | model_names = sorted(name for name in models.__dict__ 17 | if name.islower() and not name.startswith("__") 18 | and callable(models.__dict__[name])) 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch GeoGuessr AI Training') 21 | parser.add_argument('data', metavar='DIR', 22 | help='path to dataset') 23 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 24 | choices=model_names, 25 | help='model architecture: ' + 26 | ' | '.join(model_names) + 27 | ' (default: resnet50)') 28 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 29 | help='number of data loading workers (default: 4)') 30 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 31 | help='number of total epochs to run') 32 | parser.add_argument('--checkpoint-step', default=1, type=int, metavar='N', 33 | help='how often (in epochs) to save the model (default: 1)') 34 | parser.add_argument('-b', '--batch-size', default=64, type=int, 35 | metavar='N', 36 | help='batch size (default: 64), this is the total ' 37 | 'batch size of the GPU') 38 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 39 | metavar='LR', help='learning rate for optimizer', dest='lr') 40 | parser.add_argument('-p', '--print-freq', default=10, type=int, 41 | metavar='N', help='print frequency (default: 10)') 42 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 45 | help='evaluate model on validation set') 46 | 47 | start_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 48 | args = parser.parse_args() 49 | 50 | def fwd_pass(model, data, targets, loss_function, optimizer, train=False): 51 | data = data.cuda() 52 | targets = targets.cuda() 53 | 54 | if train: 55 | model.zero_grad() 56 | 57 | outputs = model(data) 58 | matches = [(torch.where(i >= 0.5, 1, 0) == j).all() for i, j in zip(outputs, targets)] 59 | acc = matches.count(True) / len(matches) 60 | loss = loss_function(outputs.float(), targets) 61 | 62 | if train: 63 | loss.backward() 64 | optimizer.step() 65 | 66 | return acc, loss 67 | 68 | def test(val_loader, model, loss_function, optimizer): 69 | random = np.random.randint(len(val_loader)) 70 | 71 | model.eval() 72 | acc = [] 73 | loss = [] 74 | 75 | for idx, sample in enumerate(val_loader): 76 | if idx >= random and idx < random + 4: 77 | data, targets = sample 78 | with torch.no_grad(): 79 | val_acc, val_loss = fwd_pass(model, data, targets, loss_function, optimizer) 80 | acc.append(val_acc) 81 | loss.append(val_loss.cpu().numpy()) 82 | 83 | val_acc = np.mean(acc) 84 | val_loss = np.mean(loss) 85 | return val_acc, val_loss 86 | 87 | def train(train_loader, val_loader, model, loss_function, optimizer, epochs, start_epoch=0): 88 | with open(f'models/{start_time}/model.log', 'a') as f: 89 | for epoch in range(start_epoch, epochs): 90 | model.train() 91 | 92 | train_acc = [] 93 | train_loss = [] 94 | 95 | for idx, sample in enumerate(tqdm(train_loader)): 96 | data, target = sample 97 | acc, loss = fwd_pass(model, data, target, loss_function, optimizer, train=True) 98 | train_acc.append(acc) 99 | train_loss.append(loss.cpu().detach().numpy()) 100 | 101 | acc = np.mean(train_acc) 102 | loss = np.mean(train_loss) 103 | 104 | val_acc, val_loss = test(val_loader, model, loss_function, optimizer) 105 | 106 | # Add accuracy and loss to tensorboard 107 | progress = len(train_loader) / idx 108 | writer.add_scalar('Loss/train', loss, epoch) 109 | writer.add_scalar('Accuracy/train', acc, epoch) 110 | writer.add_scalar('Loss/test', val_loss, epoch) 111 | writer.add_scalar('Accuracy/test', val_acc, epoch) 112 | 113 | # Log Accuracy and Loss 114 | log = f'model-{epoch}, Accuracy: {round(float(acc), 2)}, Loss: {round(float(loss), 4)}, Val Accuracy: {round(float(val_acc), 2)}, Val Loss: {round(float(val_loss), 4)}\n' 115 | print(log, end='') 116 | f.write(log) 117 | 118 | if epoch % args.checkpoint_step == 0: 119 | print('Saving model...') 120 | torch.save({ 121 | 'epoch': epoch, 122 | 'model_state_dict': model.state_dict(), 123 | 'optimizer_state_dict': optimizer.state_dict(), 124 | 'loss': loss 125 | }, f'models/{start_time}/model-{epoch}.pth') 126 | 127 | def main(): 128 | global writer 129 | writer = SummaryWriter(f'tensorboard/{start_time}') 130 | 131 | os.makedirs(f'models/{start_time}', exist_ok=True) 132 | 133 | traindir = os.path.join(args.data, 'train') 134 | valdir = os.path.join(args.data, 'val') 135 | train_dataset = GeoGuessrDataset(traindir) 136 | val_dataset = GeoGuessrDataset(valdir) 137 | 138 | train_loader = torch.utils.data.DataLoader( 139 | train_dataset, batch_size=args.batch_size, shuffle=True, 140 | num_workers=args.workers, pin_memory=True) 141 | 142 | val_loader = torch.utils.data.DataLoader( 143 | val_dataset, batch_size=args.batch_size, shuffle=False, 144 | num_workers=args.workers, pin_memory=True) 145 | 146 | print("=> creating model '{}'".format(args.arch)) 147 | model = models.__dict__[args.arch](pretrained=False, progress=True, num_classes=142) 148 | model = nn.Sequential( 149 | model, 150 | nn.Softmax(dim=0) 151 | ) 152 | 153 | loss_function = nn.CrossEntropyLoss() 154 | 155 | if torch.cuda.is_available(): 156 | print('Using GPU') 157 | torch.device("cuda") 158 | model = model.cuda() 159 | loss_function = loss_function.cuda() 160 | else: 161 | print('Using CPU') 162 | torch.device("cpu") 163 | 164 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=1e-4) 165 | 166 | start_epoch = 0 167 | 168 | if not args.resume == '': 169 | checkpoint = torch.load(args.resume) 170 | model.load_state_dict(checkpoint['model_state_dict']) 171 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 172 | start_epoch = checkpoint['epoch'] + 1 173 | print(f'Resuming from epoch {start_epoch}') 174 | 175 | EPOCHS = args.epochs 176 | train(train_loader=train_loader, val_loader=val_loader, model=model, loss_function=loss_function, optimizer=optimizer, epochs=EPOCHS, start_epoch=start_epoch) 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /notebook/GeoGuessr_AI_Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "GeoGuessr-AI Demo", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyP5rZRIlS0xIT6+iHUu0CqK", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "# GeoGuessr-AI Demo" 34 | ], 35 | "metadata": { 36 | "id": "XYROTdQlbE6x" 37 | } 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "source": [ 42 | "Disclaimer, this model was only trained on pictures from 5 cities (Washington, City of New York, Chicago, Detroit, and San Francisco) and will therefore only be able to predict the location of photos taken in those cities." 43 | ], 44 | "metadata": { 45 | "id": "mPqEzY1MIfwQ" 46 | } 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "source": [ 51 | "## Setup" 52 | ], 53 | "metadata": { 54 | "id": "hToy8jcKbOjx" 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 1, 60 | "metadata": { 61 | "colab": { 62 | "base_uri": "https://localhost:8080/" 63 | }, 64 | "id": "d17FwqlWXsTd", 65 | "outputId": "9dec4987-f9a7-438e-e3cf-0c1425203300" 66 | }, 67 | "outputs": [ 68 | { 69 | "output_type": "stream", 70 | "name": "stdout", 71 | "text": [ 72 | "Downloading...\n", 73 | "From: https://drive.google.com/uc?id=1VJpeLJp6jC8IUfKy6cAtZ9WZcX1TTutW&confirm=t\n", 74 | "To: /content/geoguessr_production_model.pt\n", 75 | "100% 269M/269M [00:01<00:00, 176MB/s]\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "!gdown 'https://drive.google.com/uc?id=1VJpeLJp6jC8IUfKy6cAtZ9WZcX1TTutW&confirm=t'" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "source": [ 86 | "import torch\n", 87 | "import torchvision.models as models\n", 88 | "import torchvision.transforms as transforms\n", 89 | "import torch.nn as nn\n", 90 | "\n", 91 | "import numpy as np\n", 92 | "from PIL import Image\n", 93 | "import os\n", 94 | "\n", 95 | "from google.colab import files\n", 96 | "\n", 97 | "model = models.wide_resnet50_2(pretrained=False, progress=True, num_classes=142)\n", 98 | "model = nn.Sequential(\n", 99 | " model,\n", 100 | " nn.Sigmoid()\n", 101 | ")\n", 102 | "model_file = torch.load('geoguessr_production_model.pt', map_location=torch.device('cpu'))\n", 103 | "model.load_state_dict(model_file)\n", 104 | "model.eval()\n", 105 | "print('Loaded Model')\n", 106 | "\n", 107 | "def reformat(arr, guess_num=1):\n", 108 | " num = ''\n", 109 | " if arr[0] >= 0.5:\n", 110 | " num += '-'\n", 111 | " \n", 112 | " arr = arr[1:]\n", 113 | "\n", 114 | " for idx in range(0, len(arr), 10):\n", 115 | " if idx == 30:\n", 116 | " num += '.'\n", 117 | " num += str(np.where(arr[idx:idx+10] == np.partition(arr[idx:idx+10].flatten(), -guess_num)[-guess_num])[0][0] % 10)\n", 118 | "\n", 119 | " return num\n", 120 | "\n", 121 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 122 | " std=[0.229, 0.224, 0.225])\n", 123 | "transform = transforms.Compose([\n", 124 | " transforms.Resize((256, 256)),\n", 125 | " transforms.ToTensor(),\n", 126 | " normalize,\n", 127 | " ])" 128 | ], 129 | "metadata": { 130 | "colab": { 131 | "base_uri": "https://localhost:8080/" 132 | }, 133 | "id": "AiJcPzD0bmT2", 134 | "outputId": "1787a5a5-67c4-4381-dbfe-09f9e7aa0c28" 135 | }, 136 | "execution_count": 2, 137 | "outputs": [ 138 | { 139 | "output_type": "stream", 140 | "name": "stdout", 141 | "text": [ 142 | "Loaded Model\n" 143 | ] 144 | } 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "source": [ 150 | "## Run Model" 151 | ], 152 | "metadata": { 153 | "id": "vbWYVDdocvN0" 154 | } 155 | }, 156 | { 157 | "cell_type": "code", 158 | "source": [ 159 | "#@title ## Load File\n", 160 | "\n", 161 | "reference_file = files.upload()\n", 162 | "reference_file = list(reference_file.keys())[0]\n", 163 | "\n", 164 | "img_path = os.path.join(reference_file)\n", 165 | "img = Image.open(img_path)\n", 166 | "data = transform(img)\n", 167 | "\n", 168 | "from matplotlib import pyplot as plt\n", 169 | "img = data.permute(1, 2, 0)\n", 170 | "plt.imshow(img, interpolation='nearest')\n", 171 | "plt.show()" 172 | ], 173 | "metadata": { 174 | "colab": { 175 | "resources": { 176 | "http://localhost:8080/nbextensions/google.colab/files.js": { 177 | "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgZG8gewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwoKICAgICAgbGV0IHBlcmNlbnREb25lID0gZmlsZURhdGEuYnl0ZUxlbmd0aCA9PT0gMCA/CiAgICAgICAgICAxMDAgOgogICAgICAgICAgTWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCk7CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPSBgJHtwZXJjZW50RG9uZX0lIGRvbmVgOwoKICAgIH0gd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCk7CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", 178 | "ok": true, 179 | "headers": [ 180 | [ 181 | "content-type", 182 | "application/javascript" 183 | ] 184 | ], 185 | "status": 200, 186 | "status_text": "" 187 | } 188 | }, 189 | "base_uri": "https://localhost:8080/", 190 | "height": 342 191 | }, 192 | "cellView": "form", 193 | "id": "hKdi894nbyKp", 194 | "outputId": "662520b0-9ae2-4406-873a-49adfb3fe830" 195 | }, 196 | "execution_count": null, 197 | "outputs": [ 198 | { 199 | "output_type": "display_data", 200 | "data": { 201 | "text/html": [ 202 | "\n", 203 | " \n", 205 | " \n", 206 | " Upload widget is only available when the cell has been executed in the\n", 207 | " current browser session. Please rerun this cell to enable.\n", 208 | " \n", 209 | " " 210 | ], 211 | "text/plain": [ 212 | "" 213 | ] 214 | }, 215 | "metadata": {} 216 | }, 217 | { 218 | "output_type": "stream", 219 | "name": "stdout", 220 | "text": [ 221 | "Saving sample_image.jpeg to sample_image.jpeg\n" 222 | ] 223 | }, 224 | { 225 | "output_type": "stream", 226 | "name": "stderr", 227 | "text": [ 228 | "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" 229 | ] 230 | }, 231 | { 232 | "output_type": "display_data", 233 | "data": { 234 | "image/png": "\n", 235 | "text/plain": [ 236 | "
" 237 | ] 238 | }, 239 | "metadata": { 240 | "needs_background": "light" 241 | } 242 | } 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "source": [ 248 | "#@title ## Make Prediction\n", 249 | "#@markdown Run the model and get a lattitude and longitude prediction based on the image, the model will predict the most likely location and then make a backup guess as well.\n", 250 | "\n", 251 | "with torch.no_grad():\n", 252 | " data = data.view(-1,3,256,256)\n", 253 | " output = model(data)\n", 254 | "\n", 255 | "target_split_len = int(len(output.cpu().numpy()[0])/2)\n", 256 | "output_reformatted = reformat(output.cpu().numpy()[0][:target_split_len]) + ' ' + reformat(output.cpu().numpy()[0][target_split_len:])\n", 257 | "output_reformatted2 = reformat(output.cpu().numpy()[0][:target_split_len]) + ' ' + reformat(output.cpu().numpy()[0][target_split_len:], 2)\n", 258 | "print('First Guess:', output_reformatted)\n", 259 | "print('Second Guess:', output_reformatted2)" 260 | ], 261 | "metadata": { 262 | "colab": { 263 | "base_uri": "https://localhost:8080/" 264 | }, 265 | "id": "XWavPNf1c6Fy", 266 | "outputId": "be208b02-a41e-4399-fbf0-6bc9eb377624", 267 | "cellView": "form" 268 | }, 269 | "execution_count": null, 270 | "outputs": [ 271 | { 272 | "output_type": "stream", 273 | "name": "stdout", 274 | "text": [ 275 | "First Guess: 037.7838 -122.4105\n", 276 | "Second Guess: 037.7838 -073.3263\n" 277 | ] 278 | } 279 | ] 280 | } 281 | ] 282 | } -------------------------------------------------------------------------------- /save_production_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.utils.data 5 | 6 | parser = argparse.ArgumentParser(description='PyTorch GeoGuessr AI Best Model Locator') 7 | parser.add_argument('modelpath', metavar='DIR', 8 | help='path to model') 9 | 10 | args = parser.parse_args() 11 | 12 | def main(): 13 | checkpoint = torch.load(args.modelpath) 14 | torch.save(checkpoint['model_state_dict'], 'geoguessr_production_model.pt') 15 | 16 | if __name__ == '__main__': 17 | main() -------------------------------------------------------------------------------- /utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def round_tensor(tensor, decimals=4): 4 | return torch.round(tensor * 10 ** decimals) / (10 ** decimals) --------------------------------------------------------------------------------