├── preprocess_iwildcam.sh ├── requirements.txt ├── preprocess_mountain_zebra.sh ├── process_taxonomy_ott.py ├── resnet.py ├── LICENSE ├── utils.py ├── process_taxonomy_inat.py ├── .gitignore ├── pytorchtools.py ├── gen_utils ├── analyze_img_time.py ├── eval_kge_specie_wise.py ├── analyze_img_loc.py ├── dump_kge_pred_specie_wise.py ├── dump_imageonly_pred_specie_wise.py └── analyze_taxonomy_model.py ├── README.md ├── dataset.py ├── model_st.py ├── eval.py ├── preprocess_data_iwildcam.py ├── model.py ├── preprocess_data_mountain_zebra.py ├── run_image_only_model.py ├── dataset_baseline.py ├── gps_locations.json ├── main.py └── run_kge_model_baseline.py /preprocess_iwildcam.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://files.opentreeoflife.org/ott/ott3.3/ott3.3.tgz 4 | tar -xvzf ott3.3.tgz 5 | 6 | python process_taxonomy_ott.py 7 | python preprocess_data_iwildcam.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | pandas==1.5.3 3 | numpy==1.24.2 4 | Pillow==9.4.0 5 | scipy==1.10.1 6 | tensorboard==2.12.2 7 | torchvision==0.15.1 8 | tqdm==4.64.1 9 | wilds==2.0.0 10 | matplotlib==3.7.1 11 | -------------------------------------------------------------------------------- /preprocess_mountain_zebra.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://www.inaturalist.org/taxa/inaturalist-taxonomy.dwca.zip 4 | mkdir inaturalist-taxonomy.dwca/ 5 | unzip -d inaturalist-taxonomy.dwca/ inaturalist-taxonomy.dwca.zip 6 | 7 | python process_taxonomy_inat.py 8 | 9 | mv data/snapshot_mountain_zebra/MTZ_public/MTZ_S1 data/snapshot_mountain_zebra/ 10 | 11 | python preprocess_data_mountain_zebra.py --data-dir data/snapshot_mountain_zebra/ --dataset-prefix mountain_zebra --species-common-names-file data/snapshot_mountain_zebra/category_to_label_map.json 12 | -------------------------------------------------------------------------------- /process_taxonomy_ott.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import string 3 | from tqdm import tqdm 4 | import re 5 | 6 | ott_taxonomy = pd.read_csv("ott3.3/taxonomy.tsv", sep="\t") 7 | ott_taxonomy = ott_taxonomy.loc[:, ['uid', 'parent_uid', 'name', 'rank', 'sourceinfo', 'uniqname', 'flags']] 8 | 9 | punctuation_string = string.punctuation 10 | taxonomy_category = list(ott_taxonomy.name) 11 | 12 | for i in tqdm(range(len(taxonomy_category))): 13 | taxonomy_category[i] = ' '.join(taxonomy_category[i].split()) 14 | taxonomy_category[i] = taxonomy_category[i].translate(str.maketrans('', '', string.punctuation)) 15 | taxonomy_category[i] = taxonomy_category[i].lower() 16 | taxonomy_category[i] = re.sub(' +', ' ', taxonomy_category[i]) 17 | 18 | ott_taxonomy_2 = ott_taxonomy.copy() 19 | ott_taxonomy_2.name = taxonomy_category 20 | 21 | ott_taxonomy_2.to_csv('ott_taxonomy.csv', index=False) -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import resnet18, resnet50 4 | 5 | class Resnet18(nn.Module): 6 | def __init__(self, args): 7 | super(Resnet18, self).__init__() 8 | self.args = args 9 | self.image_embedding = resnet18(pretrained=True) 10 | self.needs_y = False 11 | 12 | self.image_embedding.fc = nn.Linear(512, 182) 13 | nn.init.xavier_uniform_(self.image_embedding.fc.weight.data) 14 | 15 | def forward(self, x): 16 | emb_h = self.image_embedding(x) 17 | 18 | return emb_h 19 | 20 | class Resnet50(nn.Module): 21 | def __init__(self, args): 22 | super(Resnet50, self).__init__() 23 | self.args = args 24 | self.image_embedding = resnet50(pretrained=True) 25 | self.needs_y = False 26 | 27 | self.image_embedding.fc = nn.Linear(2048, 182) 28 | nn.init.xavier_uniform_(self.image_embedding.fc.weight.data) 29 | 30 | def forward(self, x): 31 | emb_h = self.image_embedding(x) 32 | 33 | return emb_h -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OSU Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import csv 4 | import argparse 5 | import random 6 | from pathlib import Path 7 | import numpy as np 8 | import torch 9 | import pandas as pd 10 | import re 11 | 12 | from torch.utils.data import DataLoader 13 | 14 | try: 15 | from torch_geometric.data import Batch 16 | except ImportError: 17 | pass 18 | 19 | def set_seed(seed): 20 | """Sets seed""" 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed(seed) 23 | torch.manual_seed(seed) 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | torch.backends.cudnn.benchmark = False 27 | torch.backends.cudnn.deterministic = True 28 | 29 | 30 | def move_to(obj, device): 31 | if isinstance(obj, dict): 32 | return {k: move_to(v, device) for k, v in obj.items()} 33 | elif isinstance(obj, list): 34 | return [move_to(v, device) for v in obj] 35 | elif isinstance(obj, float) or isinstance(obj, int): 36 | return obj 37 | else: 38 | # Assume obj is a Tensor or other type 39 | # (like Batch, for MolPCBA) that supports .to(device) 40 | return obj.to(device) 41 | 42 | def detach_and_clone(obj): 43 | if torch.is_tensor(obj): 44 | return obj.detach().clone() 45 | elif isinstance(obj, dict): 46 | return {k: detach_and_clone(v) for k, v in obj.items()} 47 | elif isinstance(obj, list): 48 | return [detach_and_clone(v) for v in obj] 49 | elif isinstance(obj, float) or isinstance(obj, int): 50 | return obj 51 | else: 52 | raise TypeError("Invalid type for detach_and_clone") 53 | 54 | def collate_list(vec): 55 | """ 56 | If vec is a list of Tensors, it concatenates them all along the first dimension. 57 | 58 | If vec is a list of lists, it joins these lists together, but does not attempt to 59 | recursively collate. This allows each element of the list to be, e.g., its own dict. 60 | 61 | If vec is a list of dicts (with the same keys in each dict), it returns a single dict 62 | with the same keys. For each key, it recursively collates all entries in the list. 63 | """ 64 | if not isinstance(vec, list): 65 | raise TypeError("collate_list must take in a list") 66 | elem = vec[0] 67 | if torch.is_tensor(elem): 68 | return torch.cat(vec) 69 | elif isinstance(elem, list): 70 | return [obj for sublist in vec for obj in sublist] 71 | elif isinstance(elem, dict): 72 | return {k: collate_list([d[k] for d in vec]) for k in elem} 73 | else: 74 | raise TypeError("Elements of the list to collate must be tensors or dicts.") 75 | -------------------------------------------------------------------------------- /process_taxonomy_inat.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import string 3 | from tqdm import tqdm 4 | import re 5 | import json 6 | 7 | inat_taxonomy = pd.read_csv("inaturalist-taxonomy.dwca/taxa.csv") 8 | inat_taxonomy.fillna('') 9 | 10 | inat_taxonomy = inat_taxonomy.loc[:, ['id', 'parentNameUsageID', 'scientificName', 'taxonRank']] 11 | inat_taxonomy.columns = ['uid', 'parent_uid', 'name', 'taxonRank'] 12 | 13 | punctuation_string = string.punctuation 14 | taxonomy_category = list(inat_taxonomy.name) 15 | 16 | for i in tqdm(range(len(taxonomy_category))): 17 | taxonomy_category[i] = ' '.join(taxonomy_category[i].split()) 18 | taxonomy_category[i] = taxonomy_category[i].translate(str.maketrans('', '', string.punctuation)) 19 | taxonomy_category[i] = taxonomy_category[i].lower() 20 | taxonomy_category[i] = re.sub(' +', ' ', taxonomy_category[i]) 21 | 22 | inat_taxonomy_2 = inat_taxonomy.copy() 23 | inat_taxonomy_2.name = taxonomy_category 24 | 25 | # replace all parent ids by just ids 26 | parent_uids = list(inat_taxonomy_2.parent_uid) 27 | 28 | parent_uids_new = [] 29 | 30 | for x in parent_uids: 31 | if isinstance(x, str): 32 | parent_uids_new.append(x.replace('https://www.inaturalist.org/taxa/','')) 33 | else: 34 | parent_uids_new.append('') 35 | 36 | inat_taxonomy_2.parent_uid = parent_uids_new 37 | 38 | inat_taxonomy_2 = inat_taxonomy_2.loc[inat_taxonomy_2['parent_uid'] != ''] 39 | 40 | taxon = inat_taxonomy_2 41 | taxon = taxon.fillna(0) 42 | taxon = taxon.loc[:, ["uid", "parent_uid"]] 43 | taxon.columns = ["h", "t"] 44 | taxon.insert(loc=1, column="r", value=1) 45 | taxon.insert(loc=1, column="datatype_h", value="id") 46 | taxon.insert(loc=4, column="datatype_t", value="id") 47 | taxon.insert(loc=5, column="split", value="train") 48 | taxon.columns = ["h", "datatype_h", "r", "t", "datatype_t", "split"] 49 | 50 | son = list(taxon["h"]) 51 | father = list(taxon["t"]) 52 | paths = {} 53 | 54 | for i in tqdm(range(len(son))): 55 | if isinstance(father[i], str) and len(father[i])==0: 56 | print('flag 1') 57 | continue 58 | 59 | paths[int(float(son[i]))] = int(float(father[i])) 60 | 61 | taxon_id_to_name = json.load(open('data/snapshot_mountain_zebra/taxon_id_to_name_lila.json')) 62 | category_to_label_map = json.load(open('data/snapshot_mountain_zebra/category_to_label_map_lila.json')) 63 | taxon_name_to_id = {v:k for k,v in taxon_id_to_name.items()} 64 | 65 | category_names = [] 66 | 67 | for x in tqdm(category_to_label_map): 68 | if category_to_label_map[x] in taxon_name_to_id: 69 | category_names.append(taxon_name_to_id[category_to_label_map[x]]) 70 | else: 71 | print(category_to_label_map[x]) 72 | 73 | leaf_node = category_names 74 | leaf_nodes = [] 75 | for item in tqdm(leaf_node): 76 | if int(float(item)) not in leaf_nodes: 77 | leaf_nodes.append(int(float(item))) 78 | 79 | list_paths = [] 80 | def get_paths(leaf_node, paths, nodes_list): 81 | while leaf_node in paths.keys(): 82 | # print(leaf_node,"->",paths[leaf_node]) 83 | nodes_list.append(leaf_node) 84 | leaf_node = paths[leaf_node] 85 | 86 | def get_path_nodes(leaf_nodes, paths): 87 | nodes_list = [] 88 | for item in leaf_nodes: 89 | get_paths(item,paths,nodes_list) 90 | return nodes_list 91 | 92 | paths_nodes = get_path_nodes(leaf_nodes, paths) 93 | 94 | taxon["h"] = paths.keys() 95 | taxon["t"] = paths.values() 96 | 97 | taxon = taxon.loc[(taxon['h'].isin(paths_nodes)) & (taxon['t'].isin(paths_nodes)),:] 98 | # taxon = taxon.reset_index() 99 | 100 | print('len(taxon) = {}'.format(len(taxon))) 101 | 102 | out_file = 'data/snapshot_mountain_zebra/taxon.csv' 103 | taxon.to_csv(out_file, index=False) 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /pytorchtools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class EarlyStopping: 5 | """Early stops the training if validation loss doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, ckpt_path='checkpoint.pt', best_ckpt_path='best_checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.val_loss_min = np.Inf 26 | self.delta = delta 27 | self.ckpt_path = ckpt_path 28 | self.best_ckpt_path = best_ckpt_path 29 | self.trace_func = trace_func 30 | 31 | def __call__(self, val_loss, model, dense_optimizer, sparse_optimizer=None): 32 | 33 | score = -val_loss 34 | 35 | if self.best_score is None: 36 | self.best_score = score 37 | # self.save_checkpoint(val_loss, model, dense_optimizer, sparse_optimizer) 38 | 39 | if self.verbose: 40 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best ckpt ...') 41 | 42 | if sparse_optimizer: 43 | ckpt_dict = {'model': model.state_dict(), 'sparse_optimizer':sparse_optimizer.state_dict(), 'dense_optimizer':dense_optimizer.state_dict()} 44 | else: 45 | ckpt_dict = {'model': model.state_dict(), 'dense_optimizer':dense_optimizer.state_dict()} 46 | 47 | torch.save(ckpt_dict, self.best_ckpt_path) 48 | 49 | self.val_loss_min = val_loss 50 | elif score < self.best_score + self.delta: 51 | self.counter += 1 52 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 53 | self.save_checkpoint(val_loss, model, dense_optimizer, sparse_optimizer) 54 | if self.counter >= self.patience: 55 | self.early_stop = True 56 | else: 57 | self.best_score = score 58 | if self.verbose: 59 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best ckpt ...') 60 | 61 | if sparse_optimizer: 62 | ckpt_dict = {'model': model.state_dict(), 'sparse_optimizer':sparse_optimizer.state_dict(), 'dense_optimizer':dense_optimizer.state_dict()} 63 | else: 64 | ckpt_dict = {'model': model.state_dict(), 'dense_optimizer':dense_optimizer.state_dict()} 65 | 66 | torch.save(ckpt_dict, self.best_ckpt_path) 67 | 68 | self.save_checkpoint(val_loss, model, dense_optimizer, sparse_optimizer) 69 | 70 | self.val_loss_min = val_loss 71 | self.counter = 0 72 | 73 | def save_checkpoint(self, val_loss, model, dense_optimizer, sparse_optimizer): 74 | '''Saves model when validation loss decrease.''' 75 | if self.verbose: 76 | self.trace_func(f'Saving ckpt ...') 77 | 78 | if sparse_optimizer: 79 | ckpt_dict = {'model': model.state_dict(), 'sparse_optimizer':sparse_optimizer.state_dict(), 'dense_optimizer':dense_optimizer.state_dict()} 80 | else: 81 | ckpt_dict = {'model': model.state_dict(), 'dense_optimizer':dense_optimizer.state_dict()} 82 | 83 | torch.save(ckpt_dict, self.ckpt_path) 84 | -------------------------------------------------------------------------------- /gen_utils/analyze_img_time.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import argparse 5 | import os 6 | import numpy as np 7 | import re 8 | from tqdm import tqdm 9 | import sklearn.cluster as cluster 10 | from collections import defaultdict, Counter 11 | from re import match 12 | 13 | def getSeparated(item): 14 | m = match(r"(.*)-(.*)-(.*) (.*):(.*):(\d{2})", item) 15 | years, month, day, hour, minutes, second = m.groups() 16 | return int(years), int(month), int(day), int(hour), int(minutes), int(second) 17 | 18 | def create_image_id_dict(datacsv_id): 19 | image_id_dict = {} 20 | 21 | for i in range(len(datacsv_id)): 22 | image_filename = datacsv_id.iloc[i, 0] 23 | specie_id = int(float(datacsv_id.iloc[i, -3])) 24 | 25 | image_id_dict[image_filename] = specie_id 26 | 27 | return image_id_dict 28 | 29 | def bhattacharyya_distance(distribution1, distribution2): 30 | """ Estimate Bhattacharyya Distance (between General Distributions) 31 | 32 | Args: 33 | distribution1: a sample distribution 1 34 | distribution2: a sample distribution 2 35 | 36 | Returns: 37 | Bhattacharyya distance 38 | """ 39 | sq = 0 40 | for i in range(len(distribution1)): 41 | sq += np.sqrt(distribution1[i]*distribution2[i]) 42 | 43 | return -np.log(sq) 44 | 45 | def calc_dist_train_val(centroid_counters, centroid_counters_val, idx1, idx2): 46 | counter_0_keys = list(set(centroid_counters[idx1].keys()) | set(centroid_counters_val[idx2].keys())) 47 | 48 | counter_0_train_dist, counter_0_val_dist = np.zeros(len(counter_0_keys)), np.zeros(len(counter_0_keys)) 49 | 50 | for item in centroid_counters[idx1]: 51 | counter_0_train_dist[counter_0_keys.index(item)] += centroid_counters[idx1][item] 52 | counter_0_train_dist = counter_0_train_dist/np.sum(counter_0_train_dist) 53 | 54 | for item in centroid_counters_val[idx2]: 55 | counter_0_val_dist[counter_0_keys.index(item)] += centroid_counters_val[idx2][item] 56 | counter_0_val_dist = counter_0_val_dist/np.sum(counter_0_val_dist) 57 | 58 | counter_0_dist = bhattacharyya_distance(counter_0_train_dist, counter_0_val_dist) 59 | return counter_0_dist 60 | 61 | if __name__=='__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--data-dir', type=str, default='../iwildcam_v2.0/') 64 | parser.add_argument('--seed', type=int, default=813765) 65 | args = parser.parse_args() 66 | 67 | np.random.seed(args.seed) 68 | 69 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 70 | 71 | datacsv_time_train = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'time') & (datacsv['split'] == 'train'), :] 72 | datacsv_time_val = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'time') & (datacsv['split'] == 'val'), :] 73 | 74 | 75 | datacsv_id_train = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'id') & (datacsv['split'] == 'train'), :] 76 | datacsv_id_val = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'id') & (datacsv['split'] == 'val'), :] 77 | 78 | # compute distribution over hh, mm, ss 79 | image_time_dict_train, image_time_dict_val = {}, {} 80 | 81 | for i in range(len(datacsv_time_train)): 82 | image_filename = datacsv_time_train.iloc[i, 0] 83 | time = datacsv_time_train.iloc[i, -3] 84 | image_time_dict_train[image_filename] = time 85 | 86 | for i in range(len(datacsv_time_val)): 87 | image_filename = datacsv_time_val.iloc[i, 0] 88 | time = datacsv_time_val.iloc[i, -3] 89 | image_time_dict_val[image_filename] = time 90 | 91 | image_id_dict_train = create_image_id_dict(datacsv_id_train) 92 | image_id_dict_val = create_image_id_dict(datacsv_id_val) 93 | 94 | # calculate confus matrix for different hours 95 | species_c_hour_train = [Counter() for _ in range(24)] 96 | 97 | for img_filename in tqdm(image_time_dict_train): 98 | time = image_time_dict_train[img_filename] 99 | yyyy, mon, dd, hh, mm, ss = getSeparated(time) 100 | 101 | species_id = image_id_dict_train[img_filename] 102 | species_c_hour_train[hh].update([species_id]) 103 | 104 | species_c_hour_val = [Counter() for _ in range(24)] 105 | 106 | for img_filename in tqdm(image_time_dict_val): 107 | time = image_time_dict_val[img_filename] 108 | yyyy, mon, dd, hh, mm, ss = getSeparated(time) 109 | 110 | species_id = image_id_dict_val[img_filename] 111 | species_c_hour_val[hh].update([species_id]) 112 | 113 | confus_mat = np.ones((24, 24))*np.inf 114 | 115 | for i in range(24): 116 | for j in range(24): 117 | if len(species_c_hour_train[i])>0 and len(species_c_hour_val[j])>0: 118 | confus_mat[i, j] = calc_dist_train_val(species_c_hour_train, species_c_hour_val, i, j) 119 | 120 | 121 | fig = plt.figure() 122 | cax = plt.matshow(confus_mat) 123 | fig.colorbar(cax) 124 | plt.savefig('time_corr_analysis.png') 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /gen_utils/eval_kge_specie_wise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | import sys 12 | import json 13 | from collections import defaultdict 14 | import math 15 | 16 | sys.path.append('../') 17 | 18 | from model import MKGE 19 | from resnet import Resnet18, Resnet50 20 | 21 | from tqdm import tqdm 22 | from utils import collate_list, detach_and_clone, move_to 23 | import torch.optim as optim 24 | from torch.utils.data import Dataset, DataLoader 25 | from wilds.common.metrics.all_metrics import Accuracy 26 | from PIL import Image 27 | from dataset import iWildCamOTTDataset 28 | 29 | if __name__=='__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--data-dir', type=str, default='../iwildcam_v2.0/') 32 | parser.add_argument('--img-dir', type=str, default='../iwildcam_v2.0/imgs/') 33 | parser.add_argument('--split', type=str, default='val') 34 | parser.add_argument('--seed', type=int, default=813765) 35 | 36 | parser.add_argument('--y-pred-path-1', type=str, default=None, help='path to y_pred 1 predictions') 37 | parser.add_argument('--y-pred-path-2', type=str, default=None, help='path to y_pred 2 predictions') 38 | 39 | parser.add_argument('--debug', action='store_true') 40 | parser.add_argument('--no-cuda', action='store_true') 41 | parser.add_argument('--use-subtree', action='store_true', help='use truncated OTT') 42 | parser.add_argument('--batch_size', type=int, default=16) 43 | 44 | parser.add_argument('--embedding-dim', type=int, default=512) 45 | parser.add_argument('--location_input_dim', type=int, default=2) 46 | parser.add_argument('--time_input_dim', type=int, default=1) 47 | parser.add_argument('--mlp_location_numlayer', type=int, default=3) 48 | parser.add_argument('--mlp_time_numlayer', type=int, default=3) 49 | 50 | parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50') 51 | parser.add_argument('--use-data-subset', action='store_true') 52 | parser.add_argument('--subset-size', type=int, default=10) 53 | 54 | 55 | args = parser.parse_args() 56 | 57 | print('args = {}'.format(args)) 58 | args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu') 59 | 60 | # Set random seed 61 | torch.manual_seed(args.seed) 62 | np.random.seed(args.seed) 63 | random.seed(args.seed) 64 | 65 | y_1_pred_dict = json.load(open(args.y_pred_path_1)) 66 | y_2_pred_dict = json.load(open(args.y_pred_path_2)) 67 | 68 | total = 0 69 | train_c = dict([(156, 48007), (1, 10267), (14, 7534), (0, 4078), (5, 4023), (2, 3986), (27, 3584), (54, 3177), (15, 3091), (30, 2740), (31, 2642), (57, 2401), (17, 1966), (12, 1913), (24, 1751), (158, 1709), (160, 1542), (48, 1530), (52, 1444), (32, 1428), (13, 1246), (155, 1168), (33, 1150), (11, 1042), (53, 977), (165, 949), (55, 904), (159, 865), (9, 771), (16, 730), (3, 716), (56, 684), (8, 605), (10, 538), (7, 531), (64, 459), (41, 457), (6, 450), (37, 433), (46, 380), (74, 367), (101, 350), (70, 290), (29, 243), (106, 201), (58, 200), (44, 194), (80, 190), (45, 180), (4, 161), (61, 158), (40, 146), (28, 136), (162, 128), (36, 117), (130, 110), (67, 108), (21, 106), (35, 102), (65, 100), (82, 100), (88, 92), (71, 87), (18, 81), (102, 80), (161, 80), (170, 80), (25, 75), (77, 73), (50, 70), (62, 62), (100, 60), (97, 60), (34, 55), (43, 50), (79, 48), (157, 46), (111, 44), (94, 39), (59, 38), (19, 38), (47, 36), (98, 32), (39, 30), (85, 30), (22, 29), (90, 29), (84, 29), (121, 28), (63, 25), (38, 24), (173, 23), (83, 21), (110, 21), (139, 20), (69, 20), (95, 19), (86, 18), (72, 18), (127, 15), (129, 15), (26, 15), (75, 15), (154, 15), (93, 14), (76, 13), (87, 13), (81, 13), (109, 12), (108, 12), (120, 12), (123, 12), (60, 12), (96, 12), (145, 11), (131, 10), (149, 10), (177, 10), (178, 10), (23, 9), (122, 9), (42, 9), (103, 9), (134, 9), (135, 9), (153, 9), (164, 9), (66, 8), (20, 8), (116, 8), (114, 7), (125, 7), (172, 7), (107, 6), (119, 6), (99, 6), (133, 6), (140, 6), (142, 6), (146, 6), (147, 6), (179, 6), (180, 6), (181, 6), (118, 5), (163, 5), (104, 4), (112, 4), (167, 4), (113, 3), (115, 3), (117, 3), (78, 3), (92, 3), (126, 3), (128, 3), (91, 3), (68, 3), (137, 3), (138, 3), (143, 3), (144, 3), (51, 3), (150, 3), (152, 3), (89, 2), (49, 2), (132, 2), (136, 2), (169, 2), (166, 2), (124, 1), (73, 1), (105, 1), (141, 1), (148, 1), (151, 1), (168, 1), (171, 1), (174, 1), (175, 1), (176, 1)]) 70 | 71 | threshold = 100 72 | 73 | acc_1_avg = 0.0 74 | acc_2_avg = 0.0 75 | 76 | for label_id in y_1_pred_dict: 77 | # print(type(label_id)) 78 | 79 | if train_c[int(label_id)] <= threshold: 80 | y_1_pred = y_1_pred_dict[label_id] 81 | y_2_pred = y_2_pred_dict[label_id] 82 | 83 | assert len(y_1_pred)==len(y_2_pred) 84 | 85 | if len(y_1_pred)>0: 86 | acc_1 = y_1_pred.count(1)*100.0/len(y_1_pred) 87 | acc_2 = y_2_pred.count(1)*100.0/len(y_1_pred) 88 | print(f'label_id = {label_id}, acc_1 = {acc_1:.2f}, acc_2 = {acc_2:.2f}, train_count = {train_c[int(label_id)]}') 89 | 90 | acc_1_avg += acc_1 91 | acc_2_avg += acc_2 92 | 93 | total += 1 94 | 95 | acc_1_avg = acc_1_avg/total 96 | acc_2_avg = acc_2_avg/total 97 | 98 | print('acc_1_avg = {}'.format(acc_1_avg)) 99 | print('acc_2_avg = {}'.format(acc_2_avg)) 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # COSMO 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bringing-back-the-context-camera-trap-species/image-classification-on-iwildcam2020-wilds)](https://paperswithcode.com/sota/image-classification-on-iwildcam2020-wilds?p=bringing-back-the-context-camera-trap-species) 5 | 6 | ## [CIKM'24] Reviving the Context: Camera Trap Species Classification as Link Prediction on Multimodal Knowledge Graphs 7 | 8 | **Paper**: https://arxiv.org/pdf/2401.00608.pdf 9 | 10 | **Project webpage**: https://osu-nlp-group.github.io/COSMO/ 11 | 12 | **Authors**: Vardaan Pahuja, Weidi Luo, Yu Gu, Cheng-Hao Tu, Hong-You Chen, Tanya Berger-Wolf, Charles Stewart, Song Gao, Wei-Lun Chao, and Yu Su 13 | 14 | ## Installation 15 | 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Data Preprocessing 21 | 22 | ### iWildCam2020-WILDS 23 | ``` 24 | bash preprocess_iwildcam.sh 25 | ``` 26 | Note: The dir. `data/iwildcam_v2.0/train/` contains images for all splits. 27 | 28 | ### Snapshot Mountain Zebra 29 | 1. Download snapshot_mountain_zebra.zip from [this link](https://buckeyemailosu-my.sharepoint.com/:u:/g/personal/pahuja_9_buckeyemail_osu_edu/EWI05mXQsopNskBo78a_l_ABSZJHl0uCsdNMu72aXmNNiA?e=LOtm5Q) and uncompress it into a directory `data/snapshot_mountain_zebra/`. 30 | 2. Download images using the command `gsutil -m cp -r "gs://public-datasets-lila/snapshot-safari/MTZ/MTZ_public" data/snapshot_mountain_zebra/` 31 | 2. Run `bash preprocess_mountain_zebra.sh` 32 | 33 | 34 | ## Training 35 | 36 | Note: The below commands will use the DistMult model by default. Use the following hyperparameter configuration: 37 | 38 | - For iWildCam2020-WILDS, set `DATA_DIR` to `data/iwildcam_v2.0/`, `IMG_DIR` to `data/iwildcam_v2.0/train/`, and `DATASET` to `iwildcam` 39 | - For Snapshot Mountain Zebra, set `DATA_DIR` to `data/snapshot_mountain_zebra/` and `IMG_DIR` to `data/snapshot_mountain_zebra/`, and `DATASET` to `mountain_zebra`. 40 | - For ConvE, use `--kg-embed-model conve --embedding-dim 200` in args. 41 | 42 | 43 | ### Image-only model (ERM baseline) 44 | ``` 45 | python -u run_image_only_model.py --data-dir DATA_DIR --img-dir IMG_DIR --save-dir CKPT_DIR > CKPT_DIR/log.txt 46 | ``` 47 | 48 | ### COSMO, no-context baseline 49 | ``` 50 | python -u main.py --data-dir DATA_DIR --img-dir IMG_DIR --save-dir CKPT_DIR > CKPT_DIR/log.txt 51 | ``` 52 | 53 | ### COSMO, taxonomy 54 | ``` 55 | python -u main.py --data-dir DATA_DIR --img-dir IMG_DIR --save-dir CKPT_DIR --add-id-id > CKPT_DIR/log.txt 56 | ``` 57 | 58 | ### COSMO, location 59 | ``` 60 | python -u main.py --data-dir DATA_DIR --img-dir IMG_DIR --save-dir CKPT_DIR --add-image-location > CKPT_DIR/log.txt 61 | ``` 62 | 63 | ### COSMO, time 64 | ``` 65 | python -u main.py --data-dir DATA_DIR --img-dir IMG_DIR/ --save-dir CKPT_DIR --add-image-time > CKPT_DIR/log.txt 66 | ``` 67 | 68 | ### COSMO, taxonomy + location + time 69 | ``` 70 | python -u main.py --data-dir DATA_DIR --img-dir IMG_DIR --save-dir CKPT_DIR --add-id-id --add-image-time --add-image-location > CKPT_DIR/log.txt 71 | ``` 72 | 73 | ### MLP-concat baseline 74 | ``` 75 | python -u run_kge_model_baseline.py --data-dir DATA_DIR --img-dir IMG_DIR --save-dir CKPT_DIR --embedding-dim 512 --use-subtree --only-hour --time_input_dim 1 --early-stopping-patience 10 > CKPT_DIR/log.txt 76 | 77 | ``` 78 | 79 | ## Evaluation 80 | 81 | ### Evaluate a model (specify split) 82 | ``` 83 | python eval.py --ckpt-path --split test --data-dir DATA_DIR --img-dir IMG_DIR 84 | ``` 85 | 86 | ## Error Analysis 87 | 88 | ### Taxonomy analysis 89 | ``` 90 | cd gen_utils/ 91 | python analyze_taxonomy_model.py --data-dir DATA_DIR --img-dir IMG_DIR --ckpt-1-path --ckpt-2-path 92 | ``` 93 | 94 | ### Plot location correlation analysis 95 | ``` 96 | cd gen_utils/ 97 | python analyze_img_loc.py --data-dir DATA_DIR 98 | ``` 99 | 100 | ### Plot time correlation analysis 101 | ``` 102 | cd gen_utils/ 103 | python analyze_img_time.py --data-dir DATA_DIR 104 | ``` 105 | 106 | 107 | ### Under-represented Species Analysis 108 | 109 | #### Dump predictions for baseline image-only model 110 | ``` 111 | cd gen_utils/ 112 | python dump_imageonly_pred_specie_wise.py --ckpt-path --split test --out-dir 113 | ``` 114 | 115 | #### Dump predictions for COSMO model 116 | ``` 117 | cd gen_utils/ 118 | python dump_kge_pred_specie_wise.py --ckpt-path --split test --out-dir 119 | ``` 120 | 121 | #### Compare performance for under-represented species 122 | ``` 123 | cd gen_utils/ 124 | python eval_kge_specie_wise.py --y-pred-path-1 --y-pred-path-2 125 | ``` 126 | 127 | ## Citation 128 | ``` 129 | @inproceedings{10.1145/3627673.3679545, 130 | author = {Pahuja, Vardaan and Luo, Weidi and Gu, Yu and Tu, Cheng-Hao and Chen, Hong-You and Berger-Wolf, Tanya and Stewart, Charles and Gao, Song and Chao, Wei-Lun and Su, Yu}, 131 | title = {Reviving the Context: Camera Trap Species Classification as Link Prediction on Multimodal Knowledge Graphs}, 132 | year = {2024}, 133 | isbn = {9798400704369}, 134 | publisher = {Association for Computing Machinery}, 135 | address = {New York, NY, USA}, 136 | url = {https://doi.org/10.1145/3627673.3679545}, 137 | doi = {10.1145/3627673.3679545}, 138 | booktitle = {Proceedings of the 33rd ACM International Conference on Information and Knowledge Management}, 139 | pages = {1825–1835}, 140 | numpages = {11}, 141 | keywords = {KG link prediction, camera traps, multimodal knowledge graph, species classification}, 142 | location = {Boise, ID, USA}, 143 | series = {CIKM '24} 144 | } 145 | ``` 146 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import re 5 | from math import pi 6 | from re import match 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | import json 11 | import pickle 12 | import pandas as pd 13 | 14 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 15 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 16 | 17 | class iWildCamOTTDataset(Dataset): 18 | def __init__(self, datacsv, mode, args, entity2id, target_list, head_type=None, tail_type=None): 19 | super(iWildCamOTTDataset, self).__init__() 20 | if head_type is not None and tail_type is not None: 21 | self.datacsv = datacsv.loc[(datacsv['datatype_h'] == head_type) & (datacsv['datatype_t'] == tail_type) & ( 22 | datacsv['split'] == mode), :] 23 | print("length of {}2{} dataset = {}".format(head_type, tail_type, len(self.datacsv))) 24 | else: 25 | self.datacsv = datacsv.loc[datacsv['split'] == mode, :] 26 | print("length of alltype dataset = {}".format(len(self.datacsv))) 27 | self.args = args 28 | self.mode = mode 29 | self.entity2id = entity2id 30 | self.target_list = target_list 31 | self.entity_to_species_id = {self.target_list[i, 0].item():i for i in range(len(self.target_list))} 32 | 33 | # print(self.entity_to_species_id) 34 | 35 | if args.use_data_subset: 36 | train_indices = np.random.choice(np.arange(len(self.datacsv)), size=args.subset_size, replace=False) 37 | self.datacsv = self.datacsv.iloc[train_indices] 38 | 39 | if head_type == 'image' and tail_type == 'location': 40 | datacsv_loc = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'location')] 41 | 42 | self.location_to_id = {} 43 | 44 | for i in range(len(datacsv_loc)): 45 | loc = datacsv_loc.iloc[i, 3] 46 | 47 | assert loc[0] == '[' 48 | assert loc[-1] == ']' 49 | # print(loc) 50 | if loc not in self.location_to_id: 51 | self.location_to_id[loc] = len(self.location_to_id) 52 | 53 | self.all_locs = torch.stack(list(map(lambda x:getNumber(x), self.location_to_id.keys()))) 54 | 55 | self.all_timestamps = None 56 | 57 | if head_type == 'image' and tail_type == 'time': 58 | datacsv_time = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'time')] 59 | self.time_to_id = {} 60 | 61 | for i in range(len(datacsv_time)): 62 | time = datacsv_time.iloc[i, 3] 63 | 64 | _, hour = get_separate_time(time) 65 | 66 | _HOUR_RAD = 2 * pi / 24 67 | 68 | h1, h2 = point(hour, _HOUR_RAD) 69 | 70 | time = hour 71 | 72 | if time not in self.time_to_id: 73 | self.time_to_id[time] = len(self.time_to_id) 74 | 75 | self.all_timestamps = torch.stack(list(map(lambda x:torch.tensor(x), self.time_to_id.keys()))) 76 | if len(self.all_timestamps.size())==1: 77 | self.all_timestamps = self.all_timestamps.unsqueeze(-1) 78 | 79 | def __len__(self): 80 | return len(self.datacsv) 81 | 82 | def __getitem__(self, idx): 83 | 84 | head_type = self.datacsv.iloc[idx, 1] 85 | tail_type = self.datacsv.iloc[idx, 4] 86 | head = self.datacsv.iloc[idx, 0] 87 | relation = self.datacsv.iloc[idx, 2] 88 | tail = self.datacsv.iloc[idx, 3] 89 | 90 | # for tail extract 91 | h = None 92 | t = None 93 | 94 | if tail_type == "id": 95 | if head_type in ["image", "location"]: 96 | t = torch.tensor([self.entity_to_species_id[self.entity2id[str(int(float(tail)))]]], dtype=torch.long).squeeze(-1) 97 | else: 98 | t = torch.tensor([self.entity2id[str(int(float(tail)))]], dtype=torch.long).squeeze(-1) 99 | 100 | elif tail_type == "location": 101 | t = self.location_to_id[tail] 102 | 103 | elif tail_type == "time": 104 | tail = datatime_divide(tail, self.args) 105 | t = self.time_to_id[tail] 106 | 107 | # for head extract 108 | if head_type == "id": 109 | h = torch.tensor([self.entity2id[str(int(float(head)))]], dtype=torch.long).squeeze(-1) 110 | 111 | elif head_type == "image": 112 | img = Image.open(os.path.join(self.args.img_dir, head)).convert('RGB') 113 | 114 | transform_steps = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)]) 115 | h = transform_steps(img) 116 | 117 | elif head_type == "location": 118 | h = getNumber(head) 119 | 120 | # for r extract 121 | r = torch.tensor([int(relation)]) 122 | 123 | return h, r, t 124 | 125 | def getNumber(x): 126 | return torch.tensor(np.fromstring(x[1:-1], dtype=float, sep=' '), dtype=torch.float) 127 | 128 | def get_separate_time(item): 129 | m = match(r"(.*)-(.*)-(.*) (.*):(.*):(\d{2})", item) 130 | years, month, day, hour, minutes, second = m.groups() 131 | return float(month), float(hour) 132 | 133 | def datatime_divide(timestamp, args): 134 | month, hour = get_separate_time(timestamp) 135 | 136 | _HOUR_RAD = 2 * pi / 24 137 | 138 | h1, h2 = point(hour, _HOUR_RAD) 139 | 140 | return hour 141 | 142 | def point(m, rad): 143 | from math import sin, cos 144 | # place on circle 145 | return sin(m * rad), cos(m * rad) 146 | 147 | 148 | -------------------------------------------------------------------------------- /model_st.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch import Tensor 5 | from typing import Tuple 6 | 7 | from torchvision.models import resnet18, resnet50 8 | from torchvision.models import ResNet18_Weights, ResNet50_Weights 9 | import pretrainedmodels 10 | import ssl 11 | ssl._create_default_https_context = ssl._create_unverified_context # for pretrainedmodels 12 | 13 | class MKGE(nn.Module): 14 | def __init__(self, args, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None, all_loc_times=None): 15 | super(MKGE, self).__init__() 16 | self.args = args 17 | self.num_ent_uid = num_ent_uid 18 | 19 | self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, args.embedding_dim, sparse=False) 20 | 21 | if self.args.use_learned_loc_embed: 22 | self.location_embedding = torch.nn.Embedding(len(all_locs), args.embedding_dim) 23 | else: 24 | self.location_embedding = MLP(args.location_input_dim, args.embedding_dim, args.mlp_location_numlayer) 25 | 26 | self.time_embedding = MLP(args.time_input_dim, args.embedding_dim, args.mlp_time_numlayer) 27 | # print(self.time_embedding) 28 | # print(self.location_embedding) 29 | 30 | if self.args.img_embed_model == 'resnet50': 31 | self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) 32 | self.image_embedding.fc = nn.Linear(2048, args.embedding_dim) 33 | elif self.args.img_embed_model == 'resnet18': 34 | self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) 35 | self.image_embedding.fc = nn.Linear(512, args.embedding_dim) 36 | # self.image_embedding.fc = nn.Linear(512, 182) 37 | else: 38 | raise NotImplementedError 39 | 40 | self.target_list = target_list 41 | if all_locs is not None: 42 | self.all_locs = all_locs.to(device) 43 | if all_timestamps is not None: 44 | self.all_timestamps = all_timestamps.to(device) 45 | if all_loc_times is not None: 46 | self.all_loc_times = all_loc_times.to(device) 47 | 48 | #print(self.all_locs) 49 | 50 | if self.args.add_inverse_rels: 51 | num_relations = 4 52 | else: 53 | num_relations = 2 54 | 55 | self.act = nn.PReLU() 56 | 57 | self.mlp = nn.Linear(3*args.embedding_dim, args.embedding_dim) 58 | self.layer_norm = nn.LayerNorm(3*args.embedding_dim) 59 | 60 | self.classifier = nn.Linear(args.embedding_dim, len(self.target_list)) 61 | 62 | self.args = args 63 | self.device = device 64 | 65 | self.init() 66 | 67 | def init(self): 68 | nn.init.xavier_uniform_(self.ent_embedding.weight.data) 69 | # nn.init.xavier_uniform_(self.rel_embedding.weight.data) 70 | 71 | if self.args.img_embed_model in ['resnet18', 'resnet50']: 72 | nn.init.xavier_uniform_(self.image_embedding.fc.weight.data) 73 | 74 | if self.args.use_learned_loc_embed: 75 | nn.init.xavier_uniform_(self.location_embedding.weight.data) 76 | 77 | nn.init.xavier_uniform_(self.mlp.weight.data) 78 | 79 | nn.init.xavier_uniform_(self.classifier.weight.data) 80 | 81 | # @profile 82 | def forward_ce(self, graph, image, time, location=None): 83 | 84 | # create a graph using location and time attributes of the image 85 | # print('graph.n_id = {}'.format(graph.n_id)) 86 | 87 | # node ids: 88 | # : 0 89 | # T: 1 90 | # L: 2 91 | 92 | # edge ids: 93 | # (, T): 0 94 | # (, L): 1 95 | 96 | # gather initial node embedding 97 | batch_size = image.size(0) 98 | 99 | img_embed = self.image_embedding(image) 100 | 101 | # print('img_embed = {}'.format(img_embed)) 102 | 103 | time_emb = self.time_embedding(time) 104 | 105 | if location is not None: 106 | loc_emb = self.location_embedding(location) 107 | 108 | if location is not None: 109 | node_emb = torch.stack([img_embed, time_emb, loc_emb], dim=1) # [batch, n_nodes, hid_dim] 110 | else: 111 | node_emb = torch.stack([img_embed, time_emb], dim=1) 112 | 113 | node_emb = node_emb.view(node_emb.size(0), -1) 114 | 115 | node_emb = self.layer_norm(node_emb) 116 | 117 | img_context_emb = self.mlp(node_emb) 118 | 119 | img_context_emb = self.act(img_context_emb) 120 | 121 | # project the embeddding using a linear layer to compute label distribution 122 | score = self.classifier(img_context_emb) 123 | # print('score = {}'.format(score.size())) 124 | 125 | return score 126 | 127 | 128 | class MLP(nn.Module): 129 | def __init__(self, 130 | input_dim, 131 | output_dim, 132 | num_layers=3, 133 | p_dropout=0.0, 134 | bias=True): 135 | 136 | super().__init__() 137 | 138 | self.input_dim = input_dim 139 | self.output_dim = output_dim 140 | 141 | self.p_dropout = p_dropout 142 | step_size = (input_dim - output_dim) // num_layers 143 | hidden_dims = [output_dim + (i * step_size) 144 | for i in reversed(range(num_layers))] 145 | 146 | mlp = list() 147 | layer_indim = input_dim 148 | for hidden_dim in hidden_dims: 149 | mlp.extend([nn.Linear(layer_indim, hidden_dim, bias), 150 | nn.Dropout(p=self.p_dropout, inplace=True), 151 | nn.PReLU()]) 152 | 153 | layer_indim = hidden_dim 154 | 155 | self.mlp = nn.Sequential(*mlp) 156 | 157 | # initiate weights 158 | self.init() 159 | 160 | def forward(self, x): 161 | return self.mlp(x) 162 | 163 | def init(self): 164 | for param in self.parameters(): 165 | nn.init.uniform_(param) 166 | -------------------------------------------------------------------------------- /gen_utils/analyze_img_loc.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import argparse 5 | import os 6 | import numpy as np 7 | import re 8 | from tqdm import tqdm 9 | import sklearn.cluster as cluster 10 | from collections import defaultdict, Counter 11 | 12 | def getNumber(x): 13 | return np.fromstring(x[1:-1], sep=' ', dtype=float) 14 | 15 | def plot_loc_viz(image_loc_dict, image_loc_dict_val, image_loc_dict_test): 16 | plt.figure() 17 | 18 | # train 19 | X_train, Y_train = [], [] 20 | 21 | for img_filename in tqdm(image_loc_dict): 22 | x, y = getNumber(image_loc_dict[img_filename]) 23 | X_train.append(x) 24 | Y_train.append(y) 25 | 26 | scale = 200.0 * np.random.rand(len(X_train)) 27 | plt.subplot(1, 4, 1) 28 | plt.scatter(X_train, Y_train, label='train', s=scale, alpha=0.3) 29 | plt.legend() 30 | 31 | # val 32 | X_val, Y_val = [], [] 33 | 34 | for img_filename in tqdm(image_loc_dict_val): 35 | x, y = getNumber(image_loc_dict_val[img_filename]) 36 | X_val.append(x) 37 | Y_val.append(y) 38 | 39 | scale = 200.0 * np.random.rand(len(X_val)) 40 | plt.subplot(1, 4, 2) 41 | plt.scatter(X_val, Y_val, label='val', s=scale, alpha=0.3) 42 | plt.legend() 43 | 44 | # test 45 | X_test, Y_test = [], [] 46 | 47 | for img_filename in tqdm(image_loc_dict_test): 48 | x, y = getNumber(image_loc_dict_test[img_filename]) 49 | X_test.append(x) 50 | Y_test.append(y) 51 | 52 | scale = 200.0 * np.random.rand(len(X_test)) 53 | plt.subplot(1, 4, 3) 54 | plt.scatter(X_test, Y_test, label='test', s=scale, alpha=0.3) 55 | plt.legend() 56 | 57 | plt.subplot(1, 4, 4) 58 | scale = 200.0 * np.random.rand(len(X_train)) 59 | plt.scatter(X_train, Y_train, label='train', s=scale, alpha=0.3) 60 | scale = 200.0 * np.random.rand(len(X_val)) 61 | plt.scatter(X_val, Y_val, label='val', s=scale, alpha=0.3) 62 | scale = 200.0 * np.random.rand(len(X_test)) 63 | plt.scatter(X_test, Y_test, label='test', s=scale, alpha=0.3) 64 | plt.legend() 65 | 66 | plt.savefig('locs_splits.png') 67 | 68 | def plot_loc_hist(n_species_loc): 69 | # plot histogram 70 | plt.figure() 71 | plt.hist(n_species_loc) 72 | plt.xlabel('No. of species') 73 | plt.ylabel('No. of locations') 74 | plt.savefig('n_species_loc_hist.png') 75 | 76 | def create_image_id_dict(datacsv_id): 77 | image_id_dict = {} 78 | 79 | for i in range(len(datacsv_id)): 80 | image_filename = datacsv_id.iloc[i, 0] 81 | specie_id = int(float(datacsv_id.iloc[i, -3])) 82 | 83 | image_id_dict[image_filename] = specie_id 84 | 85 | return image_id_dict 86 | 87 | def bhattacharyya_distance(distribution1, distribution2): 88 | """ Estimate Bhattacharyya Distance (between General Distributions) 89 | 90 | Args: 91 | distribution1: a sample distribution 1 92 | distribution2: a sample distribution 2 93 | 94 | Returns: 95 | Bhattacharyya distance 96 | """ 97 | sq = 0 98 | for i in range(len(distribution1)): 99 | sq += np.sqrt(distribution1[i]*distribution2[i]) 100 | 101 | return -np.log(sq) 102 | 103 | def calc_dist_train_val(centroid_counters, centroid_counters_val, idx1, idx2): 104 | counter_0_keys = list(set(centroid_counters[idx1].keys()) | set(centroid_counters_val[idx2].keys())) 105 | 106 | counter_0_train_dist, counter_0_val_dist = np.zeros(len(counter_0_keys)), np.zeros(len(counter_0_keys)) 107 | 108 | for item in centroid_counters[idx1]: 109 | counter_0_train_dist[counter_0_keys.index(item)] += centroid_counters[idx1][item] 110 | counter_0_train_dist = counter_0_train_dist/np.sum(counter_0_train_dist) 111 | 112 | for item in centroid_counters_val[idx2]: 113 | counter_0_val_dist[counter_0_keys.index(item)] += centroid_counters_val[idx2][item] 114 | counter_0_val_dist = counter_0_val_dist/np.sum(counter_0_val_dist) 115 | 116 | 117 | counter_0_dist = bhattacharyya_distance(counter_0_train_dist, counter_0_val_dist) 118 | return counter_0_dist 119 | 120 | if __name__=='__main__': 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('--data-dir', type=str, default='iwildcam_v2.0/') 123 | parser.add_argument('--seed', type=int, default=813765) 124 | args = parser.parse_args() 125 | 126 | np.random.seed(args.seed) 127 | 128 | mode = 'train' 129 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 130 | 131 | datacsv_loc = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'location') & (datacsv['split'] == mode), :] 132 | datacsv_loc_val = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'location') & (datacsv['split'] == 'val'), :] 133 | 134 | datacsv_id = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'id') & (datacsv['split'] == mode), :] 135 | datacsv_id_val = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'id') & (datacsv['split'] == 'val'), :] 136 | 137 | image_loc_dict, image_loc_dict_val, image_loc_dict_test = {}, {}, {} 138 | loc_image_dict = defaultdict(list) 139 | all_locs = set() 140 | 141 | for i in range(len(datacsv_loc)): 142 | image_filename = datacsv_loc.iloc[i, 0] 143 | loc = datacsv_loc.iloc[i, -3] 144 | image_loc_dict[image_filename] = loc 145 | loc_image_dict[loc].append(image_filename) 146 | all_locs.add(loc) 147 | 148 | for i in range(len(datacsv_loc_val)): 149 | image_filename = datacsv_loc_val.iloc[i, 0] 150 | loc = datacsv_loc_val.iloc[i, -3] 151 | image_loc_dict_val[image_filename] = loc 152 | 153 | all_locs = list(all_locs) 154 | 155 | assert len(image_loc_dict) == len(datacsv_loc) 156 | 157 | all_locs_arr = np.array(list(map(lambda x:getNumber(x), all_locs))) 158 | 159 | (centroid, label, _) = cluster.k_means(all_locs_arr, n_clusters=6) 160 | 161 | centroid_counters = [Counter() for _ in range(6)] 162 | 163 | image_id_dict = create_image_id_dict(datacsv_id) 164 | image_id_dict_val = create_image_id_dict(datacsv_id_val) 165 | 166 | all_species = list(set(image_id_dict.values())) 167 | 168 | n_species = len(all_species) 169 | 170 | colors = np.random.rand(n_species) 171 | 172 | loc_species_dict = defaultdict(list) 173 | 174 | for img_filename in tqdm(image_loc_dict): 175 | loc = image_loc_dict[img_filename] 176 | species_id = image_id_dict[img_filename] 177 | 178 | loc_species_dict[loc].append(species_id) 179 | 180 | n_species_loc = [len(set(loc_species_dict[loc])) for loc in loc_species_dict] 181 | 182 | n_avg_species_loc = np.average(n_species_loc) 183 | 184 | for loc in tqdm(loc_species_dict): 185 | centroid_counters[label[all_locs.index(loc)]].update(loc_species_dict[loc]) 186 | 187 | # plot locations of train/val/test 188 | plot_loc_viz(image_loc_dict, image_loc_dict_val, image_loc_dict_test) 189 | 190 | centroid_counters_val = [Counter() for _ in range(6)] 191 | 192 | for img_filename in tqdm(image_loc_dict_val): 193 | loc = getNumber(image_loc_dict_val[img_filename]) 194 | 195 | # find closest of (x, y) to each of train's centroid points 196 | loc_dist = np.linalg.norm(loc - centroid, axis=-1) 197 | cluster_id = np.argmin(loc_dist) 198 | 199 | # assign centroid label to this point 200 | species_id = image_id_dict_val[img_filename] 201 | centroid_counters_val[cluster_id].update([species_id]) 202 | 203 | 204 | confus_mat = np.ones((6, 6))*np.inf 205 | 206 | for i in range(6): 207 | for j in range(6): 208 | if len(centroid_counters[i])>0 and len(centroid_counters_val[j])>0: 209 | confus_mat[i, j] = calc_dist_train_val(centroid_counters, centroid_counters_val, i, j) 210 | 211 | 212 | fig = plt.figure() 213 | cax = plt.matshow(confus_mat) 214 | fig.colorbar(cax) 215 | plt.savefig('loc_corr_analysis.png') 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | import sys 12 | import json 13 | from collections import defaultdict 14 | import math 15 | 16 | sys.path.append('../') 17 | 18 | from model import MKGE 19 | from resnet import Resnet18, Resnet50 20 | 21 | from tqdm import tqdm 22 | from utils import collate_list, detach_and_clone, move_to 23 | import torch.optim as optim 24 | from torch.utils.data import Dataset, DataLoader 25 | from wilds.common.metrics.all_metrics import Accuracy 26 | from PIL import Image 27 | from dataset import iWildCamOTTDataset 28 | 29 | def evaluate(model, val_loader, target_list, args): 30 | model.eval() 31 | torch.set_grad_enabled(False) 32 | 33 | epoch_y_true = [] 34 | epoch_y_pred = [] 35 | 36 | batch_idx = 0 37 | for labeled_batch in tqdm(val_loader): 38 | h, r, t = labeled_batch 39 | h = move_to(h, args.device) 40 | r = move_to(r, args.device) 41 | t = move_to(t, args.device) 42 | 43 | outputs = model.forward_ce(h, r, t, triple_type=('image', 'id')) 44 | 45 | batch_results = { 46 | 'y_true': t.cpu(), 47 | 'y_pred': outputs.cpu(), 48 | } 49 | 50 | y_true = detach_and_clone(batch_results['y_true']) 51 | epoch_y_true.append(y_true) 52 | y_pred = detach_and_clone(batch_results['y_pred']) 53 | y_pred = y_pred.argmax(-1) 54 | 55 | epoch_y_pred.append(y_pred) 56 | 57 | batch_idx += 1 58 | if args.debug: 59 | break 60 | 61 | epoch_y_pred = collate_list(epoch_y_pred) 62 | epoch_y_true = collate_list(epoch_y_true) 63 | 64 | metrics = [ 65 | Accuracy(prediction_fn=None), 66 | ] 67 | 68 | results = {} 69 | 70 | for i in range(len(metrics)): 71 | results.update({ 72 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 73 | }) 74 | 75 | print(f'Eval., split: {args.split}, image to id, Average acc: {results[metrics[0].agg_metric_field]*100:.2f}') 76 | 77 | return 78 | 79 | def _get_id(dict, key): 80 | id = dict.get(key, None) 81 | if id is None: 82 | id = len(dict) 83 | dict[key] = id 84 | return id 85 | 86 | def generate_target_list(data, entity2id): 87 | sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] 88 | sub = list(sub['t']) 89 | categories = [] 90 | for item in tqdm(sub): 91 | if entity2id[str(int(float(item)))] not in categories: 92 | categories.append(entity2id[str(int(float(item)))]) 93 | # print('categories = {}'.format(categories)) 94 | print("No. of target categories = {}".format(len(categories))) 95 | return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) 96 | 97 | 98 | 99 | if __name__=='__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--dataset', choices=['iwildcam', 'mountain_zebra'], default='iwildcam') 102 | parser.add_argument('--data-dir', type=str, default='../iwildcam_v2.0/') 103 | parser.add_argument('--img-dir', type=str, default='../iwildcam_v2.0/imgs/') 104 | parser.add_argument('--split', type=str, default='val') 105 | parser.add_argument('--seed', type=int, default=813765) 106 | parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt') 107 | parser.add_argument('--debug', action='store_true') 108 | parser.add_argument('--no-cuda', action='store_true') 109 | parser.add_argument('--use-subtree', action='store_true', help='use truncated OTT') 110 | parser.add_argument('--batch_size', type=int, default=16) 111 | 112 | parser.add_argument('--kg-embed-model', choices=['distmult', 'conve'], default='distmult') 113 | parser.add_argument('--embedding-dim', type=int, default=512) 114 | parser.add_argument('--location_input_dim', type=int, default=2) 115 | parser.add_argument('--time_input_dim', type=int, default=1) 116 | parser.add_argument('--mlp_location_numlayer', type=int, default=3) 117 | parser.add_argument('--mlp_time_numlayer', type=int, default=3) 118 | 119 | parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50') 120 | parser.add_argument('--use-data-subset', action='store_true') 121 | parser.add_argument('--subset-size', type=int, default=10) 122 | 123 | # ConvE hyperparams 124 | parser.add_argument('--embedding-shape1', type=int, default=20, help='The first dimension of the reshaped 2D embedding. The second dimension is infered. Default: 20') 125 | parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.') 126 | parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.') 127 | parser.add_argument('--feat-drop', type=float, default=0.2, help='Dropout for the convolutional features. Default: 0.2.') 128 | parser.add_argument('--use-bias', action='store_true', default=True, help='Use a bias in the convolutional layer. Default: True') 129 | parser.add_argument('--hidden-size', type=int, default=9728, help='The side of the hidden layer. The required size changes with the size of the embeddings. Default: 9728 (embedding size 200).') 130 | 131 | args = parser.parse_args() 132 | 133 | print('args = {}'.format(args)) 134 | args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu') 135 | print(args.device) 136 | 137 | # Set random seed 138 | torch.manual_seed(args.seed) 139 | np.random.seed(args.seed) 140 | random.seed(args.seed) 141 | 142 | if args.dataset == 'iwildcam': 143 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 144 | entity_id_file = os.path.join(args.data_dir, 'entity2id_subtree.json') 145 | else: 146 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'data_triples.csv'), low_memory=False) 147 | entity_id_file = os.path.join(args.data_dir, 'entity2id.json') 148 | 149 | if not os.path.exists(entity_id_file): 150 | entity2id = {} # each of triple types have their own entity2id 151 | 152 | for i in tqdm(range(datacsv.shape[0])): 153 | if datacsv.iloc[i,1] == "id": 154 | _get_id(entity2id, str(int(float(datacsv.iloc[i,0])))) 155 | 156 | if datacsv.iloc[i,-2] == "id": 157 | _get_id(entity2id, str(int(float(datacsv.iloc[i,-3])))) 158 | json.dump(entity2id, open(entity_id_file, 'w')) 159 | else: 160 | entity2id = json.load(open(entity_id_file, 'r')) 161 | 162 | num_ent_id = len(entity2id) 163 | 164 | print('len(entity2id) = {}'.format(len(entity2id))) 165 | 166 | target_list = generate_target_list(datacsv, entity2id) 167 | 168 | val_image_to_id_dataset = iWildCamOTTDataset(datacsv, args.split, args, entity2id, target_list, head_type="image", tail_type="id") 169 | print('len(val_image_to_id_dataset) = {}'.format(len(val_image_to_id_dataset))) 170 | 171 | val_loader = DataLoader( 172 | val_image_to_id_dataset, 173 | shuffle=False, # Do not shuffle eval datasets 174 | sampler=None, 175 | batch_size=args.batch_size, 176 | num_workers=0, 177 | pin_memory=True) 178 | 179 | model = MKGE(args, num_ent_id, target_list, args.device) 180 | 181 | model.to(args.device) 182 | 183 | # restore from ckpt 184 | if args.ckpt_path: 185 | ckpt = torch.load(args.ckpt_path, map_location=args.device) 186 | model.load_state_dict(ckpt['model'], strict=False) 187 | print('ckpt loaded...') 188 | 189 | evaluate(model, val_loader, target_list, args) 190 | -------------------------------------------------------------------------------- /preprocess_data_iwildcam.py: -------------------------------------------------------------------------------- 1 | from wilds import get_dataset 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | from tqdm import tqdm 6 | 7 | def gps(x): 8 | return np.array([x["latitude"], x["longitude"]]) 9 | 10 | # Load the full dataset, and download it if necessary 11 | dataset = get_dataset(dataset="iwildcam", download=True) 12 | 13 | metadata = pd.read_csv("data/iwildcam_v2.0/metadata.csv") 14 | categories = pd.read_csv("data/iwildcam_v2.0/categories.csv") 15 | 16 | # the map is iwildcam_id_to_name {y:name,....} 17 | k = list(categories.y) 18 | v = list(categories.name) 19 | 20 | iwildcam_id_to_name = {} 21 | for i in range(len(k)): 22 | iwildcam_id_to_name[k[i]] = v[i] 23 | 24 | iwildcam_name_to_id = {v:k for k,v in iwildcam_id_to_name.items()} 25 | 26 | # the map processing (replaces iwildcam category ids by species names) 27 | metadata_y = list(metadata.y) 28 | for i in range(len(metadata_y)): 29 | metadata_y[i] = iwildcam_id_to_name[metadata_y[i]] 30 | 31 | metadata.y = metadata_y 32 | metadata = metadata.loc[:, ["split", "location", "y", "datetime", "filename"]] 33 | metadata.columns = ["split", "location", "name", "datetime", "filename"] 34 | 35 | # store time used data 36 | time_used = metadata 37 | # load pre_used_taxonomy.csv to get dic{name:uid} 38 | taxon = pd.read_csv("ott_taxonomy.csv") 39 | 40 | k = list(taxon.name) 41 | v = list(taxon.uid) 42 | 43 | # the map is taxon_name_to_id {name:uid,....} 44 | taxon_name_to_id = {} 45 | for i in range(len(k)): 46 | taxon_name_to_id[k[i]] = v[i] 47 | 48 | taxon_id_to_name = {x:y for x,y in zip(taxon.uid, taxon.name)} 49 | json.dump(taxon_id_to_name, open('data/iwildcam_v2.0/taxon_id_to_name.json', 'w'), indent=1) 50 | 51 | category_offset_non_intersection = max(taxon_name_to_id.values()) + 1 52 | 53 | 54 | meta_categories = list(set([x for x in metadata.name])) 55 | 56 | ott_categories = list(taxon.name) 57 | 58 | intersection_categories = list(set(ott_categories) & set(meta_categories)) 59 | 60 | # intersection of iwildcam and OTT 61 | metadata_intersection = metadata.loc[metadata["name"].isin(intersection_categories), :].copy() 62 | 63 | # non-interesection part 64 | metadata_non_intersection = metadata.loc[~metadata["name"].isin(intersection_categories), :].copy() 65 | 66 | # replace name by uid in metadata_intersection 67 | metadata_name = list(metadata_intersection.name) 68 | for i in range(len(metadata_name)): 69 | metadata_name[i] = taxon_name_to_id[metadata_name[i]] 70 | 71 | metadata_intersection.name = metadata_name 72 | metadata_intersection.columns = ["split", "location", "uid", "datetime", "filename"] 73 | 74 | metadata_non_intersection_name = list(metadata_non_intersection.name) 75 | non_intersection_uids = set() 76 | overall_id_to_name = {} 77 | 78 | for i in range(len(metadata_non_intersection_name)): 79 | specie_name = metadata_non_intersection_name[i] 80 | metadata_non_intersection_name[i] = iwildcam_name_to_id[specie_name] + category_offset_non_intersection 81 | non_intersection_uids.add(iwildcam_name_to_id[specie_name]) 82 | overall_id_to_name[metadata_non_intersection_name[i]] = specie_name 83 | 84 | metadata_non_intersection.name = metadata_non_intersection_name 85 | 86 | intersection_uids = set([iwildcam_name_to_id[taxon_id_to_name[x]] for x in metadata_intersection.uid]) 87 | 88 | for specie_id in intersection_uids: 89 | overall_id_to_name[taxon_name_to_id[iwildcam_id_to_name[specie_id]]] = iwildcam_id_to_name[specie_id] 90 | 91 | common = non_intersection_uids & intersection_uids 92 | common = [iwildcam_id_to_name[x] for x in common] 93 | 94 | json.dump(overall_id_to_name, open('data/iwildcam_v2.0/overall_id_to_name.json', 'w')) 95 | 96 | # re-name name column 97 | metadata_non_intersection.columns = ["split", "location", "uid", "datetime", "filename"] 98 | 99 | # concatenate metadata_intersection and metadata_non_intersection 100 | metadata = pd.concat([metadata_intersection, metadata_non_intersection]) 101 | 102 | # store uid used dataset 103 | uid_used = metadata 104 | 105 | gps_data = pd.read_json('gps_locations.json') 106 | gps_data = gps_data.transpose() 107 | gps_data.insert(loc=2, column="location", value=gps_data.index.to_list()) 108 | gps_data = gps_data.sort_index(ascending=True) 109 | 110 | gps_data["GPS"] = gps_data.apply(gps, axis=1) 111 | 112 | k = list(gps_data.location) 113 | v = list(gps_data.GPS) 114 | 115 | # find the species that have GPS in metadata 116 | metadata = metadata.loc[metadata["location"].isin(k), :].copy() 117 | 118 | 119 | # the map is dic {location:GPS,....} 120 | dic = {} 121 | for i in range(len(k)): 122 | dic[k[i]] = v[i] 123 | 124 | # make location to GPS 125 | metadata_location = list(metadata.location) 126 | for i in range(len(metadata_location)): 127 | metadata_location[i] = dic[metadata_location[i]] 128 | 129 | metadata.location = metadata_location 130 | 131 | # store GPS used data 132 | gps_used = metadata 133 | 134 | taxon = taxon.fillna(0) 135 | taxon = taxon.loc[:, ["uid", "parent_uid"]] 136 | taxon.columns = ["h", "t"] 137 | taxon.insert(loc=1, column="r", value=1) 138 | taxon.insert(loc=1, column="datatype_h", value="id") 139 | taxon.insert(loc=4, column="datatype_t", value="id") 140 | taxon.insert(loc=5, column="split", value="train") 141 | taxon.columns = ["h", "datatype_h", "r", "t", "datatype_t", "split"] 142 | 143 | takeLocation = gps_used.loc[:, ["filename", "location", "split"]] 144 | takeLocation.insert(loc=1, column="r", value=2) 145 | takeLocation.insert(loc=1, column="datatype_h", value="image") 146 | takeLocation.insert(loc=4, column="datatype_t", value="location") 147 | takeLocation.columns = ["h", "datatype_h", "r", "t", "datatype_t", "split"] 148 | 149 | takeTime = time_used.loc[:, ["filename", "datetime", "split"]] 150 | takeTime.insert(loc=1, column="r", value=0) 151 | takeTime.insert(loc=1, column="datatype_h", value="image") 152 | takeTime.insert(loc=4, column="datatype_t", value="time") 153 | takeTime.columns = ["h", "datatype_h", "r", "t", "datatype_t", "split"] 154 | 155 | imageIsIn = uid_used.loc[:, ["filename", "uid", "split"]] 156 | imageIsIn.insert(loc=1, column="r", value=3) 157 | imageIsIn.insert(loc=1, column="datatype_h", value="image") 158 | imageIsIn.insert(loc=4, column="datatype_t", value="id") 159 | imageIsIn.columns = ["h", "datatype_h", "r", "t", "datatype_t", "split"] 160 | 161 | a = pd.concat([taxon, imageIsIn], ignore_index=True) 162 | a = pd.concat([a, takeTime], ignore_index=True) 163 | a = pd.concat([a, takeLocation], ignore_index=True) 164 | 165 | inner = a.loc[(a["datatype_h"]=="image") & (a["datatype_t"]=="id"),:].copy() 166 | 167 | ott = a.loc[(a["datatype_h"]=="id") & (a["datatype_t"]=="id"),:].copy() 168 | 169 | son = list(ott["h"]) 170 | father = list(ott["t"]) 171 | paths = {} 172 | for i in tqdm(range(len(son))): 173 | paths[int(float(son[i]))] = int(float(father[i])) 174 | 175 | leaf_node = list(inner.t) 176 | leaf_nodes = [] 177 | for item in tqdm(leaf_node): 178 | if int(float(item)) not in leaf_nodes: 179 | leaf_nodes.append(int(float(item))) 180 | 181 | list_paths = [] 182 | def get_paths(leaf_node, paths, nodes_list): 183 | while leaf_node in paths.keys(): 184 | # print(leaf_node,"->",paths[leaf_node]) 185 | nodes_list.append(leaf_node) 186 | leaf_node = paths[leaf_node] 187 | 188 | def get_path_nodes(leaf_nodes, paths): 189 | nodes_list = [] 190 | for item in leaf_nodes: 191 | get_paths(item,paths,nodes_list) 192 | return nodes_list 193 | 194 | paths_nodes = get_path_nodes(leaf_nodes, paths) 195 | 196 | ott["h"] = paths.keys() 197 | ott["t"] = paths.values() 198 | 199 | ott = ott.loc[(ott['h'].isin(paths_nodes)) & (ott['t'].isin(paths_nodes)),:] 200 | ott = ott.reset_index() 201 | 202 | a = a.loc[(a["datatype_h"] != "id"),:] 203 | a.reset_index() 204 | 205 | dataset = pd.concat([ott, a], ignore_index=True) 206 | dataset = dataset.iloc[:,1:] 207 | dataset.to_csv("data/iwildcam_v2.0/dataset_subtree.csv",index = False) 208 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch import Tensor 5 | from typing import Tuple 6 | 7 | from torchvision.models import resnet18, resnet50 8 | from torchvision.models import ResNet18_Weights, ResNet50_Weights 9 | 10 | class MKGE(nn.Module): 11 | def __init__(self, args, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None): 12 | super(MKGE, self).__init__() 13 | self.args = args 14 | self.num_ent_uid = num_ent_uid 15 | 16 | self.num_relations = 4 17 | 18 | self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, args.embedding_dim, sparse=False) 19 | self.rel_embedding = torch.nn.Embedding(self.num_relations, args.embedding_dim, sparse=False) 20 | 21 | if self.args.kg_embed_model == 'conve': 22 | self.inp_drop = torch.nn.Dropout(args.input_drop) 23 | self.hidden_drop = torch.nn.Dropout(args.hidden_drop) 24 | self.feature_map_drop = torch.nn.Dropout2d(args.feat_drop) 25 | 26 | self.emb_dim1 = args.embedding_shape1 # important parameter for ConvE 27 | self.emb_dim2 = args.embedding_dim // self.emb_dim1 28 | 29 | self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) 30 | self.bn0 = torch.nn.BatchNorm2d(1) 31 | self.bn1 = torch.nn.BatchNorm2d(32) 32 | self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim) 33 | self.fc = torch.nn.Linear(args.hidden_size, args.embedding_dim) 34 | 35 | self.location_embedding = MLP(args.location_input_dim, args.embedding_dim, args.mlp_location_numlayer) 36 | 37 | self.time_embedding = MLP(args.time_input_dim, args.embedding_dim, args.mlp_time_numlayer) 38 | 39 | if self.args.img_embed_model == 'resnet50': 40 | self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) 41 | self.image_embedding.fc = nn.Linear(2048, args.embedding_dim) 42 | else: 43 | self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) 44 | self.image_embedding.fc = nn.Linear(512, args.embedding_dim) 45 | 46 | self.target_list = target_list 47 | if all_locs is not None: 48 | self.all_locs = all_locs.to(device) 49 | if all_timestamps is not None: 50 | self.all_timestamps = all_timestamps.to(device) 51 | 52 | self.args = args 53 | self.device = device 54 | 55 | self.init() 56 | 57 | def init(self): 58 | nn.init.xavier_uniform_(self.ent_embedding.weight.data) 59 | nn.init.xavier_uniform_(self.rel_embedding.weight.data) 60 | nn.init.xavier_uniform_(self.image_embedding.fc.weight.data) 61 | 62 | def forward(self, h, r, t): 63 | 64 | emb_h = self.batch_embedding_concat_h(h) 65 | 66 | emb_r = self.rel_embedding(r.squeeze(-1)) 67 | 68 | if self.args.kg_embed_model == 'distmult': 69 | emb_t = self.batch_embedding_concat_h(t) 70 | score = torch.sum(emb_h * emb_r * emb_t, -1) 71 | 72 | elif self.args.kg_embed_model == 'conve': 73 | e1_embedded = e1_embedded.view(-1, 1, self.emb_dim1, self.emb_dim2) # [batch, 1, emb_dim1, emb_dim2] 74 | rel_embedded = rel_embedded.view(-1, 1, self.emb_dim1, self.emb_dim2) # [batch, 1, emb_dim1, emb_dim2] 75 | 76 | stacked_inputs = torch.cat([e1_embedded, rel_embedded], 2) # [batch, 1, 2*emb_dim1, emb_dim2] 77 | 78 | stacked_inputs = self.bn0(stacked_inputs) 79 | x = self.inp_drop(stacked_inputs) 80 | x = self.conv1(x) 81 | x = self.bn1(x) 82 | x = F.relu(x) 83 | x = self.feature_map_drop(x) 84 | x = x.view(x.shape[0], -1) 85 | x = self.fc(x) 86 | x = self.hidden_drop(x) 87 | x = self.bn2(x) 88 | x = F.relu(x) 89 | score = x * t 90 | else: 91 | raise NotImplementedError 92 | 93 | return score 94 | 95 | # @profile 96 | def forward_ce(self, h, r, t, triple_type=None): 97 | emb_h = self.batch_embedding_concat_h(h) # [batch, hid] 98 | 99 | emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid] 100 | 101 | if self.args.kg_embed_model == 'distmult': 102 | emb_hr = emb_h * emb_r # [batch, hid] 103 | elif self.args.kg_embed_model == 'conve': 104 | emb_h = emb_h.view(-1, 1, self.emb_dim1, self.emb_dim2) # [batch, 1, emb_dim1, emb_dim2] 105 | emb_r = emb_r.view(-1, 1, self.emb_dim1, self.emb_dim2) # [batch, 1, emb_dim1, emb_dim2] 106 | 107 | stacked_inputs = torch.cat([emb_h, emb_r], 2) # [batch, 1, 2*emb_dim1, emb_dim2] 108 | 109 | stacked_inputs = self.bn0(stacked_inputs) 110 | x = self.inp_drop(stacked_inputs) 111 | x = self.conv1(x) 112 | x = self.bn1(x) 113 | x = F.relu(x) 114 | x = self.feature_map_drop(x) 115 | x = x.view(x.shape[0], -1) 116 | x = self.fc(x) 117 | x = self.hidden_drop(x) 118 | x = self.bn2(x) 119 | emb_hr = F.relu(x) 120 | else: 121 | raise NotImplementedError 122 | 123 | if triple_type == ('image', 'id'): 124 | score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent] 125 | elif triple_type == ('id', 'id'): 126 | score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent] 127 | elif triple_type == ('image', 'location'): 128 | loc_emb = self.location_embedding(self.all_locs) # computed for each batch 129 | score = torch.mm(emb_hr, loc_emb.T) 130 | elif triple_type == ('image', 'time'): 131 | time_emb = self.time_embedding(self.all_timestamps) 132 | score = torch.mm(emb_hr, time_emb.T) 133 | else: 134 | raise NotImplementedError 135 | 136 | return score 137 | 138 | def batch_embedding_concat_h(self, e1): 139 | e1_embedded = None 140 | 141 | if len(e1.size())==1 or e1.size(1) == 1: # uid 142 | # print('ent_embedding = {}'.format(self.ent_embedding.weight.size())) 143 | e1_embedded = self.ent_embedding(e1.squeeze(-1)) 144 | elif e1.size(1) == 15: # time 145 | e1_embedded = self.time_embedding(e1) 146 | elif e1.size(1) == 2: # GPS 147 | e1_embedded = self.location_embedding(e1) 148 | elif e1.size(1) == 3: # Image 149 | e1_embedded = self.image_embedding(e1) 150 | 151 | return e1_embedded 152 | 153 | 154 | class MLP(nn.Module): 155 | def __init__(self, 156 | input_dim, 157 | output_dim, 158 | num_layers=3, 159 | p_dropout=0.0, 160 | bias=True): 161 | 162 | super().__init__() 163 | 164 | self.input_dim = input_dim 165 | self.output_dim = output_dim 166 | 167 | self.p_dropout = p_dropout 168 | step_size = (input_dim - output_dim) // num_layers 169 | hidden_dims = [output_dim + (i * step_size) 170 | for i in reversed(range(num_layers))] 171 | 172 | mlp = list() 173 | layer_indim = input_dim 174 | for hidden_dim in hidden_dims: 175 | mlp.extend([nn.Linear(layer_indim, hidden_dim, bias), 176 | nn.Dropout(p=self.p_dropout, inplace=True), 177 | nn.PReLU()]) 178 | 179 | layer_indim = hidden_dim 180 | 181 | self.mlp = nn.Sequential(*mlp) 182 | 183 | # initialize weights 184 | self.init() 185 | 186 | def forward(self, x): 187 | return self.mlp(x) 188 | 189 | def init(self): 190 | for param in self.parameters(): 191 | nn.init.uniform_(param) 192 | -------------------------------------------------------------------------------- /gen_utils/dump_kge_pred_specie_wise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | import sys 12 | import json 13 | from collections import defaultdict 14 | import math 15 | 16 | sys.path.append('../') 17 | 18 | from model import MKGE 19 | 20 | from tqdm import tqdm 21 | from utils import collate_list, detach_and_clone, move_to 22 | import torch.optim as optim 23 | from torch.utils.data import Dataset, DataLoader 24 | from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 25 | from PIL import Image 26 | from dataset import iWildCamOTTDataset 27 | 28 | def evaluate(model, val_loader, target_list, node_parent_map, args): 29 | model.eval() 30 | torch.set_grad_enabled(False) 31 | 32 | epoch_y_true = [] 33 | epoch_y_pred = [] 34 | 35 | batch_idx = 0 36 | 37 | y_pred_dict = {} 38 | 39 | for label_id in range(182): 40 | y_pred_dict[label_id] = [] 41 | 42 | for labeled_batch in tqdm(val_loader): 43 | h, r, t = labeled_batch 44 | h = move_to(h, args.device) 45 | r = move_to(r, args.device) 46 | t = move_to(t, args.device) 47 | 48 | outputs = model.forward_ce(h, r, t, triple_type=('image', 'id')) 49 | 50 | batch_results = { 51 | 'y_true': t.cpu(), 52 | 'y_pred': outputs.cpu(), 53 | } 54 | 55 | y_true = detach_and_clone(batch_results['y_true']) 56 | epoch_y_true.append(y_true) 57 | y_pred = detach_and_clone(batch_results['y_pred']) 58 | y_pred = y_pred.argmax(-1) 59 | 60 | epoch_y_pred.append(y_pred) 61 | 62 | for i in range(y_true.size(0)): 63 | x = (y_pred[i] == y_true[i]).long().item() 64 | y_pred_dict[y_true[i].item()].append(x) # 1 means prediction matches label, 0 otherwise. Used for calculating F1 score. 65 | 66 | batch_idx += 1 67 | if args.debug: 68 | break 69 | 70 | epoch_y_pred = collate_list(epoch_y_pred) 71 | epoch_y_true = collate_list(epoch_y_true) 72 | 73 | metrics = [ 74 | Accuracy(prediction_fn=None), 75 | Recall(prediction_fn=None, average='macro'), 76 | F1(prediction_fn=None, average='macro'), 77 | ] 78 | 79 | results = {} 80 | 81 | for i in range(len(metrics)): 82 | results.update({ 83 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 84 | }) 85 | 86 | print(f'Eval., split: {args.split}, image to id, Average acc: {results[metrics[0].agg_metric_field]*100:.2f}, F1 macro: {results[metrics[2].agg_metric_field]*100:.2f}') 87 | 88 | return y_pred_dict 89 | 90 | def _get_id(dict, key): 91 | id = dict.get(key, None) 92 | if id is None: 93 | id = len(dict) 94 | dict[key] = id 95 | return id 96 | 97 | def generate_target_list(data, entity2id): 98 | sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] 99 | sub = list(sub['t']) 100 | categories = [] 101 | for item in tqdm(sub): 102 | if entity2id[str(int(float(item)))] not in categories: 103 | categories.append(entity2id[str(int(float(item)))]) 104 | # print('categories = {}'.format(categories)) 105 | print("No. of target categories = {}".format(len(categories))) 106 | return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) 107 | 108 | if __name__=='__main__': 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('--data-dir', type=str, default='../iwildcam_v2.0/') 111 | parser.add_argument('--img-dir', type=str, default='../iwildcam_v2.0/imgs/') 112 | parser.add_argument('--split', type=str, default='val') 113 | parser.add_argument('--seed', type=int, default=813765) 114 | parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt') 115 | parser.add_argument('--out-dir', type=str) 116 | parser.add_argument('--debug', action='store_true') 117 | parser.add_argument('--no-cuda', action='store_true') 118 | parser.add_argument('--batch_size', type=int, default=16) 119 | 120 | parser.add_argument('--embedding-dim', type=int, default=512) 121 | parser.add_argument('--location_input_dim', type=int, default=2) 122 | parser.add_argument('--time_input_dim', type=int, default=1) 123 | parser.add_argument('--mlp_location_numlayer', type=int, default=3) 124 | parser.add_argument('--mlp_time_numlayer', type=int, default=3) 125 | 126 | parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50') 127 | parser.add_argument('--use-data-subset', action='store_true') 128 | parser.add_argument('--subset-size', type=int, default=10) 129 | 130 | parser.add_argument('--kg-embed-model', choices=['distmult', 'conve'], default='distmult') 131 | 132 | # ConvE hyperparams 133 | parser.add_argument('--embedding-shape1', type=int, default=20, help='The first dimension of the reshaped 2D embedding. The second dimension is infered. Default: 20') 134 | parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.') 135 | parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.') 136 | parser.add_argument('--feat-drop', type=float, default=0.2, help='Dropout for the convolutional features. Default: 0.2.') 137 | parser.add_argument('--use-bias', action='store_true', default=True, help='Use a bias in the convolutional layer. Default: True') 138 | parser.add_argument('--hidden-size', type=int, default=9728, help='The side of the hidden layer. The required size changes with the size of the embeddings. Default: 9728 (embedding size 200).') 139 | 140 | args = parser.parse_args() 141 | 142 | print('args = {}'.format(args)) 143 | args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu') 144 | 145 | # Set random seed 146 | torch.manual_seed(args.seed) 147 | np.random.seed(args.seed) 148 | random.seed(args.seed) 149 | 150 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 151 | 152 | # construct OTT parent map 153 | datacsv_id_id = datacsv.loc[(datacsv['datatype_h'] == 'id') & (datacsv['datatype_t'] == 'id')] 154 | node_parent_map = {} 155 | 156 | for idx in range(len(datacsv_id_id)): 157 | node = int(float(datacsv.iloc[idx, 0])) 158 | parent = int(float(datacsv.iloc[idx, -3])) 159 | 160 | node_parent_map[node] = parent 161 | 162 | entity_id_file = os.path.join(args.data_dir, 'entity2id_subtree.json') 163 | 164 | if not os.path.exists(entity_id_file): 165 | entity2id = {} # each of triple types have their own entity2id 166 | 167 | for i in tqdm(range(datacsv.shape[0])): 168 | if datacsv.iloc[i,1] == "id": 169 | _get_id(entity2id, str(int(float(datacsv.iloc[i,0])))) 170 | 171 | if datacsv.iloc[i,-2] == "id": 172 | _get_id(entity2id, str(int(float(datacsv.iloc[i,-3])))) 173 | json.dump(entity2id, open(entity_id_file, 'w')) 174 | else: 175 | entity2id = json.load(open(entity_id_file, 'r')) 176 | 177 | num_ent_id = len(entity2id) 178 | 179 | print('len(entity2id) = {}'.format(len(entity2id))) 180 | 181 | # print('entity2id = {}'.format(entity2id)) 182 | id2entity = {v:k for k,v in entity2id.items()} 183 | 184 | target_list = generate_target_list(datacsv, entity2id) 185 | 186 | val_image_to_id_dataset = iWildCamOTTDataset(datacsv, args.split, args, entity2id, target_list, head_type="image", tail_type="id") 187 | print('len(val_image_to_id_dataset) = {}'.format(len(val_image_to_id_dataset))) 188 | 189 | val_loader = DataLoader( 190 | val_image_to_id_dataset, 191 | shuffle=False, # Do not shuffle eval datasets 192 | sampler=None, 193 | batch_size=args.batch_size, 194 | num_workers=4, 195 | pin_memory=True) 196 | 197 | model = MKGE(args, num_ent_id, target_list, args.device) 198 | 199 | model.to(args.device) 200 | 201 | # restore from ckpt 202 | if args.ckpt_path: 203 | ckpt = torch.load(args.ckpt_path) 204 | model.load_state_dict(ckpt['model'], strict=False) 205 | print('ckpt loaded...') 206 | 207 | y_pred_dict = evaluate(model, val_loader, target_list, node_parent_map, args) 208 | 209 | json.dump(y_pred_dict, open(os.path.join(args.out_dir, 'y_pred_dict_{}.json'.format(args.split)), 'w')) 210 | 211 | 212 | -------------------------------------------------------------------------------- /gen_utils/dump_imageonly_pred_specie_wise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | import sys 12 | import json 13 | from collections import defaultdict 14 | import math 15 | 16 | sys.path.append('../') 17 | 18 | from model import MKGE 19 | from resnet import Resnet18, Resnet50 20 | 21 | from tqdm import tqdm 22 | from utils import collate_list, detach_and_clone, move_to 23 | import torch.optim as optim 24 | from torch.utils.data import Dataset, DataLoader 25 | from wilds.common.metrics.all_metrics import Accuracy 26 | from PIL import Image 27 | from dataset import iWildCamOTTDataset 28 | import torchvision.transforms as transforms 29 | 30 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 31 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 32 | 33 | def print_ancestors(species_id, node_parent_map, target_list, taxon_id_to_name, overall_id_to_name): 34 | out = [] 35 | curr_node = species_id 36 | 37 | while True: 38 | if str(curr_node) in taxon_id_to_name: 39 | out.append(taxon_id_to_name[str(curr_node)]) 40 | else: 41 | out.append(overall_id_to_name[str(curr_node)]) 42 | break 43 | 44 | if curr_node not in node_parent_map: 45 | break 46 | curr_node = node_parent_map[curr_node] 47 | 48 | print(' --> '.join(out)) 49 | 50 | def evaluate(model, val_loader, args): 51 | model.eval() 52 | torch.set_grad_enabled(False) 53 | 54 | epoch_y_true = [] 55 | epoch_y_pred = [] 56 | 57 | batch_idx = 0 58 | 59 | y_pred_dict = {} 60 | 61 | for label_id in range(182): 62 | y_pred_dict[label_id] = [] 63 | 64 | for labeled_batch in tqdm(val_loader): 65 | x, y_true = labeled_batch 66 | x = move_to(x, args.device) 67 | y_true = move_to(y_true, args.device) 68 | 69 | outputs = model(x) 70 | 71 | batch_results = { 72 | # 'g': g, 73 | 'y_true': y_true.cpu(), 74 | 'y_pred': outputs.cpu(), 75 | # 'metadata': metadata, 76 | } 77 | 78 | y_true = detach_and_clone(batch_results['y_true']) 79 | epoch_y_true.append(y_true) 80 | y_pred = detach_and_clone(batch_results['y_pred']) 81 | y_pred = y_pred.argmax(-1) 82 | 83 | 84 | epoch_y_pred.append(y_pred) 85 | 86 | for i in range(y_true.size(0)): 87 | x = (y_pred[i] == y_true[i]).long().item() 88 | y_pred_dict[y_true[i].item()].append(x) # 1 means prediction matches label, 0 otherwise. Used for calculating F1 score. 89 | 90 | batch_idx += 1 91 | if args.debug: 92 | break 93 | 94 | epoch_y_pred = collate_list(epoch_y_pred) 95 | epoch_y_true = collate_list(epoch_y_true) 96 | 97 | metrics = [ 98 | Accuracy(prediction_fn=None), 99 | ] 100 | 101 | results = {} 102 | 103 | for i in range(len(metrics)): 104 | results.update({ 105 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 106 | }) 107 | 108 | print(f'Eval., split: {args.split}, image to id, Average acc: {results[metrics[0].agg_metric_field]*100:.2f}') 109 | 110 | return y_pred_dict 111 | 112 | def _get_id(dict, key): 113 | id = dict.get(key, None) 114 | if id is None: 115 | id = len(dict) 116 | dict[key] = id 117 | return id 118 | 119 | def generate_target_list(data, entity2id): 120 | sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] 121 | sub = list(sub['t']) 122 | categories = [] 123 | for item in tqdm(sub): 124 | if entity2id[str(int(float(item)))] not in categories: 125 | categories.append(entity2id[str(int(float(item)))]) 126 | # print('categories = {}'.format(categories)) 127 | print("No. of target categories = {}".format(len(categories))) 128 | return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) 129 | 130 | class iWildCamDataset(Dataset): 131 | def __init__(self, datacsv, root, img_dir, mode, entity2id, target_list): # dic_data <- datas 132 | super(iWildCamDataset, self).__init__() 133 | self.mode = mode 134 | self.datacsv = datacsv.loc[datacsv['split'] == mode, :] 135 | self.root = root 136 | self.img_dir = img_dir 137 | self.entity2id = entity2id 138 | self.target_list = target_list 139 | self.entity_to_species_id = {self.target_list[i, 0].item():i for i in range(len(self.target_list))} 140 | 141 | def __len__(self): 142 | return len(self.datacsv) 143 | 144 | def __getitem__(self, idx): 145 | y = torch.tensor([self.entity_to_species_id[self.entity2id[str(int(float(self.datacsv.iloc[idx, -3])))]]], dtype=torch.long).squeeze() 146 | 147 | img = Image.open(os.path.join(self.img_dir, self.datacsv.iloc[idx, 0])).convert('RGB') 148 | 149 | transform_steps = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)]) 150 | x = transform_steps(img) 151 | 152 | return x, y 153 | 154 | if __name__=='__main__': 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('--data-dir', type=str, default='../iwildcam_v2.0/') 157 | parser.add_argument('--img-dir', type=str, default='../iwildcam_v2.0/imgs/') 158 | parser.add_argument('--split', type=str, default='val') 159 | parser.add_argument('--seed', type=int, default=813765) 160 | parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt') 161 | parser.add_argument('--out-dir', type=str) 162 | parser.add_argument('--debug', action='store_true') 163 | parser.add_argument('--no-cuda', action='store_true') 164 | parser.add_argument('--use-subtree', action='store_true', help='use truncated OTT') 165 | parser.add_argument('--batch_size', type=int, default=16) 166 | 167 | parser.add_argument('--embedding-dim', type=int, default=512) 168 | parser.add_argument('--location_input_dim', type=int, default=2) 169 | parser.add_argument('--time_input_dim', type=int, default=1) 170 | parser.add_argument('--mlp_location_numlayer', type=int, default=3) 171 | parser.add_argument('--mlp_time_numlayer', type=int, default=3) 172 | 173 | parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50') 174 | parser.add_argument('--use-data-subset', action='store_true') 175 | parser.add_argument('--subset-size', type=int, default=10) 176 | 177 | args = parser.parse_args() 178 | 179 | print('args = {}'.format(args)) 180 | args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu') 181 | 182 | # Set random seed 183 | torch.manual_seed(args.seed) 184 | np.random.seed(args.seed) 185 | random.seed(args.seed) 186 | 187 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 188 | 189 | # construct OTT parent map 190 | datacsv_id_id = datacsv.loc[(datacsv['datatype_h'] == 'id') & (datacsv['datatype_t'] == 'id')] 191 | node_parent_map = {} 192 | 193 | for idx in range(len(datacsv_id_id)): 194 | node = int(float(datacsv.iloc[idx, 0])) 195 | parent = int(float(datacsv.iloc[idx, -3])) 196 | 197 | node_parent_map[node] = parent 198 | 199 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv')) 200 | datacsv = datacsv.loc[(datacsv["datatype_h"] == "image") & (datacsv["datatype_t"] == "id")] 201 | 202 | entity2id = {} # each of triple types have their own entity2id 203 | 204 | for i in tqdm(range(datacsv.shape[0])): 205 | _get_id(entity2id, str(int(float(datacsv.iloc[i,-3])))) 206 | 207 | print('len(entity2id) = {}'.format(len(entity2id))) 208 | 209 | target_list = generate_target_list(datacsv, entity2id) 210 | 211 | val_dataset = iWildCamDataset(datacsv, os.path.join('iwildcam_v2.0', 'imgs/'), args.img_dir, args.split, entity2id, target_list) 212 | 213 | id2entity = {v:k for k,v in entity2id.items()} 214 | 215 | val_loader = DataLoader( 216 | val_dataset, 217 | shuffle=False, # Do not shuffle eval datasets 218 | sampler=None, 219 | batch_size=args.batch_size, 220 | num_workers=4, 221 | pin_memory=True) 222 | 223 | model = Resnet50(args) 224 | model.to(args.device) 225 | 226 | # restore from ckpt 227 | if args.ckpt_path: 228 | ckpt = torch.load(args.ckpt_path) 229 | model.load_state_dict(ckpt['model'], strict=False) 230 | print('ckpt loaded...') 231 | 232 | y_pred_dict = evaluate(model, val_loader, args) 233 | 234 | json.dump(y_pred_dict, open(os.path.join(args.out_dir, 'y_pred_dict_{}.json'.format(args.split)), 'w')) 235 | 236 | 237 | -------------------------------------------------------------------------------- /preprocess_data_mountain_zebra.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import json 4 | from tqdm import tqdm 5 | import argparse 6 | import random, string 7 | import os 8 | 9 | 10 | if __name__=='__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data-dir', type=str, default='data/snapshot_mountain_zebra/') 13 | parser.add_argument('--use-loc-canonical-id', action='store_true') 14 | parser.add_argument('--no-drop-nonexist-imgs', action='store_true') 15 | parser.add_argument('--split-dataset', action='store_true', help='randomly split into train/val/test splits') 16 | parser.add_argument('--no-datetime', action='store_true', help='ignore date/time') 17 | parser.add_argument('--no-location', action='store_true', help='ignore location') 18 | parser.add_argument('--species-common-names-file', type=str, default='data/snapshot_mountain_zebra/category_to_label_map.json') 19 | parser.add_argument('--img-prefix', type=str, default='') 20 | parser.add_argument('--dataset-prefix', type=str, default='') 21 | 22 | 23 | args = parser.parse_args() 24 | 25 | annot_file = os.path.join(args.data_dir, 'annotations.json') 26 | loc_file = os.path.join(args.data_dir, 'locations.csv') 27 | category_to_label_map = json.load(open(args.species_common_names_file, 'r')) 28 | 29 | annotations_json = json.load(open(annot_file)) 30 | 31 | taxon_id_to_name_filename = 'snapshot_mountain_zebra/taxon_id_to_name_lila.json' 32 | 33 | taxon_id_to_name = json.load(open(taxon_id_to_name_filename, 'r')) 34 | taxon_name_to_id = {v:k for k,v in taxon_id_to_name.items()} 35 | 36 | print('len(taxon_name_to_id) = {}'.format(len(taxon_name_to_id))) 37 | print('len(taxon_id_to_name) = {}'.format(len(taxon_id_to_name))) 38 | 39 | if os.path.exists(loc_file) and not args.use_loc_canonical_id: 40 | location_coordinates = pd.read_csv(loc_file) 41 | else: 42 | location_coordinates = None 43 | 44 | img_json = annotations_json['images'] 45 | img_json = [x for x in img_json if args.no_drop_nonexist_imgs or os.path.exists(os.path.join(args.data_dir, args.img_prefix, x['file_name']))] 46 | # print(img_json[0].keys()) 47 | 48 | # add y labels to metadata 49 | annotations = annotations_json['annotations'] 50 | 51 | annotations_image_id = [x['image_id'] for x in annotations] 52 | annotations_category_id = [x['category_id'] for x in annotations] 53 | 54 | annotations_df = pd.DataFrame(list(zip(annotations_image_id, annotations_category_id)), columns=['image_id', 'category_id']) 55 | 56 | metadata = annotations_df 57 | 58 | if 'caltech' in args.data_dir: 59 | datetime_field = 'date_captured' 60 | else: 61 | datetime_field = 'datetime' 62 | 63 | # add image filename 64 | img_ids = [x['id'] for x in tqdm(img_json)] 65 | img_filenames = [(args.img_prefix + x['file_name']) for x in img_json] 66 | 67 | img_loc = [x['location'] for x in img_json] 68 | 69 | if not args.no_datetime: 70 | img_datetime = [x[datetime_field] for x in img_json] 71 | img_df = pd.DataFrame(list(zip(img_ids, img_filenames, img_loc, img_datetime)), columns=['image_id', 'filename', 'location', 'datetime']) 72 | else: 73 | img_df = pd.DataFrame(list(zip(img_ids, img_filenames, img_loc)), columns=['image_id', 'filename', 'location']) 74 | 75 | # construct a df with location paired to split 76 | 77 | # TODO; check if list of locations is in order 78 | locs = list(img_df.location) 79 | splits = [] 80 | 81 | split_json_file = open(os.path.join(args.data_dir, 'splits.json')) 82 | split_json = json.load(split_json_file) 83 | 84 | train_locs = set(split_json['splits']['train']) 85 | val_locs = set(split_json['splits']['val']) 86 | 87 | if 'test' in split_json['splits']: 88 | test_locs = set(split_json['splits']['test']) 89 | else: 90 | test_locs = set() 91 | 92 | for loc in locs: 93 | if loc in train_locs: 94 | splits.append('train') 95 | elif loc in val_locs: 96 | splits.append('val') 97 | elif loc in test_locs: 98 | splits.append('test') 99 | 100 | print('len(img_df) = {}'.format(len(img_df))) 101 | print('len(splits) = {}'.format(len(splits))) 102 | 103 | img_df = img_df.assign(split=splits) 104 | # print(img_df.head()) 105 | print(img_df[img_df['split']=='val']) 106 | print(img_df[img_df['split']=='test']) 107 | 108 | img_df = img_df.drop_duplicates(subset=['image_id']) 109 | 110 | if location_coordinates is not None: 111 | location_coordinates.columns = ['location', 'elevation', 'geometry'] 112 | img_df = pd.merge(img_df, location_coordinates, how='left', left_on=['location'], right_on=['location']) # [location', 'date', 'image_id', 'category_id', 'filename'] 113 | 114 | # replace location by actual (lat, lon) coordinates 115 | 116 | locs = [img_df.iloc[i, -1] for i in range(len(img_df))] 117 | locs = [np.array(x.replace('c(','').replace(')','').split(', ')).astype(float) for x in locs] 118 | # print(locs) 119 | 120 | img_df.location = locs 121 | 122 | # print(img_df.head()) 123 | # print(img_df.columns) 124 | elif not args.no_location: 125 | locs = list(img_df.location) 126 | locs = ['{}_{}'.format(loc, args.dataset_prefix) for loc in locs] 127 | img_df.location = locs 128 | 129 | 130 | metadata = metadata.drop_duplicates(subset=['image_id'], keep=False) 131 | 132 | # print duplicates 133 | # ids = metadata['image_id'] 134 | # print(metadata[ids.isin(ids[ids.duplicated()])].sort_values('image_id')) 135 | 136 | print('len(img_df) = {}'.format(len(img_df))) 137 | print('len(metadata) before = {}'.format(len(metadata))) 138 | 139 | metadata = pd.merge(metadata, img_df, how='inner', left_on=['image_id'], right_on=['image_id']) # [location', 'date', 'image_id', 'category_id', 'filename'] 140 | print(metadata.columns) 141 | print(metadata.head()) 142 | 143 | print('len(metadata) after = {}'.format(len(metadata))) 144 | 145 | # add category names 146 | category = annotations_json['categories'] 147 | # print('species_labels = {}'.format(species_labels)) 148 | print('len(category) before = {}'.format(len(category))) 149 | # print(category) 150 | 151 | if 'ena24' in args.data_dir: 152 | for item in category: 153 | item['name'] = item['name'].lower() 154 | 155 | category = [x for x in category if x['name'] in category_to_label_map] 156 | 157 | print('len(category) after = {}'.format(len(category))) 158 | # print('category after = {}'.format([x['name'] for x in category])) 159 | 160 | category_ids = [x['id'] for x in category] 161 | category_names = [taxon_name_to_id[category_to_label_map[x['name']]] for x in category] 162 | 163 | # for x in category_names: 164 | # print(x in all_taxons) 165 | # assert x in all_taxons 166 | 167 | category_df = pd.DataFrame(list(zip(category_ids, category_names)), columns=['category_id', 'name']) 168 | metadata = pd.merge(metadata, category_df, how='inner', left_on=['category_id'], right_on=['category_id']) # [location', 'date', 'image_id', 'category_id', 'filename', 'name'] 169 | 170 | print('len(metadata) = {}'.format(len(metadata))) 171 | 172 | print(metadata.columns) 173 | 174 | if args.split_dataset: 175 | splits = ['train'] * len(metadata) 176 | print('len(splits) = {}'.format(len(splits))) 177 | 178 | n_val_samples = int(0.15 * len(metadata)) 179 | n_test_samples = int(0.15 * len(metadata)) 180 | splits[:n_val_samples] = ['val']*n_val_samples 181 | splits[n_val_samples : n_val_samples+n_test_samples] = ['test']*n_test_samples 182 | random.shuffle(splits) 183 | 184 | print('len(splits) = {}'.format(len(splits))) 185 | print(splits.count('train')) 186 | print(splits.count('val')) 187 | print(splits.count('test')) 188 | 189 | metadata = metadata.assign(split=splits) 190 | 191 | # create category_id_to_name 192 | category_id_to_name = {x['id']:x['name'] for x in category} 193 | 194 | taxon = pd.read_csv("snapshot_mountain_zebra/taxon.csv") 195 | print('len(taxon) = {}'.format(len(taxon))) 196 | 197 | if not args.no_location: 198 | takeLocation = metadata.loc[:, ['filename', 'location', 'split']] 199 | takeLocation.insert(loc=1, column='r', value=2) 200 | takeLocation.insert(loc=1, column='datatype_h', value='image') 201 | takeLocation.insert(loc=4, column='datatype_t', value='location') 202 | takeLocation.insert(loc=6, column='dataset', value=args.dataset_prefix) 203 | takeLocation.columns = ['h', 'datatype_h', 'r', 't', 'datatype_t', 'split', 'dataset'] 204 | print(takeLocation.head()) 205 | 206 | if not args.no_datetime: 207 | takeTime = metadata.loc[:, ['filename', 'datetime', 'split']] 208 | takeTime.insert(loc=1, column='r', value=0) 209 | takeTime.insert(loc=1, column='datatype_h', value='image') 210 | takeTime.insert(loc=4, column='datatype_t', value='time') 211 | takeTime.insert(loc=6, column='dataset', value=args.dataset_prefix) 212 | takeTime.columns = ['h', 'datatype_h', 'r', 't', 'datatype_t', 'split', 'dataset'] 213 | print(takeTime.head()) 214 | 215 | imageIsIn = metadata.loc[:, ['filename', 'name', 'split']] 216 | imageIsIn.insert(loc=1, column='r', value=3) 217 | imageIsIn.insert(loc=1, column='datatype_h', value='image') 218 | imageIsIn.insert(loc=4, column='datatype_t', value='id') 219 | imageIsIn.insert(loc=6, column='dataset', value=args.dataset_prefix) 220 | imageIsIn.columns = ['h', 'datatype_h', 'r', 't', 'datatype_t', 'split', 'dataset'] 221 | print(imageIsIn.head()) 222 | 223 | dataset = pd.concat([taxon, imageIsIn, takeTime, takeLocation], ignore_index=True) 224 | 225 | out_file = os.path.join(args.data_dir, 'data_triples.csv') 226 | dataset.to_csv(out_file, index=False) 227 | 228 | -------------------------------------------------------------------------------- /run_image_only_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torchvision 10 | import sys 11 | from collections import defaultdict 12 | import math 13 | import torchvision.transforms as transforms 14 | 15 | from resnet import Resnet50 16 | 17 | from tqdm import tqdm 18 | from utils import collate_list, detach_and_clone, move_to 19 | import torch.optim as optim 20 | from torch.utils.data import Dataset, DataLoader 21 | from wilds.common.metrics.all_metrics import Accuracy 22 | from PIL import Image 23 | from pytorchtools import EarlyStopping 24 | 25 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 26 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 27 | 28 | def run_epoch(model, train_loader, val_loader, optimizer, epoch, args, early_stopping, train): 29 | 30 | if train: 31 | model.train() 32 | torch.set_grad_enabled(True) 33 | else: 34 | model.eval() 35 | torch.set_grad_enabled(False) 36 | 37 | epoch_y_true = [] 38 | epoch_y_pred = [] 39 | 40 | batches = train_loader if train else val_loader 41 | 42 | batches = tqdm(batches) 43 | last_batch_idx = len(batches)-1 44 | 45 | criterion = nn.CrossEntropyLoss() 46 | 47 | batch_idx = 0 48 | for labeled_batch in batches: 49 | if train: 50 | x, y_true = labeled_batch 51 | x = move_to(x, args.device) 52 | y_true = move_to(y_true, args.device) 53 | 54 | outputs = model(x) 55 | 56 | batch_results = { 57 | # 'g': g, 58 | 'y_true': y_true.cpu(), 59 | 'y_pred': outputs.cpu(), 60 | # 'metadata': metadata, 61 | } 62 | 63 | # compute objective 64 | loss = criterion(batch_results['y_pred'], batch_results['y_true']) 65 | batch_results['objective'] = loss.item() 66 | loss.backward() 67 | 68 | # update model and logs based on effective batch 69 | optimizer.step() 70 | model.zero_grad() 71 | 72 | else: 73 | x, y_true = labeled_batch 74 | x = move_to(x, args.device) 75 | y_true = move_to(y_true, args.device) 76 | 77 | outputs = model(x) 78 | 79 | batch_results = { 80 | # 'g': g, 81 | 'y_true': y_true.cpu(), 82 | 'y_pred': outputs.cpu(), 83 | # 'metadata': metadata, 84 | } 85 | 86 | batch_results['objective'] = criterion(batch_results['y_pred'], batch_results['y_true']).item() 87 | 88 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 89 | y_pred = detach_and_clone(batch_results['y_pred']) 90 | y_pred = y_pred.argmax(-1) 91 | 92 | epoch_y_pred.append(y_pred) 93 | 94 | effective_batch_idx = batch_idx + 1 95 | 96 | batch_idx += 1 97 | if args.debug and batch_idx > 100: 98 | break 99 | 100 | epoch_y_pred = collate_list(epoch_y_pred) 101 | epoch_y_true = collate_list(epoch_y_true) 102 | # epoch_metadata = collate_list(epoch_metadata) 103 | 104 | metrics = [ 105 | Accuracy(prediction_fn=None), 106 | ] 107 | 108 | results = {} 109 | 110 | for i in range(len(metrics)): 111 | results.update({ 112 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 113 | }) 114 | 115 | results_str = ( 116 | f"Average acc: {results[metrics[0].agg_metric_field]:.3f}\n" 117 | ) 118 | 119 | if not train: # just for eval. 120 | early_stopping(-1*results[metrics[0].agg_metric_field], model, optimizer) 121 | 122 | results['epoch'] = epoch 123 | # if dataset['verbose']: 124 | print('Epoch eval:\n') 125 | print(results_str) 126 | 127 | return results, epoch_y_pred 128 | 129 | class iWildCamDataset(Dataset): 130 | def __init__(self, datacsv, img_dir, mode, entity2id, target_list): # dic_data <- datas 131 | super(iWildCamDataset, self).__init__() 132 | self.mode = mode 133 | self.datacsv = datacsv.loc[datacsv['split'] == mode, :] 134 | self.img_dir = img_dir 135 | self.entity2id = entity2id 136 | self.target_list = target_list 137 | self.entity_to_species_id = {self.target_list[i, 0].item():i for i in range(len(self.target_list))} 138 | 139 | def __len__(self): 140 | return len(self.datacsv) 141 | 142 | def __getitem__(self, idx): 143 | y = torch.tensor([self.entity_to_species_id[self.entity2id[str(int(float(self.datacsv.iloc[idx, 3])))]]], dtype=torch.long).squeeze() 144 | 145 | img = Image.open(os.path.join(self.img_dir, self.datacsv.iloc[idx, 0])).convert('RGB') 146 | 147 | transform_steps = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)]) 148 | x = transform_steps(img) 149 | 150 | return x, y 151 | 152 | def _get_id(dict, key): 153 | id = dict.get(key, None) 154 | if id is None: 155 | id = len(dict) 156 | dict[key] = id 157 | return id 158 | 159 | def generate_target_list(data, entity2id): 160 | sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] 161 | sub = list(sub['t']) 162 | categories = [] 163 | for item in tqdm(sub): 164 | if entity2id[str(int(float(item)))] not in categories: 165 | categories.append(entity2id[str(int(float(item)))]) 166 | print("No. of target categories = {}".format(len(categories))) 167 | return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) 168 | 169 | def main(): 170 | 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument('--dataset', choices=['iwildcam', 'mountain_zebra'], default='iwildcam') 173 | parser.add_argument('--data-dir', type=str, default='iwildcam_v2.0/') 174 | parser.add_argument('--img-dir', type=str, default='iwildcam_v2.0/imgs/') 175 | parser.add_argument('--batch_size', type=int, default=16) 176 | parser.add_argument('--n_epochs', type=int, default=12) 177 | parser.add_argument('--lr', type=float, default=3e-5) 178 | parser.add_argument('--weight_decay', type=float, default=0.0) 179 | parser.add_argument('--device', type=int, nargs='+', default=[0]) 180 | parser.add_argument('--seed', type=int, default=813765) 181 | parser.add_argument('--save-dir', type=str, default='ckpts/toy/') 182 | parser.add_argument('--debug', action='store_true') 183 | parser.add_argument('--early-stopping-patience', type=int, default=5, help='early stop if metric does not improve for x epochs') 184 | 185 | args = parser.parse_args() 186 | 187 | print('args = {}'.format(args)) 188 | 189 | # Set device 190 | if torch.cuda.is_available(): 191 | device_count = torch.cuda.device_count() 192 | if len(args.device) > device_count: 193 | raise ValueError(f"Specified {len(args.device)} devices, but only {device_count} devices found.") 194 | 195 | device_str = ",".join(map(str, args.device)) 196 | os.environ["CUDA_VISIBLE_DEVICES"] = device_str 197 | args.device = torch.device("cuda") 198 | else: 199 | args.device = torch.device("cpu") 200 | 201 | # Set random seed 202 | # set_seed(args.seed) 203 | torch.manual_seed(args.seed) 204 | np.random.seed(args.seed) 205 | random.seed(args.seed) 206 | 207 | if args.dataset == 'iwildcam': 208 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 209 | else: 210 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'data_triples.csv'), low_memory=False) 211 | 212 | datacsv = datacsv.loc[(datacsv["datatype_h"] == "image") & (datacsv["datatype_t"] == "id")] 213 | 214 | entity2id = {} # each of triple types have their own entity2id 215 | 216 | for i in tqdm(range(datacsv.shape[0])): 217 | _get_id(entity2id, str(int(float(datacsv.iloc[i,3])))) 218 | 219 | print('len(entity2id) = {}'.format(len(entity2id))) 220 | 221 | target_list = generate_target_list(datacsv, entity2id) 222 | 223 | train_dataset = iWildCamDataset(datacsv, args.img_dir, 'train', entity2id, target_list) 224 | val_dataset = iWildCamDataset(datacsv, args.img_dir, 'val', entity2id, target_list) 225 | 226 | 227 | train_loader = DataLoader( 228 | train_dataset, 229 | shuffle=True, # Shuffle training dataset 230 | sampler=None, 231 | batch_size=args.batch_size, 232 | num_workers=4, 233 | pin_memory=True) 234 | 235 | val_loader = DataLoader( 236 | val_dataset, 237 | shuffle=False, # Do not shuffle eval datasets 238 | sampler=None, 239 | batch_size=args.batch_size, 240 | num_workers=4, 241 | pin_memory=True) 242 | 243 | model = Resnet50(args) 244 | model.to(args.device) 245 | 246 | params = filter(lambda p: p.requires_grad, model.parameters()) 247 | optimizer = optim.Adam( 248 | params, 249 | lr=args.lr, 250 | weight_decay=args.weight_decay) 251 | 252 | best_val_metric = None 253 | early_stopping = EarlyStopping(patience=args.early_stopping_patience, verbose=True, ckpt_path=os.path.join(args.save_dir, 'model.pt'), best_ckpt_path=os.path.join(args.save_dir, 'best_model.pt')) 254 | 255 | for epoch in range(args.n_epochs): 256 | print('\nEpoch [%d]:\n' % epoch) 257 | 258 | # First run training 259 | run_epoch(model, train_loader, val_loader, optimizer, epoch, args, early_stopping, train=True) 260 | 261 | # Then run val 262 | val_results, y_pred = run_epoch(model, train_loader, val_loader, optimizer, epoch, args, early_stopping, train=False) 263 | 264 | if early_stopping.early_stop: 265 | print("Early stopping...") 266 | break 267 | 268 | if __name__=='__main__': 269 | main() 270 | -------------------------------------------------------------------------------- /dataset_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import numpy as np 5 | import re 6 | from math import pi 7 | from re import match 8 | from PIL import Image 9 | from torchvision import transforms 10 | from torch_geometric.data import Data 11 | from torch.utils.data import Dataset, DataLoader 12 | import json 13 | import pickle 14 | import pandas as pd 15 | 16 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 17 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 18 | 19 | class iWildCamOTTDataset(Dataset): 20 | def __init__(self, datacsv, mode, args, entity2id, target_list, disjoint=True, output_subgraph=False, is_train=False): # dic_data <- datas 21 | super(iWildCamOTTDataset, self).__init__() 22 | 23 | self.datacsv_id = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'id') & (datacsv['split'] == mode), :] 24 | self.datacsv_loc = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'location') & (datacsv['split'] == mode), :] 25 | self.datacsv_time = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'time') & (datacsv['split'] == mode), :] 26 | 27 | # create dataframe with both location and time 28 | self.datacsv_loc_time_left = pd.merge(self.datacsv_loc, self.datacsv_time, how='left', left_on=['h','datatype_h','split'], right_on=['h','datatype_h','split']) 29 | loc = torch.stack([getNumber(x) for x in self.datacsv_loc.loc[:, 't'].values.tolist()], dim=0) 30 | 31 | # print(loc) 32 | # print('loc = {}'.format(loc.size())) 33 | self.loc_avg = loc.mean(dim=0) 34 | 35 | if args.dataset == 'iwildcam': 36 | time = torch.stack([torch.tensor(datatime_divide(x, args)) for x in self.datacsv_time.loc[:, 't'].values.tolist()]) 37 | else: 38 | time = torch.stack([torch.tensor(date_divide(x, args)) for x in self.datacsv_time.loc[:, 't'].values.tolist()]) 39 | 40 | # print(time) 41 | self.time_avg = time.mean(dim=0) 42 | 43 | # print('time = {}'.format(time.size())) 44 | 45 | # print('self.loc_avg = {}'.format(self.loc_avg)) 46 | # print('self.time_avg = {}'.format(self.time_avg)) 47 | 48 | self.datacsv_loc_time = pd.merge(self.datacsv_loc, self.datacsv_time, how='outer', left_on=['h','datatype_h','split'], right_on=['h','datatype_h','split']) 49 | 50 | # r remains id 2 (corr. to location) 51 | # h,datatype_h,r,t,datatype_t,split 52 | self.datacsv_loc_time = self.datacsv_loc_time.loc[:, ['h','t_x','t_y','split']] 53 | self.datacsv_loc_time.columns = ['h','location', 'time','split'] 54 | 55 | datacsv_ilt = pd.merge(self.datacsv_loc_time, self.datacsv_id, how='outer', left_on=['h','split'], right_on=['h','split']) 56 | datacsv_ilt = datacsv_ilt.loc[:, ['h','location', 'time', 't', 'split']] 57 | datacsv_ilt.columns = ['image','location', 'time', 'species_id', 'split'] 58 | 59 | # print(len(self.datacsv)) 60 | # print(self.datacsv.head()) 61 | 62 | self.datacsv = datacsv_ilt 63 | 64 | # print("The length of {}2{} dataset is {}".format(head_type, tail_type, len(self.datacsv))) 65 | 66 | self.args = args 67 | self.mode = mode 68 | self.entity2id = entity2id 69 | self.target_list = target_list 70 | self.entity_to_species_id = {self.target_list[i, 0].item():i for i in range(len(self.target_list))} 71 | 72 | # print(self.entity_to_species_id) 73 | 74 | if args.use_data_subset: 75 | train_indices = np.random.choice(np.arange(len(self.datacsv)), size=args.subset_size, replace=False) 76 | self.datacsv = self.datacsv.iloc[train_indices] 77 | 78 | 79 | # print('shape(self.datacsv) = {}'.format(self.datacsv.shape)) 80 | # print(self.datacsv.head()) 81 | datacsv_loc = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'location')] 82 | 83 | 84 | self.location_to_id = {} 85 | # print(datacsv_loc) 86 | 87 | if args.dataset == 'iwildcam': 88 | for i in range(len(datacsv_loc)): 89 | loc = datacsv_loc.iloc[i, 3] 90 | 91 | assert loc[0] == '[' 92 | assert loc[-1] == ']' 93 | # print(loc) 94 | if self.args.use_cluster_centroids_for_location: 95 | loc = self.loc_centroid_map[loc] 96 | 97 | if loc not in self.location_to_id: 98 | self.location_to_id[loc] = len(self.location_to_id) 99 | 100 | if self.args.use_location_breakdown: 101 | self.all_locs = torch.stack(list(map(lambda x:GPSToHMS(x), self.location_to_id.keys()))) 102 | else: 103 | self.all_locs = torch.stack(list(map(lambda x:getNumber(x), self.location_to_id.keys()))) 104 | 105 | # print(self.location_to_id) 106 | self.all_timestamps = None 107 | 108 | datacsv_time = datacsv.loc[(datacsv['datatype_h'] == 'image') & (datacsv['datatype_t'] == 'time')] 109 | self.time_to_id = {} 110 | 111 | for i in range(len(datacsv_time)): 112 | time = datacsv_time.iloc[i, 3] 113 | 114 | if self.args.dataset == 'iwildcam': 115 | month, hour = get_separate_time(time) 116 | else: 117 | # month = get_separate_date(time) 118 | # print(time) 119 | month, hour = get_separate_time(time) 120 | 121 | _HOUR_RAD = 2 * pi / 24 122 | _MONTH_RAD = 2 * pi / 12 123 | 124 | m1, m2 = point(month, _MONTH_RAD) 125 | 126 | if self.args.dataset == 'iwildcam': 127 | h1, h2 = point(hour, _HOUR_RAD) 128 | 129 | if self.args.only_hour: 130 | if self.args.use_circular_space: 131 | time = (h1, h2) 132 | else: 133 | time = (hour,) 134 | elif self.args.only_month or self.args.dataset in ['inat18', 'inat21_mammals']: 135 | if self.args.use_circular_space: 136 | time = (m1, m2) 137 | else: 138 | time = (month,) 139 | else: 140 | if self.args.use_circular_space: 141 | time = (m1, m2, h1, h2) 142 | else: 143 | time = (month, hour) 144 | 145 | if time not in self.time_to_id: 146 | self.time_to_id[time] = len(self.time_to_id) 147 | 148 | # print(self.time_to_id) 149 | self.all_timestamps = torch.stack(list(map(lambda x:torch.tensor(x), self.time_to_id.keys()))) 150 | if len(self.all_timestamps.size())==1: 151 | self.all_timestamps = self.all_timestamps.unsqueeze(-1) 152 | 153 | # print('all_timestamps = {}'.format(self.all_timestamps.size())) 154 | 155 | if self.args.img_embed_model in ['resnet18', 'resnet50']: 156 | self.transform_steps = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)]) 157 | else: 158 | raise NotImplementedError 159 | 160 | def __len__(self): 161 | return len(self.datacsv) 162 | 163 | # @profile 164 | def __getitem__(self, idx): 165 | # 'image','location', 'time', 'species_id', 'split' 166 | 167 | image_filename = self.datacsv.iloc[idx, 0] 168 | img = Image.open(os.path.join(self.args.img_dir, image_filename)).convert('RGB') 169 | # transform_steps = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)]) 170 | img = self.transform_steps(img) 171 | 172 | edge_index, edge_type = [], [] 173 | 174 | location = self.datacsv.iloc[idx, 1] 175 | 176 | location_inp = None 177 | if isinstance(location, float) and np.isnan(location): 178 | location_inp = self.loc_avg 179 | 180 | time = self.datacsv.iloc[idx, 2] 181 | 182 | time_inp = None 183 | if isinstance(time, float) and np.isnan(time): 184 | time_inp = self.time_avg 185 | 186 | species_id = self.datacsv.iloc[idx, 3] 187 | species_id = torch.tensor([self.entity_to_species_id[self.entity2id[str(int(float(species_id)))]]], dtype=torch.long).squeeze(-1) 188 | 189 | if location_inp is None: 190 | location_inp = getNumber(location) 191 | 192 | if time_inp is None: 193 | if self.args.dataset == 'iwildcam': 194 | time_inp = torch.tensor(datatime_divide(time, self.args)) 195 | else: 196 | time_inp = torch.tensor(date_divide(time, self.args)) 197 | 198 | return img, location_inp, time_inp, species_id 199 | 200 | 201 | def getNumber(x): 202 | # return torch.tensor(np.array(re.findall(r"\d+\.?\d*", x), dtype=float), dtype=torch.float) 203 | return torch.tensor(np.fromstring(x[1:-1], dtype=float, sep=' '), dtype=torch.float) 204 | 205 | def get_separate_time(item): 206 | m = match(r"(.*)-(.*)-(.*) (.*):(.*):(\d{2})", item) 207 | years, month, day, hour, minutes, second = m.groups() 208 | return float(month), float(hour) 209 | 210 | def get_separate_date(item): 211 | m = match(r"(.*)-(.*)-(.*)", item) 212 | years, month, day = m.groups() 213 | return float(month) 214 | 215 | def datatime_divide(timestamp, args): # season{0:spring, 1: summer 2:autumn, 3:winter} hor{0:day, 1:night} 216 | month, hour = get_separate_time(timestamp) 217 | 218 | _HOUR_RAD = 2 * pi / 24 219 | _MONTH_RAD = 2 * pi / 12 220 | 221 | m1, m2 = point(month, _MONTH_RAD) 222 | h1, h2 = point(hour, _HOUR_RAD) 223 | 224 | if args.only_hour: 225 | if args.use_circular_space: 226 | return (h1, h2) 227 | else: 228 | return (hour,) 229 | elif args.only_month: 230 | if args.use_circular_space: 231 | return (m1, m2) 232 | else: 233 | return (month,) 234 | 235 | # if hour < 5 or hour > 18: 236 | # day_night = 0 237 | # else: 238 | # day_night = 1 239 | # print('timestamp = {}, day_night = {}'.format(timestamp, day_night)) 240 | if args.use_circular_space: 241 | return (m1, m2, h1, h2) 242 | else: 243 | return (month, hour) 244 | 245 | def date_divide(timestamp, args): 246 | month = get_separate_date(timestamp) 247 | 248 | _MONTH_RAD = 2 * pi / 12 249 | 250 | m1, m2 = point(month, _MONTH_RAD) 251 | 252 | if args.use_circular_space: 253 | return (m1, m2) 254 | else: 255 | return (month,) 256 | 257 | def point(m, rad): 258 | from math import sin, cos 259 | # place on circle 260 | return sin(m * rad), cos(m * rad) 261 | 262 | 263 | def separate(year): 264 | regex = "^(?P\d{0,2}?)(?P\d?)(?P\d)$" 265 | return match(regex, year) 266 | 267 | 268 | def getSeparated(item): 269 | _MINUTE_RAD = 2 * pi / 60 270 | _HOUR_RAD = 2 * pi / 24 271 | _DAY_RAD = 2 * pi / 31 272 | _MONTH_RAD = 2 * pi / 12 273 | _YEAR_DECADE_RAD = 2 * pi / 10 274 | m = match(r"(.*)-(.*)-(.*) (.*):(.*):(\d{2})", item) 275 | years, month, day, hour, minutes, second = m.groups() 276 | separated = separate(years) 277 | c = int(separated.group('century')) 278 | decade = int(separated.group('decade')) 279 | year = int(separated.group('year')) 280 | dec1, dec2 = point(decade, _YEAR_DECADE_RAD) 281 | y1, y2 = point(year, _YEAR_DECADE_RAD) 282 | m1, m2 = point(int(month), _MONTH_RAD) 283 | d1, d2 = point(int(day), _DAY_RAD) 284 | h1, h2 = point(int(hour), _HOUR_RAD) 285 | min1, min2 = point(int(minutes), _MINUTE_RAD) 286 | sec1, sec2 = point(int(second), _MINUTE_RAD) 287 | return torch.tensor(np.array([c, dec1, dec2, y1, y2, m1, m2, d1, d2, h1, h2, min1, min2, sec1, sec2]), 288 | dtype=torch.float) 289 | 290 | def D2Dms(d_data): 291 | d_data = float(d_data) 292 | d = int(d_data) 293 | m = int((d_data-d)*60) 294 | s = ((d_data-d)*60-m)*60 295 | return d,m,s 296 | 297 | 298 | def GPSToHMS(x, parse_regex=True): 299 | # print('x = {}'.format(x)) 300 | 301 | if parse_regex: 302 | a = re.findall(r"\d+\.?\d*", x) 303 | else: 304 | a = x 305 | 306 | # print('a = {}'.format(a)) 307 | lon = a[0] 308 | lat = a[1] 309 | # print('lat = {}'.format(lat)) 310 | # print('lon = {}'.format(lon)) 311 | 312 | dl, ml, sl = D2Dms(lon) 313 | da, ma, sa = D2Dms(lat) 314 | 315 | # print(f'dl = {dl}, ml = {ml}, sl = {sl}') 316 | # print(f'da = {da}, ma = {ma}, sa = {sa}') 317 | 318 | _MINUTE_RAD = 2 * pi / 60 319 | _HOUR_RAD = 2 * pi / 24 320 | 321 | dl_1, dl_2 = point(int(dl), _HOUR_RAD) 322 | ml_1, ml_2 = point(int(ml), _MINUTE_RAD) 323 | sl_1, sl_2 = point(int(sl), _MINUTE_RAD) 324 | 325 | da_1, da_2 = point(int(da), _HOUR_RAD) 326 | ma_1, ma_2 = point(int(ma), _MINUTE_RAD) 327 | sa_1, sa_2 = point(int(sa), _MINUTE_RAD) 328 | 329 | # print(f'dl_1 = {dl_1}, dl_2 = {dl_2}, ml_1 = {ml_1}, ml_2 = {ml_2}, sl_1 = {sl_1}, sl_2 = {sl_2}') 330 | # print(f'da_1 = {da_1}, da_2 = {da_2}, ma_1 = {ma_1}, ma_2 = {ma_2}, sa_1 = {sa_1}, sa_2 = {sa_2}') 331 | 332 | return torch.tensor(np.array([dl_1, dl_2, ml_1, ml_2, sl_1, sl_2, da_1, da_2, ma_1, ma_2, sa_1, sa_2], dtype=float), dtype=torch.float) 333 | 334 | -------------------------------------------------------------------------------- /gen_utils/analyze_taxonomy_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | import sys 12 | import json 13 | from collections import defaultdict, Counter 14 | import math 15 | 16 | sys.path.append('../') 17 | 18 | from model import MKGE 19 | from resnet import Resnet18, Resnet50 20 | 21 | from tqdm import tqdm 22 | from utils import collate_list, detach_and_clone, move_to 23 | import torch.optim as optim 24 | from torch.utils.data import Dataset, DataLoader 25 | from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 26 | from PIL import Image 27 | from dataset import iWildCamOTTDataset 28 | 29 | def level(a, node_parent_map): 30 | if a not in node_parent_map: 31 | return 0 32 | parent = node_parent_map[a] 33 | return level(parent, node_parent_map)+1 34 | 35 | def height(a, parent_node_map): 36 | ans = -1 37 | 38 | if a not in parent_node_map: 39 | return 0 40 | 41 | for child in parent_node_map[a]: 42 | ans = max(ans, height(child, parent_node_map)) 43 | 44 | return ans+1 45 | 46 | 47 | def least_common_ancestor(a, b, node_parent_map): 48 | 49 | if level(a, node_parent_map) > level(b, node_parent_map): 50 | a, b = b, a 51 | 52 | # if both are not at same level then move lower node upwards 53 | d = level(b, node_parent_map) - level(a, node_parent_map) 54 | 55 | # node_parent_map[i] stores the parent of node i 56 | while d > 0: 57 | b = node_parent_map[b] 58 | d-= 1 59 | 60 | # base case if one was the ancestor of other node 61 | if a == b: 62 | return a 63 | 64 | # print('a = {}, b = {}'.format(a, b)) 65 | if a not in node_parent_map or b not in node_parent_map: 66 | return '805080' # return root as ancestor 67 | 68 | while node_parent_map[a] != node_parent_map[b]: 69 | a = node_parent_map[a] 70 | b = node_parent_map[b] 71 | 72 | # print('flag 1') 73 | 74 | return node_parent_map[a] 75 | 76 | 77 | def print_ancestors(species_id, node_parent_map, target_list, taxon_id_to_name, overall_id_to_name): 78 | out = [] 79 | curr_node = species_id 80 | 81 | while True: 82 | if str(curr_node) in taxon_id_to_name: 83 | out.append(taxon_id_to_name[str(curr_node)]) 84 | else: 85 | out.append(overall_id_to_name[str(curr_node)]) 86 | break 87 | 88 | if curr_node not in node_parent_map: 89 | break 90 | curr_node = node_parent_map[curr_node] 91 | 92 | print(' --> '.join(out)) 93 | 94 | def evaluate(model, val_loader, id2entity, overall_id_to_name, taxon_id_to_name, target_list, node_parent_map, parent_node_map, args): 95 | model.eval() 96 | torch.set_grad_enabled(False) 97 | 98 | epoch_y_true = [] 99 | epoch_y_pred = [] 100 | 101 | batch_idx = 0 102 | correct_idx = [] 103 | 104 | avg_lca_height = 0 105 | total = 0 106 | 107 | for labeled_batch in tqdm(val_loader): 108 | h, r, t = labeled_batch 109 | h = move_to(h, args.device) 110 | r = move_to(r, args.device) 111 | t = move_to(t, args.device) 112 | 113 | outputs = model.forward_ce(h, r, t, triple_type=('image', 'id')) 114 | 115 | batch_results = { 116 | 'y_true': t.cpu(), 117 | 'y_pred': outputs.cpu(), 118 | } 119 | 120 | y_true = detach_and_clone(batch_results['y_true']) 121 | epoch_y_true.append(y_true) 122 | y_pred = detach_and_clone(batch_results['y_pred']) 123 | y_pred = y_pred.argmax(-1) 124 | 125 | b_range = torch.arange(y_pred.size()[0], device=args.device) 126 | 127 | arg_outputs = torch.argsort(outputs, dim=-1, descending=True) 128 | rank = 1 + torch.argsort(arg_outputs, dim=-1, descending=False)[b_range, y_true] 129 | # print('rank = {}'.format(rank)) 130 | 131 | for i in range(y_true.size(0)): 132 | if y_pred[i] == y_true[i]: 133 | correct_idx.append(batch_idx * args.batch_size + i) 134 | else: 135 | lca = least_common_ancestor(int(id2entity[target_list[y_pred[i]].item()]), int(id2entity[target_list[y_true[i]].item()]), node_parent_map) 136 | lca_height = height(lca, parent_node_map) 137 | 138 | avg_lca_height += lca_height 139 | total += 1 140 | 141 | epoch_y_pred.append(y_pred) 142 | 143 | batch_idx += 1 144 | if args.debug and batch_idx>10: 145 | break 146 | 147 | epoch_y_pred = collate_list(epoch_y_pred) 148 | epoch_y_true = collate_list(epoch_y_true) 149 | 150 | metrics = [ 151 | Accuracy(prediction_fn=None), 152 | Recall(prediction_fn=None, average='macro'), 153 | F1(prediction_fn=None, average='macro'), 154 | ] 155 | 156 | results = {} 157 | 158 | for i in range(len(metrics)): 159 | results.update({ 160 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 161 | }) 162 | 163 | print(f'Eval., split: {args.split}, image to id, Average acc: {results[metrics[0].agg_metric_field]*100:.2f}, F1 macro: {results[metrics[2].agg_metric_field]*100:.2f}') 164 | 165 | avg_lca_height = avg_lca_height/total 166 | 167 | # print('total = {}'.format(total)) 168 | # print('avg_lca_height = {}'.format(avg_lca_height)) 169 | 170 | return correct_idx, epoch_y_pred.tolist(), epoch_y_true.tolist(), avg_lca_height 171 | 172 | def _get_id(dict, key): 173 | id = dict.get(key, None) 174 | if id is None: 175 | id = len(dict) 176 | dict[key] = id 177 | return id 178 | 179 | def generate_target_list(data, entity2id): 180 | sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] 181 | sub = list(sub['t']) 182 | categories = [] 183 | for item in tqdm(sub): 184 | if entity2id[str(int(float(item)))] not in categories: 185 | categories.append(entity2id[str(int(float(item)))]) 186 | # print('categories = {}'.format(categories)) 187 | print("No. of target categories = {}".format(len(categories))) 188 | return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) 189 | 190 | def check_list_equal(list_1, list_2): 191 | for i in range(len(list_1)): 192 | if list_1[i] != list_2[i]: 193 | return False 194 | return True 195 | 196 | 197 | if __name__=='__main__': 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument('--data-dir', type=str, default='../iwildcam_v2.0/') 200 | parser.add_argument('--img-dir', type=str, default='../iwildcam_v2.0/imgs/') 201 | parser.add_argument('--split', type=str, default='val') 202 | parser.add_argument('--seed', type=int, default=813765) 203 | 204 | parser.add_argument('--ckpt-1-path', type=str, default=None, help='path to ckpt 1 for restarting expt') 205 | parser.add_argument('--ckpt-2-path', type=str, default=None, help='path to ckpt 1 for restarting expt') 206 | 207 | parser.add_argument('--debug', action='store_true') 208 | parser.add_argument('--no-cuda', action='store_true') 209 | parser.add_argument('--batch_size', type=int, default=16) 210 | 211 | parser.add_argument('--embedding-dim', type=int, default=512) 212 | parser.add_argument('--location_input_dim', type=int, default=2) 213 | parser.add_argument('--time_input_dim', type=int, default=1) 214 | parser.add_argument('--mlp_location_numlayer', type=int, default=3) 215 | parser.add_argument('--mlp_time_numlayer', type=int, default=3) 216 | 217 | parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50') 218 | parser.add_argument('--use-data-subset', action='store_true') 219 | parser.add_argument('--subset-size', type=int, default=10) 220 | parser.add_argument('--add-id-id', action='store_true', help='add idtoid triples in addition to other triples for training') 221 | 222 | parser.add_argument('--kg-embed-model', choices=['distmult', 'conve'], default='distmult') 223 | 224 | # ConvE hyperparams 225 | parser.add_argument('--embedding-shape1', type=int, default=20, help='The first dimension of the reshaped 2D embedding. The second dimension is infered. Default: 20') 226 | parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.') 227 | parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.') 228 | parser.add_argument('--feat-drop', type=float, default=0.2, help='Dropout for the convolutional features. Default: 0.2.') 229 | parser.add_argument('--use-bias', action='store_true', default=True, help='Use a bias in the convolutional layer. Default: True') 230 | parser.add_argument('--hidden-size', type=int, default=9728, help='The side of the hidden layer. The required size changes with the size of the embeddings. Default: 9728 (embedding size 200).') 231 | 232 | args = parser.parse_args() 233 | 234 | print('args = {}'.format(args)) 235 | args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu') 236 | 237 | # Set random seed 238 | torch.manual_seed(args.seed) 239 | np.random.seed(args.seed) 240 | random.seed(args.seed) 241 | 242 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 243 | 244 | # construct OTT parent map 245 | datacsv_id_id = datacsv.loc[(datacsv['datatype_h'] == 'id') & (datacsv['datatype_t'] == 'id')] 246 | 247 | node_parent_map = {} 248 | parent_node_map = defaultdict(list) 249 | 250 | for idx in range(len(datacsv_id_id)): 251 | node = int(float(datacsv.iloc[idx, 0])) 252 | parent = int(float(datacsv.iloc[idx, -3])) 253 | 254 | node_parent_map[node] = parent 255 | parent_node_map[parent].append(node) 256 | 257 | # print('node_parent_map = {}'.format(node_parent_map)) 258 | # sys.exit(0) 259 | 260 | entity_id_file = os.path.join(args.data_dir, 'entity2id_subtree.json') 261 | 262 | if not os.path.exists(entity_id_file): 263 | entity2id = {} # each of triple types have their own entity2id 264 | 265 | for i in tqdm(range(datacsv.shape[0])): 266 | if datacsv.iloc[i,1] == "id": 267 | _get_id(entity2id, str(int(float(datacsv.iloc[i,0])))) 268 | 269 | if datacsv.iloc[i,-2] == "id": 270 | _get_id(entity2id, str(int(float(datacsv.iloc[i,-3])))) 271 | json.dump(entity2id, open(entity_id_file, 'w')) 272 | else: 273 | entity2id = json.load(open(entity_id_file, 'r')) 274 | 275 | num_ent_id = len(entity2id) 276 | 277 | print('len(entity2id) = {}'.format(len(entity2id))) 278 | 279 | # print('entity2id = {}'.format(entity2id)) 280 | id2entity = {v:k for k,v in entity2id.items()} 281 | 282 | target_list = generate_target_list(datacsv, entity2id) 283 | # print('target_list = {}'.format(target_list)) 284 | 285 | val_image_to_id_dataset = iWildCamOTTDataset(datacsv, args.split, args, entity2id, target_list, head_type="image", tail_type="id") 286 | print('len(val_image_to_id_dataset) = {}'.format(len(val_image_to_id_dataset))) 287 | 288 | val_loader = DataLoader( 289 | val_image_to_id_dataset, 290 | shuffle=False, # Do not shuffle eval datasets 291 | sampler=None, 292 | batch_size=args.batch_size, 293 | num_workers=0, 294 | pin_memory=True) 295 | 296 | 297 | model_1 = MKGE(args, num_ent_id, target_list, args.device) 298 | model_2 = MKGE(args, num_ent_id, target_list, args.device) 299 | 300 | model_1.to(args.device) 301 | model_2.to(args.device) 302 | 303 | overall_id_to_name = json.load(open(os.path.join(args.data_dir, 'overall_id_to_name.json'), 'r')) 304 | taxon_id_to_name = json.load(open(os.path.join(args.data_dir, 'taxon_id_to_name.json'), 'r')) 305 | 306 | taxon_id_to_name['8032203'] = 'empty' 307 | 308 | # restore from ckpt 309 | if args.ckpt_1_path: 310 | ckpt = torch.load(args.ckpt_1_path, map_location=args.device) 311 | model_1.load_state_dict(ckpt['model'], strict=False) 312 | print('ckpt loaded...') 313 | 314 | if args.ckpt_2_path: 315 | ckpt = torch.load(args.ckpt_2_path, map_location=args.device) 316 | model_2.load_state_dict(ckpt['model'], strict=False) 317 | print('ckpt loaded...') 318 | 319 | model_1_correct_idx, model_1_pred, model_1_true, lca_model_1 = evaluate(model_1, val_loader, id2entity, overall_id_to_name, taxon_id_to_name, target_list, node_parent_map, parent_node_map, args) 320 | 321 | model_2_correct_idx, model_2_pred, model_2_true, lca_model_2 = evaluate(model_2, val_loader, id2entity, overall_id_to_name, taxon_id_to_name, target_list, node_parent_map, parent_node_map, args) 322 | 323 | print('lca_model_1 = {}'.format(lca_model_1)) 324 | print('lca_model_2 = {}'.format(lca_model_2)) 325 | 326 | # model_1 - model_2 327 | model_1_correct = list(set(model_1_correct_idx) - set(model_2_correct_idx)) 328 | 329 | # print('len(model_1_correct) = {}'.format(len(model_1_correct))) 330 | 331 | assert check_list_equal(model_1_true, model_2_true) 332 | 333 | # show taxonomy for cases where model_1 is correct but model_2 is not 334 | 335 | true_pred_c = Counter() 336 | 337 | for idx in model_1_correct: 338 | model_1_true_label = model_1_true[idx] 339 | model_1_pred_label = model_1_pred[idx] 340 | model_2_pred_label = model_2_pred[idx] 341 | 342 | assert model_1_true_label == model_1_pred_label # model_1 is correct for this example, model_2 is incorrect 343 | 344 | print('true_label = {}, model_1_pred_label = {}, model_2_pred_label = {}'.format(overall_id_to_name[id2entity[target_list[model_1_true_label].item()]], overall_id_to_name[id2entity[target_list[model_1_pred_label].item()]], overall_id_to_name[id2entity[target_list[model_2_pred_label].item()]])) 345 | 346 | true_pred_c.update([(overall_id_to_name[id2entity[target_list[model_1_true_label].item()]], overall_id_to_name[id2entity[target_list[model_2_pred_label].item()]])]) 347 | 348 | # print taxonomy (list of ancestors) for y_true 349 | print('ancestors of y_true: ') 350 | print_ancestors(int(id2entity[target_list[model_1_true_label].item()]), node_parent_map, target_list, taxon_id_to_name, overall_id_to_name) 351 | 352 | print('ancestors of y_pred: ') 353 | print_ancestors(int(id2entity[target_list[model_2_pred_label].item()]), node_parent_map, target_list, taxon_id_to_name, overall_id_to_name) 354 | 355 | print('\n') 356 | 357 | print(true_pred_c.most_common()) 358 | 359 | 360 | 361 | 362 | 363 | -------------------------------------------------------------------------------- /gps_locations.json: -------------------------------------------------------------------------------- 1 | {"110": {"latitude": -2.5465793928905707, "longitude": 29.415316715226382}, "499": {"latitude": 0.3812788773177989, "longitude": 36.8764970596741}, "0": {"latitude": -2.6262412503467187, "longitude": 29.35055652840072}, "185": {"latitude": 0.4145722029954055, "longitude": 36.83183842771321}, "64": {"latitude": -19.789640623635297, "longitude": -58.65293658236221}, "456": {"latitude": 17.386912518489904, "longitude": -89.21441303895942}, "450": {"latitude": 0.30349604653751905, "longitude": 36.88682334252553}, "312": {"latitude": -2.557353111471825, "longitude": 29.388135443922582}, "395": {"latitude": 17.47085658062343, "longitude": -89.25412173316448}, "522": {"latitude": 0.37818839195510845, "longitude": 36.84870391664237}, "293": {"latitude": -2.5660196854954873, "longitude": 29.264567825620652}, "233": {"latitude": 0.31597154226593666, "longitude": 36.8639790237097}, "509": {"latitude": 17.45973003390207, "longitude": -89.29153367242331}, "541": {"latitude": -2.580788550656127, "longitude": 29.396976826843545}, "531": {"latitude": 17.340220544240648, "longitude": -89.27484144414585}, "224": {"latitude": -2.5595452688202682, "longitude": 29.424445303107326}, "511": {"latitude": 17.439904284672217, "longitude": -89.41342659908402}, "444": {"latitude": 17.426287047940964, "longitude": -89.29927048665242}, "167": {"latitude": -2.5357441356427697, "longitude": 29.377802131072592}, "374": {"latitude": -1.885242313042777, "longitude": -76.84970379623749}, "151": {"latitude": 0.3596586871013336, "longitude": 36.912900761193235}, "9": {"latitude": -1.8542432186023636, "longitude": -76.84303058393503}, "282": {"latitude": -2.622327853110669, "longitude": 29.38838026891098}, "278": {"latitude": 0.4223472145962541, "longitude": 36.8212468917753}, "65": {"latitude": 0.3942274379637513, "longitude": 36.88438780117372}, "287": {"latitude": 17.368132706606215, "longitude": -89.27789470315696}, "73": {"latitude": 0.39570909648587754, "longitude": 36.871834127753296}, "123": {"latitude": -2.513776617328506, "longitude": 29.370985157878255}, "298": {"latitude": 0.47822735488168094, "longitude": 36.85475242946671}, "45": {"latitude": -1.8936431157718412, "longitude": -76.84158894064917}, "470": {"latitude": 17.54354318749816, "longitude": -89.3752821909996}, "242": {"latitude": 0.3730209825613568, "longitude": 36.89154172231868}, "131": {"latitude": 0.2998261240016198, "longitude": 36.888809462630654}, "518": {"latitude": 0.31058552282441704, "longitude": 36.9110880674045}, "124": {"latitude": 0.3661099234994124, "longitude": 36.88426357397884}, "489": {"latitude": 0.4184907280919391, "longitude": 36.83272593511341}, "230": {"latitude": -1.8440626992747937, "longitude": -76.92632491387245}, "457": {"latitude": -1.8787353302322152, "longitude": -76.92143820544406}, "387": {"latitude": 0.3776345355965877, "longitude": 36.88889947723915}, "492": {"latitude": 0.3723875424635886, "longitude": 36.84222434360628}, "390": {"latitude": -2.322555569447371, "longitude": 29.363897437858604}, "28": {"latitude": 0.4095609084003708, "longitude": 36.85158912415017}, "453": {"latitude": 17.48750633433792, "longitude": -89.35969966917423}, "109": {"latitude": 17.48670574225246, "longitude": -89.38450801813721}, "179": {"latitude": 0.36911597991826767, "longitude": 36.88247352058484}, "532": {"latitude": 0.4578707945636438, "longitude": 36.86856291306131}, "231": {"latitude": -2.538673842927624, "longitude": 29.40643435632873}, "409": {"latitude": 0.3980686784662569, "longitude": 36.85632244570259}, "14": {"latitude": 17.462611300994546, "longitude": -89.40753452209306}, "504": {"latitude": -2.6072618584119853, "longitude": 29.385725225195863}, "477": {"latitude": 0.3888009766362957, "longitude": 36.84335173200797}, "404": {"latitude": -1.8774590260595627, "longitude": -76.90619539604918}, "440": {"latitude": 17.50564923633344, "longitude": -89.34510662979936}, "520": {"latitude": 0.4513483520249824, "longitude": 36.84588023602884}, "422": {"latitude": 0.47825693821412835, "longitude": 36.83254542764155}, "10": {"latitude": 0.480283121689543, "longitude": 36.89220312851929}, "437": {"latitude": 0.5037197420359051, "longitude": 36.82887305620965}, "121": {"latitude": -19.848013454291902, "longitude": -58.61778840979919}, "443": {"latitude": 0.42159566745859034, "longitude": 36.89213977387095}, "53": {"latitude": 0.39701662905998314, "longitude": 36.84911889705895}, "119": {"latitude": 17.312882827663895, "longitude": -89.25041868624844}, "162": {"latitude": -2.5608803770503608, "longitude": 29.413144531252975}, "501": {"latitude": 0.472764657768066, "longitude": 36.86982812118411}, "333": {"latitude": -2.582070800125371, "longitude": 29.404693947567683}, "3": {"latitude": 17.354530364323523, "longitude": -89.33077059971222}, "132": {"latitude": 0.3692122664750138, "longitude": 36.83941767831211}, "106": {"latitude": 20.116901979578454, "longitude": 103.16998205946942}, "315": {"latitude": -1.8642956805841842, "longitude": -76.83690245696395}, "220": {"latitude": 0.41075006622792587, "longitude": 36.8723009357763}, "130": {"latitude": 0.29531409169537814, "longitude": 36.87202048158241}, "273": {"latitude": 0.3829593705848677, "longitude": 36.845312385623295}, "412": {"latitude": 0.47683976548682744, "longitude": 36.83851751103988}, "289": {"latitude": 0.46275700181129864, "longitude": 36.862931772670926}, "551": {"latitude": 0.46257812044235525, "longitude": 36.8404884721961}, "547": {"latitude": 17.44567495347379, "longitude": -89.35424276716635}, "193": {"latitude": 0.4119207376489255, "longitude": 36.820846232277844}, "177": {"latitude": 17.316871496101843, "longitude": -89.29183932513483}, "103": {"latitude": 0.27805555401582077, "longitude": 36.87395458388925}, "71": {"latitude": -2.595224856128364, "longitude": 29.395219821147176}, "105": {"latitude": 17.429667706513452, "longitude": -89.26020127000108}, "60": {"latitude": 17.45245169536857, "longitude": -89.35785667276214}, "385": {"latitude": -2.5894951586213066, "longitude": 29.390326689889275}, "229": {"latitude": 0.3932093347835856, "longitude": 36.9070151750833}, "145": {"latitude": -1.878347091415972, "longitude": -76.85460402687619}, "57": {"latitude": 17.433195051368905, "longitude": -89.23669240317896}, "122": {"latitude": 0.3834669038949662, "longitude": 36.854979246833906}, "206": {"latitude": 0.39755340561221975, "longitude": 36.92595419300959}, "108": {"latitude": 0.3943556244553961, "longitude": 36.8748229009666}, "375": {"latitude": 0.4946679268149559, "longitude": 36.84800370167744}, "329": {"latitude": 0.28282336700755756, "longitude": 36.90169218903675}, "127": {"latitude": 0.41811095589577957, "longitude": 36.88875543035903}, "458": {"latitude": 17.36647265390731, "longitude": -89.24049012393131}, "300": {"latitude": 0.38525357375101754, "longitude": 36.83489429196447}, "149": {"latitude": 17.501946163551107, "longitude": -89.38345162375447}, "400": {"latitude": 0.4855926290971324, "longitude": 36.88729886042743}, "415": {"latitude": 0.3312000722345205, "longitude": 36.891892723512605}, "370": {"latitude": -1.8770536154645834, "longitude": -76.83781456921474}, "22": {"latitude": 0.3207614226120245, "longitude": 36.900138334593095}, "449": {"latitude": 17.413248259248626, "longitude": -89.36910478923855}, "503": {"latitude": 0.36694227452076056, "longitude": 36.92751151163469}, "392": {"latitude": 17.323563176570584, "longitude": -89.21887925942633}, "417": {"latitude": -2.6126298482760455, "longitude": 29.4050298783102}, "301": {"latitude": 17.432647551432957, "longitude": -89.36824357719821}, "240": {"latitude": 0.33338342163611534, "longitude": 36.882462255025025}, "372": {"latitude": 0.2898824238009886, "longitude": 36.900216945578244}, "540": {"latitude": 0.37021142055911904, "longitude": 36.855812761992944}, "359": {"latitude": 0.4224310635946181, "longitude": 36.89581316336562}, "484": {"latitude": -19.8336680471005, "longitude": -58.610918689627304}, "264": {"latitude": -1.9420145610648762, "longitude": -76.17301619261994}, "140": {"latitude": 0.450756887540709, "longitude": 36.85953246413865}, "526": {"latitude": 17.477544788939728, "longitude": -89.21962200908158}, "512": {"latitude": 17.486799266608877, "longitude": -89.26162824679224}, "55": {"latitude": -1.885574860662518, "longitude": -76.83474114107756}, "508": {"latitude": 0.4138749991589727, "longitude": 36.89069681029452}, "379": {"latitude": 0.2881180444688006, "longitude": 36.88549290295677}, "471": {"latitude": 0.4122442349988146, "longitude": 36.86185136455476}, "497": {"latitude": 17.441534941874057, "longitude": -89.33932644761876}, "25": {"latitude": -2.622834357465289, "longitude": 29.38710228836127}, "259": {"latitude": 0.35874229563208, "longitude": 36.8783792336315}, "170": {"latitude": 0.3936811720959416, "longitude": 36.88813310063387}, "225": {"latitude": -2.5918260700357814, "longitude": 29.421881157340817}, "81": {"latitude": 0.3774443774973468, "longitude": 36.92131495280321}, "212": {"latitude": 17.33395180123941, "longitude": -89.23532347668366}, "408": {"latitude": 17.536713449953602, "longitude": -89.36949935289653}, "138": {"latitude": 0.3681844251455735, "longitude": 36.843269634152485}, "174": {"latitude": -19.824004415715304, "longitude": -58.62157045123212}, "31": {"latitude": 0.2970019375842059, "longitude": 36.85534422579122}, "416": {"latitude": -2.574455126420509, "longitude": 29.41710076742634}, "344": {"latitude": -1.8516349009701785, "longitude": -76.9309696054842}, "441": {"latitude": -19.842501421784846, "longitude": -58.61707949534127}, "8": {"latitude": 17.424200621752334, "longitude": -89.42048048427831}, "247": {"latitude": 0.42030337104609394, "longitude": 36.849365177780584}, "314": {"latitude": -2.53891885281452, "longitude": 29.381387436666873}, "144": {"latitude": -2.5722890297141343, "longitude": 29.399587661848912}, "328": {"latitude": 0.37718085102830123, "longitude": 36.86944191599825}, "251": {"latitude": 0.3770006528385197, "longitude": 36.90679905714171}, "112": {"latitude": 0.33322936837376993, "longitude": 36.86825667283611}, "306": {"latitude": 0.43636730018872955, "longitude": 36.87628456311947}, "51": {"latitude": 0.41915779941338316, "longitude": 36.90242404984711}, "291": {"latitude": 0.40593182746080636, "longitude": 36.91101405905633}, "226": {"latitude": 0.2862350369610934, "longitude": 36.86787608337893}, "254": {"latitude": 0.34412158962611333, "longitude": 36.88571069189406}, "434": {"latitude": -19.828335824111782, "longitude": -58.602082191791204}, "24": {"latitude": 0.39724054568750095, "longitude": 36.87267327752806}, "414": {"latitude": 0.3487033246800884, "longitude": 36.871134420633815}, "101": {"latitude": 0.42242543827055784, "longitude": 36.88658320163993}, "197": {"latitude": 0.36493893415735584, "longitude": 36.926233765109004}, "210": {"latitude": 17.41293389816128, "longitude": -89.32011001521958}, "91": {"latitude": 0.43389859837548245, "longitude": 36.861376533065645}, "290": {"latitude": 17.44264127986732, "longitude": -89.38775881768353}, "95": {"latitude": 0.42780379224436804, "longitude": 36.86167221895546}, "267": {"latitude": -1.9335396280450448, "longitude": -76.15850043677803}, "92": {"latitude": 17.536732220783115, "longitude": -89.29852334205884}, "249": {"latitude": 0.4193509144644843, "longitude": 36.8874792070224}, "261": {"latitude": 0.4003588712377196, "longitude": 36.892001912323856}, "176": {"latitude": 17.53551904845294, "longitude": -89.16194523701087}, "383": {"latitude": 0.4400926091787454, "longitude": 36.8499391135493}, "318": {"latitude": 0.46348606838429207, "longitude": 36.839141043009015}, "302": {"latitude": 0.3244326111241371, "longitude": 36.87252455701666}, "202": {"latitude": -2.523887951263243, "longitude": 29.39063328534747}, "218": {"latitude": 17.492381819738956, "longitude": -89.21560449646441}, "245": {"latitude": 17.26294364443537, "longitude": -90.38664553520124}, "410": {"latitude": 0.3859419414805248, "longitude": 36.8703722625851}, "113": {"latitude": 0.4265672374576081, "longitude": 36.841138579257155}, "125": {"latitude": 0.2954514171664175, "longitude": 36.897566470897324}, "529": {"latitude": 17.471844972255685, "longitude": -89.3779386416874}, "157": {"latitude": 0.4223354439540444, "longitude": 36.853573366218704}, "428": {"latitude": 17.486201882522035, "longitude": -89.34696920786806}, "159": {"latitude": 0.29771992089765453, "longitude": 36.90269370224629}, "351": {"latitude": 0.40951459423627673, "longitude": 36.87926753570763}, "373": {"latitude": 0.514976802973771, "longitude": 36.85684196258456}, "320": {"latitude": 17.49559922328219, "longitude": -89.30856133799814}, "258": {"latitude": 0.4265308065734449, "longitude": 36.89964646902503}, "152": {"latitude": 0.45862175498061547, "longitude": 36.862241291461714}, "82": {"latitude": -19.796040100276585, "longitude": -58.61505417063685}, "136": {"latitude": 0.28630780673164735, "longitude": 36.87372611217445}, "321": {"latitude": 0.33617864261335006, "longitude": 36.846555244339065}, "369": {"latitude": 0.29743851371518604, "longitude": 36.85592763201621}, "1": {"latitude": 0.31008948560970523, "longitude": 36.87881117620775}, "530": {"latitude": -2.5454882190190844, "longitude": 29.413556007395766}, "502": {"latitude": 17.526457225422796, "longitude": -89.25072300690697}, "338": {"latitude": 17.536896481052, "longitude": -89.35080836943763}, "366": {"latitude": 0.4044860383520864, "longitude": 36.85378651268177}, "169": {"latitude": 0.35758265712740117, "longitude": 36.84297716294369}, "516": {"latitude": -2.6194368615670536, "longitude": 29.406384111781836}, "521": {"latitude": -2.5195399187315743, "longitude": 29.39190832023515}, "129": {"latitude": -2.5524396866885315, "longitude": 29.396615911590544}, "356": {"latitude": -2.595351880715789, "longitude": 29.40627371938052}, "490": {"latitude": -2.470058060481568, "longitude": 29.301540157168102}, "421": {"latitude": -19.79086928442819, "longitude": -58.64803425407102}, "519": {"latitude": 17.406570152419864, "longitude": -89.34672636767576}, "69": {"latitude": 0.370441540557848, "longitude": 36.86976353155466}, "317": {"latitude": 0.37182166379606607, "longitude": 36.89841195379159}, "208": {"latitude": -2.569482223142066, "longitude": 29.392572592962967}, "277": {"latitude": 0.4316182596574106, "longitude": 36.83784976415324}, "134": {"latitude": 0.42596625430895535, "longitude": 36.861556294924334}, "515": {"latitude": 17.485280919746494, "longitude": -89.18087929186214}, "255": {"latitude": 0.39759425220151945, "longitude": 36.87502052782644}, "191": {"latitude": 0.39394057125886256, "longitude": 36.902268003389736}, "207": {"latitude": 0.3881978884859748, "longitude": 36.9269588730554}, "352": {"latitude": 0.5055720203467875, "longitude": 36.8571004152205}, "467": {"latitude": 0.36818170017229, "longitude": 36.89658942914988}, "461": {"latitude": 17.524400180479358, "longitude": -89.21131775277435}, "5": {"latitude": 17.484665796247842, "longitude": -89.41772988431285}, "244": {"latitude": -1.8567503573774151, "longitude": -76.89971211638309}, "308": {"latitude": -1.8557234293383176, "longitude": -76.81957611304658}, "358": {"latitude": 0.3662938418794768, "longitude": 36.90370368105118}, "120": {"latitude": 0.3696945646987966, "longitude": 36.92404464546971}, "549": {"latitude": -2.588867438391152, "longitude": 29.414090020238486}, "252": {"latitude": 0.44862033025045295, "longitude": 36.8600707405427}, "4": {"latitude": -19.795241430462376, "longitude": -58.6402436398495}, "161": {"latitude": 17.368935818706124, "longitude": -89.32730482451487}, "34": {"latitude": 17.423203371208736, "longitude": -89.18043304188137}, "150": {"latitude": 0.28342427400622433, "longitude": 36.87322711239579}, "397": {"latitude": 17.52372503658733, "longitude": -89.39799914819523}, "215": {"latitude": 0.362537416316581, "longitude": 36.88488873776693}, "455": {"latitude": 0.3940716108868149, "longitude": 36.834712384154834}, "98": {"latitude": 0.3193762678438628, "longitude": 36.85801522776731}, "6": {"latitude": 17.39978566993782, "longitude": -89.2469461264331}, "222": {"latitude": 0.5020779778505554, "longitude": 36.8674016347365}, "435": {"latitude": 0.32650435510101355, "longitude": 36.897721881895485}, "355": {"latitude": 0.3842413047109292, "longitude": 36.874609678567566}, "423": {"latitude": -1.9019177407764964, "longitude": -76.87396867425754}, "84": {"latitude": -1.9599186961502357, "longitude": -76.16815584543865}, "257": {"latitude": -2.6155067334355255, "longitude": 29.40035208335806}, "139": {"latitude": 0.415801634667476, "longitude": 36.86917893506737}, "548": {"latitude": -1.8657775934398815, "longitude": -76.88924294178437}, "46": {"latitude": 0.39268621424265876, "longitude": 36.83948054490969}, "368": {"latitude": 0.3568312367701896, "longitude": 36.883328353040454}, "292": {"latitude": -2.5280733863449028, "longitude": 29.387062845463717}, "217": {"latitude": -1.867009895352544, "longitude": -76.89438415655665}, "399": {"latitude": -19.84783345170883, "longitude": -58.624847784614566}, "527": {"latitude": -2.5815723966501793, "longitude": 29.40158812352173}, "430": {"latitude": 0.36447139337678314, "longitude": 36.83865575527206}, "20": {"latitude": 17.50893728307156, "longitude": -89.41245686063513}, "29": {"latitude": 0.48819039092735433, "longitude": 36.836628326432255}, "181": {"latitude": -1.8803122091103355, "longitude": -76.82636033014994}, "454": {"latitude": 0.3787448060641945, "longitude": 36.92616164717644}, "118": {"latitude": 0.47696140978121937, "longitude": 36.862298569256}, "48": {"latitude": 0.46456094214889354, "longitude": 36.84443947261405}, "483": {"latitude": 0.287346983618062, "longitude": 36.89067921540541}, "153": {"latitude": -19.85361019770085, "longitude": -58.615437022193134}, "305": {"latitude": 0.4525234880461403, "longitude": 36.84888119193651}, "59": {"latitude": 0.3368179325238944, "longitude": 36.840697236086534}, "327": {"latitude": 17.305012522645303, "longitude": -89.32744713319316}, "243": {"latitude": -2.551669708201669, "longitude": 29.390682379014187}} -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | import sys 12 | import json 13 | from collections import defaultdict 14 | import math 15 | import torchvision.transforms as transforms 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | 19 | from model import MKGE 20 | from resnet import Resnet18, Resnet50 21 | 22 | from tqdm import tqdm 23 | from utils import collate_list, detach_and_clone, move_to 24 | import torch.optim as optim 25 | from torch.utils.data import Dataset, DataLoader 26 | from wilds.common.metrics.all_metrics import Accuracy 27 | from PIL import Image 28 | from dataset import iWildCamOTTDataset 29 | from pytorchtools import EarlyStopping 30 | 31 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 32 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 33 | 34 | ################# 35 | # image to id 36 | ################# 37 | 38 | def train_image_id(train_loader, model, optimizer, writer, args, epoch_id): 39 | epoch_y_true = [] 40 | epoch_y_pred = [] 41 | 42 | batch_idx = 0 43 | avg_loss_image_id = 0.0 44 | criterion_ce = nn.CrossEntropyLoss() 45 | 46 | for labeled_batch in tqdm(train_loader['image_to_id']): 47 | h, r, t = labeled_batch 48 | h = move_to(h, args.device) 49 | r = move_to(r, args.device) 50 | t = move_to(t, args.device) 51 | 52 | outputs = model.forward_ce(h, r, t, triple_type=('image', 'id')) 53 | # outputs = model(h) 54 | 55 | batch_results = { 56 | 'y_true': t.cpu(), 57 | 'y_pred': outputs.cpu(), 58 | } 59 | 60 | # compute objective 61 | loss = criterion_ce(batch_results['y_pred'], batch_results['y_true']) 62 | batch_results['objective'] = loss.item() 63 | loss.backward() 64 | 65 | avg_loss_image_id += loss.item() 66 | 67 | # update model and logs based on effective batch 68 | optimizer.step() 69 | model.zero_grad() 70 | 71 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 72 | y_pred = detach_and_clone(batch_results['y_pred']) 73 | y_pred = y_pred.argmax(-1) 74 | 75 | epoch_y_pred.append(y_pred) 76 | 77 | batch_idx += 1 78 | if args.debug: 79 | break 80 | 81 | avg_loss_image_id = avg_loss_image_id/len(train_loader['image_to_id']) 82 | print('train/avg_loss_image_id = {}'.format(avg_loss_image_id)) 83 | writer.add_scalar('image_id_loss/train', avg_loss_image_id, epoch_id) 84 | 85 | epoch_y_pred = collate_list(epoch_y_pred) 86 | epoch_y_true = collate_list(epoch_y_true) 87 | 88 | metrics = [ 89 | Accuracy(prediction_fn=None), 90 | ] 91 | 92 | results = {} 93 | 94 | for i in range(len(metrics)): 95 | results.update({ 96 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 97 | }) 98 | 99 | 100 | results['epoch'] = epoch_id 101 | print(f'Train epoch {epoch_id}, image to id, Average acc: {results[metrics[0].agg_metric_field]*100.0:.2f}') 102 | 103 | writer.add_scalar('acc_image_id/train', results[metrics[0].agg_metric_field]*100.0, epoch_id) 104 | 105 | ################# 106 | # id to id 107 | ################# 108 | def train_id_id(train_loader, model, optimizer, writer, args, epoch_id): 109 | epoch_y_true = [] 110 | epoch_y_pred = [] 111 | 112 | batch_idx = 0 113 | avg_loss_id_id = 0.0 114 | criterion_ce = nn.CrossEntropyLoss() 115 | 116 | for labeled_batch in tqdm(train_loader['id_to_id']): 117 | h, r, t = labeled_batch 118 | # print(h, r, t) 119 | h = move_to(h, args.device) 120 | r = move_to(r, args.device) 121 | t = move_to(t, args.device) 122 | 123 | outputs = model.forward_ce(h, r, t, triple_type=('id', 'id')) 124 | 125 | batch_results = { 126 | 'y_true': t.cpu(), 127 | 'y_pred': outputs.cpu(), 128 | } 129 | 130 | # compute objective 131 | loss = criterion_ce(batch_results['y_pred'], batch_results['y_true']) 132 | avg_loss_id_id += loss.item() 133 | 134 | # print('loss = {}'.format(loss.item())) 135 | batch_results['objective'] = loss.item() 136 | loss.backward() 137 | 138 | # update model and logs based on effective batch 139 | optimizer.step() 140 | model.zero_grad() 141 | 142 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 143 | y_pred = detach_and_clone(batch_results['y_pred']) 144 | y_pred = y_pred.argmax(-1) 145 | 146 | epoch_y_pred.append(y_pred) 147 | 148 | batch_idx += 1 149 | if args.debug: 150 | break 151 | 152 | avg_loss_id_id = avg_loss_id_id/len(train_loader['id_to_id']) 153 | print('avg_loss_id_id = {}'.format(avg_loss_id_id)) 154 | writer.add_scalar('avg_loss_id_id/train', avg_loss_id_id, epoch_id) 155 | 156 | epoch_y_pred = collate_list(epoch_y_pred) 157 | epoch_y_true = collate_list(epoch_y_true) 158 | 159 | metrics = [ 160 | Accuracy(prediction_fn=None), 161 | ] 162 | 163 | results = {} 164 | 165 | for i in range(len(metrics)): 166 | results.update({ 167 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 168 | }) 169 | 170 | results['epoch'] = epoch_id 171 | print(f'Train epoch {epoch_id}, id to id, Average acc: {results[metrics[0].agg_metric_field]*100.0:.2f}') 172 | writer.add_scalar('acc_id_id/train', results[metrics[0].agg_metric_field]*100.0, epoch_id) 173 | 174 | ################# 175 | # image to location 176 | ################# 177 | def train_image_location(train_loader, model, optimizer, writer, args, epoch_id): 178 | batch_idx = 0 179 | avg_loss_image_location = 0.0 180 | criterion_bce = nn.BCEWithLogitsLoss() 181 | 182 | for labeled_batch in tqdm(train_loader['image_to_location']): 183 | h, r, t = labeled_batch 184 | 185 | # print(h, r, t) 186 | # print(t) 187 | h = move_to(h, args.device) 188 | r = move_to(r, args.device) 189 | t = move_to(t, args.device) 190 | 191 | outputs = model.forward_ce(h, r, t, triple_type=('image', 'location')) 192 | target = F.one_hot(t, num_classes=len(model.all_locs)).float() 193 | loss = criterion_bce(outputs, target) 194 | 195 | avg_loss_image_location += loss.item() 196 | 197 | loss.backward() 198 | 199 | # update model and logs based on effective batch 200 | optimizer.step() 201 | model.zero_grad() 202 | 203 | batch_idx += 1 204 | if args.debug: 205 | break 206 | 207 | avg_loss_image_location = avg_loss_image_location/len(train_loader['image_to_location']) 208 | print('avg_loss_image_location = {}'.format(avg_loss_image_location)) 209 | writer.add_scalar('avg_loss_image_location/train', avg_loss_image_location, epoch_id) 210 | 211 | ################# 212 | # image to time 213 | ################# 214 | def train_image_time(train_loader, model, optimizer, writer, args, epoch_id): 215 | batch_idx = 0 216 | avg_loss_image_time = 0.0 217 | criterion_ce = nn.CrossEntropyLoss() 218 | criterion_bce = nn.BCEWithLogitsLoss() 219 | 220 | for labeled_batch in tqdm(train_loader['image_to_time']): 221 | h, r, t = labeled_batch 222 | 223 | # print(h, r, t) 224 | h = move_to(h, args.device) 225 | r = move_to(r, args.device) 226 | t = move_to(t, args.device) 227 | 228 | outputs = model.forward_ce(h, r, t, triple_type=('image', 'time')) 229 | target = F.one_hot(t, num_classes=len(model.all_timestamps)).float() 230 | loss = criterion_bce(outputs, target) 231 | 232 | avg_loss_image_time += loss.item() 233 | 234 | loss.backward() 235 | 236 | # update model and logs based on effective batch 237 | optimizer.step() 238 | model.zero_grad() 239 | 240 | batch_idx += 1 241 | if args.debug: 242 | break 243 | 244 | avg_loss_image_time = avg_loss_image_time/len(train_loader['image_to_time']) 245 | print('avg_loss_image_time = {}'.format(avg_loss_image_time)) 246 | writer.add_scalar('avg_loss_image_time/train', avg_loss_image_time, epoch_id) 247 | 248 | def train(model, train_loader, optimizer, epoch_id, writer, args): 249 | model.train() 250 | torch.set_grad_enabled(True) 251 | 252 | if args.add_id_id: 253 | train_id_id(train_loader, model, optimizer, writer, args, epoch_id) 254 | 255 | if args.add_image_location: 256 | train_image_location(train_loader, model, optimizer, writer, args, epoch_id) 257 | 258 | if args.add_image_time and not args.add_id_id: 259 | train_image_id(train_loader, model, optimizer, writer, args, epoch_id) 260 | 261 | if args.add_image_time: 262 | train_image_time(train_loader, model, optimizer, writer, args, epoch_id) 263 | 264 | train_image_id(train_loader, model, optimizer, writer, args, epoch_id) 265 | 266 | return 267 | 268 | def evaluate(model, val_loader, optimizer, early_stopping, epoch_id, writer, args): 269 | model.eval() 270 | torch.set_grad_enabled(False) 271 | criterion = nn.CrossEntropyLoss() 272 | 273 | epoch_y_true = [] 274 | epoch_y_pred = [] 275 | 276 | batch_idx = 0 277 | avg_loss_image_id = 0.0 278 | for labeled_batch in tqdm(val_loader): 279 | h, r, t = labeled_batch 280 | h = move_to(h, args.device) 281 | r = move_to(r, args.device) 282 | t = move_to(t, args.device) 283 | 284 | outputs = model.forward_ce(h, r, t, triple_type=('image', 'id')) 285 | 286 | batch_results = { 287 | 'y_true': t.cpu(), 288 | 'y_pred': outputs.cpu(), 289 | } 290 | 291 | batch_results['objective'] = criterion(batch_results['y_pred'], batch_results['y_true']).item() 292 | avg_loss_image_id += batch_results['objective'] 293 | 294 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 295 | y_pred = detach_and_clone(batch_results['y_pred']) 296 | y_pred = y_pred.argmax(-1) 297 | 298 | epoch_y_pred.append(y_pred) 299 | 300 | batch_idx += 1 301 | if args.debug: 302 | break 303 | 304 | epoch_y_pred = collate_list(epoch_y_pred) 305 | epoch_y_true = collate_list(epoch_y_true) 306 | 307 | metrics = [ 308 | Accuracy(prediction_fn=None), 309 | ] 310 | 311 | results = {} 312 | 313 | for i in range(len(metrics)): 314 | results.update({ 315 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 316 | }) 317 | 318 | results['epoch'] = epoch_id 319 | 320 | avg_loss_image_id = avg_loss_image_id/len(val_loader) 321 | 322 | early_stopping(-1*results[metrics[0].agg_metric_field], model, optimizer) 323 | 324 | print('val/avg_loss_image_id = {}'.format(avg_loss_image_id)) 325 | writer.add_scalar('image_id_loss/val', avg_loss_image_id, epoch_id) 326 | 327 | writer.add_scalar('acc_image_id/val', results[metrics[0].agg_metric_field]*100, epoch_id) 328 | 329 | print(f'Eval. epoch {epoch_id}, image to id, Average acc: {results[metrics[0].agg_metric_field]*100:.2f}') 330 | 331 | return results, epoch_y_pred 332 | 333 | 334 | def _get_id(dict, key): 335 | id = dict.get(key, None) 336 | if id is None: 337 | id = len(dict) 338 | dict[key] = id 339 | return id 340 | 341 | def generate_target_list(data, entity2id): 342 | sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] 343 | sub = list(sub['t']) 344 | categories = [] 345 | for item in tqdm(sub): 346 | if entity2id[str(int(float(item)))] not in categories: 347 | categories.append(entity2id[str(int(float(item)))]) 348 | # print('categories = {}'.format(categories)) 349 | print("No. of target categories = {}".format(len(categories))) 350 | return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) 351 | 352 | def main(): 353 | 354 | parser = argparse.ArgumentParser() 355 | parser.add_argument('--dataset', choices=['iwildcam', 'mountain_zebra'], default='iwildcam') 356 | parser.add_argument('--data-dir', type=str, default='iwildcam_v2.0/') 357 | parser.add_argument('--img-dir', type=str, default='iwildcam_v2.0/imgs/') 358 | parser.add_argument('--batch_size', type=int, default=16) 359 | parser.add_argument('--n_epochs', type=int, default=12) 360 | parser.add_argument('--img-lr', type=float, default=3e-5, help='lr for img embed params') 361 | parser.add_argument('--lr', type=float, default=1e-3, help='default lr for all parameters') 362 | parser.add_argument('--loc-lr', type=float, default=1e-3, help='lr for location embedding') 363 | parser.add_argument('--time-lr', type=float, default=1e-3, help='lr for time embedding') 364 | parser.add_argument('--weight_decay', type=float, default=0.0) 365 | parser.add_argument('--device', type=int, nargs='+', default=[0]) 366 | parser.add_argument('--seed', type=int, default=813765) 367 | parser.add_argument('--save-dir', type=str, default='ckpts/toy/') 368 | parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt') 369 | parser.add_argument('--start-epoch', type=int, default=0, help='epoch id to restore model') 370 | parser.add_argument('--early-stopping-patience', type=int, default=5, help='early stop if metric does not improve for x epochs') 371 | parser.add_argument('--debug', action='store_true') 372 | parser.add_argument('--no-cuda', action='store_true') 373 | 374 | parser.add_argument('--kg-embed-model', choices=['distmult', 'conve'], default='distmult') 375 | parser.add_argument('--embedding-dim', type=int, default=512) 376 | parser.add_argument('--location_input_dim', type=int, default=2) 377 | parser.add_argument('--time_input_dim', type=int, default=1) 378 | parser.add_argument('--mlp_location_numlayer', type=int, default=3) 379 | parser.add_argument('--mlp_time_numlayer', type=int, default=3) 380 | 381 | parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50') 382 | parser.add_argument('--use-data-subset', action='store_true') 383 | parser.add_argument('--subset-size', type=int, default=10) 384 | 385 | parser.add_argument('--add-id-id', action='store_true', help='add idtoid triples in addition to other triples for training') 386 | parser.add_argument('--add-image-location', action='store_true', help='add imagetolocation triples in addition to other triples for training') 387 | parser.add_argument('--add-image-time', action='store_true', help='use only imagetotime triples in addition to other triples for training') 388 | parser.add_argument('--omit-double-img-id', action='store_true', help='omit double image id after location') 389 | 390 | # ConvE hyperparams 391 | parser.add_argument('--embedding-shape1', type=int, default=20, help='The first dimension of the reshaped 2D embedding. The second dimension is infered. Default: 20') 392 | parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.') 393 | parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.') 394 | parser.add_argument('--feat-drop', type=float, default=0.2, help='Dropout for the convolutional features. Default: 0.2.') 395 | parser.add_argument('--use-bias', action='store_true', default=True, help='Use a bias in the convolutional layer. Default: True') 396 | parser.add_argument('--hidden-size', type=int, default=9728, help='The side of the hidden layer. The required size changes with the size of the embeddings. Default: 9728 (embedding size 200).') 397 | 398 | args = parser.parse_args() 399 | 400 | print('args = {}'.format(args)) 401 | args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu') 402 | 403 | # Set random seed 404 | torch.manual_seed(args.seed) 405 | np.random.seed(args.seed) 406 | random.seed(args.seed) 407 | 408 | writer = SummaryWriter(log_dir=args.save_dir) 409 | 410 | if args.dataset == 'iwildcam': 411 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 412 | entity_id_file = os.path.join(args.data_dir, 'entity2id_subtree.json') 413 | else: 414 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'data_triples.csv'), low_memory=False) 415 | entity_id_file = os.path.join(args.data_dir, 'entity2id.json') 416 | 417 | 418 | if not os.path.exists(entity_id_file): 419 | entity2id = {} # each of triple types have their own entity2id 420 | 421 | for i in tqdm(range(datacsv.shape[0])): 422 | if datacsv.iloc[i,1] == "id": 423 | _get_id(entity2id, str(int(float(datacsv.iloc[i,0])))) 424 | 425 | if datacsv.iloc[i,4] == "id": 426 | _get_id(entity2id, str(int(float(datacsv.iloc[i,3])))) 427 | json.dump(entity2id, open(entity_id_file, 'w')) 428 | else: 429 | entity2id = json.load(open(entity_id_file, 'r')) 430 | 431 | num_ent_id = len(entity2id) 432 | 433 | print('len(entity2id) = {}'.format(len(entity2id))) 434 | 435 | target_list = generate_target_list(datacsv, entity2id) 436 | 437 | train_image_to_id_dataset = iWildCamOTTDataset(datacsv, 'train', args, entity2id, target_list, head_type="image", tail_type="id") 438 | print('len(train_image_to_id_dataset) = {}'.format(len(train_image_to_id_dataset))) 439 | 440 | if args.add_id_id: 441 | train_id_to_id_dataset = iWildCamOTTDataset(datacsv, 'train', args, entity2id, target_list, head_type="id", tail_type="id") 442 | print('len(train_id_to_id_dataset) = {}'.format(len(train_id_to_id_dataset))) 443 | 444 | if args.add_image_location: 445 | train_image_to_location_dataset = iWildCamOTTDataset(datacsv, 'train', args, entity2id, target_list, head_type="image", tail_type="location") 446 | print('len(train_image_to_location_dataset) = {}'.format(len(train_image_to_location_dataset))) 447 | 448 | if args.add_image_time: 449 | train_image_to_time_dataset = iWildCamOTTDataset(datacsv, 'train', args, entity2id, target_list, head_type="image", tail_type="time") 450 | print('len(train_image_to_time_dataset) = {}'.format(len(train_image_to_time_dataset))) 451 | 452 | val_image_to_id_dataset = iWildCamOTTDataset(datacsv, 'val', args, entity2id, target_list, head_type="image", tail_type="id") 453 | print('len(val_image_to_id_dataset) = {}'.format(len(val_image_to_id_dataset))) 454 | 455 | model_kwargs = {} 456 | if args.kg_embed_model == 'conve': 457 | model_kwargs['drop_last'] = True 458 | 459 | train_loader_image_to_id = DataLoader( 460 | train_image_to_id_dataset, 461 | shuffle=True, # Shuffle training dataset 462 | sampler=None, 463 | batch_size=args.batch_size, 464 | num_workers=4, 465 | pin_memory=True, 466 | **model_kwargs) 467 | 468 | if args.add_id_id: 469 | train_loader_id_to_id = DataLoader( 470 | train_id_to_id_dataset, 471 | shuffle=True, # Shuffle training dataset 472 | sampler=None, 473 | batch_size=args.batch_size, 474 | num_workers=4, 475 | pin_memory=True, 476 | **model_kwargs) 477 | 478 | if args.add_image_location: 479 | train_loader_image_to_location = DataLoader( 480 | train_image_to_location_dataset, 481 | shuffle=True, # Shuffle training dataset 482 | sampler=None, 483 | batch_size=args.batch_size, 484 | num_workers=4, 485 | pin_memory=True, 486 | **model_kwargs) 487 | 488 | if args.add_image_time: 489 | train_loader_image_to_time = DataLoader( 490 | train_image_to_time_dataset, 491 | shuffle=True, # Shuffle training dataset 492 | sampler=None, 493 | batch_size=args.batch_size, 494 | num_workers=4, 495 | pin_memory=True, 496 | **model_kwargs) 497 | 498 | train_loaders = {} 499 | 500 | train_loaders['image_to_id'] = train_loader_image_to_id 501 | 502 | if args.add_id_id: 503 | train_loaders['id_to_id'] = train_loader_id_to_id 504 | 505 | if args.add_image_location: 506 | train_loaders['image_to_location'] = train_loader_image_to_location 507 | 508 | if args.add_image_time: 509 | train_loaders['image_to_time'] = train_loader_image_to_time 510 | 511 | val_loader = DataLoader( 512 | val_image_to_id_dataset, 513 | shuffle=False, # Do not shuffle eval datasets 514 | sampler=None, 515 | batch_size=args.batch_size, 516 | num_workers=4, 517 | pin_memory=True) 518 | 519 | kwargs = {} 520 | if args.add_image_time: 521 | kwargs['all_timestamps'] = train_image_to_time_dataset.all_timestamps 522 | if args.add_image_location: 523 | kwargs['all_locs'] = train_image_to_location_dataset.all_locs 524 | 525 | model = MKGE(args, num_ent_id, target_list, args.device, **kwargs) 526 | 527 | model.to(args.device) 528 | 529 | early_stopping = EarlyStopping(patience=args.early_stopping_patience, verbose=True, ckpt_path=os.path.join(args.save_dir, 'model.pt'), best_ckpt_path=os.path.join(args.save_dir, 'best_model.pt')) 530 | 531 | params_diff_lr = ['ent_embedding', 'rel_embedding', 'image_embedding', 'location_embedding', 'time_embedding'] 532 | 533 | optimizer_grouped_parameters = [ 534 | {"params": [param for p_name, param in model.named_parameters() if not any([x in p_name for x in params_diff_lr])]}, 535 | {"params": model.ent_embedding.parameters(), "lr": args.lr}, 536 | {"params": model.rel_embedding.parameters(), "lr": args.lr}, 537 | {"params": model.image_embedding.parameters(), "lr": args.img_lr}, 538 | {"params": model.location_embedding.parameters(), "lr": args.loc_lr}, 539 | {"params": model.time_embedding.parameters(), "lr": args.time_lr}, 540 | ] 541 | 542 | optimizer = optim.Adam( 543 | optimizer_grouped_parameters, 544 | lr=args.lr, 545 | weight_decay=args.weight_decay) 546 | 547 | # restore from ckpt 548 | if args.ckpt_path: 549 | print('ckpt loaded...') 550 | ckpt = torch.load(args.ckpt_path) 551 | model.load_state_dict(ckpt['model'], strict=False) 552 | optimizer.load_state_dict(ckpt['dense_optimizer']) 553 | 554 | for epoch_id in range(args.start_epoch, args.n_epochs): 555 | print('\nEpoch [%d]:\n' % epoch_id) 556 | 557 | # First run training 558 | train(model, train_loaders, optimizer, epoch_id, writer, args) 559 | 560 | # Then run val 561 | val_results, y_pred = evaluate(model, val_loader, optimizer, early_stopping, epoch_id, writer, args) 562 | 563 | if early_stopping.early_stop: 564 | print("Early stopping...") 565 | break 566 | 567 | writer.close() 568 | 569 | if __name__=='__main__': 570 | main() 571 | -------------------------------------------------------------------------------- /run_kge_model_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | import sys 12 | import json 13 | from collections import defaultdict 14 | import math 15 | import torchvision.transforms as transforms 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | try: 19 | import wandb 20 | except Exception as e: 21 | pass 22 | 23 | # from model import DistMult, ConvE 24 | from model_st import MKGE 25 | from resnet import Resnet18, Resnet50 26 | 27 | from tqdm import tqdm 28 | from utils import collate_list, detach_and_clone, move_to 29 | import torch.optim as optim 30 | from torch.utils.data import Dataset, DataLoader 31 | from torch_geometric.data import Dataset as DatasetGeometric, DataLoader as DataLoaderGeometric 32 | 33 | from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 34 | from PIL import Image 35 | from dataset_baseline import iWildCamOTTDataset 36 | from pytorchtools import EarlyStopping 37 | 38 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 39 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 40 | 41 | def make_infinite(dataloader): 42 | while True: 43 | yield from dataloader 44 | 45 | # @profile 46 | def train(train_loader, model, optimizer, writer, args, epoch_id, scheduler): 47 | model.train() 48 | torch.set_grad_enabled(True) 49 | 50 | epoch_y_true = [] 51 | epoch_y_pred = [] 52 | 53 | batch_idx = 0 54 | avg_loss_image_id = 0.0 55 | criterion_ce = nn.CrossEntropyLoss() 56 | 57 | for labeled_batch in tqdm(train_loader): 58 | # image, location, time, species_id = labeled_batch 59 | 60 | image, location, time, species_id = labeled_batch 61 | 62 | image = move_to(image, args.device) 63 | 64 | location = move_to(location, args.device) 65 | time = move_to(time, args.device) 66 | species_id = move_to(species_id, args.device) 67 | 68 | if args.dataset == 'mountain_zebra': 69 | location = None 70 | 71 | outputs = model.forward_ce(None, image, time, location) 72 | 73 | batch_results = { 74 | 'y_true': species_id.cpu(), 75 | 'y_pred': outputs.cpu(), 76 | } 77 | 78 | # compute objective 79 | loss = criterion_ce(batch_results['y_pred'], batch_results['y_true']) 80 | batch_results['objective'] = loss.item() 81 | loss.backward() 82 | 83 | avg_loss_image_id += loss.item() 84 | 85 | # update model and logs based on effective batch 86 | optimizer.step() 87 | model.zero_grad() 88 | 89 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 90 | y_pred = detach_and_clone(batch_results['y_pred']) 91 | y_pred = y_pred.argmax(-1) 92 | 93 | epoch_y_pred.append(y_pred) 94 | 95 | batch_idx += 1 96 | if args.debug: 97 | break 98 | 99 | # x = avg_loss_image_id/(batch_idx+1) 100 | # print(x) 101 | 102 | if scheduler is not None: 103 | scheduler.step() 104 | 105 | avg_loss_image_id = avg_loss_image_id/len(train_loader) 106 | print('train/avg_loss = {}'.format(avg_loss_image_id)) 107 | writer.add_scalar('loss/train', avg_loss_image_id, epoch_id) 108 | 109 | epoch_y_pred = collate_list(epoch_y_pred) 110 | epoch_y_true = collate_list(epoch_y_true) 111 | 112 | metrics = [ 113 | Accuracy(prediction_fn=None), 114 | Recall(prediction_fn=None, average='macro'), 115 | F1(prediction_fn=None, average='macro'), 116 | ] 117 | 118 | results = {} 119 | 120 | for i in range(len(metrics)): 121 | results.update({ 122 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 123 | }) 124 | 125 | 126 | results['epoch'] = epoch_id 127 | print(f'Train epoch {epoch_id}, Average acc: {results[metrics[0].agg_metric_field]*100.0:.2f}, F1 macro: {results[metrics[2].agg_metric_field]*100.0:.2f}') 128 | 129 | writer.add_scalar('acc/train', results[metrics[0].agg_metric_field]*100.0, epoch_id) 130 | writer.add_scalar('f1_macro/train', results[metrics[2].agg_metric_field]*100.0, epoch_id) 131 | 132 | return epoch_y_pred, epoch_y_true 133 | 134 | def evaluate(model, val_loader, optimizer, early_stopping, epoch_id, writer, args): 135 | model.eval() 136 | torch.set_grad_enabled(False) 137 | criterion = nn.CrossEntropyLoss() 138 | 139 | epoch_y_true = [] 140 | epoch_y_pred = [] 141 | 142 | batch_idx = 0 143 | avg_loss_image_id = 0.0 144 | for labeled_batch in tqdm(val_loader): 145 | image, location, time, species_id = labeled_batch 146 | 147 | image = move_to(image, args.device) 148 | 149 | location = move_to(location, args.device) 150 | 151 | time = move_to(time, args.device) 152 | species_id = move_to(species_id, args.device) 153 | 154 | if args.dataset == 'mountain_zebra': 155 | location = None 156 | 157 | outputs = model.forward_ce(None, image, time, location) 158 | 159 | batch_results = { 160 | 'y_true': species_id.cpu(), 161 | 'y_pred': outputs.cpu(), 162 | } 163 | 164 | batch_results['objective'] = criterion(batch_results['y_pred'], batch_results['y_true']).item() 165 | avg_loss_image_id += batch_results['objective'] 166 | 167 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 168 | y_pred = detach_and_clone(batch_results['y_pred']) 169 | y_pred = y_pred.argmax(-1) 170 | 171 | epoch_y_pred.append(y_pred) 172 | 173 | batch_idx += 1 174 | if args.debug: 175 | break 176 | 177 | epoch_y_pred = collate_list(epoch_y_pred) 178 | epoch_y_true = collate_list(epoch_y_true) 179 | 180 | metrics = [ 181 | Accuracy(prediction_fn=None), 182 | Recall(prediction_fn=None, average='macro'), 183 | F1(prediction_fn=None, average='macro'), 184 | ] 185 | 186 | results = {} 187 | 188 | for i in range(len(metrics)): 189 | results.update({ 190 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 191 | }) 192 | 193 | results['epoch'] = epoch_id 194 | 195 | avg_loss_image_id = avg_loss_image_id/len(val_loader) 196 | 197 | print('val/avg_loss = {}'.format(avg_loss_image_id)) 198 | writer.add_scalar('loss/val', avg_loss_image_id, epoch_id) 199 | 200 | writer.add_scalar('acc/val', results[metrics[0].agg_metric_field]*100.0, epoch_id) 201 | writer.add_scalar('f1_macro/val', results[metrics[2].agg_metric_field]*100.0, epoch_id) 202 | 203 | print(f'Eval. epoch {epoch_id}, Average acc: {results[metrics[0].agg_metric_field]*100:.2f}, F1 macro: {results[metrics[2].agg_metric_field]*100:.2f}') 204 | 205 | return epoch_y_pred, epoch_y_true 206 | 207 | 208 | def _get_id(dict, key): 209 | id = dict.get(key, None) 210 | if id is None: 211 | id = len(dict) 212 | dict[key] = id 213 | return id 214 | 215 | def generate_target_list(data, entity2id): 216 | sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] 217 | sub = list(sub['t']) 218 | categories = [] 219 | for item in tqdm(sub): 220 | if entity2id[str(int(float(item)))] not in categories: 221 | categories.append(entity2id[str(int(float(item)))]) 222 | # print('categories = {}'.format(categories)) 223 | print("No. of target categories = {}".format(len(categories))) 224 | return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) 225 | 226 | def calc_agg_results(epoch_y_pred_ilt, epoch_y_true_ilt, epoch_y_pred_it, epoch_y_true_it): 227 | epoch_y_pred_overall, epoch_y_true_overall = epoch_y_pred_ilt.tolist(), epoch_y_true_ilt.tolist() 228 | 229 | epoch_y_pred_overall.extend(epoch_y_pred_it.tolist()) 230 | epoch_y_true_overall.extend(epoch_y_true_it.tolist()) 231 | 232 | epoch_y_pred_overall, epoch_y_true_overall = torch.tensor(epoch_y_pred_overall), torch.tensor(epoch_y_true_overall) 233 | 234 | metrics = [ 235 | Accuracy(prediction_fn=None), 236 | Recall(prediction_fn=None, average='macro'), 237 | F1(prediction_fn=None, average='macro'), 238 | ] 239 | 240 | results = {} 241 | 242 | for i in range(len(metrics)): 243 | results.update({ 244 | **metrics[i].compute(epoch_y_pred_overall, epoch_y_true_overall), 245 | }) 246 | 247 | return results, metrics 248 | 249 | ''' 250 | CUDA_VISIBLE_DEVICES=5 python run_kge_model_st.py --n_epochs 1 251 | ''' 252 | 253 | def main(): 254 | 255 | parser = argparse.ArgumentParser() 256 | parser.add_argument('--dataset', choices=['iwildcam', 'mountain_zebra'], default='iwildcam') 257 | parser.add_argument('--data-dir', type=str, default='iwildcam_v2.0/') 258 | parser.add_argument('--img-dir', type=str, default='iwildcam_v2.0/imgs/') 259 | parser.add_argument('--iwildcam-image-h5-path', type=str, default='/local/scratch/pahuja.9/iwildcam2020_images.h5') 260 | parser.add_argument('--batch_size', type=int, default=16) 261 | parser.add_argument('--n_epochs', type=int, default=12) 262 | parser.add_argument('--num-workers', type=int, default=4) 263 | parser.add_argument('--lr', type=float, default=1e-3, help='default lr for all parameters') 264 | parser.add_argument('--loc-lr', type=float, default=1e-3, help='lr for location embedding') 265 | parser.add_argument('--time-lr', type=float, default=1e-3, help='lr for time embedding') 266 | parser.add_argument('--weight_decay', type=float, default=0.0) 267 | parser.add_argument('--device', type=int, nargs='+', default=[0]) 268 | parser.add_argument('--seed', type=int, default=813765) 269 | parser.add_argument('--save-dir', type=str, default='ckpts/toy/') 270 | parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt') 271 | parser.add_argument('--start-epoch', type=int, default=0, help='epoch id to restore model') 272 | parser.add_argument('--early-stopping-patience', type=int, default=5, help='early stop if metric does not improve for x epochs') 273 | parser.add_argument('--debug', action='store_true') 274 | parser.add_argument('--no-cuda', action='store_true') 275 | parser.add_argument('--use-loss-es', action='store_true', help='use val. loss for early stopping') 276 | parser.add_argument('--add-inverse-rels', action='store_true', help='add inverse relations for R-GCN/CompGCN') 277 | 278 | parser.add_argument('--optimizer', choices=['adam', 'adamw'], default='adam') 279 | 280 | parser.add_argument('--embedding-dim', type=int, default=512) 281 | parser.add_argument('--location_input_dim', type=int, default=2) 282 | parser.add_argument('--time_input_dim', type=int, default=2, help='2 corresponds to hour and month. change to 1 for just hour or month.') 283 | parser.add_argument('--location_time_input_dim', type=int, default=3) 284 | parser.add_argument('--mlp_location_numlayer', type=int, default=3) 285 | parser.add_argument('--mlp_time_numlayer', type=int, default=3) 286 | parser.add_argument('--mlp_location_time_numlayer', type=int, default=3) 287 | parser.add_argument('--loc-loss-coeff', type=float, default=1e0) 288 | parser.add_argument('--num-neg-frac', type=float, default=0.2) 289 | 290 | parser.add_argument('--use-distmult-model', action='store_true') 291 | parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50', 'inc-resnet-v2'], default='resnet50') 292 | parser.add_argument('--kg-embed-model', choices=['distmult', 'conve'], default='distmult') 293 | parser.add_argument('--use-subtree', action='store_true', help='use truncated OTT') 294 | parser.add_argument('--omit-taxon-ids', action='store_true', help='omit taxon ids in embedding') 295 | parser.add_argument('--use-h5', action='store_true', help='use hdf5 instead of raw images') 296 | parser.add_argument('--use-data-subset', action='store_true') 297 | parser.add_argument('--subset-size', type=int, default=10) 298 | parser.add_argument('--use-bce-for-location', action='store_true', help='use BCE loss for location') 299 | parser.add_argument('--use-ce-for-location', action='store_true', help='use CE loss for location') 300 | parser.add_argument('--use-bce-for-time', action='store_true', help='use BCE loss for time') 301 | parser.add_argument('--use-ce-for-time', action='store_true', help='use CE loss for time') 302 | parser.add_argument('--use-bce-for-location-time', action='store_true', help='use BCE loss for location-time') 303 | parser.add_argument('--use-ce-for-location-time', action='store_true', help='use CE loss for location-time') 304 | parser.add_argument('--use-cluster-centroids-for-location', action='store_true', help='use 6 cluster centroids for location') 305 | parser.add_argument('--use-learned-loc-embed', action='store_true', help='use learned embedding for location') 306 | 307 | parser.add_argument('--exclude-image-id', action='store_true', help='exclude image-id for training') 308 | parser.add_argument('--taxonomy-type', choices=['ott', 'standard'], default='ott') 309 | parser.add_argument('--add-reverse-id-id', action='store_true', help='add reversed triples for id-id') 310 | 311 | # options for img-time 312 | parser.add_argument('--only-hour', action='store_true', help='use only hour for img-time triples') 313 | parser.add_argument('--only-month', action='store_true', help='use only month for img-time triples') 314 | parser.add_argument('--use-circular-space', action='store_true', help='use circular space for hour and month') 315 | 316 | # ConvE hyperparams 317 | parser.add_argument('--embedding-shape1', type=int, default=20, help='The first dimension of the reshaped 2D embedding. The second dimension is infered. Default: 20') 318 | parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.') 319 | parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.') 320 | parser.add_argument('--feat-drop', type=float, default=0.2, help='Dropout for the convolutional features. Default: 0.2.') 321 | parser.add_argument('--use-bias', action='store_true', default=True, help='Use a bias in the convolutional layer. Default: True') 322 | parser.add_argument('--hidden-size', type=int, default=9728, help='The side of the hidden layer. The required size changes with the size of the embeddings. Default: 9728 (embedding size 200).') 323 | 324 | # experimental 325 | parser.add_argument('--use-location-breakdown', action='store_true', help='break down location into hh,mm,ss for lat.,long.') 326 | parser.add_argument('--use-prop-sampling', action='store_true', help='mix all dataloader samples except (img, id) prop. acc. to dataset size') 327 | parser.add_argument('--use-uniform-sampling', action='store_true', help='mix all dataloader samples except (img, id) uniformly') 328 | parser.add_argument('--freeze-mlp', action='store_true') 329 | 330 | args = parser.parse_args() 331 | 332 | 333 | print('args = {}'.format(args)) 334 | args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu') 335 | 336 | 337 | # Set random seed 338 | torch.manual_seed(args.seed) 339 | np.random.seed(args.seed) 340 | random.seed(args.seed) 341 | 342 | writer = SummaryWriter(log_dir=args.save_dir) 343 | 344 | # datacsv = pd.read_csv("../camera_trap/data_triples.csv") 345 | # datacsv = datacsv.loc[(datacsv["datatype_h"] == "image") & (datacsv["datatype_t"] == "id")] 346 | 347 | if args.dataset == 'iwildcam': 348 | if args.use_subtree: 349 | if args.taxonomy_type == 'standard': 350 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree_standard.csv'), low_memory=False) 351 | else: 352 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False) 353 | else: 354 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'data_triples.csv'), low_memory=False) 355 | 356 | if args.use_subtree: 357 | if args.taxonomy_type == 'standard': 358 | entity_id_file = os.path.join(args.data_dir, 'entity2id_subtree_standard_st.json') 359 | else: 360 | entity_id_file = os.path.join(args.data_dir, 'entity2id_subtree_st.json') 361 | else: 362 | entity_id_file = os.path.join(args.data_dir, 'entity2id_st.json') 363 | else: 364 | datacsv = pd.read_csv(os.path.join(args.data_dir, 'data_triples.csv'), low_memory=False) 365 | entity_id_file = os.path.join(args.data_dir, 'entity2id.json') 366 | 367 | 368 | if not os.path.exists(entity_id_file): 369 | entity2id = {} # each of triple types have their own entity2id 370 | 371 | for i in tqdm(range(datacsv.shape[0])): 372 | if args.omit_taxon_ids and (datacsv.iloc[i,1] != 'image' or datacsv.iloc[i,4] != 'id'): 373 | continue 374 | 375 | if datacsv.iloc[i,1] == "id": 376 | _get_id(entity2id, str(int(float(datacsv.iloc[i,0])))) 377 | 378 | if datacsv.iloc[i,-2] == "id": 379 | _get_id(entity2id, str(int(float(datacsv.iloc[i,-3])))) 380 | json.dump(entity2id, open(entity_id_file, 'w')) 381 | else: 382 | entity2id = json.load(open(entity_id_file, 'r')) 383 | 384 | num_ent_id = len(entity2id) 385 | 386 | print('len(entity2id) = {}'.format(len(entity2id))) 387 | 388 | target_list = generate_target_list(datacsv, entity2id) 389 | 390 | train_ILT_dataset = iWildCamOTTDataset(datacsv, 'train', args, entity2id, target_list, disjoint=False, is_train=True) 391 | print('len(train_ILT_dataset) = {}'.format(len(train_ILT_dataset))) 392 | 393 | 394 | val_ILT_dataset = iWildCamOTTDataset(datacsv, 'val', args, entity2id, target_list, disjoint=True) 395 | print('len(val_ILT_dataset) = {}'.format(len(val_ILT_dataset))) 396 | 397 | 398 | model_kwargs = {} 399 | if args.kg_embed_model == 'conve': 400 | model_kwargs['drop_last'] = True 401 | 402 | train_loader = DataLoader( 403 | train_ILT_dataset, 404 | shuffle=True, # Shuffle training dataset 405 | sampler=None, 406 | batch_size=args.batch_size, 407 | num_workers=args.num_workers, 408 | pin_memory=True, 409 | **model_kwargs) 410 | 411 | val_loader = DataLoader( 412 | val_ILT_dataset, 413 | shuffle=False, # Do not shuffle eval datasets 414 | sampler=None, 415 | batch_size=args.batch_size, 416 | num_workers=args.num_workers, 417 | pin_memory=True) 418 | 419 | 420 | kwargs = {} 421 | 422 | model = MKGE(args, num_ent_id, target_list, args.device, **kwargs) 423 | 424 | model.to(args.device) 425 | 426 | if args.freeze_mlp: 427 | for param in model.mlp.parameters(): 428 | param.requires_grad = False 429 | 430 | early_stopping = EarlyStopping(patience=args.early_stopping_patience, verbose=True, ckpt_path=os.path.join(args.save_dir, 'model.pt'), best_ckpt_path=os.path.join(args.save_dir, 'best_model.pt')) 431 | 432 | params_diff_lr = ['ent_embedding', 'image_embedding', 'location_embedding', 'time_embedding'] 433 | 434 | optimizer_grouped_parameters = [ 435 | {"params": [param for p_name, param in model.named_parameters() if not any([x in p_name for x in params_diff_lr])]}, 436 | {"params": model.ent_embedding.parameters(), "lr": args.lr}, 437 | {"params": model.image_embedding.parameters(), "lr": 3e-5}, 438 | {"params": model.location_embedding.parameters(), "lr": args.loc_lr}, 439 | {"params": model.time_embedding.parameters(), "lr": args.time_lr}, 440 | ] 441 | 442 | n_params_model = sum(torch.numel(param) for p_name, param in model.named_parameters()) 443 | n_params_optimizer = sum([sum([torch.numel(x) for x in group['params']]) for group in optimizer_grouped_parameters]) 444 | 445 | print('n_params_model = {}'.format(n_params_model)) 446 | print('n_params_optimizer = {}'.format(n_params_optimizer)) 447 | 448 | assert n_params_model == n_params_optimizer 449 | 450 | if args.optimizer == 'adam': 451 | optimizer = optim.Adam( 452 | optimizer_grouped_parameters, 453 | lr=args.lr, 454 | weight_decay=args.weight_decay) 455 | scheduler = None 456 | 457 | elif args.optimizer == 'adamw': 458 | # optimizer = optim.AdamW(model.parameters(), lr=args.lr) 459 | optimizer = optim.AdamW( 460 | optimizer_grouped_parameters, 461 | lr=args.lr) 462 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs*5, eta_min=1e-6) 463 | 464 | 465 | else: 466 | raise NotImplementedError 467 | 468 | # restore from ckpt 469 | if args.ckpt_path: 470 | print('ckpt loaded...') 471 | ckpt = torch.load(args.ckpt_path) 472 | model.load_state_dict(ckpt['model'], strict=False) 473 | optimizer.load_state_dict(ckpt['dense_optimizer']) 474 | 475 | for epoch_id in range(args.start_epoch, args.n_epochs): 476 | print('\nEpoch [%d]:\n' % epoch_id) 477 | 478 | # First run training 479 | epoch_y_pred, epoch_y_true = train(train_loader, model, optimizer, writer, args, epoch_id, scheduler) 480 | 481 | metrics = [ 482 | Accuracy(prediction_fn=None), 483 | Recall(prediction_fn=None, average='macro'), 484 | F1(prediction_fn=None, average='macro'), 485 | ] 486 | 487 | results = {} 488 | 489 | for i in range(len(metrics)): 490 | results.update({ 491 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 492 | }) 493 | 494 | print(f'Train epoch {epoch_id}, Average acc: {results[metrics[0].agg_metric_field]*100.0:.2f}, F1 macro: {results[metrics[2].agg_metric_field]*100.0:.2f}') 495 | 496 | writer.add_scalar('acc/train', results[metrics[0].agg_metric_field]*100.0, epoch_id) 497 | writer.add_scalar('f1_macro/train', results[metrics[2].agg_metric_field]*100.0, epoch_id) 498 | 499 | # Then run val 500 | epoch_y_pred, epoch_y_true = evaluate(model, val_loader, optimizer, early_stopping, epoch_id, writer, args) 501 | 502 | results = {} 503 | 504 | for i in range(len(metrics)): 505 | results.update({ 506 | **metrics[i].compute(epoch_y_pred, epoch_y_true), 507 | }) 508 | 509 | if args.use_loss_es: 510 | early_stopping(-1*results[metrics[2].agg_metric_field], model, optimizer) 511 | else: 512 | early_stopping(-1*results[metrics[0].agg_metric_field], model, optimizer) 513 | 514 | writer.add_scalar('acc/val', results[metrics[0].agg_metric_field]*100.0, epoch_id) 515 | writer.add_scalar('f1_macro/val', results[metrics[2].agg_metric_field]*100.0, epoch_id) 516 | 517 | print(f'Eval. epoch {epoch_id}, Average acc: {results[metrics[0].agg_metric_field]*100:.2f}, F1 macro: {results[metrics[2].agg_metric_field]*100:.2f}') 518 | 519 | if early_stopping.early_stop: 520 | print("Early stopping...") 521 | break 522 | 523 | writer.close() 524 | 525 | if __name__=='__main__': 526 | main() 527 | --------------------------------------------------------------------------------