├── .gitignore ├── run.sh ├── README.md ├── parse.py ├── eval.py ├── load_data.py ├── main.py ├── logger.py ├── model.py ├── dataset.py └── data_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | results 3 | data 4 | elliptic.sh 5 | .nfs00000000030228380000b2bd -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # GCN backbone 2 | python main.py --dataset cora --backbone gcn --weight_decay 5e-5 --tau 1 --dropout 0.2 --env_type graph --combine_result --store 3 | python main.py --dataset citeseer --backbone gcn --weight_decay 5e-5 --tau 1 --dropout 0.1 --env_type graph --combine_result --store 4 | python main.py --dataset pubmed --backbone gcn --weight_decay 5e-5 --tau 2 --dropout 0.2 --env_type graph --combine_result --store 5 | python main.py --dataset arxiv --backbone gcn --weight_decay 0.0005 --tau 1 --dropout 0.2 --env_type node --variant --store 6 | python main.py --dataset twitch --backbone gcn --weight_decay 5e-5 --tau 3 --dropout 0 --env_type graph --store 7 | python main.py --dataset elliptic --backbone gcn --weight_decay 0.001 --tau 1 --K 3 --dropout 0.2 --env_type node --variant --num_layers 3 --hidden_channels 32 --store 8 | 9 | # GAT backbone 10 | python main.py --dataset cora --backbone gat --weight_decay 0 --tau 3 --dropout 0.2 --env_type graph --combine_result --store 11 | python main.py --dataset citeseer --backbone gat --weight_decay 0 --tau 3 --dropout 0.2 --env_type graph --combine_result --store 12 | python main.py --dataset pubmed --backbone gat --weight_decay 5e-5 --tau 1 --dropout 0.2 --env_type graph --combine_result --store 13 | python main.py --dataset arxiv --backbone gat --weight_decay 5e-5 --tau 2 --dropout 0.2 --env_type graph --store 14 | python main.py --dataset twitch --backbone gat --weight_decay 5e-5 --tau 2 --dropout 0 --env_type graph --store 15 | python main.py --dataset elliptic --backbone gat --weight_decay 0.0005 --tau 2 --dropout 0.1 --env_type graph --store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CaNet 2 | The official implementation for WWW2024 Oral paper "Graph Out-of-Distribution Generalization via Causal Intervention" 3 | 4 | Related material: [[Paper](https://arxiv.org/pdf/2402.11494.pdf)], [[Blog (Chinese)](https://zhuanlan.zhihu.com/p/709125359)], [[Blog (English)](https://medium.com/towards-data-science/towards-generalization-on-graphs-from-invariance-to-causality-c81a174ac37b)], [[Slides](https://qitianwu.github.io/assets/publications/www2024-canet/slides.pdf)] 5 | 6 | ## What's news 7 | 8 | [2024.02.08] We release the code for the model on six datasets. More detailed info will be updated soon. 9 | 10 | ## Model and Results 11 | 12 | Our model coordinates two key components 1) an environment estimator that infers pseudo environment labels, and 2) a mixture-of-expert GNN predictor with feature propagation 13 | units conditioned on the pseudo environments. 14 | 15 | image 16 | 17 | ## Dataset 18 | 19 | One can download the datasets Planetoid (Cora, Citeseer, Pubmed), Arxiv, Twitch, and Elliptic from the google drive link below: 20 | 21 | https://drive.google.com/drive/folders/1FAPWghoyGp9vzr1xmnndpmLLFS1OgBDa?usp=sharing 22 | 23 | ## Dependence 24 | 25 | Python 3.8, PyTorch 1.13.0, PyTorch Geometric 2.1.0, NumPy 1.23.4 26 | 27 | ## Run the codes 28 | 29 | Please refer to the bash script `run.sh` in each folder for running the training and evaluation pipeline on six datasets. 30 | 31 | ## Acknowledgement 32 | 33 | The implementation of training pipeline and evaluation is based on [EERM](https://github.com/qitianwu/GraphOOD-EERM). 34 | 35 | ### Citation 36 | 37 | If you find our code and model useful, please cite our work. Thank you! 38 | 39 | ```bibtex 40 | @inproceedings{wu2024canet, 41 | title = {Graph Out-of-Distribution Generalization via Causal Intervention}, 42 | author = {Qitian Wu and Nie Fan and Chenxiao Yang and Tianyi Bao and Junchi Yan}, 43 | booktitle = {The Web Conference}, 44 | year = {2024} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /parse.py: -------------------------------------------------------------------------------- 1 | def parser_add_main_args(parser): 2 | # setup and protocol 3 | parser.add_argument('--dataset', type=str, default='cora') 4 | parser.add_argument('--data_dir', type=str, 5 | default='/mnt/nas/home/niefan/ODgraph-energy/data/') 6 | parser.add_argument('--device', type=int, default=0, 7 | help='which gpu to use if any (default: 0)') 8 | parser.add_argument('--cpu', action='store_true') 9 | parser.add_argument('--seed', type=int, default=123) 10 | parser.add_argument('--runs', type=int, default=5, 11 | help='number of distinct runs') 12 | parser.add_argument('--epochs', type=int, default=500) 13 | 14 | # model network 15 | parser.add_argument('--hidden_channels', type=int, default=64) 16 | parser.add_argument('--num_layers', type=int, default=2, 17 | help='number of layers for deep methods') 18 | 19 | # CaNet 20 | parser.add_argument('--backbone_type', type=str, default='gcn', choices=['gcn', 'gat']) 21 | parser.add_argument('--K', type=int, default=3, 22 | help='num of domains, each for one graph convolution filter') 23 | parser.add_argument('--tau', type=float, default=1, 24 | help='temperature for Gumbel Softmax') 25 | parser.add_argument('--env_type', type=str, default='node', choices=['node', 'graph']) 26 | parser.add_argument('--lamda', type=float, default=1.0, 27 | help='weight for regularlization') 28 | parser.add_argument('--variant', action='store_true',help='set to use variant') 29 | 30 | # training 31 | parser.add_argument('--weight_decay', type=float, default=5e-4) 32 | parser.add_argument('--dropout', type=float, default=0.0) 33 | parser.add_argument('--lr', type=float, default=0.01) 34 | parser.add_argument('--use_bn', action='store_true', help='use batch norm') 35 | 36 | # display and utility 37 | parser.add_argument('--display_step', type=int, 38 | default=1, help='how often to print') 39 | parser.add_argument('--store_result', action='store_true', 40 | help='whether to store results') 41 | parser.add_argument('--combine_result', action='store_true', 42 | help='whether to combine all the ood environments') 43 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from sklearn.metrics import roc_auc_score, f1_score 5 | 6 | def eval_f1(y_true, y_pred): 7 | y_true = y_true.detach().cpu().numpy() 8 | y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy() 9 | f1 = f1_score(y_true, y_pred, average='macro') 10 | return f1 11 | 12 | def eval_acc(y_true, y_pred): 13 | acc_list = [] 14 | y_true = y_true.detach().cpu().numpy() 15 | y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy() 16 | 17 | for i in range(y_true.shape[1]): 18 | is_labeled = y_true[:, i] == y_true[:, i] 19 | correct = y_true[is_labeled, i] == y_pred[is_labeled, i] 20 | acc_list.append(float(np.sum(correct))/len(correct)) 21 | 22 | return sum(acc_list)/len(acc_list) 23 | 24 | 25 | def eval_rocauc(y_true, y_pred): 26 | """ adapted from ogb 27 | https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/evaluate.py""" 28 | rocauc_list = [] 29 | y_true = y_true.detach().cpu().numpy() 30 | if y_true.shape[1] == 1: 31 | # use the predicted class for single-class classification 32 | y_pred = F.softmax(y_pred, dim=-1)[:, 1].unsqueeze(1).cpu().numpy() 33 | else: 34 | y_pred = y_pred.detach().cpu().numpy() 35 | 36 | for i in range(y_true.shape[1]): 37 | # AUC is only defined when there is at least one positive data. 38 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 39 | is_labeled = y_true[:, i] == y_true[:, i] 40 | score = roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]) 41 | 42 | rocauc_list.append(score) 43 | 44 | if len(rocauc_list) == 0: 45 | raise RuntimeError( 46 | 'No positively labeled data available. Cannot compute ROC-AUC.') 47 | 48 | return sum(rocauc_list)/len(rocauc_list) 49 | 50 | 51 | @torch.no_grad() 52 | def evaluate_full(model, dataset, eval_func): 53 | model.eval() 54 | 55 | train_idx, valid_idx, test_in_idx, test_ood_idx = dataset.train_idx, dataset.valid_idx, dataset.test_in_idx, dataset.test_ood_idx 56 | y = dataset.y.cpu() 57 | out = model(dataset.x, dataset.edge_index).cpu() 58 | 59 | train_acc = eval_func(y[train_idx], out[train_idx]) 60 | valid_acc = eval_func(y[valid_idx], out[valid_idx]) 61 | test_in_acc = eval_func(y[test_in_idx], out[test_in_idx]) 62 | test_ood_accs = [] 63 | for t in test_ood_idx: 64 | test_ood_accs.append(eval_func(y[t], out[t])) 65 | result = [train_acc, valid_acc, test_in_acc] + test_ood_accs 66 | 67 | return result 68 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import numpy as np 3 | import scipy.sparse 4 | import torch 5 | import csv 6 | import json 7 | from os import path 8 | 9 | DATAPATH = '../../data/' 10 | 11 | def load_fb100(filename): 12 | # e.g. filename = Rutgers89 or Cornell5 or Wisconsin87 or Amherst41 13 | # columns are: student/faculty, gender, major, 14 | # second major/minor, dorm/house, year/ high school 15 | # 0 denotes missing entry 16 | mat = scipy.io.loadmat(DATAPATH + 'facebook100/' + filename + '.mat') 17 | A = mat['A'] 18 | metadata = mat['local_info'] 19 | return A, metadata 20 | 21 | def load_twitch(lang): 22 | assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset' 23 | filepath = DATAPATH + f"twitch/{lang}" 24 | label = [] 25 | node_ids = [] 26 | src = [] 27 | targ = [] 28 | uniq_ids = set() 29 | with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f: 30 | reader = csv.reader(f) 31 | next(reader) 32 | for row in reader: 33 | node_id = int(row[5]) 34 | # handle FR case of non-unique rows 35 | if node_id not in uniq_ids: 36 | uniq_ids.add(node_id) 37 | label.append(int(row[2]=="True")) 38 | node_ids.append(int(row[5])) 39 | 40 | node_ids = np.array(node_ids, dtype=np.int) 41 | with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f: 42 | reader = csv.reader(f) 43 | next(reader) 44 | for row in reader: 45 | src.append(int(row[0])) 46 | targ.append(int(row[1])) 47 | with open(f"{filepath}/musae_{lang}_features.json", 'r') as f: 48 | j = json.load(f) 49 | src = np.array(src) 50 | targ = np.array(targ) 51 | label = np.array(label) 52 | inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)} 53 | reorder_node_ids = np.zeros_like(node_ids) 54 | for i in range(label.shape[0]): 55 | reorder_node_ids[i] = inv_node_ids[i] 56 | 57 | n = label.shape[0] 58 | A = scipy.sparse.csr_matrix((np.ones(len(src)), 59 | (np.array(src), np.array(targ))), 60 | shape=(n,n)) 61 | features = np.zeros((n,3170)) 62 | for node, feats in j.items(): 63 | if int(node) >= n: 64 | continue 65 | features[int(node), np.array(feats, dtype=int)] = 1 66 | # features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task 67 | new_label = label[reorder_node_ids] 68 | label = new_label 69 | 70 | return A, label, features 71 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os, random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch_geometric.utils import to_undirected 9 | from torch_scatter import scatter 10 | from torch_geometric.data import ShaDowKHopSampler 11 | 12 | from logger import Logger 13 | from dataset import * 14 | from data_utils import normalize, gen_normalized_adjs, to_sparse_tensor, \ 15 | load_fixed_splits, rand_splits, get_gpu_memory_map, count_parameters, reindex_env 16 | from eval import evaluate_full, eval_acc, eval_rocauc, eval_f1 17 | from parse import parser_add_main_args 18 | from model import * 19 | import time 20 | 21 | 22 | # NOTE: for consistent data splits, see data_utils.rand_train_test_idx 23 | def fix_seed(seed): 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | 30 | ### Parse args ### 31 | parser = argparse.ArgumentParser(description='General Training Pipeline') 32 | parser_add_main_args(parser) 33 | args = parser.parse_args() 34 | print(args) 35 | 36 | fix_seed(args.seed) 37 | 38 | if args.cpu: 39 | device = torch.device("cpu") 40 | else: 41 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 42 | 43 | ### Load and preprocess data ### 44 | # multi-graph datasets, divide graphs into train/valid/test 45 | if args.dataset == 'twitch': 46 | dataset = load_twitch_dataset(args.data_dir, train_num=3) 47 | elif args.dataset == 'elliptic': 48 | dataset = load_elliptic_dataset(args.data_dir, train_num=5) 49 | # single-graph datasets, divide nodes into train/valid/test 50 | elif args.dataset == 'arxiv': 51 | dataset = load_arxiv_dataset(args.data_dir, train_num=3) 52 | # synthetic datasets, add spurious node features 53 | elif args.dataset in ('cora', 'citeseer', 'pubmed'): 54 | dataset = load_synthetic_dataset(args.data_dir, args.dataset, train_num=3, combine=args.combine_result) 55 | else: 56 | raise ValueError('Invalid dataname') 57 | 58 | if len(dataset.y.shape) == 1: 59 | dataset.y = dataset.y.unsqueeze(1) 60 | 61 | c = max(dataset.y.max().item() + 1, dataset.y.shape[1]) 62 | d = dataset.x.shape[1] 63 | n = dataset.num_nodes 64 | 65 | print(f"dataset {args.dataset}: all nodes {dataset.num_nodes} | edges {dataset.edge_index.size(1)} | " 66 | + f"classes {c} | feats {d}") 67 | print(f"train nodes {dataset.train_idx.shape[0]} | valid nodes {dataset.valid_idx.shape[0]} | " 68 | f"test in nodes {dataset.test_in_idx.shape[0]}") 69 | m = "" 70 | for i in range(len(dataset.test_ood_idx)): 71 | m += f"test ood{i+1} nodes {dataset.test_ood_idx[i].shape[0]} " 72 | print(m) 73 | print(f'[INFO] env numbers: {dataset.env_num} train env numbers: {dataset.train_env_num}') 74 | 75 | ### Load method ### 76 | is_multilabel = args.dataset in ('proteins', 'ppi') 77 | 78 | model = CaNet(d, c, args, device).to(device) 79 | 80 | if args.dataset in ('elliptic', 'twitch'): 81 | criterion = nn.BCEWithLogitsLoss(reduction='mean') 82 | else: 83 | criterion = nn.CrossEntropyLoss(reduction='mean') 84 | 85 | if args.dataset in ('twitch'): 86 | eval_func = eval_rocauc 87 | elif args.dataset in ('elliptic'): 88 | eval_func = eval_f1 89 | else: 90 | eval_func = eval_acc 91 | 92 | logger = Logger(args.runs, args) 93 | 94 | model.train() 95 | print('MODEL:', model) 96 | 97 | tr_acc, val_acc = [], [] 98 | 99 | ### Training loop ### 100 | for run in range(args.runs): 101 | model.reset_parameters() 102 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 103 | best_val = float('-inf') 104 | 105 | dataset.x, dataset.y, dataset.edge_index, dataset.env = \ 106 | dataset.x.to(device), dataset.y.to(device), dataset.edge_index.to(device), dataset.env.to(device) 107 | 108 | for epoch in range(args.epochs): 109 | model.train() 110 | optimizer.zero_grad() 111 | loss = model.loss_compute(dataset, criterion, args) 112 | loss.backward() 113 | optimizer.step() 114 | result = evaluate_full(model, dataset, eval_func) 115 | logger.add_result(run, result) 116 | 117 | tr_acc.append(result[0]) 118 | val_acc.append(result[2]) 119 | 120 | if epoch % args.display_step == 0: 121 | m = f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * result[0]:.2f}%, Valid: {100 * result[1]:.2f}%, Test In: {100 * result[2]:.2f}% ' 122 | for i in range(len(result)-3): 123 | m += f'Test OOD{i+1}: {100 * result[i+3]:.2f}% ' 124 | print(m) 125 | logger.print_statistics(run) 126 | 127 | 128 | logger.print_statistics() 129 | if args.store_result: 130 | logger.output(args) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from datetime import datetime 4 | from texttable import Texttable 5 | import os 6 | import numpy as np 7 | 8 | class Logger(object): 9 | """ Adapted from https://github.com/snap-stanford/ogb/ """ 10 | 11 | def __init__(self, runs, info=None): 12 | self.info = info 13 | self.results = [[] for _ in range(runs)] 14 | 15 | def add_result(self, run, result): 16 | assert len(result) >= 4 17 | assert run >= 0 and run < len(self.results) 18 | self.results[run].append(result) 19 | 20 | def print_statistics(self, run=None): 21 | if run is not None: 22 | result = 100 * torch.tensor(self.results[run]) 23 | argmax = result[:, 1].argmax().item() 24 | print(f'Run {run + 1:02d}:') 25 | print(f'Highest Train: {result[:, 0].max():.2f}') 26 | print(f'Highest Valid: {result[:, 1].max():.2f}') 27 | print(f'Highest In Test: {result[:, 2].max():.2f}') 28 | for i in range(result.size(1)-3): 29 | print(f'Highest OOD Test: {result[:, i+3].max():.2f}') 30 | print(f'Chosen epoch: {argmax+1}') 31 | print(f'Final Train: {result[argmax, 0]:.2f}') 32 | print(f'Final In Test: {result[argmax, 2]:.2f}') 33 | for i in range(result.size(1)-3): 34 | print(f'Final OOD Test: {result[argmax, i+3]:.2f}') 35 | self.test = result[argmax, 2] 36 | else: 37 | result = 100 * torch.tensor(self.results) 38 | best_results = [] 39 | for r in result: 40 | train_high = r[:, 0].max().item() 41 | valid_high = r[:, 1].max().item() 42 | test_in_high = r[:, 2].max().item() 43 | test_ood_high = [] 44 | for i in range(r.size(1) - 3): 45 | test_ood_high += [r[:, i+3].max().item()] 46 | train_final = r[r[:, 1].argmax(), 0].item() 47 | test_in_final = r[r[:, 1].argmax(), 2].item() 48 | test_ood_final = [] 49 | for i in range(r.size(1) - 3): 50 | test_ood_final += [r[r[:, 1].argmax(), i+3].item()] 51 | best_result = [train_high, valid_high, test_in_high] + test_ood_high + [train_final, test_in_final] + test_ood_final 52 | best_results.append(best_result) 53 | 54 | best_result = torch.tensor(best_results) 55 | print(f'All runs:') 56 | r = best_result[:, 0] 57 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 58 | r = best_result[:, 1] 59 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 60 | r = best_result[:, 2] 61 | print(f'Highest In Test: {r.mean():.2f} ± {r.std():.2f}') 62 | ood_size = result[0].size(1)-3 63 | for i in range(ood_size): 64 | r = best_result[:, i+3] 65 | print(f'Highest OOD Test: {r.mean():.2f} ± {r.std():.2f}') 66 | r = best_result[:, ood_size+3] 67 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 68 | r = best_result[:, ood_size+4] 69 | print(f' Final In Test: {r.mean():.2f} ± {r.std():.2f}') 70 | for i in range(ood_size): 71 | r = best_result[:, i+5+ood_size] 72 | print(f' Final OOD Test: {r.mean():.2f} ± {r.std():.2f}') 73 | 74 | def output(self, args): 75 | result = 100 * torch.tensor(self.results) 76 | best_results = [] 77 | for r in result: 78 | train_high = r[:, 0].max().item() 79 | valid_high = r[:, 1].max().item() 80 | test_in_high = r[:, 2].max().item() 81 | test_ood_high = [] 82 | for i in range(r.size(1) - 3): 83 | test_ood_high += [r[:, i+3].max().item()] 84 | train_final = r[r[:, 1].argmax(), 0].item() 85 | test_in_final = r[r[:, 1].argmax(), 2].item() 86 | test_ood_final = [] 87 | for i in range(r.size(1) - 3): 88 | test_ood_final += [r[r[:, 1].argmax(), i+3].item()] 89 | best_result = [train_high, valid_high, test_in_high] + test_ood_high + [train_final, test_in_final] + test_ood_final 90 | best_results.append(best_result) 91 | best_result = torch.tensor(best_results) 92 | _dict = vars(args) 93 | table = Texttable() 94 | table.add_row(["Parameter", "Value"]) 95 | for k in _dict: 96 | table.add_row([k, _dict[k]]) 97 | 98 | if not os.path.exists(f'results/{args.dataset}/{args.backbone_type}'): 99 | os.makedirs(f'results/{args.dataset}/{args.backbone_type}') 100 | datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S") 101 | filename = f'results/{args.dataset}/{args.backbone_type}/lr_{args.lr}.wd_{args.weight_decay}.tau_{args.tau}.K_{args.K}.dp_{args.dropout}.env_{args.env_type}.{datetime_now}.txt' 102 | with open(f"{filename}", 'a') as f: 103 | f.write(table.draw()) 104 | f.write(f'\nAll runs:\n') 105 | r = best_result[:, 0] 106 | f.write(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}\n') 107 | r = best_result[:, 1] 108 | f.write(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}\n') 109 | r = best_result[:, 2] 110 | f.write(f'Highest In Test: {r.mean():.2f} ± {r.std():.2f}\n') 111 | ood_size = result[0].size(1)-3 112 | for i in range(ood_size): 113 | r = best_result[:, i+3] 114 | f.write(f'Highest OOD Test: {r.mean():.2f} ± {r.std():.2f}\n') 115 | r = best_result[:, ood_size+3] 116 | f.write(f' Final Train: {r.mean():.2f} ± {r.std():.2f}\n') 117 | r = best_result[:, ood_size+4] 118 | f.write(f' Final In Test: {r.mean():.2f} ± {r.std():.2f}\n') 119 | for i in range(ood_size): 120 | r = best_result[:, i+5+ood_size] 121 | f.write(f' Final OOD Test: {r.mean():.2f} ± {r.std():.2f}\n') 122 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from torch_geometric.utils import erdos_renyi_graph, remove_self_loops, add_self_loops, degree, add_remaining_self_loops 8 | from data_utils import sys_normalized_adjacency, sparse_mx_to_torch_sparse_tensor 9 | from torch_sparse import SparseTensor, matmul 10 | 11 | def gcn_conv(x, edge_index): 12 | N = x.shape[0] 13 | row, col = edge_index 14 | d = degree(col, N).float() 15 | d_norm_in = (1. / d[col]).sqrt() 16 | d_norm_out = (1. / d[row]).sqrt() 17 | value = torch.ones_like(row) * d_norm_in * d_norm_out 18 | value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0) 19 | adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N)) 20 | return matmul(adj, x) # [N, D] 21 | 22 | class GraphConvolutionBase(nn.Module): 23 | 24 | def __init__(self, in_features, out_features, residual=False): 25 | super(GraphConvolutionBase, self).__init__() 26 | self.residual = residual 27 | self.in_features = in_features 28 | 29 | self.out_features = out_features 30 | self.weight = Parameter(torch.FloatTensor(self.in_features, self.out_features)) 31 | if self.residual: 32 | self.weight_r = Parameter(torch.FloatTensor(self.in_features, self.out_features)) 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | stdv = 1. / math.sqrt(self.out_features) 37 | self.weight.data.uniform_(-stdv, stdv) 38 | self.weight_r.data.uniform_(-stdv, stdv) 39 | 40 | def forward(self, x, adj, x0): 41 | hi = gcn_conv(x, adj) 42 | output = torch.mm(hi, self.weight) 43 | if self.residual: 44 | output = output + torch.mm(x, self.weight_r) 45 | return output 46 | 47 | class CaNetConv(nn.Module): 48 | 49 | def __init__(self, in_features, out_features, K, residual=True, backbone_type='gcn', variant=False, device=None): 50 | super(CaNetConv, self).__init__() 51 | self.backbone_type = backbone_type 52 | self.out_features = out_features 53 | self.residual = residual 54 | if backbone_type == 'gcn': 55 | self.weights = Parameter(torch.FloatTensor(K, in_features*2, out_features)) 56 | elif backbone_type == 'gat': 57 | self.leakyrelu = nn.LeakyReLU() 58 | self.weights = nn.Parameter(torch.zeros(K, in_features, out_features)) 59 | self.a = nn.Parameter(torch.zeros(K, 2 * out_features, 1)) 60 | self.K = K 61 | self.device = device 62 | self.variant = variant 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | stdv = 1. / math.sqrt(self.out_features) 67 | self.weights.data.uniform_(-stdv, stdv) 68 | if self.backbone_type == 'gat': 69 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 70 | 71 | def specialspmm(self, adj, spm, size, h): 72 | adj = SparseTensor(row=adj[0], col=adj[1], value=spm, sparse_sizes=size) 73 | return matmul(adj, h) 74 | 75 | def forward(self, x, adj, e, weights=None): 76 | if weights == None: 77 | weights = self.weights 78 | if self.backbone_type == 'gcn': 79 | if not self.variant: 80 | hi = gcn_conv(x, adj) 81 | else: 82 | adj = torch.sparse_coo_tensor(adj, torch.ones(adj.shape[1]).to(self.device), size=(x.shape[0],x.shape[0])).to(self.device) 83 | hi = torch.sparse.mm(adj, x) 84 | hi = torch.cat([hi, x], 1) 85 | hi = hi.unsqueeze(0).repeat(self.K, 1, 1) # [K, N, D*2] 86 | outputs = torch.matmul(hi, weights) # [K, N, D] 87 | outputs = outputs.transpose(1, 0) # [N, K, D] 88 | elif self.backbone_type == 'gat': 89 | xi = x.unsqueeze(0).repeat(self.K, 1, 1) # [K, N, D] 90 | h = torch.matmul(xi, weights) # [K, N, D] 91 | N = x.size()[0] 92 | adj, _ = remove_self_loops(adj) 93 | adj, _ = add_self_loops(adj, num_nodes=N) 94 | edge_h = torch.cat((h[:, adj[0, :], :], h[:, adj[1, :], :]), dim=2) # [K, E, 2*D] 95 | logits = self.leakyrelu(torch.matmul(edge_h, self.a)).squeeze(2) 96 | logits_max , _ = torch.max(logits, dim=1, keepdim=True) 97 | edge_e = torch.exp(logits-logits_max) # [K, E] 98 | 99 | outputs = [] 100 | eps = 1e-8 101 | for k in range(self.K): 102 | edge_e_k = edge_e[k, :] # [E] 103 | e_expsum_k = self.specialspmm(adj, edge_e_k, torch.Size([N, N]), torch.ones(N, 1).cuda()) + eps 104 | assert not torch.isnan(e_expsum_k).any() 105 | 106 | hi_k = self.specialspmm(adj, edge_e_k, torch.Size([N, N]), h[k]) 107 | hi_k = torch.div(hi_k, e_expsum_k) # [N, D] 108 | outputs.append(hi_k) 109 | outputs = torch.stack(outputs, dim=1) # [N, K, D] 110 | 111 | es = e.unsqueeze(2).repeat(1, 1, self.out_features) # [N, K, D] 112 | output = torch.sum(torch.mul(es, outputs), dim=1) # [N, D] 113 | 114 | if self.residual: 115 | output = output + x 116 | 117 | return output 118 | 119 | class CaNet(nn.Module): 120 | def __init__(self, d, c, args, device): 121 | super(CaNet, self).__init__() 122 | self.convs = nn.ModuleList() 123 | for _ in range(args.num_layers): 124 | self.convs.append(CaNetConv(args.hidden_channels, args.hidden_channels, args.K, backbone_type=args.backbone_type, residual=True, device=device, variant=args.variant)) 125 | self.fcs = nn.ModuleList() 126 | self.fcs.append(nn.Linear(d, args.hidden_channels)) 127 | self.fcs.append(nn.Linear(args.hidden_channels, c)) 128 | self.env_enc = nn.ModuleList() 129 | for _ in range(args.num_layers): 130 | if args.env_type == 'node': 131 | self.env_enc.append(nn.Linear(args.hidden_channels, args.K)) 132 | elif args.env_type == 'graph': 133 | self.env_enc.append(GraphConvolutionBase(args.hidden_channels, args.K, residual=True)) 134 | else: 135 | raise NotImplementedError 136 | self.act_fn = nn.ReLU() 137 | self.dropout = args.dropout 138 | self.num_layers = args.num_layers 139 | self.tau = args.tau 140 | self.env_type = args.env_type 141 | self.device = device 142 | 143 | def reset_parameters(self): 144 | for conv in self.convs: 145 | conv.reset_parameters() 146 | for fc in self.fcs: 147 | fc.reset_parameters() 148 | for enc in self.env_enc: 149 | enc.reset_parameters() 150 | 151 | def forward(self, x, adj, idx=None, training=False): 152 | self.training = training 153 | x = F.dropout(x, self.dropout, training=self.training) 154 | h = self.act_fn(self.fcs[0](x)) 155 | h0 = h.clone() 156 | 157 | reg = 0 158 | for i,con in enumerate(self.convs): 159 | h = F.dropout(h, self.dropout, training=self.training) 160 | if self.training: 161 | if self.env_type == 'node': 162 | logit = self.env_enc[i](h) 163 | else: 164 | logit = self.env_enc[i](h, adj, h0) 165 | e = F.gumbel_softmax(logit, tau=self.tau, dim=-1) 166 | reg += self.reg_loss(e, logit) 167 | else: 168 | if self.env_type == 'node': 169 | e = F.softmax(self.env_enc[i](h), dim=-1) 170 | else: 171 | e = F.softmax(self.env_enc[i](h, adj, h0), dim=-1) 172 | h = self.act_fn(con(h, adj, e)) 173 | 174 | h = F.dropout(h, self.dropout, training=self.training) 175 | out = self.fcs[-1](h) 176 | if self.training: 177 | return out, reg / self.num_layers 178 | else: 179 | return out 180 | 181 | def reg_loss(self, z, logit, logit_0 = None): 182 | log_pi = logit - torch.logsumexp(logit, dim=-1, keepdim=True).repeat(1, logit.size(1)) 183 | return torch.mean(torch.sum( 184 | torch.mul(z, log_pi), dim=1)) 185 | 186 | def sup_loss_calc(self, y, pred, criterion, args): 187 | if args.dataset in ('twitch', 'elliptic'): 188 | if y.shape[1] == 1: 189 | true_label = F.one_hot(y, y.max() + 1).squeeze(1) 190 | else: 191 | true_label = y 192 | loss = criterion(pred, true_label.squeeze(1).to(torch.float)) 193 | else: 194 | out = F.log_softmax(pred, dim=1) 195 | target = y.squeeze(1) 196 | loss = criterion(out, target) 197 | return loss 198 | 199 | def loss_compute(self, d, criterion, args): 200 | logits, reg_loss = self.forward(d.x, d.edge_index, idx=d.train_idx, training=True) 201 | sup_loss = self.sup_loss_calc(d.y[d.train_idx], logits[d.train_idx], criterion, args) 202 | loss = sup_loss + args.lamda * reg_loss 203 | return loss -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import scipy 7 | import scipy.io 8 | from sklearn.preprocessing import label_binarize 9 | import torch_geometric.transforms as T 10 | from data_utils import even_quantile_labels, to_sparse_tensor 11 | 12 | from torch_geometric.datasets import Planetoid, Amazon, Coauthor, Twitch, PPI, Reddit 13 | from torch_geometric.transforms import NormalizeFeatures, RadiusGraph 14 | from torch_geometric.data import Data, Batch 15 | from torch_geometric.utils import stochastic_blockmodel_graph, subgraph, homophily, to_dense_adj, dense_to_sparse 16 | 17 | from torch_geometric.nn import GCNConv, SGConv, SAGEConv, GATConv 18 | 19 | 20 | import pickle as pkl 21 | import os 22 | 23 | class GCN_gen(nn.Module): 24 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers): 25 | super(GCN_gen, self).__init__() 26 | 27 | self.convs = nn.ModuleList() 28 | self.convs.append( 29 | GCNConv(in_channels, hidden_channels)) 30 | for _ in range(num_layers - 2): 31 | self.convs.append( 32 | GCNConv(hidden_channels, hidden_channels)) 33 | self.convs.append( 34 | GCNConv(hidden_channels, out_channels)) 35 | 36 | self.activation = F.relu 37 | 38 | def reset_parameters(self): 39 | for conv in self.convs: 40 | conv.reset_parameters() 41 | 42 | def forward(self, x, edge_index, edge_weight=None): 43 | for i, conv in enumerate(self.convs[:-1]): 44 | x = conv(x, edge_index, edge_weight) 45 | x = self.activation(x) 46 | x = self.convs[-1](x, edge_index) 47 | return x 48 | 49 | def load_twitch_dataset(data_dir, train_num=3, train_ratio=0.5, valid_ratio=0.25): 50 | transform = T.NormalizeFeatures() 51 | sub_graphs = ['DE', 'PT', 'RU', 'ES', 'FR', 'EN'] 52 | x_list, edge_index_list, y_list, env_list = [], [], [], [] 53 | node_idx_list = [] 54 | idx_shift = 0 55 | for i, g in enumerate(sub_graphs): 56 | torch_dataset = Twitch(root=f'{data_dir}Twitch', 57 | name=g, transform=transform) 58 | data = torch_dataset[0] 59 | x, edge_index, y = data.x, data.edge_index, data.y 60 | x_list.append(x) 61 | y_list.append(y) 62 | edge_index_list.append(edge_index + idx_shift) 63 | env_list.append(torch.ones(x.size(0)) * i) 64 | node_idx_list.append(torch.arange(data.num_nodes) + idx_shift) 65 | 66 | idx_shift += data.num_nodes 67 | x = torch.cat(x_list, dim=0) 68 | y = torch.cat(y_list, dim=0) 69 | edge_index = torch.cat(edge_index_list, dim=1) 70 | env = torch.cat(env_list, dim=0) 71 | dataset = Data(x=x, edge_index=edge_index, y=y) 72 | dataset.env = env 73 | dataset.env_num = len(sub_graphs) 74 | dataset.train_env_num = train_num 75 | 76 | assert (train_num <= 5) 77 | 78 | ind_idx = torch.cat(node_idx_list[:train_num], dim=0) 79 | idx = torch.randperm(ind_idx.size(0)) 80 | train_idx_ind = idx[:int(idx.size(0) * train_ratio)] 81 | valid_idx_ind = idx[int(idx.size(0) * train_ratio) : int(idx.size(0) * (train_ratio + valid_ratio))] 82 | test_idx_ind = idx[int(idx.size(0) * (train_ratio + valid_ratio)):] 83 | dataset.train_idx = ind_idx[train_idx_ind] 84 | dataset.valid_idx = ind_idx[valid_idx_ind] 85 | dataset.test_in_idx = ind_idx[test_idx_ind] 86 | dataset.test_ood_idx = [node_idx_list[-1]] if train_num>=4 else node_idx_list[train_num:] 87 | 88 | return dataset 89 | 90 | def load_synthetic_dataset(data_dir, name, env_num=6, train_num=3, train_ratio=0.5, valid_ratio=0.25, combine=False): 91 | transform = T.NormalizeFeatures() 92 | torch_dataset = Planetoid(root=f'{data_dir}Planetoid', 93 | name=name, transform=transform) 94 | preprocess_dir = os.path.join(data_dir, 'Planetoid', name) 95 | 96 | data = torch_dataset[0] 97 | 98 | edge_index = data.edge_index 99 | x = data.x 100 | d = x.shape[1] 101 | 102 | preprocess_dir = os.path.join(preprocess_dir, 'gen') 103 | if not os.path.exists(preprocess_dir): 104 | os.makedirs(preprocess_dir) 105 | spu_feat_num = 10 106 | class_num = data.y.max().item() + 1 107 | 108 | node_idx_list = [torch.arange(data.num_nodes) + i*data.num_nodes for i in range(env_num)] 109 | 110 | file_path = preprocess_dir + f'/{class_num}-{spu_feat_num}-{env_num}.pkl' 111 | if not os.path.exists(file_path): 112 | 113 | print("creating new synthetic data...") 114 | x_list, edge_index_list, y_list, env_list = [], [], [], [] 115 | idx_shift = 0 116 | 117 | # Generator_y = GCN_gen(in_channels=d, hidden_channels=10, out_channels=class_num, num_layers=2) 118 | Generator_x = GCN_gen(in_channels=class_num, hidden_channels=10, out_channels=spu_feat_num, num_layers=2) 119 | Generator_noise = nn.Linear(env_num, spu_feat_num) 120 | 121 | with torch.no_grad(): 122 | for i in range(env_num): 123 | label_new = F.one_hot(data.y, class_num).squeeze(1).float() 124 | context_ = torch.zeros(x.size(0), env_num) 125 | context_[:, i] = 1 126 | x2 = Generator_x(label_new, edge_index) + Generator_noise(context_) 127 | x2 += torch.ones_like(x2).normal_(0, 0.1) 128 | x_new = torch.cat([x, x2], dim=1) 129 | 130 | x_list.append(x_new) 131 | y_list.append(data.y) 132 | edge_index_list.append(edge_index + idx_shift) 133 | env_list.append(torch.ones(x.size(0)) * i) 134 | 135 | idx_shift += data.num_nodes 136 | 137 | x = torch.cat(x_list, dim=0) 138 | y = torch.cat(y_list, dim=0) 139 | edge_index = torch.cat(edge_index_list, dim=1) 140 | env = torch.cat(env_list, dim=0) 141 | dataset = Data(x=x, edge_index=edge_index, y=y) 142 | dataset.env = env 143 | 144 | with open(file_path, 'wb') as f: 145 | pkl.dump((dataset), f, pkl.HIGHEST_PROTOCOL) 146 | else: 147 | print("using existing synthetic data...") 148 | with open(file_path, 'rb') as f: 149 | dataset = pkl.load(f) 150 | 151 | assert (train_num <= env_num-1) 152 | 153 | ind_idx = torch.cat(node_idx_list[:train_num], dim=0) 154 | idx = torch.randperm(ind_idx.size(0)) 155 | train_idx_ind = idx[:int(idx.size(0) * train_ratio)] 156 | valid_idx_ind = idx[int(idx.size(0) * train_ratio): int(idx.size(0) * (train_ratio + valid_ratio))] 157 | test_idx_ind = idx[int(idx.size(0) * (train_ratio + valid_ratio)):] 158 | dataset.train_idx = ind_idx[train_idx_ind] 159 | dataset.valid_idx = ind_idx[valid_idx_ind] 160 | dataset.test_in_idx = ind_idx[test_idx_ind] 161 | 162 | if combine: 163 | dataset.test_ood_idx = [node_idx_list[-1]] if train_num==env_num-1 else [torch.cat(node_idx_list[train_num:], dim=0)] # Combine three ood environments 164 | else: 165 | dataset.test_ood_idx = [node_idx_list[-1]] if train_num==env_num-1 else node_idx_list[train_num:] # Test three ood environments respectively 166 | 167 | dataset.env_num = env_num 168 | dataset.train_env_num = train_num 169 | 170 | return dataset 171 | 172 | 173 | def load_arxiv_dataset(data_dir, train_num=3, train_ratio=0.5, valid_ratio=0.25, inductive=True): 174 | from ogb.nodeproppred import NodePropPredDataset 175 | 176 | ogb_dataset = NodePropPredDataset(name='ogbn-arxiv', root=f'{data_dir}/ogb') 177 | 178 | node_years = ogb_dataset.graph['node_year'] 179 | 180 | edge_index = torch.as_tensor(ogb_dataset.graph['edge_index']) 181 | node_feat = torch.as_tensor(ogb_dataset.graph['node_feat']) 182 | label = torch.as_tensor(ogb_dataset.labels) 183 | 184 | year_bound = [2005, 2010, 2012, 2014, 2016, 2018, 2021] 185 | env = torch.zeros(label.shape[0]) 186 | for n in range(node_years.shape[0]): 187 | year = int(node_years[n]) 188 | for i in range(len(year_bound)-1): 189 | if year >= year_bound[i+1]: 190 | continue 191 | else: 192 | env[n] = i 193 | break 194 | 195 | dataset = Data(x=node_feat, edge_index=edge_index, y=label) 196 | dataset.env = env 197 | dataset.env_num = len(year_bound) 198 | dataset.train_env_num = train_num 199 | 200 | ind_mask = (node_years < year_bound[train_num]).squeeze(1) 201 | idx = torch.arange(dataset.num_nodes) 202 | ind_idx = idx[ind_mask] 203 | idx_ = torch.randperm(ind_idx.size(0)) 204 | train_idx_ind = idx_[:int(idx_.size(0) * train_ratio)] 205 | valid_idx_ind = idx_[int(idx_.size(0) * train_ratio): int(idx_.size(0) * (train_ratio + valid_ratio))] 206 | test_idx_ind = idx_[int(idx_.size(0) * (train_ratio + valid_ratio)):] 207 | dataset.train_idx = ind_idx[train_idx_ind] 208 | dataset.valid_idx = ind_idx[valid_idx_ind] 209 | dataset.test_in_idx = ind_idx[test_idx_ind] 210 | 211 | dataset.test_ood_idx = [] 212 | 213 | for i in range(train_num, len(year_bound)-1): 214 | ood_mask_i = ((node_years >= year_bound[i]) * (node_years < year_bound[i+1])).squeeze(1) 215 | dataset.test_ood_idx.append(idx[ood_mask_i]) 216 | 217 | return dataset 218 | 219 | 220 | def load_elliptic_dataset(data_dir, train_num=5, train_ratio=0.5, valid_ratio=0.25): 221 | 222 | sub_graphs = range(0, 49) 223 | x_list, edge_index_list, y_list, mask_list, env_list = [], [], [], [], [] 224 | node_idx_list = [] 225 | idx_shift = 0 226 | for i in sub_graphs: 227 | result = pkl.load(open('{}/elliptic/{}.pkl'.format(data_dir, i), 'rb')) 228 | A, label, features = result 229 | edge_index = torch.tensor(A.nonzero(), dtype=torch.long) 230 | x = torch.tensor(features, dtype=torch.float) 231 | y = torch.tensor(label) 232 | 233 | x_list.append(x) 234 | y_list.append(y) 235 | mask = (y >= 0) 236 | edge_index_list.append(edge_index + idx_shift) 237 | env_list.append(torch.ones(x.size(0)) * i) 238 | node_idx_list.append(torch.arange(x.shape[0])[mask] + idx_shift) 239 | 240 | idx_shift += x.shape[0] 241 | 242 | x = torch.cat(x_list, dim=0) 243 | y = torch.cat(y_list, dim=0) 244 | edge_index = torch.cat(edge_index_list, dim=1) 245 | env = torch.cat(env_list, dim=0) 246 | dataset = Data(x=x, edge_index=edge_index, y=y) 247 | dataset.env = env 248 | dataset.env_num = len(sub_graphs) 249 | dataset.train_env_num = train_num 250 | 251 | ind_idx = torch.cat(node_idx_list[:train_num], dim=0) 252 | idx = torch.randperm(ind_idx.size(0)) 253 | train_idx_ind = idx[:int(idx.size(0) * train_ratio)] 254 | valid_idx_ind = idx[int(idx.size(0) * train_ratio): int(idx.size(0) * (train_ratio + valid_ratio))] 255 | test_idx_ind = idx[int(idx.size(0) * (train_ratio + valid_ratio)):] 256 | dataset.train_idx = ind_idx[train_idx_ind] 257 | dataset.valid_idx = ind_idx[valid_idx_ind] 258 | dataset.test_in_idx = ind_idx[test_idx_ind] 259 | 260 | ood_margin = 4 261 | dataset.test_ood_idx = [] 262 | for k in range((len(sub_graphs) - train_num*2) // ood_margin - 1): 263 | ood_idx_k = [node_idx_list[l] for l in range(train_num*2 + ood_margin * k, train_num*2 + ood_margin * (k + 1))] 264 | dataset.test_ood_idx.append(torch.cat(ood_idx_k, dim=0)) 265 | return dataset -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from collections import defaultdict 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from scipy import sparse as sp 9 | 10 | from torch_sparse import SparseTensor 11 | from collections import Counter 12 | 13 | 14 | def rand_splits(node_idx, train_prop=.5, valid_prop=.25): 15 | """ randomly splits label into train/valid/test splits """ 16 | splits = {} 17 | n = node_idx.size(0) 18 | 19 | train_num = int(n * train_prop) 20 | valid_num = int(n * valid_prop) 21 | 22 | perm = torch.as_tensor(np.random.permutation(n)) 23 | 24 | train_indices = perm[:train_num] 25 | val_indices = perm[train_num:train_num + valid_num] 26 | test_indices = perm[train_num + valid_num:] 27 | 28 | splits['train'] = node_idx[train_indices] 29 | splits['valid'] = node_idx[val_indices] 30 | splits['test'] = node_idx[test_indices] 31 | 32 | return splits 33 | 34 | 35 | def load_fixed_splits(data_dir, dataset, name, protocol): 36 | splits_lst = [] 37 | if name in ['cora', 'citeseer', 'pubmed'] and protocol == 'semi': 38 | splits = {} 39 | splits['train'] = torch.as_tensor( 40 | dataset.train_mask.nonzero().squeeze(1)) 41 | splits['valid'] = torch.as_tensor( 42 | dataset.val_mask.nonzero().squeeze(1)) 43 | splits['test'] = torch.as_tensor( 44 | dataset.test_mask.nonzero().squeeze(1)) 45 | splits_lst.append(splits) 46 | elif name in ['cora', 'citeseer', 'pubmed', 'chameleon', 'squirrel', 'film', 'cornell', 'texas', 'wisconsin']: 47 | for i in range(10): 48 | splits_file_path = '{}/geom-gcn/splits/{}'.format( 49 | data_dir, name) + '_split_0.6_0.2_'+str(i)+'.npz' 50 | splits = {} 51 | with np.load(splits_file_path) as splits_file: 52 | splits['train'] = torch.BoolTensor(splits_file['train_mask']) 53 | splits['valid'] = torch.BoolTensor(splits_file['val_mask']) 54 | splits['test'] = torch.BoolTensor(splits_file['test_mask']) 55 | splits_lst.append(splits) 56 | else: 57 | raise NotImplementedError 58 | 59 | return splits_lst 60 | 61 | 62 | def even_quantile_labels(vals, nclasses, verbose=True): 63 | """ partitions vals into nclasses by a quantile based split, 64 | where the first class is less than the 1/nclasses quantile, 65 | second class is less than the 2/nclasses quantile, and so on 66 | 67 | vals is np array 68 | returns an np array of int class labels 69 | """ 70 | label = -1 * np.ones(vals.shape[0], dtype=np.int) 71 | interval_lst = [] 72 | lower = -np.inf 73 | for k in range(nclasses - 1): 74 | upper = np.quantile(vals, (k + 1) / nclasses) 75 | interval_lst.append((lower, upper)) 76 | inds = (vals >= lower) * (vals < upper) 77 | label[inds] = k 78 | lower = upper 79 | label[vals >= lower] = nclasses - 1 80 | interval_lst.append((lower, np.inf)) 81 | if verbose: 82 | print('Class Label Intervals:') 83 | for class_idx, interval in enumerate(interval_lst): 84 | print(f'Class {class_idx}: [{interval[0]}, {interval[1]})]') 85 | return label 86 | 87 | 88 | def to_planetoid(dataset): 89 | """ 90 | Takes in a NCDataset and returns the dataset in H2GCN Planetoid form, as follows: 91 | x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; 92 | tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; 93 | allx => the feature vectors of both labeled and unlabeled training instances 94 | (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object; 95 | y => the one-hot labels of the labeled training instances as numpy.ndarray object; 96 | ty => the one-hot labels of the test instances as numpy.ndarray object; 97 | ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object; 98 | graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict 99 | object; 100 | split_idx => The ogb dictionary that contains the train, valid, test splits 101 | """ 102 | split_idx = dataset.get_idx_split('random', 0.25) 103 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 104 | 105 | graph, label = dataset[0] 106 | 107 | label = torch.squeeze(label) 108 | 109 | print("generate x") 110 | x = graph['node_feat'][train_idx].numpy() 111 | x = sp.csr_matrix(x) 112 | 113 | tx = graph['node_feat'][test_idx].numpy() 114 | tx = sp.csr_matrix(tx) 115 | 116 | allx = graph['node_feat'].numpy() 117 | allx = sp.csr_matrix(allx) 118 | 119 | y = F.one_hot(label[train_idx]).numpy() 120 | ty = F.one_hot(label[test_idx]).numpy() 121 | ally = F.one_hot(label).numpy() 122 | 123 | edge_index = graph['edge_index'].T 124 | 125 | graph = defaultdict(list) 126 | 127 | for i in range(0, label.shape[0]): 128 | graph[i].append(i) 129 | 130 | for start_edge, end_edge in edge_index: 131 | graph[start_edge.item()].append(end_edge.item()) 132 | 133 | return x, tx, allx, y, ty, ally, graph, split_idx 134 | 135 | 136 | def to_sparse_tensor(edge_index, edge_feat, num_nodes): 137 | """ converts the edge_index into SparseTensor 138 | """ 139 | num_edges = edge_index.size(1) 140 | 141 | (row, col), N, E = edge_index, num_nodes, num_edges 142 | perm = (col * N + row).argsort() 143 | row, col = row[perm], col[perm] 144 | 145 | value = edge_feat[perm] 146 | adj_t = SparseTensor(row=col, col=row, value=value, 147 | sparse_sizes=(N, N), is_sorted=True) 148 | 149 | # Pre-process some important attributes. 150 | adj_t.storage.rowptr() 151 | adj_t.storage.csr2csc() 152 | 153 | return adj_t 154 | 155 | 156 | def normalize(edge_index): 157 | """ normalizes the edge_index 158 | """ 159 | adj_t = edge_index.set_diag() 160 | deg = adj_t.sum(dim=1).to(torch.float) 161 | deg_inv_sqrt = deg.pow(-0.5) 162 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 163 | adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) 164 | return adj_t 165 | 166 | 167 | def gen_normalized_adjs(dataset): 168 | """ returns the normalized adjacency matrix 169 | """ 170 | row, col = dataset.graph['edge_index'] 171 | N = dataset.graph['num_nodes'] 172 | adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) 173 | deg = adj.sum(dim=1).to(torch.float) 174 | D_isqrt = deg.pow(-0.5) 175 | D_isqrt[D_isqrt == float('inf')] = 0 176 | 177 | DAD = D_isqrt.view(-1, 1) * adj * D_isqrt.view(1, -1) 178 | DA = D_isqrt.view(-1, 1) * D_isqrt.view(-1, 1) * adj 179 | AD = adj * D_isqrt.view(1, -1) * D_isqrt.view(1, -1) 180 | return DAD, DA, AD 181 | 182 | 183 | def convert_to_adj(edge_index, n_node): 184 | '''convert from pyg format edge_index to n by n adj matrix''' 185 | adj = torch.zeros((n_node, n_node)) 186 | row, col = edge_index 187 | adj[row, col] = 1 188 | return adj 189 | 190 | 191 | def get_gpu_memory_map(): 192 | """Get the current gpu usage. 193 | Returns 194 | ------- 195 | usage: dict 196 | Keys are device ids as integers. 197 | Values are memory usage as integers in MB. 198 | """ 199 | result = subprocess.check_output( 200 | [ 201 | 'nvidia-smi', '--query-gpu=memory.used', 202 | '--format=csv,nounits,noheader' 203 | ], encoding='utf-8') 204 | # Convert lines into a dictionary 205 | gpu_memory = np.array([int(x) for x in result.strip().split('\n')]) 206 | # gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 207 | return gpu_memory 208 | 209 | 210 | def count_parameters(model): 211 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 212 | 213 | 214 | dataset_drive_url = { 215 | 'snap-patents': '1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia', 216 | 'pokec': '1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y', 217 | 'yelp-chi': '1fAXtTVQS4CfEk4asqrFw9EPmlUPGbGtJ', 218 | } 219 | 220 | splits_drive_url = { 221 | 'snap-patents': '12xbBRqd8mtG_XkNLH8dRRNZJvVM4Pw-N', 222 | 'pokec': '1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_', 223 | } 224 | 225 | 226 | def split_into_groups(g): 227 | """ 228 | Args: 229 | - g (Tensor): Vector of groups 230 | Returns: 231 | - groups (Tensor): Unique groups present in g 232 | - group_indices (list): List of Tensors, where the i-th tensor is the indices of the 233 | elements of g that equal groups[i]. 234 | Has the same length as len(groups). 235 | - unique_counts (Tensor): Counts of each element in groups. 236 | Has the same length as len(groups). 237 | """ 238 | unique_groups, unique_counts = torch.unique( 239 | g, sorted=False, return_counts=True) 240 | group_indices = [] 241 | for group in unique_groups: 242 | group_indices.append( 243 | torch.nonzero(g == group, as_tuple=True)[0]) 244 | return unique_groups, group_indices, unique_counts 245 | 246 | 247 | def reindex_env(dataset, debug=False): 248 | """ 249 | reindex the environments 250 | make sure the environments on train set is 0 ~ (train_env_num - 1) 251 | """ 252 | idx_map = {} 253 | train_idx_set = set(dataset.train_idx) 254 | new_env = torch.zeros_like(dataset.env) 255 | for idx in train_idx_set: 256 | if int(dataset.env[idx]) not in idx_map: 257 | idx_map[int(dataset.env[idx])] = len(idx_map) 258 | new_env[idx] = idx_map[int(dataset.env[idx])] 259 | 260 | train_env_num = len(idx_map) 261 | 262 | for i in range(len(dataset.env)): 263 | if int(dataset.env[i]) not in idx_map: 264 | idx_map[int(dataset.env[i])] = len(idx_map) 265 | new_env[i] = idx_map[int(dataset.env[i])] 266 | 267 | if debug: 268 | print('[INFO] reindex the environments') 269 | print(idx_map) 270 | print(Counter(dataset.env.cpu().tolist())) 271 | print(Counter(new_env.cpu().tolist())) 272 | 273 | new_env = new_env.to(dataset.env.device) 274 | dataset.env = new_env 275 | return train_env_num 276 | 277 | 278 | def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: 279 | """This function converts target class indices to one-hot vectors, given 280 | the number of classes. 281 | Args: 282 | targets (Tensor): The ground truth label of the prediction 283 | with shape (N, 1) 284 | classes (int): the number of classes. 285 | Returns: 286 | Tensor: Processed loss values. 287 | """ 288 | assert (torch.max(targets).item() < classes), \ 289 | 'Class Index must be less than number of classes' 290 | one_hot_targets = torch.zeros((targets.shape[0], classes), 291 | dtype=torch.long, 292 | device=targets.device) 293 | one_hot_targets.scatter_(1, targets.long(), 1) 294 | return one_hot_targets 295 | 296 | 297 | def sys_normalized_adjacency(adj, size=None): 298 | adj = sp.coo_matrix(adj, size) 299 | adj = adj + sp.eye(adj.shape[0]) 300 | row_sum = np.array(adj.sum(1)) 301 | row_sum=(row_sum==0)*1+row_sum 302 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 303 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 304 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 305 | return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo() 306 | 307 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 308 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 309 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 310 | indices = torch.from_numpy( 311 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 312 | values = torch.from_numpy(sparse_mx.data) 313 | shape = torch.Size(sparse_mx.shape) 314 | return torch.sparse.FloatTensor(indices, values, shape) --------------------------------------------------------------------------------