├── dataset_apis ├── topology_dist │ └── .gitkeep ├── texas.py ├── cornell.py ├── wisconsin.py ├── cora.py ├── citeseer.py ├── dblp.py └── pubmed.py ├── assets └── framework.PNG ├── .gitignore ├── global_var.py ├── adversarial.py ├── eval.py ├── test_runs.py ├── config.yaml ├── augmentation.py ├── README.md ├── train.py ├── eval_utils.py └── model.py /dataset_apis/topology_dist/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/framework.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhuYun97/RoSA/HEAD/assets/framework.PNG -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # checkpoints 2 | checkpoints/ 3 | runs_ckpts_Cora 4 | runs_ckpts_Citeseer 5 | 6 | # records for running 7 | runs/ 8 | 9 | # dataset 10 | datasets/ 11 | dataset_apis/topology_dist/*.pt 12 | 13 | # logs 14 | logs/ 15 | 16 | 17 | # all pyc files_ 18 | **/__pycache__ 19 | 20 | **/.vscode 21 | **/ipynb_checkpoints 22 | *.ipynb 23 | 24 | transform_ckpt.py -------------------------------------------------------------------------------- /dataset_apis/texas.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import WebKB 2 | import torch_geometric.transforms as T 3 | 4 | def load_trainset(trans): 5 | return WebKB(root='~/datasets', name='Texas', transform=trans) 6 | 7 | def load_eval_trainset(): 8 | return WebKB(root='~/datasets', name='Texas') 9 | 10 | def load_testset(): 11 | return WebKB(root='~/datasets', name='Texas') -------------------------------------------------------------------------------- /global_var.py: -------------------------------------------------------------------------------- 1 | def _init(): # initialize 2 | global _global_dict 3 | _global_dict = {} 4 | 5 | def set_value(key, value): 6 | # define a global value 7 | _global_dict[key] = value 8 | 9 | def get_value(key): 10 | # get a pre-defined global value 11 | try: 12 | return _global_dict[key] 13 | except: 14 | print('access'+key+'is failing\r\n') -------------------------------------------------------------------------------- /dataset_apis/cornell.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import WebKB 2 | import torch_geometric.transforms as T 3 | 4 | def load_trainset(trans): 5 | return WebKB(root='~/datasets', name='Cornell', transform=trans) 6 | 7 | def load_eval_trainset(): 8 | return WebKB(root='~/datasets', name='Cornell') 9 | 10 | def load_testset(): 11 | return WebKB(root='~/datasets', name='Cornell') -------------------------------------------------------------------------------- /dataset_apis/wisconsin.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import WebKB 2 | import torch_geometric.transforms as T 3 | 4 | def load_trainset(trans): 5 | return WebKB(root='~/datasets', name='Wisconsin', transform=trans) 6 | 7 | def load_eval_trainset(): 8 | return WebKB(root='~/datasets', name='Wisconsin') 9 | 10 | def load_testset(): 11 | return WebKB(root='~/datasets', name='Wisconsin') -------------------------------------------------------------------------------- /dataset_apis/cora.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import Planetoid 2 | import torch_geometric.transforms as T 3 | 4 | def load_trainset(trans): 5 | dataset = Planetoid(root='~/datasets', name='Cora', transform=T.Compose([trans])) 6 | return dataset 7 | 8 | def load_eval_trainset(): 9 | return Planetoid(root='~/datasets', name='Cora') 10 | 11 | def load_testset(): 12 | return Planetoid(root='~/datasets', name='Cora') -------------------------------------------------------------------------------- /dataset_apis/citeseer.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import Planetoid 2 | import torch_geometric.transforms as T 3 | 4 | def load_trainset(trans): 5 | dataset = Planetoid(root='~/datasets', name='Citeseer', transform=T.Compose([trans])) 6 | return dataset 7 | 8 | def load_eval_trainset(): 9 | return Planetoid(root='~/datasets', name='Citeseer') 10 | 11 | def load_testset(): 12 | return Planetoid(root='~/datasets', name='Citeseer') -------------------------------------------------------------------------------- /dataset_apis/dblp.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import CitationFull 2 | import torch_geometric.transforms as T 3 | 4 | def load_trainset(trans): 5 | dataset = CitationFull(root='~/datasets', name='dblp', transform=T.Compose([trans])) 6 | return dataset 7 | 8 | def load_eval_trainset(): 9 | return CitationFull(root='~/datasets', name='dblp') 10 | 11 | def load_testset(): 12 | return CitationFull(root='~/datasets', name='dblp') -------------------------------------------------------------------------------- /dataset_apis/pubmed.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import Planetoid 2 | import torch_geometric.transforms as T 3 | 4 | 5 | def load_trainset(trans): 6 | dataset = Planetoid(root='~/datasets', name='Pubmed', transform=T.Compose([trans])) 7 | return dataset 8 | 9 | def load_eval_trainset(): 10 | return Planetoid(root='~/datasets', name='Pubmed') 11 | 12 | def load_testset(): 13 | return Planetoid(root='~/datasets', name='Pubmed') -------------------------------------------------------------------------------- /adversarial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def ad_training(model, node_attack, perturb_shape, args, device): 4 | model.train() 5 | 6 | perturb = torch.FloatTensor(*perturb_shape).uniform_(-args.step_size, args.step_size).to(device) 7 | perturb.requires_grad_() 8 | 9 | loss = node_attack(perturb) 10 | loss /= args.m 11 | 12 | for i in range(args.m-1): 13 | loss.backward() 14 | perturb_data = perturb.detach() + args.step_size * torch.sign(perturb.grad.detach()) 15 | perturb.data = perturb_data.data 16 | perturb.grad[:] = 0 17 | 18 | loss = node_attack(perturb) 19 | loss /= args.m 20 | 21 | return loss -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from eval_utils import eval 2 | import argparse 3 | import random 4 | from model import RoSA 5 | from eval_utils import eval 6 | import yaml 7 | from yaml import SafeLoader 8 | import torch 9 | 10 | 11 | def str2bool(v): 12 | if isinstance(v, bool): 13 | return v 14 | if v.lower() in ('yes','y','true','t','1'): 15 | return True 16 | if v.lower() in ('no','n','false','f','0'): 17 | return False 18 | else: 19 | raise argparse.ArgumentTypeError('Boolean value expected.') 20 | def get_parser(): 21 | parser = argparse.ArgumentParser(description='Description: Script to run our model.') 22 | parser.add_argument('--config', type=str, default='config.yaml') 23 | # dataset 24 | parser.add_argument('--dataset',help='Cora, Citeseer , Pubmed, etc. Default=Cora', default='Cora') 25 | 26 | return parser 27 | 28 | device = "cuda" if torch.cuda.is_available() else "cpu" 29 | 30 | if __name__ == "__main__": 31 | parser = get_parser() 32 | try: 33 | args = parser.parse_args() 34 | except: 35 | exit() 36 | config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset] 37 | 38 | # combine args and config 39 | for k, v in config.items(): 40 | args.__setattr__(k, v) 41 | 42 | # repeated experiment 43 | torch.manual_seed(args.seed) 44 | random.seed(12345) 45 | 46 | model = RoSA( 47 | # model 48 | encoder=args.encoder, 49 | # shape 50 | input_dim=args.input_dim, 51 | # model configuration 52 | layer_num=args.layer_num, 53 | hidden=args.hidden, 54 | proj_shape=(args.proj_middim, args.proj_outdim), 55 | ).to(device) 56 | 57 | model.load_state_dict(torch.load(f"checkpoints/{args.dataset}/best.pt")) 58 | 59 | test_acc = eval(args, model, device) -------------------------------------------------------------------------------- /test_runs.py: -------------------------------------------------------------------------------- 1 | from cgi import test 2 | from collections import OrderedDict 3 | from eval_utils import eval 4 | import argparse 5 | import random 6 | from model import RoSA 7 | from eval_utils import eval 8 | import yaml 9 | from yaml import SafeLoader 10 | import torch 11 | import numpy as np 12 | from tqdm import tqdm 13 | import os, wget 14 | 15 | 16 | def str2bool(v): 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ('yes','y','true','t','1'): 20 | return True 21 | if v.lower() in ('no','n','false','f','0'): 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError('Boolean value expected.') 25 | def get_parser(): 26 | parser = argparse.ArgumentParser(description='Description: Script to run our model.') 27 | parser.add_argument('--config', type=str, default='config.yaml') 28 | # dataset 29 | parser.add_argument('--dataset',help='Cora, Citeseer , Pubmed, etc. Default=Cora', default='Cora') 30 | 31 | return parser 32 | 33 | device = "cuda" if torch.cuda.is_available() else "cpu" 34 | 35 | eval_seeds = { 36 | 'Cora': 8085, 37 | 'Citeseer': 2230 38 | } 39 | 40 | download_url = 'https://raw.githubusercontent.com/ZhuYun97/RoSA_ckpts/main/' 41 | 42 | if __name__ == "__main__": 43 | parser = get_parser() 44 | try: 45 | args = parser.parse_args() 46 | except: 47 | exit() 48 | config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset] 49 | 50 | # combine args and config 51 | for k, v in config.items(): 52 | args.__setattr__(k, v) 53 | 54 | # repeated experiment 55 | seed = eval_seeds[args.dataset] 56 | torch.manual_seed(seed) 57 | random.seed(seed) 58 | np.random.seed(seed) 59 | torch.manual_seed(seed) 60 | if torch.cuda.device_count() > 0: 61 | torch.cuda.manual_seed_all(seed) 62 | torch.backends.cudnn.deterministic=True 63 | 64 | model = RoSA( 65 | # model 66 | encoder=args.encoder, 67 | # shape 68 | input_dim=args.input_dim, 69 | # model configuration 70 | layer_num=args.layer_num, 71 | hidden=args.hidden, 72 | proj_shape=(args.proj_middim, args.proj_outdim), 73 | ).to(device) 74 | 75 | dir = f'./runs_ckpts_{args.dataset}' 76 | exist_ckpts = os.path.exists(dir) 77 | if not exist_ckpts: 78 | os.makedirs(dir) 79 | 80 | test_acc_list = [] 81 | progress = tqdm(range(20)) 82 | for i in progress: 83 | file = os.path.join(dir, f'{args.dataset}_run{i}.pt') 84 | if not os.path.exists(file): 85 | # download non-existing ckpt 86 | file_url = os.path.join(download_url, f'runs_ckpts_{args.dataset}/{args.dataset}_run{i}.pt') 87 | wget.download(file_url, file) 88 | pretrained_dicts = torch.load(file) 89 | model.load_state_dict(pretrained_dicts) 90 | test_acc = eval(args, model, device) 91 | test_acc_list.append(test_acc) 92 | progress.set_postfix({'RUN': i, 'ACC': test_acc*100}) 93 | print(f"After 20 runs, test acc: {round(np.mean(test_acc_list)*100, 2)} ± {round(np.std(test_acc_list)*100, 2)}") -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | Cora: 2 | seed: 39788 3 | learning_rate: 0.01 4 | input_dim: 1433 5 | hidden: 128 6 | proj_middim: 128 7 | proj_outdim: 128 8 | encoder: gcn 9 | layer_num: 2 10 | walk_step: 10 11 | restart_ratio: 0.5 12 | graph_num: 128 13 | drop_edge_rate_1: 0.2 14 | drop_edge_rate_2: 0.3 15 | drop_feature_rate_1: 0.2 16 | drop_feature_rate_2: 0.3 17 | tau: 0.4 18 | epochs: 500 19 | weight_decay: 0.0005 20 | optimizer: sgd 21 | patience: 500 22 | inductive: False 23 | Citeseer: 24 | seed: 38108 25 | learning_rate: 0.01 26 | input_dim: 3703 27 | hidden: 256 28 | proj_middim: 256 29 | proj_outdim: 256 30 | encoder: gcn 31 | layer_num: 2 32 | walk_step: 10 33 | restart_ratio: 0.5 34 | graph_num: 128 35 | drop_edge_rate_1: 0.4 36 | drop_edge_rate_2: 0.5 37 | drop_feature_rate_1: 0.4 38 | drop_feature_rate_2: 0.5 39 | tau: 0.7 40 | epochs: 300 41 | weight_decay: 0.0005 42 | optimizer: sgd 43 | patience: 300 44 | inductive: False 45 | Pubmed: 46 | seed: 23344 47 | learning_rate: 0.001 48 | input_dim: 500 49 | hidden: 256 50 | proj_middim: 256 51 | proj_outdim: 256 # 128 52 | encoder: gcn 53 | layer_num: 2 54 | walk_step: 20 55 | restart_ratio: 0.8 56 | graph_num: 256 57 | drop_edge_rate_1: 0.4 58 | drop_edge_rate_2: 0.1 59 | drop_feature_rate_1: 0.0 60 | drop_feature_rate_2: 0.2 61 | tau: 0.1 62 | epochs: 500 63 | weight_decay: 0.0005 64 | optimizer: adamw 65 | patience: 500 66 | inductive: False 67 | DBLP: 68 | seed: 83521 69 | learning_rate: 0.001 70 | input_dim: 1639 71 | hidden: 256 72 | proj_middim: 256 73 | proj_outdim: 256 74 | encoder: gcn 75 | layer_num: 2 76 | walk_step: 10 77 | restart_ratio: 0.5 78 | graph_num: 128 79 | drop_edge_rate_1: 0.1 80 | drop_edge_rate_2: 0.2 81 | drop_feature_rate_1: 0.2 82 | drop_feature_rate_2: 0.3 83 | tau: 0.8 84 | epochs: 500 85 | weight_decay: 0.0005 86 | optimizer: adamw 87 | patience: 500 88 | inductive: False 89 | Cornell: 90 | seed: 39788 91 | learning_rate: 0.001 92 | input_dim: 1703 93 | hidden: 64 94 | proj_middim: 64 95 | proj_outdim: 64 96 | encoder: gcn 97 | layer_num: 2 98 | walk_step: 10 99 | restart_ratio: 0.0 100 | graph_num: 256 101 | drop_edge_rate_1: 0.2 102 | drop_edge_rate_2: 0.3 103 | drop_feature_rate_1: 0.2 104 | drop_feature_rate_2: 0.3 105 | tau: 0.4 106 | epochs: 200 107 | weight_decay: 0.0005 108 | optimizer: sgd 109 | patience: 20 110 | inductive: False 111 | Texas: 112 | seed: 39788 113 | learning_rate: 0.001 114 | input_dim: 1703 115 | hidden: 64 116 | proj_middim: 64 117 | proj_outdim: 64 118 | encoder: gcn 119 | layer_num: 2 120 | walk_step: 20 121 | restart_ratio: 0.0 122 | graph_num: 256 123 | drop_edge_rate_1: 0.2 124 | drop_edge_rate_2: 0.3 125 | drop_feature_rate_1: 0.2 126 | drop_feature_rate_2: 0.3 127 | tau: 0.4 128 | epochs: 200 129 | weight_decay: 0.0005 130 | optimizer: sgd 131 | patience: 20 132 | inductive: False 133 | Wisconsin: 134 | seed: 39788 135 | learning_rate: 0.001 136 | input_dim: 1703 137 | hidden: 64 138 | proj_middim: 64 139 | proj_outdim: 64 140 | encoder: gcn 141 | layer_num: 2 142 | walk_step: 10 143 | restart_ratio: 0.0 144 | graph_num: 256 145 | drop_edge_rate_1: 0.2 146 | drop_edge_rate_2: 0.3 147 | drop_feature_rate_1: 0.2 148 | drop_feature_rate_2: 0.3 149 | tau: 0.4 150 | epochs: 200 151 | weight_decay: 0.0005 152 | optimizer: sgd 153 | patience: 20 154 | inductive: False 155 | 156 | -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | from torch_geometric.utils import subgraph, to_undirected, remove_isolated_nodes, dropout_adj, remove_self_loops 4 | from torch_geometric.utils.num_nodes import maybe_num_nodes 5 | import copy 6 | from torch_sparse import SparseTensor 7 | 8 | 9 | def add_remaining_selfloop_for_isolated_nodes(edge_index, num_nodes): 10 | num_nodes = max(maybe_num_nodes(edge_index), num_nodes) 11 | # only add self-loop on isolated nodes 12 | # edge_index, _ = remove_self_loops(edge_index) 13 | loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device) 14 | connected_nodes_indices = torch.cat([edge_index[0], edge_index[1]]).unique() 15 | mask = torch.ones(num_nodes, dtype=torch.bool) 16 | mask[connected_nodes_indices] = False 17 | loops_for_isolatd_nodes = loop_index[mask] 18 | loops_for_isolatd_nodes = loops_for_isolatd_nodes.unsqueeze(0).repeat(2, 1) 19 | edge_index = torch.cat([edge_index, loops_for_isolatd_nodes], dim=1) 20 | return edge_index 21 | 22 | 23 | class RWR: 24 | """ Every node in the graph will get a random path 25 | 26 | A stochastic data augmentation module that transforms a complete graph into many subgraphs through random walking 27 | the subgraphs which contain the same center nodes are positive pairs, otherwise they are negative pairs 28 | """ 29 | 30 | def __init__(self, walk_step=50, graph_num=128, restart_ratio=0.5, inductive=False, aligned=False, **args): 31 | self.walk_steps = walk_step 32 | self.graph_num = graph_num 33 | self.restart_ratio = restart_ratio 34 | self.inductive = inductive 35 | self.aligned = aligned 36 | 37 | def __call__(self, graph): 38 | graph = copy.deepcopy(graph) # modified on the copy 39 | assert self.walk_steps > 1 40 | # remove isolated nodes (or we can construct edges for these nodes) 41 | if self.inductive: 42 | train_node_idx = torch.where(graph.train_mask == True)[0] 43 | graph.edge_index, _ = subgraph(train_node_idx, graph.edge_index) # remove val and test nodes (val and test are considered as isolated nodes) 44 | edge_index, _, mask = remove_isolated_nodes(graph.edge_index, num_nodes=graph.x.shape[0]) # remove all ioslated nodes and re-index nodes 45 | graph.x = graph.x[mask] 46 | edge_index = to_undirected(graph.edge_index) 47 | edge_index = add_remaining_selfloop_for_isolated_nodes(edge_index, graph.x.shape[0]) 48 | graph.edge_index = edge_index 49 | 50 | node_num = graph.x.shape[0] 51 | graph_num = min(self.graph_num, node_num) 52 | start_nodes = torch.randperm(node_num)[:graph_num] 53 | edge_index = graph.edge_index 54 | 55 | value = torch.arange(edge_index.size(1)) 56 | self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], 57 | value=value, 58 | sparse_sizes=(node_num, node_num)).t() 59 | 60 | view1_list = [] 61 | view2_list = [] 62 | 63 | views_cnt = 1 if self.aligned else 2 64 | for view_idx in range(views_cnt): 65 | current_nodes = start_nodes.clone() 66 | history = start_nodes.clone().unsqueeze(0) 67 | signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0) 68 | for i in range(self.walk_steps): 69 | seed = torch.rand([graph_num]) 70 | nei = self.adj_t.sample(1, current_nodes).squeeze() 71 | sign = seed < self.restart_ratio 72 | nei[sign] = start_nodes[sign] 73 | history = torch.cat((history, nei.unsqueeze(0)), dim=0) 74 | signs = torch.cat((signs, sign.unsqueeze(0)), dim=0) 75 | current_nodes = nei 76 | history = history.T 77 | signs = signs.T 78 | 79 | for i in range(graph_num): 80 | path = history[i] 81 | sign = signs[i] 82 | node_idx = path.unique() 83 | sources = path[:-1].numpy().tolist() 84 | targets = path[1:].numpy().tolist() 85 | sub_edges = torch.IntTensor([sources, targets]).type_as(graph.edge_index) 86 | sub_edges = sub_edges.T[~sign[1:]].T 87 | # undirectional 88 | if sub_edges.shape[1] != 0: 89 | sub_edges = to_undirected(sub_edges) 90 | view = self.adjust_idx(sub_edges, node_idx, graph, path[0].item()) 91 | 92 | if self.aligned: 93 | view1_list.append(view) 94 | view2_list.append(copy.deepcopy(view)) 95 | else: 96 | if view_idx == 0: 97 | view1_list.append(view) 98 | else: 99 | view2_list.append(view) 100 | return (view1_list, view2_list) 101 | 102 | def adjust_idx(self, edge_index, node_idx, full_g, center_idx): 103 | '''re-index the nodes and edge index 104 | 105 | In the subgraphs, some nodes are droppped. We need to change the node index in edge_index in order to corresponds 106 | nodes' index to edge index 107 | ''' 108 | node_idx_map = {j : i for i, j in enumerate(node_idx.numpy().tolist())} 109 | sources_idx = list(map(node_idx_map.get, edge_index[0].numpy().tolist())) 110 | target_idx = list(map(node_idx_map.get, edge_index[1].numpy().tolist())) 111 | 112 | edge_index = torch.IntTensor([sources_idx, target_idx]).type_as(full_g.edge_index) 113 | x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], center=node_idx_map[center_idx], original_idx=node_idx) 114 | return x_view 115 | 116 | 117 | class DropEdgeAndFeature(object): 118 | def __init__(self, d_fea=0.2, d_edge=0.3): 119 | self.drop_feature_ratio=d_fea 120 | self.drop_edge_ratio=d_edge 121 | 122 | def set_drop_ratio(self, d_fea, d_edge): 123 | self.drop_feature_ratio=d_fea 124 | self.drop_edge_ratio=d_edge 125 | 126 | def set_drop_edge_ratio(self, d_edge): 127 | self.drop_edge_ratio=d_edge 128 | 129 | def set_drop_fea_ratio(self, d_fea): 130 | self.drop_feature_ratio=d_fea 131 | 132 | def drop_fea_edge(self, g): 133 | new_g = copy.deepcopy(g) 134 | edge_index = self._drop_edge(new_g) 135 | x = self._drop_feature(new_g) 136 | new_g.x = x 137 | new_g.edge_index = edge_index 138 | return new_g 139 | 140 | def _drop_feature(self, g): 141 | drop_mask = torch.empty( 142 | (g.x.size(1), ), 143 | dtype=torch.float32, 144 | device=g.x.device).uniform_(0, 1) < self.drop_feature_ratio 145 | x = g.x.clone() 146 | x[:, drop_mask] = 0 147 | return x 148 | 149 | def _drop_edge(self, g): 150 | edge_index, _ = dropout_adj(g.edge_index, p=self.drop_edge_ratio) 151 | return edge_index -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RoSA: A Robust Self-Aligned Framework for Node-Node Graph Contrastive Learning 2 | This repo is Pytorch implemention of
[RoSA: A Robust Self-Aligned Framework for Node-Node Graph Contrastive Learning](https://www.ijcai.org/proceedings/2022/0527.pdf)[\[PPT\]](https://docs.google.com/presentation/d/1RnpXrjyaojJ3Tiqu92NImtYI4cD2XVmr/edit?usp=sharing&ouid=103567337311720006952&rtpof=true&sd=true)[\[appendix\]](https://arxiv.org/abs/2204.13846)

3 | Yun Zhu\*, Jianhao Guo\*, Fei Wu, Siliang Tang†

4 | In IJCAI 2022
5 | 6 | ## Overview 7 | This is the first work dedicated to solving non-aligned node-node graph contrastive learning problems. To tackle the non-aligned problem, we introduce a novel graph-based optimal transport algorithm, g-EMD, which does not require explicit node-node correspondence and can fully utilize graph topological and attributive information for non-aligned node-node contrasting. Moreover, to compensate for the possible information loss caused by non-aligned sub-sampling, we propose a nontrivial unsupervised graph adversarial training to improve the diversity of sub-sampling and strengthen the robustness of the model. The overview of our method is depicted as: 8 | ![FRAMEWORK](./assets/framework.PNG) 9 | 10 | ## Files 11 | ``` 12 | . 13 | ├── dataset_apis # Code process datasets. 14 | │ ├── topology_dist # Storing the distance of the shortest path (SPD) between vi and vj. 15 | │ ├── citeseer.py # processing for citeseer dataset. 16 | │ ├── cora.py # processing for cora dataset. 17 | │ ├── dblp.py # processing for dblp dataset. 18 | │ ├── pubmed.py # processing for pubmed dataset. 19 | │ ├── cornell.py # processing for cornell dataset. 20 | │ ├── wisconsin.py # processing for wisconsin dataset. 21 | │ ├── texas.py # processing for texas dataset. 22 | │ └── ... # More datasets will be added. 23 | │ 24 | ├── adversarial.py # Code for unsupervised adversarial training. 25 | ├── augmentation.py # Code for augmentation. 26 | ├── config.yaml # Configurations for our method. 27 | ├── eval_utils.py # The toolkits for evaluation. 28 | ├── eval.py # Code for evaluation. 29 | ├── global_var.py # Code for storing global variable. 30 | ├── model.py # Code for building up model. 31 | ├── train.py # Training process. 32 | ├── test_runs.py # Reproduce the results reported in our paper 33 | └── ... 34 | ``` 35 | 36 | ## Setup 37 | Recommand you to set up a Python virtual environment with the required dependencies as follows: 38 | ``` 39 | conda create -n rosa python==3.9 40 | conda activate rosa 41 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge 42 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.8.0+cu111.html 43 | ``` 44 | ## Usage 45 | **Command for training model on Cora dataset** 46 | ```bash 47 | CUDA_VISIBLE_DEVICES=0 python train.py --dataset=Cora --config=config.yaml --ad=True --rectified=True 48 | ``` 49 | For efficient usage, you can run as below: 50 | ```bash 51 | CUDA_VISIBLE_DEVICES=0 python train.py --dataset=Cora --config=config.yaml --ad=False --rectified=False 52 | ``` 53 | Now supported datasets include Cora, Citeseer, Pubmed, DBLP, Cornell, Wisconsin, Texas. More datasets are coming soon! 54 | 55 | **Command for testing model on Cora dataset**
56 | After training, the best checkpoint will be stored in `checkpoints\\` dir. Then you can test the checkpoint through this command: 57 | ```bash 58 | CUDA_VISIBLE_DEVICES=0 python eval.py --dataset=Cora --config=config.yaml 59 | ``` 60 | 61 | **Command for reproducing the results in the paper**
62 | Because our results are based on 20 independent experiments. We provide the all(20) pre-trained ckpts of RoSA for Cora and Citeseer. As for other larger datasets, you can contact us by email if you need.
63 | The usage is quite simple, you just run the command below. The ckpts will be downloaded automatically and the pre-trained models will be tested. 64 | ```bash 65 | CUDA_VISIBLE_DEVICES=0 python test_runs.py --dataset=Cora 66 | ``` 67 | > After 20 runs, test acc: 84.56 ± 0.67 68 | 69 | ```bash 70 | CUDA_VISIBLE_DEVICES=0 python test_runs.py --dataset=Citeseer 71 | ``` 72 | > After 20 runs, test acc: 73.5 ± 0.45 73 | 74 | ### Illustration of arguements 75 | 76 | ``` 77 | --dataset: default Cora, [Cora, Citeseer, Pubmed, DBLP, Cornell, Wisconsin, Texas] can also be choosen 78 | --rectified: defalut False, use rectified cost matrix instead of vanilla cost matrix (if True) 79 | --ad: default False, use unsupervised adversarial training (if True) 80 | --aligned: default False, use aligned views (if True) 81 | ``` 82 | 83 | ### More experiments 84 | We conduct experimetns on other five commonly used datasets with RoSA, the results show in Table 1. RoSA reaches SOTA on these datasets which proves the effectiveness of our method. 85 | | Method | Wiki-CS | Amazon-Computers | Amazon-Photo | Coauthor-CS | Coauthor-Physics | 86 | | :----:| :----------: | :----------: | :----------: | :----------: | :----------: | 87 | | DGI | 75.35 ± 0.14 | 83.95 ± 0.47 | 91.61 ± 0.22 | 92.15 ± 0.63 | 94.51 ± 0.52 | 88 | | GMI | 74.85 ± 0.08 | 82.21 ± 0.31 | 90.68 ± 0.17 | OOM | OOM | 89 | | MVGRL | 77.52 ± 0.08 | 87.52 ± 0.11 | 91.74 ± 0.07 | 92.11 ± 0.12 | 95.33 ± 0.03 | 90 | | GRACE | 78.19 ± 0.01 | 87.46 ± 0.22 | 92.15 ± 0.24 | 92.93 ± 0.01 | 95.26 ± 0.02 | 91 | | GCA | 78.35 ± 0.05 | 88.94 ± 0.15 | 92.53 ± 0.16 | 93.10 ± 0.01 | 95.73 ± 0.03 | 92 | | BGRL | 79.36 ± 0.53 | 89.68 ± 0.31 | 92.87 ± 0.27 | 93.21 ± 0.18 | 95.56 ± 0.12 | 93 | | RoSA | **80.11 ± 0.10** | **90.12 ± 0.26** | **93.67 ± 0.07** | **93.23 ± 0.13** | **95.76 ± 0.09** | 94 | 95 | 96 | | | Hidden size | Batch size | Learning rate | Walk length | Epochs | tau | p_{e,1} | p_{e,1} | p_{f,1} | p_{f,1} | 97 | | :---- | :---------- | :--------- | :------------ | :----------- | :----- | :-- | :-------- | :-------- | :-------- | :-------- | 98 | | Wiki-CS | 256 | 256 | 1e-3 | 10 | 500 | 0.5 | 0.2 | 0.3 | 0.2 | 0.3 | 99 | | Amazon-Computers | 128 | 256 | 1e-3 | 10 | 500 | 0.2 | 0.4 | 0.5 | 0.1 | 0.2 | 100 | | Amazon-Photo | 256 | 256 | 1e-3 | 10 | 500 | 0.3 | 0.2 | 0.3 | 0.2 | 0.3 | 101 | | Coauthor-CS | 256 | 128 | 1e-3 | 10 | 100 | 0.1 | 0.2 | 0.3 | 0.2 | 0.3 | 102 | | Coauthor-Physics | 128 | 256 | 1e-3 | 10 | 100 | 0.5 | 0.2 | 0.3 | 0.2 | 0.3 | 103 | 104 | In additions, we use `prelu` as activation function and use `adamw` optimizer with `5e-4 weight decay` for all experimetns. The restart ratio of random walking is 0.5. 105 | 106 | ## Citation 107 | If you use this code for you research, please cite our paper. 108 | ``` 109 | @inproceedings{ijcai2022-527, 110 | title = {RoSA: A Robust Self-Aligned Framework for Node-Node Graph Contrastive Learning}, 111 | author = {Zhu, Yun and Guo, Jianhao and Wu, Fei and Tang, Siliang}, 112 | booktitle = {Proceedings of the Thirty-First International Joint Conference on 113 | Artificial Intelligence, {IJCAI-22}}, 114 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 115 | editor = {Lud De Raedt}, 116 | pages = {3795--3801}, 117 | year = {2022}, 118 | month = {7}, 119 | note = {Main Track}, 120 | doi = {10.24963/ijcai.2022/527}, 121 | url = {https://doi.org/10.24963/ijcai.2022/527}, 122 | } 123 | ``` 124 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Batch 2 | import numpy as np 3 | import torch 4 | import os 5 | import argparse 6 | import importlib 7 | import random 8 | from torch_geometric.data import DataLoader 9 | from torch_geometric.utils import contains_isolated_nodes, to_networkx 10 | from augmentation import RWR, DropEdgeAndFeature 11 | from model import RoSA 12 | from adversarial import ad_training 13 | from eval_utils import eval 14 | import global_var 15 | import networkx as nx 16 | import yaml 17 | from yaml import SafeLoader 18 | from tqdm import tqdm 19 | 20 | 21 | device = "cuda" if torch.cuda.is_available() else "cpu" 22 | 23 | def train(args): 24 | trans = RWR(walk_step=args.walk_step, graph_num=args.graph_num, restart_ratio=args.restart_ratio, aligned=args.aligned, inductive=args.inductive) 25 | load_dataset = getattr(importlib.import_module(f"dataset_apis.{args.dataset.lower()}"), 'load_trainset') 26 | dataset = load_dataset(trans) 27 | 28 | train_loader = DataLoader(dataset, batch_size=1, shuffle=False) 29 | 30 | # Model 31 | model = RoSA( 32 | # model 33 | encoder=args.encoder, 34 | # shape 35 | input_dim=args.input_dim, 36 | # model configuration 37 | layer_num=args.layer_num, 38 | hidden=args.hidden, 39 | proj_shape=(args.proj_middim, args.proj_outdim), 40 | # loss 41 | is_rectified=args.rectified, 42 | T = args.tau, 43 | topo_t=args.topo_t 44 | ).to(device) 45 | 46 | # optimizer 47 | if args.optimizer == 'sgd': 48 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9) 49 | elif args.optimizer == 'adam': 50 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 51 | elif args.optimizer == 'adamw': 52 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 53 | 54 | # save checkpoints 55 | path = f'./checkpoints/{args.dataset}/' 56 | if not os.path.exists(path): 57 | os.makedirs(path) 58 | 59 | patience = args.patience # early stopping 60 | stop_cnt = 0 61 | best = 9999 62 | best_t = 0 63 | 64 | loop = tqdm(range(args.epochs)) 65 | for epoch in loop: 66 | model.train() 67 | for idx, graphs in enumerate(train_loader): 68 | model.train() 69 | optimizer.zero_grad() 70 | # Because the augmentation function, the processed batch in loader will be [batch, batch, batch], we should collect them into one batch 71 | view1_list = [] 72 | view2_list = [] 73 | assert len(graphs[0]) == len(graphs[1]) 74 | all_graphs_num = len(graphs[0]) 75 | for i in range(all_graphs_num): 76 | view1_list.extend(graphs[0][i].to_data_list()) 77 | view2_list.extend(graphs[1][i].to_data_list()) 78 | 79 | # shuffle views list 80 | shuf_idx = np.random.permutation(all_graphs_num) 81 | view1_list_shuff = [view1_list[i] for i in shuf_idx] 82 | view2_list_shuff = [view2_list[i] for i in shuf_idx] 83 | 84 | views1 = Batch().from_data_list(view1_list_shuff).to(device) 85 | views2 = Batch().from_data_list(view2_list_shuff).to(device) 86 | # additional augmentation 87 | drop = DropEdgeAndFeature(d_fea=args.drop_feature_rate_1, d_edge=args.drop_edge_rate_1) 88 | views1 = drop.drop_fea_edge(views1) 89 | drop.set_drop_ratio(d_fea=args.drop_feature_rate_2, d_edge=args.drop_edge_rate_2) 90 | views2 = drop.drop_fea_edge(views2) 91 | 92 | if args.ad: 93 | def node_attack(perturb): 94 | views1.x += perturb 95 | return model(views1, views2, views1.batch, views2.batch) 96 | 97 | loss = ad_training(model, node_attack, views1.x.shape, args, device) 98 | else: 99 | loss = model(views1, views2, views1.batch, views2.batch) 100 | loop.set_postfix(loss = loss.item()) 101 | loss.backward() 102 | optimizer.step() 103 | 104 | if loss < best: 105 | stop_cnt = 0 106 | best = loss 107 | best_t = epoch + 1 108 | torch.save(model.state_dict(), os.path.join(path, 'best.pt')) 109 | else: 110 | stop_cnt += 1 111 | 112 | if stop_cnt >= patience: 113 | print("early stopping") 114 | break 115 | 116 | if patience < args.epochs: 117 | print('Loading {}th epoch'.format(best_t)) 118 | model.load_state_dict(torch.load(os.path.join(path, 'best.pt'))) 119 | 120 | return model 121 | 122 | def register_topology(args): 123 | MAX_HOP = 100 124 | load_dataset = getattr(importlib.import_module(f"dataset_apis.{args.dataset.lower()}"), 'load_eval_trainset') 125 | data = load_dataset()[0] 126 | topo_file = f"./dataset_apis/topology_dist/{args.dataset.lower()}_padding.pt" 127 | exist = os.path.isfile(topo_file) 128 | if not exist: 129 | node_num = data.x.shape[0] 130 | G = to_networkx(data) 131 | generator = dict(nx.shortest_path_length(G)) 132 | topology_dist = torch.zeros((node_num+1, node_num+1)) # we shift the node index with 1, in order to store 0-index for padding nodes 133 | mask = torch.zeros((node_num+1, node_num+1)).bool() 134 | 135 | topology_dist[0, :] = 1000 # used for padding nodes 136 | topology_dist[:, 0] = 1000 137 | 138 | for i in tqdm(range(1, node_num+1)): 139 | # print(f"processing {i}-th node") 140 | for j in range(1, node_num+1): 141 | if j-1 in generator[i-1].keys(): 142 | topology_dist[i][j] = generator[i-1][j-1] 143 | else: 144 | topology_dist[i][j] = MAX_HOP 145 | mask[i][j] = True # record nodes that do not have connections 146 | torch.save(topology_dist, topo_file) 147 | else: 148 | topology_dist = torch.load(topo_file) 149 | global_var._init() 150 | global_var.set_value("topology_dist", topology_dist) 151 | 152 | def str2bool(v): 153 | if isinstance(v, bool): 154 | return v 155 | if v.lower() in ('yes','y','true','t','1'): 156 | return True 157 | if v.lower() in ('no','n','false','f','0'): 158 | return False 159 | else: 160 | raise argparse.ArgumentTypeError('Boolean value expected.') 161 | def get_parser(): 162 | parser = argparse.ArgumentParser(description='Description: Script to run our model.') 163 | parser.add_argument('--config', type=str, default='config.yaml') 164 | # dataset 165 | parser.add_argument('--dataset',help='Cora, Citeseer , Pubmed, etc. Default=Cora', default='Cora') 166 | # augmentation 167 | parser.add_argument('--aligned', type=str2bool, help='aligned views or not', default=False) 168 | # adversarial training 169 | parser.add_argument('--ad', type=str2bool, help='combine with adversarial training', default=False) 170 | parser.add_argument('--step-size', type=float, default=1e-3) 171 | parser.add_argument('--m', type=int, default=3) 172 | # loss 173 | parser.add_argument('--rectified', type=str2bool, help='use rectified cost matrix', default=False) 174 | parser.add_argument('--topo_t', type=int, help='temperature for sigmoid', default=2) 175 | 176 | return parser 177 | 178 | if __name__ == "__main__": 179 | parser = get_parser() 180 | try: 181 | args = parser.parse_args() 182 | except: 183 | exit() 184 | config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset] 185 | 186 | # combine args and config 187 | for k, v in config.items(): 188 | args.__setattr__(k, v) 189 | 190 | # repeated experiment 191 | torch.manual_seed(args.seed) 192 | random.seed(12345) 193 | 194 | if args.rectified: 195 | register_topology(args) 196 | 197 | model = train(args) 198 | test_acc = eval(args, model, device) 199 | 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import functools 3 | from sklearn.metrics import accuracy_score, f1_score 4 | from sklearn.linear_model import LogisticRegression 5 | from sklearn.model_selection import train_test_split, GridSearchCV, ShuffleSplit 6 | from sklearn.multiclass import OneVsRestClassifier 7 | from sklearn.preprocessing import normalize, OneHotEncoder 8 | import importlib 9 | import torch 10 | from torch_sparse import SparseTensor 11 | from model import LogReg 12 | 13 | 14 | # borrow from BGRL [https://github.com/nerdslab/bgrl/blob/main/bgrl/logistic_regression_eval.py] 15 | def fit_logistic_regression(X, y, data_random_seed=1, repeat=1): 16 | # transfrom targets to one-hot vector 17 | one_hot_encoder = OneHotEncoder(categories='auto', sparse=False) 18 | 19 | y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool) 20 | 21 | # normalize x 22 | X = normalize(X, norm='l2') 23 | 24 | # set random state 25 | rng = np.random.RandomState(data_random_seed) # this will ensure the dataset will be split exactly the same 26 | # throughout training 27 | 28 | accuracies = [] 29 | for _ in range(repeat): 30 | # different random split after each repeat 31 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=rng) 32 | 33 | # grid search with one-vs-rest classifiers 34 | logreg = LogisticRegression(solver='liblinear') 35 | c = 2.0 ** np.arange(-10, 11) 36 | cv = ShuffleSplit(n_splits=5, test_size=0.5) 37 | clf = GridSearchCV(estimator=OneVsRestClassifier(logreg), param_grid=dict(estimator__C=c), 38 | n_jobs=5, cv=cv, verbose=0) 39 | clf.fit(X_train, y_train) 40 | 41 | y_pred = clf.predict_proba(X_test) 42 | y_pred = np.argmax(y_pred, axis=1) 43 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) 44 | 45 | test_acc = accuracy_score(y_test, y_pred) 46 | accuracies.append(test_acc) 47 | return accuracies 48 | 49 | 50 | # borrow from BGRL [https://github.com/nerdslab/bgrl/blob/main/bgrl/logistic_regression_eval.py] 51 | def fit_logistic_regression_preset_splits(X, y, train_masks, val_masks, test_mask): 52 | # transfrom targets to one-hot vector 53 | one_hot_encoder = OneHotEncoder(categories='auto', sparse=False) 54 | y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool) 55 | 56 | # normalize x 57 | X = normalize(X, norm='l2') 58 | 59 | accuracies = [] 60 | for split_id in range(train_masks.shape[1]): 61 | # get train/val/test masks 62 | train_mask, val_mask = train_masks[:, split_id], val_masks[:, split_id] 63 | 64 | # make custom cv 65 | X_train, y_train = X[train_mask], y[train_mask] 66 | X_val, y_val = X[val_mask], y[val_mask] 67 | X_test, y_test = X[test_mask], y[test_mask] 68 | 69 | # grid search with one-vs-rest classifiers 70 | best_test_acc, best_acc = 0, 0 71 | for c in 2.0 ** np.arange(-10, 11): 72 | clf = OneVsRestClassifier(LogisticRegression(solver='liblinear', C=c)) 73 | clf.fit(X_train, y_train) 74 | 75 | y_pred = clf.predict_proba(X_val) 76 | y_pred = np.argmax(y_pred, axis=1) 77 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) 78 | val_acc = accuracy_score(y_val, y_pred) 79 | if val_acc > best_acc: 80 | best_acc = val_acc 81 | y_pred = clf.predict_proba(X_test) 82 | y_pred = np.argmax(y_pred, axis=1) 83 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) 84 | best_test_acc = accuracy_score(y_test, y_pred) 85 | 86 | accuracies.append(best_test_acc) 87 | print(np.mean(accuracies)) 88 | return accuracies 89 | 90 | 91 | def repeat(n_times): 92 | def decorator(f): 93 | @functools.wraps(f) 94 | def wrapper(*args, **kwargs): 95 | results = [f(*args, **kwargs) for _ in range(n_times)] 96 | statistics = {} 97 | for key in results[0].keys(): 98 | values = [r[key] for r in results] 99 | statistics[key] = { 100 | 'mean': np.mean(values), 101 | 'std': np.std(values)} 102 | print_statistics(statistics, f.__name__) 103 | return statistics 104 | return wrapper 105 | return decorator 106 | 107 | 108 | def prob_to_one_hot(y_pred): 109 | ret = np.zeros(y_pred.shape, np.bool) 110 | indices = np.argmax(y_pred, axis=1) 111 | for i in range(y_pred.shape[0]): 112 | ret[i][indices[i]] = True 113 | return ret 114 | 115 | 116 | def print_statistics(statistics, function_name): 117 | print(f'(E) | {function_name}:', end=' ') 118 | for i, key in enumerate(statistics.keys()): 119 | mean = statistics[key]['mean'] 120 | std = statistics[key]['std'] 121 | print(f'{key}={mean:.4f}+-{std:.4f}', end='') 122 | if i != len(statistics.keys()) - 1: 123 | print(',', end=' ') 124 | else: 125 | print() 126 | 127 | 128 | # borrow from GRACE [https://github.com/CRIPAC-DIG/GRACE/blob/master/eval.py] 129 | # @repeat(5) 130 | def label_classification(embeddings, y, ratio=0.1): 131 | X = embeddings.detach().cpu().numpy() 132 | Y = y.detach().cpu().numpy() 133 | Y = Y.reshape(-1, 1) 134 | onehot_encoder = OneHotEncoder(categories='auto').fit(Y) 135 | Y = onehot_encoder.transform(Y).toarray().astype(np.bool) 136 | 137 | X = normalize(X, norm='l2') 138 | 139 | X_train, X_test, y_train, y_test = train_test_split(X, Y, 140 | test_size=1 - ratio) 141 | 142 | logreg = LogisticRegression(solver='liblinear') 143 | c = 2.0 ** np.arange(-10, 10) 144 | 145 | clf = GridSearchCV(estimator=OneVsRestClassifier(logreg), 146 | param_grid=dict(estimator__C=c), n_jobs=8, cv=5, 147 | verbose=0) 148 | clf.fit(X_train, y_train) 149 | 150 | y_pred = clf.predict_proba(X_test) 151 | y_pred = prob_to_one_hot(y_pred) 152 | 153 | acc = accuracy_score(y_test, y_pred) 154 | # return {"accuracy": acc} 155 | return acc 156 | 157 | 158 | def heter_eval(model, z, y, train_mask, val_mask, test_mask, device): 159 | model.eval() 160 | num_classes = y.max().item()+1 161 | 162 | xent = torch.nn.CrossEntropyLoss() 163 | # z = torch.nn.functional.normalize(z, p=2, dim=1) 164 | log = LogReg(model.hidden, num_classes).to(device) 165 | opt = torch.optim.Adam(log.parameters(), lr=1e-2, weight_decay=0.0) 166 | 167 | train_embs = z[train_mask] 168 | val_embs = z[val_mask] 169 | test_embs = z[test_mask] 170 | 171 | best_acc_from_val = torch.zeros(1).cuda() 172 | best_val = torch.zeros(1).cuda() 173 | best_t = 0 174 | 175 | log.train() 176 | for i in range(100): 177 | opt.zero_grad() 178 | 179 | logits = log(train_embs) 180 | loss = xent(logits, y[train_mask].long()) 181 | 182 | with torch.no_grad(): 183 | ltra = log(train_embs) 184 | lv = log(val_embs) 185 | lt = log(test_embs) 186 | ltra_preds = torch.argmax(ltra, dim=1) 187 | lv_preds = torch.argmax(lv, dim=1) 188 | lt_preds = torch.argmax(lt, dim=1) 189 | train_acc = torch.sum(ltra_preds == y[train_mask]).float() / train_mask.sum() 190 | val_acc = torch.sum(lv_preds == y[val_mask]).float() / val_mask.sum() 191 | test_acc = torch.sum(lt_preds == y[test_mask]).float() / test_mask.sum() 192 | 193 | if val_acc > best_val: 194 | best_acc_from_val = test_acc 195 | best_val = val_acc 196 | best_t = i 197 | 198 | loss.backward() 199 | opt.step() 200 | return best_acc_from_val.cpu().item() 201 | 202 | 203 | def eval(args, model, device): 204 | model.eval() 205 | load_dataset = getattr(importlib.import_module(f"dataset_apis.{args.dataset.lower()}"), 'load_eval_trainset') 206 | dataset = load_dataset() 207 | data = dataset[0].to(device) 208 | z = model.embed(data) 209 | if args.dataset.lower() in ['cora', 'citeseer', 'pubmed', 'dblp']: 210 | acc = label_classification(z, data.y, ratio=0.1) 211 | elif args.dataset.lower() in ['amazon_photos', 'amazon_computers', 'coauthor_cs', 'coauthor_physics']: 212 | acc = fit_logistic_regression(z, data.y) 213 | elif args.dataset.lower() == 'wikics': 214 | acc = fit_logistic_regression_preset_splits(z, data.y, data.train_mask, data.val_mask, data.test_mask) 215 | elif args.dataset.lower() in ["cornell", "wisconsin", "texas"]: 216 | acc_list = [] 217 | for run in range(data.train_mask.shape[1]): # These datasets contains 10 different splits. Note: In our paper, we run 20 independent experiments. Different experiments use different splits. 218 | train_mask = data.train_mask[:, run%10] 219 | val_mask = data.val_mask[:, run%10] 220 | test_mask = data.test_mask[:, run%10] 221 | acc = heter_eval(model, z, data.y, train_mask, val_mask, test_mask, device) 222 | acc_list.append(acc) 223 | acc = np.mean(acc_list) 224 | print("Test acc: {:.2f}±{:.2f}".format(np.mean(acc_list)*100, np.std(acc_list)*100)) 225 | else: 226 | raise NotImplementedError(f"{args.dataset} is not supported!") 227 | return acc 228 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.data import Data 4 | from torch_geometric.utils import to_dense_batch 5 | from torch_geometric.nn import global_mean_pool 6 | from torch_geometric.nn import GCNConv 7 | from torch_geometric.nn.inits import glorot 8 | from torch.nn import ModuleList 9 | import global_var 10 | 11 | 12 | class RoSA(torch.nn.Module): 13 | def __init__( 14 | self, 15 | encoder='gcn', 16 | input_dim=1433, 17 | 18 | # model configuration 19 | layer_num=2, # layers of encoder 20 | hidden=128, # encoder hidden size 21 | proj_shape=(128, 128), # hidden size for predictor 22 | 23 | T = 0.4, # temperature 24 | is_rectified=True, # use rectified cost matrix 25 | topo_t = 2 # temperature for calculating re-scale ratios with topology dist 26 | ): 27 | super(RoSA, self).__init__() 28 | 29 | # load components 30 | if encoder == 'gcn': 31 | Encoder = GCN 32 | elif encoder == 'sage-gcn': 33 | Encoder = GraphSAGE_GCN 34 | else: 35 | raise NotImplementedError(f'{encoder} is not implemented!') 36 | self.encoder = Encoder(input_dim=input_dim, layer_num=layer_num, hidden=hidden) 37 | self.T = T 38 | self.rectified = is_rectified 39 | self.topo_t = topo_t 40 | self.hidden = hidden 41 | 42 | # adapative size for mlp's input dim 43 | fake_x = torch.rand((2, input_dim)) 44 | fake_edge_index = torch.LongTensor([[0], [0]]) 45 | fake_g = Data(x=fake_x, edge_index=fake_edge_index) 46 | # fake_graph = Data(x=fake_x, edge_index=fake_edge_index, batch=torch.LongTensor([0])) 47 | with torch.no_grad(): 48 | rep = self.encoder(fake_g) 49 | hid = rep.shape[-1] 50 | self.projector = Projector([hid, *proj_shape]) 51 | 52 | def gen_rep(self, data): 53 | h = self.encoder(data) 54 | z = self.projector(h) 55 | return h, z 56 | 57 | def sim(self, reps1, reps2): 58 | reps1_unit = F.normalize(reps1, dim=-1) 59 | reps2_unit = F.normalize(reps2, dim=-1) 60 | if len(reps1.shape) == 2: 61 | sim_mat = torch.einsum("ik,jk->ij", [reps1_unit, reps2_unit]) 62 | elif len(reps1.shape) == 3: 63 | sim_mat = torch.einsum('bik,bjk->bij', [reps1_unit, reps2_unit]) 64 | else: 65 | print(f"{len(reps1.shape)} dimension tensor is not supported for this function!") 66 | return sim_mat 67 | 68 | def topology_dist(self, node_idx1, node_idx2): 69 | full_topology_dist = global_var.get_value('topology_dist').cuda() 70 | 71 | batch_size = node_idx1.shape[0] 72 | batch_subpology_dist = [full_topology_dist.index_select(dim=0, index=node_idx1[i]).index_select(dim=1, index=node_idx2[i]) for i in range(batch_size)] 73 | batch_subpology_dist = torch.stack(batch_subpology_dist) 74 | return batch_subpology_dist 75 | 76 | def _batched_semi_emd_loss(self, out1, avg_out1, out2, avg_out2, lamb=20, rescale_ratio=None): 77 | assert out1.shape[0] == out2.shape[0] and avg_out1.shape == avg_out2.shape 78 | 79 | cost_matrix = 1-self.sim(out1, out2) 80 | if rescale_ratio is not None: 81 | cost_matrix = cost_matrix * rescale_ratio 82 | 83 | # Sinkhorn iteration 84 | iter_times = 5 85 | with torch.no_grad(): 86 | r = torch.bmm(out1, avg_out2.transpose(1,2)) 87 | r[r<=0] = 1e-8 88 | r = r / r.sum(dim=1, keepdim=True) 89 | c = torch.bmm(out2, avg_out1.transpose(1,2)) 90 | c[c<=0] = 1e-8 91 | c = c / c.sum(dim=1, keepdim=True) 92 | P = torch.exp(-1*lamb*cost_matrix) 93 | u = (torch.ones_like(c)/c.shape[1]) 94 | for i in range(iter_times): 95 | v = torch.div(r, torch.bmm(P, u)) 96 | u = torch.div(c, torch.bmm(P.transpose(1,2), v)) 97 | u = u.squeeze(dim=-1) 98 | v = v.squeeze(dim=-1) 99 | transport_matrix = torch.bmm(torch.bmm(matrix_diag(v), P), matrix_diag(u)) 100 | assert cost_matrix.shape == transport_matrix.shape 101 | 102 | # S = torch.mul(transport_matrix, 1-cost_matrix).sum(dim=1).sum(dim=1, keepdim=True) 103 | emd = torch.mul(transport_matrix, cost_matrix).sum(dim=1).sum(dim=1, keepdim=True) 104 | S = 2-2*emd 105 | return S 106 | 107 | def batched_semi_emd_loss(self, reps1, reps2, batch1, batch2, original_idx1=None, original_idx2=None): 108 | batch_size1 = batch1.max().cpu().item() 109 | batch_size2 = batch2.max().cpu().item() 110 | assert batch_size1 == batch_size2 111 | batch_size = batch_size1+1 112 | # avg nodes rep 113 | y_online1_pooling = global_mean_pool(reps1, batch1) 114 | y_online2_pooling = global_mean_pool(reps2, batch2) 115 | avg_out1 = y_online1_pooling[:, None, :] # (B,1,D) 116 | avg_out2 = y_online2_pooling[:, None, :] # (B,1,D) 117 | 118 | # x reps from sparse to dense 119 | out1, mask1 = to_dense_batch(reps1, batch1, fill_value=1e-8) # (B,N,D), B means batch size, N means the number of nodes, D means hidden size 120 | out2, mask2 = to_dense_batch(reps2, batch2, fill_value=1e-8) # (B,M,D) 121 | 122 | if original_idx1 != None and original_idx2 != None: 123 | dense_original_idx1, idx_mask1 = to_dense_batch(original_idx1, batch=batch1, fill_value=0) 124 | dense_original_idx2, idx_mask2 = to_dense_batch(original_idx2, batch=batch2, fill_value=0) 125 | topology_dist = self.topology_dist(dense_original_idx1, dense_original_idx2) 126 | rescale_ratio = torch.sigmoid(topology_dist/self.topo_t) 127 | topo_mask = torch.bmm(idx_mask1[:,:,None].float(), idx_mask2[:,None,:].float()).bool() 128 | loss_pos = self._batched_semi_emd_loss(out1, avg_out1, out2, avg_out2, rescale_ratio=rescale_ratio) * 2 129 | else: 130 | loss_pos = self._batched_semi_emd_loss(out1, avg_out1, out2, avg_out2) 131 | 132 | T = self.T # temperature 133 | f = lambda x: torch.exp(x/T) 134 | 135 | total_neg_loss = 0 136 | # completely create negative samples 137 | neg_index = list(range(batch_size)) 138 | for i in range((batch_size-1)): 139 | neg_index.insert(0, neg_index.pop(-1)) 140 | out1_perm = out1[neg_index].clone() 141 | out2_perm = out2[neg_index].clone() 142 | avg_out1_perm = avg_out1[neg_index].clone() 143 | avg_out2_perm = avg_out2[neg_index].clone() 144 | total_neg_loss += f(self._batched_semi_emd_loss(out1, avg_out1, out1_perm, avg_out1_perm)) + f(self._batched_semi_emd_loss(out1, avg_out1, out2_perm, avg_out2_perm)) 145 | 146 | loss = -torch.log(f(loss_pos) / (total_neg_loss)) 147 | 148 | return loss 149 | 150 | 151 | 152 | def gen_loss(self, reps1, reps2, batch1=None, batch2=None, original_idx1=None, original_idx2=None): 153 | # data with batch 154 | loss1 = self.batched_semi_emd_loss(reps1, reps2, batch1, batch2, original_idx1=original_idx1, original_idx2=original_idx2) 155 | loss2 = self.batched_semi_emd_loss(reps2, reps1, batch2, batch1, original_idx1=original_idx2, original_idx2=original_idx1) 156 | loss = (loss1*0.5 + loss2*0.5).mean() 157 | return loss 158 | 159 | def forward(self, view1, view2, batch1=None, batch2=None): 160 | h1, z1 = self.gen_rep(view1) 161 | h2, z2 = self.gen_rep(view2) 162 | 163 | if hasattr(view1, 'original_idx') and self.rectified: 164 | original_idx1, original_idx2 = view1.original_idx+1, view2.original_idx+1 # shift one, zero-index is stored for padding nodes 165 | loss = self.gen_loss(z1, z2, batch1, batch2, original_idx1, original_idx2) 166 | else: 167 | loss = self.gen_loss(z1, z2, batch1, batch2) 168 | 169 | return loss 170 | 171 | def embed(self, data): 172 | h = self.encoder(data) 173 | return h.detach() 174 | 175 | 176 | class GCN(torch.nn.Module): 177 | def __init__(self, input_dim, layer_num=2, hidden=128): 178 | super(GCN, self).__init__() 179 | self.layer_num = layer_num 180 | self.hidden = hidden 181 | self.input_dim = input_dim 182 | 183 | self.convs = ModuleList() 184 | if self.layer_num > 1: 185 | self.convs.append(GCNConv(input_dim, hidden*2)) 186 | for i in range(layer_num-2): 187 | self.convs.append(GCNConv(hidden*2, hidden*2)) 188 | glorot(self.convs[i].weight) 189 | self.convs.append(GCNConv(hidden*2, hidden)) 190 | glorot(self.convs[-1].weight) 191 | 192 | else: # one layer gcn 193 | self.convs.append(GCNConv(input_dim, hidden)) 194 | glorot(self.convs[-1].weight) 195 | 196 | 197 | def forward(self, data): 198 | x, edge_index = data.x, data.edge_index 199 | for i in range(self.layer_num): 200 | x = F.relu(self.convs[i](x, edge_index)) 201 | return x 202 | 203 | 204 | class GraphSAGE_GCN(torch.nn.Module): 205 | def __init__(self, input_dim, layer_num=3, hidden=512): 206 | super().__init__() 207 | self.convs = torch.nn.ModuleList() 208 | self.layer = layer_num 209 | self.acts = torch.nn.ModuleList() 210 | self.norms = torch.nn.ModuleList() 211 | 212 | for i in range(self.layer): 213 | if i == 0: 214 | self.convs.append(SAGEConv(input_dim, hidden, root_weight=True)) 215 | else: 216 | self.convs.append(SAGEConv(hidden, hidden, root_weight=True)) 217 | # self.acts.append(torch.nn.PReLU(hidden)) 218 | self.acts.append(torch.nn.ELU()) 219 | self.norms.append(torch.nn.BatchNorm1d(hidden)) 220 | 221 | def forward(self, data): 222 | x, edge_index = data.x, data.edge_index 223 | for i in range(self.layer): 224 | x = self.acts[i](self.norms[i](self.convs[i](x, edge_index))) 225 | return x 226 | 227 | 228 | class Projector(torch.nn.Module): 229 | def __init__(self, shape=()): 230 | super(Projector, self).__init__() 231 | if len(shape) < 3: 232 | raise Exception("Wrong shape for Projector") 233 | 234 | self.main = torch.nn.Sequential( 235 | torch.nn.Linear(shape[0], shape[1]), 236 | torch.nn.BatchNorm1d(shape[1]), 237 | torch.nn.ReLU(), 238 | torch.nn.Linear(shape[1], shape[2]) 239 | ) 240 | 241 | def forward(self, x): 242 | return self.main(x) 243 | 244 | 245 | def matrix_diag(diagonal): 246 | N = diagonal.shape[-1] 247 | shape = diagonal.shape[:-1] + (N, N) 248 | device, dtype = diagonal.device, diagonal.dtype 249 | result = torch.zeros(shape, dtype=dtype, device=device) 250 | indices = torch.arange(result.numel(), device=device).reshape(shape) 251 | indices = indices.diagonal(dim1=-2, dim2=-1) 252 | result.view(-1)[indices] = diagonal 253 | return result 254 | 255 | class LogReg(torch.nn.Module): 256 | def __init__(self, ft_in, nb_classes): 257 | super(LogReg, self).__init__() 258 | self.fc = torch.nn.Linear(ft_in, nb_classes) 259 | 260 | for m in self.modules(): 261 | self.weights_init(m) 262 | 263 | def weights_init(self, m): 264 | if isinstance(m, torch.nn.Linear): 265 | torch.nn.init.xavier_uniform_(m.weight.data) 266 | # torch.nn.init.xavier_normal_(m.weight.data) 267 | if m.bias is not None: 268 | m.bias.data.fill_(0.0) 269 | 270 | def forward(self, seq): 271 | ret = self.fc(seq) 272 | return ret 273 | --------------------------------------------------------------------------------