├── .gitignore ├── README.md ├── datasets ├── __init__.py └── dataset │ ├── Hindex.py │ └── __init__.py ├── gcc.py ├── graphcontrol.py ├── models ├── __init__.py ├── encoder.py ├── gcc.py ├── gcc_graphcontrol.py ├── mlp.py ├── model_manager.py └── pooler.py ├── node2vec.py ├── optimizers └── __init__.py ├── png └── framework.png └── utils ├── __init__.py ├── args.py ├── augmentation.py ├── normalize.py ├── random.py ├── register.py ├── sampling.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/data 2 | checkpoint/ 3 | 4 | # all pyc files_ 5 | **/__pycache__ 6 | 7 | **/.vscode 8 | **/ipynb_checkpoints 9 | *.ipynb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphControl: Adding Conditional Control to Universal Graph Pre-trained Models for Graph Domain Transfer Learning 2 | 3 | **Official implementation of paper**
[GraphControl: Adding Conditional Control to Universal Graph Pre-trained Models for Graph Domain Transfer Learning](https://arxiv.org/abs/2310.07365)
4 | 5 | Yun Zhu*, Yaoke Wang*, Haizhou Shi, Zhenshuo Zhang, Dian Jiao, Siliang Tang† 6 | 7 | In WWW 2024 8 | 9 | ## Overview 10 | This is the first work to solve the "transferability-specificity dilemma" in graph domain transfer learning. To address this challenge, we introduce an innovative deployment module coined as GraphControl, motivated by ControlNet, to realize better graph domain transfer learning. The overview of our method is depicted as: 11 | 12 | ![](./png/framework.png) 13 | 14 | 15 | ## Setup 16 | 17 | ```bash 18 | conda create -n GraphControl python==3.9 19 | conda activate GraphControl 20 | conda install pytorch==2.1.0 torchaudio==2.1.0 cudatoolkit=12.1 -c pytorch -c conda-forge 21 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.1.0+cu121.html 22 | ``` 23 | 24 | ## Download GCC Pretrained Weight 25 | 26 | **Download GCC checkpoints** 27 | Download GCC checkpoint from https://drive.google.com/file/d/1lYW_idy9PwSdPEC7j9IH5I5Hc7Qv-22-/view and save it into `./checkpoint/gcc.pth`. 28 | 29 | ## For Attributed Graphs 30 | 31 | **Only GCC** 32 | 33 | ```bash 34 | CUDA_VISIBLE_DEVICES=0 python gcc.py --lr 1e-3 --epochs 100 --dataset Cora_ML --model GCC --use_adj --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 35 | ``` 36 | 37 | **GCC with GraphControl** 38 | 39 | ```bash 40 | CUDA_VISIBLE_DEVICES=0 python graphcontrol.py --dataset Cora_ML --epochs 100 --lr 0.5 --optimizer adamw --weight_decay 5e-4 --threshold 0.17 --walk_steps 256 --restart 0.8 --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 41 | ``` 42 | 43 | ## For Non-attributed Graphs 44 | 45 | **For non-attribute graphs, we need to generate nodes attributes through node2vec firstly** 46 | 47 | ```bash 48 | CUDA_VISIBLE_DEVICES=0 python node2vec.py --dataset Hindex --lr 1e-2 --epochs 100 49 | ``` 50 | 51 | **Then, we can train it as the same way with attributed graphs** 52 | 53 | ```bash 54 | CUDA_VISIBLE_DEVICES=0 python graphcontrol.py --dataset Hindex --epochs 100 --lr 0.1 --optimizer sgd --weight_decay 5e-4 --threshold 0.17 --walk_steps 256 --restart 0.5 --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 55 | ``` 56 | 57 | ## Illustration of arguements 58 | ``` 59 | --dataset: default Cora_ML, [Cora_ML, Photo, Physics, DBLP, usa, brazil, europe, Hindex] can also be choosen 60 | --model: default GCC_GraphControl, [GCC, GCC_GraphControl] can also be choosen. GCC refers to utilizing GCC as a pre-trained model and fine-tuning it on target data. On the other hand, GCC_GraphControl involves incorporating GraphControl with GCC to address the "transferability-specificity dilemma." Additional pre-trained models will be introduced in the updated version. 61 | ``` 62 | More details and explanations are in `utils/args.py` -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.utils import to_undirected, homophily 2 | import torch_geometric.transforms as T 3 | import copy 4 | import torch 5 | import os 6 | import numpy as np 7 | 8 | from .dataset import Amazon, Coauthor, Airports, CitationFull, HindexDataset 9 | from utils.random import reset_random_seed 10 | from utils.transforms import obtain_attributes 11 | 12 | 13 | dataset_dict = { 14 | 'Photo': Amazon, 15 | 'Physics': Coauthor, 16 | 'usa': Airports, 17 | 'brazil': Airports, 18 | 'europe': Airports, 19 | 'DBLP': CitationFull, 20 | 'Cora_ML': CitationFull, 21 | 'Hindex': HindexDataset 22 | } 23 | 24 | PATH = './datasets/data' 25 | 26 | def load_dataset(dataset_name, trans=None): 27 | if dataset_name in ['Hindex']: 28 | if trans == None: 29 | return dataset_dict[dataset_name](root=f'{PATH}/{dataset_name}') 30 | else: 31 | return dataset_dict[dataset_name](root=f'{PATH}/{dataset_name}', transform=T.Compose([trans])) 32 | else: 33 | if trans == None: 34 | return dataset_dict[dataset_name](root=PATH, name=dataset_name) 35 | else: 36 | return dataset_dict[dataset_name](root=PATH, name=dataset_name, transform=T.Compose([trans])) 37 | 38 | 39 | class NodeDataset: 40 | def __init__(self, dataset_name, trans=None, n_seeds=[0]) -> None: 41 | self.path = PATH 42 | self.dataset_name = dataset_name 43 | if dataset_name in ['Hindex']: 44 | self.dataset = dataset_dict[dataset_name](root=f'{self.path}/{dataset_name}', transform=trans) 45 | else: 46 | self.dataset = dataset_dict[dataset_name](root=f'{self.path}', name=dataset_name, transform=trans) 47 | 48 | self.num_classes = self.dataset.num_classes 49 | self.num_node_features = self.dataset.num_node_features 50 | 51 | assert len(self.dataset) == 1, "Training data consists of multiple graphs!" 52 | 53 | self.data = self.dataset[0] 54 | 55 | # parse it into undirected graph 56 | edge_index = to_undirected(self.data.edge_index) 57 | self.data.edge_index = edge_index 58 | self.num_nodes = self.data.x.shape[0] 59 | 60 | # backup original node attributes and edges 61 | self.backup_x = copy.deepcopy(self.data.x) 62 | self.backup_edges = copy.deepcopy(self.data.edge_index) 63 | self.random_split = False 64 | 65 | # For datasets without node attributes, we will use node embeddings from Node2Vec as their node attributes 66 | attr_path = f'{PATH}/{dataset_name}/processed/node2vec.pt' 67 | if dataset_name in ['USA', 'Europe', 'Brazil', 'Hindex'] and os.path.exists(attr_path): 68 | x = torch.load(attr_path) 69 | self.data.x = x.detach() 70 | 71 | # If the dataset does not contain preset splits, we will randomly split it into train:test=1:9 twenty times 72 | if not hasattr(self.data, 'train_mask'): 73 | self.random_split = True 74 | num_train = int(self.num_nodes*0.1) 75 | 76 | train_mask_list = [] 77 | test_mask_list = [] 78 | for seed in n_seeds: 79 | reset_random_seed(seed) 80 | 81 | rand_node_idx = torch.randperm(self.num_nodes) 82 | train_idx = rand_node_idx[:num_train] 83 | train_mask = torch.zeros(self.num_nodes).bool() 84 | train_mask[train_idx] = True 85 | 86 | test_mask = torch.ones_like(train_mask).bool() 87 | test_mask[train_idx] = False 88 | train_mask_list.append(train_mask.unsqueeze(1)) 89 | test_mask_list.append(test_mask.unsqueeze(1)) 90 | 91 | self.data.train_mask = torch.cat(train_mask_list, dim=1) 92 | self.data.test_mask = torch.cat(test_mask_list, dim=1) 93 | 94 | 95 | def generate_subgraph(self): 96 | pass 97 | 98 | def split_train_test(self, split_ratio=0.8): 99 | raise NotImplementedError('do not set parameter ') 100 | 101 | def to(self, device): 102 | self.data = self.data.to(device) 103 | 104 | def replace_node_attributes(self, use_adj, threshold, num_dim): 105 | self.num_node_features = num_dim 106 | self.data.x = obtain_attributes(self.data, use_adj, threshold, num_dim) 107 | 108 | def obtain_node_attributes(self, use_adj, threshold=0.1, num_dim=32): 109 | return obtain_attributes(self.data, use_adj, threshold, num_dim) 110 | 111 | def print_statistics(self): 112 | h = homophily(self.data.edge_index, self.data.y) 113 | from collections import Counter 114 | if len(self.data.y.shape) >= 2: # For one-hot labels 115 | y = self.data.y.argmax(1) 116 | else: 117 | y = self.data.y 118 | count = Counter(y.tolist()) 119 | total_num = sum(count.values()) 120 | class_ratio = {} 121 | for key, value in count.items(): 122 | r = round(value / total_num, 2) 123 | class_ratio[key] = r 124 | print(f'{self.dataset_name}: Number of nodes: {self.num_nodes}, Dimension of features: {self.num_node_features}, Number of edges: {self.data.edge_index.shape[1]}, Number of classes: {self.num_classes}, Homophily: {h}, Class ratio: {class_ratio}.') 125 | if self.random_split: 126 | print('The dataset does not contain preset splits, we randomly split the dataset twenty times. Train: teset = 1:9') 127 | else: 128 | print('We use the preset splits.') 129 | 130 | 131 | if __name__ == '__main__': 132 | dataset = NodeDataset('Hindex') 133 | print(dataset) 134 | -------------------------------------------------------------------------------- /datasets/dataset/Hindex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | from torch_geometric.data import InMemoryDataset, Data 4 | from collections import defaultdict 5 | import numpy as np 6 | 7 | 8 | class HindexDataset(InMemoryDataset): 9 | 10 | def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): 11 | self.name = 'Hindex' 12 | self.root = root 13 | 14 | super().__init__(self.root, transform, pre_transform, pre_filter) 15 | self.data, self.slices = torch.load(self.processed_paths[0]) 16 | 17 | 18 | @property 19 | def raw_file_names(self): 20 | return ['aminer_hindex_rand20intop200_5000.edgelist', 'aminer_hindex_rand20intop200_5000.nodelabel'] 21 | 22 | @property 23 | def processed_file_names(self): 24 | return ['data.pt'] 25 | 26 | @property 27 | def raw_dir(self) -> str: 28 | return osp.join(self.root, 'raw') 29 | 30 | def process(self): 31 | # Read data into huge `Data` list. 32 | edge_index, y, self.node2id = self._preprocess(self.raw_paths[0], self.raw_paths[1]) 33 | data = Data(x=torch.zeros(y.size(0), 1), edge_index=edge_index, y=y.argmax(1)) 34 | data_list = [data] 35 | 36 | if self.pre_filter is not None: 37 | data_list = [data for data in data_list if self.pre_filter(data)] 38 | 39 | if self.pre_transform is not None: 40 | data_list = [self.pre_transform(data) for data in data_list] 41 | 42 | data, slices = self.collate(data_list) 43 | torch.save((data, slices), self.processed_paths[0]) 44 | 45 | def _preprocess(self, edge_list_path, node_label_path): 46 | with open(edge_list_path) as f: 47 | edge_list = [] 48 | node2id = defaultdict(int) 49 | for line in f: 50 | x, y = list(map(int, line.split())) 51 | # Reindex 52 | if x not in node2id: 53 | node2id[x] = len(node2id) 54 | if y not in node2id: 55 | node2id[y] = len(node2id) 56 | edge_list.append([node2id[x], node2id[y]]) 57 | edge_list.append([node2id[y], node2id[x]]) 58 | 59 | num_nodes = len(node2id) 60 | with open(node_label_path) as f: 61 | nodes = [] 62 | labels = [] 63 | label2id = defaultdict(int) 64 | for line in f: 65 | x, label = list(map(int, line.split())) 66 | if label not in label2id: 67 | label2id[label] = len(label2id) 68 | nodes.append(node2id[x]) 69 | if "Hindex" in self.name: 70 | labels.append(label) 71 | else: 72 | labels.append(label2id[label]) 73 | if "Hindex" in self.name: 74 | median = np.median(labels) 75 | labels = [int(label > median) for label in labels] 76 | assert num_nodes == len(set(nodes)) 77 | y = torch.zeros(num_nodes, len(label2id)) 78 | y[nodes, labels] = 1 79 | return torch.LongTensor(edge_list).t(), y, node2id -------------------------------------------------------------------------------- /datasets/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .Hindex import HindexDataset 2 | from torch_geometric.datasets import Amazon, Coauthor, Airports, CitationFull 3 | -------------------------------------------------------------------------------- /gcc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | from torch_geometric.loader import ShaDowKHopSampler, DataLoader 5 | 6 | from utils.random import reset_random_seed 7 | from utils.args import Arguments 8 | from models import load_model 9 | from datasets import NodeDataset 10 | from utils.transforms import process_attributes 11 | from utils.sampling import ego_graphs_sampler, collect_subgraphs 12 | 13 | 14 | def preprocess(config, dataset_obj): 15 | kwargs = {'batch_size': config.batch_size, 'num_workers': 3, 'persistent_workers': True} 16 | 17 | print('generating subgraphs....') 18 | 19 | train_loader, test_loader = None, None 20 | 21 | train_idx = dataset_obj.data.train_mask.nonzero().squeeze() 22 | test_idx = dataset_obj.data.test_mask.nonzero().squeeze() 23 | 24 | train_graphs = collect_subgraphs(train_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart) 25 | test_graphs = collect_subgraphs(test_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart) 26 | 27 | if config.use_adj: 28 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in train_graphs] 29 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in test_graphs] 30 | 31 | dataset_obj.num_node_features = config.num_dim 32 | train_loader = DataLoader(train_graphs, shuffle=True, **kwargs) 33 | test_loader = DataLoader(test_graphs, **kwargs) 34 | 35 | return train_loader, test_loader 36 | 37 | 38 | def main(config): 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | 41 | dataset_obj = NodeDataset(config.dataset, n_seeds=config.seeds) 42 | dataset_obj.print_statistics() 43 | 44 | acc_list = [] 45 | 46 | train_masks = dataset_obj.data.train_mask 47 | test_masks = dataset_obj.data.test_mask 48 | 49 | for _, seed in enumerate(config.seeds): 50 | reset_random_seed(seed) 51 | 52 | if dataset_obj.random_split: 53 | dataset_obj.data.train_mask = train_masks[:, seed] 54 | dataset_obj.data.test_mask = test_masks[:, seed] 55 | 56 | train_loader, test_loader = preprocess(config, dataset_obj) 57 | model = load_model(dataset_obj.num_node_features, dataset_obj.num_classes, config).to(device) 58 | 59 | # training model 60 | train_subgraph(config, model, train_loader, device) 61 | acc = eval_subgraph(config, model, test_loader, device) 62 | 63 | acc_list.append(acc) 64 | print(f'Seed: {seed}, Accuracy: {acc:.4f}') 65 | 66 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 67 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}") 68 | 69 | def train_subgraph(config, model, train_loader, device): 70 | 71 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) 72 | criterion = torch.nn.CrossEntropyLoss() 73 | model.train() 74 | for _ in tqdm(range(config.epochs)): 75 | for batch in train_loader: 76 | batch = batch.to(device) 77 | optimizer.zero_grad() 78 | if not hasattr(batch, 'root_n_id'): 79 | batch.root_n_id = batch.root_n_index 80 | # sign flip, because the sign of eigen-vectors can be filpped randomly (annotate this operate if we conduct eigen-decomposition on full graph) 81 | sign_flip = torch.rand(batch.x.size(1)).to(device) 82 | sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0 83 | batch.x = batch.x*sign_flip.unsqueeze(0) 84 | 85 | out = model.forward_subgraph(batch.x, batch.edge_index, batch.batch, batch.root_n_id) 86 | loss = criterion(out, batch.y) 87 | loss.backward() 88 | optimizer.step() 89 | 90 | 91 | def eval_subgraph(config, model, test_loader, device): 92 | model.eval() 93 | 94 | correct = 0 95 | total_num = 0 96 | for batch in test_loader: 97 | batch = batch.to(device) 98 | if not hasattr(batch, 'root_n_id'): 99 | batch.root_n_id = batch.root_n_index 100 | 101 | preds = model.forward_subgraph(batch.x, batch.edge_index, batch.batch, batch.root_n_id).argmax(dim=1) 102 | correct += (preds == batch.y).sum().item() 103 | total_num += batch.y.shape[0] 104 | acc = correct / total_num 105 | return acc 106 | 107 | if __name__ == '__main__': 108 | config = Arguments().parse_args() 109 | 110 | main(config) -------------------------------------------------------------------------------- /graphcontrol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.loader import DataLoader 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | 7 | from utils.random import reset_random_seed 8 | from utils.args import Arguments 9 | from utils.sampling import collect_subgraphs 10 | from utils.transforms import process_attributes, obtain_attributes 11 | from models import load_model 12 | from datasets import NodeDataset 13 | from optimizers import create_optimizer 14 | 15 | 16 | def preprocess(config, dataset_obj, device): 17 | kwargs = {'batch_size': config.batch_size, 'num_workers': 4, 'persistent_workers': True, 'pin_memory': True} 18 | 19 | print('generating subgraphs....') 20 | 21 | train_idx = dataset_obj.data.train_mask.nonzero().squeeze() 22 | test_idx = dataset_obj.data.test_mask.nonzero().squeeze() 23 | 24 | train_graphs = collect_subgraphs(train_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart) 25 | test_graphs = collect_subgraphs(test_idx, dataset_obj.data, walk_steps=config.walk_steps, restart_ratio=config.restart) 26 | 27 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in train_graphs] 28 | [process_attributes(g, use_adj=config.use_adj, threshold=config.threshold, num_dim=config.num_dim) for g in test_graphs] 29 | 30 | 31 | train_loader = DataLoader(train_graphs, shuffle=True, **kwargs) 32 | test_loader = DataLoader(test_graphs, **kwargs) 33 | 34 | return train_loader, test_loader 35 | 36 | 37 | def finetune(config, model, train_loader, device, full_x_sim, test_loader): 38 | # freeze the pre-trained encoder (left branch) 39 | for k, v in model.named_parameters(): 40 | if 'encoder' in k: 41 | v.requires_grad = False 42 | 43 | model.reset_classifier() 44 | eval_steps = 3 45 | patience = 15 46 | count = 0 47 | best_acc = 0 48 | 49 | params = filter(lambda p: p.requires_grad, model.parameters()) 50 | optimizer = create_optimizer(name=config.optimizer, parameters=params, lr=config.lr, weight_decay=config.weight_decay) 51 | criterion = torch.nn.CrossEntropyLoss() 52 | process_bar = tqdm(range(config.epochs)) 53 | 54 | for epoch in process_bar: 55 | for data in train_loader: 56 | optimizer.zero_grad() 57 | model.train() 58 | 59 | data = data.to(device) 60 | 61 | if not hasattr(data, 'root_n_id'): 62 | data.root_n_id = data.root_n_index 63 | 64 | sign_flip = torch.rand(data.x.size(1)).to(device) 65 | sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0 66 | x = data.x * sign_flip.unsqueeze(0) 67 | 68 | x_sim = full_x_sim[data.original_idx] 69 | preds = model.forward_subgraph(x, x_sim, data.edge_index, data.batch, data.root_n_id, frozen=True) 70 | 71 | loss = criterion(preds, data.y) 72 | loss.backward() 73 | optimizer.step() 74 | 75 | if epoch % eval_steps == 0: 76 | acc = eval_subgraph(config, model, test_loader, device, full_x_sim) 77 | process_bar.set_postfix({"Epoch": epoch, "Accuracy": f"{acc:.4f}"}) 78 | if best_acc < acc: 79 | best_acc = acc 80 | count = 0 81 | else: 82 | count += 1 83 | 84 | if count == patience: 85 | break 86 | 87 | return best_acc 88 | 89 | 90 | def main(config): 91 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 92 | 93 | dataset_obj = NodeDataset(config.dataset, n_seeds=config.seeds) 94 | dataset_obj.print_statistics() 95 | 96 | # For large graph, we use cpu to preprocess it rather than gpu because of OOM problem. 97 | if dataset_obj.num_nodes < 30000: 98 | dataset_obj.to(device) 99 | x_sim = obtain_attributes(dataset_obj.data, use_adj=False, threshold=config.threshold).to(device) 100 | 101 | dataset_obj.to('cpu') # Otherwise the deepcopy will raise an error 102 | num_node_features = config.num_dim 103 | 104 | train_masks = dataset_obj.data.train_mask 105 | test_masks = dataset_obj.data.test_mask 106 | 107 | acc_list = [] 108 | 109 | for i, seed in enumerate(config.seeds): 110 | reset_random_seed(seed) 111 | if dataset_obj.random_split: 112 | dataset_obj.data.train_mask = train_masks[:, seed] 113 | dataset_obj.data.test_mask = test_masks[:, seed] 114 | 115 | train_loader, test_loader = preprocess(config, dataset_obj, device) 116 | 117 | model = load_model(num_node_features, dataset_obj.num_classes, config) 118 | model = model.to(device) 119 | 120 | # finetuning model 121 | best_acc = finetune(config, model, train_loader, device, x_sim, test_loader) 122 | 123 | acc_list.append(best_acc) 124 | print(f'Seed: {seed}, Accuracy: {best_acc:.4f}') 125 | 126 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 127 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}") 128 | 129 | 130 | def eval_subgraph(config, model, test_loader, device, full_x_sim): 131 | model.eval() 132 | 133 | correct = 0 134 | total_num = 0 135 | for batch in test_loader: 136 | batch = batch.to(device) 137 | if not hasattr(batch, 'root_n_id'): 138 | batch.root_n_id = batch.root_n_index 139 | x_sim = full_x_sim[batch.original_idx] 140 | preds = model.forward_subgraph(batch.x, x_sim, batch.edge_index, batch.batch, batch.root_n_id, frozen=True).argmax(dim=1) 141 | correct += (preds == batch.y).sum().item() 142 | total_num += batch.y.shape[0] 143 | acc = correct / total_num 144 | return acc 145 | 146 | if __name__ == '__main__': 147 | config = Arguments().parse_args() 148 | 149 | main(config) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcc import GCC 2 | from .gcc_graphcontrol import GCC_GraphControl 3 | 4 | from .model_manager import load_model -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import ModuleList 4 | from torch_geometric.nn.inits import glorot 5 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, global_mean_pool 6 | from torch.nn import BatchNorm1d, Identity 7 | import torch.nn as nn 8 | from utils.register import register 9 | 10 | 11 | def get_activation(name: str): 12 | activations = { 13 | 'relu': F.relu, 14 | 'hardtanh': F.hardtanh, 15 | 'elu': F.elu, 16 | 'leakyrelu': F.leaky_relu, 17 | 'prelu': torch.nn.PReLU(), 18 | 'rrelu': F.rrelu 19 | } 20 | return activations[name] 21 | 22 | 23 | @register.encoder_register 24 | class GCN_Encoder(torch.nn.Module): 25 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True): 26 | super(GCN_Encoder, self).__init__() 27 | self.layer_num = layer_num 28 | self.hidden = hidden_size 29 | self.input_dim = input_dim 30 | self.activation = get_activation(activation) 31 | self.dropout = torch.nn.Dropout(dropout) 32 | self.last_act = last_activation 33 | self.use_bn = use_bn 34 | 35 | self.convs = ModuleList() 36 | self.bns = ModuleList() 37 | # self.acts = ModuleList() 38 | if self.layer_num > 1: 39 | self.convs.append(GCNConv(input_dim, hidden_size)) 40 | for i in range(layer_num-2): 41 | self.convs.append(GCNConv(hidden_size, hidden_size)) 42 | # glorot(self.convs[i].weight) # initialization 43 | self.convs.append(GCNConv(hidden_size, hidden_size)) 44 | # glorot(self.convs[-1].weight) 45 | for i in range(layer_num): 46 | if use_bn: 47 | self.bns.append(BatchNorm1d(hidden_size)) 48 | else: 49 | self.bns.append(Identity()) 50 | 51 | else: # one layer gcn 52 | self.convs.append(GCNConv(input_dim, hidden_size)) 53 | # glorot(self.convs[-1].weight) 54 | if use_bn: 55 | self.bns.append(BatchNorm1d(hidden_size)) 56 | else: 57 | self.bns.append(Identity()) 58 | # self.acts.append(self.activation) 59 | 60 | def forward(self, x, edge_index, edge_weight=None): 61 | # print('Inside Model: num graphs: {}, device: {}'.format( 62 | # data.num_graphs, data.batch.device)) 63 | # x, edge_index = data.x, data.edge_index 64 | for i in range(self.layer_num): 65 | # x = self.convs[i](x, edge_index, edge_weight) 66 | # print(i, x.dtype, self.convs[i].lin.weight.dtype) 67 | x = self.bns[i](self.convs[i](x, edge_index, edge_weight)) 68 | if i == self.layer_num - 1 and not self.last_act: 69 | pass 70 | # print(i, 'pass last relu') 71 | else: 72 | x = self.activation(x) 73 | x = self.dropout(x) 74 | # x = self.activation(self.convs[i](x, edge_index, edge_weight)) 75 | # x = self.bns[i](x) 76 | # x = self.activation(self.bns[i](self.convs[i](x, edge_index))) 77 | return x 78 | 79 | def reset_parameters(self): 80 | for i in range(self.layer_num): 81 | self.convs[i].reset_parameters() 82 | if self.use_bn: 83 | self.bns[i].reset_parameters() 84 | 85 | 86 | @register.encoder_register 87 | class GIN_Encoder(torch.nn.Module): 88 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True): 89 | super(GIN_Encoder, self).__init__() 90 | self.layer_num = layer_num 91 | self.hidden_size = hidden_size 92 | self.input_dim = input_dim 93 | self.activation = get_activation(activation) 94 | self.dropout = torch.nn.Dropout(dropout) 95 | self.last_act = last_activation 96 | self.use_bn = use_bn 97 | 98 | self.convs = ModuleList() 99 | self.bns = ModuleList() 100 | 101 | self.readout = global_mean_pool 102 | if self.layer_num > 1: 103 | self.convs.append(GINConv(nn.Sequential(nn.Linear(input_dim, hidden_size), 104 | nn.BatchNorm1d(hidden_size), nn.ReLU(), 105 | nn.Linear(hidden_size, hidden_size)))) 106 | for i in range(layer_num-1): 107 | self.convs.append(GINConv(nn.Sequential(nn.Linear(hidden_size, hidden_size), 108 | nn.BatchNorm1d(hidden_size), nn.ReLU(), 109 | nn.Linear(hidden_size, hidden_size)))) 110 | for i in range(layer_num): 111 | if use_bn: 112 | self.bns.append(BatchNorm1d(hidden_size)) 113 | else: 114 | self.bns.append(Identity()) 115 | 116 | else: 117 | self.convs.append(GINConv(nn.Sequential(nn.Linear(input_dim, hidden_size), 118 | nn.BatchNorm1d(hidden_size), nn.ReLU(), 119 | nn.Linear(hidden_size, hidden_size)))) 120 | if use_bn: 121 | self.bns.append(BatchNorm1d(hidden_size)) 122 | else: 123 | self.bns.append(Identity()) 124 | 125 | def forward(self, x, edge_index, **kwargs): 126 | for i in range(self.layer_num): 127 | x = self.bns[i](self.convs[i](x, edge_index)) 128 | if i == self.layer_num - 1 and not self.last_act: 129 | pass 130 | else: 131 | x = self.activation(x) 132 | x = self.dropout(x) 133 | 134 | return x 135 | 136 | def reset_parameters(self): 137 | for i in range(self.layer_num): 138 | self.convs[i].reset_parameters() 139 | if self.use_bn: 140 | self.bns[i].reset_parameters() 141 | 142 | 143 | @register.encoder_register 144 | class GAT_Encoder(torch.nn.Module): 145 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True): 146 | super(GAT_Encoder, self).__init__() 147 | self.layer_num = layer_num 148 | self.hidden = hidden_size 149 | self.input_dim = input_dim 150 | self.activation = get_activation(activation) 151 | self.dropout = torch.nn.Dropout(dropout) 152 | self.last_act = last_activation 153 | self.use_bn = use_bn 154 | 155 | self.convs = ModuleList() 156 | self.bns = ModuleList() 157 | if self.layer_num > 1: 158 | self.convs.append(GATConv(input_dim, hidden_size)) 159 | for i in range(layer_num-1): 160 | self.convs.append(GATConv(hidden_size, hidden_size)) 161 | self.bns.append(BatchNorm1d(hidden_size)) 162 | else: 163 | self.convs.append(GATConv(input_dim, hidden_size)) 164 | self.bns.append(BatchNorm1d(hidden_size)) 165 | 166 | def forward(self, x, edge_index, **kwargs): 167 | for i in range(self.layer_num): 168 | x = self.bns[i](self.convs[i](x, edge_index)) 169 | if i == self.layer_num - 1 and not self.last_act: 170 | pass 171 | else: 172 | x = self.activation(x) 173 | x = self.dropout(x) 174 | return x 175 | 176 | def reset_parameters(self): 177 | for i in range(self.layer_num): 178 | self.convs[i].reset_parameters() 179 | self.bns[i].reset_parameters() 180 | 181 | 182 | @register.encoder_register 183 | class MLP_Encoder(torch.nn.Module): 184 | def __init__(self, input_dim, layer_num=2, hidden_size=128, activation="relu", dropout=0.5, use_bn=True, last_activation=True): 185 | super(MLP_Encoder, self).__init__() 186 | self.layer_num = layer_num 187 | self.hidden_size = hidden_size 188 | self.input_dim = input_dim 189 | self.activation = get_activation(activation) 190 | self.dropout = torch.nn.Dropout(dropout) 191 | self.last_act = last_activation 192 | self.use_bn = use_bn 193 | 194 | self.convs = ModuleList() 195 | self.bns = ModuleList() 196 | 197 | self.readout = global_mean_pool 198 | if self.layer_num > 1: 199 | self.convs.append(nn.Linear(input_dim, hidden_size)) 200 | for i in range(layer_num-1): 201 | self.convs.append(nn.Linear(hidden_size, hidden_size)) 202 | for i in range(layer_num): 203 | if use_bn: 204 | self.bns.append(BatchNorm1d(hidden_size)) 205 | else: 206 | self.bns.append(Identity()) 207 | 208 | else: 209 | self.convs.append(nn.Linear(input_dim, hidden_size)) 210 | if use_bn: 211 | self.bns.append(BatchNorm1d(hidden_size)) 212 | else: 213 | self.bns.append(Identity()) 214 | 215 | def forward(self, x, edge_index, **kwargs): 216 | for i in range(self.layer_num): 217 | x = self.bns[i](self.convs[i](x)) 218 | if i == self.layer_num - 1 and not self.last_act: 219 | pass 220 | else: 221 | x = self.activation(x) 222 | x = self.dropout(x) 223 | 224 | return x 225 | 226 | def reset_parameters(self): 227 | for i in range(self.layer_num): 228 | self.convs[i].reset_parameters() 229 | if self.use_bn: 230 | self.bns[i].reset_parameters() -------------------------------------------------------------------------------- /models/gcc.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_geometric.nn import GINConv 7 | from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool 8 | import torch_geometric 9 | from utils.register import register 10 | 11 | 12 | def change_params_key(params): 13 | """ 14 | Change GCC source parameters keys 15 | """ 16 | for key in list(params.keys()): 17 | sp = key.split('.') 18 | if len(sp) > 3 and sp[3] == 'apply_func': 19 | sp[3] = 'nn' 20 | str = '.'.join(sp) 21 | params[str] = params[key] 22 | params.pop(key) 23 | if sp[0] == 'set2set': 24 | params.pop(key) 25 | 26 | 27 | @register.model_register 28 | class GCC(nn.Module): 29 | """ 30 | MPNN from 31 | `Neural Message Passing for Quantum Chemistry `__ 32 | 33 | Parameters 34 | ---------- 35 | node_input_dim : int 36 | Dimension of input node feature, default to be 15. 37 | edge_input_dim : int 38 | Dimension of input edge feature, default to be 15. 39 | output_dim : int 40 | Dimension of prediction, default to be 12. 41 | node_hidden_dim : int 42 | Dimension of node feature in hidden layers, default to be 64. 43 | edge_hidden_dim : int 44 | Dimension of edge feature in hidden layers, default to be 128. 45 | num_step_message_passing : int 46 | Number of message passing steps, default to be 6. 47 | num_step_set2set : int 48 | Number of set2set steps 49 | num_layer_set2set : int 50 | Number of set2set layers 51 | """ 52 | 53 | def __init__( 54 | self, 55 | positional_embedding_size=32, 56 | max_node_freq=8, 57 | max_edge_freq=8, 58 | max_degree=128, 59 | freq_embedding_size=32, 60 | degree_embedding_size=32, 61 | output_dim=32, 62 | node_hidden_dim=32, 63 | edge_hidden_dim=32, 64 | num_layers=6, 65 | num_heads=4, 66 | num_step_set2set=6, 67 | num_layer_set2set=3, 68 | norm=False, 69 | gnn_model="mpnn", 70 | degree_input=False, 71 | lstm_as_gate=False, 72 | num_classes=10 73 | ): 74 | super(GCC, self).__init__() 75 | 76 | if degree_input: 77 | node_input_dim = positional_embedding_size + degree_embedding_size + 1 78 | else: 79 | node_input_dim = positional_embedding_size + 1 80 | # node_input_dim = ( 81 | # positional_embedding_size + freq_embedding_size + degree_embedding_size + 3 82 | # ) 83 | edge_input_dim = freq_embedding_size + 1 84 | self.gnn = UnsupervisedGIN( 85 | num_layers=num_layers, 86 | num_mlp_layers=2, 87 | input_dim=node_input_dim, 88 | hidden_dim=node_hidden_dim, 89 | output_dim=output_dim, 90 | final_dropout=0.5, 91 | learn_eps=False, 92 | graph_pooling_type="sum", 93 | neighbor_pooling_type="sum", 94 | use_selayer=False, 95 | ) 96 | self.gnn_model = gnn_model 97 | 98 | self.max_node_freq = max_node_freq 99 | self.max_edge_freq = max_edge_freq 100 | self.max_degree = max_degree 101 | self.degree_input = degree_input 102 | 103 | 104 | if degree_input: 105 | self.degree_embedding = nn.Embedding( 106 | num_embeddings=max_degree + 1, embedding_dim=degree_embedding_size 107 | ) 108 | 109 | self.lin_readout = nn.Sequential( 110 | nn.Linear(2 * node_hidden_dim, node_hidden_dim), 111 | nn.ReLU(), 112 | nn.Linear(node_hidden_dim, output_dim), 113 | ) 114 | self.norm = norm 115 | 116 | def forward(self, x, edge_index, edge_weight=None, frozen=False, **kwargs): 117 | raise NotImplementedError('Please use --subsampling') 118 | 119 | def forward_subgraph(self, x, edge_index, batch, root_n_id, edge_weight=None, frozen=False, **kwargs): 120 | """Predict molecule labels 121 | 122 | Parameters 123 | ---------- 124 | g : DGLGraph 125 | Input DGLGraph for molecule(s) 126 | n_feat : tensor of dtype float32 and shape (B1, D1) 127 | Node features. B1 for number of nodes and D1 for 128 | the node feature size. 129 | e_feat : tensor of dtype float32 and shape (B2, D2) 130 | Edge features. B2 for number of edges and D2 for 131 | the edge feature size. 132 | 133 | Returns 134 | ------- 135 | res : Predicted labels 136 | """ 137 | # nfreq = g.ndata["nfreq"] 138 | if self.degree_input: 139 | # device = g.ndata["seed"].device 140 | device = x.device 141 | degrees = torch_geometric.utils.degree(edge_index[0]).long().to(device) 142 | ego_indicator = torch.zeros(x.shape[0]).bool().to(device) 143 | ego_indicator[root_n_id] = True 144 | 145 | n_feat = torch.cat( 146 | ( 147 | x, 148 | self.degree_embedding(degrees.clamp(0, self.max_degree)), 149 | ego_indicator.unsqueeze(1).float(), 150 | ), 151 | dim=-1, 152 | ) 153 | else: 154 | n_feat = torch.cat( 155 | ( 156 | x, 157 | ), 158 | dim=-1, 159 | ) 160 | 161 | e_feat = None 162 | 163 | x, all_outputs = self.gnn(n_feat, edge_index, batch) 164 | 165 | if self.norm: 166 | x = F.normalize(x, p=2, dim=-1, eps=1e-5) 167 | 168 | return x 169 | 170 | 171 | class SELayer(nn.Module): 172 | """Squeeze-and-excitation networks""" 173 | 174 | def __init__(self, in_channels, se_channels): 175 | super(SELayer, self).__init__() 176 | 177 | self.in_channels = in_channels 178 | self.se_channels = se_channels 179 | 180 | self.encoder_decoder = nn.Sequential( 181 | nn.Linear(in_channels, se_channels), 182 | nn.ELU(), 183 | nn.Linear(se_channels, in_channels), 184 | nn.Sigmoid(), 185 | ) 186 | 187 | def forward(self, x): 188 | """""" 189 | # Aggregate input representation 190 | x_global = torch.mean(x, dim=0) 191 | # Compute reweighting vector s 192 | s = self.encoder_decoder(x_global) 193 | 194 | return x * s 195 | 196 | 197 | class ApplyNodeFunc(nn.Module): 198 | """Update the node feature hv with MLP, BN and ReLU.""" 199 | 200 | def __init__(self, mlp, use_selayer): 201 | super(ApplyNodeFunc, self).__init__() 202 | self.mlp = mlp 203 | self.bn = ( 204 | SELayer(self.mlp.output_dim, int(np.sqrt(self.mlp.output_dim))) 205 | if use_selayer 206 | else nn.BatchNorm1d(self.mlp.output_dim) 207 | ) 208 | 209 | def forward(self, h): 210 | h = self.mlp(h) 211 | h = self.bn(h) 212 | h = F.relu(h) 213 | return h 214 | 215 | 216 | class MLP(nn.Module): 217 | """MLP with linear output""" 218 | 219 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_selayer): 220 | """MLP layers construction 221 | 222 | Paramters 223 | --------- 224 | num_layers: int 225 | The number of linear layers 226 | input_dim: int 227 | The dimensionality of input features 228 | hidden_dim: int 229 | The dimensionality of hidden units at ALL layers 230 | output_dim: int 231 | The number of classes for prediction 232 | 233 | """ 234 | super(MLP, self).__init__() 235 | self.linear_or_not = True # default is linear model 236 | self.num_layers = num_layers 237 | self.output_dim = output_dim 238 | 239 | if num_layers < 1: 240 | raise ValueError("number of layers should be positive!") 241 | elif num_layers == 1: 242 | # Linear model 243 | self.linear = nn.Linear(input_dim, output_dim) 244 | else: 245 | # Multi-layer model 246 | self.linear_or_not = False 247 | self.linears = torch.nn.ModuleList() 248 | self.batch_norms = torch.nn.ModuleList() 249 | 250 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 251 | for layer in range(num_layers - 2): 252 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 253 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 254 | 255 | for layer in range(num_layers - 1): 256 | self.batch_norms.append( 257 | SELayer(hidden_dim, int(np.sqrt(hidden_dim))) 258 | if use_selayer 259 | else nn.BatchNorm1d(hidden_dim) 260 | ) 261 | 262 | def forward(self, x): 263 | if self.linear_or_not: 264 | # If linear model 265 | return self.linear(x) 266 | else: 267 | # If MLP 268 | h = x 269 | for i in range(self.num_layers - 1): 270 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 271 | return self.linears[-1](h) 272 | 273 | 274 | class UnsupervisedGIN(nn.Module): 275 | """GIN model""" 276 | 277 | def __init__( 278 | self, 279 | num_layers, 280 | num_mlp_layers, 281 | input_dim, 282 | hidden_dim, 283 | output_dim, 284 | final_dropout, 285 | learn_eps, 286 | graph_pooling_type, 287 | neighbor_pooling_type, 288 | use_selayer, 289 | ): 290 | """model parameters setting 291 | 292 | Paramters 293 | --------- 294 | num_layers: int 295 | The number of linear layers in the neural network 296 | num_mlp_layers: int 297 | The number of linear layers in mlps 298 | input_dim: int 299 | The dimensionality of input features 300 | hidden_dim: int 301 | The dimensionality of hidden units at ALL layers 302 | output_dim: int 303 | The number of classes for prediction 304 | final_dropout: float 305 | dropout ratio on the final linear layer 306 | learn_eps: boolean 307 | If True, learn epsilon to distinguish center nodes from neighbors 308 | If False, aggregate neighbors and center nodes altogether. 309 | neighbor_pooling_type: str 310 | how to aggregate neighbors (sum, mean, or max) 311 | graph_pooling_type: str 312 | how to aggregate entire nodes in a graph (sum, mean or max) 313 | 314 | """ 315 | super(UnsupervisedGIN, self).__init__() 316 | self.num_layers = num_layers 317 | self.learn_eps = learn_eps 318 | 319 | # List of MLPs 320 | self.ginlayers = torch.nn.ModuleList() 321 | self.batch_norms = torch.nn.ModuleList() 322 | 323 | for layer in range(self.num_layers - 1): 324 | if layer == 0: 325 | mlp = MLP( 326 | num_mlp_layers, input_dim, hidden_dim, hidden_dim, use_selayer 327 | ) 328 | else: 329 | mlp = MLP( 330 | num_mlp_layers, hidden_dim, hidden_dim, hidden_dim, use_selayer 331 | ) 332 | 333 | self.ginlayers.append( 334 | GINConv( 335 | ApplyNodeFunc(mlp, use_selayer), 336 | 0, 337 | self.learn_eps, 338 | ) 339 | ) 340 | self.batch_norms.append( 341 | SELayer(hidden_dim, int(np.sqrt(hidden_dim))) 342 | if use_selayer 343 | else nn.BatchNorm1d(hidden_dim) 344 | ) 345 | 346 | # Linear function for graph poolings of output of each layer 347 | # which maps the output of different layers into a prediction score 348 | self.linears_prediction = torch.nn.ModuleList() 349 | 350 | for layer in range(num_layers): 351 | if layer == 0: 352 | self.linears_prediction.append( 353 | nn.Linear(input_dim, output_dim)) 354 | else: 355 | self.linears_prediction.append( 356 | nn.Linear(hidden_dim, output_dim)) 357 | 358 | self.drop = nn.Dropout(final_dropout) 359 | 360 | if graph_pooling_type == "sum": 361 | self.pool = global_add_pool 362 | elif graph_pooling_type == "mean": 363 | self.pool = global_mean_pool 364 | elif graph_pooling_type == "max": 365 | self.pool = global_max_pool 366 | else: 367 | raise NotImplementedError 368 | 369 | def forward(self, x, edge_index, batch): 370 | # list of hidden representation at each layer (including input) 371 | hidden_rep = [x] 372 | h = x 373 | for i in range(self.num_layers - 1): 374 | h = self.ginlayers[i](h, edge_index) 375 | h = self.batch_norms[i](h) 376 | h = F.relu(h) 377 | hidden_rep.append(h) 378 | 379 | score_over_layer = 0 380 | 381 | # perform pooling over all nodes in each graph in every layer 382 | all_outputs = [] 383 | for i, h in list(enumerate(hidden_rep)): 384 | pooled_h = self.pool(h, batch) 385 | all_outputs.append(pooled_h) 386 | score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) 387 | 388 | return score_over_layer, all_outputs[1:] 389 | 390 | -------------------------------------------------------------------------------- /models/gcc_graphcontrol.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_geometric.nn import GINConv 7 | from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool 8 | from utils.register import register 9 | import copy 10 | from .gcc import GCC 11 | 12 | @register.model_register 13 | class GCC_GraphControl(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | **kwargs 18 | ): 19 | super(GCC_GraphControl, self).__init__() 20 | input_dim = kwargs['positional_embedding_size'] 21 | hidden_size = kwargs['node_hidden_dim'] 22 | output_dim = kwargs['num_classes'] 23 | 24 | self.encoder = GCC(**kwargs) 25 | self.trainable_copy = copy.deepcopy(self.encoder) 26 | 27 | self.zero_conv1 = torch.nn.Linear(input_dim, input_dim) 28 | self.zero_conv2 = torch.nn.Linear(hidden_size, hidden_size) 29 | 30 | self.linear_classifier = torch.nn.Linear(hidden_size, output_dim) 31 | 32 | with torch.no_grad(): 33 | self.zero_conv1.weight = torch.nn.Parameter(torch.zeros(input_dim, input_dim)) 34 | self.zero_conv1.bias = torch.nn.Parameter(torch.zeros(input_dim)) 35 | self.zero_conv2.weight = torch.nn.Parameter(torch.zeros(hidden_size, hidden_size)) 36 | self.zero_conv2.bias = torch.nn.Parameter(torch.zeros(hidden_size)) 37 | 38 | self.prompt = torch.nn.Parameter(torch.normal(mean=0, std=0.01, size=(1, input_dim))) 39 | 40 | def forward(self, x, edge_index, edge_weight=None, frozen=False, **kwargs): 41 | raise NotImplementedError('Please use --subsampling') 42 | 43 | def reset_classifier(self): 44 | self.linear_classifier.reset_parameters() 45 | 46 | def forward_subgraph(self, x, x_sim, edge_index, batch, root_n_id, edge_weight=None, frozen=False, **kwargs): 47 | if frozen: 48 | with torch.no_grad(): 49 | self.encoder.eval() 50 | out = self.encoder.forward_subgraph(x, edge_index, batch, root_n_id) 51 | 52 | x_down = self.zero_conv1(x_sim) 53 | x_down = x_down + x 54 | 55 | # for simplicity, we use edge_index to calculate degrees 56 | x_down = self.trainable_copy.forward_subgraph(x_down, edge_index, batch, root_n_id) 57 | 58 | x_down = self.zero_conv2(x_down) 59 | 60 | out = x_down + out 61 | else: 62 | raise NotImplementedError('Please freeze pre-trained models') 63 | 64 | x = self.linear_classifier(out) 65 | return x -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.register import register 4 | from .encoder import MLP_Encoder 5 | from torch_geometric.nn import global_mean_pool 6 | 7 | 8 | class Two_MLP_BN(torch.nn.Module): 9 | r""" 10 | Applies a non-linear transformation to contrastive space from representations. 11 | 12 | Args: 13 | hidden size of encoder, mlp hidden size, mlp output size 14 | """ 15 | def __init__(self, hidden, mlp_hid, mlp_out): 16 | 17 | super(Two_MLP_BN, self).__init__() 18 | self.proj = nn.Sequential( 19 | nn.Linear(hidden, mlp_hid), 20 | nn.BatchNorm1d(mlp_hid), 21 | nn.ReLU(), 22 | nn.Linear(mlp_hid, mlp_out) 23 | ) 24 | 25 | def forward(self, feat): 26 | return self.proj(feat) 27 | 28 | class Two_MLP(nn.Module): 29 | r"""MLP used for predictor. The MLP has one hidden layer. 30 | 31 | Args: 32 | input_size (int): Size of input features. 33 | output_size (int): Size of output features. 34 | hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`). 35 | """ 36 | def __init__(self, input_size, output_size, hidden_size=512): 37 | super().__init__() 38 | 39 | self.net = nn.Sequential( 40 | nn.Linear(input_size, hidden_size, bias=True), 41 | nn.PReLU(1), 42 | nn.Linear(hidden_size, output_size, bias=True) 43 | ) 44 | self.reset_parameters() 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | def reset_parameters(self): 50 | # kaiming_uniform 51 | for m in self.modules(): 52 | if isinstance(m, nn.Linear): 53 | m.reset_parameters() 54 | 55 | 56 | 57 | 58 | @register.model_register 59 | class MLP(torch.nn.Module): 60 | def __init__(self, input_dim, layer_num=2, hidden_size=128, output_dim=70, activation="relu", dropout=0.5, use_bn=False, **kargs): 61 | super(MLP, self).__init__() 62 | self.layer_num = layer_num 63 | self.hidden = hidden_size 64 | self.input_dim = input_dim 65 | 66 | self.encoder = MLP_Encoder(input_dim, layer_num, hidden_size, activation, dropout, use_bn) 67 | 68 | self.eigen_val_emb = torch.nn.Sequential(torch.nn.Linear(32, hidden_size), 69 | torch.nn.ReLU(), 70 | torch.nn.Linear(hidden_size, hidden_size)) 71 | 72 | self.classifier = torch.nn.Linear(hidden_size, output_dim) 73 | self.linear_classifier = torch.nn.Linear(hidden_size*2, output_dim) 74 | 75 | def forward(self, x, edge_index, edge_weight=None, frozen=False): 76 | if frozen: 77 | with torch.no_grad(): 78 | self.encoder.eval() 79 | x = self.encoder(x=x, edge_index=edge_index, edge_weight=edge_weight) 80 | else: 81 | x = self.encoder(x=x, edge_index=edge_index, edge_weight=edge_weight) 82 | 83 | x = self.classifier(x) 84 | return x 85 | 86 | def forward_subgraph(self, x, edge_index, batch, root_n_id, edge_weight=None, **kwargs): 87 | x = self.encoder(x=x, edge_index=edge_index, edge_weight=edge_weight) 88 | x = torch.cat([x[root_n_id], global_mean_pool(x, batch)], dim=-1) 89 | 90 | x = self.linear_classifier(x) # use linear classifier 91 | return x 92 | 93 | def reset_classifier(self): 94 | torch.nn.init.xavier_uniform_(self.classifier.weight.data) 95 | torch.nn.init.constant_(self.classifier.bias.data, 0) 96 | 97 | 98 | -------------------------------------------------------------------------------- /models/model_manager.py: -------------------------------------------------------------------------------- 1 | from utils.register import register 2 | import torch 3 | from .gcc import change_params_key 4 | 5 | 6 | def load_model(input_dim: int, output_dim: int, config): 7 | if config.model in ['GCC', 'GCC_GraphControl']: 8 | state_dict = torch.load('checkpoint/gcc.pth', map_location='cpu') 9 | opt = state_dict['opt'] 10 | model = register.models[config.model]( 11 | positional_embedding_size=opt.positional_embedding_size, 12 | max_node_freq=opt.max_node_freq, 13 | max_edge_freq=opt.max_edge_freq, 14 | max_degree=opt.max_degree, 15 | freq_embedding_size=opt.freq_embedding_size, 16 | degree_embedding_size=opt.degree_embedding_size, 17 | output_dim=opt.hidden_size, 18 | node_hidden_dim=opt.hidden_size, 19 | edge_hidden_dim=opt.hidden_size, 20 | num_layers=opt.num_layer, 21 | num_step_set2set=opt.set2set_iter, 22 | num_layer_set2set=opt.set2set_lstm_layer, 23 | gnn_model=opt.model, 24 | norm=opt.norm, 25 | degree_input=True, 26 | num_classes = output_dim 27 | ) 28 | params = state_dict['model'] 29 | change_params_key(params) 30 | 31 | if config.model == 'GCC': 32 | model.load_state_dict(params) 33 | return model 34 | elif config.model == 'GCC_GraphControl': 35 | model.encoder.load_state_dict(params) 36 | model.trainable_copy.load_state_dict(params) 37 | return model 38 | else: 39 | return register.models[config.model](input_dim=input_dim, output_dim=output_dim, **vars(config)) -------------------------------------------------------------------------------- /models/pooler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_scatter 3 | 4 | 5 | def subg_pooling(reps, data): 6 | batch_size = data.batch.max().cpu().item() + 1 7 | graphsize_perbat = torch.zeros(batch_size, dtype=data.batch.dtype, device=data.batch.device) 8 | tmp = torch.ones_like(data.batch) 9 | torch_scatter.scatter_add(tmp, data.batch, out=graphsize_perbat) 10 | center_indices = data.center 11 | center_mask = torch.zeros_like(data.batch) 12 | 13 | pointer = 0 14 | for i in range(0, batch_size): 15 | center_mask[center_indices[i] + pointer] = 1 16 | pointer += graphsize_perbat[i] 17 | 18 | center_mask = center_mask.bool() 19 | return reps[center_mask], data.y -------------------------------------------------------------------------------- /node2vec.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is used for generating node embeddings for datasets with graph topology. 3 | ''' 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | from torch_geometric.datasets import Planetoid 9 | from torch_geometric.nn import Node2Vec 10 | from datasets import NodeDataset 11 | from utils.args import Arguments 12 | from utils.random import reset_random_seed 13 | import torch 14 | import os 15 | import numpy as np 16 | from tqdm import tqdm 17 | 18 | PATH = f'./datasets/data' 19 | 20 | 21 | def train(): 22 | model.train() 23 | total_loss = 0 24 | for pos_rw, neg_rw in loader: 25 | optimizer.zero_grad() 26 | loss = model.loss(pos_rw.to(device), neg_rw.to(device)) 27 | loss.backward() 28 | optimizer.step() 29 | total_loss += loss.item() 30 | return total_loss / len(loader) 31 | 32 | @torch.no_grad() 33 | def test(): 34 | model.eval() 35 | z = model() 36 | acc = model.test(z[data.train_mask], data.y[data.train_mask], 37 | z[data.test_mask], data.y[data.test_mask], 38 | max_iter=150) 39 | return acc 40 | 41 | 42 | if __name__ == "__main__": 43 | config = Arguments().parse_args() 44 | 45 | dataset_obj = NodeDataset(dataset_name=config.dataset) 46 | dataset_obj.print_statistics() 47 | data = dataset_obj.data 48 | 49 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 50 | num_workers = 0 if sys.platform.startswith('win') else 4 51 | 52 | train_masks = dataset_obj.data.train_mask 53 | test_masks = dataset_obj.data.test_mask 54 | 55 | model = Node2Vec( 56 | data.edge_index, 57 | embedding_dim=config.emb_dim, # 256 for USA, 64 for Europe, 32 for Brazil 58 | walk_length=config.walk_length, 59 | context_size=config.context_size, 60 | walks_per_node=config.walk_per_nodes, 61 | sparse=True, 62 | ).to(device) 63 | 64 | loader = model.loader(batch_size=config.batch_size, shuffle=True, 65 | num_workers=num_workers) 66 | optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=config.lr) 67 | 68 | if dataset_obj.random_split: 69 | dataset_obj.data.train_mask = train_masks[:, 0] 70 | dataset_obj.data.test_mask = test_masks[:, 0] 71 | 72 | progress = tqdm(range(0, config.epochs)) 73 | for epoch in progress: 74 | loss = train() 75 | acc = test() 76 | progress.set_postfix_str(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Acc: {acc:.4f}') 77 | 78 | # save embedding 79 | torch.save(model.embedding.weight.cpu(), f=f'{PATH}/{config.dataset}/processed/node2vec.pt') 80 | 81 | acc_list = [] 82 | 83 | for seed in config.seeds: 84 | reset_random_seed(seed) 85 | dataset_obj.data.train_mask = train_masks[:, seed] 86 | dataset_obj.data.test_mask = test_masks[:, seed] 87 | acc = test() 88 | acc_list.append(acc) 89 | 90 | 91 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 92 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}") 93 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | optimizers_dicts = { 5 | 'sgd': torch.optim.SGD, 6 | 'adam': torch.optim.Adam, 7 | 'adamw': torch.optim.AdamW, 8 | 'radam': torch.optim.RAdam, 9 | 'nadam': torch.optim.NAdam 10 | } 11 | 12 | def create_optimizer(**kwargs): 13 | lr = kwargs['lr'] 14 | weight_decay = kwargs['weight_decay'] 15 | name = kwargs['name'] 16 | parameters = kwargs['parameters'] 17 | 18 | return optimizers_dicts[name](parameters, lr=lr, weight_decay=weight_decay) 19 | -------------------------------------------------------------------------------- /png/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wykk00/GraphControl/f437010a4c09f01baf1278e747c6951bdf3d9d17/png/framework.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wykk00/GraphControl/f437010a4c09f01baf1278e747c6951bdf3d9d17/utils/__init__.py -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class Arguments: 4 | def __init__(self) -> None: 5 | self.parser = argparse.ArgumentParser() 6 | self.parser.add_argument("--seeds", type=int, nargs="+", default=[0]) 7 | # Dataset 8 | self.parser.add_argument('--dataset', type=str, help="dataset name", default='Cora_ML') 9 | 10 | # Model configuration 11 | self.parser.add_argument('--layer_num', type=int, help="the number of encoder's layers", default=2) 12 | self.parser.add_argument('--hidden_size', type=int, help="the hidden size", default=128) 13 | self.parser.add_argument('--dropout', type=float, help="dropout rate", default=0.0) 14 | self.parser.add_argument('--activation', type=str, help="activation function", default='relu', 15 | choices=['relu', 'elu', 'hardtanh', 'leakyrelu', 'prelu', 'rrelu']) 16 | self.parser.add_argument('--use_bn', action='store_true', help="use BN or not") 17 | self.parser.add_argument('--model', type=str, help="model name", default='GCC_GraphControl', 18 | choices=['GCC', 'GCC_GraphControl']) 19 | 20 | # Training settings 21 | self.parser.add_argument('--optimizer', type=str, help="the kind of optimizer", default='adam', 22 | choices=['adam', 'sgd', 'adamw', 'nadam', 'radam']) 23 | self.parser.add_argument('--lr', type=float, help="learning rate", default=1e-3) 24 | self.parser.add_argument('--weight_decay', type=float, help="weight decay", default=5e-4) 25 | self.parser.add_argument('--epochs', type=int, help="training epochs", default=200) 26 | self.parser.add_argument('--batch_size', type=int, default=128) 27 | self.parser.add_argument('--finetune', action='store_true', help="Quickly find optim parameters") 28 | 29 | # Processing node attributes 30 | self.parser.add_argument('--use_adj', action='store_true', help="use eigen-vectors of adjacent matrix as node attributes") 31 | self.parser.add_argument('--threshold', type=float, help="the threshold for discreting similarity matrix", default=0.15) 32 | self.parser.add_argument('--num_dim', type=int, help="the number of replaced node attributes", default=32) 33 | # self.parser.add_argument('--ad_aug', action='store_true', help="adversarial augmentation") 34 | self.parser.add_argument('--restart', type=float, help="the restart ratio of random walking", default=0.3) 35 | self.parser.add_argument('--walk_steps', type=int, help="the number of random walk's steps", default=256) 36 | 37 | # Node2vec config 38 | self.parser.add_argument('--emb_dim', type=int, default=128, help="Embedding dim for node2vec") 39 | self.parser.add_argument('--walk_length', type=int, default=50, help="Walk length for node2vec") 40 | self.parser.add_argument('--context_size', type=int, default=10, help="Context size for node2vec") 41 | self.parser.add_argument('--walk_per_nodes', type=int, default=10, help="Walk per nodes for node2vec") 42 | 43 | def parse_args(self): 44 | return self.parser.parse_args() 45 | -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def drop_feature(x, drop_prob): 5 | drop_mask = torch.empty( 6 | (x.size(1), ), 7 | dtype=torch.float32, 8 | device=x.device).uniform_(0, 1) < drop_prob 9 | x = x.clone() 10 | x[:, drop_mask] = 0 11 | 12 | return x 13 | 14 | def adversarial_aug_train(model, node_attack, perturb_shape, step_size, m, device): 15 | model.train() 16 | 17 | perturb = torch.FloatTensor(*perturb_shape).uniform_(-step_size, step_size).to(device) 18 | perturb.requires_grad_() 19 | 20 | loss = node_attack(perturb) 21 | loss /= m 22 | 23 | for i in range(m-1): 24 | loss.backward() 25 | perturb_data = perturb.detach() + step_size * torch.sign(perturb.grad.detach()) 26 | perturb.data = perturb_data.data 27 | perturb.grad[:] = 0 28 | 29 | loss = node_attack(perturb) 30 | loss /= m 31 | 32 | return loss -------------------------------------------------------------------------------- /utils/normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_laplacian_matrix(adj): 6 | ''' 7 | Calculating laplacian matrix. 8 | 9 | Args: 10 | adj: adjacent matrix or discrete similarity matrix. 11 | 12 | Returns: 13 | normalized laplacian matrix. 14 | ''' 15 | EPS = 1e-6 16 | # check and remove self-loop 17 | I = torch.eye(adj.shape[0], device=adj.device) 18 | if torch.diag(adj).sum().item()+EPS >= adj.shape[0]: 19 | tmp = adj - I 20 | else: 21 | tmp = adj 22 | 23 | D = tmp.sum(dim=1).clip(1) 24 | D_rsqrt = torch.rsqrt(D) 25 | D_rsqrt = torch.diag(D_rsqrt) 26 | lap_mat = I - D_rsqrt@tmp@D_rsqrt 27 | return lap_mat 28 | 29 | def similarity(z1: torch.Tensor, z2: torch.Tensor): 30 | z1 = F.normalize(z1) 31 | z2 = F.normalize(z2) 32 | return torch.mm(z1, z2.t()) -------------------------------------------------------------------------------- /utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def reset_random_seed(seed): 8 | r""" 9 | Initial process for fixing all possible random seed. 10 | 11 | Args: 12 | config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.random_seed`) 13 | """ 14 | # Fix Random seed 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | 23 | # Default state is a training state 24 | torch.enable_grad() -------------------------------------------------------------------------------- /utils/register.py: -------------------------------------------------------------------------------- 1 | r"""A kernel module that contains a global register for unified model, dataset, and pre-training algorithms access. 2 | """ 3 | 4 | class Register(object): 5 | r""" 6 | Global register for unified model, dataset, and pre-training algorithms access. 7 | """ 8 | 9 | def __init__(self): 10 | self.pipelines = dict() 11 | self.launchers = dict() 12 | self.models = dict() 13 | self.datasets = dict() 14 | self.dataloader = dict() 15 | self.ood_algs = dict() 16 | self.encoders = dict() 17 | 18 | def pipeline_register(self, pipeline_class): 19 | r""" 20 | Register for pipeline access. 21 | 22 | Args: 23 | pipeline_class (class): pipeline class 24 | 25 | Returns (class): 26 | pipeline class 27 | 28 | """ 29 | self.pipelines[pipeline_class.__name__] = pipeline_class 30 | return pipeline_class 31 | 32 | def launcher_register(self, launcher_class): 33 | r""" 34 | Register for pipeline access. 35 | 36 | Args: 37 | launcher_class (class): pipeline class 38 | 39 | Returns (class): 40 | pipeline class 41 | 42 | """ 43 | self.launchers[launcher_class.__name__] = launcher_class 44 | return launcher_class 45 | 46 | def model_register(self, model_class): 47 | r""" 48 | Register for model access. 49 | 50 | Args: 51 | model_class (class): model class 52 | 53 | Returns (class): 54 | model class 55 | 56 | """ 57 | self.models[model_class.__name__] = model_class 58 | return model_class 59 | 60 | def encoder_register(self, encoder_class): 61 | r""" 62 | Register for model access. 63 | 64 | Args: 65 | model_class (class): model class 66 | 67 | Returns (class): 68 | model class 69 | 70 | """ 71 | self.encoders[encoder_class.__name__] = encoder_class 72 | return encoder_class 73 | 74 | def dataset_register(self, dataset_class): 75 | r""" 76 | Register for dataset access. 77 | 78 | Args: 79 | dataset_class (class): dataset class 80 | 81 | Returns (class): 82 | dataset class 83 | 84 | """ 85 | self.datasets[dataset_class.__name__] = dataset_class 86 | return dataset_class 87 | 88 | def dataloader_register(self, dataloader_class): 89 | r""" 90 | Register for dataloader access. 91 | 92 | Args: 93 | dataloader_class (class): dataloader class 94 | 95 | Returns (class): 96 | dataloader class 97 | 98 | """ 99 | self.dataloader[dataloader_class.__name__] = dataloader_class 100 | return dataloader_class 101 | 102 | 103 | register = Register() #: The register object used for accessing models, datasets and pre-training algorithms. 104 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | from torch_geometric.utils import subgraph, to_undirected, remove_isolated_nodes, dropout_adj, remove_self_loops, k_hop_subgraph 4 | from torch_geometric.utils.num_nodes import maybe_num_nodes 5 | import copy 6 | from torch_sparse import SparseTensor 7 | 8 | from .transforms import obtain_attributes 9 | 10 | 11 | def add_remaining_selfloop_for_isolated_nodes(edge_index, num_nodes): 12 | num_nodes = max(maybe_num_nodes(edge_index), num_nodes) 13 | # only add self-loop on isolated nodes 14 | # edge_index, _ = remove_self_loops(edge_index) 15 | loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device) 16 | connected_nodes_indices = torch.cat([edge_index[0], edge_index[1]]).unique() 17 | mask = torch.ones(num_nodes, dtype=torch.bool) 18 | mask[connected_nodes_indices] = False 19 | loops_for_isolatd_nodes = loop_index[mask] 20 | loops_for_isolatd_nodes = loops_for_isolatd_nodes.unsqueeze(0).repeat(2, 1) 21 | edge_index = torch.cat([edge_index, loops_for_isolatd_nodes], dim=1) 22 | return edge_index 23 | 24 | 25 | class RWR: 26 | """ Every node in the graph will get a random path 27 | 28 | A stochastic data augmentation module that transforms a complete graph into many subgraphs through random walking 29 | the subgraphs which contain the same center nodes are positive pairs, otherwise they are negative pairs 30 | """ 31 | 32 | def __init__(self, walk_steps=50, graph_num=128, restart_ratio=0.5, inductive=False, aligned=False, **args): 33 | self.walk_steps = walk_steps 34 | self.graph_num = graph_num 35 | self.restart_ratio = restart_ratio 36 | self.inductive = inductive 37 | self.aligned = aligned 38 | 39 | def __call__(self, graph): 40 | graph = copy.deepcopy(graph) # modified on the copy 41 | assert self.walk_steps > 1 42 | # remove isolated nodes (or we can construct edges for these nodes) 43 | if self.inductive: 44 | train_node_idx = torch.where(graph.train_mask == True)[0] 45 | graph.edge_index, _ = subgraph(train_node_idx, graph.edge_index) # remove val and test nodes (val and test are considered as isolated nodes) 46 | edge_index, _, mask = remove_isolated_nodes(graph.edge_index, num_nodes=graph.x.shape[0]) # remove all ioslated nodes and re-index nodes 47 | graph.x = graph.x[mask] 48 | edge_index = to_undirected(graph.edge_index) 49 | edge_index = add_remaining_selfloop_for_isolated_nodes(edge_index, graph.x.shape[0]) 50 | graph.edge_index = edge_index 51 | 52 | node_num = graph.x.shape[0] 53 | graph_num = min(self.graph_num, node_num) 54 | start_nodes = torch.randperm(node_num)[:graph_num] 55 | edge_index = graph.edge_index 56 | 57 | value = torch.arange(edge_index.size(1)) 58 | self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], 59 | value=value, 60 | sparse_sizes=(node_num, node_num)).t() 61 | 62 | view1_list = [] 63 | view2_list = [] 64 | 65 | views_cnt = 1 if self.aligned else 2 66 | for view_idx in range(views_cnt): 67 | current_nodes = start_nodes.clone() 68 | history = start_nodes.clone().unsqueeze(0) 69 | signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0) 70 | for i in range(self.walk_steps): 71 | seed = torch.rand([graph_num]) 72 | nei = self.adj_t.sample(1, current_nodes).squeeze() 73 | sign = seed < self.restart_ratio 74 | nei[sign] = start_nodes[sign] 75 | history = torch.cat((history, nei.unsqueeze(0)), dim=0) 76 | signs = torch.cat((signs, sign.unsqueeze(0)), dim=0) 77 | current_nodes = nei 78 | history = history.T 79 | signs = signs.T 80 | 81 | for i in range(graph_num): 82 | path = history[i] 83 | sign = signs[i] 84 | node_idx = path.unique() 85 | sources = path[:-1].numpy().tolist() 86 | targets = path[1:].numpy().tolist() 87 | sub_edges = torch.IntTensor([sources, targets]).type_as(graph.edge_index) 88 | sub_edges = sub_edges.T[~sign[1:]].T 89 | # undirectional 90 | if sub_edges.shape[1] != 0: 91 | sub_edges = to_undirected(sub_edges) 92 | view = self.adjust_idx(sub_edges, node_idx, graph, path[0].item()) 93 | 94 | if self.aligned: 95 | view1_list.append(view) 96 | view2_list.append(copy.deepcopy(view)) 97 | else: 98 | if view_idx == 0: 99 | view1_list.append(view) 100 | else: 101 | view2_list.append(view) 102 | return (view1_list, view2_list) 103 | 104 | def adjust_idx(self, edge_index, node_idx, full_g, center_idx): 105 | '''re-index the nodes and edge index 106 | 107 | In the subgraphs, some nodes are droppped. We need to change the node index in edge_index in order to corresponds 108 | nodes' index to edge index 109 | ''' 110 | node_idx_map = {j : i for i, j in enumerate(node_idx.numpy().tolist())} 111 | sources_idx = list(map(node_idx_map.get, edge_index[0].numpy().tolist())) 112 | target_idx = list(map(node_idx_map.get, edge_index[1].numpy().tolist())) 113 | 114 | edge_index = torch.IntTensor([sources_idx, target_idx]).type_as(full_g.edge_index) 115 | # x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], center=node_idx_map[center_idx], original_idx=node_idx) 116 | x = obtain_attributes(Data(edge_index=edge_index), use_adj=True) 117 | x_view = Data(edge_index=edge_index, x=x, center=node_idx_map[center_idx], original_idx=node_idx, y=full_g.y[center_idx], root_n_index=node_idx_map[center_idx]) 118 | return x_view 119 | 120 | 121 | def collect_subgraphs(selected_id, graph, walk_steps=20, restart_ratio=0.5): 122 | graph = copy.deepcopy(graph) # modified on the copy 123 | edge_index = graph.edge_index 124 | node_num = graph.x.shape[0] 125 | start_nodes = selected_id # only sampling selected nodes as subgraphs 126 | graph_num = start_nodes.shape[0] 127 | 128 | value = torch.arange(edge_index.size(1)) 129 | adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], 130 | value=value, 131 | sparse_sizes=(node_num, node_num)).t() 132 | 133 | current_nodes = start_nodes.clone() 134 | history = start_nodes.clone().unsqueeze(0) 135 | signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0) 136 | for i in range(walk_steps): 137 | seed = torch.rand([graph_num]) 138 | nei = adj_t.sample(1, current_nodes).squeeze() 139 | sign = seed < restart_ratio 140 | nei[sign] = start_nodes[sign] 141 | history = torch.cat((history, nei.unsqueeze(0)), dim=0) 142 | signs = torch.cat((signs, sign.unsqueeze(0)), dim=0) 143 | current_nodes = nei 144 | history = history.T 145 | signs = signs.T 146 | 147 | graph_list = [] 148 | for i in range(graph_num): 149 | path = history[i] 150 | sign = signs[i] 151 | node_idx = path.unique() 152 | sources = path[:-1].numpy().tolist() 153 | targets = path[1:].numpy().tolist() 154 | sub_edges = torch.IntTensor([sources, targets]).type_as(graph.edge_index) 155 | sub_edges = sub_edges.T[~sign[1:]].T 156 | # undirectional 157 | if sub_edges.shape[1] != 0: 158 | sub_edges = to_undirected(sub_edges) 159 | view = adjust_idx(sub_edges, node_idx, graph, path[0].item()) 160 | 161 | graph_list.append(view) 162 | return graph_list 163 | 164 | def adjust_idx(edge_index, node_idx, full_g, center_idx): 165 | '''re-index the nodes and edge index 166 | 167 | In the subgraphs, some nodes are droppped. We need to change the node index in edge_index in order to corresponds 168 | nodes' index to edge index 169 | ''' 170 | node_idx_map = {j : i for i, j in enumerate(node_idx.numpy().tolist())} 171 | sources_idx = list(map(node_idx_map.get, edge_index[0].numpy().tolist())) 172 | target_idx = list(map(node_idx_map.get, edge_index[1].numpy().tolist())) 173 | 174 | edge_index = torch.IntTensor([sources_idx, target_idx]).type_as(full_g.edge_index) 175 | x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], center=node_idx_map[center_idx], original_idx=node_idx, y=full_g.y[center_idx], root_n_index=node_idx_map[center_idx]) 176 | return x_view 177 | 178 | def ego_graphs_sampler(node_idx, data, hop=2): 179 | ego_graphs = [] 180 | for idx in node_idx: 181 | subset, sub_edge_index, mapping, edge_mask = k_hop_subgraph([idx], hop, data.edge_index, relabel_nodes=True) 182 | # sub_edge_index = to_undirected(sub_edge_index) 183 | sub_x = data.x[subset] 184 | # center_idx = subset[mapping].item() # node idx in the original graph, use idx instead 185 | g = Data(x=sub_x, edge_index=sub_edge_index, root_n_index=mapping, y=data.y[idx], original_idx=idx) # note: there we use root_n_index to record the index of target node, because `PyG` increments attributes by the number of nodes whenever their attribute names contain the substring :obj:`index` 186 | ego_graphs.append(g) 187 | return ego_graphs -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from torch_geometric.utils import to_undirected, remove_self_loops, to_dense_adj 3 | import torch.nn.functional as F 4 | import torch 5 | import scipy 6 | 7 | from .normalize import similarity, get_laplacian_matrix 8 | 9 | def obtain_attributes(data, use_adj=False, threshold=0.1, num_dim=32): 10 | save_node_border = 30000 11 | 12 | if use_adj: 13 | # to undirected and remove self-loop 14 | edges = to_undirected(data.edge_index) 15 | edges, _ = remove_self_loops(edges) 16 | tmp = to_dense_adj(edges)[0] 17 | else: 18 | tmp = similarity(data.x, data.x) 19 | 20 | # discretize the similarity matrix by threshold 21 | tmp = torch.where(tmp>threshold, 1.0, 0.0) 22 | 23 | tmp = get_laplacian_matrix(tmp) 24 | if tmp.shape[0] > save_node_border: 25 | L, V = scipy.linalg.eigh(tmp) 26 | L = torch.from_numpy(L) 27 | V = torch.from_numpy(V) 28 | else: 29 | L, V = torch.linalg.eigh(tmp) # much faster than torch.linalg.eig 30 | 31 | x = V[:, :num_dim].float() 32 | import sklearn.preprocessing as preprocessing 33 | x = preprocessing.normalize(x.cpu(), norm="l2") 34 | x = torch.tensor(x, dtype=torch.float32) 35 | 36 | return x 37 | 38 | 39 | def process_attributes(data, use_adj=False, threshold=0.1, num_dim=32, soft=False, kernel=False): 40 | ''' 41 | Replace the node attributes with positional encoding. Warning: this function will replace the node attributes! 42 | 43 | Args: 44 | data: a single graph contains x (if use_adj=False) and edge_index. 45 | use_adj: use the eigen-vectors of adjacent matrix or similarity matrix as node attributes. 46 | threshold: only work when use_adj=False, used for discretize the similarity matrix. 1 if Adj(i,j)>0.1 else 0 47 | soft: only work when use_adj=False, if soft=True, we will use soft similarity matrix. 48 | 49 | Returns: 50 | modified data. 51 | ''' 52 | 53 | if use_adj: 54 | # to undirected and remove self-loop 55 | edges = to_undirected(data.edge_index) 56 | if edges.size(1) > 1: 57 | edges, _ = remove_self_loops(edges) 58 | else: 59 | edges = torch.tensor([[0],[0]]) # for isolated nodes 60 | Adj = to_dense_adj(edges)[0] 61 | else: 62 | 63 | if kernel: 64 | # memory efficient 65 | XY = (data.x@data.x.T) # 2xy 66 | deg = torch.diag(XY) 67 | Y_norm = deg.repeat(XY.shape[0],1) 68 | X_norm = Y_norm.T 69 | Adj = X_norm - 2*XY + Y_norm # |X-Y|^2 70 | Adj = torch.exp(-0.05*Adj) # rbf kernel 71 | else: 72 | Adj = similarity(data.x, data.x) # equal to linear kernel 73 | if soft: 74 | L, V = torch.linalg.eigh(Adj) 75 | x = V[:, :num_dim].float() 76 | x = F.normalize(x, dim=1) 77 | data.x = x 78 | return data 79 | else: 80 | # discretize the similarity matrix by threshold 81 | Adj = torch.where(Adj>threshold, 1.0, 0.0) 82 | Lap = get_laplacian_matrix(Adj) 83 | 84 | L, V = torch.linalg.eigh(Lap) # much faster than torch.linalg.eig, if this line triggers bugs please refer to https://github.com/pytorch/pytorch/issues/70122#issuecomment-1232766638 85 | L_sort, _ = torch.sort(L, descending=False) 86 | hist = torch.histc(L, bins=32, min=0, max=2) 87 | hist = hist.unsqueeze(0) 88 | 89 | # Padding 90 | import sklearn.preprocessing as preprocessing 91 | if V.shape[0] < num_dim: 92 | V = preprocessing.normalize(V, norm="l2") 93 | V = torch.tensor(V, dtype=torch.float32) 94 | x = torch.nn.functional.pad(V, (0, num_dim-V.shape[0])) 95 | data.x = x.float() 96 | data.eigen_val = torch.nn.functional.pad(L_sort, (0, num_dim-L_sort.shape[0])).unsqueeze(0) 97 | else: 98 | x = V[:, 0:num_dim].float() 99 | x = preprocessing.normalize(x, norm="l2") 100 | x = torch.tensor(x, dtype=torch.float32) 101 | data.x = x.float() 102 | data.eigen_val = L_sort[:num_dim].unsqueeze(0) 103 | 104 | return data --------------------------------------------------------------------------------