├── .gitignore
├── LICENSE
├── README.md
├── dataset_builder_multi_label.py
├── find_best_model.py
├── geoguessr_dataset.py
├── get_images.py
├── main.py
├── notebook
└── GeoGuessr_AI_Demo.ipynb
├── save_production_model.py
└── utils
└── tensor_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Machine Learning
132 | /models
133 | /images
134 | /geoguessr_dataset
135 | /runs
136 | /tensorboard
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Alex
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GeoGuessr AI
2 |
3 | [](https://colab.research.google.com/github/Stelath/geoguessr-ai/blob/main/notebook/GeoGuessr_AI_Demo.ipynb) 
4 |
5 | This project was done personally as an opportunity to learn more about CNN & MLPs by creating a ML Model that could reliably guess a random location in one of five US cities given only a Google Street View image. The idea was inspired by the game GeoGuessr where the user is given a random Google Street View location and have to guess based on the Street View their location in the world.
6 |
7 | ## Update
8 | Stanford Grad Students have recently released a model called PIGEON which significantly outperforms ours, I suggest you check it out [here](https://huggingface.co/geolocal/StreetCLIP)
9 |
10 | ## Project Overview
11 |
12 | ### Creating a Dataset
13 |
14 | In order to train the model I first had to create a reasonably large dataset of Google Street View images to train it on. To do this I wrote a python script ([get_images.py](https://github.com/Stelath/geoguessr-ai/blob/main/get_images.py "get_images.py")) to download a large set of photographs from 5 cities in the US from Google Street View API. In order to download images from the Google Street View API latitude and longitude coordinates were needed, to solve this I utilized an address book from [Open Addresses](https://openaddresses.io/) to get the latitude and longitude data of random street addresses for each of the 5 cities.
15 |
16 | ### Formatting the Images
17 |
18 | Before training the model I decided to format the targets in order to create a way for the AI to better guess the location by turning the GPS coordinates into multi class targets, through formatting each number in the coordinate as a one hot array with numbers ranging one through ten.
19 |
20 | ### Training the Model
21 |
22 | After using a couple different model architectures I settled on a wideresnet with 50 layers which gave comparable results to a wideresnet with 100 layers but took far less GPU memory. The model was trained on a dataset of 50,000 images and had the best performance 20 epochs in and then began to slowly overfit.
23 |
24 | ### How Could it be Improved?
25 |
26 | Using a custom model that is better suited to guessing locations rather than just image classification. Adding far more layers so the model can pick up more complexity however this would require a larger GPU. Arguably the best performance improver - and something that would more accurately follow the GeoGuessr - idea would be to train a 3D CNN on an array of images from the location giving the model far more data to work with and make accurate predictions.
27 |
28 | ### Takeaway
29 |
30 | While the model is by no means perfect, it is suprisingly accurate given the limited input it recives. This interestingly enough reveals that while many would regard American cities as similar there are clearly significant differences in their landscapes so much so that the AI was able to take advantage of them to at least correctly predict the city it was in the majority of hte time.
31 |
32 | ## Instructions
33 |
34 | ### Run Pretrained Model
35 |
36 | In order to run a pretrained model you can download it from Google Drive: [](https://drive.google.com/file/d/1VJpeLJp6jC8IUfKy6cAtZ9WZcX1TTutW/view?usp=sharing) or you can use the Google Colab: [](https://colab.research.google.com/github/Stelath/geoguessr-ai/blob/main/notebook/GeoGuessr_AI_Demo.ipynb).
37 |
38 | ### Train a Model
39 | You can use the [get_images.py](https://github.com/Stelath/geoguessr-ai/blob/main/get_images.py "get_images.py") script to download a database of images through Google Cloud, they allow 28,500 free Google Street View API calls each month so keep that in mind (anything more and you will be charged $0.007 per image), you will also have to set up a Google Cloud account. You can also download the database of Google Street View images I created [here](https://www.kaggle.com/stelath/city-street-view-dataset).
40 |
41 | After you have a database of images running the [dataset_builder_multi_label.py](https://github.com/Stelath/geoguessr-ai/blob/main/dataset_builder_multi_label.py) script will preprocess all of the images, then running [main.py](https://github.com/Stelath/geoguessr-ai/blob/main/main.py) will begin training the model.
42 |
43 | Here is a set of commands that would be used to train a model on 25,000 images (keep in mind you will need a cities folder containing `.geojson` files from [Open Addresses](https://openaddresses.io/)):
44 | ```
45 | python -m get_images --cities cities/ --output images/ --icount 25000 --key (YOUR GSV API KEY HERE)
46 | python -m dataset_builder_multi_label --file images/picture_coords.csv --images images/ --output geoguessr_dataset/
47 | python -m main geoguessr_dataset/ -a wide_resnet50_2 -b 16 --lr 0.0001 -j 6 --checkpoint-step 1
48 | ```
49 |
--------------------------------------------------------------------------------
/dataset_builder_multi_label.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | from PIL import Image
4 | from tqdm import tqdm
5 | from random import randint
6 | import numpy as np
7 | import os
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--file", help="The CSV file to read and extract GPS coordinates from", required=True, type=str)
12 | parser.add_argument("--images", help="The path to the images folder, (defaults to: images/)", default='images/', type=str)
13 | parser.add_argument("--output", help="The output folder", required=True, type=str)
14 | return parser.parse_args()
15 |
16 | args = get_args()
17 |
18 | targets_train = []
19 | targets_val = []
20 |
21 | def multi_label(num, decimals=4):
22 | num_original = str(abs(float(num)))
23 |
24 | if decimals > 0:
25 | num = str(round(float(num) * (10 ** decimals)))
26 | else:
27 | num = str(round(float(num)))
28 |
29 | if num[0] == '-':
30 | label = np.array([1])
31 | num = num[1:]
32 | else:
33 | label = np.array([0])
34 |
35 | num = num.ljust(len(num_original.split('.')[0]) + decimals, '0')
36 | num = num.zfill(decimals + 3)
37 |
38 | for digit in num:
39 | label = np.concatenate((label, np.eye(10)[int(digit)]))
40 |
41 | return label
42 |
43 | def get_data(coord, coord_index):
44 | lat, lon = coord[0], coord[1]
45 |
46 | img_path = os.path.join(args.images, f'street_view_{coord_index}.jpg')
47 | img = Image.open(img_path)
48 |
49 | lat_multi_label = multi_label(lat)
50 | lon_multi_label = multi_label(lon)
51 |
52 | target = np.concatenate((lat_multi_label, lon_multi_label))
53 |
54 | return [img, target]
55 |
56 | def main():
57 | with open(args.file, 'r') as f:
58 | coords_reader = csv.reader(f)
59 | coords = []
60 | for row in coords_reader:
61 | coords.append(row)
62 |
63 |
64 | train_data_path = os.path.join(args.output, 'train')
65 | os.makedirs(train_data_path, exist_ok=True)
66 | val_data_path = os.path.join(args.output, 'val')
67 | os.makedirs(val_data_path, exist_ok=True)
68 |
69 | val_count = 0
70 | train_count = 0
71 |
72 | for coord_index, coord in enumerate(tqdm(coords)):
73 | if randint(0, 9) == 0:
74 | data = get_data(coord, coord_index)
75 | val_data_path = os.path.join(args.output, f'val/street_view_{val_count}.jpg')
76 | data[0].save(val_data_path)
77 | targets_val.append(data[1])
78 | val_count += 1
79 | else:
80 | data = get_data(coord, coord_index)
81 | train_data_path = os.path.join(args.output, f'train/street_view_{train_count}.jpg')
82 | data[0].save(train_data_path)
83 | targets_train.append(data[1])
84 | train_count += 1
85 |
86 | np.save(os.path.join(args.output, f'train/targets.npy'), np.array(targets_train))
87 | np.save(os.path.join(args.output, f'val/targets.npy'), np.array(targets_val))
88 |
89 | print('Train Files:', train_count)
90 | print('Val Files:', val_count)
91 |
92 | if __name__ == '__main__':
93 | main()
94 |
--------------------------------------------------------------------------------
/find_best_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | import numpy as np
5 | from datetime import datetime
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.utils.data
10 | import torchvision.models as models
11 | from geoguessr_dataset import GeoGuessrDataset
12 |
13 | model_names = sorted(name for name in models.__dict__
14 | if name.islower() and not name.startswith("__")
15 | and callable(models.__dict__[name]))
16 |
17 | parser = argparse.ArgumentParser(description='PyTorch GeoGuessr AI Best Model Locator')
18 | parser.add_argument('data', metavar='DIR',
19 | help='path to dataset')
20 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
21 | choices=model_names,
22 | help='model architecture: ' +
23 | ' | '.join(model_names) +
24 | ' (default: resnet50)')
25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
26 | help='number of data loading workers (default: 4)')
27 | parser.add_argument('-b', '--batch-size', default=64, type=int,
28 | metavar='N',
29 | help='batch size (default: 64), this is the total '
30 | 'batch size of the GPU')
31 | parser.add_argument('--models-dir', default='models', type=str)
32 |
33 | start_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
34 | args = parser.parse_args()
35 | all_loss = []
36 |
37 | def fwd_pass(model, data, targets, loss_function, optimizer, train=False):
38 | data = data.cuda()
39 | targets = targets.cuda()
40 |
41 | if train:
42 | model.zero_grad()
43 |
44 | outputs = model(data)
45 | matches = [(torch.where(i >= 0.5, 1, 0) == j).all() for i, j in zip(outputs, targets)]
46 | acc = matches.count(True) / len(matches)
47 | loss = loss_function(outputs, targets)
48 |
49 | if train:
50 | loss.backward()
51 | optimizer.step()
52 |
53 | return acc, loss
54 |
55 | def test(val_loader, model, loss_function):
56 | model.eval()
57 | acc = []
58 | loss = []
59 |
60 | for idx, sample in enumerate(tqdm(val_loader)):
61 | if idx < 256:
62 | data, target = sample
63 | batch_acc, batch_loss = fwd_pass(model, data, target, loss_function, None)
64 | acc.append(batch_acc)
65 | loss.append(batch_loss.cpu().detach().numpy())
66 |
67 | acc = np.mean(acc)
68 | loss = np.mean(loss)
69 |
70 | val_acc = np.mean(acc)
71 | val_loss = np.mean(loss)
72 | return val_acc, val_loss
73 |
74 | def main():
75 | torch.device("cuda")
76 | valdir = os.path.join(args.data, 'val')
77 | val_dataset = GeoGuessrDataset(valdir)
78 |
79 | val_loader = torch.utils.data.DataLoader(
80 | val_dataset, batch_size=args.batch_size, shuffle=False,
81 | num_workers=args.workers, pin_memory=True)
82 |
83 | print("=> creating model '{}'".format(args.arch))
84 | model = models.__dict__[args.arch](pretrained=False, progress=True, num_classes=142)
85 | model = nn.Sequential(
86 | model,
87 | nn.Sigmoid()
88 | )
89 | model.cuda()
90 |
91 | loss_function = nn.BCELoss()
92 |
93 | for model_path in tqdm(os.listdir(args.models_dir)):
94 | model_path = os.path.join(args.models_dir, model_path)
95 | print("=> loading model '{}'".format(model_path))
96 | checkpoint = torch.load(model_path)
97 | model.load_state_dict(checkpoint['model_state_dict'])
98 | print("=> loaded model '{}' (epoch {})".format(model_path, checkpoint['epoch']))
99 |
100 | val_acc, val_loss = test(val_loader, model, loss_function)
101 | all_loss.append(val_loss)
102 | print("=> val_acc: {:.4f}, val_loss: {:.4f}".format(val_acc, val_loss))
103 |
104 | min_value = min(all_loss)
105 | min_index = all_loss.index(min_value)
106 | print("=> best model: {}, loss: {}".format(os.listdir(args.models_dir)[min_index], min_value))
107 |
108 | if __name__ == '__main__':
109 | main()
110 |
--------------------------------------------------------------------------------
/geoguessr_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import os, os.path
6 | from PIL import Image
7 |
8 | class GeoGuessrDataset(Dataset):
9 | def __init__(self, data_dir):
10 | self.data_dir = data_dir
11 | self.targets = np.load(os.path.join(data_dir, 'targets.npy'), allow_pickle=True)
12 |
13 | def __len__(self):
14 | return len(os.listdir(self.data_dir)) - 1
15 |
16 | def __getitem__(self, idx):
17 | data_path = os.path.join(self.data_dir, f'street_view_{idx}.jpg')
18 |
19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
20 | std=[0.229, 0.224, 0.225])
21 | transform = transforms.Compose([
22 | transforms.Resize(256),
23 | transforms.ToTensor(),
24 | normalize,
25 | ])
26 |
27 | img = pil_loader(data_path)
28 | data = transform(img)
29 |
30 | target = torch.tensor(self.targets[idx], dtype=torch.float)
31 |
32 | return data, target
33 |
34 | def pil_loader(path: str) -> Image.Image:
35 | with open(path, "rb") as f:
36 | img = Image.open(f)
37 | return img.convert("RGB")
--------------------------------------------------------------------------------
/get_images.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from tqdm import tqdm
3 | import os
4 | import json
5 | from random import randint
6 | import argparse
7 | from csv import writer
8 | # Consider using https://osmnx.readthedocs.io/en/stable/osmnx.html#osmnx.utils_geo.sample_points for street gps coords
9 | # Stack Overflow on how to get the coordinates: https://stackoverflow.com/questions/68367074/how-to-generate-random-lat-long-points-within-geographical-boundaries
10 |
11 | def get_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--cities", help="The folder full of addresses per city to read and extract GPS coordinates from", required=True, type=str)
14 | parser.add_argument("--output", help="The output folder where the images will be stored, (defaults to: images/)", default='images/', type=str)
15 | parser.add_argument("--icount", help="The amount of images to pull (defaults to 25,000)", default=25000, type=int)
16 | parser.add_argument("--key", help="Your Google Street View API Key", type=str, required=True)
17 | return parser.parse_args()
18 |
19 | args = get_args()
20 | url = 'https://maps.googleapis.com/maps/api/streetview'
21 | cities = []
22 |
23 | def load_cities():
24 | for city in os.listdir(args.cities):
25 | with open(os.path.join(args.cities, city)) as f:
26 | coordinates = []
27 | print(f'Loading {city} addresses...')
28 | for line in tqdm(f):
29 | data = json.loads(line)
30 | coordinates.append(data['geometry']['coordinates'])
31 | cities.append(coordinates)
32 |
33 | def main():
34 | # Open and create all the necessary files & folders
35 | os.makedirs(args.output, exist_ok=True)
36 |
37 | load_cities()
38 |
39 | coord_output_file = open(os.path.join(args.output, 'picture_coords.csv'), 'w', newline='')
40 | csv_writer = writer(coord_output_file)
41 |
42 | for i in tqdm(range(args.icount)):
43 | cities_count = []
44 | cities_count = [0] * len(cities)
45 | city_index = randint(0, len(cities) - 1)
46 | city = cities[city_index]
47 | cities_count[city_index] += 1
48 | addressLoc = city[randint(0, len(city) - 1)]
49 | city.remove(addressLoc) # Remove the address from the list so we don't get the same one twice
50 | # Set the parameters for the API call to Google Street View
51 | params = {
52 | 'key': args.key,
53 | 'size': '640x640',
54 | 'location': str(addressLoc[1]) + ',' + str(addressLoc[0]),
55 | 'heading': str((randint(0, 3) * 90) + randint(-15, 15)),
56 | 'pitch': '20',
57 | 'fov': '90'
58 | }
59 |
60 | response = requests.get(url, params)
61 |
62 | # Save the image to the output folder
63 | with open(os.path.join(args.output, f'street_view_{i}.jpg'), "wb") as file:
64 | file.write(response.content)
65 |
66 | # Save the coordinates to the output file
67 | csv_writer.writerow([addressLoc[1], addressLoc[0]])
68 |
69 | coord_output_file.close()
70 |
71 | for i in range(len(cities_count)):
72 | city_count = cities_count[i]
73 | city_name = os.listdir(args.cities)[i]
74 | print(f'{city_count} images pulled from {city_name}')
75 |
76 | if __name__ == '__main__':
77 | main()
78 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from calendar import EPOCH
2 | import os
3 | import argparse
4 | from tqdm import tqdm
5 | import numpy as np
6 | from datetime import datetime
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.data
11 | import torchvision.models as models
12 | from torch.utils.tensorboard import SummaryWriter
13 | from utils.tensor_utils import round_tensor
14 | from geoguessr_dataset import GeoGuessrDataset
15 |
16 | model_names = sorted(name for name in models.__dict__
17 | if name.islower() and not name.startswith("__")
18 | and callable(models.__dict__[name]))
19 |
20 | parser = argparse.ArgumentParser(description='PyTorch GeoGuessr AI Training')
21 | parser.add_argument('data', metavar='DIR',
22 | help='path to dataset')
23 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
24 | choices=model_names,
25 | help='model architecture: ' +
26 | ' | '.join(model_names) +
27 | ' (default: resnet50)')
28 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
29 | help='number of data loading workers (default: 4)')
30 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
31 | help='number of total epochs to run')
32 | parser.add_argument('--checkpoint-step', default=1, type=int, metavar='N',
33 | help='how often (in epochs) to save the model (default: 1)')
34 | parser.add_argument('-b', '--batch-size', default=64, type=int,
35 | metavar='N',
36 | help='batch size (default: 64), this is the total '
37 | 'batch size of the GPU')
38 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
39 | metavar='LR', help='learning rate for optimizer', dest='lr')
40 | parser.add_argument('-p', '--print-freq', default=10, type=int,
41 | metavar='N', help='print frequency (default: 10)')
42 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
43 | help='path to latest checkpoint (default: none)')
44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
45 | help='evaluate model on validation set')
46 |
47 | start_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
48 | args = parser.parse_args()
49 |
50 | def fwd_pass(model, data, targets, loss_function, optimizer, train=False):
51 | data = data.cuda()
52 | targets = targets.cuda()
53 |
54 | if train:
55 | model.zero_grad()
56 |
57 | outputs = model(data)
58 | matches = [(torch.where(i >= 0.5, 1, 0) == j).all() for i, j in zip(outputs, targets)]
59 | acc = matches.count(True) / len(matches)
60 | loss = loss_function(outputs.float(), targets)
61 |
62 | if train:
63 | loss.backward()
64 | optimizer.step()
65 |
66 | return acc, loss
67 |
68 | def test(val_loader, model, loss_function, optimizer):
69 | random = np.random.randint(len(val_loader))
70 |
71 | model.eval()
72 | acc = []
73 | loss = []
74 |
75 | for idx, sample in enumerate(val_loader):
76 | if idx >= random and idx < random + 4:
77 | data, targets = sample
78 | with torch.no_grad():
79 | val_acc, val_loss = fwd_pass(model, data, targets, loss_function, optimizer)
80 | acc.append(val_acc)
81 | loss.append(val_loss.cpu().numpy())
82 |
83 | val_acc = np.mean(acc)
84 | val_loss = np.mean(loss)
85 | return val_acc, val_loss
86 |
87 | def train(train_loader, val_loader, model, loss_function, optimizer, epochs, start_epoch=0):
88 | with open(f'models/{start_time}/model.log', 'a') as f:
89 | for epoch in range(start_epoch, epochs):
90 | model.train()
91 |
92 | train_acc = []
93 | train_loss = []
94 |
95 | for idx, sample in enumerate(tqdm(train_loader)):
96 | data, target = sample
97 | acc, loss = fwd_pass(model, data, target, loss_function, optimizer, train=True)
98 | train_acc.append(acc)
99 | train_loss.append(loss.cpu().detach().numpy())
100 |
101 | acc = np.mean(train_acc)
102 | loss = np.mean(train_loss)
103 |
104 | val_acc, val_loss = test(val_loader, model, loss_function, optimizer)
105 |
106 | # Add accuracy and loss to tensorboard
107 | progress = len(train_loader) / idx
108 | writer.add_scalar('Loss/train', loss, epoch)
109 | writer.add_scalar('Accuracy/train', acc, epoch)
110 | writer.add_scalar('Loss/test', val_loss, epoch)
111 | writer.add_scalar('Accuracy/test', val_acc, epoch)
112 |
113 | # Log Accuracy and Loss
114 | log = f'model-{epoch}, Accuracy: {round(float(acc), 2)}, Loss: {round(float(loss), 4)}, Val Accuracy: {round(float(val_acc), 2)}, Val Loss: {round(float(val_loss), 4)}\n'
115 | print(log, end='')
116 | f.write(log)
117 |
118 | if epoch % args.checkpoint_step == 0:
119 | print('Saving model...')
120 | torch.save({
121 | 'epoch': epoch,
122 | 'model_state_dict': model.state_dict(),
123 | 'optimizer_state_dict': optimizer.state_dict(),
124 | 'loss': loss
125 | }, f'models/{start_time}/model-{epoch}.pth')
126 |
127 | def main():
128 | global writer
129 | writer = SummaryWriter(f'tensorboard/{start_time}')
130 |
131 | os.makedirs(f'models/{start_time}', exist_ok=True)
132 |
133 | traindir = os.path.join(args.data, 'train')
134 | valdir = os.path.join(args.data, 'val')
135 | train_dataset = GeoGuessrDataset(traindir)
136 | val_dataset = GeoGuessrDataset(valdir)
137 |
138 | train_loader = torch.utils.data.DataLoader(
139 | train_dataset, batch_size=args.batch_size, shuffle=True,
140 | num_workers=args.workers, pin_memory=True)
141 |
142 | val_loader = torch.utils.data.DataLoader(
143 | val_dataset, batch_size=args.batch_size, shuffle=False,
144 | num_workers=args.workers, pin_memory=True)
145 |
146 | print("=> creating model '{}'".format(args.arch))
147 | model = models.__dict__[args.arch](pretrained=False, progress=True, num_classes=142)
148 | model = nn.Sequential(
149 | model,
150 | nn.Softmax(dim=0)
151 | )
152 |
153 | loss_function = nn.CrossEntropyLoss()
154 |
155 | if torch.cuda.is_available():
156 | print('Using GPU')
157 | torch.device("cuda")
158 | model = model.cuda()
159 | loss_function = loss_function.cuda()
160 | else:
161 | print('Using CPU')
162 | torch.device("cpu")
163 |
164 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=1e-4)
165 |
166 | start_epoch = 0
167 |
168 | if not args.resume == '':
169 | checkpoint = torch.load(args.resume)
170 | model.load_state_dict(checkpoint['model_state_dict'])
171 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
172 | start_epoch = checkpoint['epoch'] + 1
173 | print(f'Resuming from epoch {start_epoch}')
174 |
175 | EPOCHS = args.epochs
176 | train(train_loader=train_loader, val_loader=val_loader, model=model, loss_function=loss_function, optimizer=optimizer, epochs=EPOCHS, start_epoch=start_epoch)
177 |
178 | if __name__ == '__main__':
179 | main()
180 |
--------------------------------------------------------------------------------
/notebook/GeoGuessr_AI_Demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "GeoGuessr-AI Demo",
7 | "provenance": [],
8 | "authorship_tag": "ABX9TyP5rZRIlS0xIT6+iHUu0CqK",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "view-in-github",
24 | "colab_type": "text"
25 | },
26 | "source": [
27 | "
"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "source": [
33 | "# GeoGuessr-AI Demo"
34 | ],
35 | "metadata": {
36 | "id": "XYROTdQlbE6x"
37 | }
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "source": [
42 | "Disclaimer, this model was only trained on pictures from 5 cities (Washington, City of New York, Chicago, Detroit, and San Francisco) and will therefore only be able to predict the location of photos taken in those cities."
43 | ],
44 | "metadata": {
45 | "id": "mPqEzY1MIfwQ"
46 | }
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "source": [
51 | "## Setup"
52 | ],
53 | "metadata": {
54 | "id": "hToy8jcKbOjx"
55 | }
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 1,
60 | "metadata": {
61 | "colab": {
62 | "base_uri": "https://localhost:8080/"
63 | },
64 | "id": "d17FwqlWXsTd",
65 | "outputId": "9dec4987-f9a7-438e-e3cf-0c1425203300"
66 | },
67 | "outputs": [
68 | {
69 | "output_type": "stream",
70 | "name": "stdout",
71 | "text": [
72 | "Downloading...\n",
73 | "From: https://drive.google.com/uc?id=1VJpeLJp6jC8IUfKy6cAtZ9WZcX1TTutW&confirm=t\n",
74 | "To: /content/geoguessr_production_model.pt\n",
75 | "100% 269M/269M [00:01<00:00, 176MB/s]\n"
76 | ]
77 | }
78 | ],
79 | "source": [
80 | "!gdown 'https://drive.google.com/uc?id=1VJpeLJp6jC8IUfKy6cAtZ9WZcX1TTutW&confirm=t'"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "source": [
86 | "import torch\n",
87 | "import torchvision.models as models\n",
88 | "import torchvision.transforms as transforms\n",
89 | "import torch.nn as nn\n",
90 | "\n",
91 | "import numpy as np\n",
92 | "from PIL import Image\n",
93 | "import os\n",
94 | "\n",
95 | "from google.colab import files\n",
96 | "\n",
97 | "model = models.wide_resnet50_2(pretrained=False, progress=True, num_classes=142)\n",
98 | "model = nn.Sequential(\n",
99 | " model,\n",
100 | " nn.Sigmoid()\n",
101 | ")\n",
102 | "model_file = torch.load('geoguessr_production_model.pt', map_location=torch.device('cpu'))\n",
103 | "model.load_state_dict(model_file)\n",
104 | "model.eval()\n",
105 | "print('Loaded Model')\n",
106 | "\n",
107 | "def reformat(arr, guess_num=1):\n",
108 | " num = ''\n",
109 | " if arr[0] >= 0.5:\n",
110 | " num += '-'\n",
111 | " \n",
112 | " arr = arr[1:]\n",
113 | "\n",
114 | " for idx in range(0, len(arr), 10):\n",
115 | " if idx == 30:\n",
116 | " num += '.'\n",
117 | " num += str(np.where(arr[idx:idx+10] == np.partition(arr[idx:idx+10].flatten(), -guess_num)[-guess_num])[0][0] % 10)\n",
118 | "\n",
119 | " return num\n",
120 | "\n",
121 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
122 | " std=[0.229, 0.224, 0.225])\n",
123 | "transform = transforms.Compose([\n",
124 | " transforms.Resize((256, 256)),\n",
125 | " transforms.ToTensor(),\n",
126 | " normalize,\n",
127 | " ])"
128 | ],
129 | "metadata": {
130 | "colab": {
131 | "base_uri": "https://localhost:8080/"
132 | },
133 | "id": "AiJcPzD0bmT2",
134 | "outputId": "1787a5a5-67c4-4381-dbfe-09f9e7aa0c28"
135 | },
136 | "execution_count": 2,
137 | "outputs": [
138 | {
139 | "output_type": "stream",
140 | "name": "stdout",
141 | "text": [
142 | "Loaded Model\n"
143 | ]
144 | }
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "source": [
150 | "## Run Model"
151 | ],
152 | "metadata": {
153 | "id": "vbWYVDdocvN0"
154 | }
155 | },
156 | {
157 | "cell_type": "code",
158 | "source": [
159 | "#@title ## Load File\n",
160 | "\n",
161 | "reference_file = files.upload()\n",
162 | "reference_file = list(reference_file.keys())[0]\n",
163 | "\n",
164 | "img_path = os.path.join(reference_file)\n",
165 | "img = Image.open(img_path)\n",
166 | "data = transform(img)\n",
167 | "\n",
168 | "from matplotlib import pyplot as plt\n",
169 | "img = data.permute(1, 2, 0)\n",
170 | "plt.imshow(img, interpolation='nearest')\n",
171 | "plt.show()"
172 | ],
173 | "metadata": {
174 | "colab": {
175 | "resources": {
176 | "http://localhost:8080/nbextensions/google.colab/files.js": {
177 | "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgZG8gewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwoKICAgICAgbGV0IHBlcmNlbnREb25lID0gZmlsZURhdGEuYnl0ZUxlbmd0aCA9PT0gMCA/CiAgICAgICAgICAxMDAgOgogICAgICAgICAgTWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCk7CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPSBgJHtwZXJjZW50RG9uZX0lIGRvbmVgOwoKICAgIH0gd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCk7CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
178 | "ok": true,
179 | "headers": [
180 | [
181 | "content-type",
182 | "application/javascript"
183 | ]
184 | ],
185 | "status": 200,
186 | "status_text": ""
187 | }
188 | },
189 | "base_uri": "https://localhost:8080/",
190 | "height": 342
191 | },
192 | "cellView": "form",
193 | "id": "hKdi894nbyKp",
194 | "outputId": "662520b0-9ae2-4406-873a-49adfb3fe830"
195 | },
196 | "execution_count": null,
197 | "outputs": [
198 | {
199 | "output_type": "display_data",
200 | "data": {
201 | "text/html": [
202 | "\n",
203 | " \n",
205 | " \n",
209 | " "
210 | ],
211 | "text/plain": [
212 | ""
213 | ]
214 | },
215 | "metadata": {}
216 | },
217 | {
218 | "output_type": "stream",
219 | "name": "stdout",
220 | "text": [
221 | "Saving sample_image.jpeg to sample_image.jpeg\n"
222 | ]
223 | },
224 | {
225 | "output_type": "stream",
226 | "name": "stderr",
227 | "text": [
228 | "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
229 | ]
230 | },
231 | {
232 | "output_type": "display_data",
233 | "data": {
234 | "image/png": "\n",
235 | "text/plain": [
236 | ""
237 | ]
238 | },
239 | "metadata": {
240 | "needs_background": "light"
241 | }
242 | }
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "source": [
248 | "#@title ## Make Prediction\n",
249 | "#@markdown Run the model and get a lattitude and longitude prediction based on the image, the model will predict the most likely location and then make a backup guess as well.\n",
250 | "\n",
251 | "with torch.no_grad():\n",
252 | " data = data.view(-1,3,256,256)\n",
253 | " output = model(data)\n",
254 | "\n",
255 | "target_split_len = int(len(output.cpu().numpy()[0])/2)\n",
256 | "output_reformatted = reformat(output.cpu().numpy()[0][:target_split_len]) + ' ' + reformat(output.cpu().numpy()[0][target_split_len:])\n",
257 | "output_reformatted2 = reformat(output.cpu().numpy()[0][:target_split_len]) + ' ' + reformat(output.cpu().numpy()[0][target_split_len:], 2)\n",
258 | "print('First Guess:', output_reformatted)\n",
259 | "print('Second Guess:', output_reformatted2)"
260 | ],
261 | "metadata": {
262 | "colab": {
263 | "base_uri": "https://localhost:8080/"
264 | },
265 | "id": "XWavPNf1c6Fy",
266 | "outputId": "be208b02-a41e-4399-fbf0-6bc9eb377624",
267 | "cellView": "form"
268 | },
269 | "execution_count": null,
270 | "outputs": [
271 | {
272 | "output_type": "stream",
273 | "name": "stdout",
274 | "text": [
275 | "First Guess: 037.7838 -122.4105\n",
276 | "Second Guess: 037.7838 -073.3263\n"
277 | ]
278 | }
279 | ]
280 | }
281 | ]
282 | }
--------------------------------------------------------------------------------
/save_production_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.utils.data
5 |
6 | parser = argparse.ArgumentParser(description='PyTorch GeoGuessr AI Best Model Locator')
7 | parser.add_argument('modelpath', metavar='DIR',
8 | help='path to model')
9 |
10 | args = parser.parse_args()
11 |
12 | def main():
13 | checkpoint = torch.load(args.modelpath)
14 | torch.save(checkpoint['model_state_dict'], 'geoguessr_production_model.pt')
15 |
16 | if __name__ == '__main__':
17 | main()
--------------------------------------------------------------------------------
/utils/tensor_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def round_tensor(tensor, decimals=4):
4 | return torch.round(tensor * 10 ** decimals) / (10 ** decimals)
--------------------------------------------------------------------------------