├── models └── .keep ├── outputs └── .keep ├── datasets └── .keep ├── LICENSE ├── README.md ├── .gitignore ├── train_net.py └── gen_figure.py /models/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outputs/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ASU CICI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # explainable-geoai 2 | 3 | ## Overview 4 | This is the dataset and code for the work "Explainable GeoAI: Performance assessment of deep learning model visualization techniques to understand AI learning processes with geospatial data" 5 | 6 | ## Installation and Prerequisities 7 | This project is built upon [PyTorch](https://pytorch.org) framework and its library [Captum](https://captum.ai) for model interpretability. To install both PyTorch and Captum, please check their official installation guide for more details. 8 | 9 | ## Getting Start 10 | ### Training a New Model 11 | This project used VGG-16 as the demo model. A trained model is saved under [`models`](./models). To train a new model, you can run: 12 | ```bash 13 | python train_net.py 14 | ``` 15 | 16 | ### Generating Visualizations 17 | To generate visualizations of a given image, you should specify the image name ``. You can run: 18 | ```bash 19 | python gen_figure.py --image-name 20 | # example: 21 | python gen_figure.py --image-name crater_000001 22 | ``` 23 | 24 | After the generating process, you will have four images in the [`outputs`](./outputs) folder: 25 | * `.original.png`: the original image with bounding box labels 26 | * `.occ.png`: Occlusion visualization 27 | * `.cam.png`: Grad-CAM visualization 28 | * `.ig.png`: Integrated Gradients visualization 29 | 30 | The file names of the figures in the manuscript are: 31 | * Figure 2: volcano_000039 32 | * Figure 3: meander_000059 33 | * Figure 4: crater_000083 34 | * Figure 5: dunes_000010 35 | * Figure 6: volcano_000055 36 | * Figure 7: icebergtongue_000067 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # self defined 2 | *.jpg 3 | *.xml 4 | *.pickle 5 | *.pth 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import pickle 4 | import random 5 | import time 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from torch.utils.data import DataLoader, Dataset 11 | from torchvision import models 12 | 13 | # dataset class 14 | class_names = ['crater', 'dunes', 'hill', 'icebergtongue', 'lake', 'meander', 'river', 'volcano'] 15 | cls_to_idx = {'crater': 0, 'dunes': 1, 'hill': 2, 'icebergtongue': 3, 'lake': 4, 'meander': 5, 'river': 6, 'volcano': 7} 16 | 17 | class terrainDataset(Dataset): 18 | 19 | def __init__(self, sets='train', transform=None): 20 | self.dir = './datasets/natural_feature_dataset/' 21 | 22 | if sets == 'train': 23 | self.img_list = pickle.load(open(os.path.join(self.dir, 'train_set.pickle'), 'rb')) 24 | elif sets == 'test': 25 | self.img_list = pickle.load(open(os.path.join(self.dir, 'test_set.pickle'), 'rb')) 26 | 27 | self.transform = transform 28 | random.shuffle(self.img_list) 29 | 30 | def __len__(self): 31 | return len(self.img_list) 32 | 33 | def __getitem__(self, index): 34 | img_path = os.path.join(self.dir, 'images', self.img_list[index]) 35 | img_cls = self.img_list[index].split('_')[0] 36 | cls_idx = cls_to_idx[img_cls] 37 | img = Image.open(img_path) 38 | img = img.convert('RGB') 39 | 40 | if self.transform: 41 | img = self.transform(img) 42 | 43 | sample = {'image': img, 'class': cls_idx, 'name': self.img_list[index]} 44 | 45 | return sample 46 | 47 | 48 | # intialize datasets and transformation 49 | transform = transforms.Compose([transforms.Resize([224, 224]), 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 52 | 53 | train_set = terrainDataset(sets='train', transform=transform) 54 | test_set = terrainDataset(sets='test', transform=transform) 55 | train_loader = DataLoader(train_set, batch_size=4, shuffle=True) 56 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False) 57 | dataloader = {'train': train_loader, 'val': test_loader} 58 | dataset_sizes = {'train': len(train_set), 'val': len(test_set)} 59 | 60 | 61 | # model initialization 62 | model = models.vgg16(pretrained=True) 63 | num_ftrs = model.classifier[6].in_features 64 | model.classifier[6] = torch.nn.Linear(num_ftrs, 8) 65 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 66 | model = model.to(device) 67 | 68 | 69 | # loss function and optimizer 70 | criterion = torch.nn.CrossEntropyLoss() 71 | optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 72 | num_epochs = 30 73 | 74 | # train vgg16 75 | since = time.time() 76 | best_model_wts = copy.deepcopy(model.state_dict()) 77 | best_acc = 0.0 78 | 79 | for epoch in range(num_epochs): 80 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 81 | print('-' * 10) 82 | 83 | for phase in ['train', 'val']: 84 | if phase == 'train': 85 | model.train() 86 | else: 87 | model.eval() 88 | 89 | running_loss = 0.0 90 | running_corrects = 0 91 | 92 | for sample in dataloader[phase]: 93 | inputs = sample['image'] 94 | labels = sample['class'] 95 | inputs = inputs.to(device) 96 | labels = labels.to(device) 97 | 98 | optimizer.zero_grad() 99 | with torch.set_grad_enabled(phase == 'train'): 100 | outputs = model(inputs) 101 | _, preds = torch.max(outputs, 1) 102 | loss = criterion(outputs, labels) 103 | 104 | if phase == 'train': 105 | loss.backward() 106 | optimizer.step() 107 | 108 | running_loss += loss.item() * inputs.size(0) 109 | running_corrects += torch.sum(preds == labels.data) 110 | 111 | epoch_loss = running_loss / dataset_sizes[phase] 112 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 113 | 114 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 115 | 116 | if phase == 'val' and epoch_acc > best_acc: 117 | best_acc = epoch_acc 118 | best_model_wts = copy.deepcopy(model.state_dict()) 119 | 120 | print() 121 | 122 | time_elapsed = time.time() - since 123 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 124 | print('Best val Acc: {:4f}'.format(best_acc)) 125 | torch.save(best_model_wts, './models/vgg16_best.pth') 126 | -------------------------------------------------------------------------------- /gen_figure.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import xml.etree.ElementTree as et 5 | 6 | import matplotlib.patches as patches 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import torchvision.transforms as transforms 11 | from captum.attr import IntegratedGradients, LayerGradCam, Occlusion 12 | from PIL import Image 13 | from torchvision import models 14 | 15 | 16 | def gen_figure(img_name): 17 | # dataset class 18 | class_names = ['crater', 'dunes', 'hill', 'icebergtongue', 'lake', 'meander', 'river', 'volcano'] 19 | cls_to_idx = {'crater': 0, 'dunes': 1, 'hill': 2, 'icebergtongue': 3, 'lake': 4, 'meander': 5, 'river': 6, 'volcano': 7} 20 | 21 | 22 | # transformation 23 | transform = transforms.Compose([transforms.Resize([224, 224]), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 26 | 27 | 28 | # load image 29 | img_dir = './datasets/natural_feature_dataset/' 30 | img_path = os.path.join(img_dir, 'images', img_name + '.jpg') 31 | img = Image.open(img_path) 32 | width, height = img.size 33 | img = img.convert('RGB') 34 | img = transform(img) 35 | img = img[None, :] 36 | 37 | 38 | # load trained model 39 | model = models.vgg16(pretrained=True) 40 | num_ftrs = model.classifier[6].in_features 41 | model.classifier[6] = torch.nn.Linear(num_ftrs, 8) 42 | model.load_state_dict(torch.load('./models/vgg16_best.pth')) 43 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 44 | model.to(device) 45 | img = img.to(device) 46 | model.eval() 47 | 48 | 49 | # run prediction on the image 50 | gt_label = img_name.split('_')[0] 51 | gt_target = cls_to_idx[gt_label] 52 | 53 | m = torch.nn.Softmax(dim=1) 54 | 55 | with torch.no_grad(): 56 | output = m(model(img)) 57 | gt_prob = output[0, gt_target].item() 58 | predict_target = torch.argmax(output) 59 | predict_prob = output[0, predict_target].item() 60 | predict_label = class_names[predict_target] 61 | print('running prediction on {}...'.format(img_name)) 62 | print('Label: {} ({})'.format(gt_label, gt_prob)) 63 | print('Prediction: {} ({})'.format(predict_label, predict_prob)) 64 | 65 | 66 | # draw original images 67 | fig, ax = plt.subplots() 68 | inp = copy.deepcopy(img) 69 | inp = np.transpose(inp.squeeze().cpu().detach().numpy(), (1,2,0)) 70 | mean = np.array([0.485, 0.456, 0.406]) 71 | std = np.array([0.229, 0.224, 0.225]) 72 | inp = std * inp + mean 73 | inp = np.clip(inp, 0, 1) 74 | ax.imshow(inp) 75 | ax.axis('off') 76 | 77 | # draw labels 78 | tree = et.parse(os.path.join(img_dir, 'labels', img_name + '.xml')) 79 | root = tree.getroot() 80 | wratio = 224.0 / width 81 | hratio = 224.0 / height 82 | 83 | for bb in root.iter('bndbox'): 84 | xmin = int(bb[0].text) * wratio 85 | ymin = int(bb[1].text) * hratio 86 | xmax = int(bb[2].text) * wratio 87 | ymax = int(bb[3].text) * hratio 88 | rec = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, edgecolor='b', fill=False) 89 | ax.add_patch(rec) 90 | 91 | plt.savefig('./outputs/{}_original.png'.format(img_name), bbox_inches='tight', dpi=300) 92 | plt.close() 93 | 94 | 95 | # Grad-CAM 96 | gridspec = {'width_ratios': [1,1,0.1]} 97 | fig, ax = plt.subplots(1,3,figsize=(11,4.7),gridspec_kw=gridspec) 98 | layer_gc = LayerGradCam(model, model.features[30]) 99 | attri_gc = layer_gc.attribute(img, predict_target) 100 | cam = attri_gc.squeeze().cpu().detach().numpy() 101 | cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) 102 | cam = np.uint8(cam * 255) # Scale between 0-255 to visualize 103 | cam = np.uint8(Image.fromarray(cam).resize((224, 224), Image.ANTIALIAS))/255 104 | heatmap = ax[0].imshow(cam, cmap='jet') 105 | ax[0].axis('off') 106 | 107 | # Grad-CAM on the original image 108 | ax[1].imshow(inp) 109 | ax[1].axis('off') 110 | ax[1].imshow(cam, cmap='jet', alpha=0.3) 111 | plt.colorbar(heatmap, cax=ax[2]) 112 | plt.savefig('./outputs/{}_cam.png'.format(img_name), bbox_inches='tight', dpi=300) 113 | plt.close() 114 | 115 | 116 | # occlusion 117 | gridspec = {'width_ratios': [1,1,0.1]} 118 | fig, ax = plt.subplots(1,3,figsize=(11,4.7),gridspec_kw=gridspec) 119 | occ = Occlusion(model) 120 | attri_occ = occ.attribute(img, target=predict_target, sliding_window_shapes=(3,40,40), show_progress=True) 121 | attri_occ = np.transpose(attri_occ.squeeze().cpu().detach().numpy(), (1,2,0)) 122 | attri_occ = attri_occ[:,:,0] 123 | attri_occ = (attri_occ - np.min(attri_occ)) / (np.max(attri_occ) - np.min(attri_occ)) 124 | heatmap = ax[0].imshow(attri_occ, cmap='jet') 125 | ax[0].axis('off') 126 | 127 | # occlusion on the image 128 | ax[1].imshow(inp) 129 | ax[1].axis('off') 130 | ax[1].imshow(attri_occ, cmap='jet', alpha=0.3) 131 | plt.colorbar(heatmap, cax=ax[2]) 132 | plt.savefig('./outputs/{}_occ.png'.format(img_name), bbox_inches='tight', dpi=300) 133 | plt.close() 134 | 135 | 136 | # Integrated Gradients 137 | gridspec = {'width_ratios': [1,1,0.1]} 138 | fig, ax = plt.subplots(1,3,figsize=(11,4.7),gridspec_kw=gridspec) 139 | ig = IntegratedGradients(model) 140 | attri_ig = ig.attribute(img, target=predict_target) 141 | attri_ig = np.transpose(attri_ig.squeeze().cpu().detach().numpy(), (1,2,0)) 142 | attri_ig = np.sum(attri_ig, axis=2) 143 | attri_ig = np.abs(attri_ig) 144 | attri_sorted = np.sort(attri_ig.flatten()) 145 | attri_ig = (attri_ig < attri_sorted[-50]) * attri_ig 146 | attri_ig = (attri_ig - np.min(attri_ig)) / (np.max(attri_ig) - np.min(attri_ig)) 147 | heatmap = ax[0].imshow(attri_ig, cmap='jet') 148 | ax[0].axis('off') 149 | 150 | # ig on the image 151 | ax[1].imshow(inp) 152 | ax[1].axis('off') 153 | ax[1].imshow(attri_ig, cmap='jet', alpha=0.3) 154 | plt.colorbar(heatmap, cax=ax[2]) 155 | plt.savefig('./outputs/{}_ig.png'.format(img_name), bbox_inches='tight', dpi=300) 156 | plt.close() 157 | 158 | 159 | def get_parser(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument( 162 | "--image-name", 163 | type=str, 164 | help="image file name", 165 | required=True, 166 | ) 167 | 168 | return parser 169 | 170 | if __name__=="__main__": 171 | args = get_parser().parse_args() 172 | gen_figure(args.image_name) 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | --------------------------------------------------------------------------------