├── dataset_apis
├── topology_dist
│ └── .gitkeep
├── texas.py
├── cornell.py
├── wisconsin.py
├── cora.py
├── citeseer.py
├── dblp.py
└── pubmed.py
├── assets
└── framework.PNG
├── .gitignore
├── global_var.py
├── adversarial.py
├── eval.py
├── test_runs.py
├── config.yaml
├── augmentation.py
├── README.md
├── train.py
├── eval_utils.py
└── model.py
/dataset_apis/topology_dist/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/framework.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhuYun97/RoSA/HEAD/assets/framework.PNG
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # checkpoints
2 | checkpoints/
3 | runs_ckpts_Cora
4 | runs_ckpts_Citeseer
5 |
6 | # records for running
7 | runs/
8 |
9 | # dataset
10 | datasets/
11 | dataset_apis/topology_dist/*.pt
12 |
13 | # logs
14 | logs/
15 |
16 |
17 | # all pyc files_
18 | **/__pycache__
19 |
20 | **/.vscode
21 | **/ipynb_checkpoints
22 | *.ipynb
23 |
24 | transform_ckpt.py
--------------------------------------------------------------------------------
/dataset_apis/texas.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import WebKB
2 | import torch_geometric.transforms as T
3 |
4 | def load_trainset(trans):
5 | return WebKB(root='~/datasets', name='Texas', transform=trans)
6 |
7 | def load_eval_trainset():
8 | return WebKB(root='~/datasets', name='Texas')
9 |
10 | def load_testset():
11 | return WebKB(root='~/datasets', name='Texas')
--------------------------------------------------------------------------------
/global_var.py:
--------------------------------------------------------------------------------
1 | def _init(): # initialize
2 | global _global_dict
3 | _global_dict = {}
4 |
5 | def set_value(key, value):
6 | # define a global value
7 | _global_dict[key] = value
8 |
9 | def get_value(key):
10 | # get a pre-defined global value
11 | try:
12 | return _global_dict[key]
13 | except:
14 | print('access'+key+'is failing\r\n')
--------------------------------------------------------------------------------
/dataset_apis/cornell.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import WebKB
2 | import torch_geometric.transforms as T
3 |
4 | def load_trainset(trans):
5 | return WebKB(root='~/datasets', name='Cornell', transform=trans)
6 |
7 | def load_eval_trainset():
8 | return WebKB(root='~/datasets', name='Cornell')
9 |
10 | def load_testset():
11 | return WebKB(root='~/datasets', name='Cornell')
--------------------------------------------------------------------------------
/dataset_apis/wisconsin.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import WebKB
2 | import torch_geometric.transforms as T
3 |
4 | def load_trainset(trans):
5 | return WebKB(root='~/datasets', name='Wisconsin', transform=trans)
6 |
7 | def load_eval_trainset():
8 | return WebKB(root='~/datasets', name='Wisconsin')
9 |
10 | def load_testset():
11 | return WebKB(root='~/datasets', name='Wisconsin')
--------------------------------------------------------------------------------
/dataset_apis/cora.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import Planetoid
2 | import torch_geometric.transforms as T
3 |
4 | def load_trainset(trans):
5 | dataset = Planetoid(root='~/datasets', name='Cora', transform=T.Compose([trans]))
6 | return dataset
7 |
8 | def load_eval_trainset():
9 | return Planetoid(root='~/datasets', name='Cora')
10 |
11 | def load_testset():
12 | return Planetoid(root='~/datasets', name='Cora')
--------------------------------------------------------------------------------
/dataset_apis/citeseer.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import Planetoid
2 | import torch_geometric.transforms as T
3 |
4 | def load_trainset(trans):
5 | dataset = Planetoid(root='~/datasets', name='Citeseer', transform=T.Compose([trans]))
6 | return dataset
7 |
8 | def load_eval_trainset():
9 | return Planetoid(root='~/datasets', name='Citeseer')
10 |
11 | def load_testset():
12 | return Planetoid(root='~/datasets', name='Citeseer')
--------------------------------------------------------------------------------
/dataset_apis/dblp.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import CitationFull
2 | import torch_geometric.transforms as T
3 |
4 | def load_trainset(trans):
5 | dataset = CitationFull(root='~/datasets', name='dblp', transform=T.Compose([trans]))
6 | return dataset
7 |
8 | def load_eval_trainset():
9 | return CitationFull(root='~/datasets', name='dblp')
10 |
11 | def load_testset():
12 | return CitationFull(root='~/datasets', name='dblp')
--------------------------------------------------------------------------------
/dataset_apis/pubmed.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import Planetoid
2 | import torch_geometric.transforms as T
3 |
4 |
5 | def load_trainset(trans):
6 | dataset = Planetoid(root='~/datasets', name='Pubmed', transform=T.Compose([trans]))
7 | return dataset
8 |
9 | def load_eval_trainset():
10 | return Planetoid(root='~/datasets', name='Pubmed')
11 |
12 | def load_testset():
13 | return Planetoid(root='~/datasets', name='Pubmed')
--------------------------------------------------------------------------------
/adversarial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def ad_training(model, node_attack, perturb_shape, args, device):
4 | model.train()
5 |
6 | perturb = torch.FloatTensor(*perturb_shape).uniform_(-args.step_size, args.step_size).to(device)
7 | perturb.requires_grad_()
8 |
9 | loss = node_attack(perturb)
10 | loss /= args.m
11 |
12 | for i in range(args.m-1):
13 | loss.backward()
14 | perturb_data = perturb.detach() + args.step_size * torch.sign(perturb.grad.detach())
15 | perturb.data = perturb_data.data
16 | perturb.grad[:] = 0
17 |
18 | loss = node_attack(perturb)
19 | loss /= args.m
20 |
21 | return loss
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from eval_utils import eval
2 | import argparse
3 | import random
4 | from model import RoSA
5 | from eval_utils import eval
6 | import yaml
7 | from yaml import SafeLoader
8 | import torch
9 |
10 |
11 | def str2bool(v):
12 | if isinstance(v, bool):
13 | return v
14 | if v.lower() in ('yes','y','true','t','1'):
15 | return True
16 | if v.lower() in ('no','n','false','f','0'):
17 | return False
18 | else:
19 | raise argparse.ArgumentTypeError('Boolean value expected.')
20 | def get_parser():
21 | parser = argparse.ArgumentParser(description='Description: Script to run our model.')
22 | parser.add_argument('--config', type=str, default='config.yaml')
23 | # dataset
24 | parser.add_argument('--dataset',help='Cora, Citeseer , Pubmed, etc. Default=Cora', default='Cora')
25 |
26 | return parser
27 |
28 | device = "cuda" if torch.cuda.is_available() else "cpu"
29 |
30 | if __name__ == "__main__":
31 | parser = get_parser()
32 | try:
33 | args = parser.parse_args()
34 | except:
35 | exit()
36 | config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset]
37 |
38 | # combine args and config
39 | for k, v in config.items():
40 | args.__setattr__(k, v)
41 |
42 | # repeated experiment
43 | torch.manual_seed(args.seed)
44 | random.seed(12345)
45 |
46 | model = RoSA(
47 | # model
48 | encoder=args.encoder,
49 | # shape
50 | input_dim=args.input_dim,
51 | # model configuration
52 | layer_num=args.layer_num,
53 | hidden=args.hidden,
54 | proj_shape=(args.proj_middim, args.proj_outdim),
55 | ).to(device)
56 |
57 | model.load_state_dict(torch.load(f"checkpoints/{args.dataset}/best.pt"))
58 |
59 | test_acc = eval(args, model, device)
--------------------------------------------------------------------------------
/test_runs.py:
--------------------------------------------------------------------------------
1 | from cgi import test
2 | from collections import OrderedDict
3 | from eval_utils import eval
4 | import argparse
5 | import random
6 | from model import RoSA
7 | from eval_utils import eval
8 | import yaml
9 | from yaml import SafeLoader
10 | import torch
11 | import numpy as np
12 | from tqdm import tqdm
13 | import os, wget
14 |
15 |
16 | def str2bool(v):
17 | if isinstance(v, bool):
18 | return v
19 | if v.lower() in ('yes','y','true','t','1'):
20 | return True
21 | if v.lower() in ('no','n','false','f','0'):
22 | return False
23 | else:
24 | raise argparse.ArgumentTypeError('Boolean value expected.')
25 | def get_parser():
26 | parser = argparse.ArgumentParser(description='Description: Script to run our model.')
27 | parser.add_argument('--config', type=str, default='config.yaml')
28 | # dataset
29 | parser.add_argument('--dataset',help='Cora, Citeseer , Pubmed, etc. Default=Cora', default='Cora')
30 |
31 | return parser
32 |
33 | device = "cuda" if torch.cuda.is_available() else "cpu"
34 |
35 | eval_seeds = {
36 | 'Cora': 8085,
37 | 'Citeseer': 2230
38 | }
39 |
40 | download_url = 'https://raw.githubusercontent.com/ZhuYun97/RoSA_ckpts/main/'
41 |
42 | if __name__ == "__main__":
43 | parser = get_parser()
44 | try:
45 | args = parser.parse_args()
46 | except:
47 | exit()
48 | config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset]
49 |
50 | # combine args and config
51 | for k, v in config.items():
52 | args.__setattr__(k, v)
53 |
54 | # repeated experiment
55 | seed = eval_seeds[args.dataset]
56 | torch.manual_seed(seed)
57 | random.seed(seed)
58 | np.random.seed(seed)
59 | torch.manual_seed(seed)
60 | if torch.cuda.device_count() > 0:
61 | torch.cuda.manual_seed_all(seed)
62 | torch.backends.cudnn.deterministic=True
63 |
64 | model = RoSA(
65 | # model
66 | encoder=args.encoder,
67 | # shape
68 | input_dim=args.input_dim,
69 | # model configuration
70 | layer_num=args.layer_num,
71 | hidden=args.hidden,
72 | proj_shape=(args.proj_middim, args.proj_outdim),
73 | ).to(device)
74 |
75 | dir = f'./runs_ckpts_{args.dataset}'
76 | exist_ckpts = os.path.exists(dir)
77 | if not exist_ckpts:
78 | os.makedirs(dir)
79 |
80 | test_acc_list = []
81 | progress = tqdm(range(20))
82 | for i in progress:
83 | file = os.path.join(dir, f'{args.dataset}_run{i}.pt')
84 | if not os.path.exists(file):
85 | # download non-existing ckpt
86 | file_url = os.path.join(download_url, f'runs_ckpts_{args.dataset}/{args.dataset}_run{i}.pt')
87 | wget.download(file_url, file)
88 | pretrained_dicts = torch.load(file)
89 | model.load_state_dict(pretrained_dicts)
90 | test_acc = eval(args, model, device)
91 | test_acc_list.append(test_acc)
92 | progress.set_postfix({'RUN': i, 'ACC': test_acc*100})
93 | print(f"After 20 runs, test acc: {round(np.mean(test_acc_list)*100, 2)} ± {round(np.std(test_acc_list)*100, 2)}")
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | Cora:
2 | seed: 39788
3 | learning_rate: 0.01
4 | input_dim: 1433
5 | hidden: 128
6 | proj_middim: 128
7 | proj_outdim: 128
8 | encoder: gcn
9 | layer_num: 2
10 | walk_step: 10
11 | restart_ratio: 0.5
12 | graph_num: 128
13 | drop_edge_rate_1: 0.2
14 | drop_edge_rate_2: 0.3
15 | drop_feature_rate_1: 0.2
16 | drop_feature_rate_2: 0.3
17 | tau: 0.4
18 | epochs: 500
19 | weight_decay: 0.0005
20 | optimizer: sgd
21 | patience: 500
22 | inductive: False
23 | Citeseer:
24 | seed: 38108
25 | learning_rate: 0.01
26 | input_dim: 3703
27 | hidden: 256
28 | proj_middim: 256
29 | proj_outdim: 256
30 | encoder: gcn
31 | layer_num: 2
32 | walk_step: 10
33 | restart_ratio: 0.5
34 | graph_num: 128
35 | drop_edge_rate_1: 0.4
36 | drop_edge_rate_2: 0.5
37 | drop_feature_rate_1: 0.4
38 | drop_feature_rate_2: 0.5
39 | tau: 0.7
40 | epochs: 300
41 | weight_decay: 0.0005
42 | optimizer: sgd
43 | patience: 300
44 | inductive: False
45 | Pubmed:
46 | seed: 23344
47 | learning_rate: 0.001
48 | input_dim: 500
49 | hidden: 256
50 | proj_middim: 256
51 | proj_outdim: 256 # 128
52 | encoder: gcn
53 | layer_num: 2
54 | walk_step: 20
55 | restart_ratio: 0.8
56 | graph_num: 256
57 | drop_edge_rate_1: 0.4
58 | drop_edge_rate_2: 0.1
59 | drop_feature_rate_1: 0.0
60 | drop_feature_rate_2: 0.2
61 | tau: 0.1
62 | epochs: 500
63 | weight_decay: 0.0005
64 | optimizer: adamw
65 | patience: 500
66 | inductive: False
67 | DBLP:
68 | seed: 83521
69 | learning_rate: 0.001
70 | input_dim: 1639
71 | hidden: 256
72 | proj_middim: 256
73 | proj_outdim: 256
74 | encoder: gcn
75 | layer_num: 2
76 | walk_step: 10
77 | restart_ratio: 0.5
78 | graph_num: 128
79 | drop_edge_rate_1: 0.1
80 | drop_edge_rate_2: 0.2
81 | drop_feature_rate_1: 0.2
82 | drop_feature_rate_2: 0.3
83 | tau: 0.8
84 | epochs: 500
85 | weight_decay: 0.0005
86 | optimizer: adamw
87 | patience: 500
88 | inductive: False
89 | Cornell:
90 | seed: 39788
91 | learning_rate: 0.001
92 | input_dim: 1703
93 | hidden: 64
94 | proj_middim: 64
95 | proj_outdim: 64
96 | encoder: gcn
97 | layer_num: 2
98 | walk_step: 10
99 | restart_ratio: 0.0
100 | graph_num: 256
101 | drop_edge_rate_1: 0.2
102 | drop_edge_rate_2: 0.3
103 | drop_feature_rate_1: 0.2
104 | drop_feature_rate_2: 0.3
105 | tau: 0.4
106 | epochs: 200
107 | weight_decay: 0.0005
108 | optimizer: sgd
109 | patience: 20
110 | inductive: False
111 | Texas:
112 | seed: 39788
113 | learning_rate: 0.001
114 | input_dim: 1703
115 | hidden: 64
116 | proj_middim: 64
117 | proj_outdim: 64
118 | encoder: gcn
119 | layer_num: 2
120 | walk_step: 20
121 | restart_ratio: 0.0
122 | graph_num: 256
123 | drop_edge_rate_1: 0.2
124 | drop_edge_rate_2: 0.3
125 | drop_feature_rate_1: 0.2
126 | drop_feature_rate_2: 0.3
127 | tau: 0.4
128 | epochs: 200
129 | weight_decay: 0.0005
130 | optimizer: sgd
131 | patience: 20
132 | inductive: False
133 | Wisconsin:
134 | seed: 39788
135 | learning_rate: 0.001
136 | input_dim: 1703
137 | hidden: 64
138 | proj_middim: 64
139 | proj_outdim: 64
140 | encoder: gcn
141 | layer_num: 2
142 | walk_step: 10
143 | restart_ratio: 0.0
144 | graph_num: 256
145 | drop_edge_rate_1: 0.2
146 | drop_edge_rate_2: 0.3
147 | drop_feature_rate_1: 0.2
148 | drop_feature_rate_2: 0.3
149 | tau: 0.4
150 | epochs: 200
151 | weight_decay: 0.0005
152 | optimizer: sgd
153 | patience: 20
154 | inductive: False
155 |
156 |
--------------------------------------------------------------------------------
/augmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data
3 | from torch_geometric.utils import subgraph, to_undirected, remove_isolated_nodes, dropout_adj, remove_self_loops
4 | from torch_geometric.utils.num_nodes import maybe_num_nodes
5 | import copy
6 | from torch_sparse import SparseTensor
7 |
8 |
9 | def add_remaining_selfloop_for_isolated_nodes(edge_index, num_nodes):
10 | num_nodes = max(maybe_num_nodes(edge_index), num_nodes)
11 | # only add self-loop on isolated nodes
12 | # edge_index, _ = remove_self_loops(edge_index)
13 | loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device)
14 | connected_nodes_indices = torch.cat([edge_index[0], edge_index[1]]).unique()
15 | mask = torch.ones(num_nodes, dtype=torch.bool)
16 | mask[connected_nodes_indices] = False
17 | loops_for_isolatd_nodes = loop_index[mask]
18 | loops_for_isolatd_nodes = loops_for_isolatd_nodes.unsqueeze(0).repeat(2, 1)
19 | edge_index = torch.cat([edge_index, loops_for_isolatd_nodes], dim=1)
20 | return edge_index
21 |
22 |
23 | class RWR:
24 | """ Every node in the graph will get a random path
25 |
26 | A stochastic data augmentation module that transforms a complete graph into many subgraphs through random walking
27 | the subgraphs which contain the same center nodes are positive pairs, otherwise they are negative pairs
28 | """
29 |
30 | def __init__(self, walk_step=50, graph_num=128, restart_ratio=0.5, inductive=False, aligned=False, **args):
31 | self.walk_steps = walk_step
32 | self.graph_num = graph_num
33 | self.restart_ratio = restart_ratio
34 | self.inductive = inductive
35 | self.aligned = aligned
36 |
37 | def __call__(self, graph):
38 | graph = copy.deepcopy(graph) # modified on the copy
39 | assert self.walk_steps > 1
40 | # remove isolated nodes (or we can construct edges for these nodes)
41 | if self.inductive:
42 | train_node_idx = torch.where(graph.train_mask == True)[0]
43 | graph.edge_index, _ = subgraph(train_node_idx, graph.edge_index) # remove val and test nodes (val and test are considered as isolated nodes)
44 | edge_index, _, mask = remove_isolated_nodes(graph.edge_index, num_nodes=graph.x.shape[0]) # remove all ioslated nodes and re-index nodes
45 | graph.x = graph.x[mask]
46 | edge_index = to_undirected(graph.edge_index)
47 | edge_index = add_remaining_selfloop_for_isolated_nodes(edge_index, graph.x.shape[0])
48 | graph.edge_index = edge_index
49 |
50 | node_num = graph.x.shape[0]
51 | graph_num = min(self.graph_num, node_num)
52 | start_nodes = torch.randperm(node_num)[:graph_num]
53 | edge_index = graph.edge_index
54 |
55 | value = torch.arange(edge_index.size(1))
56 | self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
57 | value=value,
58 | sparse_sizes=(node_num, node_num)).t()
59 |
60 | view1_list = []
61 | view2_list = []
62 |
63 | views_cnt = 1 if self.aligned else 2
64 | for view_idx in range(views_cnt):
65 | current_nodes = start_nodes.clone()
66 | history = start_nodes.clone().unsqueeze(0)
67 | signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0)
68 | for i in range(self.walk_steps):
69 | seed = torch.rand([graph_num])
70 | nei = self.adj_t.sample(1, current_nodes).squeeze()
71 | sign = seed < self.restart_ratio
72 | nei[sign] = start_nodes[sign]
73 | history = torch.cat((history, nei.unsqueeze(0)), dim=0)
74 | signs = torch.cat((signs, sign.unsqueeze(0)), dim=0)
75 | current_nodes = nei
76 | history = history.T
77 | signs = signs.T
78 |
79 | for i in range(graph_num):
80 | path = history[i]
81 | sign = signs[i]
82 | node_idx = path.unique()
83 | sources = path[:-1].numpy().tolist()
84 | targets = path[1:].numpy().tolist()
85 | sub_edges = torch.IntTensor([sources, targets]).type_as(graph.edge_index)
86 | sub_edges = sub_edges.T[~sign[1:]].T
87 | # undirectional
88 | if sub_edges.shape[1] != 0:
89 | sub_edges = to_undirected(sub_edges)
90 | view = self.adjust_idx(sub_edges, node_idx, graph, path[0].item())
91 |
92 | if self.aligned:
93 | view1_list.append(view)
94 | view2_list.append(copy.deepcopy(view))
95 | else:
96 | if view_idx == 0:
97 | view1_list.append(view)
98 | else:
99 | view2_list.append(view)
100 | return (view1_list, view2_list)
101 |
102 | def adjust_idx(self, edge_index, node_idx, full_g, center_idx):
103 | '''re-index the nodes and edge index
104 |
105 | In the subgraphs, some nodes are droppped. We need to change the node index in edge_index in order to corresponds
106 | nodes' index to edge index
107 | '''
108 | node_idx_map = {j : i for i, j in enumerate(node_idx.numpy().tolist())}
109 | sources_idx = list(map(node_idx_map.get, edge_index[0].numpy().tolist()))
110 | target_idx = list(map(node_idx_map.get, edge_index[1].numpy().tolist()))
111 |
112 | edge_index = torch.IntTensor([sources_idx, target_idx]).type_as(full_g.edge_index)
113 | x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], center=node_idx_map[center_idx], original_idx=node_idx)
114 | return x_view
115 |
116 |
117 | class DropEdgeAndFeature(object):
118 | def __init__(self, d_fea=0.2, d_edge=0.3):
119 | self.drop_feature_ratio=d_fea
120 | self.drop_edge_ratio=d_edge
121 |
122 | def set_drop_ratio(self, d_fea, d_edge):
123 | self.drop_feature_ratio=d_fea
124 | self.drop_edge_ratio=d_edge
125 |
126 | def set_drop_edge_ratio(self, d_edge):
127 | self.drop_edge_ratio=d_edge
128 |
129 | def set_drop_fea_ratio(self, d_fea):
130 | self.drop_feature_ratio=d_fea
131 |
132 | def drop_fea_edge(self, g):
133 | new_g = copy.deepcopy(g)
134 | edge_index = self._drop_edge(new_g)
135 | x = self._drop_feature(new_g)
136 | new_g.x = x
137 | new_g.edge_index = edge_index
138 | return new_g
139 |
140 | def _drop_feature(self, g):
141 | drop_mask = torch.empty(
142 | (g.x.size(1), ),
143 | dtype=torch.float32,
144 | device=g.x.device).uniform_(0, 1) < self.drop_feature_ratio
145 | x = g.x.clone()
146 | x[:, drop_mask] = 0
147 | return x
148 |
149 | def _drop_edge(self, g):
150 | edge_index, _ = dropout_adj(g.edge_index, p=self.drop_edge_ratio)
151 | return edge_index
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RoSA: A Robust Self-Aligned Framework for Node-Node Graph Contrastive Learning
2 | This repo is Pytorch implemention of
[RoSA: A Robust Self-Aligned Framework for Node-Node Graph Contrastive Learning](https://www.ijcai.org/proceedings/2022/0527.pdf)[\[PPT\]](https://docs.google.com/presentation/d/1RnpXrjyaojJ3Tiqu92NImtYI4cD2XVmr/edit?usp=sharing&ouid=103567337311720006952&rtpof=true&sd=true)[\[appendix\]](https://arxiv.org/abs/2204.13846)
3 | Yun Zhu\*, Jianhao Guo\*, Fei Wu, Siliang Tang†
4 | In IJCAI 2022
5 |
6 | ## Overview
7 | This is the first work dedicated to solving non-aligned node-node graph contrastive learning problems. To tackle the non-aligned problem, we introduce a novel graph-based optimal transport algorithm, g-EMD, which does not require explicit node-node correspondence and can fully utilize graph topological and attributive information for non-aligned node-node contrasting. Moreover, to compensate for the possible information loss caused by non-aligned sub-sampling, we propose a nontrivial unsupervised graph adversarial training to improve the diversity of sub-sampling and strengthen the robustness of the model. The overview of our method is depicted as:
8 | 
9 |
10 | ## Files
11 | ```
12 | .
13 | ├── dataset_apis # Code process datasets.
14 | │ ├── topology_dist # Storing the distance of the shortest path (SPD) between vi and vj.
15 | │ ├── citeseer.py # processing for citeseer dataset.
16 | │ ├── cora.py # processing for cora dataset.
17 | │ ├── dblp.py # processing for dblp dataset.
18 | │ ├── pubmed.py # processing for pubmed dataset.
19 | │ ├── cornell.py # processing for cornell dataset.
20 | │ ├── wisconsin.py # processing for wisconsin dataset.
21 | │ ├── texas.py # processing for texas dataset.
22 | │ └── ... # More datasets will be added.
23 | │
24 | ├── adversarial.py # Code for unsupervised adversarial training.
25 | ├── augmentation.py # Code for augmentation.
26 | ├── config.yaml # Configurations for our method.
27 | ├── eval_utils.py # The toolkits for evaluation.
28 | ├── eval.py # Code for evaluation.
29 | ├── global_var.py # Code for storing global variable.
30 | ├── model.py # Code for building up model.
31 | ├── train.py # Training process.
32 | ├── test_runs.py # Reproduce the results reported in our paper
33 | └── ...
34 | ```
35 |
36 | ## Setup
37 | Recommand you to set up a Python virtual environment with the required dependencies as follows:
38 | ```
39 | conda create -n rosa python==3.9
40 | conda activate rosa
41 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
42 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.8.0+cu111.html
43 | ```
44 | ## Usage
45 | **Command for training model on Cora dataset**
46 | ```bash
47 | CUDA_VISIBLE_DEVICES=0 python train.py --dataset=Cora --config=config.yaml --ad=True --rectified=True
48 | ```
49 | For efficient usage, you can run as below:
50 | ```bash
51 | CUDA_VISIBLE_DEVICES=0 python train.py --dataset=Cora --config=config.yaml --ad=False --rectified=False
52 | ```
53 | Now supported datasets include Cora, Citeseer, Pubmed, DBLP, Cornell, Wisconsin, Texas. More datasets are coming soon!
54 |
55 | **Command for testing model on Cora dataset**
56 | After training, the best checkpoint will be stored in `checkpoints\\` dir. Then you can test the checkpoint through this command:
57 | ```bash
58 | CUDA_VISIBLE_DEVICES=0 python eval.py --dataset=Cora --config=config.yaml
59 | ```
60 |
61 | **Command for reproducing the results in the paper**
62 | Because our results are based on 20 independent experiments. We provide the all(20) pre-trained ckpts of RoSA for Cora and Citeseer. As for other larger datasets, you can contact us by email if you need.
63 | The usage is quite simple, you just run the command below. The ckpts will be downloaded automatically and the pre-trained models will be tested.
64 | ```bash
65 | CUDA_VISIBLE_DEVICES=0 python test_runs.py --dataset=Cora
66 | ```
67 | > After 20 runs, test acc: 84.56 ± 0.67
68 |
69 | ```bash
70 | CUDA_VISIBLE_DEVICES=0 python test_runs.py --dataset=Citeseer
71 | ```
72 | > After 20 runs, test acc: 73.5 ± 0.45
73 |
74 | ### Illustration of arguements
75 |
76 | ```
77 | --dataset: default Cora, [Cora, Citeseer, Pubmed, DBLP, Cornell, Wisconsin, Texas] can also be choosen
78 | --rectified: defalut False, use rectified cost matrix instead of vanilla cost matrix (if True)
79 | --ad: default False, use unsupervised adversarial training (if True)
80 | --aligned: default False, use aligned views (if True)
81 | ```
82 |
83 | ### More experiments
84 | We conduct experimetns on other five commonly used datasets with RoSA, the results show in Table 1. RoSA reaches SOTA on these datasets which proves the effectiveness of our method.
85 | | Method | Wiki-CS | Amazon-Computers | Amazon-Photo | Coauthor-CS | Coauthor-Physics |
86 | | :----:| :----------: | :----------: | :----------: | :----------: | :----------: |
87 | | DGI | 75.35 ± 0.14 | 83.95 ± 0.47 | 91.61 ± 0.22 | 92.15 ± 0.63 | 94.51 ± 0.52 |
88 | | GMI | 74.85 ± 0.08 | 82.21 ± 0.31 | 90.68 ± 0.17 | OOM | OOM |
89 | | MVGRL | 77.52 ± 0.08 | 87.52 ± 0.11 | 91.74 ± 0.07 | 92.11 ± 0.12 | 95.33 ± 0.03 |
90 | | GRACE | 78.19 ± 0.01 | 87.46 ± 0.22 | 92.15 ± 0.24 | 92.93 ± 0.01 | 95.26 ± 0.02 |
91 | | GCA | 78.35 ± 0.05 | 88.94 ± 0.15 | 92.53 ± 0.16 | 93.10 ± 0.01 | 95.73 ± 0.03 |
92 | | BGRL | 79.36 ± 0.53 | 89.68 ± 0.31 | 92.87 ± 0.27 | 93.21 ± 0.18 | 95.56 ± 0.12 |
93 | | RoSA | **80.11 ± 0.10** | **90.12 ± 0.26** | **93.67 ± 0.07** | **93.23 ± 0.13** | **95.76 ± 0.09** |
94 |
95 |
96 | | | Hidden size | Batch size | Learning rate | Walk length | Epochs | tau | p_{e,1} | p_{e,1} | p_{f,1} | p_{f,1} |
97 | | :---- | :---------- | :--------- | :------------ | :----------- | :----- | :-- | :-------- | :-------- | :-------- | :-------- |
98 | | Wiki-CS | 256 | 256 | 1e-3 | 10 | 500 | 0.5 | 0.2 | 0.3 | 0.2 | 0.3 |
99 | | Amazon-Computers | 128 | 256 | 1e-3 | 10 | 500 | 0.2 | 0.4 | 0.5 | 0.1 | 0.2 |
100 | | Amazon-Photo | 256 | 256 | 1e-3 | 10 | 500 | 0.3 | 0.2 | 0.3 | 0.2 | 0.3 |
101 | | Coauthor-CS | 256 | 128 | 1e-3 | 10 | 100 | 0.1 | 0.2 | 0.3 | 0.2 | 0.3 |
102 | | Coauthor-Physics | 128 | 256 | 1e-3 | 10 | 100 | 0.5 | 0.2 | 0.3 | 0.2 | 0.3 |
103 |
104 | In additions, we use `prelu` as activation function and use `adamw` optimizer with `5e-4 weight decay` for all experimetns. The restart ratio of random walking is 0.5.
105 |
106 | ## Citation
107 | If you use this code for you research, please cite our paper.
108 | ```
109 | @inproceedings{ijcai2022-527,
110 | title = {RoSA: A Robust Self-Aligned Framework for Node-Node Graph Contrastive Learning},
111 | author = {Zhu, Yun and Guo, Jianhao and Wu, Fei and Tang, Siliang},
112 | booktitle = {Proceedings of the Thirty-First International Joint Conference on
113 | Artificial Intelligence, {IJCAI-22}},
114 | publisher = {International Joint Conferences on Artificial Intelligence Organization},
115 | editor = {Lud De Raedt},
116 | pages = {3795--3801},
117 | year = {2022},
118 | month = {7},
119 | note = {Main Track},
120 | doi = {10.24963/ijcai.2022/527},
121 | url = {https://doi.org/10.24963/ijcai.2022/527},
122 | }
123 | ```
124 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Batch
2 | import numpy as np
3 | import torch
4 | import os
5 | import argparse
6 | import importlib
7 | import random
8 | from torch_geometric.data import DataLoader
9 | from torch_geometric.utils import contains_isolated_nodes, to_networkx
10 | from augmentation import RWR, DropEdgeAndFeature
11 | from model import RoSA
12 | from adversarial import ad_training
13 | from eval_utils import eval
14 | import global_var
15 | import networkx as nx
16 | import yaml
17 | from yaml import SafeLoader
18 | from tqdm import tqdm
19 |
20 |
21 | device = "cuda" if torch.cuda.is_available() else "cpu"
22 |
23 | def train(args):
24 | trans = RWR(walk_step=args.walk_step, graph_num=args.graph_num, restart_ratio=args.restart_ratio, aligned=args.aligned, inductive=args.inductive)
25 | load_dataset = getattr(importlib.import_module(f"dataset_apis.{args.dataset.lower()}"), 'load_trainset')
26 | dataset = load_dataset(trans)
27 |
28 | train_loader = DataLoader(dataset, batch_size=1, shuffle=False)
29 |
30 | # Model
31 | model = RoSA(
32 | # model
33 | encoder=args.encoder,
34 | # shape
35 | input_dim=args.input_dim,
36 | # model configuration
37 | layer_num=args.layer_num,
38 | hidden=args.hidden,
39 | proj_shape=(args.proj_middim, args.proj_outdim),
40 | # loss
41 | is_rectified=args.rectified,
42 | T = args.tau,
43 | topo_t=args.topo_t
44 | ).to(device)
45 |
46 | # optimizer
47 | if args.optimizer == 'sgd':
48 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9)
49 | elif args.optimizer == 'adam':
50 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
51 | elif args.optimizer == 'adamw':
52 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
53 |
54 | # save checkpoints
55 | path = f'./checkpoints/{args.dataset}/'
56 | if not os.path.exists(path):
57 | os.makedirs(path)
58 |
59 | patience = args.patience # early stopping
60 | stop_cnt = 0
61 | best = 9999
62 | best_t = 0
63 |
64 | loop = tqdm(range(args.epochs))
65 | for epoch in loop:
66 | model.train()
67 | for idx, graphs in enumerate(train_loader):
68 | model.train()
69 | optimizer.zero_grad()
70 | # Because the augmentation function, the processed batch in loader will be [batch, batch, batch], we should collect them into one batch
71 | view1_list = []
72 | view2_list = []
73 | assert len(graphs[0]) == len(graphs[1])
74 | all_graphs_num = len(graphs[0])
75 | for i in range(all_graphs_num):
76 | view1_list.extend(graphs[0][i].to_data_list())
77 | view2_list.extend(graphs[1][i].to_data_list())
78 |
79 | # shuffle views list
80 | shuf_idx = np.random.permutation(all_graphs_num)
81 | view1_list_shuff = [view1_list[i] for i in shuf_idx]
82 | view2_list_shuff = [view2_list[i] for i in shuf_idx]
83 |
84 | views1 = Batch().from_data_list(view1_list_shuff).to(device)
85 | views2 = Batch().from_data_list(view2_list_shuff).to(device)
86 | # additional augmentation
87 | drop = DropEdgeAndFeature(d_fea=args.drop_feature_rate_1, d_edge=args.drop_edge_rate_1)
88 | views1 = drop.drop_fea_edge(views1)
89 | drop.set_drop_ratio(d_fea=args.drop_feature_rate_2, d_edge=args.drop_edge_rate_2)
90 | views2 = drop.drop_fea_edge(views2)
91 |
92 | if args.ad:
93 | def node_attack(perturb):
94 | views1.x += perturb
95 | return model(views1, views2, views1.batch, views2.batch)
96 |
97 | loss = ad_training(model, node_attack, views1.x.shape, args, device)
98 | else:
99 | loss = model(views1, views2, views1.batch, views2.batch)
100 | loop.set_postfix(loss = loss.item())
101 | loss.backward()
102 | optimizer.step()
103 |
104 | if loss < best:
105 | stop_cnt = 0
106 | best = loss
107 | best_t = epoch + 1
108 | torch.save(model.state_dict(), os.path.join(path, 'best.pt'))
109 | else:
110 | stop_cnt += 1
111 |
112 | if stop_cnt >= patience:
113 | print("early stopping")
114 | break
115 |
116 | if patience < args.epochs:
117 | print('Loading {}th epoch'.format(best_t))
118 | model.load_state_dict(torch.load(os.path.join(path, 'best.pt')))
119 |
120 | return model
121 |
122 | def register_topology(args):
123 | MAX_HOP = 100
124 | load_dataset = getattr(importlib.import_module(f"dataset_apis.{args.dataset.lower()}"), 'load_eval_trainset')
125 | data = load_dataset()[0]
126 | topo_file = f"./dataset_apis/topology_dist/{args.dataset.lower()}_padding.pt"
127 | exist = os.path.isfile(topo_file)
128 | if not exist:
129 | node_num = data.x.shape[0]
130 | G = to_networkx(data)
131 | generator = dict(nx.shortest_path_length(G))
132 | topology_dist = torch.zeros((node_num+1, node_num+1)) # we shift the node index with 1, in order to store 0-index for padding nodes
133 | mask = torch.zeros((node_num+1, node_num+1)).bool()
134 |
135 | topology_dist[0, :] = 1000 # used for padding nodes
136 | topology_dist[:, 0] = 1000
137 |
138 | for i in tqdm(range(1, node_num+1)):
139 | # print(f"processing {i}-th node")
140 | for j in range(1, node_num+1):
141 | if j-1 in generator[i-1].keys():
142 | topology_dist[i][j] = generator[i-1][j-1]
143 | else:
144 | topology_dist[i][j] = MAX_HOP
145 | mask[i][j] = True # record nodes that do not have connections
146 | torch.save(topology_dist, topo_file)
147 | else:
148 | topology_dist = torch.load(topo_file)
149 | global_var._init()
150 | global_var.set_value("topology_dist", topology_dist)
151 |
152 | def str2bool(v):
153 | if isinstance(v, bool):
154 | return v
155 | if v.lower() in ('yes','y','true','t','1'):
156 | return True
157 | if v.lower() in ('no','n','false','f','0'):
158 | return False
159 | else:
160 | raise argparse.ArgumentTypeError('Boolean value expected.')
161 | def get_parser():
162 | parser = argparse.ArgumentParser(description='Description: Script to run our model.')
163 | parser.add_argument('--config', type=str, default='config.yaml')
164 | # dataset
165 | parser.add_argument('--dataset',help='Cora, Citeseer , Pubmed, etc. Default=Cora', default='Cora')
166 | # augmentation
167 | parser.add_argument('--aligned', type=str2bool, help='aligned views or not', default=False)
168 | # adversarial training
169 | parser.add_argument('--ad', type=str2bool, help='combine with adversarial training', default=False)
170 | parser.add_argument('--step-size', type=float, default=1e-3)
171 | parser.add_argument('--m', type=int, default=3)
172 | # loss
173 | parser.add_argument('--rectified', type=str2bool, help='use rectified cost matrix', default=False)
174 | parser.add_argument('--topo_t', type=int, help='temperature for sigmoid', default=2)
175 |
176 | return parser
177 |
178 | if __name__ == "__main__":
179 | parser = get_parser()
180 | try:
181 | args = parser.parse_args()
182 | except:
183 | exit()
184 | config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset]
185 |
186 | # combine args and config
187 | for k, v in config.items():
188 | args.__setattr__(k, v)
189 |
190 | # repeated experiment
191 | torch.manual_seed(args.seed)
192 | random.seed(12345)
193 |
194 | if args.rectified:
195 | register_topology(args)
196 |
197 | model = train(args)
198 | test_acc = eval(args, model, device)
199 |
200 |
201 |
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/eval_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import functools
3 | from sklearn.metrics import accuracy_score, f1_score
4 | from sklearn.linear_model import LogisticRegression
5 | from sklearn.model_selection import train_test_split, GridSearchCV, ShuffleSplit
6 | from sklearn.multiclass import OneVsRestClassifier
7 | from sklearn.preprocessing import normalize, OneHotEncoder
8 | import importlib
9 | import torch
10 | from torch_sparse import SparseTensor
11 | from model import LogReg
12 |
13 |
14 | # borrow from BGRL [https://github.com/nerdslab/bgrl/blob/main/bgrl/logistic_regression_eval.py]
15 | def fit_logistic_regression(X, y, data_random_seed=1, repeat=1):
16 | # transfrom targets to one-hot vector
17 | one_hot_encoder = OneHotEncoder(categories='auto', sparse=False)
18 |
19 | y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool)
20 |
21 | # normalize x
22 | X = normalize(X, norm='l2')
23 |
24 | # set random state
25 | rng = np.random.RandomState(data_random_seed) # this will ensure the dataset will be split exactly the same
26 | # throughout training
27 |
28 | accuracies = []
29 | for _ in range(repeat):
30 | # different random split after each repeat
31 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=rng)
32 |
33 | # grid search with one-vs-rest classifiers
34 | logreg = LogisticRegression(solver='liblinear')
35 | c = 2.0 ** np.arange(-10, 11)
36 | cv = ShuffleSplit(n_splits=5, test_size=0.5)
37 | clf = GridSearchCV(estimator=OneVsRestClassifier(logreg), param_grid=dict(estimator__C=c),
38 | n_jobs=5, cv=cv, verbose=0)
39 | clf.fit(X_train, y_train)
40 |
41 | y_pred = clf.predict_proba(X_test)
42 | y_pred = np.argmax(y_pred, axis=1)
43 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool)
44 |
45 | test_acc = accuracy_score(y_test, y_pred)
46 | accuracies.append(test_acc)
47 | return accuracies
48 |
49 |
50 | # borrow from BGRL [https://github.com/nerdslab/bgrl/blob/main/bgrl/logistic_regression_eval.py]
51 | def fit_logistic_regression_preset_splits(X, y, train_masks, val_masks, test_mask):
52 | # transfrom targets to one-hot vector
53 | one_hot_encoder = OneHotEncoder(categories='auto', sparse=False)
54 | y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool)
55 |
56 | # normalize x
57 | X = normalize(X, norm='l2')
58 |
59 | accuracies = []
60 | for split_id in range(train_masks.shape[1]):
61 | # get train/val/test masks
62 | train_mask, val_mask = train_masks[:, split_id], val_masks[:, split_id]
63 |
64 | # make custom cv
65 | X_train, y_train = X[train_mask], y[train_mask]
66 | X_val, y_val = X[val_mask], y[val_mask]
67 | X_test, y_test = X[test_mask], y[test_mask]
68 |
69 | # grid search with one-vs-rest classifiers
70 | best_test_acc, best_acc = 0, 0
71 | for c in 2.0 ** np.arange(-10, 11):
72 | clf = OneVsRestClassifier(LogisticRegression(solver='liblinear', C=c))
73 | clf.fit(X_train, y_train)
74 |
75 | y_pred = clf.predict_proba(X_val)
76 | y_pred = np.argmax(y_pred, axis=1)
77 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool)
78 | val_acc = accuracy_score(y_val, y_pred)
79 | if val_acc > best_acc:
80 | best_acc = val_acc
81 | y_pred = clf.predict_proba(X_test)
82 | y_pred = np.argmax(y_pred, axis=1)
83 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool)
84 | best_test_acc = accuracy_score(y_test, y_pred)
85 |
86 | accuracies.append(best_test_acc)
87 | print(np.mean(accuracies))
88 | return accuracies
89 |
90 |
91 | def repeat(n_times):
92 | def decorator(f):
93 | @functools.wraps(f)
94 | def wrapper(*args, **kwargs):
95 | results = [f(*args, **kwargs) for _ in range(n_times)]
96 | statistics = {}
97 | for key in results[0].keys():
98 | values = [r[key] for r in results]
99 | statistics[key] = {
100 | 'mean': np.mean(values),
101 | 'std': np.std(values)}
102 | print_statistics(statistics, f.__name__)
103 | return statistics
104 | return wrapper
105 | return decorator
106 |
107 |
108 | def prob_to_one_hot(y_pred):
109 | ret = np.zeros(y_pred.shape, np.bool)
110 | indices = np.argmax(y_pred, axis=1)
111 | for i in range(y_pred.shape[0]):
112 | ret[i][indices[i]] = True
113 | return ret
114 |
115 |
116 | def print_statistics(statistics, function_name):
117 | print(f'(E) | {function_name}:', end=' ')
118 | for i, key in enumerate(statistics.keys()):
119 | mean = statistics[key]['mean']
120 | std = statistics[key]['std']
121 | print(f'{key}={mean:.4f}+-{std:.4f}', end='')
122 | if i != len(statistics.keys()) - 1:
123 | print(',', end=' ')
124 | else:
125 | print()
126 |
127 |
128 | # borrow from GRACE [https://github.com/CRIPAC-DIG/GRACE/blob/master/eval.py]
129 | # @repeat(5)
130 | def label_classification(embeddings, y, ratio=0.1):
131 | X = embeddings.detach().cpu().numpy()
132 | Y = y.detach().cpu().numpy()
133 | Y = Y.reshape(-1, 1)
134 | onehot_encoder = OneHotEncoder(categories='auto').fit(Y)
135 | Y = onehot_encoder.transform(Y).toarray().astype(np.bool)
136 |
137 | X = normalize(X, norm='l2')
138 |
139 | X_train, X_test, y_train, y_test = train_test_split(X, Y,
140 | test_size=1 - ratio)
141 |
142 | logreg = LogisticRegression(solver='liblinear')
143 | c = 2.0 ** np.arange(-10, 10)
144 |
145 | clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
146 | param_grid=dict(estimator__C=c), n_jobs=8, cv=5,
147 | verbose=0)
148 | clf.fit(X_train, y_train)
149 |
150 | y_pred = clf.predict_proba(X_test)
151 | y_pred = prob_to_one_hot(y_pred)
152 |
153 | acc = accuracy_score(y_test, y_pred)
154 | # return {"accuracy": acc}
155 | return acc
156 |
157 |
158 | def heter_eval(model, z, y, train_mask, val_mask, test_mask, device):
159 | model.eval()
160 | num_classes = y.max().item()+1
161 |
162 | xent = torch.nn.CrossEntropyLoss()
163 | # z = torch.nn.functional.normalize(z, p=2, dim=1)
164 | log = LogReg(model.hidden, num_classes).to(device)
165 | opt = torch.optim.Adam(log.parameters(), lr=1e-2, weight_decay=0.0)
166 |
167 | train_embs = z[train_mask]
168 | val_embs = z[val_mask]
169 | test_embs = z[test_mask]
170 |
171 | best_acc_from_val = torch.zeros(1).cuda()
172 | best_val = torch.zeros(1).cuda()
173 | best_t = 0
174 |
175 | log.train()
176 | for i in range(100):
177 | opt.zero_grad()
178 |
179 | logits = log(train_embs)
180 | loss = xent(logits, y[train_mask].long())
181 |
182 | with torch.no_grad():
183 | ltra = log(train_embs)
184 | lv = log(val_embs)
185 | lt = log(test_embs)
186 | ltra_preds = torch.argmax(ltra, dim=1)
187 | lv_preds = torch.argmax(lv, dim=1)
188 | lt_preds = torch.argmax(lt, dim=1)
189 | train_acc = torch.sum(ltra_preds == y[train_mask]).float() / train_mask.sum()
190 | val_acc = torch.sum(lv_preds == y[val_mask]).float() / val_mask.sum()
191 | test_acc = torch.sum(lt_preds == y[test_mask]).float() / test_mask.sum()
192 |
193 | if val_acc > best_val:
194 | best_acc_from_val = test_acc
195 | best_val = val_acc
196 | best_t = i
197 |
198 | loss.backward()
199 | opt.step()
200 | return best_acc_from_val.cpu().item()
201 |
202 |
203 | def eval(args, model, device):
204 | model.eval()
205 | load_dataset = getattr(importlib.import_module(f"dataset_apis.{args.dataset.lower()}"), 'load_eval_trainset')
206 | dataset = load_dataset()
207 | data = dataset[0].to(device)
208 | z = model.embed(data)
209 | if args.dataset.lower() in ['cora', 'citeseer', 'pubmed', 'dblp']:
210 | acc = label_classification(z, data.y, ratio=0.1)
211 | elif args.dataset.lower() in ['amazon_photos', 'amazon_computers', 'coauthor_cs', 'coauthor_physics']:
212 | acc = fit_logistic_regression(z, data.y)
213 | elif args.dataset.lower() == 'wikics':
214 | acc = fit_logistic_regression_preset_splits(z, data.y, data.train_mask, data.val_mask, data.test_mask)
215 | elif args.dataset.lower() in ["cornell", "wisconsin", "texas"]:
216 | acc_list = []
217 | for run in range(data.train_mask.shape[1]): # These datasets contains 10 different splits. Note: In our paper, we run 20 independent experiments. Different experiments use different splits.
218 | train_mask = data.train_mask[:, run%10]
219 | val_mask = data.val_mask[:, run%10]
220 | test_mask = data.test_mask[:, run%10]
221 | acc = heter_eval(model, z, data.y, train_mask, val_mask, test_mask, device)
222 | acc_list.append(acc)
223 | acc = np.mean(acc_list)
224 | print("Test acc: {:.2f}±{:.2f}".format(np.mean(acc_list)*100, np.std(acc_list)*100))
225 | else:
226 | raise NotImplementedError(f"{args.dataset} is not supported!")
227 | return acc
228 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch_geometric.data import Data
4 | from torch_geometric.utils import to_dense_batch
5 | from torch_geometric.nn import global_mean_pool
6 | from torch_geometric.nn import GCNConv
7 | from torch_geometric.nn.inits import glorot
8 | from torch.nn import ModuleList
9 | import global_var
10 |
11 |
12 | class RoSA(torch.nn.Module):
13 | def __init__(
14 | self,
15 | encoder='gcn',
16 | input_dim=1433,
17 |
18 | # model configuration
19 | layer_num=2, # layers of encoder
20 | hidden=128, # encoder hidden size
21 | proj_shape=(128, 128), # hidden size for predictor
22 |
23 | T = 0.4, # temperature
24 | is_rectified=True, # use rectified cost matrix
25 | topo_t = 2 # temperature for calculating re-scale ratios with topology dist
26 | ):
27 | super(RoSA, self).__init__()
28 |
29 | # load components
30 | if encoder == 'gcn':
31 | Encoder = GCN
32 | elif encoder == 'sage-gcn':
33 | Encoder = GraphSAGE_GCN
34 | else:
35 | raise NotImplementedError(f'{encoder} is not implemented!')
36 | self.encoder = Encoder(input_dim=input_dim, layer_num=layer_num, hidden=hidden)
37 | self.T = T
38 | self.rectified = is_rectified
39 | self.topo_t = topo_t
40 | self.hidden = hidden
41 |
42 | # adapative size for mlp's input dim
43 | fake_x = torch.rand((2, input_dim))
44 | fake_edge_index = torch.LongTensor([[0], [0]])
45 | fake_g = Data(x=fake_x, edge_index=fake_edge_index)
46 | # fake_graph = Data(x=fake_x, edge_index=fake_edge_index, batch=torch.LongTensor([0]))
47 | with torch.no_grad():
48 | rep = self.encoder(fake_g)
49 | hid = rep.shape[-1]
50 | self.projector = Projector([hid, *proj_shape])
51 |
52 | def gen_rep(self, data):
53 | h = self.encoder(data)
54 | z = self.projector(h)
55 | return h, z
56 |
57 | def sim(self, reps1, reps2):
58 | reps1_unit = F.normalize(reps1, dim=-1)
59 | reps2_unit = F.normalize(reps2, dim=-1)
60 | if len(reps1.shape) == 2:
61 | sim_mat = torch.einsum("ik,jk->ij", [reps1_unit, reps2_unit])
62 | elif len(reps1.shape) == 3:
63 | sim_mat = torch.einsum('bik,bjk->bij', [reps1_unit, reps2_unit])
64 | else:
65 | print(f"{len(reps1.shape)} dimension tensor is not supported for this function!")
66 | return sim_mat
67 |
68 | def topology_dist(self, node_idx1, node_idx2):
69 | full_topology_dist = global_var.get_value('topology_dist').cuda()
70 |
71 | batch_size = node_idx1.shape[0]
72 | batch_subpology_dist = [full_topology_dist.index_select(dim=0, index=node_idx1[i]).index_select(dim=1, index=node_idx2[i]) for i in range(batch_size)]
73 | batch_subpology_dist = torch.stack(batch_subpology_dist)
74 | return batch_subpology_dist
75 |
76 | def _batched_semi_emd_loss(self, out1, avg_out1, out2, avg_out2, lamb=20, rescale_ratio=None):
77 | assert out1.shape[0] == out2.shape[0] and avg_out1.shape == avg_out2.shape
78 |
79 | cost_matrix = 1-self.sim(out1, out2)
80 | if rescale_ratio is not None:
81 | cost_matrix = cost_matrix * rescale_ratio
82 |
83 | # Sinkhorn iteration
84 | iter_times = 5
85 | with torch.no_grad():
86 | r = torch.bmm(out1, avg_out2.transpose(1,2))
87 | r[r<=0] = 1e-8
88 | r = r / r.sum(dim=1, keepdim=True)
89 | c = torch.bmm(out2, avg_out1.transpose(1,2))
90 | c[c<=0] = 1e-8
91 | c = c / c.sum(dim=1, keepdim=True)
92 | P = torch.exp(-1*lamb*cost_matrix)
93 | u = (torch.ones_like(c)/c.shape[1])
94 | for i in range(iter_times):
95 | v = torch.div(r, torch.bmm(P, u))
96 | u = torch.div(c, torch.bmm(P.transpose(1,2), v))
97 | u = u.squeeze(dim=-1)
98 | v = v.squeeze(dim=-1)
99 | transport_matrix = torch.bmm(torch.bmm(matrix_diag(v), P), matrix_diag(u))
100 | assert cost_matrix.shape == transport_matrix.shape
101 |
102 | # S = torch.mul(transport_matrix, 1-cost_matrix).sum(dim=1).sum(dim=1, keepdim=True)
103 | emd = torch.mul(transport_matrix, cost_matrix).sum(dim=1).sum(dim=1, keepdim=True)
104 | S = 2-2*emd
105 | return S
106 |
107 | def batched_semi_emd_loss(self, reps1, reps2, batch1, batch2, original_idx1=None, original_idx2=None):
108 | batch_size1 = batch1.max().cpu().item()
109 | batch_size2 = batch2.max().cpu().item()
110 | assert batch_size1 == batch_size2
111 | batch_size = batch_size1+1
112 | # avg nodes rep
113 | y_online1_pooling = global_mean_pool(reps1, batch1)
114 | y_online2_pooling = global_mean_pool(reps2, batch2)
115 | avg_out1 = y_online1_pooling[:, None, :] # (B,1,D)
116 | avg_out2 = y_online2_pooling[:, None, :] # (B,1,D)
117 |
118 | # x reps from sparse to dense
119 | out1, mask1 = to_dense_batch(reps1, batch1, fill_value=1e-8) # (B,N,D), B means batch size, N means the number of nodes, D means hidden size
120 | out2, mask2 = to_dense_batch(reps2, batch2, fill_value=1e-8) # (B,M,D)
121 |
122 | if original_idx1 != None and original_idx2 != None:
123 | dense_original_idx1, idx_mask1 = to_dense_batch(original_idx1, batch=batch1, fill_value=0)
124 | dense_original_idx2, idx_mask2 = to_dense_batch(original_idx2, batch=batch2, fill_value=0)
125 | topology_dist = self.topology_dist(dense_original_idx1, dense_original_idx2)
126 | rescale_ratio = torch.sigmoid(topology_dist/self.topo_t)
127 | topo_mask = torch.bmm(idx_mask1[:,:,None].float(), idx_mask2[:,None,:].float()).bool()
128 | loss_pos = self._batched_semi_emd_loss(out1, avg_out1, out2, avg_out2, rescale_ratio=rescale_ratio) * 2
129 | else:
130 | loss_pos = self._batched_semi_emd_loss(out1, avg_out1, out2, avg_out2)
131 |
132 | T = self.T # temperature
133 | f = lambda x: torch.exp(x/T)
134 |
135 | total_neg_loss = 0
136 | # completely create negative samples
137 | neg_index = list(range(batch_size))
138 | for i in range((batch_size-1)):
139 | neg_index.insert(0, neg_index.pop(-1))
140 | out1_perm = out1[neg_index].clone()
141 | out2_perm = out2[neg_index].clone()
142 | avg_out1_perm = avg_out1[neg_index].clone()
143 | avg_out2_perm = avg_out2[neg_index].clone()
144 | total_neg_loss += f(self._batched_semi_emd_loss(out1, avg_out1, out1_perm, avg_out1_perm)) + f(self._batched_semi_emd_loss(out1, avg_out1, out2_perm, avg_out2_perm))
145 |
146 | loss = -torch.log(f(loss_pos) / (total_neg_loss))
147 |
148 | return loss
149 |
150 |
151 |
152 | def gen_loss(self, reps1, reps2, batch1=None, batch2=None, original_idx1=None, original_idx2=None):
153 | # data with batch
154 | loss1 = self.batched_semi_emd_loss(reps1, reps2, batch1, batch2, original_idx1=original_idx1, original_idx2=original_idx2)
155 | loss2 = self.batched_semi_emd_loss(reps2, reps1, batch2, batch1, original_idx1=original_idx2, original_idx2=original_idx1)
156 | loss = (loss1*0.5 + loss2*0.5).mean()
157 | return loss
158 |
159 | def forward(self, view1, view2, batch1=None, batch2=None):
160 | h1, z1 = self.gen_rep(view1)
161 | h2, z2 = self.gen_rep(view2)
162 |
163 | if hasattr(view1, 'original_idx') and self.rectified:
164 | original_idx1, original_idx2 = view1.original_idx+1, view2.original_idx+1 # shift one, zero-index is stored for padding nodes
165 | loss = self.gen_loss(z1, z2, batch1, batch2, original_idx1, original_idx2)
166 | else:
167 | loss = self.gen_loss(z1, z2, batch1, batch2)
168 |
169 | return loss
170 |
171 | def embed(self, data):
172 | h = self.encoder(data)
173 | return h.detach()
174 |
175 |
176 | class GCN(torch.nn.Module):
177 | def __init__(self, input_dim, layer_num=2, hidden=128):
178 | super(GCN, self).__init__()
179 | self.layer_num = layer_num
180 | self.hidden = hidden
181 | self.input_dim = input_dim
182 |
183 | self.convs = ModuleList()
184 | if self.layer_num > 1:
185 | self.convs.append(GCNConv(input_dim, hidden*2))
186 | for i in range(layer_num-2):
187 | self.convs.append(GCNConv(hidden*2, hidden*2))
188 | glorot(self.convs[i].weight)
189 | self.convs.append(GCNConv(hidden*2, hidden))
190 | glorot(self.convs[-1].weight)
191 |
192 | else: # one layer gcn
193 | self.convs.append(GCNConv(input_dim, hidden))
194 | glorot(self.convs[-1].weight)
195 |
196 |
197 | def forward(self, data):
198 | x, edge_index = data.x, data.edge_index
199 | for i in range(self.layer_num):
200 | x = F.relu(self.convs[i](x, edge_index))
201 | return x
202 |
203 |
204 | class GraphSAGE_GCN(torch.nn.Module):
205 | def __init__(self, input_dim, layer_num=3, hidden=512):
206 | super().__init__()
207 | self.convs = torch.nn.ModuleList()
208 | self.layer = layer_num
209 | self.acts = torch.nn.ModuleList()
210 | self.norms = torch.nn.ModuleList()
211 |
212 | for i in range(self.layer):
213 | if i == 0:
214 | self.convs.append(SAGEConv(input_dim, hidden, root_weight=True))
215 | else:
216 | self.convs.append(SAGEConv(hidden, hidden, root_weight=True))
217 | # self.acts.append(torch.nn.PReLU(hidden))
218 | self.acts.append(torch.nn.ELU())
219 | self.norms.append(torch.nn.BatchNorm1d(hidden))
220 |
221 | def forward(self, data):
222 | x, edge_index = data.x, data.edge_index
223 | for i in range(self.layer):
224 | x = self.acts[i](self.norms[i](self.convs[i](x, edge_index)))
225 | return x
226 |
227 |
228 | class Projector(torch.nn.Module):
229 | def __init__(self, shape=()):
230 | super(Projector, self).__init__()
231 | if len(shape) < 3:
232 | raise Exception("Wrong shape for Projector")
233 |
234 | self.main = torch.nn.Sequential(
235 | torch.nn.Linear(shape[0], shape[1]),
236 | torch.nn.BatchNorm1d(shape[1]),
237 | torch.nn.ReLU(),
238 | torch.nn.Linear(shape[1], shape[2])
239 | )
240 |
241 | def forward(self, x):
242 | return self.main(x)
243 |
244 |
245 | def matrix_diag(diagonal):
246 | N = diagonal.shape[-1]
247 | shape = diagonal.shape[:-1] + (N, N)
248 | device, dtype = diagonal.device, diagonal.dtype
249 | result = torch.zeros(shape, dtype=dtype, device=device)
250 | indices = torch.arange(result.numel(), device=device).reshape(shape)
251 | indices = indices.diagonal(dim1=-2, dim2=-1)
252 | result.view(-1)[indices] = diagonal
253 | return result
254 |
255 | class LogReg(torch.nn.Module):
256 | def __init__(self, ft_in, nb_classes):
257 | super(LogReg, self).__init__()
258 | self.fc = torch.nn.Linear(ft_in, nb_classes)
259 |
260 | for m in self.modules():
261 | self.weights_init(m)
262 |
263 | def weights_init(self, m):
264 | if isinstance(m, torch.nn.Linear):
265 | torch.nn.init.xavier_uniform_(m.weight.data)
266 | # torch.nn.init.xavier_normal_(m.weight.data)
267 | if m.bias is not None:
268 | m.bias.data.fill_(0.0)
269 |
270 | def forward(self, seq):
271 | ret = self.fc(seq)
272 | return ret
273 |
--------------------------------------------------------------------------------