├── .gitignore
├── README.md
├── datasets
├── __init__.py
└── dataset
│ ├── Hindex.py
│ └── __init__.py
├── gcc.py
├── graphcontrol.py
├── models
├── __init__.py
├── encoder.py
├── gcc.py
├── gcc_graphcontrol.py
├── mlp.py
├── model_manager.py
└── pooler.py
├── node2vec.py
├── optimizers
└── __init__.py
├── png
└── framework.png
└── utils
├── __init__.py
├── args.py
├── augmentation.py
├── normalize.py
├── random.py
├── register.py
├── sampling.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/data
2 | checkpoint/
3 |
4 | # all pyc files_
5 | **/__pycache__
6 |
7 | **/.vscode
8 | **/ipynb_checkpoints
9 | *.ipynb
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GraphControl: Adding Conditional Control to Universal Graph Pre-trained Models for Graph Domain Transfer Learning
2 |
3 | **Official implementation of paper**
[GraphControl: Adding Conditional Control to Universal Graph Pre-trained Models for Graph Domain Transfer Learning](https://arxiv.org/abs/2310.07365)
4 |
5 | Yun Zhu*, Yaoke Wang*, Haizhou Shi, Zhenshuo Zhang, Dian Jiao, Siliang Tang†
6 |
7 | In WWW 2024
8 |
9 | ## Overview
10 | This is the first work to solve the "transferability-specificity dilemma" in graph domain transfer learning. To address this challenge, we introduce an innovative deployment module coined as GraphControl, motivated by ControlNet, to realize better graph domain transfer learning. The overview of our method is depicted as:
11 |
12 | 
13 |
14 |
15 | ## Setup
16 |
17 | ```bash
18 | conda create -n GraphControl python==3.9
19 | conda activate GraphControl
20 | conda install pytorch==2.1.0 torchaudio==2.1.0 cudatoolkit=12.1 -c pytorch -c conda-forge
21 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
22 | ```
23 |
24 | ## Download GCC Pretrained Weight
25 |
26 | **Download GCC checkpoints**
27 | Download GCC checkpoint from https://drive.google.com/file/d/1lYW_idy9PwSdPEC7j9IH5I5Hc7Qv-22-/view and save it into `./checkpoint/gcc.pth`.
28 |
29 | ## For Attributed Graphs
30 |
31 | **Only GCC**
32 |
33 | ```bash
34 | CUDA_VISIBLE_DEVICES=0 python gcc.py --lr 1e-3 --epochs 100 --dataset Cora_ML --model GCC --use_adj --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
35 | ```
36 |
37 | **GCC with GraphControl**
38 |
39 | ```bash
40 | CUDA_VISIBLE_DEVICES=0 python graphcontrol.py --dataset Cora_ML --epochs 100 --lr 0.5 --optimizer adamw --weight_decay 5e-4 --threshold 0.17 --walk_steps 256 --restart 0.8 --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
41 | ```
42 |
43 | ## For Non-attributed Graphs
44 |
45 | **For non-attribute graphs, we need to generate nodes attributes through node2vec firstly**
46 |
47 | ```bash
48 | CUDA_VISIBLE_DEVICES=0 python node2vec.py --dataset Hindex --lr 1e-2 --epochs 100
49 | ```
50 |
51 | **Then, we can train it as the same way with attributed graphs**
52 |
53 | ```bash
54 | CUDA_VISIBLE_DEVICES=0 python graphcontrol.py --dataset Hindex --epochs 100 --lr 0.1 --optimizer sgd --weight_decay 5e-4 --threshold 0.17 --walk_steps 256 --restart 0.5 --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
55 | ```
56 |
57 | ## Illustration of arguements
58 | ```
59 | --dataset: default Cora_ML, [Cora_ML, Photo, Physics, DBLP, usa, brazil, europe, Hindex] can also be choosen
60 | --model: default GCC_GraphControl, [GCC, GCC_GraphControl] can also be choosen. GCC refers to utilizing GCC as a pre-trained model and fine-tuning it on target data. On the other hand, GCC_GraphControl involves incorporating GraphControl with GCC to address the "transferability-specificity dilemma." Additional pre-trained models will be introduced in the updated version.
61 | ```
62 | More details and explanations are in `utils/args.py`
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.utils import to_undirected, homophily
2 | import torch_geometric.transforms as T
3 | import copy
4 | import torch
5 | import os
6 | import numpy as np
7 |
8 | from .dataset import Amazon, Coauthor, Airports, CitationFull, HindexDataset
9 | from utils.random import reset_random_seed
10 | from utils.transforms import obtain_attributes
11 |
12 |
13 | dataset_dict = {
14 | 'Photo': Amazon,
15 | 'Physics': Coauthor,
16 | 'usa': Airports,
17 | 'brazil': Airports,
18 | 'europe': Airports,
19 | 'DBLP': CitationFull,
20 | 'Cora_ML': CitationFull,
21 | 'Hindex': HindexDataset
22 | }
23 |
24 | PATH = './datasets/data'
25 |
26 | def load_dataset(dataset_name, trans=None):
27 | if dataset_name in ['Hindex']:
28 | if trans == None:
29 | return dataset_dict[dataset_name](root=f'{PATH}/{dataset_name}')
30 | else:
31 | return dataset_dict[dataset_name](root=f'{PATH}/{dataset_name}', transform=T.Compose([trans]))
32 | else:
33 | if trans == None:
34 | return dataset_dict[dataset_name](root=PATH, name=dataset_name)
35 | else:
36 | return dataset_dict[dataset_name](root=PATH, name=dataset_name, transform=T.Compose([trans]))
37 |
38 |
39 | class NodeDataset:
40 | def __init__(self, dataset_name, trans=None, n_seeds=[0]) -> None:
41 | self.path = PATH
42 | self.dataset_name = dataset_name
43 | if dataset_name in ['Hindex']:
44 | self.dataset = dataset_dict[dataset_name](root=f'{self.path}/{dataset_name}', transform=trans)
45 | else:
46 | self.dataset = dataset_dict[dataset_name](root=f'{self.path}', name=dataset_name, transform=trans)
47 |
48 | self.num_classes = self.dataset.num_classes
49 | self.num_node_features = self.dataset.num_node_features
50 |
51 | assert len(self.dataset) == 1, "Training data consists of multiple graphs!"
52 |
53 | self.data = self.dataset[0]
54 |
55 | # parse it into undirected graph
56 | edge_index = to_undirected(self.data.edge_index)
57 | self.data.edge_index = edge_index
58 | self.num_nodes = self.data.x.shape[0]
59 |
60 | # backup original node attributes and edges
61 | self.backup_x = copy.deepcopy(self.data.x)
62 | self.backup_edges = copy.deepcopy(self.data.edge_index)
63 | self.random_split = False
64 |
65 | # For datasets without node attributes, we will use node embeddings from Node2Vec as their node attributes
66 | attr_path = f'{PATH}/{dataset_name}/processed/node2vec.pt'
67 | if dataset_name in ['USA', 'Europe', 'Brazil', 'Hindex'] and os.path.exists(attr_path):
68 | x = torch.load(attr_path)
69 | self.data.x = x.detach()
70 |
71 | # If the dataset does not contain preset splits, we will randomly split it into train:test=1:9 twenty times
72 | if not hasattr(self.data, 'train_mask'):
73 | self.random_split = True
74 | num_train = int(self.num_nodes*0.1)
75 |
76 | train_mask_list = []
77 | test_mask_list = []
78 | for seed in n_seeds:
79 | reset_random_seed(seed)
80 |
81 | rand_node_idx = torch.randperm(self.num_nodes)
82 | train_idx = rand_node_idx[:num_train]
83 | train_mask = torch.zeros(self.num_nodes).bool()
84 | train_mask[train_idx] = True
85 |
86 | test_mask = torch.ones_like(train_mask).bool()
87 | test_mask[train_idx] = False
88 | train_mask_list.append(train_mask.unsqueeze(1))
89 | test_mask_list.append(test_mask.unsqueeze(1))
90 |
91 | self.data.train_mask = torch.cat(train_mask_list, dim=1)
92 | self.data.test_mask = torch.cat(test_mask_list, dim=1)
93 |
94 |
95 | def generate_subgraph(self):
96 | pass
97 |
98 | def split_train_test(self, split_ratio=0.8):
99 | raise NotImplementedError('do not set parameter ')
100 |
101 | def to(self, device):
102 | self.data = self.data.to(device)
103 |
104 | def replace_node_attributes(self, use_adj, threshold, num_dim):
105 | self.num_node_features = num_dim
106 | self.data.x = obtain_attributes(self.data, use_adj, threshold, num_dim)
107 |
108 | def obtain_node_attributes(self, use_adj, threshold=0.1, num_dim=32):
109 | return obtain_attributes(self.data, use_adj, threshold, num_dim)
110 |
111 | def print_statistics(self):
112 | h = homophily(self.data.edge_index, self.data.y)
113 | from collections import Counter
114 | if len(self.data.y.shape) >= 2: # For one-hot labels
115 | y = self.data.y.argmax(1)
116 | else:
117 | y = self.data.y
118 | count = Counter(y.tolist())
119 | total_num = sum(count.values())
120 | class_ratio = {}
121 | for key, value in count.items():
122 | r = round(value / total_num, 2)
123 | class_ratio[key] = r
124 | print(f'{self.dataset_name}: Number of nodes: {self.num_nodes}, Dimension of features: {self.num_node_features}, Number of edges: {self.data.edge_index.shape[1]}, Number of classes: {self.num_classes}, Homophily: {h}, Class ratio: {class_ratio}.')
125 | if self.random_split:
126 | print('The dataset does not contain preset splits, we randomly split the dataset twenty times. Train: teset = 1:9')
127 | else:
128 | print('We use the preset splits.')
129 |
130 |
131 | if __name__ == '__main__':
132 | dataset = NodeDataset('Hindex')
133 | print(dataset)
134 |
--------------------------------------------------------------------------------
/datasets/dataset/Hindex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os.path as osp
3 | from torch_geometric.data import InMemoryDataset, Data
4 | from collections import defaultdict
5 | import numpy as np
6 |
7 |
8 | class HindexDataset(InMemoryDataset):
9 |
10 | def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
11 | self.name = 'Hindex'
12 | self.root = root
13 |
14 | super().__init__(self.root, transform, pre_transform, pre_filter)
15 | self.data, self.slices = torch.load(self.processed_paths[0])
16 |
17 |
18 | @property
19 | def raw_file_names(self):
20 | return ['aminer_hindex_rand20intop200_5000.edgelist', 'aminer_hindex_rand20intop200_5000.nodelabel']
21 |
22 | @property
23 | def processed_file_names(self):
24 | return ['data.pt']
25 |
26 | @property
27 | def raw_dir(self) -> str:
28 | return osp.join(self.root, 'raw')
29 |
30 | def process(self):
31 | # Read data into huge `Data` list.
32 | edge_index, y, self.node2id = self._preprocess(self.raw_paths[0], self.raw_paths[1])
33 | data = Data(x=torch.zeros(y.size(0), 1), edge_index=edge_index, y=y.argmax(1))
34 | data_list = [data]
35 |
36 | if self.pre_filter is not None:
37 | data_list = [data for data in data_list if self.pre_filter(data)]
38 |
39 | if self.pre_transform is not None:
40 | data_list = [self.pre_transform(data) for data in data_list]
41 |
42 | data, slices = self.collate(data_list)
43 | torch.save((data, slices), self.processed_paths[0])
44 |
45 | def _preprocess(self, edge_list_path, node_label_path):
46 | with open(edge_list_path) as f:
47 | edge_list = []
48 | node2id = defaultdict(int)
49 | for line in f:
50 | x, y = list(map(int, line.split()))
51 | # Reindex
52 | if x not in node2id:
53 | node2id[x] = len(node2id)
54 | if y not in node2id:
55 | node2id[y] = len(node2id)
56 | edge_list.append([node2id[x], node2id[y]])
57 | edge_list.append([node2id[y], node2id[x]])
58 |
59 | num_nodes = len(node2id)
60 | with open(node_label_path) as f:
61 | nodes = []
62 | labels = []
63 | label2id = defaultdict(int)
64 | for line in f:
65 | x, label = list(map(int, line.split()))
66 | if label not in label2id:
67 | label2id[label] = len(label2id)
68 | nodes.append(node2id[x])
69 | if "Hindex" in self.name:
70 | labels.append(label)
71 | else:
72 | labels.append(label2id[label])
73 | if "Hindex" in self.name:
74 | median = np.median(labels)
75 | labels = [int(label > median) for label in labels]
76 | assert num_nodes == len(set(nodes))
77 | y = torch.zeros(num_nodes, len(label2id))
78 | y[nodes, labels] = 1
79 | return torch.LongTensor(edge_list).t(), y, node2id
--------------------------------------------------------------------------------
/datasets/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .Hindex import HindexDataset
2 | from torch_geometric.datasets import Amazon, Coauthor, Airports, CitationFull
3 |
--------------------------------------------------------------------------------
/gcc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import numpy as np
4 | from torch_geometric.loader import ShaDowKHopSampler, DataLoader
5 |
6 | from utils.random import reset_random_seed
7 | from utils.args import Arguments
8 | from models import load_model
9 | from datasets import NodeDataset
10 | from utils.transforms import process_attributes
11 | from utils.sampling import ego_graphs_sampler, collect_subgraphs
12 |
13 |
14 | def preprocess(config, dataset_obj):
15 | kwargs = {'batch_size': config.batch_size, 'num_workers': 3, 'persistent_workers': True}
16 |
17 | print('generating subgraphs....')
18 |
19 | train_loader, test_loader = None, None
20 |
21 | train_idx = dataset_obj.data.train_mask.nonzero().squeeze()
22 | test_idx = dataset_obj.data.test_mask.nonzero().squeeze()
23 |
24 | train_graphs = collect_subgraphs(train_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart)
25 | test_graphs = collect_subgraphs(test_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart)
26 |
27 | if config.use_adj:
28 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in train_graphs]
29 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in test_graphs]
30 |
31 | dataset_obj.num_node_features = config.num_dim
32 | train_loader = DataLoader(train_graphs, shuffle=True, **kwargs)
33 | test_loader = DataLoader(test_graphs, **kwargs)
34 |
35 | return train_loader, test_loader
36 |
37 |
38 | def main(config):
39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40 |
41 | dataset_obj = NodeDataset(config.dataset, n_seeds=config.seeds)
42 | dataset_obj.print_statistics()
43 |
44 | acc_list = []
45 |
46 | train_masks = dataset_obj.data.train_mask
47 | test_masks = dataset_obj.data.test_mask
48 |
49 | for _, seed in enumerate(config.seeds):
50 | reset_random_seed(seed)
51 |
52 | if dataset_obj.random_split:
53 | dataset_obj.data.train_mask = train_masks[:, seed]
54 | dataset_obj.data.test_mask = test_masks[:, seed]
55 |
56 | train_loader, test_loader = preprocess(config, dataset_obj)
57 | model = load_model(dataset_obj.num_node_features, dataset_obj.num_classes, config).to(device)
58 |
59 | # training model
60 | train_subgraph(config, model, train_loader, device)
61 | acc = eval_subgraph(config, model, test_loader, device)
62 |
63 | acc_list.append(acc)
64 | print(f'Seed: {seed}, Accuracy: {acc:.4f}')
65 |
66 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list)
67 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")
68 |
69 | def train_subgraph(config, model, train_loader, device):
70 |
71 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
72 | criterion = torch.nn.CrossEntropyLoss()
73 | model.train()
74 | for _ in tqdm(range(config.epochs)):
75 | for batch in train_loader:
76 | batch = batch.to(device)
77 | optimizer.zero_grad()
78 | if not hasattr(batch, 'root_n_id'):
79 | batch.root_n_id = batch.root_n_index
80 | # sign flip, because the sign of eigen-vectors can be filpped randomly (annotate this operate if we conduct eigen-decomposition on full graph)
81 | sign_flip = torch.rand(batch.x.size(1)).to(device)
82 | sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
83 | batch.x = batch.x*sign_flip.unsqueeze(0)
84 |
85 | out = model.forward_subgraph(batch.x, batch.edge_index, batch.batch, batch.root_n_id)
86 | loss = criterion(out, batch.y)
87 | loss.backward()
88 | optimizer.step()
89 |
90 |
91 | def eval_subgraph(config, model, test_loader, device):
92 | model.eval()
93 |
94 | correct = 0
95 | total_num = 0
96 | for batch in test_loader:
97 | batch = batch.to(device)
98 | if not hasattr(batch, 'root_n_id'):
99 | batch.root_n_id = batch.root_n_index
100 |
101 | preds = model.forward_subgraph(batch.x, batch.edge_index, batch.batch, batch.root_n_id).argmax(dim=1)
102 | correct += (preds == batch.y).sum().item()
103 | total_num += batch.y.shape[0]
104 | acc = correct / total_num
105 | return acc
106 |
107 | if __name__ == '__main__':
108 | config = Arguments().parse_args()
109 |
110 | main(config)
--------------------------------------------------------------------------------
/graphcontrol.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.loader import DataLoader
3 | from tqdm import tqdm
4 | import numpy as np
5 |
6 |
7 | from utils.random import reset_random_seed
8 | from utils.args import Arguments
9 | from utils.sampling import collect_subgraphs
10 | from utils.transforms import process_attributes, obtain_attributes
11 | from models import load_model
12 | from datasets import NodeDataset
13 | from optimizers import create_optimizer
14 |
15 |
16 | def preprocess(config, dataset_obj, device):
17 | kwargs = {'batch_size': config.batch_size, 'num_workers': 4, 'persistent_workers': True, 'pin_memory': True}
18 |
19 | print('generating subgraphs....')
20 |
21 | train_idx = dataset_obj.data.train_mask.nonzero().squeeze()
22 | test_idx = dataset_obj.data.test_mask.nonzero().squeeze()
23 |
24 | train_graphs = collect_subgraphs(train_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart)
25 | test_graphs = collect_subgraphs(test_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart)
26 |
27 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in train_graphs]
28 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in test_graphs]
29 |
30 |
31 | train_loader = DataLoader(train_graphs, shuffle=True, **kwargs)
32 | test_loader = DataLoader(test_graphs, **kwargs)
33 |
34 | return train_loader, test_loader
35 |
36 |
37 | def finetune(config, model, train_loader, device, full_x_sim, test_loader):
38 | # freeze the pre-trained encoder (left branch)
39 | for k, v in model.named_parameters():
40 | if 'encoder' in k:
41 | v.requires_grad = False
42 |
43 | model.reset_classifier()
44 | eval_steps = 3
45 | patience = 15
46 | count = 0
47 | best_acc = 0
48 |
49 | params = filter(lambda p: p.requires_grad, model.parameters())
50 | optimizer = create_optimizer(name=config.optimizer, parameters=params, lr=config.lr, weight_decay=config.weight_decay)
51 | criterion = torch.nn.CrossEntropyLoss()
52 | process_bar = tqdm(range(config.epochs))
53 |
54 | for epoch in process_bar:
55 | for data in train_loader:
56 | optimizer.zero_grad()
57 | model.train()
58 |
59 | data = data.to(device)
60 |
61 | if not hasattr(data, 'root_n_id'):
62 | data.root_n_id = data.root_n_index
63 |
64 | sign_flip = torch.rand(data.x.size(1)).to(device)
65 | sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
66 | x = data.x * sign_flip.unsqueeze(0)
67 |
68 | x_sim = full_x_sim[data.original_idx]
69 | preds = model.forward_subgraph(x, x_sim, data.edge_index, data.batch, data.root_n_id, frozen=True)
70 |
71 | loss = criterion(preds, data.y)
72 | loss.backward()
73 | optimizer.step()
74 |
75 | if epoch % eval_steps == 0:
76 | acc = eval_subgraph(config, model, test_loader, device, full_x_sim)
77 | process_bar.set_postfix({"Epoch": epoch, "Accuracy": f"{acc:.4f}"})
78 | if best_acc < acc:
79 | best_acc = acc
80 | count = 0
81 | else:
82 | count += 1
83 |
84 | if count == patience:
85 | break
86 |
87 | return best_acc
88 |
89 |
90 | def main(config):
91 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92 |
93 | dataset_obj = NodeDataset(config.dataset, n_seeds=config.seeds)
94 | dataset_obj.print_statistics()
95 |
96 | # For large graph, we use cpu to preprocess it rather than gpu because of OOM problem.
97 | if dataset_obj.num_nodes < 30000:
98 | dataset_obj.to(device)
99 | x_sim = obtain_attributes(dataset_obj.data, use_adj=False, threshold=config.threshold).to(device)
100 |
101 | dataset_obj.to('cpu') # Otherwise the deepcopy will raise an error
102 | num_node_features = config.num_dim
103 |
104 | train_masks = dataset_obj.data.train_mask
105 | test_masks = dataset_obj.data.test_mask
106 |
107 | acc_list = []
108 |
109 | for i, seed in enumerate(config.seeds):
110 | reset_random_seed(seed)
111 | if dataset_obj.random_split:
112 | dataset_obj.data.train_mask = train_masks[:, seed]
113 | dataset_obj.data.test_mask = test_masks[:, seed]
114 |
115 | train_loader, test_loader = preprocess(config, dataset_obj, device)
116 |
117 | model = load_model(num_node_features, dataset_obj.num_classes, config)
118 | model = model.to(device)
119 |
120 | # finetuning model
121 | best_acc = finetune(config, model, train_loader, device, x_sim, test_loader)
122 |
123 | acc_list.append(best_acc)
124 | print(f'Seed: {seed}, Accuracy: {best_acc:.4f}')
125 |
126 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list)
127 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")
128 |
129 |
130 | def eval_subgraph(config, model, test_loader, device, full_x_sim):
131 | model.eval()
132 |
133 | correct = 0
134 | total_num = 0
135 | for batch in test_loader:
136 | batch = batch.to(device)
137 | if not hasattr(batch, 'root_n_id'):
138 | batch.root_n_id = batch.root_n_index
139 | x_sim = full_x_sim[batch.original_idx]
140 | preds = model.forward_subgraph(batch.x, x_sim, batch.edge_index, batch.batch, batch.root_n_id, frozen=True).argmax(dim=1)
141 | correct += (preds == batch.y).sum().item()
142 | total_num += batch.y.shape[0]
143 | acc = correct / total_num
144 | return acc
145 |
146 | if __name__ == '__main__':
147 | config = Arguments().parse_args()
148 |
149 | main(config)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .gcc import GCC
2 | from .gcc_graphcontrol import GCC_GraphControl
3 |
4 | from .model_manager import load_model
--------------------------------------------------------------------------------
/models/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import ModuleList
4 | from torch_geometric.nn.inits import glorot
5 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, global_mean_pool
6 | from torch.nn import BatchNorm1d, Identity
7 | import torch.nn as nn
8 | from utils.register import register
9 |
10 |
11 | def get_activation(name: str):
12 | activations = {
13 | 'relu': F.relu,
14 | 'hardtanh': F.hardtanh,
15 | 'elu': F.elu,
16 | 'leakyrelu': F.leaky_relu,
17 | 'prelu': torch.nn.PReLU(),
18 | 'rrelu': F.rrelu
19 | }
20 | return activations[name]
21 |
22 |
23 | @register.encoder_register
24 | class GCN_Encoder(torch.nn.Module):
25 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True):
26 | super(GCN_Encoder, self).__init__()
27 | self.layer_num = layer_num
28 | self.hidden = hidden_size
29 | self.input_dim = input_dim
30 | self.activation = get_activation(activation)
31 | self.dropout = torch.nn.Dropout(dropout)
32 | self.last_act = last_activation
33 | self.use_bn = use_bn
34 |
35 | self.convs = ModuleList()
36 | self.bns = ModuleList()
37 | # self.acts = ModuleList()
38 | if self.layer_num > 1:
39 | self.convs.append(GCNConv(input_dim, hidden_size))
40 | for i in range(layer_num-2):
41 | self.convs.append(GCNConv(hidden_size, hidden_size))
42 | # glorot(self.convs[i].weight) # initialization
43 | self.convs.append(GCNConv(hidden_size, hidden_size))
44 | # glorot(self.convs[-1].weight)
45 | for i in range(layer_num):
46 | if use_bn:
47 | self.bns.append(BatchNorm1d(hidden_size))
48 | else:
49 | self.bns.append(Identity())
50 |
51 | else: # one layer gcn
52 | self.convs.append(GCNConv(input_dim, hidden_size))
53 | # glorot(self.convs[-1].weight)
54 | if use_bn:
55 | self.bns.append(BatchNorm1d(hidden_size))
56 | else:
57 | self.bns.append(Identity())
58 | # self.acts.append(self.activation)
59 |
60 | def forward(self, x, edge_index, edge_weight=None):
61 | # print('Inside Model: num graphs: {}, device: {}'.format(
62 | # data.num_graphs, data.batch.device))
63 | # x, edge_index = data.x, data.edge_index
64 | for i in range(self.layer_num):
65 | # x = self.convs[i](x, edge_index, edge_weight)
66 | # print(i, x.dtype, self.convs[i].lin.weight.dtype)
67 | x = self.bns[i](self.convs[i](x, edge_index, edge_weight))
68 | if i == self.layer_num - 1 and not self.last_act:
69 | pass
70 | # print(i, 'pass last relu')
71 | else:
72 | x = self.activation(x)
73 | x = self.dropout(x)
74 | # x = self.activation(self.convs[i](x, edge_index, edge_weight))
75 | # x = self.bns[i](x)
76 | # x = self.activation(self.bns[i](self.convs[i](x, edge_index)))
77 | return x
78 |
79 | def reset_parameters(self):
80 | for i in range(self.layer_num):
81 | self.convs[i].reset_parameters()
82 | if self.use_bn:
83 | self.bns[i].reset_parameters()
84 |
85 |
86 | @register.encoder_register
87 | class GIN_Encoder(torch.nn.Module):
88 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True):
89 | super(GIN_Encoder, self).__init__()
90 | self.layer_num = layer_num
91 | self.hidden_size = hidden_size
92 | self.input_dim = input_dim
93 | self.activation = get_activation(activation)
94 | self.dropout = torch.nn.Dropout(dropout)
95 | self.last_act = last_activation
96 | self.use_bn = use_bn
97 |
98 | self.convs = ModuleList()
99 | self.bns = ModuleList()
100 |
101 | self.readout = global_mean_pool
102 | if self.layer_num > 1:
103 | self.convs.append(GINConv(nn.Sequential(nn.Linear(input_dim, hidden_size),
104 | nn.BatchNorm1d(hidden_size), nn.ReLU(),
105 | nn.Linear(hidden_size, hidden_size))))
106 | for i in range(layer_num-1):
107 | self.convs.append(GINConv(nn.Sequential(nn.Linear(hidden_size, hidden_size),
108 | nn.BatchNorm1d(hidden_size), nn.ReLU(),
109 | nn.Linear(hidden_size, hidden_size))))
110 | for i in range(layer_num):
111 | if use_bn:
112 | self.bns.append(BatchNorm1d(hidden_size))
113 | else:
114 | self.bns.append(Identity())
115 |
116 | else:
117 | self.convs.append(GINConv(nn.Sequential(nn.Linear(input_dim, hidden_size),
118 | nn.BatchNorm1d(hidden_size), nn.ReLU(),
119 | nn.Linear(hidden_size, hidden_size))))
120 | if use_bn:
121 | self.bns.append(BatchNorm1d(hidden_size))
122 | else:
123 | self.bns.append(Identity())
124 |
125 | def forward(self, x, edge_index, **kwargs):
126 | for i in range(self.layer_num):
127 | x = self.bns[i](self.convs[i](x, edge_index))
128 | if i == self.layer_num - 1 and not self.last_act:
129 | pass
130 | else:
131 | x = self.activation(x)
132 | x = self.dropout(x)
133 |
134 | return x
135 |
136 | def reset_parameters(self):
137 | for i in range(self.layer_num):
138 | self.convs[i].reset_parameters()
139 | if self.use_bn:
140 | self.bns[i].reset_parameters()
141 |
142 |
143 | @register.encoder_register
144 | class GAT_Encoder(torch.nn.Module):
145 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True):
146 | super(GAT_Encoder, self).__init__()
147 | self.layer_num = layer_num
148 | self.hidden = hidden_size
149 | self.input_dim = input_dim
150 | self.activation = get_activation(activation)
151 | self.dropout = torch.nn.Dropout(dropout)
152 | self.last_act = last_activation
153 | self.use_bn = use_bn
154 |
155 | self.convs = ModuleList()
156 | self.bns = ModuleList()
157 | if self.layer_num > 1:
158 | self.convs.append(GATConv(input_dim, hidden_size))
159 | for i in range(layer_num-1):
160 | self.convs.append(GATConv(hidden_size, hidden_size))
161 | self.bns.append(BatchNorm1d(hidden_size))
162 | else:
163 | self.convs.append(GATConv(input_dim, hidden_size))
164 | self.bns.append(BatchNorm1d(hidden_size))
165 |
166 | def forward(self, x, edge_index, **kwargs):
167 | for i in range(self.layer_num):
168 | x = self.bns[i](self.convs[i](x, edge_index))
169 | if i == self.layer_num - 1 and not self.last_act:
170 | pass
171 | else:
172 | x = self.activation(x)
173 | x = self.dropout(x)
174 | return x
175 |
176 | def reset_parameters(self):
177 | for i in range(self.layer_num):
178 | self.convs[i].reset_parameters()
179 | self.bns[i].reset_parameters()
180 |
181 |
182 | @register.encoder_register
183 | class MLP_Encoder(torch.nn.Module):
184 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True):
185 | super(MLP_Encoder, self).__init__()
186 | self.layer_num = layer_num
187 | self.hidden_size = hidden_size
188 | self.input_dim = input_dim
189 | self.activation = get_activation(activation)
190 | self.dropout = torch.nn.Dropout(dropout)
191 | self.last_act = last_activation
192 | self.use_bn = use_bn
193 |
194 | self.convs = ModuleList()
195 | self.bns = ModuleList()
196 |
197 | self.readout = global_mean_pool
198 | if self.layer_num > 1:
199 | self.convs.append(nn.Linear(input_dim, hidden_size))
200 | for i in range(layer_num-1):
201 | self.convs.append(nn.Linear(hidden_size, hidden_size))
202 | for i in range(layer_num):
203 | if use_bn:
204 | self.bns.append(BatchNorm1d(hidden_size))
205 | else:
206 | self.bns.append(Identity())
207 |
208 | else:
209 | self.convs.append(nn.Linear(input_dim, hidden_size))
210 | if use_bn:
211 | self.bns.append(BatchNorm1d(hidden_size))
212 | else:
213 | self.bns.append(Identity())
214 |
215 | def forward(self, x, edge_index, **kwargs):
216 | for i in range(self.layer_num):
217 | x = self.bns[i](self.convs[i](x))
218 | if i == self.layer_num - 1 and not self.last_act:
219 | pass
220 | else:
221 | x = self.activation(x)
222 | x = self.dropout(x)
223 |
224 | return x
225 |
226 | def reset_parameters(self):
227 | for i in range(self.layer_num):
228 | self.convs[i].reset_parameters()
229 | if self.use_bn:
230 | self.bns[i].reset_parameters()
--------------------------------------------------------------------------------
/models/gcc.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch_geometric.nn import GINConv
7 | from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
8 | import torch_geometric
9 | from utils.register import register
10 |
11 |
12 | def change_params_key(params):
13 | """
14 | Change GCC source parameters keys
15 | """
16 | for key in list(params.keys()):
17 | sp = key.split('.')
18 | if len(sp) > 3 and sp[3] == 'apply_func':
19 | sp[3] = 'nn'
20 | str = '.'.join(sp)
21 | params[str] = params[key]
22 | params.pop(key)
23 | if sp[0] == 'set2set':
24 | params.pop(key)
25 |
26 |
27 | @register.model_register
28 | class GCC(nn.Module):
29 | """
30 | MPNN from
31 | `Neural Message Passing for Quantum Chemistry `__
32 |
33 | Parameters
34 | ----------
35 | node_input_dim : int
36 | Dimension of input node feature, default to be 15.
37 | edge_input_dim : int
38 | Dimension of input edge feature, default to be 15.
39 | output_dim : int
40 | Dimension of prediction, default to be 12.
41 | node_hidden_dim : int
42 | Dimension of node feature in hidden layers, default to be 64.
43 | edge_hidden_dim : int
44 | Dimension of edge feature in hidden layers, default to be 128.
45 | num_step_message_passing : int
46 | Number of message passing steps, default to be 6.
47 | num_step_set2set : int
48 | Number of set2set steps
49 | num_layer_set2set : int
50 | Number of set2set layers
51 | """
52 |
53 | def __init__(
54 | self,
55 | positional_embedding_size=32,
56 | max_node_freq=8,
57 | max_edge_freq=8,
58 | max_degree=128,
59 | freq_embedding_size=32,
60 | degree_embedding_size=32,
61 | output_dim=32,
62 | node_hidden_dim=32,
63 | edge_hidden_dim=32,
64 | num_layers=6,
65 | num_heads=4,
66 | num_step_set2set=6,
67 | num_layer_set2set=3,
68 | norm=False,
69 | gnn_model="mpnn",
70 | degree_input=False,
71 | lstm_as_gate=False,
72 | num_classes=10
73 | ):
74 | super(GCC, self).__init__()
75 |
76 | if degree_input:
77 | node_input_dim = positional_embedding_size + degree_embedding_size + 1
78 | else:
79 | node_input_dim = positional_embedding_size + 1
80 | # node_input_dim = (
81 | # positional_embedding_size + freq_embedding_size + degree_embedding_size + 3
82 | # )
83 | edge_input_dim = freq_embedding_size + 1
84 | self.gnn = UnsupervisedGIN(
85 | num_layers=num_layers,
86 | num_mlp_layers=2,
87 | input_dim=node_input_dim,
88 | hidden_dim=node_hidden_dim,
89 | output_dim=output_dim,
90 | final_dropout=0.5,
91 | learn_eps=False,
92 | graph_pooling_type="sum",
93 | neighbor_pooling_type="sum",
94 | use_selayer=False,
95 | )
96 | self.gnn_model = gnn_model
97 |
98 | self.max_node_freq = max_node_freq
99 | self.max_edge_freq = max_edge_freq
100 | self.max_degree = max_degree
101 | self.degree_input = degree_input
102 |
103 |
104 | if degree_input:
105 | self.degree_embedding = nn.Embedding(
106 | num_embeddings=max_degree + 1, embedding_dim=degree_embedding_size
107 | )
108 |
109 | self.lin_readout = nn.Sequential(
110 | nn.Linear(2 * node_hidden_dim, node_hidden_dim),
111 | nn.ReLU(),
112 | nn.Linear(node_hidden_dim, output_dim),
113 | )
114 | self.norm = norm
115 |
116 | def forward(self, x, edge_index, edge_weight=None, frozen=False, **kwargs):
117 | raise NotImplementedError('Please use --subsampling')
118 |
119 | def forward_subgraph(self, x, edge_index, batch, root_n_id, edge_weight=None, frozen=False, **kwargs):
120 | """Predict molecule labels
121 |
122 | Parameters
123 | ----------
124 | g : DGLGraph
125 | Input DGLGraph for molecule(s)
126 | n_feat : tensor of dtype float32 and shape (B1, D1)
127 | Node features. B1 for number of nodes and D1 for
128 | the node feature size.
129 | e_feat : tensor of dtype float32 and shape (B2, D2)
130 | Edge features. B2 for number of edges and D2 for
131 | the edge feature size.
132 |
133 | Returns
134 | -------
135 | res : Predicted labels
136 | """
137 | # nfreq = g.ndata["nfreq"]
138 | if self.degree_input:
139 | # device = g.ndata["seed"].device
140 | device = x.device
141 | degrees = torch_geometric.utils.degree(edge_index[0]).long().to(device)
142 | ego_indicator = torch.zeros(x.shape[0]).bool().to(device)
143 | ego_indicator[root_n_id] = True
144 |
145 | n_feat = torch.cat(
146 | (
147 | x,
148 | self.degree_embedding(degrees.clamp(0, self.max_degree)),
149 | ego_indicator.unsqueeze(1).float(),
150 | ),
151 | dim=-1,
152 | )
153 | else:
154 | n_feat = torch.cat(
155 | (
156 | x,
157 | ),
158 | dim=-1,
159 | )
160 |
161 | e_feat = None
162 |
163 | x, all_outputs = self.gnn(n_feat, edge_index, batch)
164 |
165 | if self.norm:
166 | x = F.normalize(x, p=2, dim=-1, eps=1e-5)
167 |
168 | return x
169 |
170 |
171 | class SELayer(nn.Module):
172 | """Squeeze-and-excitation networks"""
173 |
174 | def __init__(self, in_channels, se_channels):
175 | super(SELayer, self).__init__()
176 |
177 | self.in_channels = in_channels
178 | self.se_channels = se_channels
179 |
180 | self.encoder_decoder = nn.Sequential(
181 | nn.Linear(in_channels, se_channels),
182 | nn.ELU(),
183 | nn.Linear(se_channels, in_channels),
184 | nn.Sigmoid(),
185 | )
186 |
187 | def forward(self, x):
188 | """"""
189 | # Aggregate input representation
190 | x_global = torch.mean(x, dim=0)
191 | # Compute reweighting vector s
192 | s = self.encoder_decoder(x_global)
193 |
194 | return x * s
195 |
196 |
197 | class ApplyNodeFunc(nn.Module):
198 | """Update the node feature hv with MLP, BN and ReLU."""
199 |
200 | def __init__(self, mlp, use_selayer):
201 | super(ApplyNodeFunc, self).__init__()
202 | self.mlp = mlp
203 | self.bn = (
204 | SELayer(self.mlp.output_dim, int(np.sqrt(self.mlp.output_dim)))
205 | if use_selayer
206 | else nn.BatchNorm1d(self.mlp.output_dim)
207 | )
208 |
209 | def forward(self, h):
210 | h = self.mlp(h)
211 | h = self.bn(h)
212 | h = F.relu(h)
213 | return h
214 |
215 |
216 | class MLP(nn.Module):
217 | """MLP with linear output"""
218 |
219 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_selayer):
220 | """MLP layers construction
221 |
222 | Paramters
223 | ---------
224 | num_layers: int
225 | The number of linear layers
226 | input_dim: int
227 | The dimensionality of input features
228 | hidden_dim: int
229 | The dimensionality of hidden units at ALL layers
230 | output_dim: int
231 | The number of classes for prediction
232 |
233 | """
234 | super(MLP, self).__init__()
235 | self.linear_or_not = True # default is linear model
236 | self.num_layers = num_layers
237 | self.output_dim = output_dim
238 |
239 | if num_layers < 1:
240 | raise ValueError("number of layers should be positive!")
241 | elif num_layers == 1:
242 | # Linear model
243 | self.linear = nn.Linear(input_dim, output_dim)
244 | else:
245 | # Multi-layer model
246 | self.linear_or_not = False
247 | self.linears = torch.nn.ModuleList()
248 | self.batch_norms = torch.nn.ModuleList()
249 |
250 | self.linears.append(nn.Linear(input_dim, hidden_dim))
251 | for layer in range(num_layers - 2):
252 | self.linears.append(nn.Linear(hidden_dim, hidden_dim))
253 | self.linears.append(nn.Linear(hidden_dim, output_dim))
254 |
255 | for layer in range(num_layers - 1):
256 | self.batch_norms.append(
257 | SELayer(hidden_dim, int(np.sqrt(hidden_dim)))
258 | if use_selayer
259 | else nn.BatchNorm1d(hidden_dim)
260 | )
261 |
262 | def forward(self, x):
263 | if self.linear_or_not:
264 | # If linear model
265 | return self.linear(x)
266 | else:
267 | # If MLP
268 | h = x
269 | for i in range(self.num_layers - 1):
270 | h = F.relu(self.batch_norms[i](self.linears[i](h)))
271 | return self.linears[-1](h)
272 |
273 |
274 | class UnsupervisedGIN(nn.Module):
275 | """GIN model"""
276 |
277 | def __init__(
278 | self,
279 | num_layers,
280 | num_mlp_layers,
281 | input_dim,
282 | hidden_dim,
283 | output_dim,
284 | final_dropout,
285 | learn_eps,
286 | graph_pooling_type,
287 | neighbor_pooling_type,
288 | use_selayer,
289 | ):
290 | """model parameters setting
291 |
292 | Paramters
293 | ---------
294 | num_layers: int
295 | The number of linear layers in the neural network
296 | num_mlp_layers: int
297 | The number of linear layers in mlps
298 | input_dim: int
299 | The dimensionality of input features
300 | hidden_dim: int
301 | The dimensionality of hidden units at ALL layers
302 | output_dim: int
303 | The number of classes for prediction
304 | final_dropout: float
305 | dropout ratio on the final linear layer
306 | learn_eps: boolean
307 | If True, learn epsilon to distinguish center nodes from neighbors
308 | If False, aggregate neighbors and center nodes altogether.
309 | neighbor_pooling_type: str
310 | how to aggregate neighbors (sum, mean, or max)
311 | graph_pooling_type: str
312 | how to aggregate entire nodes in a graph (sum, mean or max)
313 |
314 | """
315 | super(UnsupervisedGIN, self).__init__()
316 | self.num_layers = num_layers
317 | self.learn_eps = learn_eps
318 |
319 | # List of MLPs
320 | self.ginlayers = torch.nn.ModuleList()
321 | self.batch_norms = torch.nn.ModuleList()
322 |
323 | for layer in range(self.num_layers - 1):
324 | if layer == 0:
325 | mlp = MLP(
326 | num_mlp_layers, input_dim, hidden_dim, hidden_dim, use_selayer
327 | )
328 | else:
329 | mlp = MLP(
330 | num_mlp_layers, hidden_dim, hidden_dim, hidden_dim, use_selayer
331 | )
332 |
333 | self.ginlayers.append(
334 | GINConv(
335 | ApplyNodeFunc(mlp, use_selayer),
336 | 0,
337 | self.learn_eps,
338 | )
339 | )
340 | self.batch_norms.append(
341 | SELayer(hidden_dim, int(np.sqrt(hidden_dim)))
342 | if use_selayer
343 | else nn.BatchNorm1d(hidden_dim)
344 | )
345 |
346 | # Linear function for graph poolings of output of each layer
347 | # which maps the output of different layers into a prediction score
348 | self.linears_prediction = torch.nn.ModuleList()
349 |
350 | for layer in range(num_layers):
351 | if layer == 0:
352 | self.linears_prediction.append(
353 | nn.Linear(input_dim, output_dim))
354 | else:
355 | self.linears_prediction.append(
356 | nn.Linear(hidden_dim, output_dim))
357 |
358 | self.drop = nn.Dropout(final_dropout)
359 |
360 | if graph_pooling_type == "sum":
361 | self.pool = global_add_pool
362 | elif graph_pooling_type == "mean":
363 | self.pool = global_mean_pool
364 | elif graph_pooling_type == "max":
365 | self.pool = global_max_pool
366 | else:
367 | raise NotImplementedError
368 |
369 | def forward(self, x, edge_index, batch):
370 | # list of hidden representation at each layer (including input)
371 | hidden_rep = [x]
372 | h = x
373 | for i in range(self.num_layers - 1):
374 | h = self.ginlayers[i](h, edge_index)
375 | h = self.batch_norms[i](h)
376 | h = F.relu(h)
377 | hidden_rep.append(h)
378 |
379 | score_over_layer = 0
380 |
381 | # perform pooling over all nodes in each graph in every layer
382 | all_outputs = []
383 | for i, h in list(enumerate(hidden_rep)):
384 | pooled_h = self.pool(h, batch)
385 | all_outputs.append(pooled_h)
386 | score_over_layer += self.drop(self.linears_prediction[i](pooled_h))
387 |
388 | return score_over_layer, all_outputs[1:]
389 |
390 |
--------------------------------------------------------------------------------
/models/gcc_graphcontrol.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch_geometric.nn import GINConv
7 | from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
8 | from utils.register import register
9 | import copy
10 | from .gcc import GCC
11 |
12 | @register.model_register
13 | class GCC_GraphControl(nn.Module):
14 |
15 | def __init__(
16 | self,
17 | **kwargs
18 | ):
19 | super(GCC_GraphControl, self).__init__()
20 | input_dim = kwargs['positional_embedding_size']
21 | hidden_size = kwargs['node_hidden_dim']
22 | output_dim = kwargs['num_classes']
23 |
24 | self.encoder = GCC(**kwargs)
25 | self.trainable_copy = copy.deepcopy(self.encoder)
26 |
27 | self.zero_conv1 = torch.nn.Linear(input_dim, input_dim)
28 | self.zero_conv2 = torch.nn.Linear(hidden_size, hidden_size)
29 |
30 | self.linear_classifier = torch.nn.Linear(hidden_size, output_dim)
31 |
32 | with torch.no_grad():
33 | self.zero_conv1.weight = torch.nn.Parameter(torch.zeros(input_dim, input_dim))
34 | self.zero_conv1.bias = torch.nn.Parameter(torch.zeros(input_dim))
35 | self.zero_conv2.weight = torch.nn.Parameter(torch.zeros(hidden_size, hidden_size))
36 | self.zero_conv2.bias = torch.nn.Parameter(torch.zeros(hidden_size))
37 |
38 | self.prompt = torch.nn.Parameter(torch.normal(mean=0, std=0.01, size=(1, input_dim)))
39 |
40 | def forward(self, x, edge_index, edge_weight=None, frozen=False, **kwargs):
41 | raise NotImplementedError('Please use --subsampling')
42 |
43 | def reset_classifier(self):
44 | self.linear_classifier.reset_parameters()
45 |
46 | def forward_subgraph(self, x, x_sim, edge_index, batch, root_n_id, edge_weight=None, frozen=False, **kwargs):
47 | if frozen:
48 | with torch.no_grad():
49 | self.encoder.eval()
50 | out = self.encoder.forward_subgraph(x, edge_index, batch, root_n_id)
51 |
52 | x_down = self.zero_conv1(x_sim)
53 | x_down = x_down + x
54 |
55 | # for simplicity, we use edge_index to calculate degrees
56 | x_down = self.trainable_copy.forward_subgraph(x_down, edge_index, batch, root_n_id)
57 |
58 | x_down = self.zero_conv2(x_down)
59 |
60 | out = x_down + out
61 | else:
62 | raise NotImplementedError('Please freeze pre-trained models')
63 |
64 | x = self.linear_classifier(out)
65 | return x
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utils.register import register
4 | from .encoder import MLP_Encoder
5 | from torch_geometric.nn import global_mean_pool
6 |
7 |
8 | class Two_MLP_BN(torch.nn.Module):
9 | r"""
10 | Applies a non-linear transformation to contrastive space from representations.
11 |
12 | Args:
13 | hidden size of encoder, mlp hidden size, mlp output size
14 | """
15 | def __init__(self, hidden, mlp_hid, mlp_out):
16 |
17 | super(Two_MLP_BN, self).__init__()
18 | self.proj = nn.Sequential(
19 | nn.Linear(hidden, mlp_hid),
20 | nn.BatchNorm1d(mlp_hid),
21 | nn.ReLU(),
22 | nn.Linear(mlp_hid, mlp_out)
23 | )
24 |
25 | def forward(self, feat):
26 | return self.proj(feat)
27 |
28 | class Two_MLP(nn.Module):
29 | r"""MLP used for predictor. The MLP has one hidden layer.
30 |
31 | Args:
32 | input_size (int): Size of input features.
33 | output_size (int): Size of output features.
34 | hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`).
35 | """
36 | def __init__(self, input_size, output_size, hidden_size=512):
37 | super().__init__()
38 |
39 | self.net = nn.Sequential(
40 | nn.Linear(input_size, hidden_size, bias=True),
41 | nn.PReLU(1),
42 | nn.Linear(hidden_size, output_size, bias=True)
43 | )
44 | self.reset_parameters()
45 |
46 | def forward(self, x):
47 | return self.net(x)
48 |
49 | def reset_parameters(self):
50 | # kaiming_uniform
51 | for m in self.modules():
52 | if isinstance(m, nn.Linear):
53 | m.reset_parameters()
54 |
55 |
56 |
57 |
58 | @register.model_register
59 | class MLP(torch.nn.Module):
60 | def __init__(self, input_dim, layer_num=2, hidden_size=128, output_dim=70, activation="relu", dropout=0.5, use_bn=False, **kargs):
61 | super(MLP, self).__init__()
62 | self.layer_num = layer_num
63 | self.hidden = hidden_size
64 | self.input_dim = input_dim
65 |
66 | self.encoder = MLP_Encoder(input_dim, layer_num, hidden_size, activation, dropout, use_bn)
67 |
68 | self.eigen_val_emb = torch.nn.Sequential(torch.nn.Linear(32, hidden_size),
69 | torch.nn.ReLU(),
70 | torch.nn.Linear(hidden_size, hidden_size))
71 |
72 | self.classifier = torch.nn.Linear(hidden_size, output_dim)
73 | self.linear_classifier = torch.nn.Linear(hidden_size*2, output_dim)
74 |
75 | def forward(self, x, edge_index, edge_weight=None, frozen=False):
76 | if frozen:
77 | with torch.no_grad():
78 | self.encoder.eval()
79 | x = self.encoder(x=x, edge_index=edge_index, edge_weight=edge_weight)
80 | else:
81 | x = self.encoder(x=x, edge_index=edge_index, edge_weight=edge_weight)
82 |
83 | x = self.classifier(x)
84 | return x
85 |
86 | def forward_subgraph(self, x, edge_index, batch, root_n_id, edge_weight=None, **kwargs):
87 | x = self.encoder(x=x, edge_index=edge_index, edge_weight=edge_weight)
88 | x = torch.cat([x[root_n_id], global_mean_pool(x, batch)], dim=-1)
89 |
90 | x = self.linear_classifier(x) # use linear classifier
91 | return x
92 |
93 | def reset_classifier(self):
94 | torch.nn.init.xavier_uniform_(self.classifier.weight.data)
95 | torch.nn.init.constant_(self.classifier.bias.data, 0)
96 |
97 |
98 |
--------------------------------------------------------------------------------
/models/model_manager.py:
--------------------------------------------------------------------------------
1 | from utils.register import register
2 | import torch
3 | from .gcc import change_params_key
4 |
5 |
6 | def load_model(input_dim: int, output_dim: int, config):
7 | if config.model in ['GCC', 'GCC_GraphControl']:
8 | state_dict = torch.load('checkpoint/gcc.pth', map_location='cpu')
9 | opt = state_dict['opt']
10 | model = register.models[config.model](
11 | positional_embedding_size=opt.positional_embedding_size,
12 | max_node_freq=opt.max_node_freq,
13 | max_edge_freq=opt.max_edge_freq,
14 | max_degree=opt.max_degree,
15 | freq_embedding_size=opt.freq_embedding_size,
16 | degree_embedding_size=opt.degree_embedding_size,
17 | output_dim=opt.hidden_size,
18 | node_hidden_dim=opt.hidden_size,
19 | edge_hidden_dim=opt.hidden_size,
20 | num_layers=opt.num_layer,
21 | num_step_set2set=opt.set2set_iter,
22 | num_layer_set2set=opt.set2set_lstm_layer,
23 | gnn_model=opt.model,
24 | norm=opt.norm,
25 | degree_input=True,
26 | num_classes = output_dim
27 | )
28 | params = state_dict['model']
29 | change_params_key(params)
30 |
31 | if config.model == 'GCC':
32 | model.load_state_dict(params)
33 | return model
34 | elif config.model == 'GCC_GraphControl':
35 | model.encoder.load_state_dict(params)
36 | model.trainable_copy.load_state_dict(params)
37 | return model
38 | else:
39 | return register.models[config.model](input_dim=input_dim, output_dim=output_dim, **vars(config))
--------------------------------------------------------------------------------
/models/pooler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_scatter
3 |
4 |
5 | def subg_pooling(reps, data):
6 | batch_size = data.batch.max().cpu().item() + 1
7 | graphsize_perbat = torch.zeros(batch_size, dtype=data.batch.dtype, device=data.batch.device)
8 | tmp = torch.ones_like(data.batch)
9 | torch_scatter.scatter_add(tmp, data.batch, out=graphsize_perbat)
10 | center_indices = data.center
11 | center_mask = torch.zeros_like(data.batch)
12 |
13 | pointer = 0
14 | for i in range(0, batch_size):
15 | center_mask[center_indices[i] + pointer] = 1
16 | pointer += graphsize_perbat[i]
17 |
18 | center_mask = center_mask.bool()
19 | return reps[center_mask], data.y
--------------------------------------------------------------------------------
/node2vec.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is used for generating node embeddings for datasets with graph topology.
3 | '''
4 |
5 | import os.path as osp
6 | import sys
7 |
8 | from torch_geometric.datasets import Planetoid
9 | from torch_geometric.nn import Node2Vec
10 | from datasets import NodeDataset
11 | from utils.args import Arguments
12 | from utils.random import reset_random_seed
13 | import torch
14 | import os
15 | import numpy as np
16 | from tqdm import tqdm
17 |
18 | PATH = f'./datasets/data'
19 |
20 |
21 | def train():
22 | model.train()
23 | total_loss = 0
24 | for pos_rw, neg_rw in loader:
25 | optimizer.zero_grad()
26 | loss = model.loss(pos_rw.to(device), neg_rw.to(device))
27 | loss.backward()
28 | optimizer.step()
29 | total_loss += loss.item()
30 | return total_loss / len(loader)
31 |
32 | @torch.no_grad()
33 | def test():
34 | model.eval()
35 | z = model()
36 | acc = model.test(z[data.train_mask], data.y[data.train_mask],
37 | z[data.test_mask], data.y[data.test_mask],
38 | max_iter=150)
39 | return acc
40 |
41 |
42 | if __name__ == "__main__":
43 | config = Arguments().parse_args()
44 |
45 | dataset_obj = NodeDataset(dataset_name=config.dataset)
46 | dataset_obj.print_statistics()
47 | data = dataset_obj.data
48 |
49 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
50 | num_workers = 0 if sys.platform.startswith('win') else 4
51 |
52 | train_masks = dataset_obj.data.train_mask
53 | test_masks = dataset_obj.data.test_mask
54 |
55 | model = Node2Vec(
56 | data.edge_index,
57 | embedding_dim=config.emb_dim, # 256 for USA, 64 for Europe, 32 for Brazil
58 | walk_length=config.walk_length,
59 | context_size=config.context_size,
60 | walks_per_node=config.walk_per_nodes,
61 | sparse=True,
62 | ).to(device)
63 |
64 | loader = model.loader(batch_size=config.batch_size, shuffle=True,
65 | num_workers=num_workers)
66 | optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=config.lr)
67 |
68 | if dataset_obj.random_split:
69 | dataset_obj.data.train_mask = train_masks[:, 0]
70 | dataset_obj.data.test_mask = test_masks[:, 0]
71 |
72 | progress = tqdm(range(0, config.epochs))
73 | for epoch in progress:
74 | loss = train()
75 | acc = test()
76 | progress.set_postfix_str(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Acc: {acc:.4f}')
77 |
78 | # save embedding
79 | torch.save(model.embedding.weight.cpu(), f=f'{PATH}/{config.dataset}/processed/node2vec.pt')
80 |
81 | acc_list = []
82 |
83 | for seed in config.seeds:
84 | reset_random_seed(seed)
85 | dataset_obj.data.train_mask = train_masks[:, seed]
86 | dataset_obj.data.test_mask = test_masks[:, seed]
87 | acc = test()
88 | acc_list.append(acc)
89 |
90 |
91 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list)
92 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")
93 |
--------------------------------------------------------------------------------
/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | optimizers_dicts = {
5 | 'sgd': torch.optim.SGD,
6 | 'adam': torch.optim.Adam,
7 | 'adamw': torch.optim.AdamW,
8 | 'radam': torch.optim.RAdam,
9 | 'nadam': torch.optim.NAdam
10 | }
11 |
12 | def create_optimizer(**kwargs):
13 | lr = kwargs['lr']
14 | weight_decay = kwargs['weight_decay']
15 | name = kwargs['name']
16 | parameters = kwargs['parameters']
17 |
18 | return optimizers_dicts[name](parameters, lr=lr, weight_decay=weight_decay)
19 |
--------------------------------------------------------------------------------
/png/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wykk00/GraphControl/f437010a4c09f01baf1278e747c6951bdf3d9d17/png/framework.png
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wykk00/GraphControl/f437010a4c09f01baf1278e747c6951bdf3d9d17/utils/__init__.py
--------------------------------------------------------------------------------
/utils/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | class Arguments:
4 | def __init__(self) -> None:
5 | self.parser = argparse.ArgumentParser()
6 | self.parser.add_argument("--seeds", type=int, nargs="+", default=[0])
7 | # Dataset
8 | self.parser.add_argument('--dataset', type=str, help="dataset name", default='Cora_ML')
9 |
10 | # Model configuration
11 | self.parser.add_argument('--layer_num', type=int, help="the number of encoder's layers", default=2)
12 | self.parser.add_argument('--hidden_size', type=int, help="the hidden size", default=128)
13 | self.parser.add_argument('--dropout', type=float, help="dropout rate", default=0.0)
14 | self.parser.add_argument('--activation', type=str, help="activation function", default='relu',
15 | choices=['relu', 'elu', 'hardtanh', 'leakyrelu', 'prelu', 'rrelu'])
16 | self.parser.add_argument('--use_bn', action='store_true', help="use BN or not")
17 | self.parser.add_argument('--model', type=str, help="model name", default='GCC_GraphControl',
18 | choices=['GCC', 'GCC_GraphControl'])
19 |
20 | # Training settings
21 | self.parser.add_argument('--optimizer', type=str, help="the kind of optimizer", default='adam',
22 | choices=['adam', 'sgd', 'adamw', 'nadam', 'radam'])
23 | self.parser.add_argument('--lr', type=float, help="learning rate", default=1e-3)
24 | self.parser.add_argument('--weight_decay', type=float, help="weight decay", default=5e-4)
25 | self.parser.add_argument('--epochs', type=int, help="training epochs", default=200)
26 | self.parser.add_argument('--batch_size', type=int, default=128)
27 | self.parser.add_argument('--finetune', action='store_true', help="Quickly find optim parameters")
28 |
29 | # Processing node attributes
30 | self.parser.add_argument('--use_adj', action='store_true', help="use eigen-vectors of adjacent matrix as node attributes")
31 | self.parser.add_argument('--threshold', type=float, help="the threshold for discreting similarity matrix", default=0.15)
32 | self.parser.add_argument('--num_dim', type=int, help="the number of replaced node attributes", default=32)
33 | # self.parser.add_argument('--ad_aug', action='store_true', help="adversarial augmentation")
34 | self.parser.add_argument('--restart', type=float, help="the restart ratio of random walking", default=0.3)
35 | self.parser.add_argument('--walk_steps', type=int, help="the number of random walk's steps", default=256)
36 |
37 | # Node2vec config
38 | self.parser.add_argument('--emb_dim', type=int, default=128, help="Embedding dim for node2vec")
39 | self.parser.add_argument('--walk_length', type=int, default=50, help="Walk length for node2vec")
40 | self.parser.add_argument('--context_size', type=int, default=10, help="Context size for node2vec")
41 | self.parser.add_argument('--walk_per_nodes', type=int, default=10, help="Walk per nodes for node2vec")
42 |
43 | def parse_args(self):
44 | return self.parser.parse_args()
45 |
--------------------------------------------------------------------------------
/utils/augmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def drop_feature(x, drop_prob):
5 | drop_mask = torch.empty(
6 | (x.size(1), ),
7 | dtype=torch.float32,
8 | device=x.device).uniform_(0, 1) < drop_prob
9 | x = x.clone()
10 | x[:, drop_mask] = 0
11 |
12 | return x
13 |
14 | def adversarial_aug_train(model, node_attack, perturb_shape, step_size, m, device):
15 | model.train()
16 |
17 | perturb = torch.FloatTensor(*perturb_shape).uniform_(-step_size, step_size).to(device)
18 | perturb.requires_grad_()
19 |
20 | loss = node_attack(perturb)
21 | loss /= m
22 |
23 | for i in range(m-1):
24 | loss.backward()
25 | perturb_data = perturb.detach() + step_size * torch.sign(perturb.grad.detach())
26 | perturb.data = perturb_data.data
27 | perturb.grad[:] = 0
28 |
29 | loss = node_attack(perturb)
30 | loss /= m
31 |
32 | return loss
--------------------------------------------------------------------------------
/utils/normalize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def get_laplacian_matrix(adj):
6 | '''
7 | Calculating laplacian matrix.
8 |
9 | Args:
10 | adj: adjacent matrix or discrete similarity matrix.
11 |
12 | Returns:
13 | normalized laplacian matrix.
14 | '''
15 | EPS = 1e-6
16 | # check and remove self-loop
17 | I = torch.eye(adj.shape[0], device=adj.device)
18 | if torch.diag(adj).sum().item()+EPS >= adj.shape[0]:
19 | tmp = adj - I
20 | else:
21 | tmp = adj
22 |
23 | D = tmp.sum(dim=1).clip(1)
24 | D_rsqrt = torch.rsqrt(D)
25 | D_rsqrt = torch.diag(D_rsqrt)
26 | lap_mat = I - D_rsqrt@tmp@D_rsqrt
27 | return lap_mat
28 |
29 | def similarity(z1: torch.Tensor, z2: torch.Tensor):
30 | z1 = F.normalize(z1)
31 | z2 = F.normalize(z2)
32 | return torch.mm(z1, z2.t())
--------------------------------------------------------------------------------
/utils/random.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def reset_random_seed(seed):
8 | r"""
9 | Initial process for fixing all possible random seed.
10 |
11 | Args:
12 | config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.random_seed`)
13 | """
14 | # Fix Random seed
15 | random.seed(seed)
16 | np.random.seed(seed)
17 | torch.manual_seed(seed)
18 | torch.cuda.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 | torch.backends.cudnn.deterministic = True
21 | torch.backends.cudnn.benchmark = False
22 |
23 | # Default state is a training state
24 | torch.enable_grad()
--------------------------------------------------------------------------------
/utils/register.py:
--------------------------------------------------------------------------------
1 | r"""A kernel module that contains a global register for unified model, dataset, and pre-training algorithms access.
2 | """
3 |
4 | class Register(object):
5 | r"""
6 | Global register for unified model, dataset, and pre-training algorithms access.
7 | """
8 |
9 | def __init__(self):
10 | self.pipelines = dict()
11 | self.launchers = dict()
12 | self.models = dict()
13 | self.datasets = dict()
14 | self.dataloader = dict()
15 | self.ood_algs = dict()
16 | self.encoders = dict()
17 |
18 | def pipeline_register(self, pipeline_class):
19 | r"""
20 | Register for pipeline access.
21 |
22 | Args:
23 | pipeline_class (class): pipeline class
24 |
25 | Returns (class):
26 | pipeline class
27 |
28 | """
29 | self.pipelines[pipeline_class.__name__] = pipeline_class
30 | return pipeline_class
31 |
32 | def launcher_register(self, launcher_class):
33 | r"""
34 | Register for pipeline access.
35 |
36 | Args:
37 | launcher_class (class): pipeline class
38 |
39 | Returns (class):
40 | pipeline class
41 |
42 | """
43 | self.launchers[launcher_class.__name__] = launcher_class
44 | return launcher_class
45 |
46 | def model_register(self, model_class):
47 | r"""
48 | Register for model access.
49 |
50 | Args:
51 | model_class (class): model class
52 |
53 | Returns (class):
54 | model class
55 |
56 | """
57 | self.models[model_class.__name__] = model_class
58 | return model_class
59 |
60 | def encoder_register(self, encoder_class):
61 | r"""
62 | Register for model access.
63 |
64 | Args:
65 | model_class (class): model class
66 |
67 | Returns (class):
68 | model class
69 |
70 | """
71 | self.encoders[encoder_class.__name__] = encoder_class
72 | return encoder_class
73 |
74 | def dataset_register(self, dataset_class):
75 | r"""
76 | Register for dataset access.
77 |
78 | Args:
79 | dataset_class (class): dataset class
80 |
81 | Returns (class):
82 | dataset class
83 |
84 | """
85 | self.datasets[dataset_class.__name__] = dataset_class
86 | return dataset_class
87 |
88 | def dataloader_register(self, dataloader_class):
89 | r"""
90 | Register for dataloader access.
91 |
92 | Args:
93 | dataloader_class (class): dataloader class
94 |
95 | Returns (class):
96 | dataloader class
97 |
98 | """
99 | self.dataloader[dataloader_class.__name__] = dataloader_class
100 | return dataloader_class
101 |
102 |
103 | register = Register() #: The register object used for accessing models, datasets and pre-training algorithms.
104 |
--------------------------------------------------------------------------------
/utils/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data
3 | from torch_geometric.utils import subgraph, to_undirected, remove_isolated_nodes, dropout_adj, remove_self_loops, k_hop_subgraph
4 | from torch_geometric.utils.num_nodes import maybe_num_nodes
5 | import copy
6 | from torch_sparse import SparseTensor
7 |
8 | from .transforms import obtain_attributes
9 |
10 |
11 | def add_remaining_selfloop_for_isolated_nodes(edge_index, num_nodes):
12 | num_nodes = max(maybe_num_nodes(edge_index), num_nodes)
13 | # only add self-loop on isolated nodes
14 | # edge_index, _ = remove_self_loops(edge_index)
15 | loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device)
16 | connected_nodes_indices = torch.cat([edge_index[0], edge_index[1]]).unique()
17 | mask = torch.ones(num_nodes, dtype=torch.bool)
18 | mask[connected_nodes_indices] = False
19 | loops_for_isolatd_nodes = loop_index[mask]
20 | loops_for_isolatd_nodes = loops_for_isolatd_nodes.unsqueeze(0).repeat(2, 1)
21 | edge_index = torch.cat([edge_index, loops_for_isolatd_nodes], dim=1)
22 | return edge_index
23 |
24 |
25 | class RWR:
26 | """ Every node in the graph will get a random path
27 |
28 | A stochastic data augmentation module that transforms a complete graph into many subgraphs through random walking
29 | the subgraphs which contain the same center nodes are positive pairs, otherwise they are negative pairs
30 | """
31 |
32 | def __init__(self, walk_steps=50, graph_num=128, restart_ratio=0.5, inductive=False, aligned=False, **args):
33 | self.walk_steps = walk_steps
34 | self.graph_num = graph_num
35 | self.restart_ratio = restart_ratio
36 | self.inductive = inductive
37 | self.aligned = aligned
38 |
39 | def __call__(self, graph):
40 | graph = copy.deepcopy(graph) # modified on the copy
41 | assert self.walk_steps > 1
42 | # remove isolated nodes (or we can construct edges for these nodes)
43 | if self.inductive:
44 | train_node_idx = torch.where(graph.train_mask == True)[0]
45 | graph.edge_index, _ = subgraph(train_node_idx, graph.edge_index) # remove val and test nodes (val and test are considered as isolated nodes)
46 | edge_index, _, mask = remove_isolated_nodes(graph.edge_index, num_nodes=graph.x.shape[0]) # remove all ioslated nodes and re-index nodes
47 | graph.x = graph.x[mask]
48 | edge_index = to_undirected(graph.edge_index)
49 | edge_index = add_remaining_selfloop_for_isolated_nodes(edge_index, graph.x.shape[0])
50 | graph.edge_index = edge_index
51 |
52 | node_num = graph.x.shape[0]
53 | graph_num = min(self.graph_num, node_num)
54 | start_nodes = torch.randperm(node_num)[:graph_num]
55 | edge_index = graph.edge_index
56 |
57 | value = torch.arange(edge_index.size(1))
58 | self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
59 | value=value,
60 | sparse_sizes=(node_num, node_num)).t()
61 |
62 | view1_list = []
63 | view2_list = []
64 |
65 | views_cnt = 1 if self.aligned else 2
66 | for view_idx in range(views_cnt):
67 | current_nodes = start_nodes.clone()
68 | history = start_nodes.clone().unsqueeze(0)
69 | signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0)
70 | for i in range(self.walk_steps):
71 | seed = torch.rand([graph_num])
72 | nei = self.adj_t.sample(1, current_nodes).squeeze()
73 | sign = seed < self.restart_ratio
74 | nei[sign] = start_nodes[sign]
75 | history = torch.cat((history, nei.unsqueeze(0)), dim=0)
76 | signs = torch.cat((signs, sign.unsqueeze(0)), dim=0)
77 | current_nodes = nei
78 | history = history.T
79 | signs = signs.T
80 |
81 | for i in range(graph_num):
82 | path = history[i]
83 | sign = signs[i]
84 | node_idx = path.unique()
85 | sources = path[:-1].numpy().tolist()
86 | targets = path[1:].numpy().tolist()
87 | sub_edges = torch.IntTensor([sources, targets]).type_as(graph.edge_index)
88 | sub_edges = sub_edges.T[~sign[1:]].T
89 | # undirectional
90 | if sub_edges.shape[1] != 0:
91 | sub_edges = to_undirected(sub_edges)
92 | view = self.adjust_idx(sub_edges, node_idx, graph, path[0].item())
93 |
94 | if self.aligned:
95 | view1_list.append(view)
96 | view2_list.append(copy.deepcopy(view))
97 | else:
98 | if view_idx == 0:
99 | view1_list.append(view)
100 | else:
101 | view2_list.append(view)
102 | return (view1_list, view2_list)
103 |
104 | def adjust_idx(self, edge_index, node_idx, full_g, center_idx):
105 | '''re-index the nodes and edge index
106 |
107 | In the subgraphs, some nodes are droppped. We need to change the node index in edge_index in order to corresponds
108 | nodes' index to edge index
109 | '''
110 | node_idx_map = {j : i for i, j in enumerate(node_idx.numpy().tolist())}
111 | sources_idx = list(map(node_idx_map.get, edge_index[0].numpy().tolist()))
112 | target_idx = list(map(node_idx_map.get, edge_index[1].numpy().tolist()))
113 |
114 | edge_index = torch.IntTensor([sources_idx, target_idx]).type_as(full_g.edge_index)
115 | # x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], center=node_idx_map[center_idx], original_idx=node_idx)
116 | x = obtain_attributes(Data(edge_index=edge_index), use_adj=True)
117 | x_view = Data(edge_index=edge_index, x=x, center=node_idx_map[center_idx], original_idx=node_idx, y=full_g.y[center_idx], root_n_index=node_idx_map[center_idx])
118 | return x_view
119 |
120 |
121 | def collect_subgraphs(selected_id, graph, walk_steps=20, restart_ratio=0.5):
122 | graph = copy.deepcopy(graph) # modified on the copy
123 | edge_index = graph.edge_index
124 | node_num = graph.x.shape[0]
125 | start_nodes = selected_id # only sampling selected nodes as subgraphs
126 | graph_num = start_nodes.shape[0]
127 |
128 | value = torch.arange(edge_index.size(1))
129 | adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
130 | value=value,
131 | sparse_sizes=(node_num, node_num)).t()
132 |
133 | current_nodes = start_nodes.clone()
134 | history = start_nodes.clone().unsqueeze(0)
135 | signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0)
136 | for i in range(walk_steps):
137 | seed = torch.rand([graph_num])
138 | nei = adj_t.sample(1, current_nodes).squeeze()
139 | sign = seed < restart_ratio
140 | nei[sign] = start_nodes[sign]
141 | history = torch.cat((history, nei.unsqueeze(0)), dim=0)
142 | signs = torch.cat((signs, sign.unsqueeze(0)), dim=0)
143 | current_nodes = nei
144 | history = history.T
145 | signs = signs.T
146 |
147 | graph_list = []
148 | for i in range(graph_num):
149 | path = history[i]
150 | sign = signs[i]
151 | node_idx = path.unique()
152 | sources = path[:-1].numpy().tolist()
153 | targets = path[1:].numpy().tolist()
154 | sub_edges = torch.IntTensor([sources, targets]).type_as(graph.edge_index)
155 | sub_edges = sub_edges.T[~sign[1:]].T
156 | # undirectional
157 | if sub_edges.shape[1] != 0:
158 | sub_edges = to_undirected(sub_edges)
159 | view = adjust_idx(sub_edges, node_idx, graph, path[0].item())
160 |
161 | graph_list.append(view)
162 | return graph_list
163 |
164 | def adjust_idx(edge_index, node_idx, full_g, center_idx):
165 | '''re-index the nodes and edge index
166 |
167 | In the subgraphs, some nodes are droppped. We need to change the node index in edge_index in order to corresponds
168 | nodes' index to edge index
169 | '''
170 | node_idx_map = {j : i for i, j in enumerate(node_idx.numpy().tolist())}
171 | sources_idx = list(map(node_idx_map.get, edge_index[0].numpy().tolist()))
172 | target_idx = list(map(node_idx_map.get, edge_index[1].numpy().tolist()))
173 |
174 | edge_index = torch.IntTensor([sources_idx, target_idx]).type_as(full_g.edge_index)
175 | x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], center=node_idx_map[center_idx], original_idx=node_idx, y=full_g.y[center_idx], root_n_index=node_idx_map[center_idx])
176 | return x_view
177 |
178 | def ego_graphs_sampler(node_idx, data, hop=2):
179 | ego_graphs = []
180 | for idx in node_idx:
181 | subset, sub_edge_index, mapping, edge_mask = k_hop_subgraph([idx], hop, data.edge_index, relabel_nodes=True)
182 | # sub_edge_index = to_undirected(sub_edge_index)
183 | sub_x = data.x[subset]
184 | # center_idx = subset[mapping].item() # node idx in the original graph, use idx instead
185 | g = Data(x=sub_x, edge_index=sub_edge_index, root_n_index=mapping, y=data.y[idx], original_idx=idx) # note: there we use root_n_index to record the index of target node, because `PyG` increments attributes by the number of nodes whenever their attribute names contain the substring :obj:`index`
186 | ego_graphs.append(g)
187 | return ego_graphs
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | from torch_geometric.utils import to_undirected, remove_self_loops, to_dense_adj
3 | import torch.nn.functional as F
4 | import torch
5 | import scipy
6 |
7 | from .normalize import similarity, get_laplacian_matrix
8 |
9 | def obtain_attributes(data, use_adj=False, threshold=0.1, num_dim=32):
10 | save_node_border = 30000
11 |
12 | if use_adj:
13 | # to undirected and remove self-loop
14 | edges = to_undirected(data.edge_index)
15 | edges, _ = remove_self_loops(edges)
16 | tmp = to_dense_adj(edges)[0]
17 | else:
18 | tmp = similarity(data.x, data.x)
19 |
20 | # discretize the similarity matrix by threshold
21 | tmp = torch.where(tmp>threshold, 1.0, 0.0)
22 |
23 | tmp = get_laplacian_matrix(tmp)
24 | if tmp.shape[0] > save_node_border:
25 | L, V = scipy.linalg.eigh(tmp)
26 | L = torch.from_numpy(L)
27 | V = torch.from_numpy(V)
28 | else:
29 | L, V = torch.linalg.eigh(tmp) # much faster than torch.linalg.eig
30 |
31 | x = V[:, :num_dim].float()
32 | import sklearn.preprocessing as preprocessing
33 | x = preprocessing.normalize(x.cpu(), norm="l2")
34 | x = torch.tensor(x, dtype=torch.float32)
35 |
36 | return x
37 |
38 |
39 | def process_attributes(data, use_adj=False, threshold=0.1, num_dim=32, soft=False, kernel=False):
40 | '''
41 | Replace the node attributes with positional encoding. Warning: this function will replace the node attributes!
42 |
43 | Args:
44 | data: a single graph contains x (if use_adj=False) and edge_index.
45 | use_adj: use the eigen-vectors of adjacent matrix or similarity matrix as node attributes.
46 | threshold: only work when use_adj=False, used for discretize the similarity matrix. 1 if Adj(i,j)>0.1 else 0
47 | soft: only work when use_adj=False, if soft=True, we will use soft similarity matrix.
48 |
49 | Returns:
50 | modified data.
51 | '''
52 |
53 | if use_adj:
54 | # to undirected and remove self-loop
55 | edges = to_undirected(data.edge_index)
56 | if edges.size(1) > 1:
57 | edges, _ = remove_self_loops(edges)
58 | else:
59 | edges = torch.tensor([[0],[0]]) # for isolated nodes
60 | Adj = to_dense_adj(edges)[0]
61 | else:
62 |
63 | if kernel:
64 | # memory efficient
65 | XY = (data.x@data.x.T) # 2xy
66 | deg = torch.diag(XY)
67 | Y_norm = deg.repeat(XY.shape[0],1)
68 | X_norm = Y_norm.T
69 | Adj = X_norm - 2*XY + Y_norm # |X-Y|^2
70 | Adj = torch.exp(-0.05*Adj) # rbf kernel
71 | else:
72 | Adj = similarity(data.x, data.x) # equal to linear kernel
73 | if soft:
74 | L, V = torch.linalg.eigh(Adj)
75 | x = V[:, :num_dim].float()
76 | x = F.normalize(x, dim=1)
77 | data.x = x
78 | return data
79 | else:
80 | # discretize the similarity matrix by threshold
81 | Adj = torch.where(Adj>threshold, 1.0, 0.0)
82 | Lap = get_laplacian_matrix(Adj)
83 |
84 | L, V = torch.linalg.eigh(Lap) # much faster than torch.linalg.eig, if this line triggers bugs please refer to https://github.com/pytorch/pytorch/issues/70122#issuecomment-1232766638
85 | L_sort, _ = torch.sort(L, descending=False)
86 | hist = torch.histc(L, bins=32, min=0, max=2)
87 | hist = hist.unsqueeze(0)
88 |
89 | # Padding
90 | import sklearn.preprocessing as preprocessing
91 | if V.shape[0] < num_dim:
92 | V = preprocessing.normalize(V, norm="l2")
93 | V = torch.tensor(V, dtype=torch.float32)
94 | x = torch.nn.functional.pad(V, (0, num_dim-V.shape[0]))
95 | data.x = x.float()
96 | data.eigen_val = torch.nn.functional.pad(L_sort, (0, num_dim-L_sort.shape[0])).unsqueeze(0)
97 | else:
98 | x = V[:, 0:num_dim].float()
99 | x = preprocessing.normalize(x, norm="l2")
100 | x = torch.tensor(x, dtype=torch.float32)
101 | data.x = x.float()
102 | data.eigen_val = L_sort[:num_dim].unsqueeze(0)
103 |
104 | return data
--------------------------------------------------------------------------------