├── .gitignore ├── README.md ├── data ├── aminer_dataset.py ├── create_aminer_dataset.py ├── data.py ├── parse_aminer_citation.py └── rank_aminer.py ├── diagnostics └── ppi_test.py ├── main.py ├── maml.py ├── models ├── __init__.py ├── autoencoder.py ├── layers.py └── models.py ├── plotting ├── plot_comet.py └── plot_wandb.py ├── scripts ├── run_adamic_baseline.sh ├── run_baselines.sh ├── run_hyperparam.sh ├── run_maml.sh ├── run_plotter.sh ├── run_plotter_aminer_grad_wandb.sh ├── run_plotter_enzymes.sh ├── run_plotter_firstmmdb.sh ├── run_plotter_firstmmdb_grad_wandb.sh ├── run_plotter_ppi_grad_wandb.sh ├── run_plotter_ppi_wandb.sh ├── run_plotter_reddit.sh ├── run_ppi_best_gs.sh └── run_random.sh ├── utils └── utils.py └── vgae.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | *.pkl 3 | settings.json 4 | *.env 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Meta-Learning on Graphs ### 2 | This repository contains code for the arXiv Preprint: 3 | "Link Prediction from Sparse DataUsing Meta Learning" 4 | by: Avishek Joey Bose, Ankit Jain, Piero Molino, William L. Hamilton 5 | 6 | ArXiv Link: https://arxiv.org/abs/1912.09867 7 | If this repository is helpful in your research, please consider citing us. 8 | 9 | ``` 10 | @article{bose2019meta, 11 | title={Meta-Graph: Few Shot Link Prediction via Meta Learning}, 12 | author={Bose, Avishek Joey and Jain, Ankit and Molino, Piero and Hamilton, William L}, 13 | journal={arXiv preprint arXiv:1912.09867}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | Some Requirements: 19 | - pytorch geometric 20 | - scikit-learn==0.22 21 | - comet_ml 22 | - wandb 23 | - grakel 24 | - torchviz 25 | 26 | This codebase has many different flags so its important one familiarizes themselves with all the command line args. 27 | The easiest accesspoint to the codebase is using some prepared scripts in the scripts folder. 28 | 29 | Here are some sample commands: 30 | 31 | ## Running Graph Signature on PPI 32 | `python3 main.py --meta_train_edge_ratio=0.1 --model='VGAE' 33 | --encoder='GraphSignature' --epochs=46 --use_gcn_sig --concat_fixed_feats 34 | --inner_steps=2 --inner-lr=2.24e-3 --meta-lr=2.727e-3 --clip_grad 35 | --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 36 | --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.1'` 37 | 38 | This command will run the Meta-Graph algorithm using 10% training edges for each graph. 39 | It will also use the default GraphSignature function as the encoder in a VGAE. The `--use_gcn_sig` 40 | flag will force the GraphSignature to use a GCN style signature function and finally 41 | order 2 will perform second order optimization. 42 | -------------------------------------------------------------------------------- /data/aminer_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | from os import listdir 5 | from os.path import isfile, join 6 | import ipdb 7 | from tqdm import tqdm 8 | from operator import itemgetter 9 | import torch 10 | from torch_geometric.data import Data, Dataset, InMemoryDataset, download_url, extract_zip 11 | import networkx as nx 12 | import numpy as np 13 | from torch_geometric.data.dataset import files_exist 14 | from torch_geometric.utils import remove_self_loops 15 | 16 | class Aminer_Dataset(Dataset): 17 | def __init__(self, root, name, k_core, reprocess=False, ego=False, transform=None, pre_transform=None): 18 | self.name = name 19 | self.raw_graphs_path = join(root,'raw') 20 | self.k_core = k_core 21 | self.reprocess = reprocess 22 | self.ego = ego 23 | print("Path is %s" %(self.raw_graphs_path)) 24 | onlyfiles = [f for f in listdir(self.raw_graphs_path) if isfile(join(self.raw_graphs_path, f))] 25 | self.raw_files = onlyfiles 26 | super(Aminer_Dataset, self).__init__(root, transform, pre_transform) 27 | 28 | @property 29 | def raw_file_names(self): 30 | if files_exist(self.processed_paths): # pragma: no cover 31 | return 32 | onlyfiles = [f for f in listdir(self.raw_graphs_path) if isfile(join(self.raw_graphs_path, f))] 33 | self.raw_files = onlyfiles 34 | return onlyfiles 35 | 36 | @property 37 | def processed_file_names(self): 38 | onlyfiles = [f for f in listdir(self.processed_dir) if isfile(join(self.processed_dir, f))] 39 | return onlyfiles 40 | 41 | def __len__(self): 42 | return len(self.processed_file_names) 43 | 44 | def _process(self): 45 | if self.reprocess: 46 | self.process() 47 | if files_exist(self.processed_paths): # pragma: no cover 48 | return 49 | print('Processing...') 50 | makedirs(self.processed_dir) 51 | self.process() 52 | print('Done!') 53 | 54 | def _download(self): 55 | pass 56 | 57 | def download(self): 58 | if not os.path.exists(self.root): 59 | os.mkdir(self.root) 60 | 61 | if not os.path.exists(self.raw_dir): 62 | os.mkdir(self.raw_dir) 63 | 64 | if not os.path.exists(self.processed_dir): 65 | os.mkdir(self.processed_dir) 66 | print("Manually copy *.pkl files to AMINER/raw dir") 67 | 68 | def process(self): 69 | i = 0 70 | print("Path is %s" %(self.raw_graphs_path)) 71 | onlyfiles = [f for f in listdir(self.raw_graphs_path) if isfile(join(self.raw_graphs_path, f))] 72 | self.raw_files = onlyfiles 73 | G_len_list = [] 74 | print("Beginning to Process AMINER graphs") 75 | for j,raw_path in tqdm(enumerate(self.raw_files),total=len(self.raw_files)): 76 | path = path = osp.join(self.raw_dir, raw_path) 77 | # Read data from `raw_path`. 78 | G = nx.read_gpickle(path) 79 | # Compute K-core of Graph 80 | G = nx.k_core(G, k=self.k_core) 81 | # Re-index nodes from 0 to len(G) 82 | mapping = dict(zip(G, range(0, len(G)))) 83 | G = nx.relabel_nodes(G, mapping) 84 | if self.ego: 85 | # Randomly sample node 86 | node_and_degree = G.degree() 87 | rand_node = np.random.choice(len(G),1) 88 | G = nx.ego_graph(G,rand_node[0],radius=2) 89 | mapping = dict(zip(G, range(0, len(G)))) 90 | G = nx.relabel_nodes(G, mapping) 91 | 92 | all_embeddings = nx.get_node_attributes(G,'emb') 93 | x = np.array([val for (key,val) in all_embeddings.items()]) 94 | x = torch.from_numpy(x).to(torch.float) 95 | edge_index = torch.tensor(list(G.edges)).t().contiguous() 96 | edge_index, _ = remove_self_loops(edge_index) 97 | data = Data(edge_index=edge_index, x=x) 98 | G_len_list.append(len(G)) 99 | 100 | if self.pre_filter is not None and not self.pre_filter(data): 101 | continue 102 | 103 | if self.pre_transform is not None: 104 | data = self.pre_transform(data) 105 | 106 | if self.ego: 107 | ego_dir = self.processed_dir + '/ego/' 108 | if not os.path.exists(ego_dir): 109 | os.makedirs(ego_dir) 110 | torch.save(data, ego_dir + 'ego_data_{}.pt'.format(i)) 111 | else: 112 | torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i))) 113 | i += 1 114 | 115 | print("Avg %d, Min G %d, Max G %d", (sum(G_len_list)/len(G_len_list),\ 116 | min(G_len_list), max(G_len_list))) 117 | 118 | def get(self, idx): 119 | if self.ego: 120 | data = torch.load(osp.join(self.processed_dir, 'ego/ego_data_{}.pt'.format(idx))) 121 | else: 122 | data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx))) 123 | return data 124 | 125 | 126 | -------------------------------------------------------------------------------- /data/create_aminer_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import ipdb 4 | from tqdm import tqdm 5 | import argparse 6 | from os import listdir 7 | from os.path import isfile, join 8 | import pickle 9 | import joblib 10 | from collections import Counter 11 | from shutil import copyfile 12 | import networkx as nx 13 | import spacy 14 | import nltk 15 | import numpy as np 16 | 17 | nltk.download('stopwords') 18 | nltk_stopwords = nltk.corpus.stopwords.words('english') 19 | data_path = '/home/joey.bose/dblp_papers_v11.txt' 20 | save_path_base = '/home/joey.bose/aminer_data/' 21 | load_path_rank_base = '/home/joey.bose/aminer_data_ranked/fos/' 22 | save_path_graph_base = '/home/joey.bose/aminer_data_ranked/graphs/' 23 | raw_save_path = '/home/joey.bose/aminer_data_ranked/aminer_raw.txt' 24 | spacy_nlp = spacy.load('en_core_web_sm') 25 | glove_path = '/home/joey.bose/docker_temp/meta-graph/meta-graph/glove.840B.300d.txt' 26 | 27 | class Lang: 28 | def __init__(self): 29 | super(Lang, self).__init__() 30 | self.word2index = {} 31 | self.word2count = {} 32 | self.index2word = {} 33 | self.n_words = 0 # Count default tokens 34 | 35 | def index_words(self, sentence): 36 | for word in sentence: 37 | self.index_word(word) 38 | 39 | def index_word(self, word): 40 | if word not in self.word2index: 41 | self.word2index[word] = self.n_words 42 | self.word2count[word] = 1 43 | self.index2word[self.n_words] = word 44 | self.n_words += 1 45 | else: 46 | self.word2count[word] += 1 47 | 48 | def gen_embeddings(vocab, file, emb_size, emb_dim): 49 | """ 50 | Generate an initial embedding matrix for word_dict. 51 | If an embedding file is not given or a word is not in the embedding file, 52 | a randomly initialized vector will be used. 53 | """ 54 | # embeddings = np.random.randn(vocab.n_words, emb_size) * 0.01 55 | embeddings = np.zeros((vocab.n_words, emb_size)) 56 | print('Embeddings: %d x %d' % (vocab.n_words, emb_size)) 57 | if file is not None: 58 | print('Loading embedding file: %s' % file) 59 | pre_trained = 0 60 | for line in open(file).readlines(): 61 | sp = line.split() 62 | if(len(sp) == emb_dim + 1): 63 | if sp[0] in vocab.word2index: 64 | pre_trained += 1 65 | embeddings[vocab.word2index[sp[0]]] = [float(x) for x in sp[1:]] 66 | else: 67 | print(sp[0]) 68 | print('Pre-trained: %d (%.2f%%)' % (pre_trained, pre_trained * 100.0 / vocab.n_words)) 69 | return embeddings 70 | 71 | def process_raw_abstracts(vocab): 72 | with open(raw_save_path,"r",encoding="utf8") as f: 73 | for line in tqdm(f,total=13304586): 74 | tokens = nltk.tokenize.word_tokenize(line) 75 | tokens = [token for token in tokens if not token in nltk_stopwords] 76 | vocab.index_words(tokens) 77 | 78 | def get_node_embed(text,vocab): 79 | sum_embed = 0 80 | for word in text: 81 | embed = embeddings[vocab.word2index[word]] 82 | sum_embed += embed 83 | return sum_embed 84 | def check_graph(G): 85 | total_nodes = len(G.nodes) 86 | no_emb_nodes = 0 87 | nodes_to_delete = [] 88 | for node_str in G.nodes: 89 | try: 90 | emb = G.node[node_str]['emb'] 91 | except: 92 | no_emb_nodes += 1 93 | nodes_to_delete.append(node_str) 94 | print("%d Nodes and %d missing nodes in G " %(total_nodes, no_emb_nodes)) 95 | G.remove_nodes_from(nodes_to_delete) 96 | return G 97 | 98 | def process_line(G, line, vocab=None): 99 | try: 100 | fos = data['fos'] 101 | abstract = data['indexed_abstract'] 102 | paper_id = data['id'] 103 | references_id = data['references'] 104 | text = list(abstract['InvertedIndex'].keys()) 105 | text =" ".join(text) 106 | if args.process_raw: 107 | with open(raw_save_path,"a+") as f: 108 | f.write(text) 109 | f.write('\n') 110 | 111 | '''Create Node Embedding if Node doesn't exist ''' 112 | if vocab is not None: 113 | tokens = nltk.tokenize.word_tokenize(text) 114 | tokens = [token for token in tokens if not token in nltk_stopwords] 115 | node_emb = get_node_embed(tokens,vocab) 116 | 117 | for field in fos: 118 | name = field['name'] 119 | for ref in references_id: 120 | G.add_edge(paper_id, ref) 121 | G.node[paper_id]['emb'] = node_emb 122 | except: 123 | return G 124 | 125 | return G 126 | 127 | if __name__ == '__main__': 128 | """ 129 | Create Aminer-Citation v-11 Graphs 130 | """ 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('--topk', type=int, default='100') 133 | parser.add_argument("--process_raw", action="store_true", default=False, 134 | help='Process Raw Data') 135 | parser.add_argument("--make_vocab", action="store_true", default=False, 136 | help='Create Vocab from the raw abstract data') 137 | args = parser.parse_args() 138 | onlyfiles = [f for f in listdir(load_path_rank_base) if isfile(join(load_path_rank_base, f))] 139 | vocab = Lang() 140 | if args.make_vocab: 141 | process_raw_abstracts(vocab) 142 | joblib.dump(vocab, "aminer_100_vocab.pkl") 143 | print("Done generating vocab") 144 | embeddings = gen_embeddings(vocab,file=glove_path,emb_size=300,emb_dim=300) 145 | joblib.dump(embeddings, "aminer_100_embed.pkl") 146 | print("Done") 147 | exit() 148 | else: 149 | vocab = joblib.load("aminer_100_vocab.pkl") 150 | embeddings = joblib.load("aminer_100_embed.pkl") 151 | 152 | for i, file_ in tqdm(enumerate(onlyfiles),total=len(onlyfiles)): 153 | file_path = load_path_rank_base + file_ 154 | G = nx.Graph() 155 | with open(file_path,'r', encoding="utf8") as f: 156 | for line in f: 157 | data = json.loads(line) 158 | G = process_line(G,data,vocab) 159 | G = check_graph(G) 160 | print("%s has %d Nodes and %d edges" %(file_,len(G),len(G.edges))) 161 | if not os.path.exists(save_path_graph_base): 162 | os.mkdir(save_path_graph_base) 163 | save_path_graph = save_path_graph_base + file_.split('.')[0] + '_graph.pkl' 164 | nx.write_gpickle(G,save_path_graph) 165 | 166 | -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import os 3 | import os.path as osp 4 | import json 5 | import torch 6 | import numpy as np 7 | import ipdb 8 | import torch_geometric.transforms as T 9 | from torch_geometric.data import DataLoader,DataListLoader 10 | import ssl 11 | from utils.utils import calculate_max_nodes_in_dataset, filter_dataset 12 | import urllib 13 | from random import shuffle 14 | from torch_geometric.datasets import Planetoid,PPI,TUDataset 15 | from .aminer_dataset import Aminer_Dataset 16 | 17 | def load_dataset(name,args): 18 | ssl._create_default_https_context = ssl._create_unverified_context 19 | path = osp.join( 20 | osp.dirname(osp.realpath(__file__)), '..', 'data', name) 21 | args.fail_counter = 0 22 | args.resplit = True 23 | if name == 'PPI': 24 | train_dataset = PPI(path, split='train',transform=T.NormalizeFeatures()) 25 | val_dataset = PPI(path, split='val',transform=T.NormalizeFeatures()) 26 | test_dataset = PPI(path, split='test',transform=T.NormalizeFeatures()) 27 | args.num_features = train_dataset.num_features 28 | max_nodes = calculate_max_nodes_in_dataset(train_dataset + val_dataset + test_dataset,\ 29 | args.min_nodes) 30 | total_graphs = len(train_dataset) + len(val_dataset) + len(test_dataset) 31 | print("Total Graphs in PPI %d" %(total_graphs)) 32 | else: 33 | if name == 'ENZYMES': 34 | dataset = list(TUDataset(path,name,use_node_attr=True,\ 35 | transform=T.NormalizeFeatures())) 36 | shuffle(dataset) 37 | elif name =='REDDIT-MULTI-12K': 38 | dataset = list(TUDataset(path,name)) 39 | shuffle(dataset) 40 | max_nodes = calculate_max_nodes_in_dataset(dataset,args.min_nodes) 41 | dataset = filter_dataset(dataset,args.min_nodes,max_nodes) 42 | args.feats = torch.randn(max_nodes,args.num_fixed_features,requires_grad=False) 43 | assert(args.use_fixed_feats or args.use_same_fixed_feats) 44 | elif name =='FIRSTMM_DB': 45 | dataset = list(TUDataset(path,name)) 46 | shuffle(dataset) 47 | max_nodes = calculate_max_nodes_in_dataset(dataset,args.min_nodes) 48 | elif name =='DD': 49 | dataset = list(TUDataset(path,name)) 50 | shuffle(dataset) 51 | max_nodes = calculate_max_nodes_in_dataset(dataset,args.min_nodes) 52 | dataset = filter_dataset(dataset,args.min_nodes,args.max_nodes) 53 | elif name =='AMINER': 54 | if args.opus: 55 | path = '/mnt/share/ankit.jain/meta-graph-data/AMINER/' 56 | dataset = Aminer_Dataset(path,name, args.k_core, \ 57 | reprocess=args.reprocess, ego=args.ego) 58 | dataset = [dataset[i] for i in range(0,len(dataset))] 59 | shuffle(dataset) 60 | max_nodes = calculate_max_nodes_in_dataset(dataset,args.min_nodes) 61 | dataset = filter_dataset(dataset,args.min_nodes,args.max_nodes) 62 | elif name == 'Cora': 63 | dataset = Planetoid(path, "Cora", T.NormalizeFeatures()) 64 | else: 65 | raise NotImplementedError 66 | num_graphs = len(dataset) 67 | print("%d Graphs in Dataset" %(num_graphs)) 68 | if num_graphs == 1: 69 | train_dataset = dataset 70 | val_dataset = dataset 71 | test_dataset = dataset 72 | else: 73 | train_cutoff = int(np.round(args.train_ratio*num_graphs)) 74 | val_cutoff = train_cutoff + int(np.round(args.val_ratio*num_graphs)) 75 | train_dataset = dataset[:train_cutoff] 76 | val_dataset = dataset[train_cutoff:val_cutoff] 77 | test_dataset = dataset[val_cutoff:] 78 | try: 79 | args.num_features = train_dataset[0].x.shape[1] 80 | except: 81 | ## TODO: Load Fixed Random Features 82 | print("Using Fixed Features") 83 | args.num_features = args.num_fixed_features 84 | if args.concat_fixed_feats: 85 | args.num_features = args.num_features + args.num_concat_features 86 | print("Node Features: %d" %(args.num_features)) 87 | train_loader = DataListLoader(train_dataset, batch_size=args.train_batch_size, shuffle=False) 88 | val_loader = DataListLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False) 89 | test_loader = DataListLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False) 90 | 91 | return train_loader,val_loader,test_loader 92 | 93 | 94 | -------------------------------------------------------------------------------- /data/parse_aminer_citation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import csv 4 | import pandas as pd 5 | import ipdb 6 | from tqdm import tqdm 7 | 8 | data_path = '/home/joey.bose/dblp_papers_v11.txt' 9 | save_path_base = '/home/joey.bose/aminer_data/' 10 | 11 | def process_line(line, fail_counter): 12 | try: 13 | fos = data['fos'] 14 | abstract = data['indexed_abstract'] 15 | for field in fos: 16 | name = field['name'] 17 | save_path = save_path_base + name.replace(" ", "") + '.txt' 18 | 19 | if not os.path.exists(save_path_base): 20 | os.mkdir(save_path_base) 21 | 22 | with open(save_path,"a+") as f: 23 | f.write(json.dumps(line)) 24 | f.write('\n') 25 | except: 26 | fail_counter +=1 27 | if fail_counter % 100 ==0: 28 | print("Failed on a File | Total Fails %d" %(fail_counter)) 29 | return fail_counter 30 | 31 | if __name__ == '__main__': 32 | """ 33 | Parse Aminer-Citation v-11 34 | """ 35 | fail_counter = 0 36 | with open(data_path,'r', encoding="utf8") as f: 37 | for line in tqdm(f,total=4107340): 38 | data = json.loads(line) 39 | fail_counter = process_line(data,fail_counter) 40 | -------------------------------------------------------------------------------- /data/rank_aminer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import ipdb 4 | from tqdm import tqdm 5 | from os import listdir 6 | from os.path import isfile, join 7 | import pickle 8 | import joblib 9 | from collections import Counter 10 | from shutil import copyfile 11 | 12 | data_path = '/home/joey.bose/dblp_papers_v11.txt' 13 | save_path_base = '/home/joey.bose/aminer_data/' 14 | save_path_rank_base = '/home/joey.bose/aminer_data_ranked/' 15 | 16 | def file_len(fname): 17 | with open(fname) as f: 18 | for i, l in enumerate(f): 19 | pass 20 | return i + 1 21 | 22 | if __name__ == '__main__': 23 | """ 24 | Rank Aminer-Citation v-11 25 | """ 26 | fail_counter = 0 27 | onlyfiles = [f for f in listdir(save_path_base) if isfile(join(save_path_base, f))] 28 | file_dict = {} 29 | # Rank top 100 biggest files 30 | Topk = 100 31 | 32 | for i, file_ in tqdm(enumerate(onlyfiles),total=len(onlyfiles)): 33 | file_path = save_path_base + file_ 34 | num_lines = file_len(file_path) 35 | file_dict[file_] = num_lines 36 | 37 | topk_files = sorted(file_dict.items(), key=lambda x:-x[1])[:Topk] 38 | print(topk_files) 39 | f = open("aminer_file_dict.pkl","wb") 40 | pickle.dump(file_dict,f) 41 | f.close() 42 | 43 | # Move Files 44 | if not os.path.exists(save_path_rank_base): 45 | os.mkdir(save_path_rank_base) 46 | 47 | for i, file_tuple in tqdm(enumerate(topk_files),total=Topk): 48 | file_name = file_tuple[0] 49 | src_path = save_path_base + file_name 50 | move_path = save_path_rank_base + file_name 51 | copyfile(src_path,move_path) 52 | -------------------------------------------------------------------------------- /diagnostics/ppi_test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import ipdb 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import PPI 6 | from torch_geometric.data import DataLoader 7 | from torch_geometric.nn import GATConv 8 | from sklearn.metrics import f1_score 9 | 10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI') 11 | train_dataset = PPI(path, split='train') 12 | val_dataset = PPI(path, split='test') 13 | test_dataset = PPI(path, split='test') 14 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 15 | val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) 16 | test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) 17 | 18 | 19 | class Net(torch.nn.Module): 20 | def __init__(self): 21 | super(Net, self).__init__() 22 | self.conv1 = GATConv(train_dataset.num_features, 256, heads=4) 23 | self.lin1 = torch.nn.Linear(train_dataset.num_features, 4 * 256) 24 | self.conv2 = GATConv(4 * 256, 256, heads=4) 25 | self.lin2 = torch.nn.Linear(4 * 256, 4 * 256) 26 | self.conv3 = GATConv( 27 | 4 * 256, train_dataset.num_classes, heads=6, concat=False) 28 | self.lin3 = torch.nn.Linear(4 * 256, train_dataset.num_classes) 29 | 30 | def forward(self, x, edge_index): 31 | x = F.elu(self.conv1(x, edge_index) + self.lin1(x)) 32 | x = F.elu(self.conv2(x, edge_index) + self.lin2(x)) 33 | x = self.conv3(x, edge_index) + self.lin3(x) 34 | return x 35 | 36 | 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | model = Net().to(device) 39 | loss_op = torch.nn.BCEWithLogitsLoss() 40 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005) 41 | 42 | def train(): 43 | model.train() 44 | 45 | total_loss = 0 46 | for data in train_loader: 47 | num_graphs = data.num_graphs 48 | data.batch = None 49 | data = data.to(device) 50 | optimizer.zero_grad() 51 | loss = loss_op(model(data.x, data.edge_index), data.y) 52 | total_loss += loss.item() * num_graphs 53 | loss.backward() 54 | optimizer.step() 55 | return total_loss / len(train_loader.dataset) 56 | 57 | def test(loader): 58 | model.eval() 59 | 60 | ys, preds = [], [] 61 | for data in loader: 62 | ys.append(data.y) 63 | with torch.no_grad(): 64 | out = model(data.x.to(device), data.edge_index.to(device)) 65 | preds.append((out > 0).float().cpu()) 66 | 67 | y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() 68 | return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 69 | 70 | 71 | for epoch in range(1, 101): 72 | loss = train() 73 | acc = test(val_loader) 74 | print('Epoch: {:02d}, Loss: {:.4f}, F1: {:.4f}'.format(epoch, loss, acc)) 75 | -------------------------------------------------------------------------------- /maml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd 3 | from collections import OrderedDict 4 | from torch.optim import Optimizer 5 | from torch.nn import Module 6 | from typing import Dict, List, Callable, Union 7 | from utils.utils import custom_clip_grad_norm_, val, test, EarlyStopping, monitor_grad_norm, monitor_grad_norm_2, monitor_weight_norm 8 | from torchviz import make_dot 9 | import ipdb 10 | import wandb 11 | from copy import deepcopy 12 | 13 | def replace_grad(parameter_gradients, parameter_name): 14 | def replace_grad_(module): 15 | return parameter_gradients[parameter_name] 16 | 17 | return replace_grad_ 18 | 19 | 20 | def meta_gradient_step(model, 21 | args, 22 | data_batch, 23 | optimiser, 24 | inner_train_steps, 25 | inner_lr, 26 | order, 27 | graph_id, 28 | mode, 29 | inner_avg_auc_list, 30 | inner_avg_ap_list, 31 | epoch, 32 | batch_id, 33 | train, 34 | inner_test_auc_array=None, 35 | inner_test_ap_array=None): 36 | """ 37 | Perform a gradient step on a meta-learner. 38 | # Arguments 39 | model: Base model of the meta-learner being trained 40 | optimiser: Optimiser to calculate gradient step from loss 41 | loss_fn: Loss function to calculate between predictions and outputs 42 | data_batch: Input samples for all few shot tasks 43 | meta-gradients after applying the update to 44 | inner_train_steps: Number of gradient steps to fit the fast weights during each inner update 45 | inner_lr: Learning rate used to update the fast weights on the inner update 46 | order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the 47 | query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated 48 | weights on the query with respect to the original weights). 49 | graph_id: The ID of graph currently being trained 50 | train: Whether to update the meta-learner weights at the end of the episode. 51 | inner_test_auc_array: Final Test AUC array where we train to convergence 52 | inner_test_ap_array: Final Test AP array where we train to convergence 53 | """ 54 | 55 | task_gradients = [] 56 | task_losses = [] 57 | task_predictions = [] 58 | auc_list = [] 59 | ap_list = [] 60 | torch.autograd.set_detect_anomaly(True) 61 | for idx, data_graph in enumerate(data_batch): 62 | data_graph.train_mask = data_graph.val_mask = data_graph.test_mask = data_graph.y = None 63 | data_graph.batch = None 64 | num_nodes = data_graph.num_nodes 65 | 66 | if args.use_fixed_feats: 67 | perm = torch.randperm(args.feats.size(0)) 68 | perm_idx = perm[:num_nodes] 69 | data_graph.x = args.feats[perm_idx] 70 | elif args.use_same_fixed_feats: 71 | node_feats = args.feats[0].unsqueeze(0).repeat(num_nodes,1) 72 | data_graph.x = node_feats 73 | 74 | if args.concat_fixed_feats: 75 | if data_graph.x.shape[1] < args.num_features: 76 | concat_feats = torch.randn(num_nodes,args.num_concat_features,requires_grad=False) 77 | data_graph.x = torch.cat((data_graph.x,concat_feats),1) 78 | 79 | # Val Ratio is Fixed at 0.1 80 | meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio 81 | 82 | ''' Check if Data is split''' 83 | try: 84 | x, train_pos_edge_index = data_graph.x.to(args.dev), data_graph.train_pos_edge_index.to(args.dev) 85 | data = data_graph 86 | except: 87 | data_graph.x.cuda() 88 | data = model.split_edges(data_graph,val_ratio=args.meta_val_edge_ratio,test_ratio=meta_test_edge_ratio) 89 | 90 | # Additional Failure Checks for small graphs 91 | if data.val_pos_edge_index.size()[1] == 0 or data.test_pos_edge_index.size()[1] == 0: 92 | args.fail_counter += 1 93 | print("Failed on Graph %d" %(graph_id)) 94 | continue 95 | 96 | try: 97 | x, train_pos_edge_index = data.x.to(args.dev), data.train_pos_edge_index.to(args.dev) 98 | test_pos_edge_index, test_neg_edge_index = data.test_pos_edge_index.to(args.dev),\ 99 | data.test_neg_edge_index.to(args.dev) 100 | except: 101 | print("Failed Splitting data on Graph %d" %(graph_id)) 102 | continue 103 | 104 | data_shape = x.shape[2:] 105 | create_graph = (True if order == 2 else False) and train 106 | 107 | # Create a fast model using the current meta model weights 108 | fast_weights = OrderedDict(model.named_parameters()) 109 | early_stopping = EarlyStopping(patience=args.patience, verbose=False) 110 | 111 | # Train the model for `inner_train_steps` iterations 112 | for inner_batch in range(inner_train_steps): 113 | # Perform update of model weights 114 | z = model.encode(x, train_pos_edge_index, fast_weights, inner_loop=True) 115 | loss = model.recon_loss(z, train_pos_edge_index) 116 | if args.model in ['VGAE']: 117 | kl_loss = args.kl_anneal*(1 / num_nodes) * model.kl_loss() 118 | loss = loss + kl_loss 119 | # print("Inner KL Loss: %f" %(kl_loss.item())) 120 | if not args.train_only_gs: 121 | gradients = torch.autograd.grad(loss, fast_weights.values(),\ 122 | allow_unused=args.allow_unused, create_graph=create_graph) 123 | gradients = [0 if grad is None else grad for grad in gradients] 124 | if args.wandb: 125 | wandb.log({"Inner_Train_loss":loss.item()}) 126 | 127 | if args.clip_grad: 128 | # for grad in gradients: 129 | custom_clip_grad_norm_(gradients,args.clip) 130 | grad_norm = monitor_grad_norm_2(gradients) 131 | if args.wandb: 132 | inner_grad_norm_metric = 'Inner_Grad_Norm' 133 | wandb.log({inner_grad_norm_metric:grad_norm}) 134 | 135 | ''' Only do this if its the final test set eval ''' 136 | if args.final_test and inner_batch % 5 ==0: 137 | inner_test_auc, inner_test_ap = test(model, x, train_pos_edge_index, 138 | data.test_pos_edge_index, data.test_neg_edge_index,fast_weights) 139 | val_pos_edge_index = data.val_pos_edge_index.to(args.dev) 140 | val_loss = val(model,args, x,val_pos_edge_index,data.num_nodes,fast_weights) 141 | early_stopping(val_loss, model) 142 | my_step = int(inner_batch / 5) 143 | inner_test_auc_array[graph_id][my_step] = inner_test_auc 144 | inner_test_ap_array[graph_id][my_step] = inner_test_ap 145 | 146 | # Update weights manually 147 | if not args.train_only_gs and args.clip_weight: 148 | fast_weights = OrderedDict( 149 | (name, torch.clamp((param - inner_lr * grad),-args.clip_weight_val,args.clip_weight_val)) 150 | for ((name, param), grad) in zip(fast_weights.items(), gradients) 151 | ) 152 | elif not args.train_only_gs: 153 | fast_weights = OrderedDict( 154 | (name, param - inner_lr * grad) 155 | for ((name, param), grad) in zip(fast_weights.items(), gradients) 156 | ) 157 | 158 | if early_stopping.early_stop: 159 | print("Early stopping for Graph %d | AUC: %f AP: %f" \ 160 | %(graph_id, inner_test_auc, inner_test_ap)) 161 | my_step = int(epoch / 5) 162 | inner_test_auc_array[graph_id][my_step:,] = inner_test_auc 163 | inner_test_ap_array[graph_id][my_step:,] = inner_test_ap 164 | break 165 | 166 | # Do a pass of the model on the validation data from the current task 167 | val_pos_edge_index = data.val_pos_edge_index.to(args.dev) 168 | z_val = model.encode(x, val_pos_edge_index, fast_weights, inner_loop=False) 169 | loss_val = model.recon_loss(z_val, val_pos_edge_index) 170 | if args.model in ['VGAE']: 171 | kl_loss = args.kl_anneal*(1 / num_nodes) * model.kl_loss() 172 | # print("Outer KL Loss: %f" %(kl_loss.item())) 173 | loss_val = loss_val + kl_loss 174 | 175 | if args.wandb: 176 | wandb.log({"Inner_Val_loss":loss_val.item()}) 177 | # print("Inner Val Loss %f" % (loss_val.item())) 178 | 179 | ##TODO: Is this backward call needed here? Not sure because original repo has it 180 | # https://github.com/oscarknagg/few-shot/blob/master/few_shot/maml.py#L84 181 | if args.extra_backward: 182 | loss_val.backward(retain_graph=True) 183 | 184 | # Get post-update accuracies 185 | auc, ap = test(model, x, train_pos_edge_index, 186 | data.test_pos_edge_index, data.test_neg_edge_index,fast_weights) 187 | 188 | auc_list.append(auc) 189 | ap_list.append(ap) 190 | inner_avg_auc_list.append(auc) 191 | inner_avg_ap_list.append(ap) 192 | 193 | # Accumulate losses and gradients 194 | graph_id += 1 195 | task_losses.append(loss_val) 196 | if order == 1: 197 | gradients = torch.autograd.grad(loss_val, fast_weights.values(), create_graph=create_graph) 198 | named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)} 199 | task_gradients.append(named_grads) 200 | 201 | if len(auc_list) > 0 and len(ap_list) > 0 and batch_id % 5 == 0: 202 | print('Epoch {:01d} Inner Graph Batch: {:01d}, Inner-Update AUC: {:.4f}, AP: {:.4f}'.format(epoch,batch_id,sum(auc_list)/len(auc_list),sum(ap_list)/len(ap_list))) 203 | if args.comet: 204 | if len(ap_list) > 0: 205 | auc_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AUC' 206 | ap_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AP' 207 | avg_auc_metric = mode + '_Inner_Batch_Graph' + '_AUC' 208 | avg_ap_metric = mode + '_Inner_Batch_Graph' + '_AP' 209 | args.experiment.log_metric(auc_metric,sum(auc_list)/len(auc_list),step=epoch) 210 | args.experiment.log_metric(ap_metric,sum(ap_list)/len(ap_list),step=epoch) 211 | args.experiment.log_metric(avg_auc_metric,sum(auc_list)/len(auc_list),step=epoch) 212 | args.experiment.log_metric(avg_ap_metric,sum(ap_list)/len(ap_list),step=epoch) 213 | if args.wandb: 214 | if len(ap_list) > 0: 215 | auc_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AUC' 216 | ap_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AP' 217 | avg_auc_metric = mode + '_Inner_Batch_Graph' + '_AUC' 218 | avg_ap_metric = mode + '_Inner_Batch_Graph' + '_AP' 219 | wandb.log({auc_metric:sum(auc_list)/len(auc_list),ap_metric:sum(ap_list)/len(ap_list),\ 220 | avg_auc_metric:sum(auc_list)/len(auc_list),avg_ap_metric:sum(ap_list)/len(ap_list)}) 221 | 222 | meta_batch_loss = torch.Tensor([0]) 223 | if order == 1: 224 | if train and len(task_losses) != 0: 225 | sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0) 226 | for k in task_gradients[0].keys()} 227 | hooks = [] 228 | for name, param in model.named_parameters(): 229 | hooks.append( 230 | param.register_hook(replace_grad(sum_task_gradients, name)) 231 | ) 232 | 233 | model.train() 234 | optimiser.zero_grad() 235 | # Dummy pass in order to create `loss` variable 236 | # Replace dummy gradients with mean task gradients using hooks 237 | ## TODO: Double check if you really need functional forward here 238 | z_dummy = model.encode(torch.zeros(x.shape[0],x.shape[1]).float().cuda(), \ 239 | torch.zeros(train_pos_edge_index.shape[0],train_pos_edge_index.shape[1]).long().cuda(), fast_weights) 240 | loss = model.recon_loss(z_dummy,torch.zeros(train_pos_edge_index.shape[0],\ 241 | train_pos_edge_index.shape[1]).long().cuda()) 242 | loss.backward() 243 | optimiser.step() 244 | 245 | for h in hooks: 246 | h.remove() 247 | meta_batch_loss = torch.stack(task_losses).mean() 248 | return graph_id, meta_batch_loss, inner_avg_auc_list, inner_avg_ap_list 249 | 250 | elif order == 2: 251 | if len(task_losses) != 0: 252 | model.train() 253 | optimiser.zero_grad() 254 | meta_batch_loss = torch.stack(task_losses).mean() 255 | 256 | if train: 257 | meta_batch_loss.backward() 258 | if args.clip_grad: 259 | torch.nn.utils.clip_grad_norm_(model.parameters(),args.clip) 260 | grad_norm = monitor_grad_norm(model) 261 | if args.wandb: 262 | outer_grad_norm_metric = 'Outer_Grad_Norm' 263 | wandb.log({outer_grad_norm_metric:grad_norm}) 264 | 265 | optimiser.step() 266 | if args.clip_weight: 267 | for p in model.parameters(): 268 | p.data.clamp_(-args.clip_weight_val,args.clip_weight_val) 269 | return graph_id, meta_batch_loss, inner_avg_auc_list, inner_avg_ap_list 270 | else: 271 | raise ValueError('Order must be either 1 or 2.') 272 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeybose/Meta-Graph/e35ef1b6c6969a2bb0b0c6bba0b43a269ede3cac/models/__init__.py -------------------------------------------------------------------------------- /models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import ipdb 4 | import torch 5 | import numpy as np 6 | from sklearn.metrics import roc_auc_score, average_precision_score 7 | from torch_geometric.utils import to_undirected 8 | 9 | # from ..inits import reset 10 | 11 | EPS = 1e-15 12 | LOG_VAR_MAX = 10 13 | LOG_VAR_MIN = EPS 14 | 15 | def reset(nn): 16 | def _reset(item): 17 | if hasattr(item, 'reset_parameters'): 18 | item.reset_parameters() 19 | 20 | if nn is not None: 21 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 22 | for item in nn.children(): 23 | _reset(item) 24 | else: 25 | _reset(nn) 26 | 27 | def negative_sampling(pos_edge_index, num_nodes): 28 | idx = (pos_edge_index[0] * num_nodes + pos_edge_index[1]) 29 | idx = idx.to(torch.device('cpu')) 30 | 31 | rng = range(num_nodes**2) 32 | perm = torch.tensor(random.sample(rng, idx.size(0))) 33 | mask = torch.from_numpy(np.isin(perm, idx).astype(np.uint8)) 34 | rest = mask.nonzero().view(-1) 35 | while rest.numel() > 0: # pragma: no cover 36 | tmp = torch.tensor(random.sample(rng, rest.size(0))) 37 | mask = torch.from_numpy(np.isin(tmp, idx).astype(np.uint8)) 38 | perm[rest] = tmp 39 | rest = mask.nonzero().view(-1) 40 | 41 | row, col = perm / num_nodes, perm % num_nodes 42 | return torch.stack([row, col], dim=0).to(pos_edge_index.device) 43 | 44 | 45 | class InnerProductDecoder(torch.nn.Module): 46 | r"""The inner product decoder from the `"Variational Graph Auto-Encoders" 47 | `_ paper 48 | 49 | .. math:: 50 | \sigma(\mathbf{Z}\mathbf{Z}^{\top}) 51 | 52 | where :math:`\mathbf{Z} \in \mathbb{R}^{N \times d}` denotes the latent 53 | space produced by the encoder.""" 54 | 55 | def forward(self, z, edge_index, sigmoid=True): 56 | r"""Decodes the latent variables :obj:`z` into edge probabilties for 57 | the given node-pairs :obj:`edge_index`. 58 | 59 | Args: 60 | z (Tensor): The latent space :math:`\mathbf{Z}`. 61 | sigmoid (bool, optional): If set to :obj:`False`, does not apply 62 | the logistic sigmoid function to the output. 63 | (default: :obj:`True`) 64 | """ 65 | value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1) 66 | return torch.sigmoid(value) if sigmoid else value 67 | 68 | def forward_all(self, z, sigmoid=True): 69 | r"""Decodes the latent variables :obj:`z` into a probabilistic dense 70 | adjacency matrix. 71 | 72 | Args: 73 | z (Tensor): The latent space :math:`\mathbf{Z}`. 74 | sigmoid (bool, optional): If set to :obj:`False`, does not apply 75 | the logistic sigmoid function to the output. 76 | (default: :obj:`True`) 77 | """ 78 | adj = torch.matmul(z, z.t()) 79 | return torch.sigmoid(adj) if sigmoid else adj 80 | 81 | 82 | class MyGAE(torch.nn.Module): 83 | r"""The Graph Auto-Encoder model from the 84 | `"Variational Graph Auto-Encoders" `_ 85 | paper based on user-defined encoder and decoder models. 86 | 87 | Args: 88 | encoder (Module): The encoder module. 89 | decoder (Module, optional): The decoder module. If set to :obj:`None`, 90 | will default to the 91 | :class:`torch_geometric.nn.models.InnerProductDecoder`. 92 | (default: :obj:`None`) 93 | """ 94 | 95 | def __init__(self, encoder, decoder=None): 96 | super(MyGAE, self).__init__() 97 | self.encoder = encoder 98 | self.decoder = InnerProductDecoder() if decoder is None else decoder 99 | 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | reset(self.encoder) 104 | reset(self.decoder) 105 | 106 | def encode(self, *args, **kwargs): 107 | r"""Runs the encoder and computes node-wise latent variables.""" 108 | return self.encoder(*args, **kwargs) 109 | 110 | def decode(self, *args, **kwargs): 111 | r"""Runs the decoder and computes edge probabilties.""" 112 | return self.decoder(*args, **kwargs) 113 | 114 | def split_edges(self, data, val_ratio=0.05, test_ratio=0.1): 115 | r"""Splits the edges of a :obj:`torch_geometric.data.Data` object 116 | into positve and negative train/val/test edges. 117 | 118 | Args: 119 | data (Data): The data object. 120 | val_ratio (float, optional): The ratio of positive validation 121 | edges. (default: :obj:`0.05`) 122 | test_ratio (float, optional): The ratio of positive test 123 | edges. (default: :obj:`0.1`) 124 | """ 125 | 126 | assert 'batch' not in data # No batch-mode. 127 | 128 | row, col = data.edge_index 129 | data.edge_index = None 130 | 131 | # Return upper triangular portion. 132 | mask = row < col 133 | row, col = row[mask], col[mask] 134 | 135 | n_v = int(math.floor(val_ratio * row.size(0))) 136 | n_t = int(math.floor(test_ratio * row.size(0))) 137 | 138 | # Positive edges. 139 | perm = torch.randperm(row.size(0)) 140 | row, col = row[perm], col[perm] 141 | 142 | r, c = row[:n_v], col[:n_v] 143 | data.val_pos_edge_index = torch.stack([r, c], dim=0) 144 | r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t] 145 | data.test_pos_edge_index = torch.stack([r, c], dim=0) 146 | 147 | r, c = row[n_v + n_t:], col[n_v + n_t:] 148 | data.train_pos_edge_index = torch.stack([r, c], dim=0) 149 | data.train_pos_edge_index = to_undirected(data.train_pos_edge_index) 150 | 151 | # Negative edges. 152 | num_nodes = data.num_nodes 153 | neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8) 154 | neg_adj_mask = neg_adj_mask.triu(diagonal=1) 155 | neg_adj_mask[row, col] = 0 156 | 157 | neg_row, neg_col = neg_adj_mask.nonzero().t() 158 | perm = torch.tensor(random.sample(range(neg_row.size(0)), n_v + n_t)) 159 | perm = perm.to(torch.long) 160 | neg_row, neg_col = neg_row[perm], neg_col[perm] 161 | 162 | neg_adj_mask[neg_row, neg_col] = 0 163 | data.train_neg_adj_mask = neg_adj_mask 164 | 165 | row, col = neg_row[:n_v], neg_col[:n_v] 166 | data.val_neg_edge_index = torch.stack([row, col], dim=0) 167 | 168 | row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t] 169 | data.test_neg_edge_index = torch.stack([row, col], dim=0) 170 | 171 | return data 172 | 173 | def recon_loss(self, z, pos_edge_index): 174 | r"""Given latent variables :obj:`z`, computes the binary cross 175 | entropy loss for positive edges :obj:`pos_edge_index` and negative 176 | sampled edges. 177 | 178 | Args: 179 | z (Tensor): The latent space :math:`\mathbf{Z}`. 180 | pos_edge_index (LongTensor): The positive edges to train against. 181 | """ 182 | 183 | pos_loss = -torch.log( 184 | self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean() 185 | 186 | neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) 187 | neg_loss = -torch.log( 188 | 1 - self.decoder(z, neg_edge_index, sigmoid=True) + EPS).mean() 189 | 190 | return pos_loss + neg_loss 191 | 192 | def test(self, z, pos_edge_index, neg_edge_index): 193 | r"""Given latent variables :obj:`z`, positive edges 194 | :obj:`pos_edge_index` and negative edges :obj:`neg_edge_index`, 195 | computes area under the ROC curve (AUC) and average precision (AP) 196 | scores. 197 | 198 | Args: 199 | z (Tensor): The latent space :math:`\mathbf{Z}`. 200 | pos_edge_index (LongTensor): The positive edges to evaluate 201 | against. 202 | neg_edge_index (LongTensor): The negative edges to evaluate 203 | against. 204 | """ 205 | pos_y = z.new_ones(pos_edge_index.size(1)) 206 | neg_y = z.new_zeros(neg_edge_index.size(1)) 207 | y = torch.cat([pos_y, neg_y], dim=0) 208 | 209 | pos_pred = self.decoder(z, pos_edge_index, sigmoid=True) 210 | neg_pred = self.decoder(z, neg_edge_index, sigmoid=True) 211 | pred = torch.cat([pos_pred, neg_pred], dim=0) 212 | 213 | y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy() 214 | 215 | return roc_auc_score(y, pred), average_precision_score(y, pred) 216 | 217 | 218 | class MyVGAE(MyGAE): 219 | r"""The Variational Graph Auto-Encoder model from the 220 | `"Variational Graph Auto-Encoders" `_ 221 | paper. 222 | 223 | Args: 224 | encoder (Module): The encoder module to compute :math:`\mu` and 225 | :math:`\log\sigma^2`. 226 | decoder (Module, optional): The decoder module. If set to :obj:`None`, 227 | will default to the 228 | :class:`torch_geometric.nn.models.InnerProductDecoder`. 229 | (default: :obj:`None`) 230 | """ 231 | 232 | def __init__(self, encoder, decoder=None): 233 | super(MyVGAE, self).__init__(encoder, decoder=decoder) 234 | 235 | def reparametrize(self, mu, logvar): 236 | if self.training: 237 | return mu + torch.randn_like(logvar) * torch.exp(logvar) 238 | else: 239 | return mu 240 | 241 | def encode(self, *args, **kwargs): 242 | """""" 243 | self.__mu__, self.__logvar__ = self.encoder(*args, **kwargs) 244 | self.__logvar__ = torch.clamp(self.__logvar__,min=LOG_VAR_MIN,max=LOG_VAR_MAX) 245 | # self.__logvar__ = torch.clamp(self.__logvar__,max=LOG_VAR_MAX) 246 | z = self.reparametrize(self.__mu__, self.__logvar__) 247 | return z 248 | 249 | def kl_loss(self, mu=None, logvar=None): 250 | r"""Computes the KL loss, either for the passed arguments :obj:`mu` 251 | and :obj:`logvar`, or based on latent variables from last encoding. 252 | 253 | Args: 254 | mu (Tensor, optional): The latent space for :math:`\mu`. If set to 255 | :obj:`None`, uses the last computation of :math:`mu`. 256 | (default: :obj:`None`) 257 | logvar (Tensor, optional): The latent space for 258 | :math:`\log\sigma^2`. If set to :obj:`None`, uses the last 259 | computation of :math:`\log\sigma^2`.(default: :obj:`None`) 260 | """ 261 | mu = self.__mu__ if mu is None else mu 262 | logvar = self.__logvar__ if logvar is None else logvar 263 | # print("KL Variance is %f" %(torch.mean(logvar.exp()).item())) 264 | return -0.5 * torch.mean( 265 | torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1)) 266 | 267 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from torch_geometric.datasets import Planetoid,PPI 5 | import torch_geometric.transforms as T 6 | from torch_geometric.nn import GATConv, GCNConv, GAE, VGAE 7 | from torch.nn import Parameter 8 | from torch_scatter import scatter_add 9 | from torch_geometric.nn.conv import MessagePassing 10 | from torch_geometric.utils import add_remaining_self_loops 11 | from torch.distributions import Normal 12 | from torch import nn 13 | import torch.nn.functional as F 14 | from utils.utils import uniform 15 | import ipdb 16 | 17 | def glorot(tensor): 18 | if tensor is not None: 19 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 20 | tensor.data.uniform_(-stdv, stdv) 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | class MetaGCNConv(MessagePassing): 27 | r"""The graph convolutional operator from the `"Semi-supervised 28 | Classfication with Graph Convolutional Networks" 29 | `_ paper 30 | .. math:: 31 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 32 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 33 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 34 | adjacency matrix with inserted self-loops and 35 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 36 | Args: 37 | in_channels (int): Size of each input sample. 38 | out_channels (int): Size of each output sample. 39 | improved (bool, optional): If set to :obj:`True`, the layer computes 40 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 41 | (default: :obj:`False`) 42 | cached (bool, optional): If set to :obj:`True`, the layer will cache 43 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 44 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 45 | (default: :obj:`False`) 46 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 47 | an additive bias. (default: :obj:`True`) 48 | **kwargs (optional): Additional arguments of 49 | :class:`torch_geometric.nn.conv.MessagePassing`. 50 | """ 51 | 52 | def __init__(self, 53 | in_channels, 54 | out_channels, 55 | improved=False, 56 | cached=False, 57 | bias=True, 58 | **kwargs): 59 | super(MetaGCNConv, self).__init__(aggr='add', **kwargs) 60 | 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | self.improved = improved 64 | self.cached = cached 65 | self.cached_result = None 66 | 67 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 68 | 69 | if bias: 70 | self.bias = Parameter(torch.Tensor(out_channels)) 71 | else: 72 | self.register_parameter('bias', None) 73 | 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self): 77 | glorot(self.weight) 78 | zeros(self.bias) 79 | self.cached_result = None 80 | 81 | @staticmethod 82 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 83 | if edge_weight is None: 84 | edge_weight = torch.ones((edge_index.size(1), ), 85 | dtype=dtype, 86 | device=edge_index.device) 87 | 88 | fill_value = 1 if not improved else 2 89 | edge_index, edge_weight = add_remaining_self_loops( 90 | edge_index, edge_weight, fill_value, num_nodes) 91 | 92 | row, col = edge_index 93 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 94 | deg_inv_sqrt = deg.pow(-0.5) 95 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 96 | 97 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 98 | 99 | def forward(self, x, edge_index, weight, bias, edge_weight=None, gamma=None, beta=None): 100 | """""" 101 | x = torch.matmul(x, weight) 102 | 103 | ''' FiLM part ''' 104 | if gamma is not None and beta is not None: 105 | x = torch.mul(x, gamma) + beta 106 | 107 | if not self.cached or self.cached_result is None: 108 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, 109 | self.improved, x.dtype) 110 | self.cached_result = edge_index, norm 111 | edge_index, norm = self.cached_result 112 | 113 | return self.propagate(edge_index, x=x, norm=norm, bias=bias) 114 | 115 | def message(self, x_j, norm): 116 | return norm.view(-1, 1) * x_j 117 | 118 | def update(self, aggr_out, bias): 119 | if self.bias is not None: 120 | aggr_out = aggr_out + bias 121 | return aggr_out 122 | 123 | def __repr__(self): 124 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 125 | self.out_channels) 126 | 127 | class MetaGatedGCNConv(MessagePassing): 128 | r"""The graph convolutional operator from the `"Semi-supervised 129 | Classfication with Graph Convolutional Networks" 130 | `_ paper 131 | .. math:: 132 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 133 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 134 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 135 | adjacency matrix with inserted self-loops and 136 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 137 | Args: 138 | in_channels (int): Size of each input sample. 139 | out_channels (int): Size of each output sample. 140 | improved (bool, optional): If set to :obj:`True`, the layer computes 141 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 142 | (default: :obj:`False`) 143 | cached (bool, optional): If set to :obj:`True`, the layer will cache 144 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 145 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 146 | (default: :obj:`False`) 147 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 148 | an additive bias. (default: :obj:`True`) 149 | **kwargs (optional): Additional arguments of 150 | :class:`torch_geometric.nn.conv.MessagePassing`. 151 | """ 152 | 153 | def __init__(self, 154 | in_channels, 155 | out_channels, 156 | improved=False, 157 | cached=False, 158 | bias=True, 159 | gating=None, 160 | **kwargs): 161 | super(MetaGatedGCNConv, self).__init__(aggr='add', **kwargs) 162 | 163 | self.in_channels = in_channels 164 | self.out_channels = out_channels 165 | self.improved = improved 166 | self.cached = cached 167 | self.cached_result = None 168 | self.gating = gating 169 | if gating == 'signature': 170 | self.weight_1 = Parameter(torch.Tensor(in_channels, out_channels)) 171 | self.weight_2 = Parameter(torch.Tensor(1, 1)) 172 | self.gating_weights = Parameter(torch.Tensor(out_channels)) 173 | 174 | elif gating == 'weights': 175 | self.weight_1 = Parameter(torch.Tensor(in_channels, out_channels)) 176 | self.weight_2 = Parameter(torch.Tensor(in_channels, out_channels)) 177 | self.gating_weights = Parameter(torch.Tensor(out_channels)) 178 | 179 | elif gating == 'signature_cond': 180 | self.weight_1 = Parameter(torch.Tensor(in_channels, out_channels)) 181 | self.weight_2 = Parameter(torch.Tensor(1, 1)) 182 | self.gating_weights = Parameter(torch.Tensor(in_channels, out_channels)) 183 | 184 | elif gating == 'weights_cond': 185 | self.weight_1 = Parameter(torch.Tensor(in_channels, out_channels)) 186 | self.weight_2 = Parameter(torch.Tensor(in_channels, out_channels)) 187 | self.gating_weights = Parameter(torch.Tensor(in_channels, out_channels)) 188 | 189 | else: 190 | self.weight_1 = Parameter(torch.Tensor(in_channels, out_channels)) 191 | self.weight_2 = Parameter(torch.Tensor(1, 1)) 192 | self.gating_weights = Parameter(torch.Tensor(1, 1)) 193 | 194 | if bias: 195 | self.bias = Parameter(torch.Tensor(out_channels)) 196 | else: 197 | self.register_parameter('bias', None) 198 | 199 | self.reset_parameters() 200 | 201 | def reset_parameters(self): 202 | glorot(self.weight_1) 203 | glorot(self.weight_2) 204 | zeros(self.bias) 205 | 206 | if self.gating.endswith('cond'): 207 | glorot(self.gating_weights) 208 | else: 209 | zeros(self.gating_weights) 210 | 211 | self.cached_result = None 212 | 213 | @staticmethod 214 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 215 | if edge_weight is None: 216 | edge_weight = torch.ones((edge_index.size(1), ), 217 | dtype=dtype, 218 | device=edge_index.device) 219 | 220 | fill_value = 1 if not improved else 2 221 | edge_index, edge_weight = add_remaining_self_loops( 222 | edge_index, edge_weight, fill_value, num_nodes) 223 | 224 | row, col = edge_index 225 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 226 | deg_inv_sqrt = deg.pow(-0.5) 227 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 228 | 229 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 230 | 231 | def forward(self, x, edge_index, weight_1, weight_2, bias, gating_weights, edge_weight=None, gamma=None, beta=None): 232 | """""" 233 | ''' FiLM part ''' 234 | if gamma is not None and beta is not None: 235 | if self.gating == 'signature': 236 | alpha = torch.sigmoid(gating_weights) 237 | gamma = torch.mul(alpha, gamma) + torch.mul(1 - alpha, torch.ones_like(gamma)) 238 | beta = torch.mul(alpha, beta) + torch.mul(1 - alpha, torch.ones_like(beta)) 239 | 240 | x = torch.matmul(x, weight_1) 241 | x = torch.mul(x, gamma) + beta 242 | 243 | elif self.gating == 'weights': 244 | alpha = torch.sigmoid(gating_weights) 245 | 246 | x_1 = torch.matmul(x, weight_1) 247 | x_2 = torch.matmul(x, weight_2) 248 | x_2 = torch.mul(x_2, gamma) + beta 249 | 250 | x = torch.mul(alpha, x_1) + torch.mul(1 - alpha, x_2) 251 | 252 | elif self.gating == 'signature_cond': 253 | alpha = torch.sigmoid(torch.matmul(x, gating_weights)) 254 | gamma = torch.mul(alpha, gamma) + torch.mul(1 - alpha, torch.ones_like(gamma)) 255 | beta = torch.mul(alpha, beta) + torch.mul(1 - alpha, torch.ones_like(beta)) 256 | 257 | x = torch.matmul(x, weight_1) 258 | x = torch.mul(x, gamma) + beta 259 | 260 | elif self.gating == 'weights_cond': 261 | alpha = torch.sigmoid(torch.matmul(x, gating_weights)) 262 | 263 | x_1 = torch.matmul(x, weight_1) 264 | x_2 = torch.matmul(x, weight_2) 265 | x_2 = torch.mul(x_2, gamma) + beta 266 | 267 | x = torch.mul(alpha, x_1) + torch.mul(1 - alpha, x_2) 268 | 269 | else: 270 | x = torch.matmul(x, weight_1) 271 | x = torch.mul(x, gamma) + beta 272 | 273 | if not self.cached or self.cached_result is None: 274 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, 275 | self.improved, x.dtype) 276 | self.cached_result = edge_index, norm 277 | edge_index, norm = self.cached_result 278 | 279 | return self.propagate(edge_index, x=x, norm=norm, bias=bias) 280 | 281 | def message(self, x_j, norm): 282 | return norm.view(-1, 1) * x_j 283 | 284 | def update(self, aggr_out, bias): 285 | if self.bias is not None: 286 | aggr_out = aggr_out + bias 287 | return aggr_out 288 | 289 | def __repr__(self): 290 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 291 | self.out_channels) 292 | 293 | class MetaGRUCell(nn.Module): 294 | 295 | """ 296 | An implementation of GRUCell with Functional Ops. 297 | """ 298 | def __init__(self, input_size, hidden_size, bias=True): 299 | super(MetaGRUCell, self).__init__() 300 | self.input_size = input_size 301 | self.hidden_size = hidden_size 302 | self.bias = bias 303 | self.x2h = nn.Linear(input_size, 3 * hidden_size, bias=bias) 304 | self.h2h = nn.Linear(hidden_size, 3 * hidden_size, bias=bias) 305 | self.reset_parameters() 306 | 307 | def reset_parameters(self): 308 | std = 1.0 / math.sqrt(self.hidden_size) 309 | for w in self.parameters(): 310 | w.data.uniform_(-std, std) 311 | 312 | def forward(self, x, hidden, weights, keys): 313 | x = x.view(-1, x.size(1)) 314 | gate_x = F.linear(x,weights[keys[1]],weights[keys[2]]) 315 | gate_h = F.linear(x,weights[keys[3]],weights[keys[4]]) 316 | gate_x = gate_x.squeeze() 317 | gate_h = gate_h.squeeze() 318 | i_r, i_i, i_n = gate_x.chunk(3, 1) 319 | h_r, h_i, h_n = gate_h.chunk(3, 1) 320 | resetgate = torch.sigmoid(i_r + h_r) 321 | inputgate = torch.sigmoid(i_i + h_i) 322 | newgate = torch.tanh(i_n + (resetgate * h_n)) 323 | hy = newgate + inputgate * (hidden - newgate) 324 | return hy 325 | 326 | class MetaGatedGraphConv(MessagePassing): 327 | r"""The gated graph convolution operator from the `"Gated Graph Sequence 328 | Neural Networks" `_ paper 329 | .. math:: 330 | \mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0} 331 | \mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta} 332 | \cdot \mathbf{h}_j^{(l)} 333 | \mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, 334 | \mathbf{h}_i^{(l)}) 335 | up to representation :math:`\mathbf{h}_i^{(L)}`. 336 | The number of input channels of :math:`\mathbf{x}_i` needs to be less or 337 | equal than :obj:`out_channels`. 338 | Args: 339 | out_channels (int): Size of each input sample. 340 | num_layers (int): The sequence length :math:`L`. 341 | aggr (string, optional): The aggregation scheme to use 342 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). 343 | (default: :obj:`"add"`) 344 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 345 | an additive bias. (default: :obj:`True`) 346 | **kwargs (optional): Additional arguments of 347 | :class:`torch_geometric.nn.conv.MessagePassing`. 348 | """ 349 | 350 | def __init__(self, 351 | out_channels, 352 | num_layers, 353 | aggr='add', 354 | bias=True, 355 | **kwargs): 356 | super(MetaGatedGraphConv, self).__init__(aggr=aggr, **kwargs) 357 | 358 | self.out_channels = out_channels 359 | self.num_layers = num_layers 360 | 361 | self.weight = Parameter(torch.Tensor(num_layers, out_channels, out_channels)) 362 | self.rnn = MetaGRUCell(out_channels, out_channels, bias=bias) 363 | 364 | self.reset_parameters() 365 | 366 | def reset_parameters(self): 367 | uniform(self.out_channels, self.weight) 368 | self.rnn.reset_parameters() 369 | 370 | def forward(self, x, edge_index, weights, keys, edge_weight=None): 371 | """""" 372 | h = x if x.dim() == 2 else x.unsqueeze(-1) 373 | if h.size(1) > self.out_channels: 374 | raise ValueError('The number of input channels is not allowed to ' 375 | 'be larger than the number of output channels') 376 | 377 | if h.size(1) < self.out_channels: 378 | zero = h.new_zeros(h.size(0), self.out_channels - h.size(1)) 379 | h = torch.cat([h, zero], dim=1) 380 | 381 | for i in range(self.num_layers): 382 | m = torch.matmul(h, weights[keys[0]][i]) 383 | m = self.propagate(edge_index, x=m, edge_weight=edge_weight) 384 | h = self.rnn(m, h, weights, keys) 385 | 386 | return h 387 | 388 | def message(self, x_j, edge_weight): 389 | if edge_weight is not None: 390 | return edge_weight.view(-1, 1) * x_j 391 | return x_j 392 | 393 | def __repr__(self): 394 | return '{}({}, num_layers={})'.format( 395 | self.__class__.__name__, self.out_channels, self.num_layers) 396 | 397 | class MetaGATConv(MessagePassing): 398 | r"""The graph attentional operator from the `"Graph Attention Networks" 399 | `_ paper 400 | .. math:: 401 | \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + 402 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, 403 | where the attention coefficients :math:`\alpha_{i,j}` are computed as 404 | .. math:: 405 | \alpha_{i,j} = 406 | \frac{ 407 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 408 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] 409 | \right)\right)} 410 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} 411 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 412 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] 413 | \right)\right)}. 414 | Args: 415 | in_channels (int): Size of each input sample. 416 | out_channels (int): Size of each output sample. 417 | heads (int, optional): Number of multi-head-attentions. 418 | (default: :obj:`1`) 419 | concat (bool, optional): If set to :obj:`False`, the multi-head 420 | attentions are averaged instead of concatenated. 421 | (default: :obj:`True`) 422 | negative_slope (float, optional): LeakyReLU angle of the negative 423 | slope. (default: :obj:`0.2`) 424 | dropout (float, optional): Dropout probability of the normalized 425 | attention coefficients which exposes each node to a stochastically 426 | sampled neighborhood during training. (default: :obj:`0`) 427 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 428 | an additive bias. (default: :obj:`True`) 429 | **kwargs (optional): Additional arguments of 430 | :class:`torch_geometric.nn.conv.MessagePassing`. 431 | """ 432 | 433 | def __init__(self, in_channels, out_channels, heads=1, concat=True, 434 | negative_slope=0.2, dropout=0, bias=True, **kwargs): 435 | super(MetaGATConv, self).__init__(aggr='add', **kwargs) 436 | 437 | self.in_channels = in_channels 438 | self.out_channels = out_channels 439 | self.heads = heads 440 | self.concat = concat 441 | self.negative_slope = negative_slope 442 | self.dropout = dropout 443 | 444 | self.weight = Parameter( 445 | torch.Tensor(in_channels, heads * out_channels)) 446 | self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels)) 447 | 448 | if bias and concat: 449 | self.bias = Parameter(torch.Tensor(heads * out_channels)) 450 | elif bias and not concat: 451 | self.bias = Parameter(torch.Tensor(out_channels)) 452 | else: 453 | self.register_parameter('bias', None) 454 | 455 | self.reset_parameters() 456 | 457 | def reset_parameters(self): 458 | glorot(self.weight) 459 | glorot(self.att) 460 | zeros(self.bias) 461 | 462 | def forward(self, x, edge_index, weight, bias, size=None): 463 | """""" 464 | if size is None and torch.is_tensor(x): 465 | edge_index, _ = remove_self_loops(edge_index) 466 | edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 467 | 468 | if torch.is_tensor(x): 469 | x = torch.matmul(x, weight) 470 | else: 471 | x = (None if x[0] is None else torch.matmul(x[0], self.weight), 472 | None if x[1] is None else torch.matmul(x[1], self.weight)) 473 | 474 | return self.propagate(edge_index, size=size, x=x, bias=bias) 475 | 476 | def message(self, edge_index_i, x_i, x_j, size_i): 477 | # Compute attention coefficients. 478 | x_j = x_j.view(-1, self.heads, self.out_channels) 479 | if x_i is None: 480 | alpha = (x_j * self.att[:, :, self.out_channels:]).sum(dim=-1) 481 | else: 482 | x_i = x_i.view(-1, self.heads, self.out_channels) 483 | alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) 484 | 485 | alpha = F.leaky_relu(alpha, self.negative_slope) 486 | alpha = softmax(alpha, edge_index_i, size_i) 487 | 488 | # Sample attention coefficients stochastically. 489 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 490 | 491 | return x_j * alpha.view(-1, self.heads, 1) 492 | 493 | def update(self, aggr_out): 494 | if self.concat is True: 495 | aggr_out = aggr_out.view(-1, self.heads * self.out_channels) 496 | else: 497 | aggr_out = aggr_out.mean(dim=1) 498 | 499 | if self.bias is not None: 500 | aggr_out = aggr_out + self.bias 501 | return aggr_out 502 | 503 | def __repr__(self): 504 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 505 | self.in_channels, 506 | self.out_channels, self.heads) 507 | 508 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from torch_geometric.datasets import Planetoid,PPI 5 | import torch_geometric.transforms as T 6 | from torch_geometric.nn import GATConv, GCNConv, GAE, VGAE 7 | from torch.nn import Parameter 8 | from torch_scatter import scatter_add 9 | from torch_geometric.nn.conv import MessagePassing 10 | from torch_geometric.utils import add_remaining_self_loops 11 | from torch.distributions import Normal 12 | from torch import nn 13 | from .layers import MetaGCNConv, MetaGatedGraphConv, MetaGRUCell, MetaGatedGCNConv 14 | import torch.nn.functional as F 15 | from utils.utils import uniform 16 | import ipdb 17 | 18 | def glorot(tensor): 19 | if tensor is not None: 20 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 21 | tensor.data.uniform_(-stdv, stdv) 22 | 23 | def zeros(tensor): 24 | if tensor is not None: 25 | tensor.data.fill_(0) 26 | 27 | class Encoder(torch.nn.Module): 28 | def __init__(self, args, in_channels, out_channels): 29 | super(Encoder, self).__init__() 30 | self.args = args 31 | self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=False) 32 | if args.model in ['GAE']: 33 | self.conv2 = GCNConv(2 * out_channels, out_channels, cached=False) 34 | elif args.model in ['VGAE']: 35 | self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=False) 36 | self.conv_logvar = GCNConv( 37 | 2 * out_channels, out_channels, cached=False) 38 | 39 | def forward(self, x, edge_index): 40 | x = F.relu(self.conv1(x, edge_index)) 41 | if self.args.model in ['GAE']: 42 | return self.conv2(x, edge_index) 43 | elif self.args.model in ['VGAE']: 44 | return self.conv_mu(x, edge_index), self.conv_logvar(x, edge_index) 45 | 46 | class MetaMLPEncoder(torch.nn.Module): 47 | def __init__(self, args, in_channels, out_channels): 48 | super(MetaMLPEncoder, self).__init__() 49 | self.args = args 50 | self.fc1 = nn.Linear(in_channels, 2 * out_channels, bias=True) 51 | if args.model in ['GAE']: 52 | self.fc_mu = nn.Linear(2 * out_channels, out_channels, bias=True) 53 | elif args.model in ['VGAE']: 54 | self.fc_mu = nn.Linear(2 * out_channels, out_channels, bias=True) 55 | self.fc_logvar = nn.Linear(2 * out_channels, out_channels, bias=True) 56 | 57 | def forward(self, x, edge_index, weights, inner_loop=True): 58 | x = F.relu(F.linear(x, weights['encoder.fc1.weight'],weights['encoder.fc1.bias'])) 59 | if self.args.model in ['GAE']: 60 | return F.relu(F.linear(x, weights['encoder.fc_mu.weight'],weights['encoder.fc_mu.bias'])) 61 | elif self.args.model in ['VGAE']: 62 | return F.relu(F.linear(x,weights['encoder.fc_mu.weight'],\ 63 | weights['encoder.fc_mu.bias'])),F.relu(F.linear(x,\ 64 | weights['encoder.fc_logvar.weight'],weights['encoder.fc_logvar.bias'])) 65 | 66 | class MLPEncoder(torch.nn.Module): 67 | def __init__(self, args, in_channels, out_channels): 68 | super(MLPEncoder, self).__init__() 69 | self.args = args 70 | self.fc1 = nn.Linear(in_channels, 2 * out_channels, bias=True) 71 | self.fc2 = nn.Linear(2 * out_channels, out_channels, bias=True) 72 | 73 | def forward(self, x, edge_index): 74 | x = F.relu(self.fc1(x)) 75 | x = F.relu(self.fc2(x)) 76 | return x 77 | 78 | class GraphSignature(torch.nn.Module): 79 | def __init__(self, args, in_channels, out_channels): 80 | super(GraphSignature, self).__init__() 81 | self.args = args 82 | if self.args.use_gcn_sig: 83 | self.conv1 = MetaGCNConv(in_channels, 2*out_channels, cached=False) 84 | self.fc1 = nn.Linear(2*out_channels, 2*out_channels, bias=True) 85 | self.fc2 = nn.Linear(2*out_channels, 2*out_channels, bias=True) 86 | self.fc3 = nn.Linear(2*out_channels, out_channels, bias=True) 87 | self.fc4 = nn.Linear(2*out_channels, out_channels, bias=True) 88 | else: 89 | self.gated_conv1 = MetaGatedGraphConv(in_channels, args.num_gated_layers) 90 | self.fc1 = nn.Linear(in_channels, 2 * out_channels, bias=True) 91 | self.fc2 = nn.Linear(in_channels, 2 * out_channels, bias=True) 92 | self.fc3 = nn.Linear(in_channels, out_channels, bias=True) 93 | self.fc4 = nn.Linear(in_channels, out_channels, bias=True) 94 | 95 | def forward(self, x, edge_index, weights, keys): 96 | if self.args.use_gcn_sig: 97 | x = F.relu(self.conv1(x, edge_index, \ 98 | weights['encoder.signature.conv1.weight'],\ 99 | weights['encoder.signature.conv1.bias'])) 100 | else: 101 | x = F.relu(self.gated_conv1(x, edge_index, weights,keys)) 102 | 103 | x = x.sum(0) 104 | x_gamma_1 = F.linear(x, weights['encoder.signature.fc1.weight'],\ 105 | weights['encoder.signature.fc1.bias']) 106 | x_beta_1 = F.linear(x, weights['encoder.signature.fc2.weight'],\ 107 | weights['encoder.signature.fc2.bias']) 108 | x_gamma_2 = F.linear(x, weights['encoder.signature.fc3.weight'],\ 109 | weights['encoder.signature.fc3.bias']) 110 | x_beta_2 = F.linear(x, weights['encoder.signature.fc4.weight'],\ 111 | weights['encoder.signature.fc4.bias']) 112 | return torch.tanh(x_gamma_1), torch.tanh(x_beta_1),\ 113 | torch.tanh(x_gamma_2), torch.tanh(x_beta_2) 114 | 115 | class MetaSignatureEncoder(torch.nn.Module): 116 | def __init__(self, args, in_channels, out_channels): 117 | super(MetaSignatureEncoder, self).__init__() 118 | self.args = args 119 | self.conv1 = MetaGCNConv(in_channels, 2 * out_channels, cached=False) 120 | if args.model in ['GAE']: 121 | self.conv2 = MetaGCNConv(2 * out_channels, out_channels, cached=False) 122 | elif args.model in ['VGAE']: 123 | self.conv_mu = MetaGCNConv(2 * out_channels, out_channels, cached=False) 124 | self.conv_logvar = MetaGCNConv( 125 | 2 * out_channels, out_channels, cached=False) 126 | # in_channels is the input feature dim 127 | self.signature = GraphSignature(args, in_channels, out_channels) 128 | 129 | def forward(self, x, edge_index, weights, inner_loop=True): 130 | keys = list(weights.keys()) 131 | sig_keys = [key for key in keys if 'signature' in key] 132 | if inner_loop: 133 | with torch.no_grad(): 134 | sig_gamma_1, sig_beta_1, sig_gamma_2, sig_beta_2 = self.signature(x, edge_index, weights, sig_keys) 135 | self.cache_sig_out = [sig_gamma_1,sig_beta_1,sig_gamma_2,sig_beta_2] 136 | else: 137 | sig_gamma_1, sig_beta_1, sig_gamma_2, sig_beta_2 = self.signature(x, edge_index, weights, sig_keys) 138 | self.cache_sig_out = [sig_gamma_1,sig_beta_1,sig_gamma_2,sig_beta_2] 139 | 140 | x = F.relu(self.conv1(x, edge_index, weights['encoder.conv1.weight'],\ 141 | weights['encoder.conv1.bias'], gamma=sig_gamma_1, beta=sig_beta_1)) 142 | if self.args.layer_norm: 143 | x = nn.LayerNorm(x.size()[1:], elementwise_affine=False)(x) 144 | if self.args.model in ['GAE']: 145 | x = self.conv2(x, edge_index,weights['encoder.conv2.weight'],\ 146 | weights['encoder.conv2.bias'],gamma=sig_gamma_2, beta=sig_beta_2) 147 | if self.args.layer_norm: 148 | x = nn.LayerNorm(x.size()[1:], elementwise_affine=False)(x) 149 | return x 150 | elif self.args.model in ['VGAE']: 151 | mu = self.conv_mu(x,edge_index,weights['encoder.conv_mu.weight'],\ 152 | weights['encoder.conv_mu.bias'], gamma=sig_gamma_2, beta=sig_beta_2) 153 | sig = self.conv_logvar(x,edge_index,weights['encoder.conv_logvar.weight'],\ 154 | weights['encoder.conv_logvar.bias'], gamma=sig_gamma_2, beta=sig_beta_2) 155 | if self.args.layer_norm: 156 | mu = nn.LayerNorm(mu.size()[1:], elementwise_affine=False)(mu) 157 | sig = nn.LayerNorm(sig.size()[1:], elementwise_affine=False)(sig) 158 | return mu, sig 159 | 160 | class MetaGatedSignatureEncoder(torch.nn.Module): 161 | def __init__(self, args, in_channels, out_channels): 162 | super(MetaGatedSignatureEncoder, self).__init__() 163 | self.args = args 164 | self.conv1 = MetaGatedGCNConv(in_channels, 2 * out_channels, gating=args.gating, cached=False) 165 | if args.model in ['GAE']: 166 | self.conv2 = MetaGatedGCNConv(2 * out_channels, out_channels, gating=args.gating, cached=False) 167 | elif args.model in ['VGAE']: 168 | self.conv_mu = MetaGatedGCNConv(2 * out_channels, out_channels, gating=args.gating, cached=False) 169 | self.conv_logvar = MetaGatedGCNConv( 170 | 2 * out_channels, out_channels, gating=args.gating, cached=False) 171 | # in_channels is the input feature dim 172 | self.signature = GraphSignature(args, in_channels, out_channels) 173 | 174 | def forward(self, x, edge_index, weights, inner_loop=True): 175 | keys = list(weights.keys()) 176 | sig_keys = [key for key in keys if 'signature' in key] 177 | if inner_loop: 178 | with torch.no_grad(): 179 | sig_gamma_1, sig_beta_1, sig_gamma_2, sig_beta_2 = self.signature(x, edge_index, weights, sig_keys) 180 | self.cache_sig_out = [sig_gamma_1,sig_beta_1,sig_gamma_2,sig_beta_2,\ 181 | torch.sigmoid(weights['encoder.conv1.gating_weights']),\ 182 | torch.sigmoid(weights['encoder.conv_mu.gating_weights']),\ 183 | torch.sigmoid(weights['encoder.conv_logvar.gating_weights'])] 184 | else: 185 | sig_gamma_1, sig_beta_1, sig_gamma_2, sig_beta_2 = self.signature(x, edge_index, weights, sig_keys) 186 | 187 | x = F.relu(self.conv1(x, edge_index,\ 188 | weights['encoder.conv1.weight_1'],\ 189 | weights['encoder.conv1.weight_2'],\ 190 | weights['encoder.conv1.bias'],\ 191 | weights['encoder.conv1.gating_weights'],\ 192 | gamma=sig_gamma_1, beta=sig_beta_1)) 193 | if self.args.layer_norm: 194 | x = nn.LayerNorm(x.size()[1:], elementwise_affine=False)(x) 195 | if self.args.model in ['GAE']: 196 | x = self.conv2(x, edge_index,\ 197 | weights['encoder.conv_mu.weight_1'],\ 198 | weights['encoder.conv_mu.weight_2'],\ 199 | weights['encoder.conv_mu.bias'],\ 200 | weights['encoder.conv_mu.gating_weights'],\ 201 | gamma=sig_gamma_2, beta=sig_beta_2) 202 | if self.args.layer_norm: 203 | x = nn.LayerNorm(x.size()[1:], elementwise_affine=False)(x) 204 | return x 205 | elif self.args.model in ['VGAE']: 206 | mu = self.conv_mu(x,edge_index,\ 207 | weights['encoder.conv_mu.weight_1'],\ 208 | weights['encoder.conv_mu.weight_2'],\ 209 | weights['encoder.conv_mu.bias'],\ 210 | weights['encoder.conv_mu.gating_weights'],\ 211 | gamma=sig_gamma_2, beta=sig_beta_2) 212 | sig = self.conv_logvar(x,edge_index,\ 213 | weights['encoder.conv_logvar.weight_1'],\ 214 | weights['encoder.conv_logvar.weight_2'],\ 215 | weights['encoder.conv_logvar.bias'],\ 216 | weights['encoder.conv_logvar.gating_weights'],\ 217 | gamma=sig_gamma_2, beta=sig_beta_2) 218 | if self.args.layer_norm: 219 | mu = nn.LayerNorm(mu.size()[1:], elementwise_affine=False)(mu) 220 | sig = nn.LayerNorm(sig.size()[1:], elementwise_affine=False)(sig) 221 | return mu, sig 222 | 223 | class MetaEncoder(torch.nn.Module): 224 | def __init__(self, args, in_channels, out_channels): 225 | super(MetaEncoder, self).__init__() 226 | self.args = args 227 | self.conv1 = MetaGCNConv(in_channels, 2 * out_channels, cached=False) 228 | if args.model in ['GAE']: 229 | self.conv2 = MetaGCNConv(2 * out_channels, out_channels, cached=False) 230 | elif args.model in ['VGAE']: 231 | self.conv_mu = MetaGCNConv(2 * out_channels, out_channels, cached=False) 232 | self.conv_logvar = MetaGCNConv( 233 | 2 * out_channels, out_channels, cached=False) 234 | 235 | def forward(self, x, edge_index, weights, inner_loop=True): 236 | x = F.relu(self.conv1(x, edge_index, \ 237 | weights['encoder.conv1.weight'],weights['encoder.conv1.bias'])) 238 | if self.args.model in ['GAE']: 239 | return self.conv2(x, edge_index,\ 240 | weights['encoder.conv2.weight'],weights['encoder.conv2.bias']) 241 | elif self.args.model in ['VGAE']: 242 | return self.conv_mu(x,edge_index,weights['encoder.conv_mu.weight'],\ 243 | weights['encoder.conv_mu.bias']),\ 244 | self.conv_logvar(x,edge_index,weights['encoder.conv_logvar.weight'],\ 245 | weights['encoder.conv_logvar.bias']) 246 | 247 | class Net(torch.nn.Module): 248 | def __init__(self,train_dataset): 249 | super(Net, self).__init__() 250 | self.conv1 = GATConv(train_dataset.num_features, 256, heads=4) 251 | self.lin1 = torch.nn.Linear(train_dataset.num_features, 4 * 256) 252 | self.conv2 = GATConv(4 * 256, 256, heads=4) 253 | self.lin2 = torch.nn.Linear(4 * 256, 4 * 256) 254 | self.conv3 = GATConv( 255 | 4 * 256, train_dataset.num_classes, heads=6, concat=False) 256 | self.lin3 = torch.nn.Linear(4 * 256, train_dataset.num_classes) 257 | 258 | def forward(self, x, edge_index): 259 | x = F.elu(self.conv1(x, edge_index) + self.lin1(x)) 260 | x = F.elu(self.conv2(x, edge_index) + self.lin2(x)) 261 | x = self.conv3(x, edge_index) + self.lin3(x) 262 | return x 263 | 264 | -------------------------------------------------------------------------------- /plotting/plot_comet.py: -------------------------------------------------------------------------------- 1 | from comet_ml import API 2 | import argparse 3 | import csv 4 | import json 5 | import os 6 | from statistics import mean 7 | import wandb 8 | import matplotlib 9 | import numpy as np 10 | import ipdb 11 | 12 | matplotlib.use('Agg') 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | 16 | # Set plotting style 17 | sns.set_context('paper', font_scale=1.3) 18 | sns.set_style('whitegrid') 19 | sns.set_palette('colorblind') 20 | plt.rcParams['text.usetex'] = False 21 | 22 | def SetPlotRC(): 23 | #If fonttype = 1 doesn't work with LaTeX, try fonttype 42. 24 | plt.rc('pdf',fonttype = 42) 25 | plt.rc('ps',fonttype = 42) 26 | 27 | def ApplyFont(ax): 28 | 29 | ticks = ax.get_xticklabels() + ax.get_yticklabels() 30 | 31 | text_size = 14.0 32 | 33 | for t in ticks: 34 | t.set_fontname('Times New Roman') 35 | t.set_fontsize(text_size) 36 | 37 | txt = ax.get_xlabel() 38 | txt_obj = ax.set_xlabel(txt) 39 | txt_obj.set_fontname('Times New Roman') 40 | txt_obj.set_fontsize(text_size) 41 | 42 | txt = ax.get_ylabel() 43 | txt_obj = ax.set_ylabe(txt) 44 | txt_obj.set_fontname('Times New Roman') 45 | txt_obj.set_fontsize(text_size) 46 | 47 | txt = ax.get_title() 48 | txt_obj = ax.set_title(txt) 49 | txt_obj.set_fontname('Times New Roman') 50 | txt_obj.set_fontsize(text_size) 51 | 52 | SetPlotRC() 53 | 54 | def connect_to_comet(comet_apikey,comet_restapikey,comet_username,comet_project): 55 | if os.path.isfile("settings.json"): 56 | with open("settings.json") as f: 57 | keys = json.load(f) 58 | comet_apikey = keys.get("apikey") 59 | comet_username = keys.get("username") 60 | comet_restapikey = keys.get("restapikey") 61 | 62 | print("COMET_REST_API_KEY=%s" %(comet_restapikey)) 63 | with open('.env', 'w') as writer: 64 | writer.write("COMET_API_KEY=%s\n" %(comet_apikey)) 65 | writer.write("COMET_REST_API_KEY=%s\n" %(comet_restapikey)) 66 | 67 | comet_api = API() 68 | return comet_api, comet_username, comet_project 69 | 70 | def data_to_extract(username,args): 71 | labels = {} 72 | labels['title'] = "PPI Link Prediction" 73 | labels['x_label'] = "Iterations" 74 | labels['y_label'] = "Percent" 75 | if args.local: 76 | param_str = 'Local' 77 | else: 78 | param_str = 'Global' 79 | 80 | labels['train_metric_auc'] = "Train_" + param_str + "_Graph_" 81 | labels['train_metric_ap'] = "Train_" + param_str + "_Graph_" 82 | labels['test_metric_auc'] = "Test_" + param_str + "_Graph_" 83 | labels['test_metric_ap'] = "Test_" + param_str + "_Graph_" 84 | if username == "joeybose": 85 | labels['experiments_key'] = [[args.one_maml],\ 86 | [args.two_maml],\ 87 | [args.random_exp],\ 88 | [args.no_finetune],\ 89 | [args.finetune],\ 90 | [args.adamic_adar]\ 91 | ] 92 | if args.local: 93 | labels['experiments_name'] = ['1-MAML','2-MAML', 'NoFinetune',\ 94 | 'Finetune','Adamic-Adar'] 95 | else: 96 | labels['experiments_name'] = ['1-MAML','2-MAML', 'Random', 'NoFinetune',\ 97 | 'Finetune'] 98 | 99 | return labels 100 | 101 | def data_to_extract_ppi(username,args): 102 | labels = {} 103 | labels['title'] = "PPI Link Prediction" 104 | labels['x_label'] = "Iterations" 105 | labels['y_label'] = "Percent" 106 | if args.local: 107 | param_str = 'Local' 108 | else: 109 | param_str = 'Global' 110 | 111 | labels['train_metric_auc'] = "Train_" + param_str + "_Graph_" 112 | labels['train_metric_ap'] = "Train_" + param_str + "_Graph_" 113 | labels['test_metric_auc'] = "Test_" + param_str + "_Graph_" 114 | labels['test_metric_ap'] = "Test_" + param_str + "_Graph_" 115 | if username == "joeybose": 116 | labels['experiments_key'] = [[args.two_maml],\ 117 | [args.concat],\ 118 | [args.random_exp],\ 119 | [args.no_finetune],\ 120 | [args.finetune],\ 121 | [args.adamic_adar],\ 122 | [args.mlp],\ 123 | [args.graph_sig]\ 124 | ] 125 | if args.local: 126 | labels['experiments_name'] = ['2-MAML', '2-MAML-Concat','NoFinetune',\ 127 | 'Finetune','Adamic-Adar','MLP', 'Inner-GraphSig'] 128 | else: 129 | labels['experiments_name'] = ['2-MAML', '2-MAML-Concat', 'Random', 'NoFinetune',\ 130 | 'Finetune'] 131 | args.local_block = 3 132 | args.global_block = [2] 133 | return labels 134 | 135 | def data_to_extract_enzymes(username,args): 136 | labels = {} 137 | labels['title'] = "Enzymes Link Prediction" 138 | labels['x_label'] = "Iterations" 139 | labels['y_label'] = "Percent" 140 | if args.local: 141 | param_str = 'Local' 142 | else: 143 | param_str = 'Global' 144 | 145 | labels['train_metric_auc'] = "Train_" + param_str + "_Graph_" 146 | labels['train_metric_ap'] = "Train_" + param_str + "_Graph_" 147 | labels['test_metric_auc'] = "Test_" + param_str + "_Graph_" 148 | labels['test_metric_ap'] = "Test_" + param_str + "_Graph_" 149 | if username == "joeybose": 150 | labels['experiments_key'] = [[args.two_maml],\ 151 | [args.concat],\ 152 | [args.random_exp],\ 153 | [args.no_finetune],\ 154 | [args.finetune],\ 155 | [args.mlp],\ 156 | [args.adamic_adar]\ 157 | ] 158 | if args.local: 159 | labels['experiments_name'] = ['2-MAML', '2-MAML-Concat','NoFinetune',\ 160 | 'Finetune','MLP','Adamic-Adar'] 161 | else: 162 | labels['experiments_name'] = ['2-MAML', '2-MAML-Concat', 'Random', 'NoFinetune',\ 163 | 'Finetune'] 164 | 165 | args.local_block = 3 166 | args.global_block = [2,6] 167 | return labels 168 | 169 | def data_to_extract_reddit(username,args): 170 | labels = {} 171 | labels['title'] = "Reddit Link Prediction" 172 | labels['x_label'] = "Iterations" 173 | labels['y_label'] = "Percent" 174 | if args.local: 175 | param_str = 'Local_Batch' 176 | else: 177 | param_str = 'Global_Batch' 178 | 179 | labels['train_metric_auc'] = "Train_" + param_str + "_Graph_" 180 | labels['train_metric_ap'] = "Train_" + param_str + "_Graph_" 181 | labels['test_metric_auc'] = "Test_" + param_str + "_Graph_" 182 | labels['test_metric_ap'] = "Test_" + param_str + "_Graph_" 183 | if username == "joeybose": 184 | labels['experiments_key'] = [[args.two_maml],\ 185 | [args.random_exp],\ 186 | [args.no_finetune],\ 187 | [args.finetune],\ 188 | [args.adamic_adar]\ 189 | ] 190 | if args.local: 191 | labels['experiments_name'] = ['2-MAML', 'NoFinetune',\ 192 | 'Finetune','Adamic-Adar'] 193 | else: 194 | labels['experiments_name'] = ['2-MAML', 'Random', 'NoFinetune',\ 195 | 'Finetune'] 196 | 197 | args.local_block = 3 198 | args.global_block = [2] 199 | return labels 200 | 201 | def truncate_exp(data_experiments): 202 | last_data_points = [run[-1] for data_run in data_experiments for run in data_run] 203 | run_end_times = [timestep for timestep, value in last_data_points] 204 | earliest_end_time = min(run_end_times) 205 | 206 | clean_data_experiments = [] 207 | for exp in data_experiments: 208 | clean_data_runs = [] 209 | for run in exp: 210 | clean_data_runs.append({x: y for x, y in run if x <= earliest_end_time}) 211 | clean_data_experiments.append(clean_data_runs) 212 | 213 | return clean_data_experiments 214 | 215 | def get_data(args, title, x_label, y_label, labels_list, data, COMET_API_KEY,\ 216 | COMET_REST_API_KEY,comet_username,comet_project): 217 | if not title or not x_label or not y_label or not labels_list: 218 | print("Error!!! Ensure filename, x and y labels,\ 219 | and metric are present.") 220 | exit(1) 221 | 222 | train_auc = labels_list['train_metric_auc'] 223 | train_ap = labels_list['train_metric_ap'] 224 | test_auc = labels_list['test_metric_auc'] 225 | test_ap = labels_list['test_metric_ap'] 226 | 227 | comet_api, comet_username, comet_project = connect_to_comet(COMET_API_KEY,\ 228 | COMET_REST_API_KEY,\ 229 | comet_username,\ 230 | comet_project) 231 | 232 | # Accumulate data for all experiments. 233 | data_experiments_auc = [] 234 | data_experiments_ap = [] 235 | for i, runs in enumerate(data): 236 | # Accumulate data for all runs of a given experiment. 237 | if i >= args.local_block and not args.local: 238 | break 239 | if (i in args.global_block) and args.local: 240 | continue 241 | data_runs_auc = [] 242 | data_runs_ap = [] 243 | if len(runs) > 0: 244 | for exp_key in runs: 245 | try: 246 | raw_data = comet_api.get("%s/%s/%s" %(comet_username,\ 247 | comet_project, exp_key)) 248 | if args.mode == 'Train': 249 | for j in range(0,args.num_train_graphs): 250 | metric_auc = train_auc + str(j) + "_AUC" 251 | metric_ap = train_ap + str(j) + "_AP" 252 | data_points_auc = raw_data.metrics_raw[metric_auc] 253 | data_points_ap = raw_data.metrics_raw[metric_ap] 254 | data_points_auc = [[point[0]+1,point[1]] for point in data_points_auc] 255 | data_points_ap = [[point[0]+1,point[1]] for point in data_points_ap] 256 | data_runs_auc.append(data_points_auc) 257 | data_runs_ap.append(data_points_ap) 258 | elif args.mode =='Test' and args.dataset =='Reddit': 259 | for k in range(0,args.num_test_graphs): 260 | metric_auc = test_auc + str(k) + "_AUC" 261 | metric_ap = test_ap + str(k) + "_AP" 262 | data_points_auc = raw_data.metrics_raw[metric_auc] 263 | data_points_ap = raw_data.metrics_raw[metric_ap] 264 | data_points_auc = [[point[0]+1,point[1]] for point in data_points_auc] 265 | data_points_ap = [[point[0]+1,point[1]] for point in data_points_ap] 266 | data_runs_auc.append(data_points_auc) 267 | data_runs_ap.append(data_points_ap) 268 | else: 269 | for k in range(0,args.num_test_graphs): 270 | metric_auc = test_auc + str(k) + "_AUC" 271 | metric_ap = test_ap + str(k) + "_AP" 272 | data_points_auc = raw_data.metrics_raw[metric_auc] 273 | data_points_ap = raw_data.metrics_raw[metric_ap] 274 | data_points_auc = [[point[0]+1,point[1]] for point in data_points_auc] 275 | data_points_ap = [[point[0]+1,point[1]] for point in data_points_ap] 276 | data_runs_auc.append(data_points_auc) 277 | data_runs_ap.append(data_points_ap) 278 | 279 | data_experiments_auc.append(data_runs_auc) 280 | data_experiments_ap.append(data_runs_ap) 281 | except: 282 | print("Failed on %s" %(exp_key)) 283 | 284 | clean_data_experiments_auc = truncate_exp(data_experiments_auc) 285 | clean_data_experiments_ap = truncate_exp(data_experiments_ap) 286 | return clean_data_experiments_auc, clean_data_experiments_ap 287 | 288 | def plot(**kwargs): 289 | labels = kwargs.get('labels') 290 | data = kwargs.get('data') 291 | 292 | # Setup figure 293 | fig = plt.figure(figsize=(12, 8)) 294 | ax = plt.subplot() 295 | 296 | for label in (ax.get_xticklabels()): 297 | label.set_fontname('Arial') 298 | label.set_fontsize(20) 299 | for label in (ax.get_yticklabels()): 300 | label.set_fontname('Arial') 301 | label.set_fontsize(20) 302 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 303 | plt.yticks(np.arange(0, 1, 0.1)) 304 | ax.xaxis.get_offset_text().set_fontsize(10) 305 | axis_font = {'fontname': 'Arial', 'size': '24'} 306 | colors = sns.color_palette('colorblind', n_colors=len(data)) 307 | 308 | # Plot data 309 | for runs, label, color in zip(data, labels.get('experiments_name'), colors): 310 | unique_x_values = set() 311 | for run in runs: 312 | for key in run.keys(): 313 | unique_x_values.add(key) 314 | x_values = sorted(unique_x_values) 315 | 316 | # Plot mean and standard deviation of all runs 317 | y_values_mean = [] 318 | y_values_std = [] 319 | 320 | for x in x_values: 321 | y_values_mean.append(mean([run.get(x) for run in runs if run.get(x)])) 322 | y_values_std.append(np.std([run.get(x) for run in runs if run.get(x)])) 323 | 324 | x_values.insert(0,0) 325 | y_values_mean.insert(0,0) 326 | y_values_std.insert(0,0) 327 | print("%s average result after graphs %f" %(label,y_values_mean[-1])) 328 | # Plot std 329 | ax.fill_between(x_values, np.add(np.array(y_values_mean), np.array(y_values_std)), 330 | np.subtract(np.array(y_values_mean), np.array(y_values_std)), 331 | alpha=0.3, 332 | edgecolor=color, facecolor=color) 333 | # Plot mean 334 | plt.plot(x_values, y_values_mean, color=color, linewidth=1.5, label=label) 335 | 336 | # Label figure 337 | ax.legend(loc='lower right', prop={'size': 16}) 338 | ax.set_xlabel(labels.get('x_label'), **axis_font) 339 | ax.set_ylabel(labels.get('y_label'), **axis_font) 340 | fig.subplots_adjust(bottom=0.2) 341 | fig.subplots_adjust(left=0.2) 342 | 343 | # remove grid lines 344 | ax.grid(False) 345 | plt.grid(b=False, color='w') 346 | return fig 347 | 348 | def main(args): 349 | Joey_COMET_API_KEY="Ht9lkWvTm58fRo9ccgpabq5zV" 350 | Joey_COMET_REST_API_KEY="gvhm1m1y8OUTnPRqJarpeTapL" 351 | comet_project = args.comet_project 352 | comet_username = "joeybose" 353 | if args.dataset == 'PPI': 354 | extract_func = data_to_extract_ppi 355 | elif args.dataset == 'Reddit': 356 | comet_project = 'meta-graph-reddit' 357 | extract_func = data_to_extract_reddit 358 | 359 | labels = extract_func("joeybose",args) 360 | data_experiments_auc, data_experiments_ap = get_data(args,labels.get('title'), labels.get('x_label'),\ 361 | labels.get('y_label'), labels,\ 362 | labels.get('experiments_key'),COMET_API_KEY=Joey_COMET_API_KEY,\ 363 | COMET_REST_API_KEY=Joey_COMET_REST_API_KEY,\ 364 | comet_project=comet_project,\ 365 | comet_username=comet_username) 366 | fig_auc = plot(labels=labels, data=data_experiments_auc) 367 | fig_ap = plot(labels=labels, data=data_experiments_ap) 368 | if args.local: 369 | param_str = '_Local_' 370 | else: 371 | param_str = '_Global_' 372 | fig_auc.savefig('../plots_datasets/'+ args.dataset + '/' + args.file_str + 373 | param_str + args.mode +'_new_AUC.png') 374 | fig_ap.savefig('../plots_datasets/'+ args.dataset + '/' + args.file_str + 375 | param_str+ args.mode + '_new_AP.png') 376 | 377 | if __name__ == '__main__': 378 | parser = argparse.ArgumentParser() 379 | parser.add_argument('--source_filename', default='plot_source.csv') 380 | parser.add_argument("--local", action="store_true", default=False) 381 | parser.add_argument('--mode', type=str, default='Train') 382 | parser.add_argument('--file_str', type=str, default='') 383 | parser.add_argument('--one_maml', type=str, default='') 384 | parser.add_argument('--two_maml', type=str, default='') 385 | parser.add_argument('--concat', type=str, default='') 386 | parser.add_argument('--random_exp', type=str, default='') 387 | parser.add_argument('--no_finetune', type=str, default='') 388 | parser.add_argument('--finetune', type=str, default='') 389 | parser.add_argument('--adamic_adar', type=str, default='') 390 | parser.add_argument('--mlp', type=str, default='') 391 | parser.add_argument('--graph_sig', type=str, default='') 392 | parser.add_argument('--comet_project', type=str, default='meta-graph') 393 | parser.add_argument('--dataset', type=str, default='PPI') 394 | parser.add_argument('--local_block', type=int, default=0) 395 | parser.add_argument('--global_block', type=int, default=2) 396 | args = parser.parse_args() 397 | 398 | if args.dataset == 'PPI': 399 | args.num_train_graphs = 20 400 | args.num_test_graphs = 2 401 | elif args.dataset == 'ENZYMES': 402 | args.num_train_graphs = 10 403 | args.num_test_graphs = 10 404 | elif args.dataset == 'Reddit': 405 | args.num_train_graphs = 10 406 | args.num_test_graphs = 10 407 | else: 408 | raise NotImplementedError 409 | 410 | main(args) 411 | 412 | -------------------------------------------------------------------------------- /scripts/run_adamic_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | python main.py --meta_train_edge_ratio=0.1 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML Adamic PPI Ratio=0.1' --adamic_adar_baseline --comet 8 | python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML Adamic PPI Ratio=0.2' --adamic_adar_baseline --comet 9 | python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML Adamic PPI Ratio=0.3' --adamic_adar_baseline --comet 10 | python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML Adamic PPI Ratio=0.4' --adamic_adar_baseline --comet 11 | python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML Adamic PPI Ratio=0.5' --adamic_adar_baseline --comet 12 | python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML Adamic PPI Ratio=0.6' --adamic_adar_baseline --comet 13 | python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML Adamic PPI Ratio=0.7' --adamic_adar_baseline --comet 14 | -------------------------------------------------------------------------------- /scripts/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | # No Finetuning 8 | #python vgae.py --model='VGAE' --epochs=2500 --comet --meta_train_edge_ratio=0.1 --namestr='NoFinetune Ratio=0.1' 9 | #python vgae.py --model='VGAE' --epochs=2500 --comet --meta_train_edge_ratio=0.2 --namestr='NoFinetune Ratio=0.2' 10 | #python vgae.py --epochs=2500 --model='VGAE' --comet --meta_train_edge_ratio=0.3 --namestr='NoFinetune Ratio=0.3' 11 | python vgae.py --epochs=2500 --model='VGAE' --comet --meta_train_edge_ratio=0.4 --namestr='NoFinetune Ratio=0.4' 12 | python vgae.py --epochs=2500 --model='VGAE' --comet --meta_train_edge_ratio=0.5 --namestr='NoFinetune Ratio=0.5' 13 | python vgae.py --epochs=2500 --model='VGAE' --comet --meta_train_edge_ratio=0.6 --namestr='NoFinetune Ratio=0.6' 14 | python vgae.py --epochs=2500 --model='VGAE' --comet --meta_train_edge_ratio=0.7 --namestr='NoFinetune Ratio=0.7' 15 | python vgae.py --epochs=2500 --model='VGAE' --comet --meta_train_edge_ratio=0.8 --namestr='NoFinetune Ratio=0.8' 16 | 17 | ## Finetuning 18 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.1 --namestr='Finetune Ratio=0.1' 19 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.2 --namestr='Finetune Ratio=0.2' 20 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.3 --namestr='Finetune Ratio=0.3' 21 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.4 --namestr='Finetune Ratio=0.4' 22 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.5 --namestr='Finetune Ratio=0.5' 23 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.6 --namestr='Finetune Ratio=0.6' 24 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.7 --namestr='Finetune Ratio=0.7' 25 | python vgae.py --epochs=2500 --model='VGAE' --finetune --comet --meta_train_edge_ratio=0.8 --namestr='Finetune Ratio=0.8' 26 | 27 | # No Concat Finetuning 28 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.1 --namestr='Concat NoFinetune Ratio=0.1' 29 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.2 --namestr='Concat NoFinetune Ratio=0.2' 30 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.3 --namestr='Concat NoFinetune Ratio=0.3' 31 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.4 --namestr=a'Concat NoFinetune Ratio=0.4' 32 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.5 --namestr='Concat NoFinetune Ratio=0.5' 33 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.6 --namestr='Concat NoFinetune Ratio=0.6' 34 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.7 --namestr='Concat NoFinetune Ratio=0.7' 35 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --comet --meta_train_edge_ratio=0.8 --namestr='Concat NoFinetune Ratio=0.8' 36 | 37 | ## Concat Finetuning 38 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.1 --namestr='Concat Finetune Ratio=0.1' 39 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.2 --namestr='Concat Finetune Ratio=0.2' 40 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.3 --namestr='Concat Finetune Ratio=0.3' 41 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.4 --namestr='Concat Finetune Ratio=0.4' 42 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.5 --namestr='Concat Finetune Ratio=0.5' 43 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.6 --namestr='Concat Finetune Ratio=0.6' 44 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.7 --namestr='Concat Finetune Ratio=0.7' 45 | python vgae.py --epochs=2500 --model='VGAE' --concat_fixed_feats --finetune --comet --meta_train_edge_ratio=0.8 --namestr='Concat Finetune Ratio=0.8' 46 | 47 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --comet --meta_train_edge_ratio=0.2 --namestr='ENZYMES NoFinetune Ratio=0.2' 48 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --comet --meta_train_edge_ratio=0.3 --namestr='ENZYMES NoFinetune Ratio=0.3' 49 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --comet --meta_train_edge_ratio=0.4 --namestr='ENZYMES NoFinetune Ratio=0.4' 50 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --comet --meta_train_edge_ratio=0.5 --namestr='ENZYMES NoFinetune Ratio=0.5' 51 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --comet --meta_train_edge_ratio=0.6 --namestr='ENZYMES NoFinetune Ratio=0.6' 52 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --comet --meta_train_edge_ratio=0.7 --namestr='ENZYMES NoFinetune Ratio=0.7' 53 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --comet --meta_train_edge_ratio=0.8 --namestr='ENZYMES NoFinetune Ratio=0.8' 54 | 55 | # Finetuning 56 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.2 --namestr='ENZYMES Finetune Ratio=0.2' 57 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.3 --namestr='ENZYMES Finetune Ratio=0.3' 58 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.4 --namestr='ENZYMES Finetune Ratio=0.4' 59 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.5 --namestr='ENZYMES Finetune Ratio=0.5' 60 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.6 --namestr='ENZYMES Finetune Ratio=0.6' 61 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.7 --namestr='ENZYMES Finetune Ratio=0.7' 62 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.8 --namestr='ENZYMES Finetune Ratio=0.8' 63 | 64 | #Concat feats 65 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --comet --meta_train_edge_ratio=0.2 --namestr='ENZYMES Concat NoFinetune Ratio=0.2' 66 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --comet --meta_train_edge_ratio=0.3 --namestr='ENZYMES Concat NoFinetune Ratio=0.3' 67 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --comet --meta_train_edge_ratio=0.4 --namestr='ENZYMES Concat NoFinetune Ratio=0.4' 68 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --comet --meta_train_edge_ratio=0.5 --namestr='ENZYMES Concat NoFinetune Ratio=0.5' 69 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --comet --meta_train_edge_ratio=0.6 --namestr='ENZYMES Concat NoFinetune Ratio=0.6' 70 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --comet --meta_train_edge_ratio=0.7 --namestr='ENZYMES Concat NoFinetune Ratio=0.7' 71 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --comet --meta_train_edge_ratio=0.8 --namestr='ENZYMES Concat NoFinetune Ratio=0.8' 72 | 73 | # Finetuning 74 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.2 --namestr='ENZYMES Concat Finetune Ratio=0.2' 75 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.3 --namestr='ENZYMES Concat Finetune Ratio=0.3' 76 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.4 --namestr='ENZYMES Concat Finetune Ratio=0.4' 77 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.5 --namestr='ENZYMES Concat Finetune Ratio=0.5' 78 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.6 --namestr='ENZYMES Concat Finetune Ratio=0.6' 79 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.7 --namestr='ENZYMES Concat Finetune Ratio=0.7' 80 | python vgae.py --epochs=2500 --model='VGAE' --train_batch_size=4 --concat_fixed_feats --dataset=ENZYMES --finetune --comet --meta_train_edge_ratio=0.8 --namestr='ENZYMES Concat Finetune Ratio=0.8' 81 | -------------------------------------------------------------------------------- /scripts/run_hyperparam.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | echo "Running Batch Size sweep PPI experiments" 7 | 8 | python main.py --meta_train_edge_ratio=0.2 --epochs=10 --train_batch_size=1 --model='VGAE' --order=2 --namestr='BS 1 2-MAML PPI Ratio=0.2' --comet 9 | python main.py --meta_train_edge_ratio=0.2 --epochs=10 --train_batch_size=2 --model='VGAE' --order=2 --namestr='BS 2 2-MAML PPI Ratio=0.2' --comet 10 | python main.py --meta_train_edge_ratio=0.2 --epochs=10 --train_batch_size=4 --model='VGAE' --order=2 --namestr='BS 4 2-MAML PPI Ratio=0.2' --comet 11 | python main.py --meta_train_edge_ratio=0.2 --epochs=10 --train_batch_size=5 --model='VGAE' --order=2 --namestr='BS 5 2-MAML PPI Ratio=0.2' --comet 12 | python main.py --meta_train_edge_ratio=0.2 --epochs=10 --train_batch_size=10 --model='VGAE' --order=2 --namestr='BS 10 2-MAML PPI Ratio=0.2' --comet 13 | -------------------------------------------------------------------------------- /scripts/run_maml.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | echo "Running PPI experiments" 7 | 8 | #python main.py --meta_train_edge_ratio=0.1 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.1' --comet 9 | #python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.2' --comet 10 | #python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.3' --comet 11 | #python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.4' --comet 12 | #python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.5' --comet 13 | #python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.6' --comet 14 | #python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.7' --comet 15 | #python main.py --meta_train_edge_ratio=0.8 --model='VGAE' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML PPI Ratio=0.8' --comet 16 | 17 | #python main.py --meta_train_edge_ratio=0.1 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML MLP PPI Ratio=0.1' --comet 18 | #python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML MLP PPI Ratio=0.2' --comet 19 | #python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML MLP PPI Ratio=0.3' --comet 20 | #python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML MLP PPI Ratio=0.4' --comet 21 | #python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML MLP PPI Ratio=0.5' --comet 22 | #python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML MLP PPI Ratio=0.6' --comet 23 | #python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML MLP PPI Ratio=0.7' --comet 24 | 25 | echo "Running PPI Graph Signature experiments" 26 | python main.py --meta_train_edge_ratio=0.1 --model='VGAE' --encoder='GraphSignature' --inner-lr=1e-3 --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML GS PPI Ratio=0.1' --comet 27 | python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --encoder='GraphSignature' --inner-lr=1e-3 --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML GS PPI Ratio=0.2' --comet 28 | python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --encoder='GraphSignature' --inner-lr=1e-3 --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML GS PPI Ratio=0.3' --comet 29 | python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --encoder='GraphSignature' --inner-lr=1e-3 --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML GS PPI Ratio=0.4' --comet 30 | python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --encoder='GraphSignature' --inner-lr=1e-3 --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML GS PPI Ratio=0.5' --comet 31 | python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --encoder='GraphSignature' --inner-lr=1e-3 --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML GS PPI Ratio=0.6' --comet 32 | python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --encoder='GraphSignature' --inner-lr=1e-3 --epochs=50 --train_batch_size=1 --order=2 --namestr='2-MAML GS PPI Ratio=0.7' --comet 33 | 34 | echo "Running PPI with concat-feats experiments" 35 | 36 | #python main.py --meta_train_edge_ratio=0.1 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --order=2 --namestr='2-MAML Concat PPI Ratio=0.1' --comet 37 | #python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --order=2 --namestr='2-MAML Concat PPI Ratio=0.2' --comet 38 | #python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --order=2 --namestr='2-MAML Concat PPI Ratio=0.3' --comet 39 | #python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --order=2 --namestr='2-MAML Concat PPI Ratio=0.4' --comet 40 | #python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --order=2 --namestr='2-MAML Concat PPI Ratio=0.5' --comet 41 | #python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --order=2 --namestr='2-MAML Concat PPI Ratio=0.6' --comet 42 | #python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --order=2 --namestr='2-MAML Concat PPI Ratio=0.7' --comet 43 | 44 | echo "Running ENZYMES experiments" 45 | 46 | #python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML ENZYMES Ratio=0.2' --comet 47 | #python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML ENZYMES Ratio=0.3' --comet 48 | #python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML ENZYMES Ratio=0.4' --comet 49 | #python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML ENZYMES Ratio=0.5' --comet 50 | #python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML ENZYMES Ratio=0.6' --comet 51 | #python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML ENZYMES Ratio=0.7' --comet 52 | 53 | 54 | echo "Running ENZYMES with concat-feats experiments" 55 | 56 | #python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML Concat ENZYMES Ratio=0.2' --comet 57 | #python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML Concat ENZYMES Ratio=0.3' --comet 58 | #python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML Concat ENZYMES Ratio=0.4' --comet 59 | #python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML Concat ENZYMES Ratio=0.5' --comet 60 | #python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML Concat ENZYMES Ratio=0.6' --comet 61 | #python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --epochs=50 --concat_fixed_feats --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML Concat ENZYMES Ratio=0.7' --comet 62 | 63 | echo "Running ENZYMES MLP experiments" 64 | 65 | #python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML MLP ENZYMES Ratio=0.2' --comet 66 | #python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML MLP ENZYMES Ratio=0.3' --comet 67 | #python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML MLP ENZYMES Ratio=0.4' --comet 68 | #python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML MLP ENZYMES Ratio=0.5' --comet 69 | #python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML MLP ENZYMES Ratio=0.6' --comet 70 | #python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --encoder='MLP' --epochs=50 --train_batch_size=1 --dataset=ENZYMES --order=2 --namestr='2-MAML MLP ENZYMES Ratio=0.7' --comet 71 | 72 | echo "Running ENZYMES Random Baseline experiments" 73 | 74 | #python main.py --epochs=50 --meta_train_edge_ratio=0.1 --dataset=ENZYMES --random_baseline --namestr='Random ENZYMES Baseline Ratio=0.1' --comet 75 | #python main.py --epochs=50 --meta_train_edge_ratio=0.2 --dataset=ENZYMES --random_baseline --namestr='Random ENZYMES Baseline Ratio=0.2' --comet 76 | #python main.py --epochs=50 --meta_train_edge_ratio=0.3 --dataset=ENZYMES --random_baseline --namestr='Random ENZYMES Baseline Ratio=0.3' --comet 77 | #python main.py --epochs=50 --meta_train_edge_ratio=0.4 --dataset=ENZYMES --random_baseline --namestr='Random ENZYMES Baseline Ratio=0.4' --comet 78 | #python main.py --epochs=50 --meta_train_edge_ratio=0.5 --dataset=ENZYMES --random_baseline --namestr='Random ENZYMES Baseline Ratio=0.5' --comet 79 | #python main.py --epochs=50 --meta_train_edge_ratio=0.6 --dataset=ENZYMES --random_baseline --namestr='Random ENZYMES Baseline Ratio=0.6' --comet 80 | #python main.py --epochs=50 --meta_train_edge_ratio=0.7 --dataset=ENZYMES --random_baseline --namestr='Random ENZYMES Baseline Ratio=0.7' --comet 81 | 82 | echo "Running ENZYMES Adamic Baseline experiments" 83 | 84 | #python main.py --meta_train_edge_ratio=0.1 --model='VGAE' --epochs=50 --dataset=ENZYMES --train_batch_size=1 --order=2 --namestr='2-MAML ENZYMES Adamic Ratio=0.1' --adamic_adar_baseline --comet 85 | #python main.py --meta_train_edge_ratio=0.2 --model='VGAE' --epochs=50 --dataset=ENZYMES --train_batch_size=1 --order=2 --namestr='2-MAML ENZYMES Adamic Ratio=0.2' --adamic_adar_baseline --comet 86 | #python main.py --meta_train_edge_ratio=0.3 --model='VGAE' --epochs=50 --dataset=ENZYMES --train_batch_size=1 --order=2 --namestr='2-MAML ENZYMES Adamic Ratio=0.3' --adamic_adar_baseline --comet 87 | #python main.py --meta_train_edge_ratio=0.4 --model='VGAE' --epochs=50 --dataset=ENZYMES --train_batch_size=1 --order=2 --namestr='2-MAML ENZYMES Adamic Ratio=0.4' --adamic_adar_baseline --comet 88 | #python main.py --meta_train_edge_ratio=0.5 --model='VGAE' --epochs=50 --dataset=ENZYMES --train_batch_size=1 --order=2 --namestr='2-MAML ENZYMES Adamic Ratio=0.5' --adamic_adar_baseline --comet 89 | #python main.py --meta_train_edge_ratio=0.6 --model='VGAE' --epochs=50 --dataset=ENZYMES --train_batch_size=1 --order=2 --namestr='2-MAML ENZYMES Adamic Ratio=0.6' --adamic_adar_baseline --comet 90 | #python main.py --meta_train_edge_ratio=0.7 --model='VGAE' --epochs=50 --dataset=ENZYMES --train_batch_size=1 --order=2 --namestr='2-MAML ENZYMES Adamic Ratio=0.7' --adamic_adar_baseline --comet 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /scripts/run_plotter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | ### 0.1 8 | python plot_comet.py --mode='Train' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 9 | python plot_comet.py --mode='Test' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 10 | python plot_comet.py --local --mode='Train' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 11 | python plot_comet.py --local --mode='Test' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 12 | 13 | #### 0.2 14 | python plot_comet.py --mode='Train' --file_str='0.2' --two_maml='cf3af6f492144444950e847d67e747ea' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='0e07866b88914e3881333263d74b2a2a' --finetune='767dee3c84234285bada554a36c86ce5' --adamic_adar='c27f337d03574000aa1c17bce24c0ff5' --mlp='6006de62f3544474ba00f53d4388240e' --graph_sig='cd266a4b5b044cf7a27eff8b9b20c7d5' 15 | python plot_comet.py --mode='Test' --file_str='0.2' --two_maml='cf3af6f492144444950e847d67e747ea' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='0e07866b88914e3881333263d74b2a2a' --finetune='767dee3c84234285bada554a36c86ce5' --adamic_adar='c27f337d03574000aa1c17bce24c0ff5' --mlp='6006de62f3544474ba00f53d4388240e' --graph_sig='cd266a4b5b044cf7a27eff8b9b20c7d5' 16 | python plot_comet.py --local --mode='Train' --file_str='0.2' --two_maml='cf3af6f492144444950e847d67e747ea' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='0e07866b88914e3881333263d74b2a2a' --finetune='767dee3c84234285bada554a36c86ce5' --adamic_adar='c27f337d03574000aa1c17bce24c0ff5' --mlp='6006de62f3544474ba00f53d4388240e' --graph_sig='cd266a4b5b044cf7a27eff8b9b20c7d5' 17 | python plot_comet.py --local --mode='Test' --file_str='0.2' --two_maml='cf3af6f492144444950e847d67e747ea' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='0e07866b88914e3881333263d74b2a2a' --finetune='767dee3c84234285bada554a36c86ce5' --adamic_adar='c27f337d03574000aa1c17bce24c0ff5' --mlp='6006de62f3544474ba00f53d4388240e' --graph_sig='cd266a4b5b044cf7a27eff8b9b20c7d5' 18 | 19 | ### 0.3 20 | python plot_comet.py --mode='Train' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' --graph_sig='9f0330a4667d40e6aed822fb0a69a468' 21 | python plot_comet.py --mode='Test' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' --graph_sig='9f0330a4667d40e6aed822fb0a69a468' 22 | python plot_comet.py --local --mode='Train' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' --graph_sig='9f0330a4667d40e6aed822fb0a69a468' 23 | python plot_comet.py --local --mode='Test' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' --graph_sig='9f0330a4667d40e6aed822fb0a69a468' 24 | 25 | ### 0.4 26 | python plot_comet.py --mode='Train' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' --graph_sig='717b2e248c1e4315a5be2473a29da1f8' 27 | python plot_comet.py --mode='Test' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' --graph_sig='717b2e248c1e4315a5be2473a29da1f8' 28 | python plot_comet.py --local --mode='Train' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' --graph_sig='717b2e248c1e4315a5be2473a29da1f8' 29 | python plot_comet.py --local --mode='Test' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' --graph_sig='717b2e248c1e4315a5be2473a29da1f8' 30 | 31 | ### 0.5 32 | python plot_comet.py --mode='Train' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' --graph_sig='5ba32a33da814c468b4755c96f755f6c' 33 | python plot_comet.py --mode='Test' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' --graph_sig='5ba32a33da814c468b4755c96f755f6c' 34 | python plot_comet.py --local --mode='Train' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' --graph_sig='5ba32a33da814c468b4755c96f755f6c' 35 | python plot_comet.py --local --mode='Test' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' --graph_sig='5ba32a33da814c468b4755c96f755f6c' 36 | 37 | ### 0.6 38 | python plot_comet.py --mode='Train' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' --graph_sig='b92dadad1f5f47dc927a3f9015807a2f' 39 | python plot_comet.py --mode='Test' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' --graph_sig='b92dadad1f5f47dc927a3f9015807a2f' 40 | python plot_comet.py --local --mode='Train' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' --graph_sig='b92dadad1f5f47dc927a3f9015807a2f' 41 | python plot_comet.py --local --mode='Test' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' --graph_sig='b92dadad1f5f47dc927a3f9015807a2f' 42 | 43 | ### 0.7 44 | python plot_comet.py --mode='Train' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' --graph_sig='fb9de0595da047bb9614d7c8bddf190c' 45 | python plot_comet.py --mode='Test' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' --graph_sig='fb9de0595da047bb9614d7c8bddf190c' 46 | python plot_comet.py --local --mode='Train' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' --graph_sig='fb9de0595da047bb9614d7c8bddf190c' 47 | python plot_comet.py --local --mode='Test' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' --graph_sig='fb9de0595da047bb9614d7c8bddf190c' 48 | 49 | -------------------------------------------------------------------------------- /scripts/run_plotter_aminer_grad_wandb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | echo "AMINER 0.1" 8 | ### 0.1 9 | python plot_wandb.py --mode='Train' --file_str='0.1' --two_maml='6r1yr00u' --no_finetune='4ivd72vi' --finetune='dxi5g89l' --mlp='x2le9exg' --graph_sig='32ovw30r' --dataset='AMINER' --get_grad_steps 10 | echo "AMINER 0.2" 11 | ### 0.2 12 | python plot_wandb.py --mode='Train' --file_str='0.2' --two_maml='2evksgb8' --no_finetune='h1l3y9di' --finetune='wqd4c252 ' --mlp='zqdw5c0g' --graph_sig='32ovw30r' --dataset='AMINER' --get_grad_steps 13 | echo "AMINER 0.3" 14 | ### 0.3 15 | python plot_wandb.py --mode='Train' --file_str='0.3' --two_maml='xefnxixv' --no_finetune='x1j83n08' --finetune='vd54v95d' --mlp='jrjl09cr' --graph_sig='mqhij7tm' --dataset='AMINER' --get_grad_steps 16 | echo "AMINER 0.4" 17 | ### 0.4 18 | python plot_wandb.py --mode='Train' --file_str='0.4' --two_maml='5ny31pj1' --no_finetune='ivg4ymhc' --finetune='lq2xbxly' --mlp='jrjl09cr' --graph_sig='b5xbsxht' --dataset='AMINER' --get_grad_steps 19 | echo "AMINER 0.5" 20 | ### 0.5 21 | python plot_wandb.py --mode='Train' --file_str='0.5' --two_maml='f6vjm2t1' --no_finetune='uc30rjw4' --finetune='uhzcx9tv' --mlp='la0af8s8' --graph_sig='a10djfgb' --dataset='AMINER' --get_grad_steps 22 | echo "AMINER 0.6" 23 | ### 0.6 24 | python plot_wandb.py --mode='Train' --file_str='0.6' --two_maml='jmk7lwgb' --no_finetune='4wrbcf4t' --finetune='yf37ohc3' --mlp='pkec71vq' --graph_sig='21qnqbcz' --dataset='AMINER' --get_grad_steps 25 | echo "AMINER 0.7" 26 | ### 0.7 27 | python plot_wandb.py --mode='Train' --file_str='0.7' --two_maml='yuyjl239' --no_finetune='n718az0j' --finetune='uovbbb2w' --mlp='sa0yq1q8' --graph_sig='pnj299hl' --dataset='AMINER' --get_grad_steps 28 | -------------------------------------------------------------------------------- /scripts/run_plotter_enzymes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | #### 0.1 8 | #python plot_comet.py --mode='Train' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' 9 | #python plot_comet.py --mode='Test' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' 10 | #python plot_comet.py --local --mode='Train' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' 11 | #python plot_comet.py --local --mode='Test' --file_str='0.1' --two_maml='6a5f16fa27ee4975ac8826119d5555a6' --concat='f592cbe24cc94453a8c24caa428ba19b' --random_exp='641efaec558a43d3aaa321159221ed5c' --no_finetune='1a59966fd8a54f43bbb4cb4cfb2020ef' --finetune='7a429694366d4f53be3f9f7b4a5e45e5' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' 12 | 13 | ### 0.2 14 | python plot_comet.py --dataset='ENZYMES' --mode='Train' --file_str='0.2' --two_maml='78db36aeadcd4c8d8e32f716c4a85863' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='bc1923dc5a4f42cf9d1f5bb9f368e8ff' --finetune='c320a18f3cd2496ab0af3a5a1768d7e3' --adamic_adar='9e9de9c3eafb41fbb03e8e33607077e7' --mlp='137a1c326f604e48920d60af65432eb9' 15 | python plot_comet.py --dataset='ENZYMES' --mode='Test' --file_str='0.2' --two_maml='78db36aeadcd4c8d8e32f716c4a85863' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='bc1923dc5a4f42cf9d1f5bb9f368e8ff' --finetune='c320a18f3cd2496ab0af3a5a1768d7e3' --adamic_adar='9e9de9c3eafb41fbb03e8e33607077e7' --mlp='137a1c326f604e48920d60af65432eb9' 16 | python plot_comet.py --dataset='ENZYMES' --local --mode='Train' --file_str='0.2' --two_maml='78db36aeadcd4c8d8e32f716c4a85863' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='bc1923dc5a4f42cf9d1f5bb9f368e8ff' --finetune='c320a18f3cd2496ab0af3a5a1768d7e3' --adamic_adar='9e9de9c3eafb41fbb03e8e33607077e7' --mlp='137a1c326f604e48920d60af65432eb9' 17 | python plot_comet.py --dataset='ENZYMES' --local --mode='Test' --file_str='0.2' --two_maml='78db36aeadcd4c8d8e32f716c4a85863' --concat='4a8c166c208448f3b59e5059a5c1773c' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='bc1923dc5a4f42cf9d1f5bb9f368e8ff' --finetune='c320a18f3cd2496ab0af3a5a1768d7e3' --adamic_adar='9e9de9c3eafb41fbb03e8e33607077e7' --mlp='137a1c326f604e48920d60af65432eb9' 18 | 19 | ## 0.3 20 | python plot_comet.py --dataset='ENZYMES' --mode='Train' --file_str='0.3' --two_maml='e3202b0474fd41f8824e6d1fa13d7c1d' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='f285cc7ca3d9485c8ba4eaeea6bd54dc' --finetune='f7104f4c96914e378bb1e5733cc061ba' --adamic_adar='3eb9517a98984d2d98c57cf4e89244ca' --mlp='95643f80b0b348a48117c13321910025' 21 | python plot_comet.py --dataset='ENZYMES' --mode='Test' --file_str='0.3' --two_maml='e3202b0474fd41f8824e6d1fa13d7c1d' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='f285cc7ca3d9485c8ba4eaeea6bd54dc' --finetune='f7104f4c96914e378bb1e5733cc061ba' --adamic_adar='3eb9517a98984d2d98c57cf4e89244ca' --mlp='95643f80b0b348a48117c13321910025' 22 | python plot_comet.py --dataset='ENZYMES' --local --mode='Train' --file_str='0.3' --two_maml='e3202b0474fd41f8824e6d1fa13d7c1d' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='f285cc7ca3d9485c8ba4eaeea6bd54dc' --finetune='f7104f4c96914e378bb1e5733cc061ba' --adamic_adar='3eb9517a98984d2d98c57cf4e89244ca' --mlp='95643f80b0b348a48117c13321910025' 23 | python plot_comet.py --dataset='ENZYMES' --local --mode='Test' --file_str='0.3' --two_maml='e3202b0474fd41f8824e6d1fa13d7c1d' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='f285cc7ca3d9485c8ba4eaeea6bd54dc' --finetune='f7104f4c96914e378bb1e5733cc061ba' --adamic_adar='3eb9517a98984d2d98c57cf4e89244ca' --mlp='95643f80b0b348a48117c13321910025' 24 | 25 | ## 0.4 26 | python plot_comet.py --dataset='ENZYMES' --mode='Train' --file_str='0.4' --two_maml='4203f23f7d294bc09ef7eab0d7dae814' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='d198387a81074a92a74b330266f3c2e2' --finetune='4f7360c275784f8f9cff1081da02923b' --adamic_adar='43997010f15e4cf1be4fdbfefbcd41e3' --mlp='5d0adea90ffa49f296cc355a64243b46' 27 | python plot_comet.py --dataset='ENZYMES' --mode='Test' --file_str='0.4' --two_maml='4203f23f7d294bc09ef7eab0d7dae814' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='d198387a81074a92a74b330266f3c2e2' --finetune='4f7360c275784f8f9cff1081da02923b' --adamic_adar='43997010f15e4cf1be4fdbfefbcd41e3' --mlp='5d0adea90ffa49f296cc355a64243b46' 28 | python plot_comet.py --dataset='ENZYMES' --local --mode='Train' --file_str='0.4' --two_maml='4203f23f7d294bc09ef7eab0d7dae814' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='d198387a81074a92a74b330266f3c2e2' --finetune='4f7360c275784f8f9cff1081da02923b' --adamic_adar='43997010f15e4cf1be4fdbfefbcd41e3' --mlp='5d0adea90ffa49f296cc355a64243b46' 29 | python plot_comet.py --dataset='ENZYMES' --local --mode='Test' --file_str='0.4' --two_maml='4203f23f7d294bc09ef7eab0d7dae814' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='d198387a81074a92a74b330266f3c2e2' --finetune='4f7360c275784f8f9cff1081da02923b' --adamic_adar='43997010f15e4cf1be4fdbfefbcd41e3' --mlp='5d0adea90ffa49f296cc355a64243b46' 30 | 31 | ## 0.5 32 | python plot_comet.py --dataset='ENZYMES' --mode='Train' --file_str='0.5' --two_maml='72f242b60a474c3699a5968abdc5a7e9' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='311e40aa6b85458aac2ac9e206ff97d5' --finetune='8322bbf4cddb468ea61a13a43b3fb051' --adamic_adar='5d320c94acd747c1a92d02daef6e518e' --mlp='ffef67f171904fddb61b9b7d8fca6273' 33 | python plot_comet.py --dataset='ENZYMES' --mode='Test' --file_str='0.5' --two_maml='72f242b60a474c3699a5968abdc5a7e9' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='311e40aa6b85458aac2ac9e206ff97d5' --finetune='8322bbf4cddb468ea61a13a43b3fb051' --adamic_adar='5d320c94acd747c1a92d02daef6e518e' --mlp='ffef67f171904fddb61b9b7d8fca6273' 34 | python plot_comet.py --dataset='ENZYMES' --local --mode='Train' --file_str='0.5' --two_maml='72f242b60a474c3699a5968abdc5a7e9' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='311e40aa6b85458aac2ac9e206ff97d5' --finetune='8322bbf4cddb468ea61a13a43b3fb051' --adamic_adar='5d320c94acd747c1a92d02daef6e518e' --mlp='ffef67f171904fddb61b9b7d8fca6273' 35 | python plot_comet.py --dataset='ENZYMES' --local --mode='Test' --file_str='0.5' --two_maml='72f242b60a474c3699a5968abdc5a7e9' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='311e40aa6b85458aac2ac9e206ff97d5' --finetune='8322bbf4cddb468ea61a13a43b3fb051' --adamic_adar='5d320c94acd747c1a92d02daef6e518e' --mlp='ffef67f171904fddb61b9b7d8fca6273' 36 | 37 | ## 0.6 38 | python plot_comet.py --dataset='ENZYMES' --mode='Train' --file_str='0.6' --two_maml='c292a19b44a04753aa0f1cb6f6689af1' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='87fa62cbf9104819a7fc995249534d72' --finetune='544d993de28a42779fe44e4ee19fc291' --adamic_adar='53c3858b46fc4759bfa90178981af110' --mlp='8a7879a57f2d484b8b881bc7b7c36a37' 39 | python plot_comet.py --dataset='ENZYMES' --mode='Test' --file_str='0.6' --two_maml='c292a19b44a04753aa0f1cb6f6689af1' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='87fa62cbf9104819a7fc995249534d72' --finetune='544d993de28a42779fe44e4ee19fc291' --adamic_adar='53c3858b46fc4759bfa90178981af110' --mlp='8a7879a57f2d484b8b881bc7b7c36a37' 40 | python plot_comet.py --dataset='ENZYMES' --local --mode='Train' --file_str='0.6' --two_maml='c292a19b44a04753aa0f1cb6f6689af1' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='87fa62cbf9104819a7fc995249534d72' --finetune='544d993de28a42779fe44e4ee19fc291' --adamic_adar='53c3858b46fc4759bfa90178981af110' --mlp='8a7879a57f2d484b8b881bc7b7c36a37' 41 | python plot_comet.py --dataset='ENZYMES' --local --mode='Test' --file_str='0.6' --two_maml='c292a19b44a04753aa0f1cb6f6689af1' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='87fa62cbf9104819a7fc995249534d72' --finetune='544d993de28a42779fe44e4ee19fc291' --adamic_adar='53c3858b46fc4759bfa90178981af110' --mlp='8a7879a57f2d484b8b881bc7b7c36a37' 42 | 43 | ## 0.7 44 | python plot_comet.py --dataset='ENZYMES' --mode='Train' --file_str='0.7' --two_maml='ab319ac14f094e199a41bf3a982327a4' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='d5b808b929494346b5c8b7de5580a470' --finetune='0d12edc5ad4b4b1d964405171a70b0d1' --adamic_adar='d6ba9b13c4f8444bab53c97d8c839709' --mlp='04feb80fb1ee4b548eee5747804fe491' 45 | python plot_comet.py --dataset='ENZYMES' --mode='Test' --file_str='0.7' --two_maml='ab319ac14f094e199a41bf3a982327a4' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='d5b808b929494346b5c8b7de5580a470' --finetune='0d12edc5ad4b4b1d964405171a70b0d1' --adamic_adar='d6ba9b13c4f8444bab53c97d8c839709' --mlp='04feb80fb1ee4b548eee5747804fe491' 46 | python plot_comet.py --dataset='ENZYMES' --local --mode='Train' --file_str='0.7' --two_maml='ab319ac14f094e199a41bf3a982327a4' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='d5b808b929494346b5c8b7de5580a470' --finetune='0d12edc5ad4b4b1d964405171a70b0d1' --adamic_adar='d6ba9b13c4f8444bab53c97d8c839709' --mlp='04feb80fb1ee4b548eee5747804fe491' 47 | python plot_comet.py --dataset='ENZYMES' --local --mode='Test' --file_str='0.7' --two_maml='ab319ac14f094e199a41bf3a982327a4' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='d5b808b929494346b5c8b7de5580a470' --finetune='0d12edc5ad4b4b1d964405171a70b0d1' --adamic_adar='d6ba9b13c4f8444bab53c97d8c839709' --mlp='04feb80fb1ee4b548eee5747804fe491' 48 | 49 | -------------------------------------------------------------------------------- /scripts/run_plotter_firstmmdb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | ### 0.1 8 | python plot_wandb.py --mode='Train' --file_str='0.1' --two_maml='a2bjgpdt' --random_exp='12uofdeu' --concat='kwj1z43t' --no_finetune='xq1gtt92' --finetune='z6vs4ixd' --adamic_adar='69mzklke' --mlp='um7d9qos' --graph_sig='lq3t724e' --graph_sig_concat='7j4lo0nj' --graph_sig_random='12uofdeu' --dataset='FIRSTMM_DB' --no_bar_plot --global_block 3 4 5 9 | python plot_wandb.py --mode='Test' --file_str='0.1' --two_maml='a2bjgpdt' --random_exp='12uofdeu' --concat='kwj1z43t' --no_finetune='xq1gtt92' --finetune='z6vs4ixd' --adamic_adar='69mzklke' --mlp='um7d9qos' --graph_sig='lq3t724e' --graph_sig_concat='7j4lo0nj' --graph_sig_random='12uofdeu' --dataset='FIRSTMM_DB' --no_bar_plot --global_block 3 4 5 10 | python plot_wandb.py --local --mode='Train' --file_str='0.1' --two_maml='a2bjgpdt' --random_exp='12uofdeu' --concat='kwj1z43t' --no_finetune='xq1gtt92' --finetune='z6vs4ixd' --adamic_adar='69mzklke' --mlp='um7d9qos' --graph_sig='lq3t724e' --graph_sig_concat='7j4lo0nj' --graph_sig_random='12uofdeu' --dataset='FIRSTMM_DB' --no_bar_plot --global_block 3 4 5 11 | python plot_wandb.py --local --mode='Test' --file_str='0.1' --two_maml='a2bjgpdt' --random_exp='12uofdeu' --concat='kwj1z43t' --no_finetune='xq1gtt92' --finetune='z6vs4ixd' --adamic_adar='69mzklke' --mlp='um7d9qos' --graph_sig='lq3t724e' --graph_sig_concat='7j4lo0nj' --graph_sig_random='12uofdeu' --dataset='FIRSTMM_DB' --no_bar_plot --global_block 3 4 5 12 | 13 | #### 0.2 14 | #python plot_wandb.py --mode='Train' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 15 | #python plot_wandb.py --mode='Test' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 16 | #python plot_wandb.py --local --mode='Train' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 17 | #python plot_wandb.py --local --mode='Test' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 18 | 19 | ##### 0.3 20 | #python plot_wandb.py --mode='Train' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 21 | #python plot_wandb.py --mode='Test' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 22 | #python plot_wandb.py --local --mode='Train' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 23 | #python plot_wandb.py --local --mode='Test' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 24 | 25 | ##### 0.4 26 | #python plot_wandb.py --mode='Train' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 27 | #python plot_wandb.py --mode='Test' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 28 | #python plot_wandb.py --local --mode='Train' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 29 | #python plot_wandb.py --local --mode='Test' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 30 | 31 | ##### 0.5 32 | ##python plot_wandb.py --mode='Train' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 33 | ##python plot_wandb.py --mode='Test' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 34 | ##python plot_wandb.py --local --mode='Train' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 35 | ##python plot_wandb.py --local --mode='Test' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 36 | 37 | ##### 0.6 38 | ##python plot_wandb.py --mode='Train' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 39 | ##python plot_wandb.py --mode='Test' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 40 | ##python plot_wandb.py --local --mode='Train' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 41 | ##python plot_wandb.py --local --mode='Test' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='lq3t724e' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 1 --global_block 1 2 3 4 42 | 43 | #### 0.7 44 | #python plot_wandb.py --mode='Train' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 45 | #python plot_wandb.py --mode='Test' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 46 | #python plot_wandb.py --local --mode='Train' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 47 | #python plot_wandb.py --local --mode='Test' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='FIRSTMM_DB' --local_block 2 --global_block 3 4 5 48 | -------------------------------------------------------------------------------- /scripts/run_plotter_firstmmdb_grad_wandb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | echo "PPI 0.1" 8 | ### 0.1 9 | python plot_wandb.py --mode='Train' --file_str='0.1' --two_maml='28oseu07' --no_finetune='s96dkq1z' --finetune='df8e2mhh' --mlp='gm0z3yxd' --graph_sig='bs0c1glm' --dataset='FIRSTMM_DB' --get_grad_steps 10 | echo "PPI 0.2" 11 | ### 0.2 12 | python plot_wandb.py --mode='Train' --file_str='0.2' --two_maml='g6k3381h' --no_finetune='jpsfaism' --finetune='vowc1u48' --mlp='cgvz7534' --graph_sig='o8ewiyol' --dataset='FIRSTMM_DB' --get_grad_steps 13 | echo "PPI 0.3" 14 | ### 0.3 15 | python plot_wandb.py --mode='Train' --file_str='0.3' --two_maml='0lo19msp' --no_finetune='tlnxol4d' --finetune='nad2ex1q' --mlp='joqd1918' --graph_sig='c4t93v8h' --dataset='FIRSTMM_DB' --get_grad_steps 16 | echo "PPI 0.4" 17 | ### 0.4 18 | python plot_wandb.py --mode='Train' --file_str='0.4' --two_maml='c8d28g6x' --no_finetune='tfincgtq' --finetune='o53cszvf' --mlp='6absh1lw' --graph_sig='cq9wlh5n' --dataset='FIRSTMM_DB' --get_grad_steps 19 | echo "PPI 0.5" 20 | ### 0.5 21 | python plot_wandb.py --mode='Train' --file_str='0.5' --two_maml='pahhhnd2' --no_finetune='0dg1jrid' --finetune='bipbpuqx' --mlp='z38ximym' --graph_sig='5gdfdlys' --dataset='FIRSTMM_DB' --get_grad_steps 22 | echo "PPI 0.6" 23 | ### 0.6 24 | python plot_wandb.py --mode='Train' --file_str='0.6' --two_maml='gwokit1o' --no_finetune='ud6hpkp5' --finetune='y5c191dh' --mlp='x50hlrep' --graph_sig='u1g20x5n' --dataset='FIRSTMM_DB' --get_grad_steps 25 | echo "PPI 0.7" 26 | ### 0.7 27 | python plot_wandb.py --mode='Train' --file_str='0.7' --two_maml='7weacjzd' --no_finetune='ajwxpznq' --finetune='deyr9bj7' --mlp='x7j9c5az' --graph_sig='wpqy10x1' --dataset='FIRSTMM_DB' --get_grad_steps 28 | -------------------------------------------------------------------------------- /scripts/run_plotter_ppi_grad_wandb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | echo "PPI 0.1" 8 | ### 0.1 9 | python plot_wandb.py --mode='Train' --file_str='0.1' --two_maml='l5n28018' --concat='hjvwed7o' --no_finetune='x6w700c7' --finetune='15l49u26' --mlp='plu09evs' --graph_sig='o0pqonmb' --dataset='PPI' --get_grad_steps 10 | echo "PPI 0.2" 11 | ### 0.2 12 | python plot_wandb.py --mode='Train' --file_str='0.2' --two_maml='665pydub' --concat='j9ia97gm' --no_finetune='dqvtbbf3' --finetune='tbs43bdn ' --mlp='3yxlpwbv' --graph_sig='8f6kukzs' --dataset='PPI' --get_grad_steps 13 | echo "PPI 0.3" 14 | ### 0.3 15 | python plot_wandb.py --mode='Train' --file_str='0.3' --two_maml='azytfs4u' --concat='dip9z64b' --no_finetune='fusq205s' --finetune='zxhtln38' --mlp='fmmvh6s7' --graph_sig='pzocn7d6' --dataset='PPI' --get_grad_steps 16 | echo "PPI 0.4" 17 | ### 0.4 18 | python plot_wandb.py --mode='Train' --file_str='0.4' --two_maml='cmyavyxz' --concat='1vhfzajp' --no_finetune='c3tl70d3' --finetune='v8dt60jw' --mlp='z6vc0iql' --graph_sig='csl5vhy7' --dataset='PPI' --get_grad_steps 19 | echo "PPI 0.5" 20 | ### 0.5 21 | python plot_wandb.py --mode='Train' --file_str='0.5' --two_maml='mgrs9sss' --concat='ot4c8qm4' --no_finetune='aam4zuim' --finetune='35e8mecc' --mlp='ohrk5bzd' --graph_sig='ww7z5tsf' --dataset='PPI' --get_grad_steps 22 | echo "PPI 0.6" 23 | ### 0.6 24 | python plot_wandb.py --mode='Train' --file_str='0.6' --two_maml='c8qqc6ti' --concat='jn8kw9pn' --no_finetune='zg5xlims' --finetune='lowk3n6j' --mlp='i5fgp3eh' --graph_sig='mtlffk3p' --dataset='PPI' --get_grad_steps 25 | echo "PPI 0.7" 26 | ### 0.7 27 | python plot_wandb.py --mode='Train' --file_str='0.7' --two_maml='my3soycv' --concat='jbkm9yoj' --no_finetune='4n4t5xmp' --finetune='tplubzay' --mlp='mqskgi0m' --graph_sig='amvq1kn7' --dataset='PPI' --get_grad_steps 28 | -------------------------------------------------------------------------------- /scripts/run_plotter_ppi_wandb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | ### 0.1 8 | python plot_wandb.py --mode='Train' --file_str='0.1' --two_maml='51a137h0' --random_exp='958jmtdt' --concat='5h6tv117' --no_finetune='oicast9p' --finetune='jwtqqcnx' --adamic_adar='jym8wrdk' --mlp='ycict122' --graph_sig='fra1kxzs' --graph_sig_concat='f5n0421a' --graph_sig_random='acwam7wd' --dataset='PPI' --no_bar_plot --global_block 3 4 5 9 | python plot_wandb.py --mode='Test' --file_str='0.1' --two_maml='51a137h0' --random_exp='958jmtdt' --concat='5h6tv117' --no_finetune='oicast9p' --finetune='jwtqqcnx' --adamic_adar='jym8wrdk' --mlp='ycict122' --graph_sig='fra1kxzs' --graph_sig_concat='f5n0421a' --graph_sig_random='acwam7wd' --dataset='PPI' --no_bar_plot --global_block 3 4 5 10 | python plot_wandb.py --local --mode='Train' --file_str='0.1' --two_maml='51a137h0' --random_exp='958jmtdt' --concat='5h6tv117' --no_finetune='oicast9p' --finetune='jwtqqcnx' --adamic_adar='jym8wrdk' --mlp='ycict122' --graph_sig='fra1kxzs' --graph_sig_concat='f5n0421a' --graph_sig_random='acwam7wd' --dataset='PPI' --no_bar_plot --global_block 3 4 5 11 | python plot_wandb.py --local --mode='Test' --file_str='0.1' --two_maml='51a137h0' --random_exp='958jmtdt' --concat='5h6tv117' --no_finetune='oicast9p' --finetune='jwtqqcnx' --adamic_adar='jym8wrdk' --mlp='ycict122' --graph_sig='fra1kxzs' --graph_sig_concat='f5n0421a' --graph_sig_random='acwam7wd' --dataset='PPI' --no_bar_plot --global_block 3 4 5 12 | 13 | #### 0.2 14 | #python plot_wandb.py --mode='Train' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 15 | #python plot_wandb.py --mode='Test' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 16 | #python plot_wandb.py --local --mode='Train' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 17 | #python plot_wandb.py --local --mode='Test' --file_str='0.2' --two_maml='2fx6yzar' --random_exp='r5nf1r3g' --concat='nwrxg66k' --no_finetune='mfd3suff' --finetune='ufpmxpe7' --adamic_adar='xk73qql6' --mlp='ycm7128c' --graph_sig='18z1zlhp' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 18 | 19 | ##### 0.3 20 | #python plot_wandb.py --mode='Train' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 21 | #python plot_wandb.py --mode='Test' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 22 | #python plot_wandb.py --local --mode='Train' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 23 | #python plot_wandb.py --local --mode='Test' --file_str='0.3' --two_maml='m5fxvkfv' --random_exp='so4007b8' --concat='icvqm55w' --no_finetune='drusjtez' --finetune='b1qbrhga' --adamic_adar='le29sm1i' --mlp='th44x964' --graph_sig='2se6o3nq' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 24 | 25 | ##### 0.4 26 | #python plot_wandb.py --mode='Train' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 27 | #python plot_wandb.py --mode='Test' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 28 | #python plot_wandb.py --local --mode='Train' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 29 | #python plot_wandb.py --local --mode='Test' --file_str='0.4' --two_maml='kiin2ltu' --random_exp='nzqt7x0m' --concat='1qh49vh0' --no_finetune='3ogmputq' --finetune='vezexdms' --adamic_adar='k3d9t34j' --mlp='43n81tbp' --graph_sig='eejt9gu9' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 30 | 31 | ##### 0.5 32 | ##python plot_wandb.py --mode='Train' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 33 | ##python plot_wandb.py --mode='Test' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 34 | ##python plot_wandb.py --local --mode='Train' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 35 | ##python plot_wandb.py --local --mode='Test' --file_str='0.5' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 36 | 37 | ##### 0.6 38 | ##python plot_wandb.py --mode='Train' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 39 | ##python plot_wandb.py --mode='Test' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 40 | ##python plot_wandb.py --local --mode='Train' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 41 | ##python plot_wandb.py --local --mode='Test' --file_str='0.6' --two_maml='24q2qh3e' --random_exp='' --no_finetune='vn2xjqev' --finetune='a4sk262k' --adamic_adar='vlv2mv1t' --mlp='17llheum' --graph_sig='fra1kxzs' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 1 --global_block 1 2 3 4 42 | 43 | #### 0.7 44 | #python plot_wandb.py --mode='Train' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 45 | #python plot_wandb.py --mode='Test' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 46 | #python plot_wandb.py --local --mode='Train' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 47 | #python plot_wandb.py --local --mode='Test' --file_str='0.7' --two_maml='12xe6dux' --random_exp='xga4kp7j' --concat='q78a3qsr' --no_finetune='pr5xaguf' --finetune='op2llf77' --adamic_adar='yhnq5i55' --mlp='3f541vpt' --graph_sig='m583zz8y' --graph_sig_concat='' --graph_sig_random='' --dataset='PPI' --local_block 2 --global_block 3 4 5 48 | -------------------------------------------------------------------------------- /scripts/run_plotter_reddit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | 7 | ### 0.1 8 | #python plot_comet.py --mode='Train' --file_str='0.1' --two_maml='e43bc12bb01e4a61ac71d1ac8e72e9ee' --random_exp='51f0b547d6244037a0ec6ba0f8797094' --no_finetune='46a6f3c347e148658f61282eaa5de522' --finetune='162273c8357240afba4e775c93a371c4' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 9 | #python plot_comet.py --mode='Test' --file_str='0.1' --two_maml='e43bc12bb01e4a61ac71d1ac8e72e9ee' --random_exp='51f0b547d6244037a0ec6ba0f8797094' --no_finetune='46a6f3c347e148658f61282eaa5de522' --finetune='162273c8357240afba4e775c93a371c4' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 10 | #python plot_comet.py --local --mode='Train' --file_str='0.1' --two_maml='e43bc12bb01e4a61ac71d1ac8e72e9ee' --random_exp='51f0b547d6244037a0ec6ba0f8797094' --no_finetune='46a6f3c347e148658f61282eaa5de522' --finetune='162273c8357240afba4e775c93a371c4' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 11 | #python plot_comet.py --local --mode='Test' --file_str='0.1' --two_maml='e43bc12bb01e4a61ac71d1ac8e72e9ee' --random_exp='51f0b547d6244037a0ec6ba0f8797094' --no_finetune='46a6f3c347e148658f61282eaa5de522' --finetune='162273c8357240afba4e775c93a371c4' --adamic_adar='adf7008762bd4490b0b768e7132bc19a' --mlp='0dd87316bab94bac87e9a006ce169e42' --graph_sig='2cd268ae17034e148297a7f74c6d10e2' 12 | 13 | #### 0.2 14 | python plot_comet.py --mode='Train' --file_str='0.2' --two_maml='8c9bb6e070854810929ac9b918a78d41' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='eef25efe8ed34ac1b33a6dcd38862fcd' --finetune='5aba1f5222ab492db059d05f2f73dcf9' --adamic_adar='5fc4c223f4f94029bcc365b9b6b9d582' --mlp='e78e58f70a024e3cb2fe7b0f3edc9c21' --graph_sig='cae4f903f75042c3ac990e9c1193f8d8' 15 | python plot_comet.py --mode='Test' --file_str='0.2' --two_maml='8c9bb6e070854810929ac9b918a78d41' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='eef25efe8ed34ac1b33a6dcd38862fcd' --finetune='5aba1f5222ab492db059d05f2f73dcf9' --adamic_adar='5fc4c223f4f94029bcc365b9b6b9d582' --mlp='e78e58f70a024e3cb2fe7b0f3edc9c21' --graph_sig='cae4f903f75042c3ac990e9c1193f8d8' 16 | python plot_comet.py --local --mode='Train' --file_str='0.2' --two_maml='8c9bb6e070854810929ac9b918a78d41' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='eef25efe8ed34ac1b33a6dcd38862fcd' --finetune='5aba1f5222ab492db059d05f2f73dcf9' --adamic_adar='5fc4c223f4f94029bcc365b9b6b9d582' --mlp='e78e58f70a024e3cb2fe7b0f3edc9c21' --graph_sig='cae4f903f75042c3ac990e9c1193f8d8' 17 | python plot_comet.py --local --mode='Test' --file_str='0.2' --two_maml='8c9bb6e070854810929ac9b918a78d41' --random_exp='c51c563a6895478e8aa429e40ef5d14e' --no_finetune='eef25efe8ed34ac1b33a6dcd38862fcd' --finetune='5aba1f5222ab492db059d05f2f73dcf9' --adamic_adar='5fc4c223f4f94029bcc365b9b6b9d582' --mlp='e78e58f70a024e3cb2fe7b0f3edc9c21' --graph_sig='cae4f903f75042c3ac990e9c1193f8d8' 18 | 19 | ### 0.3 20 | #python plot_comet.py --mode='Train' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' 21 | #python plot_comet.py --mode='Test' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' 22 | #python plot_comet.py --local --mode='Train' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' 23 | #python plot_comet.py --local --mode='Test' --file_str='0.3' --two_maml='22802834fd384d64a8fb078589f8e0a8' --concat='1bf0450c8e8e41c0b701a40a8091ac16' --random_exp='e061e0fe4b13422a9a44b35f80259f30' --no_finetune='fa0ae0e43ca84df19b6f9eab68542694' --finetune='4ea3da08557b49be87ae79fdbeac10be' --adamic_adar='989b304c58604bfebc11f071cdde177b' --mlp='8cfc4185ced74afa97714e8dc14606c8' 24 | 25 | ### 0.4 26 | #python plot_comet.py --mode='Train' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' 27 | #python plot_comet.py --mode='Test' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' 28 | #python plot_comet.py --local --mode='Train' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' 29 | #python plot_comet.py --local --mode='Test' --file_str='0.4' --two_maml='beb78b6f6bc44c6e9ce3fb7d540cdf4f' --concat='d3cb5649c3c441f69fbba291f82fd4c2' --random_exp='2521991ea277485cb14ed3605939409f' --no_finetune='21b621c907584cb4a28ae4f91e5f6faa' --finetune='8b4affacfacc4e6cafa1952aaa6c7bf0' --adamic_adar='7a964dd1f63c4dbfb1509b1ef908b5d0' --mlp='53431b4b83974e8aa9977b427e98cd86' 30 | 31 | ### 0.5 32 | #python plot_comet.py --mode='Train' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' 33 | #python plot_comet.py --mode='Test' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' 34 | #python plot_comet.py --local --mode='Train' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' 35 | #python plot_comet.py --local --mode='Test' --file_str='0.5' --two_maml='3538df7d70ec4c0684f897beba80ec91' --concat='2f6297002f3740048ea66deaafed7860' --random_exp='4b399a1a19c84f0e9e7df8db337b1753' --no_finetune='71f42f50f2f2482d82761578ac8e1c73' --finetune='86b0e4725ea344fa81e683de5aca1a6e' --adamic_adar='1b5544f3fb6f485cb20bad881bdbaa27' --mlp='e7b22ed70d5f4df682cc3989591545a5' 36 | 37 | ### 0.6 38 | #python plot_comet.py --mode='Train' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' 39 | #python plot_comet.py --mode='Test' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' 40 | #python plot_comet.py --local --mode='Train' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' 41 | #python plot_comet.py --local --mode='Test' --file_str='0.6' --two_maml='add12cfa5031483daef86550bbfabf2c' --concat='5b80f4ed3d3747a8889dfd3173497fc9' --random_exp='086fd40e1228495fa788a630cf934281' --no_finetune='30c58d4ede6a4a84aa0341449fe8f19a' --finetune='e7875cbc15294122b4de122cbd8b8d4c' --adamic_adar='822f0054a86d4a129006b78616e14a39' --mlp='656e00cbcbcd4c7d9233b96eeaeb3b36' 42 | 43 | ### 0.7 44 | #python plot_comet.py --mode='Train' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' 45 | #python plot_comet.py --mode='Test' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' 46 | #python plot_comet.py --local --mode='Train' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' 47 | #python plot_comet.py --local --mode='Test' --file_str='0.7' --two_maml='fd0822c2375440ef94d6452a37de16fc' --concat='b8a13dc4027f455b91bc726e914404f9' --random_exp='308283b6b84946cfa267e82affa58156' --no_finetune='9c9a8ebd787041fbb051f56bd4b006c6' --finetune='9066b66bc6a6434fb941b45f342d5aa2' --adamic_adar='269465a957e34e659bb32db2e3190ca5' --mlp='45cc52506c4b4e0cacc50bd2018e8d2e' 48 | 49 | -------------------------------------------------------------------------------- /scripts/run_ppi_best_gs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | python3 main.py --meta_train_edge_ratio=0.1 --model='VGAE' --encoder='GraphSignature' --epochs=46 --use_gcn_sig --concat_fixed_feats --inner_steps=2 --inner-lr=2.24e-3 --meta-lr=2.727e-3 --clip_grad --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.1' --comet --wandb 7 | python3 main.py --meta_train_edge_ratio=0.2 --model='VGAE' --encoder='GraphSignature' --epochs=32 --use_gcn_sig --concat_fixed_feats --inner_steps=23 --inner-lr=4.949e-2 --meta-lr=2.834e-3 --clip_grad --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.2' --comet --wandb 8 | python3 main.py --meta_train_edge_ratio=0.3 --model='VGAE' --encoder='GraphSignature' --epochs=39 --use_gcn_sig --concat_fixed_feats --inner_steps=15 --inner-lr=3.545e-3 --meta-lr=1.493e-2 --clip_grad --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.3' --comet --wandb 9 | python3 main.py --meta_train_edge_ratio=0.4 --model='VGAE' --encoder='GraphSignature' --epochs=36 --use_gcn_sig --concat_fixed_feats --inner_steps=30 --inner-lr=8.618e-4 --meta-lr=1.1192e-2 --clip_grad --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.4' --comet --wandb 10 | python3 main.py --meta_train_edge_ratio=0.5 --model='VGAE' --encoder='GraphSignature' --epochs=7 --use_gcn_sig --concat_fixed_feats --inner_steps=15 --inner-lr=6.07e-3 --meta-lr=1.337e-2 --clip_grad --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.5' --comet --wandb 11 | python3 main.py --meta_train_edge_ratio=0.6 --model='VGAE' --encoder='GraphSignature' --epochs=36 --use_gcn_sig --concat_fixed_feats --inner_steps=9 --inner-lr=4.14e-4 --meta-lr=1.42e-3 --clip_grad --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.6' --comet --wandb 12 | python3 main.py --meta_train_edge_ratio=0.7 --model='VGAE' --encoder='GraphSignature' --epochs=18 --use_gcn_sig --concat_fixed_feats --inner_steps=29 --inner-lr=2.592e-2 --meta-lr=1.729e-3 --clip_grad --patience=2000 --train_batch_size=1 --dataset=PPI --order=2 --namestr='2-MAML_Concat_Patience_Best_GS_PPI_Ratio=0.7' --comet --wandb 13 | -------------------------------------------------------------------------------- /scripts/run_random.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | echo "Getting into the script" 6 | #echo "Running PPI experiments" 7 | 8 | #python main.py --model='VGAE' --epochs=50 --meta_train_edge_ratio=0.1 --random_baseline --namestr='Random Baseline Ratio=0.1' --comet & 9 | #python main.py --model='VGAE' --epochs=50 --meta_train_edge_ratio=0.2 --random_baseline --namestr='Random Baseline Ratio=0.2' --comet & 10 | #python main.py --model='VGAE' --epochs=50 --meta_train_edge_ratio=0.3 --random_baseline --namestr='Random Baseline Ratio=0.3' --comet & 11 | #python main.py --model='VGAE' --epochs=50 --meta_train_edge_ratio=0.4 --random_baseline --namestr='Random Baseline Ratio=0.4' --comet & 12 | #python main.py --model='VGAE' --epochs=50 --meta_train_edge_ratio=0.5 --random_baseline --namestr='Random Baseline Ratio=0.5' --comet & 13 | #python main.py --model='VGAE' --epochs=50 --meta_train_edge_ratio=0.6 --random_baseline --namestr='Random Baseline Ratio=0.6' --comet & 14 | #python main.py --model='VGAE' --epochs=50 --meta_train_edge_ratio=0.7 --random_baseline --namestr='Random Baseline Ratio=0.7' --comet & 15 | 16 | #echo "Running ENZYMES experiments" 17 | 18 | #python main.py --model='VGAE' --epochs=50 --dataset=ENZYMES --meta_train_edge_ratio=0.1 --random_baseline --namestr='ENZYMES Random Baseline Ratio=0.1' --comet & 19 | #python main.py --model='VGAE' --epochs=50 --dataset=ENZYMES --meta_train_edge_ratio=0.2 --random_baseline --namestr='ENZYMES Random Baseline Ratio=0.2' --comet & 20 | #python main.py --model='VGAE' --epochs=50 --dataset=ENZYMES --meta_train_edge_ratio=0.3 --random_baseline --namestr='ENZYMES Random Baseline Ratio=0.3' --comet & 21 | #python main.py --model='VGAE' --epochs=50 --dataset=ENZYMES --meta_train_edge_ratio=0.4 --random_baseline --namestr='ENZYMES Random Baseline Ratio=0.4' --comet & 22 | #python main.py --model='VGAE' --epochs=50 --dataset=ENZYMES --meta_train_edge_ratio=0.5 --random_baseline --namestr='ENZYMES Random Baseline Ratio=0.5' --comet & 23 | #python main.py --model='VGAE' --epochs=50 --dataset=ENZYMES --meta_train_edge_ratio=0.6 --random_baseline --namestr='ENZYMES Random Baseline Ratio=0.6' --comet & 24 | #python main.py --model='VGAE' --epochs=50 --dataset=ENZYMES --meta_train_edge_ratio=0.7 --random_baseline --namestr='ENZYMES Random Baseline Ratio=0.7' --comet & 25 | 26 | #wait 27 | 28 | echo "Running REDDIT experiments" 29 | 30 | python main.py --model='VGAE' --epochs=50 --dataset=REDDIT-MULTI-12K --use_fixed_feats --meta_train_edge_ratio=0.1 --random_baseline --namestr='REDDIT-MULTI-12K Random Baseline Ratio=0.1' --comet & 31 | python main.py --model='VGAE' --epochs=50 --dataset=REDDIT-MULTI-12K --use_fixed_feats --meta_train_edge_ratio=0.2 --random_baseline --namestr='REDDIT-MULTI-12K Random Baseline Ratio=0.2' --comet & 32 | python main.py --model='VGAE' --epochs=50 --dataset=REDDIT-MULTI-12K --use_fixed_feats --meta_train_edge_ratio=0.3 --random_baseline --namestr='REDDIT-MULTI-12K Random Baseline Ratio=0.3' --comet & 33 | python main.py --model='VGAE' --epochs=50 --dataset=REDDIT-MULTI-12K --use_fixed_feats --meta_train_edge_ratio=0.4 --random_baseline --namestr='REDDIT-MULTI-12K Random Baseline Ratio=0.4' --comet & 34 | python main.py --model='VGAE' --epochs=50 --dataset=REDDIT-MULTI-12K --use_fixed_feats --meta_train_edge_ratio=0.5 --random_baseline --namestr='REDDIT-MULTI-12K Random Baseline Ratio=0.5' --comet & 35 | python main.py --model='VGAE' --epochs=50 --dataset=REDDIT-MULTI-12K --use_fixed_feats --meta_train_edge_ratio=0.6 --random_baseline --namestr='REDDIT-MULTI-12K Random Baseline Ratio=0.6' --comet & 36 | python main.py --model='VGAE' --epochs=50 --dataset=REDDIT-MULTI-12K --use_fixed_feats --meta_train_edge_ratio=0.7 --random_baseline --namestr='REDDIT-MULTI-12K Random Baseline Ratio=0.7' --comet & 37 | -------------------------------------------------------------------------------- /vgae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import os.path as osp 4 | from comet_ml import Experiment 5 | import argparse 6 | import torch 7 | import torch.nn.functional as F 8 | from torch_geometric.datasets import Planetoid,PPI 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import GATConv, GCNConv, GAE, VGAE 11 | from torch_geometric.data import DataLoader 12 | import numpy as np 13 | from data import load_dataset 14 | from models import * 15 | from utils import global_test, test, EarlyStopping, seed_everything 16 | import json 17 | import ipdb 18 | 19 | def train(model, args, x, train_pos_edge_index, num_nodes, optimizer): 20 | model.train() 21 | optimizer.zero_grad() 22 | z = model.encode(x, train_pos_edge_index) 23 | loss = model.recon_loss(z, train_pos_edge_index) 24 | if args.model in ['VGAE']: 25 | loss = loss + (1 / num_nodes) * model.kl_loss() 26 | loss.backward() 27 | optimizer.step() 28 | 29 | def val(model, args, x, val_pos_edge_index, num_nodes): 30 | model.eval() 31 | with torch.no_grad(): 32 | z = model.encode(x, val_pos_edge_index) 33 | loss = model.recon_loss(z, val_pos_edge_index) 34 | if args.model in ['VGAE']: 35 | loss = loss + (1 / num_nodes) * model.kl_loss() 36 | return loss.item() 37 | 38 | def test(model, x, train_pos_edge_index, pos_edge_index, neg_edge_index): 39 | model.eval() 40 | with torch.no_grad(): 41 | z = model.encode(x, train_pos_edge_index) 42 | return model.test(z, pos_edge_index, neg_edge_index) 43 | 44 | def main(args): 45 | assert args.model in ['GAE', 'VGAE'] 46 | kwargs = {'GAE': GAE, 'VGAE': VGAE} 47 | kwargs_enc = {'GCN': Encoder, 'FC': MLPEncoder, 'MLP': MetaMLPEncoder, 48 | 'GraphSignature': MetaSignatureEncoder} 49 | 50 | path = osp.join( 51 | osp.dirname(osp.realpath(__file__)), '..', 'data', args.dataset) 52 | train_loader, val_loader, test_loader = load_dataset(args.dataset,args) 53 | model = kwargs[args.model](kwargs_enc[args.encoder](args, args.num_features, args.num_channels)).to(args.dev) 54 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 55 | total_loss = 0 56 | graph_id = 0 57 | train_auc_array = np.zeros((len(train_loader)*args.train_batch_size, int(args.epochs/5 + 1))) 58 | train_ap_array = np.zeros((len(train_loader)*args.train_batch_size, int(args.epochs/5 + 1))) 59 | val_auc_array = np.zeros((len(val_loader)*args.test_batch_size, int(args.epochs/5 + 1))) 60 | val_ap_array = np.zeros((len(val_loader)*args.test_batch_size, int(args.epochs/5 + 1))) 61 | test_auc_array = np.zeros((len(test_loader)*args.test_batch_size, int(args.epochs/5 + 1))) 62 | test_ap_array = np.zeros((len(test_loader)*args.test_batch_size, int(args.epochs/5 + 1))) 63 | 64 | if args.finetune: 65 | for i,data_batch in enumerate(train_loader): 66 | for idx, data in enumerate(data_batch): 67 | data.train_mask = data.val_mask = data.test_mask = data.y = None 68 | data.batch = None 69 | num_nodes = data.num_nodes 70 | meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio 71 | if args.use_fixed_feats: 72 | ##TODO: Should this be a fixed embedding table instead of generating this each time? 73 | perm = torch.randperm(args.feats.size(0)) 74 | perm_idx = perm[:num_nodes] 75 | data.x = args.feats[perm_idx] 76 | 77 | if args.concat_fixed_feats: 78 | ##TODO: Should this be a fixed embedding table instead of generating this each time? 79 | concat_feats = torch.randn(num_nodes,args.num_concat_features,requires_grad=False) 80 | data.x = torch.cat((data.x,concat_feats),1) 81 | try: 82 | data = model.split_edges(data,val_ratio=args.meta_val_edge_ratio,test_ratio=meta_test_edge_ratio) 83 | except: 84 | args.fail_counter += 1 85 | print("Failed on Graph %d" %(graph_id)) 86 | continue 87 | 88 | # Additional Failure Checks for small graphs 89 | if data.val_pos_edge_index.size()[1] == 0 or data.test_pos_edge_index.size()[1] == 0: 90 | args.fail_counter += 1 91 | print("Failed on Graph %d" %(graph_id)) 92 | continue 93 | 94 | x, train_pos_edge_index = data.x.to(args.dev), data.train_pos_edge_index.to(args.dev) 95 | for epoch in range(0, args.epochs): 96 | if not args.random_baseline: 97 | if args.train_with_val: 98 | train_pos_edge_index = torch.cat([data.train_pos_edge_index, 99 | data.val_pos_edge_index], dim=1).cuda() 100 | train(model,args,x,train_pos_edge_index,data.num_nodes,optimizer) 101 | auc, ap = test(model, x, train_pos_edge_index, 102 | data.test_pos_edge_index, data.test_neg_edge_index) 103 | 104 | if epoch % 5 == 0: 105 | my_step = int(epoch / 5) 106 | train_auc_array[graph_id][my_step] = auc 107 | train_ap_array[graph_id][my_step] = ap 108 | 109 | ''' Save after every graph ''' 110 | auc_metric = 'Train_Local_Batch_Graph_' + str(graph_id) +'_AUC' 111 | ap_metric = 'Train_Local_Batch_Graph_' + str(graph_id) +'_AP' 112 | for val_idx in range(0,train_auc_array.shape[1]): 113 | auc = train_auc_array[graph_id][val_idx] 114 | ap = train_ap_array[graph_id][val_idx] 115 | if args.comet: 116 | args.experiment.log_metric(auc_metric,auc,step=val_idx) 117 | args.experiment.log_metric(ap_metric,ap,step=val_idx) 118 | if args.wandb: 119 | wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx}) 120 | auc = train_auc_array[graph_id][val_idx - 1] 121 | ap = train_ap_array[graph_id][val_idx - 1] 122 | print('Train Graph {:01d}, AUC: {:.4f}, AP: {:.4f}'.format(graph_id, auc, ap)) 123 | graph_id += 1 124 | save_path = '../saved_models/vgae.pt' 125 | # torch.save(model.state_dict(), save_path) 126 | 127 | auc_metric = 'Train_Complete' +'_AUC' 128 | ap_metric = 'Train_Complete' +'_AP' 129 | #Remove All zero rows 130 | train_auc_array = train_auc_array[~np.all(train_auc_array == 0, axis=1)] 131 | train_ap_array = train_ap_array[~np.all(train_ap_array == 0, axis=1)] 132 | train_aggr_auc = np.sum(train_auc_array,axis=0)/len(train_loader) 133 | train_aggr_ap = np.sum(train_ap_array,axis=0)/len(train_loader) 134 | for val_idx in range(0,train_auc_array.shape[1]): 135 | auc = train_aggr_auc[val_idx] 136 | ap = train_aggr_ap[val_idx] 137 | if args.comet: 138 | args.experiment.log_metric(auc_metric,auc,step=val_idx) 139 | args.experiment.log_metric(ap_metric,ap,step=val_idx) 140 | if args.wandb: 141 | wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx}) 142 | 143 | ''' Start Validation ''' 144 | if args.do_val: 145 | val_graph_id = 0 146 | for i,data_batch in enumerate(val_loader): 147 | ''' Re-init Optimizerr ''' 148 | if args.finetune: 149 | val_model = kwargs[args.model](kwargs_enc[args.encoder](args, args.num_features, args.num_channels)).to(args.dev) 150 | val_model.load_state_dict(model.state_dict()) 151 | optimizer = torch.optim.Adam(val_model.parameters(), lr=args.lr) 152 | else: 153 | val_model = kwargs[args.model](kwargs_enc[args.encoder](args, args.num_features, args.num_channels)).to(args.dev) 154 | optimizer = torch.optim.Adam(val_model.parameters(), lr=args.lr) 155 | early_stopping = EarlyStopping(patience=args.patience, verbose=False) 156 | 157 | for idx, data in enumerate(data_batch): 158 | data.train_mask = data.val_mask = data.test_mask = data.y = None 159 | data.batch = None 160 | num_nodes = data.num_nodes 161 | if args.use_fixed_feats: 162 | ##TODO: Should this be a fixed embedding table instead of generating this each time? 163 | perm = torch.randperm(args.feats.size(0)) 164 | perm_idx = perm[:num_nodes] 165 | data.x = args.feats[perm_idx] 166 | 167 | if args.concat_fixed_feats: 168 | ##TODO: Should this be a fixed embedding table instead of generating this each time? 169 | concat_feats = torch.randn(num_nodes,args.num_concat_features,requires_grad=False) 170 | data.x = torch.cat((data.x,concat_feats),1) 171 | 172 | # Val Ratio is Fixed at 0.1 173 | meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio 174 | try: 175 | data = val_model.split_edges(data,val_ratio=args.meta_val_edge_ratio,test_ratio=meta_test_edge_ratio) 176 | except: 177 | args.fail_counter += 1 178 | print("Failed on Graph %d" %(val_graph_id)) 179 | continue 180 | 181 | # Additional Failure Checks for small graphs 182 | if data.val_pos_edge_index.size()[1] == 0 or data.test_pos_edge_index.size()[1] == 0: 183 | args.fail_counter += 1 184 | print("Failed on Graph %d" %(val_graph_id)) 185 | continue 186 | 187 | x, train_pos_edge_index = data.x.to(args.dev), data.train_pos_edge_index.to(args.dev) 188 | val_pos_edge_index = data.val_pos_edge_index.to(args.dev) 189 | for epoch in range(0, args.epochs): 190 | if not args.random_baseline: 191 | train(val_model,args,x,train_pos_edge_index,data.num_nodes,optimizer) 192 | val_loss = val(model,args,x,val_pos_edge_index,data.num_nodes) 193 | early_stopping(val_loss, val_model) 194 | auc, ap = test(val_model, x, train_pos_edge_index, 195 | data.test_pos_edge_index, data.test_neg_edge_index) 196 | 197 | if early_stopping.early_stop: 198 | print("Early stopping for Graph %d | AUC: %f AP: %f" \ 199 | %(val_graph_id, auc, ap)) 200 | my_step = int(epoch / 5) 201 | val_auc_array[val_graph_id][my_step:,] = auc 202 | val_ap_array[val_graph_id][my_step:,] = ap 203 | break 204 | 205 | if epoch % 5 == 0: 206 | my_step = int(epoch / 5) 207 | val_auc_array[val_graph_id][my_step] = auc 208 | val_ap_array[val_graph_id][my_step] = ap 209 | 210 | ''' Save after every graph ''' 211 | auc_metric = 'Val_Local_Batch_Graph_' + str(val_graph_id) +'_AUC' 212 | ap_metric = 'Val_Local_Batch_Graph_' + str(val_graph_id) +'_AP' 213 | for val_idx in range(0,val_auc_array.shape[1]): 214 | auc = val_auc_array[val_graph_id][val_idx] 215 | ap = val_ap_array[val_graph_id][val_idx] 216 | if args.comet: 217 | args.experiment.log_metric(auc_metric,auc,step=val_idx) 218 | args.experiment.log_metric(ap_metric,ap,step=val_idx) 219 | if args.wandb: 220 | wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx}) 221 | 222 | print('Val Graph {:01d}, AUC: {:.4f}, AP: {:.4f}'.format(val_graph_id, auc, ap)) 223 | val_graph_id += 1 224 | 225 | auc_metric = 'Val_Complete' +'_AUC' 226 | ap_metric = 'Val_Complete' +'_AP' 227 | 228 | #Remove All zero rows 229 | val_auc_array = val_auc_array[~np.all(val_auc_array == 0, axis=1)] 230 | val_ap_array = val_ap_array[~np.all(val_ap_array == 0, axis=1)] 231 | val_aggr_auc = np.sum(val_auc_array,axis=0)/len(val_loader) 232 | val_aggr_ap = np.sum(val_ap_array,axis=0)/len(val_loader) 233 | max_auc = np.max(val_aggr_auc) 234 | max_ap = np.max(val_aggr_ap) 235 | for val_idx in range(0,val_auc_array.shape[1]): 236 | auc = val_aggr_auc[val_idx] 237 | ap = val_aggr_ap[val_idx] 238 | if args.comet: 239 | args.experiment.log_metric(auc_metric,auc,step=val_idx) 240 | args.experiment.log_metric(ap_metric,ap,step=val_idx) 241 | if args.wandb: 242 | wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx}) 243 | auc = val_aggr_auc[val_idx -1] 244 | ap = val_aggr_ap[val_idx - 1] 245 | print('Val Complete AUC: {:.4f}, AP: {:.4f}'.format(auc, ap)) 246 | print('Val Max AUC: {:.4f}, AP: {:.4f}'.format(max_auc, max_ap)) 247 | val_eval_metric = 0.5*max_auc + 0.5*max_ap 248 | return val_eval_metric 249 | 250 | ''' Start Testing ''' 251 | if not args.do_val: 252 | test_graph_id = 0 253 | for i,data_batch in enumerate(test_loader): 254 | ''' Re-init Optimizerr ''' 255 | if args.finetune: 256 | test_model = kwargs[args.model](kwargs_enc[args.encoder](args, args.num_features, args.num_channels)).to(args.dev) 257 | test_model.load_state_dict(model.state_dict()) 258 | optimizer = torch.optim.Adam(test_model.parameters(), lr=args.lr) 259 | else: 260 | test_model = kwargs[args.model](kwargs_enc[args.encoder](args, args.num_features, args.num_channels)).to(args.dev) 261 | optimizer = torch.optim.Adam(test_model.parameters(), lr=args.lr) 262 | early_stopping = EarlyStopping(patience=args.patience, verbose=False) 263 | 264 | for idx, data in enumerate(data_batch): 265 | data.train_mask = data.val_mask = data.test_mask = data.y = None 266 | data.batch = None 267 | num_nodes = data.num_nodes 268 | if args.use_fixed_feats: 269 | ##TODO: Should this be a fixed embedding table instead of generating this each time? 270 | perm = torch.randperm(args.feats.size(0)) 271 | perm_idx = perm[:num_nodes] 272 | data.x = args.feats[perm_idx] 273 | 274 | if args.concat_fixed_feats: 275 | ##TODO: Should this be a fixed embedding table instead of generating this each time? 276 | concat_feats = torch.randn(num_nodes,args.num_concat_features,requires_grad=False) 277 | data.x = torch.cat((data.x,concat_feats),1) 278 | 279 | # Val Ratio is Fixed at 0.1 280 | meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio 281 | try: 282 | data = test_model.split_edges(data,val_ratio=args.meta_val_edge_ratio,test_ratio=meta_test_edge_ratio) 283 | except: 284 | args.fail_counter += 1 285 | print("Failed on Graph %d" %(test_graph_id)) 286 | continue 287 | 288 | # Additional Failure Checks for small graphs 289 | if data.val_pos_edge_index.size()[1] == 0 or data.test_pos_edge_index.size()[1] == 0: 290 | args.fail_counter += 1 291 | print("Failed on Graph %d" %(test_graph_id)) 292 | continue 293 | 294 | x, train_pos_edge_index = data.x.to(args.dev), data.train_pos_edge_index.to(args.dev) 295 | val_pos_edge_index = data.val_pos_edge_index.to(args.dev) 296 | for epoch in range(0, args.epochs): 297 | if not args.random_baseline: 298 | if args.train_with_val: 299 | train_pos_edge_index =torch.cat([data.train_pos_edge_index, 300 | data.val_pos_edge_index], dim=1).cuda() 301 | val_loss = val(model,args,x,train_pos_edge_index,data.num_nodes) 302 | early_stopping(val_loss, test_model) 303 | else: 304 | val_loss = val(model,args,x,val_pos_edge_index,data.num_nodes) 305 | early_stopping(val_loss, test_model) 306 | train(test_model,args,x,train_pos_edge_index,data.num_nodes,optimizer) 307 | auc, ap = test(test_model, x, train_pos_edge_index, 308 | data.test_pos_edge_index, data.test_neg_edge_index) 309 | 310 | if early_stopping.early_stop: 311 | print("Early stopping for Graph %d | AUC: %f AP: %f" \ 312 | %(test_graph_id, auc, ap)) 313 | my_step = int(epoch / 5) 314 | test_auc_array[test_graph_id][my_step:,] = auc 315 | test_ap_array[test_graph_id][my_step:,] = ap 316 | break 317 | 318 | if epoch % 5 == 0: 319 | my_step = int(epoch / 5) 320 | test_auc_array[test_graph_id][my_step] = auc 321 | test_ap_array[test_graph_id][my_step] = ap 322 | 323 | ''' Save after every graph ''' 324 | auc_metric = 'Test_Local_Batch_Graph_' + str(test_graph_id) +'_AUC' 325 | ap_metric = 'Test_Local_Batch_Graph_' + str(test_graph_id) +'_AP' 326 | for val_idx in range(0,test_auc_array.shape[1]): 327 | auc = test_auc_array[test_graph_id][val_idx] 328 | ap = test_ap_array[test_graph_id][val_idx] 329 | if args.comet: 330 | args.experiment.log_metric(auc_metric,auc,step=val_idx) 331 | args.experiment.log_metric(ap_metric,ap,step=val_idx) 332 | if args.wandb: 333 | wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx}) 334 | 335 | print('Test Graph {:01d}, AUC: {:.4f}, AP: {:.4f}'.format(test_graph_id, auc, ap)) 336 | test_graph_id += 1 337 | if not os.path.exists('../saved_models/'): 338 | os.makedirs('../saved_models/') 339 | save_path = '../saved_models/vgae.pt' 340 | # torch.save(model.state_dict(), save_path) 341 | 342 | auc_metric = 'Test_Complete' +'_AUC' 343 | ap_metric = 'Test_Complete' +'_AP' 344 | #Remove All zero rows 345 | test_auc_array = test_auc_array[~np.all(test_auc_array == 0, axis=1)] 346 | test_ap_array = test_ap_array[~np.all(test_ap_array == 0, axis=1)] 347 | 348 | test_aggr_auc = np.sum(test_auc_array,axis=0)/len(test_loader) 349 | test_aggr_ap = np.sum(test_ap_array,axis=0)/len(test_loader) 350 | max_auc = np.max(test_aggr_auc) 351 | max_ap = np.max(test_aggr_ap) 352 | for val_idx in range(0,test_auc_array.shape[1]): 353 | auc = test_aggr_auc[val_idx] 354 | ap = test_aggr_ap[val_idx] 355 | if args.comet: 356 | args.experiment.log_metric(auc_metric,auc,step=val_idx) 357 | args.experiment.log_metric(ap_metric,ap,step=val_idx) 358 | if args.wandb: 359 | wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx}) 360 | auc = test_aggr_auc[val_idx -1] 361 | ap = test_aggr_ap[val_idx -1] 362 | print('Test Complete AUC: {:.4f}, AP: {:.4f}'.format(auc, ap)) 363 | print('Test Max AUC: {:.4f}, AP: {:.4f}'.format(max_auc, max_ap)) 364 | test_eval_metric = 0.5*max_auc + 0.5*max_ap 365 | return test_eval_metric 366 | 367 | if __name__ == '__main__': 368 | """ 369 | Process command-line arguments, then call main() 370 | """ 371 | parser = argparse.ArgumentParser() 372 | parser.add_argument('--model', type=str, default='VGAE') 373 | parser.add_argument('--num_channels', type=int, default='16') 374 | parser.add_argument('--epochs', default=251, type=int) 375 | parser.add_argument('--dataset', type=str, default='PPI') 376 | parser.add_argument("--finetune", action="store_true", default=False, 377 | help='Finetune from previous graph') 378 | parser.add_argument('--train_batch_size', default=1, type=int) 379 | parser.add_argument('--num_gated_layers', default=4, type=int,\ 380 | help='Number of layers to use for the Gated Graph Conv Layer') 381 | parser.add_argument('--encoder', type=str, default='GCN') 382 | parser.add_argument('--test_batch_size', default=1, type=int) 383 | parser.add_argument('--num_fixed_features', default=20, type=int) 384 | parser.add_argument('--num_concat_features', default=10, type=int) 385 | parser.add_argument('--meta_train_edge_ratio', type=float, default='0.2') 386 | parser.add_argument('--meta_val_edge_ratio', type=float, default='0.2') 387 | parser.add_argument('--lr', type=float, default='0.001') 388 | parser.add_argument('--train_ratio', type=float, default='0.8', \ 389 | help='Used to split number of graphs for training if not provided') 390 | parser.add_argument("--concat_fixed_feats", action="store_true", default=False, 391 | help='Concatenate random node features to current node features') 392 | parser.add_argument("--use_fixed_feats", action="store_true", default=False, 393 | help='Use a random node features') 394 | parser.add_argument('--val_ratio', type=float, default='0.1',\ 395 | help='Used to split number of graphs for validation if not provided') 396 | parser.add_argument('--train_with_val', default=False, action='store_true', 397 | help='Combine Train + Val edges') 398 | parser.add_argument('--do_val', default=False, action='store_true', 399 | help='Do Validation') 400 | parser.add_argument("--comet", action="store_true", default=False, 401 | help='Use comet for logging') 402 | parser.add_argument("--comet_username", type=str, default="joeybose", 403 | help='Username for comet logging') 404 | parser.add_argument('--seed', type=int, default=12345, metavar='S', 405 | help='random seed (default: 1)') 406 | parser.add_argument("--comet_apikey", type=str,\ 407 | default="Ht9lkWvTm58fRo9ccgpabq5zV",help='Api for comet logging') 408 | parser.add_argument("--wandb", action="store_true", default=False, 409 | help='Use wandb for logging') 410 | parser.add_argument('--debug', default=False, action='store_true', 411 | help='Debug') 412 | parser.add_argument('--model_path', type=str, default="mnist_cnn.pt", 413 | help='where to save/load') 414 | parser.add_argument("--random_baseline", action="store_true", default=False, 415 | help='Use a Random Baseline') 416 | parser.add_argument('--k_core', type=int, default=5, help="K-core for Graph") 417 | parser.add_argument('--patience', type=int, default=200, help="K-core for Graph") 418 | parser.add_argument("--reprocess", action="store_true", default=False, 419 | help='Reprocess AMINER datasete') 420 | parser.add_argument("--ego", action="store_true", default=False, 421 | help='Reprocess AMINER as ego dataset') 422 | parser.add_argument('--opus', default=False, action='store_true', 423 | help='Change AMINER File Path for Opus') 424 | parser.add_argument('--max_nodes', type=int, default=50000, \ 425 | help='Max Nodes needed for a graph to be included') 426 | parser.add_argument('--namestr', type=str, default='Meta-Graph', \ 427 | help='additional info in output filename to describe experiments') 428 | parser.add_argument('--min_nodes', type=int, default=1000, \ 429 | help='Min Nodes needed for a graph to be included') 430 | args = parser.parse_args() 431 | torch.manual_seed(args.seed) 432 | 433 | ''' Fix Random Seed ''' 434 | seed_everything(args.seed) 435 | # Check if settings file 436 | if os.path.isfile("settings.json"): 437 | with open('settings.json') as f: 438 | data = json.load(f) 439 | args.comet_apikey = data["apikey"] 440 | args.comet_username = data["username"] 441 | args.wandb_apikey = data["wandbapikey"] 442 | 443 | args.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 444 | if args.dataset=='PPI': 445 | project_name = 'meta-graph-ppi' 446 | elif args.dataset=='REDDIT-MULTI-12K': 447 | project_name = "meta-graph-reddit" 448 | elif args.dataset=='FIRSTMM_DB': 449 | project_name = "meta-graph-firstmmdb" 450 | elif args.dataset=='DD': 451 | project_name = "meta-graph-dd" 452 | elif args.dataset=='AMINER': 453 | project_name = "meta-graph-aminer" 454 | else: 455 | project_name='meta-graph' 456 | 457 | if args.comet: 458 | experiment = Experiment(api_key=args.comet_apikey,\ 459 | project_name=project_name,\ 460 | workspace=args.comet_username) 461 | experiment.set_name(args.namestr) 462 | args.experiment = experiment 463 | 464 | if args.wandb: 465 | os.environ['WANDB_API_KEY'] = args.wandb_apikey 466 | wandb.init(project=project_name,name=args.namestr) 467 | 468 | print(vars(args)) 469 | eval_metric = main(args) 470 | --------------------------------------------------------------------------------