├── LP.png ├── README.md ├── joao.png ├── semisupervised_OGB ├── README.md ├── code │ ├── conv.py │ ├── dataset_aug.py │ ├── finetune.py │ ├── finetune_joaov2.py │ ├── gnn.py │ ├── gnn_proj.py │ ├── main_pretrain_graphcl.py │ ├── main_pretrain_graphcl_joao.py │ ├── main_pretrain_graphcl_joaov2.py │ ├── py2graph.py │ ├── results │ │ └── place_holder.txt │ ├── run.sh │ ├── run_joaov2.sh │ ├── utils.py │ └── weights │ │ └── place_holder.txt └── ppa │ ├── conv.py │ ├── dataset_aug.py │ ├── finetune.py │ ├── finetune_joaov2.py │ ├── gnn.py │ ├── gnn_proj.py │ ├── main_pretrain_graphcl.py │ ├── main_pretrain_graphcl_joao.py │ ├── main_pretrain_graphcl_joaov2.py │ ├── results │ └── place_holder.txt │ ├── run.sh │ ├── run_joaov2.sh │ └── weights │ └── place_holder.txt ├── semisupervised_OGB_LP ├── README.md ├── code │ ├── cal.py │ ├── conv.py │ ├── dataset_graphcl.py │ ├── finetune.sh │ ├── gnn.py │ ├── main_infomin_epoch.py │ ├── main_pretrain_generative_infobn.py │ ├── main_pretrain_generative_infomin.py │ ├── main_pretrain_generative_infominbn.py │ ├── main_pyg.py │ ├── py2graph.py │ └── utils.py └── ppa │ ├── cal.py │ ├── conv.py │ ├── dataset_graphcl.py │ ├── finetune.sh │ ├── gnn.py │ ├── main_infomin_epoch.py │ ├── main_pretrain_generative_infobn.py │ ├── main_pretrain_generative_infomin.py │ ├── main_pretrain_generative_infominbn.py │ └── main_pyg.py ├── semisupervised_TU ├── README.md ├── finetune │ ├── datasets.py │ ├── feature_expansion.py │ ├── gcn_conv.py │ ├── main.py │ ├── res_gcn.py │ ├── results_joao │ │ └── place_holder.txt │ ├── train_eval.py │ ├── tu_dataset.py │ └── utils.py ├── finetune_joaov2 │ ├── datasets.py │ ├── feature_expansion.py │ ├── gcn_conv.py │ ├── main.py │ ├── res_gcn.py │ ├── results_joao │ │ └── place_holder.txt │ ├── train_eval.py │ ├── tu_dataset.py │ └── utils.py ├── pretrain │ ├── datasets.py │ ├── experiment_graphcl.py │ ├── experiment_joao.py │ ├── feature_expansion.py │ ├── gcn_conv.py │ ├── main.py │ ├── res_gcn.py │ ├── tu_dataset.py │ ├── utils.py │ ├── weights_graphcl │ │ └── place_holder.txt │ └── weights_joao │ │ └── place_holder.txt └── pretrain_joaov2 │ ├── datasets.py │ ├── experiment_joao.py │ ├── feature_expansion.py │ ├── gcn_conv.py │ ├── main.py │ ├── res_gcn.py │ ├── tu_dataset.py │ ├── utils.py │ └── weights_joao │ └── place_holder.txt ├── semisupervised_TU_LP ├── README.md ├── finetune │ ├── .main_visulize_generator.py.swp │ ├── calculate_result.py │ ├── datasets.py │ ├── feature_expansion.py │ ├── gcn_conv.py │ ├── main.py │ ├── res_gcn.py │ ├── train_eval.py │ ├── tu_dataset.py │ └── utils.py └── pretrain │ ├── .experiment_generative_linkPrediction.py.swp │ ├── datasets.py │ ├── experiment_generative.py │ ├── experiment_generative_ib.py │ ├── experiment_generative_ibalone.py │ ├── feature_expansion.py │ ├── gcn_conv.py │ ├── main.py │ ├── res_gcn.py │ ├── tu_dataset.py │ ├── utils.py │ ├── weights_infobn │ └── debug.log │ ├── weights_infomin │ └── debug.log │ └── weights_infominbn │ └── debug.log ├── transferLearning_MoleculeNet_PPI ├── README.md ├── bio │ ├── batch.py │ ├── dataloader.py │ ├── finetune.py │ ├── finetune.sh │ ├── loader.py │ ├── model.py │ ├── pretrain_joao.py │ ├── pretrain_joaov2.py │ ├── results │ │ └── place_holder.txt │ ├── splitters.py │ ├── util.py │ └── weights │ │ └── place_holder.txt └── chem │ ├── batch.py │ ├── dataloader.py │ ├── finetune.py │ ├── finetune.sh │ ├── loader.py │ ├── model.py │ ├── pretrain_joao.py │ ├── pretrain_joaov2.py │ ├── results │ └── place_holder.txt │ ├── splitters.py │ ├── util.py │ └── weights │ └── place_holder.txt ├── transferLearning_MoleculeNet_PPI_LP ├── README.md ├── bio │ ├── .graph_cover_2.py.swp │ ├── batch.py │ ├── dataloader.py │ ├── finetune.py │ ├── finetune.sh │ ├── loader.py │ ├── loader_vis.py │ ├── model.py │ ├── pretrain_generative_infobn.py │ ├── pretrain_generative_infomin.py │ ├── pretrain_generative_infominbn.py │ ├── pretrain_supervised.py │ ├── splitters.py │ └── util.py └── chem │ ├── batch.py │ ├── cal.py │ ├── dataloader.py │ ├── finetune.py │ ├── finetune.sh │ ├── loader.py │ ├── model.py │ ├── pretrain_generative_infobn.py │ ├── pretrain_generative_infomin.py │ ├── pretrain_generative_infominbn.py │ ├── pretrain_graphaf_joao.py │ ├── splitters.py │ └── util.py └── unsupervised_TU ├── README.md ├── arguments.py ├── aug.py ├── evaluate_embedding.py ├── gin.py ├── joao.py ├── joao.sh ├── joaov2.py ├── joaov2.sh ├── model.py └── results └── place_holder.txt /LP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/LP.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automate Graph Contrastive Learning Beyond Pre-Defined Augmentations 2 | 3 | PyTorch implementations of: 4 | 5 | **Automated selection of augmentations:** [Graph Contrastive Learning Automated](https://arxiv.org/abs/2106.07594) 6 | [[talk]](https://recorder-v3.slideslive.com/?share=39319&s=4366fe70-48a4-4f2c-952b-2a7ca56d48bf) 7 | [[poster]](https://yyou1996.github.io/files/icml2021_graphcl_automated_poster.pdf) 8 | [[appendix]](https://yyou1996.github.io/files/icml2021_graphcl_automated_supplement.pdf) 9 | 10 | Yuning You, Tianlong Chen, Yang Shen, Zhangyang Wang 11 | 12 | In ICML 2021. 13 | 14 | **Generating augmentations with generative models:** [Bringing Your Own View: Graph Contrastive Learning without Prefabricated Data Augmentations](https://arxiv.org/abs/2201.01702) 15 | 18 | 19 | Yuning You, Tianlong Chen, Zhangyang Wang, Yang Shen 20 | 21 | In WSDM 2022. 22 | 23 | ## Overview 24 | 25 | In this repository, we propose a principled framework named joint augmentation selection (JOAO), to automatically, adaptively and dynamically select augmentations during [GraphCL](https://arxiv.org/abs/2010.13902) training. 26 | Sanity check shows that the selection aligns with previous "best practices", as shown in Figure 3 of [Graph Contrastive Learning Automated](https://arxiv.org/abs/2106.07594) (ICML 2021). Corresponding folder names are $Setting_$Dataset. 27 | 28 | 29 | ![](./joao.png) 30 | 31 | 32 | 33 | We further propose leveraging graph generative models to directly generate augmentations (LP for Learned Priors) rather than relying on the prefabricated ones, as shown in Figure 2 of [Bringing Your Own View: Graph Contrastive Learning without Prefabricated Data Augmentations](https://arxiv.org/abs/2201.01702) (WSDM 2022). Corresponding folder names end with LP: $Setting_$Dataset_LP. Please note that although the study used GraphCL as the base model, leading to GraphCL-LP, the proposed LP framework is more general than that and can use other base models (such as BRGL in Appendix B). 34 | 35 | ![](./LP.png) 36 | 37 | ## Dependencies 38 | 39 | 40 | * [torch-geometric](https://github.com/rusty1s/pytorch_geometric) >= 1.6.0 41 | * [ogb](https://github.com/snap-stanford/ogb) == 1.2.4 42 | 43 | 44 | ## Experiments 45 | 46 | * Semi-supervised learning [[JOAO: TU Datasets]](https://github.com/Shen-Lab/GraphCL_Automated/tree/master/semisupervised_TU) [[JOAO: OGB]](https://github.com/Shen-Lab/GraphCL_Automated/tree/master/semisupervised_OGB) [[GraphCL-LP: TU Datasets]](https://github.com/Shen-Lab/GraphCL_Automated/tree/master/semisupervised_TU_LP) [[GraphCL-LP: OGB]](https://github.com/Shen-Lab/GraphCL_Automated/tree/master/semisupervised_OGB_LP) 47 | * Unsupervised representation learning [[JOAO: TU Datasets]](https://github.com/Shen-Lab/GraphCL_Automated/tree/master/unsupervised_TU) 48 | * Transfer learning [[JOAO: MoleculeNet and PPI]](https://github.com/Shen-Lab/GraphCL_Automated/tree/master/transferLearning_MoleculeNet_PPI) [[GraphCL-LP: MoleculeNet and PPI]](https://github.com/Shen-Lab/GraphCL_Automated/tree/master/transferLearning_MoleculeNet_PPI_LP) 49 | 50 | ## Citation 51 | 52 | If you use this code for you research, please cite our paper. 53 | 54 | ``` 55 | @article{you2021graph, 56 | title={Graph Contrastive Learning Automated}, 57 | author={You, Yuning and Chen, Tianlong and Shen, Yang and Wang, Zhangyang}, 58 | journal={arXiv preprint arXiv:2106.07594}, 59 | year={2021} 60 | } 61 | 62 | @misc{you2022bringing, 63 | title={Bringing Your Own View: Graph Contrastive Learning without Prefabricated Data Augmentations}, 64 | author={Yuning You and Tianlong Chen and Zhangyang Wang and Yang Shen}, 65 | year={2022}, 66 | eprint={2201.01702}, 67 | archivePrefix={arXiv}, 68 | primaryClass={cs.LG} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /joao.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/joao.png -------------------------------------------------------------------------------- /semisupervised_OGB/README.md: -------------------------------------------------------------------------------- 1 | ### JOAO Pre-Training: ### 2 | 3 | ``` 4 | cd ./code 5 | python main_pretrain_graphcl_joao.py --gamma 0.01 6 | cd ./ppa 7 | python main_pretrain_graphcl_joao.py --gamma 0.01 8 | ``` 9 | 10 | 11 | ### JOAO Finetuning: ### 12 | 13 | ``` 14 | cd ./code 15 | ./run.sh ${N_SPLIT} ${RESULT_FILE} ./weights/joao_${gamma}_30.pt 16 | cd ./ppa 17 | ./run.sh ${N_SPLIT} ${RESULT_FILE} ./weights/joao_${gamma}_100.pt 18 | ``` 19 | 20 | ```gamma``` is tuned from {0.01, 0.1, 1}. ```N_SPLIT``` can be 100 or 10 for 1% or 10% label rate, and ```RESULT_FILE``` is the file to store the results. 21 | 22 | 23 | ### JOAOv2 Pre-Training: ### 24 | 25 | ``` 26 | cd ./code 27 | python main_pretrain_graphcl_joaov2.py --gamma 0.01 28 | cd ./ppa 29 | python main_pretrain_graphcl_joaov2.py --gamma 0.01 30 | ``` 31 | 32 | 33 | ### JOAOv2 Finetuning: ### 34 | 35 | ``` 36 | cd ./code 37 | ./run_joaov2.sh ${N_SPLIT} ${RESULT_FILE} ./weights/joaov2_${gamma}_30.pt 38 | cd ./ppa 39 | ./run_joaov2.sh ${N_SPLIT} ${RESULT_FILE} ./weights/joaov2_${gamma}_100.pt 40 | ``` 41 | 42 | ```gamma``` is tuned from {0.01, 0.1, 1}. ```N_SPLIT``` can be 100 or 10 for 1% or 10% label rate, and ```RESULT_FILE``` is the file to store the results. 43 | 44 | 45 | ## Acknowledgements 46 | 47 | The backbone implementation is reference to https://github.com/snap-stanford/ogb. 48 | -------------------------------------------------------------------------------- /semisupervised_OGB/code/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | import torch.nn.functional as F 5 | from torch_geometric.nn.inits import uniform 6 | 7 | from conv import GNN_node, GNN_node_Virtualnode 8 | 9 | from torch_scatter import scatter_mean 10 | 11 | class GNN(torch.nn.Module): 12 | 13 | def __init__(self, num_vocab, max_seq_len, node_encoder, num_layer = 5, emb_dim = 300, 14 | gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): 15 | ''' 16 | num_tasks (int): number of labels to be predicted 17 | virtual_node (bool): whether to add virtual node or not 18 | ''' 19 | 20 | super(GNN, self).__init__() 21 | 22 | self.num_layer = num_layer 23 | self.drop_ratio = drop_ratio 24 | self.JK = JK 25 | self.emb_dim = emb_dim 26 | self.num_vocab = num_vocab 27 | self.max_seq_len = max_seq_len 28 | self.graph_pooling = graph_pooling 29 | 30 | if self.num_layer < 2: 31 | raise ValueError("Number of GNN layers must be greater than 1.") 32 | 33 | ### GNN to generate node embeddings 34 | if virtual_node: 35 | self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 36 | else: 37 | self.gnn_node = GNN_node(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 38 | 39 | 40 | ### Pooling function to generate whole-graph embeddings 41 | if self.graph_pooling == "sum": 42 | self.pool = global_add_pool 43 | elif self.graph_pooling == "mean": 44 | self.pool = global_mean_pool 45 | elif self.graph_pooling == "max": 46 | self.pool = global_max_pool 47 | elif self.graph_pooling == "attention": 48 | self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) 49 | elif self.graph_pooling == "set2set": 50 | self.pool = Set2Set(emb_dim, processing_steps = 2) 51 | else: 52 | raise ValueError("Invalid graph pooling type.") 53 | 54 | self.graph_pred_linear_list = torch.nn.ModuleList() 55 | 56 | if graph_pooling == "set2set": 57 | for i in range(max_seq_len): 58 | self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_vocab)) 59 | 60 | else: 61 | for i in range(max_seq_len): 62 | self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_vocab)) 63 | 64 | self.proj_head = torch.nn.Sequential(torch.nn.Linear(self.emb_dim, self.emb_dim), torch.nn.ReLU(inplace=True), torch.nn.Linear(self.emb_dim, self.emb_dim)) 65 | 66 | def forward(self, batched_data): 67 | ''' 68 | Return: 69 | A list of predictions. 70 | i-th element represents prediction at i-th position of the sequence. 71 | ''' 72 | 73 | h_node = self.gnn_node(batched_data) 74 | 75 | h_graph = self.pool(h_node, batched_data.batch) 76 | 77 | pred_list = [] 78 | # for i in range(self.max_seq_len): 79 | # pred_list.append(self.graph_pred_mlp_list[i](h_graph)) 80 | 81 | for i in range(self.max_seq_len): 82 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 83 | 84 | return pred_list 85 | 86 | def forward_cl(self, batched_data): 87 | h_node = self.gnn_node(batched_data) 88 | 89 | h_graph = self.pool(h_node, batched_data.batch) 90 | z = self.proj_head(h_graph) 91 | return z 92 | 93 | def loss_cl(self, x1, x2): 94 | T = 0.5 95 | batch_size, _ = x1.size() 96 | 97 | x1_abs = x1.norm(dim=1) 98 | x2_abs = x2.norm(dim=1) 99 | 100 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs) 101 | sim_matrix = torch.exp(sim_matrix / T) 102 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 103 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 104 | loss = - torch.log(loss).mean() 105 | return loss 106 | 107 | 108 | if __name__ == '__main__': 109 | pass 110 | -------------------------------------------------------------------------------- /semisupervised_OGB/code/gnn_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | import torch.nn.functional as F 5 | from torch_geometric.nn.inits import uniform 6 | 7 | from conv import GNN_node, GNN_node_Virtualnode 8 | 9 | from torch_scatter import scatter_mean 10 | 11 | class GNN(torch.nn.Module): 12 | 13 | def __init__(self, num_vocab, max_seq_len, node_encoder, num_layer = 5, emb_dim = 300, 14 | gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): 15 | ''' 16 | num_tasks (int): number of labels to be predicted 17 | virtual_node (bool): whether to add virtual node or not 18 | ''' 19 | 20 | super(GNN, self).__init__() 21 | 22 | self.num_layer = num_layer 23 | self.drop_ratio = drop_ratio 24 | self.JK = JK 25 | self.emb_dim = emb_dim 26 | self.num_vocab = num_vocab 27 | self.max_seq_len = max_seq_len 28 | self.graph_pooling = graph_pooling 29 | 30 | if self.num_layer < 2: 31 | raise ValueError("Number of GNN layers must be greater than 1.") 32 | 33 | ### GNN to generate node embeddings 34 | if virtual_node: 35 | self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 36 | else: 37 | self.gnn_node = GNN_node(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 38 | 39 | 40 | ### Pooling function to generate whole-graph embeddings 41 | if self.graph_pooling == "sum": 42 | self.pool = global_add_pool 43 | elif self.graph_pooling == "mean": 44 | self.pool = global_mean_pool 45 | elif self.graph_pooling == "max": 46 | self.pool = global_max_pool 47 | elif self.graph_pooling == "attention": 48 | self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) 49 | elif self.graph_pooling == "set2set": 50 | self.pool = Set2Set(emb_dim, processing_steps = 2) 51 | else: 52 | raise ValueError("Invalid graph pooling type.") 53 | 54 | self.graph_pred_linear_list = torch.nn.ModuleList() 55 | 56 | if graph_pooling == "set2set": 57 | for i in range(max_seq_len): 58 | self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_vocab)) 59 | 60 | else: 61 | for i in range(max_seq_len): 62 | self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_vocab)) 63 | 64 | self.proj_head = torch.nn.ModuleList([torch.nn.Sequential(torch.nn.Linear(self.emb_dim, self.emb_dim), torch.nn.ReLU(inplace=True), torch.nn.Linear(self.emb_dim, self.emb_dim)) for _ in range(5)]) 65 | 66 | def forward(self, batched_data): 67 | ''' 68 | Return: 69 | A list of predictions. 70 | i-th element represents prediction at i-th position of the sequence. 71 | ''' 72 | 73 | h_node = self.gnn_node(batched_data) 74 | 75 | h_graph = self.pool(h_node, batched_data.batch) 76 | 77 | pred_list = [] 78 | # for i in range(self.max_seq_len): 79 | # pred_list.append(self.graph_pred_mlp_list[i](h_graph)) 80 | 81 | for i in range(self.max_seq_len): 82 | pred_list.append(self.graph_pred_linear_list[i](h_graph)) 83 | 84 | return pred_list 85 | 86 | def forward_cl(self, batched_data, n_proj): 87 | h_node = self.gnn_node(batched_data) 88 | 89 | h_graph = self.pool(h_node, batched_data.batch) 90 | z = self.proj_head[n_proj](h_graph) 91 | return z 92 | 93 | def loss_cl(self, x1, x2): 94 | T = 0.5 95 | batch_size, _ = x1.size() 96 | 97 | x1_abs = x1.norm(dim=1) 98 | x2_abs = x2.norm(dim=1) 99 | 100 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs) 101 | sim_matrix = torch.exp(sim_matrix / T) 102 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 103 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 104 | loss = - torch.log(loss).mean() 105 | return loss 106 | 107 | 108 | if __name__ == '__main__': 109 | pass 110 | -------------------------------------------------------------------------------- /semisupervised_OGB/code/main_pretrain_graphcl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from gnn import GNN 7 | 8 | from tqdm import tqdm 9 | import argparse 10 | import time 11 | import numpy as np 12 | import pandas as pd 13 | import os 14 | 15 | ### importing OGB 16 | from dataset_aug import PygGraphPropPredDataset, collate 17 | from ogb.graphproppred import Evaluator 18 | 19 | ### importing utils 20 | from utils import ASTNodeEncoder, get_vocab_mapping 21 | ### for data transform 22 | from utils import augment_edge, encode_y_to_arr, decode_arr_to_seq 23 | 24 | 25 | multicls_criterion = torch.nn.CrossEntropyLoss() 26 | 27 | def train(model, device, loader, optimizer): 28 | model.train() 29 | 30 | loss_accum = 0 31 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 32 | batch1, batch2 = batch 33 | batch1, batch2 = batch1.to(device), batch2.to(device) 34 | 35 | if batch1.x.shape[0] == 1 or batch1.batch[-1] == 0: 36 | pass 37 | else: 38 | x1, x2 = model.forward_cl(batch1), model.forward_cl(batch2) 39 | optimizer.zero_grad() 40 | loss = model.loss_cl(x1, x2) 41 | loss.backward() 42 | optimizer.step() 43 | 44 | 45 | def main(): 46 | # Training settings 47 | parser = argparse.ArgumentParser(description='GNN baselines on ogbg-code data with Pytorch Geometrics') 48 | parser.add_argument('--device', type=int, default=0, 49 | help='which gpu to use if any (default: 0)') 50 | parser.add_argument('--gnn', type=str, default='gcn-virtual', 51 | help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gcn-virtual)') 52 | parser.add_argument('--drop_ratio', type=float, default=0, 53 | help='dropout ratio (default: 0)') 54 | parser.add_argument('--max_seq_len', type=int, default=5, 55 | help='maximum sequence length to predict (default: 5)') 56 | parser.add_argument('--num_vocab', type=int, default=5000, 57 | help='the number of vocabulary used for sequence prediction (default: 5000)') 58 | parser.add_argument('--num_layer', type=int, default=5, 59 | help='number of GNN message passing layers (default: 5)') 60 | parser.add_argument('--emb_dim', type=int, default=300, 61 | help='dimensionality of hidden units in GNNs (default: 300)') 62 | parser.add_argument('--batch_size', type=int, default=128, 63 | help='input batch size for training (default: 128)') 64 | parser.add_argument('--epochs', type=int, default=30, 65 | help='number of epochs to train (default: 30)') 66 | parser.add_argument('--num_workers', type=int, default=0, 67 | help='number of workers (default: 0)') 68 | parser.add_argument('--dataset', type=str, default="ogbg-code", 69 | help='dataset name (default: ogbg-code)') 70 | 71 | parser.add_argument('--filename', type=str, default="", 72 | help='filename to output result (default: )') 73 | args = parser.parse_args() 74 | 75 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 76 | 77 | ### automatic dataloading and splitting 78 | dataset = PygGraphPropPredDataset(name=args.dataset, mode='graphcl') 79 | 80 | split_idx = dataset.get_idx_split() 81 | 82 | vocab2idx, idx2vocab = get_vocab_mapping([dataset.data.y[i] for i in split_idx['train']], args.num_vocab) 83 | 84 | train_loader = torch.utils.data.DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate) 85 | 86 | nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz')) 87 | nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz')) 88 | 89 | ### Encoding node features into emb_dim vectors. 90 | ### The following three node features are used. 91 | # 1. node type 92 | # 2. node attribute 93 | # 3. node depth 94 | node_encoder = ASTNodeEncoder(args.emb_dim, num_nodetypes = len(nodetypes_mapping['type']), num_nodeattributes = len(nodeattributes_mapping['attr']), max_depth = 20) 95 | 96 | if args.gnn == 'gin': 97 | model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gin', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 98 | elif args.gnn == 'gin-virtual': 99 | model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gin', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 100 | elif args.gnn == 'gcn': 101 | model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gcn', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 102 | elif args.gnn == 'gcn-virtual': 103 | model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gcn', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 104 | else: 105 | raise ValueError('Invalid GNN type') 106 | 107 | optimizer = optim.Adam(model.parameters(), lr=0.001) 108 | 109 | for epoch in range(1, args.epochs + 1): 110 | print("=====Epoch {}".format(epoch)) 111 | print('Training...') 112 | train(model, device, train_loader, optimizer) 113 | 114 | if epoch % 10 == 0: 115 | torch.save(model.state_dict(), './weights/' + args.dataset + '_graphcl_' + str(epoch) + '.pt') 116 | 117 | print('Finished training!') 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /semisupervised_OGB/code/py2graph.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import networkx as nx 3 | import os 4 | import pandas as pd 5 | import numpy as np 6 | 7 | class OGB_ASTWalker(ast.NodeVisitor): 8 | def __init__(self): 9 | self.node_id = 0 10 | self.stack = [] 11 | self.graph = nx.Graph() 12 | self.nodes = {} 13 | 14 | def generic_visit(self, node): 15 | node_name = self.node_id 16 | self.node_id += 1 17 | 18 | # if available, extract AST node attributes 19 | name = getattr(node, 'name', None) 20 | arg = getattr(node, 'arg', None) 21 | s = getattr(node, 's', None) 22 | n = getattr(node, 'n', None) 23 | id_ = getattr(node, 'id', None) 24 | attr = getattr(node, 'attr', None) 25 | 26 | values = [name, arg, s, n, id_, attr] 27 | node_value = next((str(value) for value in values if value is not None), None) 28 | if isinstance(node_value, str): 29 | node_value = node_value.encode('utf-8', errors='surrogatepass') 30 | 31 | # encapsulate all node features in a dict 32 | self.nodes[node_name] = {'type': type(node).__name__, 33 | 'attribute': node_value, 34 | 'attributed': True if node_value != None else False, 35 | 'depth': len(self.stack), 36 | 'dfs_order': node_name} 37 | 38 | # DFS traversal logic 39 | parent_name = None 40 | if self.stack: 41 | parent_name = self.stack[-1] 42 | self.stack.append(node_name) 43 | self.graph.add_node(node_name) 44 | if parent_name != None: 45 | # replicate AST as NetworkX object 46 | self.graph.add_edge(node_name, parent_name) 47 | super().generic_visit(node) 48 | self.stack.pop() 49 | 50 | 51 | def py2graph_helper(code, attr2idx, type2idx): 52 | ''' 53 | Input: 54 | code: code snippet 55 | 56 | Mappers: 57 | attr_mapping: mapping from attribute to integer idx 58 | type_mapping: mapping from type to integer idx 59 | 60 | Output: OGB graph object 61 | ''' 62 | 63 | tree = ast.parse(code) 64 | walker = OGB_ASTWalker() 65 | walker.visit(tree) 66 | 67 | ast_nodes, ast_edges = walker.nodes, walker.graph.edges() 68 | 69 | data = dict() 70 | data['edge_index'] = np.array([[i, j] for i, j in ast_edges]).transpose() 71 | 72 | # first dim: type 73 | # second dim: attr 74 | 75 | # meta-info 76 | # dfs_order: integer 77 | # attributed: 0 or 1 78 | 79 | node_feat = [] 80 | dfs_order = [] 81 | depth = [] 82 | attributed = [] 83 | for i in range(len(ast_nodes)): 84 | typ = ast_nodes[i]['type'] if ast_nodes[i]['type'] in type2idx else '__UNK__' 85 | 86 | if ast_nodes[i]['attributed']: 87 | attr = ast_nodes[i]['attribute'].decode('UTF-8') if ast_nodes[i]['attribute'].decode('UTF-8') in attr2idx else '__UNK__' 88 | else: 89 | attr = '__NONE__' 90 | 91 | node_feat.append([type2idx[typ], attr2idx[attr]]) 92 | 93 | dfs_order.append(ast_nodes[i]['dfs_order']) 94 | depth.append(ast_nodes[i]['depth']) 95 | attributed.append(ast_nodes[i]['attributed']) 96 | 97 | ### meta-information 98 | data['node_feat'] = np.array(node_feat, dtype = np.int64) 99 | data['node_dfs_order'] = np.array(dfs_order, dtype = np.int64).reshape(-1,1) 100 | data['node_depth'] = np.array(depth, dtype = np.int64).reshape(-1,1) 101 | data['node_is_attributed'] = np.array(attributed, dtype = np.int64).reshape(-1,1) 102 | 103 | data['num_nodes'] = len(data['node_feat']) 104 | data['num_edges'] = len(data['edge_index'][0]) 105 | 106 | return data 107 | 108 | def test_transform(py2graph): 109 | code = ''' 110 | from ogb.graphproppred import PygGraphPropPredDataset 111 | from torch_geometric.data import DataLoader 112 | 113 | dataset = PygGraphPropPredDataset(name = "ogbg-molhiv") 114 | 115 | split_idx = dataset.get_idx_split() 116 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True) 117 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False) 118 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False) 119 | ''' 120 | 121 | graph = py2graph(code) 122 | print(graph) 123 | 124 | invalid_code = ''' 125 | import antigravity 126 | xkcd loves Python 127 | ''' 128 | 129 | try: 130 | graph = py2graph(invalid_code) 131 | except SyntaxError: 132 | print('Successfully caught syntax error') 133 | 134 | 135 | if __name__ == "__main__": 136 | mapping_dir = 'dataset/ogbg_code_pyg/mapping' 137 | 138 | attr_mapping = dict() 139 | type_mapping = dict() 140 | 141 | for line in pd.read_csv(os.path.join(mapping_dir, 'attridx2attr.csv.gz')).values: 142 | attr_mapping[line[1]] = int(line[0]) 143 | 144 | for line in pd.read_csv(os.path.join(mapping_dir, 'typeidx2type.csv.gz')).values: 145 | type_mapping[line[1]] = int(line[0]) 146 | 147 | py2graph = lambda py: py2graph_helper(py, attr_mapping, type_mapping) 148 | test_transform(py2graph) 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /semisupervised_OGB/code/results/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_OGB/code/results/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_OGB/code/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for number in {1..10} 4 | do 5 | python finetune.py --gnn gin --num_workers 8 --n_splits $1 --pretrain $2 --pretrain_weight $3 6 | done 7 | -------------------------------------------------------------------------------- /semisupervised_OGB/code/run_joaov2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for number in {1..10} 4 | do 5 | python finetune_joaov2.py --gnn gin --num_workers 8 --n_splits $1 --pretrain $2 --pretrain_weight $3 6 | done 7 | -------------------------------------------------------------------------------- /semisupervised_OGB/code/weights/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_OGB/code/weights/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | import torch.nn.functional as F 5 | from torch_geometric.nn.inits import uniform 6 | 7 | from conv import GNN_node, GNN_node_Virtualnode 8 | 9 | from torch_scatter import scatter_mean 10 | 11 | class GNN(torch.nn.Module): 12 | 13 | def __init__(self, num_class, num_layer = 5, emb_dim = 300, 14 | gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): 15 | ''' 16 | num_tasks (int): number of labels to be predicted 17 | virtual_node (bool): whether to add virtual node or not 18 | ''' 19 | 20 | super(GNN, self).__init__() 21 | 22 | self.num_layer = num_layer 23 | self.drop_ratio = drop_ratio 24 | self.JK = JK 25 | self.emb_dim = emb_dim 26 | self.num_class = num_class 27 | self.graph_pooling = graph_pooling 28 | 29 | if self.num_layer < 2: 30 | raise ValueError("Number of GNN layers must be greater than 1.") 31 | 32 | ### GNN to generate node embeddings 33 | if virtual_node: 34 | self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 35 | else: 36 | self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 37 | 38 | 39 | ### Pooling function to generate whole-graph embeddings 40 | if self.graph_pooling == "sum": 41 | self.pool = global_add_pool 42 | elif self.graph_pooling == "mean": 43 | self.pool = global_mean_pool 44 | elif self.graph_pooling == "max": 45 | self.pool = global_max_pool 46 | elif self.graph_pooling == "attention": 47 | self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) 48 | elif self.graph_pooling == "set2set": 49 | self.pool = Set2Set(emb_dim, processing_steps = 2) 50 | else: 51 | raise ValueError("Invalid graph pooling type.") 52 | 53 | if graph_pooling == "set2set": 54 | self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_class) 55 | else: 56 | self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_class) 57 | 58 | self.proj_head = torch.nn.Sequential(torch.nn.Linear(self.emb_dim, self.emb_dim), torch.nn.ReLU(inplace=True), torch.nn.Linear(self.emb_dim, self.emb_dim)) 59 | 60 | def forward(self, batched_data): 61 | h_node = self.gnn_node(batched_data) 62 | 63 | h_graph = self.pool(h_node, batched_data.batch) 64 | 65 | return self.graph_pred_linear(h_graph) 66 | 67 | def forward_cl(self, batched_data): 68 | h_node = self.gnn_node(batched_data) 69 | 70 | h_graph = self.pool(h_node, batched_data.batch) 71 | z = self.proj_head(h_graph) 72 | return z 73 | 74 | def loss_cl(self, x1, x2): 75 | T = 0.5 76 | batch_size, _ = x1.size() 77 | 78 | x1_abs = x1.norm(dim=1) 79 | x2_abs = x2.norm(dim=1) 80 | 81 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs) 82 | sim_matrix = torch.exp(sim_matrix / T) 83 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 84 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 85 | loss = - torch.log(loss).mean() 86 | return loss 87 | 88 | 89 | if __name__ == '__main__': 90 | GNN(num_class = 10) 91 | -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/gnn_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | import torch.nn.functional as F 5 | from torch_geometric.nn.inits import uniform 6 | 7 | from conv import GNN_node, GNN_node_Virtualnode 8 | 9 | from torch_scatter import scatter_mean 10 | 11 | class GNN(torch.nn.Module): 12 | 13 | def __init__(self, num_class, num_layer = 5, emb_dim = 300, 14 | gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): 15 | ''' 16 | num_tasks (int): number of labels to be predicted 17 | virtual_node (bool): whether to add virtual node or not 18 | ''' 19 | 20 | super(GNN, self).__init__() 21 | 22 | self.num_layer = num_layer 23 | self.drop_ratio = drop_ratio 24 | self.JK = JK 25 | self.emb_dim = emb_dim 26 | self.num_class = num_class 27 | self.graph_pooling = graph_pooling 28 | 29 | if self.num_layer < 2: 30 | raise ValueError("Number of GNN layers must be greater than 1.") 31 | 32 | ### GNN to generate node embeddings 33 | if virtual_node: 34 | self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 35 | else: 36 | self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 37 | 38 | 39 | ### Pooling function to generate whole-graph embeddings 40 | if self.graph_pooling == "sum": 41 | self.pool = global_add_pool 42 | elif self.graph_pooling == "mean": 43 | self.pool = global_mean_pool 44 | elif self.graph_pooling == "max": 45 | self.pool = global_max_pool 46 | elif self.graph_pooling == "attention": 47 | self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) 48 | elif self.graph_pooling == "set2set": 49 | self.pool = Set2Set(emb_dim, processing_steps = 2) 50 | else: 51 | raise ValueError("Invalid graph pooling type.") 52 | 53 | if graph_pooling == "set2set": 54 | self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_class) 55 | else: 56 | self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_class) 57 | 58 | self.proj_head = torch.nn.ModuleList([torch.nn.Sequential(torch.nn.Linear(self.emb_dim, self.emb_dim), torch.nn.ReLU(inplace=True), torch.nn.Linear(self.emb_dim, self.emb_dim)) for _ in range(4)]) 59 | 60 | def forward(self, batched_data): 61 | h_node = self.gnn_node(batched_data) 62 | 63 | h_graph = self.pool(h_node, batched_data.batch) 64 | 65 | return self.graph_pred_linear(h_graph) 66 | 67 | def forward_cl(self, batched_data, n_proj): 68 | h_node = self.gnn_node(batched_data) 69 | 70 | h_graph = self.pool(h_node, batched_data.batch) 71 | z = self.proj_head[n_proj](h_graph) 72 | return z 73 | 74 | def loss_cl(self, x1, x2): 75 | T = 0.5 76 | batch_size, _ = x1.size() 77 | 78 | x1_abs = x1.norm(dim=1) 79 | x2_abs = x2.norm(dim=1) 80 | 81 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs) 82 | sim_matrix = torch.exp(sim_matrix / T) 83 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 84 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 85 | loss = - torch.log(loss).mean() 86 | return loss 87 | 88 | 89 | if __name__ == '__main__': 90 | GNN(num_class = 10) 91 | -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/main_pretrain_graphcl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from gnn import GNN 6 | 7 | from tqdm import tqdm 8 | import argparse 9 | import time 10 | import numpy as np 11 | 12 | ### importing OGB 13 | from dataset_aug import PygGraphPropPredDataset, collate 14 | from ogb.graphproppred import Evaluator 15 | 16 | 17 | multicls_criterion = torch.nn.CrossEntropyLoss() 18 | 19 | def train(model, device, loader, optimizer): 20 | model.train() 21 | 22 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 23 | batch1, batch2 = batch 24 | batch1, batch2 = batch1.to(device), batch2.to(device) 25 | 26 | if batch1.x.shape[0] == 1 or batch1.batch[-1] == 0: 27 | pass 28 | else: 29 | x1, x2 = model.forward_cl(batch1), model.forward_cl(batch2) 30 | optimizer.zero_grad() 31 | loss = model.loss_cl(x1, x2) 32 | loss.backward() 33 | optimizer.step() 34 | 35 | 36 | def main(): 37 | # Training settings 38 | parser = argparse.ArgumentParser(description='GNN baselines on ogbg-ppa data with Pytorch Geometrics') 39 | parser.add_argument('--device', type=int, default=0, 40 | help='which gpu to use if any (default: 0)') 41 | parser.add_argument('--gnn', type=str, default='gin-virtual', 42 | help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') 43 | parser.add_argument('--drop_ratio', type=float, default=0.5, 44 | help='dropout ratio (default: 0.5)') 45 | parser.add_argument('--num_layer', type=int, default=5, 46 | help='number of GNN message passing layers (default: 5)') 47 | parser.add_argument('--emb_dim', type=int, default=300, 48 | help='dimensionality of hidden units in GNNs (default: 300)') 49 | parser.add_argument('--batch_size', type=int, default=32, 50 | help='input batch size for training (default: 32)') 51 | parser.add_argument('--epochs', type=int, default=100, 52 | help='number of epochs to train (default: 100)') 53 | parser.add_argument('--num_workers', type=int, default=0, 54 | help='number of workers (default: 0)') 55 | parser.add_argument('--dataset', type=str, default="ogbg-ppa", 56 | help='dataset name (default: ogbg-ppa)') 57 | 58 | parser.add_argument('--aug_ratio', type=float, default=0.2) 59 | args = parser.parse_args() 60 | 61 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 62 | 63 | ### automatic dataloading and splitting 64 | 65 | dataset = PygGraphPropPredDataset(name=args.dataset, mode='graphcl', aug_ratio=args.aug_ratio) 66 | 67 | split_idx = dataset.get_idx_split() 68 | 69 | train_loader = torch.utils.data.DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers, collate_fn=collate) 70 | 71 | if args.gnn == 'gin': 72 | model = GNN(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 73 | elif args.gnn == 'gin-virtual': 74 | model = GNN(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 75 | elif args.gnn == 'gcn': 76 | model = GNN(gnn_type = 'gcn', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 77 | elif args.gnn == 'gcn-virtual': 78 | model = GNN(gnn_type = 'gcn', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 79 | else: 80 | raise ValueError('Invalid GNN type') 81 | 82 | optimizer = optim.Adam(model.parameters(), lr=0.001) 83 | 84 | valid_curve = [] 85 | test_curve = [] 86 | train_curve = [] 87 | 88 | for epoch in range(1, args.epochs + 1): 89 | print("=====Epoch {}".format(epoch)) 90 | print('Training...') 91 | train(model, device, train_loader, optimizer) 92 | 93 | if epoch % 20 == 0: 94 | torch.save(model.state_dict(), './weights/' + args.dataset + '_graphcl_' + str(args.aug_ratio) + '_' + str(epoch) + '.pt') 95 | 96 | print('Finished training!') 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/main_pretrain_graphcl_joao.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from gnn import GNN 6 | 7 | from tqdm import tqdm 8 | import argparse 9 | import time 10 | import numpy as np 11 | 12 | ### importing OGB 13 | from dataset_aug import PygGraphPropPredDataset, collate 14 | from ogb.graphproppred import Evaluator 15 | 16 | 17 | multicls_criterion = torch.nn.CrossEntropyLoss() 18 | 19 | def train(model, device, loader, optimizer, aug_P, gamma): 20 | model.train() 21 | loader.dataset.aug_P = aug_P 22 | 23 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 24 | batch1, batch2 = batch 25 | batch1, batch2 = batch1.to(device), batch2.to(device) 26 | 27 | if batch1.x.shape[0] == 1 or batch1.batch[-1] == 0: 28 | pass 29 | else: 30 | x1, x2 = model.forward_cl(batch1), model.forward_cl(batch2) 31 | optimizer.zero_grad() 32 | loss = model.loss_cl(x1, x2) 33 | loss.backward() 34 | optimizer.step() 35 | 36 | # joint augmentation optimization 37 | loss_aug = np.zeros(16) 38 | for n in range(16): 39 | _aug_P = np.zeros(16) 40 | _aug_P[n] = 1 41 | loader.dataset.aug_P = _aug_P 42 | for batch in loader: 43 | batch1, batch2 = batch 44 | batch1, batch2 = batch1.to(device), batch2.to(device) 45 | 46 | if batch1.x.shape[0] == 1 or batch1.batch[-1] == 0: 47 | pass 48 | else: 49 | x1, x2 = model.forward_cl(batch1), model.forward_cl(batch2) 50 | loss = model.loss_cl(x1, x2) 51 | loss_aug[n] = loss.item() 52 | break 53 | 54 | gamma = gamma 55 | beta = 1 56 | b = aug_P + beta * (loss_aug - gamma * (aug_P - 1/16)) 57 | 58 | mu_min, mu_max = b.min()-1/16, b.max()-1/16 59 | mu = (mu_min + mu_max) / 2 60 | # bisection method 61 | while abs(np.maximum(b-mu, 0).sum() - 1) > 1e-2: 62 | if np.maximum(b-mu, 0).sum() > 1: 63 | mu_min = mu 64 | else: 65 | mu_max = mu 66 | mu = (mu_min + mu_max) / 2 67 | 68 | aug_P = np.maximum(b-mu, 0) 69 | aug_P /= aug_P.sum() 70 | print(loss_aug.reshape((4, 4))) 71 | print(aug_P.reshape((4, 4))) 72 | 73 | return aug_P 74 | 75 | 76 | def main(): 77 | # Training settings 78 | parser = argparse.ArgumentParser(description='GNN baselines on ogbg-ppa data with Pytorch Geometrics') 79 | parser.add_argument('--device', type=int, default=0, 80 | help='which gpu to use if any (default: 0)') 81 | parser.add_argument('--gnn', type=str, default='gin-virtual', 82 | help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') 83 | parser.add_argument('--drop_ratio', type=float, default=0.5, 84 | help='dropout ratio (default: 0.5)') 85 | parser.add_argument('--num_layer', type=int, default=5, 86 | help='number of GNN message passing layers (default: 5)') 87 | parser.add_argument('--emb_dim', type=int, default=300, 88 | help='dimensionality of hidden units in GNNs (default: 300)') 89 | parser.add_argument('--batch_size', type=int, default=32, 90 | help='input batch size for training (default: 32)') 91 | parser.add_argument('--epochs', type=int, default=100, 92 | help='number of epochs to train (default: 100)') 93 | parser.add_argument('--num_workers', type=int, default=0, 94 | help='number of workers (default: 0)') 95 | parser.add_argument('--dataset', type=str, default="ogbg-ppa", 96 | help='dataset name (default: ogbg-ppa)') 97 | 98 | parser.add_argument('--aug_ratio', type=float, default=0.2) 99 | parser.add_argument('--gamma', type=float, default=0.1) 100 | args = parser.parse_args() 101 | 102 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 103 | 104 | ### automatic dataloading and splitting 105 | 106 | dataset = PygGraphPropPredDataset(name=args.dataset, mode='sampling', aug_ratio=args.aug_ratio) 107 | 108 | split_idx = dataset.get_idx_split() 109 | 110 | train_loader = torch.utils.data.DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers, collate_fn=collate) 111 | 112 | if args.gnn == 'gin': 113 | model = GNN(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 114 | elif args.gnn == 'gin-virtual': 115 | model = GNN(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 116 | elif args.gnn == 'gcn': 117 | model = GNN(gnn_type = 'gcn', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 118 | elif args.gnn == 'gcn-virtual': 119 | model = GNN(gnn_type = 'gcn', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 120 | else: 121 | raise ValueError('Invalid GNN type') 122 | 123 | optimizer = optim.Adam(model.parameters(), lr=0.001) 124 | 125 | valid_curve = [] 126 | test_curve = [] 127 | train_curve = [] 128 | 129 | aug_P = np.ones(16) / 16 130 | for epoch in range(1, args.epochs + 1): 131 | print("=====Epoch {}".format(epoch)) 132 | print('Training...') 133 | aug_P = train(model, device, train_loader, optimizer, aug_P, args.gamma) 134 | 135 | if epoch % 20 == 0: 136 | torch.save(model.state_dict(), './weights/joao_' + str(args.aug_ratio) + '_' + str(args.gamma) + '_' + str(epoch) + '.pt') 137 | 138 | print('Finished training!') 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/main_pretrain_graphcl_joaov2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from gnn_proj import GNN 6 | 7 | from tqdm import tqdm 8 | import argparse 9 | import time 10 | import numpy as np 11 | 12 | ### importing OGB 13 | from dataset_aug import PygGraphPropPredDataset, collate 14 | from ogb.graphproppred import Evaluator 15 | 16 | 17 | multicls_criterion = torch.nn.CrossEntropyLoss() 18 | 19 | def train(model, device, loader, optimizer, aug_P, gamma): 20 | model.train() 21 | loader.dataset.aug_P = aug_P 22 | 23 | n_proj = np.random.choice(16, 1, p=aug_P)[0] 24 | n1_proj, n2_proj = n_proj//4, n_proj%4 25 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 26 | batch1, batch2 = batch 27 | batch1, batch2 = batch1.to(device), batch2.to(device) 28 | 29 | if batch1.x.shape[0] == 1 or batch1.batch[-1] == 0: 30 | pass 31 | else: 32 | x1, x2 = model.forward_cl(batch1, n1_proj), model.forward_cl(batch2, n2_proj) 33 | optimizer.zero_grad() 34 | loss = model.loss_cl(x1, x2) 35 | loss.backward() 36 | optimizer.step() 37 | 38 | # joint augmentation optimization 39 | loss_aug = np.zeros(16) 40 | for n in range(16): 41 | _aug_P = np.zeros(16) 42 | _aug_P[n] = 1 43 | loader.dataset.aug_P = _aug_P 44 | n1_proj, n2_proj = n//4, n%4 45 | for batch in loader: 46 | batch1, batch2 = batch 47 | batch1, batch2 = batch1.to(device), batch2.to(device) 48 | 49 | if batch1.x.shape[0] == 1 or batch1.batch[-1] == 0: 50 | pass 51 | else: 52 | x1, x2 = model.forward_cl(batch1, n1_proj), model.forward_cl(batch2, n2_proj) 53 | loss = model.loss_cl(x1, x2) 54 | loss_aug[n] = loss.item() 55 | break 56 | 57 | gamma = gamma 58 | beta = 1 59 | b = aug_P + beta * (loss_aug - gamma * (aug_P - 1/16)) 60 | 61 | mu_min, mu_max = b.min()-1/16, b.max()-1/16 62 | mu = (mu_min + mu_max) / 2 63 | # bisection method 64 | while abs(np.maximum(b-mu, 0).sum() - 1) > 1e-2: 65 | if np.maximum(b-mu, 0).sum() > 1: 66 | mu_min = mu 67 | else: 68 | mu_max = mu 69 | mu = (mu_min + mu_max) / 2 70 | 71 | aug_P = np.maximum(b-mu, 0) 72 | aug_P /= aug_P.sum() 73 | print(loss_aug.reshape((4, 4))) 74 | print(aug_P.reshape((4, 4))) 75 | 76 | return aug_P 77 | 78 | 79 | def main(): 80 | # Training settings 81 | parser = argparse.ArgumentParser(description='GNN baselines on ogbg-ppa data with Pytorch Geometrics') 82 | parser.add_argument('--device', type=int, default=0, 83 | help='which gpu to use if any (default: 0)') 84 | parser.add_argument('--gnn', type=str, default='gin-virtual', 85 | help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') 86 | parser.add_argument('--drop_ratio', type=float, default=0.5, 87 | help='dropout ratio (default: 0.5)') 88 | parser.add_argument('--num_layer', type=int, default=5, 89 | help='number of GNN message passing layers (default: 5)') 90 | parser.add_argument('--emb_dim', type=int, default=300, 91 | help='dimensionality of hidden units in GNNs (default: 300)') 92 | parser.add_argument('--batch_size', type=int, default=32, 93 | help='input batch size for training (default: 32)') 94 | parser.add_argument('--epochs', type=int, default=100, 95 | help='number of epochs to train (default: 100)') 96 | parser.add_argument('--num_workers', type=int, default=0, 97 | help='number of workers (default: 0)') 98 | parser.add_argument('--dataset', type=str, default="ogbg-ppa", 99 | help='dataset name (default: ogbg-ppa)') 100 | 101 | parser.add_argument('--aug_ratio', type=float, default=0.2) 102 | parser.add_argument('--gamma', type=float, default=0.1) 103 | args = parser.parse_args() 104 | 105 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 106 | 107 | ### automatic dataloading and splitting 108 | 109 | dataset = PygGraphPropPredDataset(name=args.dataset, mode='sampling', aug_ratio=args.aug_ratio) 110 | 111 | split_idx = dataset.get_idx_split() 112 | 113 | train_loader = torch.utils.data.DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers, collate_fn=collate) 114 | 115 | if args.gnn == 'gin': 116 | model = GNN(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 117 | elif args.gnn == 'gin-virtual': 118 | model = GNN(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 119 | elif args.gnn == 'gcn': 120 | model = GNN(gnn_type = 'gcn', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 121 | elif args.gnn == 'gcn-virtual': 122 | model = GNN(gnn_type = 'gcn', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 123 | else: 124 | raise ValueError('Invalid GNN type') 125 | 126 | optimizer = optim.Adam(model.parameters(), lr=0.001) 127 | 128 | valid_curve = [] 129 | test_curve = [] 130 | train_curve = [] 131 | 132 | aug_P = np.ones(16) / 16 133 | for epoch in range(1, args.epochs + 1): 134 | print("=====Epoch {}".format(epoch)) 135 | print('Training...') 136 | aug_P = train(model, device, train_loader, optimizer, aug_P, args.gamma) 137 | 138 | if epoch % 20 == 0: 139 | torch.save(model.state_dict(), './weights/joaov2_' + str(args.aug_ratio) + '_' + str(args.gamma) + '_' + str(epoch) + '.pt') 140 | 141 | print('Finished training!') 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/results/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_OGB/ppa/results/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for number in {1..10} 4 | do 5 | python finetune.py --gnn gin --num_workers 8 --n_splits $1 --pretrain $2 --pretrain_weight $3 6 | done 7 | -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/run_joaov2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for number in {1..10} 4 | do 5 | python finetune_joaov2.py --gnn gin --num_workers 8 --n_splits $1 --pretrain $2 --pretrain_weight $3 6 | done 7 | -------------------------------------------------------------------------------- /semisupervised_OGB/ppa/weights/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_OGB/ppa/weights/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_OGB_LP/README.md: -------------------------------------------------------------------------------- 1 | ### LP-InfoMin/InfoBN/Info(Min+BN) Pre-Training: ### 2 | 3 | ``` 4 | cd ./code 5 | python main_pretrain_generative_infomin.py 6 | python main_pretrain_generative_infobn.py 7 | python main_pretrain_generative_infominbn.py 8 | cd ./ppa 9 | python main_pretrain_generative_infomin.py 10 | python main_pretrain_generative_infobn.py 11 | python main_pretrain_generative_infominbn.py 12 | ``` 13 | 14 | 15 | ### LP-InfoMin/InfoBN/Info(Min+BN) Finetuning: ### 16 | 17 | ``` 18 | cd ./code 19 | ./finetune.sh ${N_SPLIT} ${RESULT_FILE} ${MODEL_PATH} 20 | cd ./ppa 21 | ./finetune.sh ${N_SPLIT} ${RESULT_FILE} ${MODEL_PATH} 22 | ``` 23 | 24 | ```N_SPLIT``` can be 100 or 10 for 1% or 10% label rate, ```RESULT_FILE``` is the file to store the results, and ```MODEL_PATH``` is the path to save the pre-trained model. 25 | 26 | 27 | ## Acknowledgements 28 | 29 | The backbone implementation is reference to https://github.com/snap-stanford/ogb. 30 | -------------------------------------------------------------------------------- /semisupervised_OGB_LP/code/cal.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | 4 | 5 | fn = str(sys.argv[1]) 6 | with open(fn, 'r') as f: 7 | data = f.read().split('\n')[:-1] 8 | 9 | val_res = [float(d.split()[0])*100 for d in data] 10 | test_res = [float(d.split()[1])*100 for d in data] 11 | 12 | 13 | print('val', np.mean(val_res), np.std(val_res)) 14 | print('test', np.mean(test_res), np.std(test_res)) 15 | -------------------------------------------------------------------------------- /semisupervised_OGB_LP/code/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for number in {1..10} 4 | do 5 | python main_pyg.py --gnn gin --num_workers 8 --n_splits $1 --pretrain $2 --pretrain_weight $3 6 | done 7 | -------------------------------------------------------------------------------- /semisupervised_OGB_LP/code/py2graph.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import networkx as nx 3 | import os 4 | import pandas as pd 5 | import numpy as np 6 | 7 | class OGB_ASTWalker(ast.NodeVisitor): 8 | def __init__(self): 9 | self.node_id = 0 10 | self.stack = [] 11 | self.graph = nx.Graph() 12 | self.nodes = {} 13 | 14 | def generic_visit(self, node): 15 | node_name = self.node_id 16 | self.node_id += 1 17 | 18 | # if available, extract AST node attributes 19 | name = getattr(node, 'name', None) 20 | arg = getattr(node, 'arg', None) 21 | s = getattr(node, 's', None) 22 | n = getattr(node, 'n', None) 23 | id_ = getattr(node, 'id', None) 24 | attr = getattr(node, 'attr', None) 25 | 26 | values = [name, arg, s, n, id_, attr] 27 | node_value = next((str(value) for value in values if value is not None), None) 28 | if isinstance(node_value, str): 29 | node_value = node_value.encode('utf-8', errors='surrogatepass') 30 | 31 | # encapsulate all node features in a dict 32 | self.nodes[node_name] = {'type': type(node).__name__, 33 | 'attribute': node_value, 34 | 'attributed': True if node_value != None else False, 35 | 'depth': len(self.stack), 36 | 'dfs_order': node_name} 37 | 38 | # DFS traversal logic 39 | parent_name = None 40 | if self.stack: 41 | parent_name = self.stack[-1] 42 | self.stack.append(node_name) 43 | self.graph.add_node(node_name) 44 | if parent_name != None: 45 | # replicate AST as NetworkX object 46 | self.graph.add_edge(node_name, parent_name) 47 | super().generic_visit(node) 48 | self.stack.pop() 49 | 50 | 51 | def py2graph_helper(code, attr2idx, type2idx): 52 | ''' 53 | Input: 54 | code: code snippet 55 | 56 | Mappers: 57 | attr_mapping: mapping from attribute to integer idx 58 | type_mapping: mapping from type to integer idx 59 | 60 | Output: OGB graph object 61 | ''' 62 | 63 | tree = ast.parse(code) 64 | walker = OGB_ASTWalker() 65 | walker.visit(tree) 66 | 67 | ast_nodes, ast_edges = walker.nodes, walker.graph.edges() 68 | 69 | data = dict() 70 | data['edge_index'] = np.array([[i, j] for i, j in ast_edges]).transpose() 71 | 72 | # first dim: type 73 | # second dim: attr 74 | 75 | # meta-info 76 | # dfs_order: integer 77 | # attributed: 0 or 1 78 | 79 | node_feat = [] 80 | dfs_order = [] 81 | depth = [] 82 | attributed = [] 83 | for i in range(len(ast_nodes)): 84 | typ = ast_nodes[i]['type'] if ast_nodes[i]['type'] in type2idx else '__UNK__' 85 | 86 | if ast_nodes[i]['attributed']: 87 | attr = ast_nodes[i]['attribute'].decode('UTF-8') if ast_nodes[i]['attribute'].decode('UTF-8') in attr2idx else '__UNK__' 88 | else: 89 | attr = '__NONE__' 90 | 91 | node_feat.append([type2idx[typ], attr2idx[attr]]) 92 | 93 | dfs_order.append(ast_nodes[i]['dfs_order']) 94 | depth.append(ast_nodes[i]['depth']) 95 | attributed.append(ast_nodes[i]['attributed']) 96 | 97 | ### meta-information 98 | data['node_feat'] = np.array(node_feat, dtype = np.int64) 99 | data['node_dfs_order'] = np.array(dfs_order, dtype = np.int64).reshape(-1,1) 100 | data['node_depth'] = np.array(depth, dtype = np.int64).reshape(-1,1) 101 | data['node_is_attributed'] = np.array(attributed, dtype = np.int64).reshape(-1,1) 102 | 103 | data['num_nodes'] = len(data['node_feat']) 104 | data['num_edges'] = len(data['edge_index'][0]) 105 | 106 | return data 107 | 108 | def test_transform(py2graph): 109 | code = ''' 110 | from ogb.graphproppred import PygGraphPropPredDataset 111 | from torch_geometric.data import DataLoader 112 | 113 | dataset = PygGraphPropPredDataset(name = "ogbg-molhiv") 114 | 115 | split_idx = dataset.get_idx_split() 116 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True) 117 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False) 118 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False) 119 | ''' 120 | 121 | graph = py2graph(code) 122 | print(graph) 123 | 124 | invalid_code = ''' 125 | import antigravity 126 | xkcd loves Python 127 | ''' 128 | 129 | try: 130 | graph = py2graph(invalid_code) 131 | except SyntaxError: 132 | print('Successfully caught syntax error') 133 | 134 | 135 | if __name__ == "__main__": 136 | mapping_dir = 'dataset/ogbg_code_pyg/mapping' 137 | 138 | attr_mapping = dict() 139 | type_mapping = dict() 140 | 141 | for line in pd.read_csv(os.path.join(mapping_dir, 'attridx2attr.csv.gz')).values: 142 | attr_mapping[line[1]] = int(line[0]) 143 | 144 | for line in pd.read_csv(os.path.join(mapping_dir, 'typeidx2type.csv.gz')).values: 145 | type_mapping[line[1]] = int(line[0]) 146 | 147 | py2graph = lambda py: py2graph_helper(py, attr_mapping, type_mapping) 148 | test_transform(py2graph) 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /semisupervised_OGB_LP/ppa/cal.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | 4 | 5 | fn = str(sys.argv[1]) 6 | with open(fn, 'r') as f: 7 | data = f.read().split('\n')[:-1] 8 | 9 | val_res = [float(d.split()[0])*100 for d in data] 10 | test_res = [float(d.split()[1])*100 for d in data] 11 | 12 | 13 | print('val', np.mean(val_res), np.std(val_res)) 14 | print('test', np.mean(test_res), np.std(test_res)) 15 | -------------------------------------------------------------------------------- /semisupervised_OGB_LP/ppa/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for number in {1..10} 4 | do 5 | python main_pyg.py --gnn gin --num_workers 8 --n_splits $1 --pretrain $2 --pretrain_weight $3 6 | done 7 | -------------------------------------------------------------------------------- /semisupervised_OGB_LP/ppa/main_pretrain_generative_infomin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from gnn import GNN, vgae 6 | 7 | from tqdm import tqdm 8 | import argparse 9 | import time 10 | import numpy as np 11 | 12 | ### importing OGB 13 | from dataset_graphcl import PygGraphPropPredDataset, collate 14 | from ogb.graphproppred import Evaluator 15 | from copy import deepcopy 16 | 17 | 18 | multicls_criterion = torch.nn.CrossEntropyLoss() 19 | 20 | def train(model, optimizer, generator1, optimizer_gen1, generator2, optimizer_gen2, device, loader): 21 | model.train() 22 | loss_pretrain, loss_generator = 0, 0 23 | # for step, batch in enumerate(tqdm(loader, desc="Iteration")): 24 | for step, batch in enumerate(loader): 25 | batch, batch1, batch2 = batch 26 | batch, batch1, batch2 = batch.to(device), batch1.to(device), batch2.to(device) 27 | 28 | #1. graphcl 29 | if batch1.x.shape[0] == 1 or batch1.batch[-1] == 0: 30 | pass 31 | else: 32 | x1, x2 = model.forward_cl(batch1), model.forward_cl(batch2) 33 | optimizer.zero_grad() 34 | loss_cl = model.loss_cl(x1, x2, mean=False) 35 | loss = loss_cl.mean() 36 | 37 | loss.backward() 38 | optimizer.step() 39 | loss_pretrain += loss.item() 40 | 41 | # reward for joao 42 | loss_cl = loss_cl.detach() 43 | loss_cl = loss_cl - loss_cl.mean() 44 | loss_cl[loss_cl>0] = 1 45 | loss_cl[loss_cl<=0] = 0.01 # weaken the reward for low cl loss 46 | 47 | # 2. joao 48 | optimizer_gen1.zero_grad() 49 | optimizer_gen2.zero_grad() 50 | 51 | x, x_mean, x_std = generator1.forward_encoder(batch) 52 | edge_attr_pred, edge_pos_pred, edge_neg_pred = generator1.forward_decoder(x, batch.edge_index, batch.edge_index_neg) 53 | loss_1 = generator1.loss_vgae(edge_attr_pred, batch.edge_attr, edge_pos_pred, edge_neg_pred, batch.edge_index_batch, batch.edge_index_neg_batch, x_mean, x_std, batch.batch, reward=loss_cl) 54 | 55 | x, x_mean, x_std = generator2.forward_encoder(batch) 56 | edge_attr_pred, edge_pos_pred, edge_neg_pred = generator2.forward_decoder(x, batch.edge_index, batch.edge_index_neg) 57 | loss_2 = generator2.loss_vgae(edge_attr_pred, batch.edge_attr, edge_pos_pred, edge_neg_pred, batch.edge_index_batch, batch.edge_index_neg_batch, x_mean, x_std, batch.batch, reward=loss_cl) 58 | 59 | loss = loss_1 + loss_2 60 | 61 | loss.backward() 62 | optimizer_gen1.step() 63 | optimizer_gen2.step() 64 | loss_generator += loss.item() 65 | 66 | print(loss_pretrain/step, loss_generator/step) 67 | 68 | 69 | def main(): 70 | # Training settings 71 | parser = argparse.ArgumentParser(description='GNN baselines on ogbg-ppa data with Pytorch Geometrics') 72 | parser.add_argument('--device', type=int, default=0, 73 | help='which gpu to use if any (default: 0)') 74 | parser.add_argument('--drop_ratio', type=float, default=0.5, 75 | help='dropout ratio (default: 0.5)') 76 | parser.add_argument('--num_layer', type=int, default=5, 77 | help='number of GNN message passing layers (default: 5)') 78 | parser.add_argument('--emb_dim', type=int, default=300, 79 | help='dimensionality of hidden units in GNNs (default: 300)') 80 | parser.add_argument('--batch_size', type=int, default=32, 81 | help='input batch size for training (default: 32)') 82 | parser.add_argument('--epochs', type=int, default=100, 83 | help='number of epochs to train (default: 100)') 84 | parser.add_argument('--num_workers', type=int, default=24, 85 | help='number of workers (default: 0)') 86 | parser.add_argument('--dataset', type=str, default="ogbg-ppa", 87 | help='dataset name (default: ogbg-ppa)') 88 | 89 | parser.add_argument('--aug_mode', type=str, default='generative') 90 | parser.add_argument('--aug_strength', type=float, default=0.2) 91 | 92 | parser.add_argument('--resume', type=int, default=44) 93 | args = parser.parse_args() 94 | 95 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 96 | 97 | ### automatic dataloading and splitting 98 | 99 | dataset = PygGraphPropPredDataset(name=args.dataset) 100 | dataset.set_augMode(args.aug_mode) 101 | dataset.set_augStrength(args.aug_strength) 102 | 103 | split_idx = dataset.get_idx_split() 104 | 105 | loader = torch.utils.data.DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers, collate_fn=collate) 106 | 107 | model = GNN(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 108 | optimizer = optim.Adam(model.parameters(), lr=0.001) 109 | 110 | # generators 111 | generator1 = vgae(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 112 | optimizer_gen1 = optim.Adam(generator1.parameters(), lr=0.001) 113 | generator2 = vgae(gnn_type = 'gin', num_class = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 114 | optimizer_gen2 = optim.Adam(generator2.parameters(), lr=0.001) 115 | 116 | if not args.resume == 0: 117 | checkpoint = torch.load('./weights_generative_joao/checkpoint_'+str(args.resume)+'.pt') 118 | model.load_state_dict(checkpoint['graphcl']) 119 | optimizer.load_state_dict(checkpoint['graphcl_optimizer']) 120 | generator1.load_state_dict(checkpoint['generator1']) 121 | optimizer_gen1.load_state_dict(checkpoint['generator1_optimizer']) 122 | generator2.load_state_dict(checkpoint['generator2']) 123 | optimizer_gen2.load_state_dict(checkpoint['generator2_optimizer']) 124 | 125 | for epoch in range(args.resume+1, args.epochs + 1): 126 | loader.dataset.set_generator(deepcopy(generator1).cpu(), deepcopy(generator2).cpu()) 127 | train(model, optimizer, generator1, optimizer_gen1, generator2, optimizer_gen2, device, loader) 128 | 129 | torch.save({'graphcl':model.state_dict(), 'graphcl_optimizer': optimizer.state_dict(), 'generator1':generator1.state_dict(), 'generator1_optimizer':optimizer_gen1.state_dict(), 'generator2':generator2.state_dict(), 'generator2_optimizer':optimizer_gen2.state_dict()}, './weights_generative_joao/checkpoint_'+str(epoch)+'.pt') 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | 135 | -------------------------------------------------------------------------------- /semisupervised_TU/README.md: -------------------------------------------------------------------------------- 1 | ### JOAO Pre-Training: ### 2 | 3 | ``` 4 | cd ./pretrain 5 | python main.py --dataset NCI1 --epochs 100 --lr 0.001 --gamma_joao 0.1 --suffix 0 6 | python main.py --dataset NCI1 --epochs 100 --lr 0.001 --gamma_joao 0.1 --suffix 1 7 | python main.py --dataset NCI1 --epochs 100 --lr 0.001 --gamma_joao 0.1 --suffix 2 8 | python main.py --dataset NCI1 --epochs 100 --lr 0.001 --gamma_joao 0.1 --suffix 3 9 | python main.py --dataset NCI1 --epochs 100 --lr 0.001 --gamma_joao 0.1 --suffix 4 10 | ``` 11 | 12 | ### JOAO Finetuning: ### 13 | 14 | ``` 15 | cd ./finetune 16 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 0 --n_splits 100 17 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 1 --n_splits 100 18 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 2 --n_splits 100 19 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 3 --n_splits 100 20 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 4 --n_splits 100 21 | ``` 22 | 23 | Five suffixes stand for five runs (with mean & std reported) 24 | 25 | ```lr``` should be tuned from {0.01, 0.001, 0.0001}, ```gamma_joao``` from {0.01, 0.1, 1} in pre-training, and ```pretrain_epoch``` in finetuning (this means the epoch checkpoint loaded from pre-trained model) from {20, 40, 60, 80, 100}. 26 | 27 | 28 | ### JOAOv2 Pre-Training: ### 29 | 30 | ``` 31 | cd ./pretrain_joaov2 32 | python main.py --dataset NCI1 --epochs 200 --lr 0.001 --gamma_joao 0.1 --suffix 0 33 | python main.py --dataset NCI1 --epochs 200 --lr 0.001 --gamma_joao 0.1 --suffix 1 34 | python main.py --dataset NCI1 --epochs 200 --lr 0.001 --gamma_joao 0.1 --suffix 2 35 | python main.py --dataset NCI1 --epochs 200 --lr 0.001 --gamma_joao 0.1 --suffix 3 36 | python main.py --dataset NCI1 --epochs 200 --lr 0.001 --gamma_joao 0.1 --suffix 4 37 | ``` 38 | 39 | ### JOAOv2 Finetuning: ### 40 | 41 | ``` 42 | cd ./finetune_joaov2 43 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 0 --n_splits 100 44 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 1 --n_splits 100 45 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 2 --n_splits 100 46 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 3 --n_splits 100 47 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 --pretrain_gamma 0.1 --suffix 4 --n_splits 100 48 | ``` 49 | 50 | Five suffixes stand for five runs (with mean & std reported) 51 | 52 | ```lr``` should be tuned from {0.01, 0.001, 0.0001}, ```gamma_joao``` from {0.01, 0.1, 1} in pre-training, and ```pretrain_epoch``` in finetuning (this means the epoch checkpoint loaded from pre-trained model) from {20, 40, 60, 80, 100, 120, 140, 160, 180, 200} since multiple projection heads are trained. 53 | 54 | ## Acknowledgements 55 | 56 | The backbone implementation is reference to https://github.com/chentingpc/gfn. 57 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | import torch 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | from torch_geometric.utils import degree 7 | import torch_geometric.transforms as T 8 | from feature_expansion import FeatureExpander 9 | from tu_dataset import TUDatasetExt 10 | 11 | 12 | def get_dataset(name, sparse=True, feat_str="deg+ak3+reall", root=None): 13 | if root is None or root == '': 14 | path = osp.join(osp.expanduser('~'), 'pyG_data', name) 15 | else: 16 | path = osp.join(root, name) 17 | path = '../' + path 18 | degree = feat_str.find("deg") >= 0 19 | onehot_maxdeg = re.findall("odeg(\d+)", feat_str) 20 | onehot_maxdeg = int(onehot_maxdeg[0]) if onehot_maxdeg else None 21 | k = re.findall("an{0,1}k(\d+)", feat_str) 22 | k = int(k[0]) if k else 0 23 | groupd = re.findall("groupd(\d+)", feat_str) 24 | groupd = int(groupd[0]) if groupd else 0 25 | remove_edges = re.findall("re(\w+)", feat_str) 26 | remove_edges = remove_edges[0] if remove_edges else 'none' 27 | edge_noises_add = re.findall("randa([\d\.]+)", feat_str) 28 | edge_noises_add = float(edge_noises_add[0]) if edge_noises_add else 0 29 | edge_noises_delete = re.findall("randd([\d\.]+)", feat_str) 30 | edge_noises_delete = float( 31 | edge_noises_delete[0]) if edge_noises_delete else 0 32 | centrality = feat_str.find("cent") >= 0 33 | coord = feat_str.find("coord") >= 0 34 | 35 | pre_transform = FeatureExpander( 36 | degree=degree, onehot_maxdeg=onehot_maxdeg, AK=k, 37 | centrality=centrality, remove_edges=remove_edges, 38 | edge_noises_add=edge_noises_add, edge_noises_delete=edge_noises_delete, 39 | group_degree=groupd).transform 40 | 41 | dataset = TUDatasetExt( 42 | path, name, pre_transform=pre_transform, 43 | use_node_attr=True, processed_filename="data_%s.pt" % feat_str) 44 | dataset.data.edge_attr = None 45 | 46 | return dataset 47 | 48 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | from torch_geometric.nn.inits import glorot, zeros 7 | 8 | 9 | class GCNConv(MessagePassing): 10 | r"""The graph convolutional operator from the `"Semi-supervised 11 | Classfication with Graph Convolutional Networks" 12 | `_ paper 13 | 14 | .. math:: 15 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 16 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 17 | 18 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 19 | adjacency matrix with inserted self-loops and 20 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 21 | 22 | Args: 23 | in_channels (int): Size of each input sample. 24 | out_channels (int): Size of each output sample. 25 | improved (bool, optional): If set to :obj:`True`, the layer computes 26 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 27 | (default: :obj:`False`) 28 | cached (bool, optional): If set to :obj:`True`, the layer will cache 29 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 30 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 31 | (default: :obj:`False`) 32 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 33 | an additive bias. (default: :obj:`True`) 34 | edge_norm (bool, optional): whether or not to normalize adj matrix. 35 | (default: :obj:`True`) 36 | gfn (bool, optional): If `True`, only linear transform (1x1 conv) is 37 | applied to every nodes. (default: :obj:`False`) 38 | """ 39 | 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | improved=False, 44 | cached=False, 45 | bias=True, 46 | edge_norm=True, 47 | gfn=False): 48 | super(GCNConv, self).__init__('add') 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.improved = improved 53 | self.cached = cached 54 | self.cached_result = None 55 | self.edge_norm = edge_norm 56 | self.gfn = gfn 57 | 58 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 59 | 60 | if bias: 61 | self.bias = Parameter(torch.Tensor(out_channels)) 62 | else: 63 | self.register_parameter('bias', None) 64 | 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | glorot(self.weight) 69 | zeros(self.bias) 70 | self.cached_result = None 71 | 72 | @staticmethod 73 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 74 | if edge_weight is None: 75 | edge_weight = torch.ones((edge_index.size(1), ), 76 | dtype=dtype, 77 | device=edge_index.device) 78 | edge_weight = edge_weight.view(-1) 79 | assert edge_weight.size(0) == edge_index.size(1) 80 | 81 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 82 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 83 | # Add edge_weight for loop edges. 84 | loop_weight = torch.full((num_nodes, ), 85 | 1 if not improved else 2, 86 | dtype=edge_weight.dtype, 87 | device=edge_weight.device) 88 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 89 | 90 | edge_index = edge_index[0] 91 | row, col = edge_index 92 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 93 | deg_inv_sqrt = deg.pow(-0.5) 94 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 95 | 96 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 97 | 98 | def forward(self, x, edge_index, edge_weight=None): 99 | """""" 100 | x = torch.matmul(x, self.weight) 101 | if self.gfn: 102 | return x 103 | 104 | if not self.cached or self.cached_result is None: 105 | if self.edge_norm: 106 | edge_index, norm = GCNConv.norm( 107 | edge_index, x.size(0), edge_weight, self.improved, x.dtype) 108 | else: 109 | norm = None 110 | self.cached_result = edge_index, norm 111 | 112 | edge_index, norm = self.cached_result 113 | return self.propagate(edge_index, x=x, norm=norm) 114 | 115 | def message(self, x_j, norm): 116 | if self.edge_norm: 117 | return norm.view(-1, 1) * x_j 118 | else: 119 | return x_j 120 | 121 | def update(self, aggr_out): 122 | if self.bias is not None: 123 | aggr_out = aggr_out + self.bias 124 | return aggr_out 125 | 126 | def __repr__(self): 127 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 128 | self.out_channels) 129 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune/results_joao/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU/finetune/results_joao/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_TU/finetune/tu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | from itertools import repeat 5 | 6 | import numpy as np 7 | import torch 8 | import torch_geometric.utils as tg_utils 9 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 10 | from torch_geometric.io import read_tu_data 11 | 12 | 13 | # tudataset adopted from torch_geometric==1.1.0 14 | class TUDatasetExt(InMemoryDataset): 15 | r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY", 16 | "REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University 17 | `_. 18 | 19 | Args: 20 | root (string): Root directory where the dataset should be saved. 21 | name (string): The `name `_ of 22 | the dataset. 23 | transform (callable, optional): A function/transform that takes in an 24 | :obj:`torch_geometric.data.Data` object and returns a transformed 25 | version. The data object will be transformed before every access. 26 | (default: :obj:`None`) 27 | pre_transform (callable, optional): A function/transform that takes in 28 | an :obj:`torch_geometric.data.Data` object and returns a 29 | transformed version. The data object will be transformed before 30 | being saved to disk. (default: :obj:`None`) 31 | pre_filter (callable, optional): A function that takes in an 32 | :obj:`torch_geometric.data.Data` object and returns a boolean 33 | value, indicating whether the data object should be included in the 34 | final dataset. (default: :obj:`None`) 35 | use_node_attr (bool, optional): If :obj:`True`, the dataset will 36 | contain additional continuous node features (if present). 37 | (default: :obj:`False`) 38 | """ 39 | 40 | url = 'https://ls11-www.cs.uni-dortmund.de/people/morris/' \ 41 | 'graphkerneldatasets' 42 | 43 | def __init__(self, 44 | root, 45 | name, 46 | transform=None, 47 | pre_transform=None, 48 | pre_filter=None, 49 | use_node_attr=False, 50 | processed_filename='data.pt'): 51 | self.name = name 52 | self.processed_filename = processed_filename 53 | super(TUDatasetExt, self).__init__(root, transform, pre_transform, 54 | pre_filter) 55 | self.data, self.slices = torch.load(self.processed_paths[0]) 56 | if self.data.x is not None and not use_node_attr: 57 | self.data.x = self.data.x[:, self.num_node_attributes:] 58 | 59 | @property 60 | def num_node_labels(self): 61 | if self.data.x is None: 62 | return 0 63 | for i in range(self.data.x.size(1)): 64 | if self.data.x[:, i:].sum().item() == self.data.x.size(0): 65 | return self.data.x.size(1) - i 66 | return 0 67 | 68 | @property 69 | def num_node_attributes(self): 70 | if self.data.x is None: 71 | return 0 72 | return self.data.x.size(1) - self.num_node_labels 73 | 74 | @property 75 | def raw_file_names(self): 76 | names = ['A', 'graph_indicator'] 77 | return ['{}_{}.txt'.format(self.name, name) for name in names] 78 | 79 | @property 80 | def processed_file_names(self): 81 | return self.processed_filename 82 | 83 | @property 84 | def num_node_features(self): 85 | r"""Returns the number of features per node in the dataset.""" 86 | return self[0].num_node_features 87 | 88 | def download(self): 89 | path = download_url('{}/{}.zip'.format(self.url, self.name), self.root) 90 | extract_zip(path, self.root) 91 | os.unlink(path) 92 | shutil.rmtree(self.raw_dir) 93 | os.rename(osp.join(self.root, self.name), self.raw_dir) 94 | 95 | def process(self): 96 | self.data, self.slices = read_tu_data(self.raw_dir, self.name) 97 | 98 | if self.pre_filter is not None: 99 | data_list = [self.get(idx) for idx in range(len(self))] 100 | data_list = [data for data in data_list if self.pre_filter(data)] 101 | self.data, self.slices = self.collate(data_list) 102 | 103 | if self.pre_transform is not None: 104 | data_list = [self.get(idx) for idx in range(len(self))] 105 | data_list = [self.pre_transform(data) for data in data_list] 106 | self.data, self.slices = self.collate(data_list) 107 | 108 | torch.save((self.data, self.slices), self.processed_paths[0]) 109 | 110 | def __repr__(self): 111 | return '{}({})'.format(self.name, len(self)) 112 | 113 | def get(self, idx): 114 | data = self.data.__class__() 115 | if hasattr(self.data, '__num_nodes__'): 116 | data.num_nodes = self.data.__num_nodes__[idx] 117 | for key in self.data.keys: 118 | item, slices = self.data[key], self.slices[key] 119 | if torch.is_tensor(item): 120 | s = list(repeat(slice(None), item.dim())) 121 | s[self.data.__cat_dim__(key, 122 | item)] = slice(slices[idx], 123 | slices[idx + 1]) 124 | else: 125 | s = slice(slices[idx], slices[idx + 1]) 126 | data[key] = item[s] 127 | return data 128 | 129 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def print_weights(model): 4 | for name, param in model.named_parameters(): 5 | if param.requires_grad: 6 | print(name, param.shape) 7 | sys.stdout.flush() 8 | 9 | 10 | def logger(info): 11 | fold, epoch = info['fold'], info['epoch'] 12 | if epoch == 1 or epoch % 10 == 0: 13 | train_acc, test_acc = info['train_acc'], info['test_acc'] 14 | print('{:02d}/{:03d}: Train Acc: {:.3f}, Test Accuracy: {:.3f}'.format( 15 | fold, epoch, train_acc, test_acc)) 16 | sys.stdout.flush() 17 | 18 | 19 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune_joaov2/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | import torch 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | from torch_geometric.utils import degree 7 | import torch_geometric.transforms as T 8 | from feature_expansion import FeatureExpander 9 | from tu_dataset import TUDatasetExt 10 | 11 | 12 | def get_dataset(name, sparse=True, feat_str="deg+ak3+reall", root=None): 13 | if root is None or root == '': 14 | path = osp.join(osp.expanduser('~'), 'pyG_data', name) 15 | else: 16 | path = osp.join(root, name) 17 | path = '../' + path 18 | degree = feat_str.find("deg") >= 0 19 | onehot_maxdeg = re.findall("odeg(\d+)", feat_str) 20 | onehot_maxdeg = int(onehot_maxdeg[0]) if onehot_maxdeg else None 21 | k = re.findall("an{0,1}k(\d+)", feat_str) 22 | k = int(k[0]) if k else 0 23 | groupd = re.findall("groupd(\d+)", feat_str) 24 | groupd = int(groupd[0]) if groupd else 0 25 | remove_edges = re.findall("re(\w+)", feat_str) 26 | remove_edges = remove_edges[0] if remove_edges else 'none' 27 | edge_noises_add = re.findall("randa([\d\.]+)", feat_str) 28 | edge_noises_add = float(edge_noises_add[0]) if edge_noises_add else 0 29 | edge_noises_delete = re.findall("randd([\d\.]+)", feat_str) 30 | edge_noises_delete = float( 31 | edge_noises_delete[0]) if edge_noises_delete else 0 32 | centrality = feat_str.find("cent") >= 0 33 | coord = feat_str.find("coord") >= 0 34 | 35 | pre_transform = FeatureExpander( 36 | degree=degree, onehot_maxdeg=onehot_maxdeg, AK=k, 37 | centrality=centrality, remove_edges=remove_edges, 38 | edge_noises_add=edge_noises_add, edge_noises_delete=edge_noises_delete, 39 | group_degree=groupd).transform 40 | 41 | dataset = TUDatasetExt( 42 | path, name, pre_transform=pre_transform, 43 | use_node_attr=True, processed_filename="data_%s.pt" % feat_str) 44 | dataset.data.edge_attr = None 45 | 46 | return dataset 47 | 48 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune_joaov2/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | from torch_geometric.nn.inits import glorot, zeros 7 | 8 | 9 | class GCNConv(MessagePassing): 10 | r"""The graph convolutional operator from the `"Semi-supervised 11 | Classfication with Graph Convolutional Networks" 12 | `_ paper 13 | 14 | .. math:: 15 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 16 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 17 | 18 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 19 | adjacency matrix with inserted self-loops and 20 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 21 | 22 | Args: 23 | in_channels (int): Size of each input sample. 24 | out_channels (int): Size of each output sample. 25 | improved (bool, optional): If set to :obj:`True`, the layer computes 26 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 27 | (default: :obj:`False`) 28 | cached (bool, optional): If set to :obj:`True`, the layer will cache 29 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 30 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 31 | (default: :obj:`False`) 32 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 33 | an additive bias. (default: :obj:`True`) 34 | edge_norm (bool, optional): whether or not to normalize adj matrix. 35 | (default: :obj:`True`) 36 | gfn (bool, optional): If `True`, only linear transform (1x1 conv) is 37 | applied to every nodes. (default: :obj:`False`) 38 | """ 39 | 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | improved=False, 44 | cached=False, 45 | bias=True, 46 | edge_norm=True, 47 | gfn=False): 48 | super(GCNConv, self).__init__('add') 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.improved = improved 53 | self.cached = cached 54 | self.cached_result = None 55 | self.edge_norm = edge_norm 56 | self.gfn = gfn 57 | 58 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 59 | 60 | if bias: 61 | self.bias = Parameter(torch.Tensor(out_channels)) 62 | else: 63 | self.register_parameter('bias', None) 64 | 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | glorot(self.weight) 69 | zeros(self.bias) 70 | self.cached_result = None 71 | 72 | @staticmethod 73 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 74 | if edge_weight is None: 75 | edge_weight = torch.ones((edge_index.size(1), ), 76 | dtype=dtype, 77 | device=edge_index.device) 78 | edge_weight = edge_weight.view(-1) 79 | assert edge_weight.size(0) == edge_index.size(1) 80 | 81 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 82 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 83 | # Add edge_weight for loop edges. 84 | loop_weight = torch.full((num_nodes, ), 85 | 1 if not improved else 2, 86 | dtype=edge_weight.dtype, 87 | device=edge_weight.device) 88 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 89 | 90 | edge_index = edge_index[0] 91 | row, col = edge_index 92 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 93 | deg_inv_sqrt = deg.pow(-0.5) 94 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 95 | 96 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 97 | 98 | def forward(self, x, edge_index, edge_weight=None): 99 | """""" 100 | x = torch.matmul(x, self.weight) 101 | if self.gfn: 102 | return x 103 | 104 | if not self.cached or self.cached_result is None: 105 | if self.edge_norm: 106 | edge_index, norm = GCNConv.norm( 107 | edge_index, x.size(0), edge_weight, self.improved, x.dtype) 108 | else: 109 | norm = None 110 | self.cached_result = edge_index, norm 111 | 112 | edge_index, norm = self.cached_result 113 | return self.propagate(edge_index, x=x, norm=norm) 114 | 115 | def message(self, x_j, norm): 116 | if self.edge_norm: 117 | return norm.view(-1, 1) * x_j 118 | else: 119 | return x_j 120 | 121 | def update(self, aggr_out): 122 | if self.bias is not None: 123 | aggr_out = aggr_out + self.bias 124 | return aggr_out 125 | 126 | def __repr__(self): 127 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 128 | self.out_channels) 129 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune_joaov2/results_joao/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU/finetune_joaov2/results_joao/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_TU/finetune_joaov2/tu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | from itertools import repeat 5 | 6 | import numpy as np 7 | import torch 8 | import torch_geometric.utils as tg_utils 9 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 10 | from torch_geometric.io import read_tu_data 11 | 12 | 13 | # tudataset adopted from torch_geometric==1.1.0 14 | class TUDatasetExt(InMemoryDataset): 15 | r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY", 16 | "REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University 17 | `_. 18 | 19 | Args: 20 | root (string): Root directory where the dataset should be saved. 21 | name (string): The `name `_ of 22 | the dataset. 23 | transform (callable, optional): A function/transform that takes in an 24 | :obj:`torch_geometric.data.Data` object and returns a transformed 25 | version. The data object will be transformed before every access. 26 | (default: :obj:`None`) 27 | pre_transform (callable, optional): A function/transform that takes in 28 | an :obj:`torch_geometric.data.Data` object and returns a 29 | transformed version. The data object will be transformed before 30 | being saved to disk. (default: :obj:`None`) 31 | pre_filter (callable, optional): A function that takes in an 32 | :obj:`torch_geometric.data.Data` object and returns a boolean 33 | value, indicating whether the data object should be included in the 34 | final dataset. (default: :obj:`None`) 35 | use_node_attr (bool, optional): If :obj:`True`, the dataset will 36 | contain additional continuous node features (if present). 37 | (default: :obj:`False`) 38 | """ 39 | 40 | url = 'https://ls11-www.cs.uni-dortmund.de/people/morris/' \ 41 | 'graphkerneldatasets' 42 | 43 | def __init__(self, 44 | root, 45 | name, 46 | transform=None, 47 | pre_transform=None, 48 | pre_filter=None, 49 | use_node_attr=False, 50 | processed_filename='data.pt'): 51 | self.name = name 52 | self.processed_filename = processed_filename 53 | super(TUDatasetExt, self).__init__(root, transform, pre_transform, 54 | pre_filter) 55 | self.data, self.slices = torch.load(self.processed_paths[0]) 56 | if self.data.x is not None and not use_node_attr: 57 | self.data.x = self.data.x[:, self.num_node_attributes:] 58 | 59 | @property 60 | def num_node_labels(self): 61 | if self.data.x is None: 62 | return 0 63 | for i in range(self.data.x.size(1)): 64 | if self.data.x[:, i:].sum().item() == self.data.x.size(0): 65 | return self.data.x.size(1) - i 66 | return 0 67 | 68 | @property 69 | def num_node_attributes(self): 70 | if self.data.x is None: 71 | return 0 72 | return self.data.x.size(1) - self.num_node_labels 73 | 74 | @property 75 | def raw_file_names(self): 76 | names = ['A', 'graph_indicator'] 77 | return ['{}_{}.txt'.format(self.name, name) for name in names] 78 | 79 | @property 80 | def processed_file_names(self): 81 | return self.processed_filename 82 | 83 | @property 84 | def num_node_features(self): 85 | r"""Returns the number of features per node in the dataset.""" 86 | return self[0].num_node_features 87 | 88 | def download(self): 89 | path = download_url('{}/{}.zip'.format(self.url, self.name), self.root) 90 | extract_zip(path, self.root) 91 | os.unlink(path) 92 | shutil.rmtree(self.raw_dir) 93 | os.rename(osp.join(self.root, self.name), self.raw_dir) 94 | 95 | def process(self): 96 | self.data, self.slices = read_tu_data(self.raw_dir, self.name) 97 | 98 | if self.pre_filter is not None: 99 | data_list = [self.get(idx) for idx in range(len(self))] 100 | data_list = [data for data in data_list if self.pre_filter(data)] 101 | self.data, self.slices = self.collate(data_list) 102 | 103 | if self.pre_transform is not None: 104 | data_list = [self.get(idx) for idx in range(len(self))] 105 | data_list = [self.pre_transform(data) for data in data_list] 106 | self.data, self.slices = self.collate(data_list) 107 | 108 | torch.save((self.data, self.slices), self.processed_paths[0]) 109 | 110 | def __repr__(self): 111 | return '{}({})'.format(self.name, len(self)) 112 | 113 | def get(self, idx): 114 | data = self.data.__class__() 115 | if hasattr(self.data, '__num_nodes__'): 116 | data.num_nodes = self.data.__num_nodes__[idx] 117 | for key in self.data.keys: 118 | item, slices = self.data[key], self.slices[key] 119 | if torch.is_tensor(item): 120 | s = list(repeat(slice(None), item.dim())) 121 | s[self.data.__cat_dim__(key, 122 | item)] = slice(slices[idx], 123 | slices[idx + 1]) 124 | else: 125 | s = slice(slices[idx], slices[idx + 1]) 126 | data[key] = item[s] 127 | return data 128 | 129 | -------------------------------------------------------------------------------- /semisupervised_TU/finetune_joaov2/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def print_weights(model): 4 | for name, param in model.named_parameters(): 5 | if param.requires_grad: 6 | print(name, param.shape) 7 | sys.stdout.flush() 8 | 9 | 10 | def logger(info): 11 | fold, epoch = info['fold'], info['epoch'] 12 | if epoch == 1 or epoch % 10 == 0: 13 | train_acc, test_acc = info['train_acc'], info['test_acc'] 14 | print('{:02d}/{:03d}: Train Acc: {:.3f}, Test Accuracy: {:.3f}'.format( 15 | fold, epoch, train_acc, test_acc)) 16 | sys.stdout.flush() 17 | 18 | 19 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | import torch 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | from torch_geometric.utils import degree 7 | import torch_geometric.transforms as T 8 | from feature_expansion import FeatureExpander 9 | from tu_dataset import TUDatasetExt 10 | 11 | 12 | def get_dataset(name, sparse=True, feat_str="deg+ak3+reall", root=None): 13 | if root is None or root == '': 14 | path = osp.join(osp.expanduser('~'), 'pyG_data', name) 15 | else: 16 | path = osp.join(root, name) 17 | path = '../' + path 18 | degree = feat_str.find("deg") >= 0 19 | onehot_maxdeg = re.findall("odeg(\d+)", feat_str) 20 | onehot_maxdeg = int(onehot_maxdeg[0]) if onehot_maxdeg else None 21 | k = re.findall("an{0,1}k(\d+)", feat_str) 22 | k = int(k[0]) if k else 0 23 | groupd = re.findall("groupd(\d+)", feat_str) 24 | groupd = int(groupd[0]) if groupd else 0 25 | remove_edges = re.findall("re(\w+)", feat_str) 26 | remove_edges = remove_edges[0] if remove_edges else 'none' 27 | edge_noises_add = re.findall("randa([\d\.]+)", feat_str) 28 | edge_noises_add = float(edge_noises_add[0]) if edge_noises_add else 0 29 | edge_noises_delete = re.findall("randd([\d\.]+)", feat_str) 30 | edge_noises_delete = float( 31 | edge_noises_delete[0]) if edge_noises_delete else 0 32 | centrality = feat_str.find("cent") >= 0 33 | coord = feat_str.find("coord") >= 0 34 | 35 | pre_transform = FeatureExpander( 36 | degree=degree, onehot_maxdeg=onehot_maxdeg, AK=k, 37 | centrality=centrality, remove_edges=remove_edges, 38 | edge_noises_add=edge_noises_add, edge_noises_delete=edge_noises_delete, 39 | group_degree=groupd).transform 40 | 41 | dataset = TUDatasetExt( 42 | path, name, pre_transform=pre_transform, 43 | use_node_attr=True, processed_filename="data_%s.pt" % feat_str) 44 | dataset.data.edge_attr = None 45 | 46 | return dataset 47 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/experiment_graphcl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from tu_dataset import DataLoader 4 | 5 | from utils import print_weights 6 | from tqdm import tqdm 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | 11 | def experiment(dataset, model_func, epochs, batch_size, lr, weight_decay, 12 | dataset_name=None, aug_mode='uniform', aug_ratio=0.2, suffix=0): 13 | model = model_func(dataset).to(device) 14 | print_weights(model) 15 | if torch.cuda.is_available(): 16 | torch.cuda.synchronize() 17 | 18 | dataset.set_aug_mode(aug_mode) 19 | dataset.set_aug_ratio(aug_ratio) 20 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=16) 21 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 22 | 23 | # for epoch in tqdm(range(1, epochs+1)): 24 | for epoch in range(1, epochs+1): 25 | pretrain_loss = train(loader, model, optimizer, device) 26 | print(pretrain_loss) 27 | 28 | if epoch % 20 == 0: 29 | weight_path = './weights_graphcl/' + dataset_name + '_' + str(lr) + '_' + str(epoch) + '_' + str(suffix) + '.pt' 30 | torch.save(model.state_dict(), weight_path) 31 | 32 | 33 | def num_graphs(data): 34 | if data.batch is not None: 35 | return data.num_graphs 36 | else: 37 | return data.x.size(0) 38 | 39 | 40 | def train(loader, model, optimizer, device): 41 | model.train() 42 | total_loss = 0 43 | for _, data1, data2 in loader: 44 | # print(data1, data2) 45 | optimizer.zero_grad() 46 | data1 = data1.to(device) 47 | data2 = data2.to(device) 48 | out1 = model.forward_graphcl(data1) 49 | out2 = model.forward_graphcl(data2) 50 | loss = model.loss_graphcl(out1, out2) 51 | loss.backward() 52 | total_loss += loss.item() * num_graphs(data1) 53 | optimizer.step() 54 | 55 | return total_loss/len(loader.dataset) 56 | 57 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/experiment_joao.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from tu_dataset import DataLoader 4 | import numpy as np 5 | 6 | from utils import print_weights 7 | from tqdm import tqdm 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | def experiment(dataset, model_func, epochs, batch_size, lr, weight_decay, 13 | dataset_name=None, aug_mode='uniform', aug_ratio=0.2, suffix=0, gamma_joao=0.1): 14 | model = model_func(dataset).to(device) 15 | print_weights(model) 16 | if torch.cuda.is_available(): 17 | torch.cuda.synchronize() 18 | 19 | dataset.set_aug_mode('sample') 20 | dataset.set_aug_ratio(aug_ratio) 21 | aug_prob = np.ones(25) / 25 22 | dataset.set_aug_prob(aug_prob) 23 | 24 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=16) 25 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 26 | 27 | # for epoch in tqdm(range(1, epochs+1)): 28 | for epoch in range(1, epochs+1): 29 | pretrain_loss, aug_prob = train(loader, model, optimizer, device, gamma_joao) 30 | print(pretrain_loss, aug_prob) 31 | loader.dataset.set_aug_prob(aug_prob) 32 | 33 | if epoch % 20 == 0: 34 | weight_path = './weights_joao/' + dataset_name + '_' + str(lr) + '_' + str(epoch) + '_' + str(gamma_joao) + '_' + str(suffix) + '.pt' 35 | torch.save(model.state_dict(), weight_path) 36 | 37 | 38 | def num_graphs(data): 39 | if data.batch is not None: 40 | return data.num_graphs 41 | else: 42 | return data.x.size(0) 43 | 44 | 45 | def train(loader, model, optimizer, device, gamma_joao): 46 | model.train() 47 | total_loss = 0 48 | for _, data1, data2 in loader: 49 | # print(data1, data2) 50 | optimizer.zero_grad() 51 | data1 = data1.to(device) 52 | data2 = data2.to(device) 53 | out1 = model.forward_graphcl(data1) 54 | out2 = model.forward_graphcl(data2) 55 | loss = model.loss_graphcl(out1, out2) 56 | loss.backward() 57 | total_loss += loss.item() * num_graphs(data1) 58 | optimizer.step() 59 | 60 | aug_prob = joao(loader, model, gamma_joao) 61 | return total_loss/len(loader.dataset), aug_prob 62 | 63 | 64 | def joao(loader, model, gamma_joao): 65 | aug_prob = loader.dataset.aug_prob 66 | # calculate augmentation loss 67 | loss_aug = np.zeros(25) 68 | for n in range(25): 69 | _aug_prob = np.zeros(25) 70 | _aug_prob[n] = 1 71 | loader.dataset.set_aug_prob(_aug_prob) 72 | 73 | count, count_stop = 0, len(loader.dataset)//(loader.batch_size*10)+1 # for efficiency, we only use around 10% of data to estimate the loss 74 | with torch.no_grad(): 75 | for _, data1, data2 in loader: 76 | data1 = data1.to(device) 77 | data2 = data2.to(device) 78 | out1 = model.forward_graphcl(data1) 79 | out2 = model.forward_graphcl(data2) 80 | loss = model.loss_graphcl(out1, out2) 81 | loss_aug[n] += loss.item() * num_graphs(data1) 82 | count += 1 83 | if count == count_stop: 84 | break 85 | loss_aug[n] /= (count*loader.batch_size) 86 | 87 | # view selection, projected gradient descent, reference: https://arxiv.org/abs/1906.03563 88 | beta = 1 89 | gamma = gamma_joao 90 | 91 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1/25)) 92 | mu_min, mu_max = b.min()-1/25, b.max()-1/25 93 | mu = (mu_min + mu_max) / 2 94 | 95 | # bisection method 96 | while abs(np.maximum(b-mu, 0).sum() - 1) > 1e-2: 97 | if np.maximum(b-mu, 0).sum() > 1: 98 | mu_min = mu 99 | else: 100 | mu_max = mu 101 | mu = (mu_min + mu_max) / 2 102 | 103 | aug_prob = np.maximum(b-mu, 0) 104 | aug_prob /= aug_prob.sum() 105 | 106 | return aug_prob 107 | 108 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | from torch_geometric.nn.inits import glorot, zeros 7 | 8 | 9 | class GCNConv(MessagePassing): 10 | r"""The graph convolutional operator from the `"Semi-supervised 11 | Classfication with Graph Convolutional Networks" 12 | `_ paper 13 | 14 | .. math:: 15 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 16 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 17 | 18 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 19 | adjacency matrix with inserted self-loops and 20 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 21 | 22 | Args: 23 | in_channels (int): Size of each input sample. 24 | out_channels (int): Size of each output sample. 25 | improved (bool, optional): If set to :obj:`True`, the layer computes 26 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 27 | (default: :obj:`False`) 28 | cached (bool, optional): If set to :obj:`True`, the layer will cache 29 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 30 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 31 | (default: :obj:`False`) 32 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 33 | an additive bias. (default: :obj:`True`) 34 | edge_norm (bool, optional): whether or not to normalize adj matrix. 35 | (default: :obj:`True`) 36 | gfn (bool, optional): If `True`, only linear transform (1x1 conv) is 37 | applied to every nodes. (default: :obj:`False`) 38 | """ 39 | 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | improved=False, 44 | cached=False, 45 | bias=True, 46 | edge_norm=True, 47 | gfn=False): 48 | super(GCNConv, self).__init__('add') 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.improved = improved 53 | self.cached = cached 54 | self.cached_result = None 55 | self.edge_norm = edge_norm 56 | self.gfn = gfn 57 | 58 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 59 | 60 | if bias: 61 | self.bias = Parameter(torch.Tensor(out_channels)) 62 | else: 63 | self.register_parameter('bias', None) 64 | 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | glorot(self.weight) 69 | zeros(self.bias) 70 | self.cached_result = None 71 | 72 | @staticmethod 73 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 74 | if edge_weight is None: 75 | edge_weight = torch.ones((edge_index.size(1), ), 76 | dtype=dtype, 77 | device=edge_index.device) 78 | edge_weight = edge_weight.view(-1) 79 | assert edge_weight.size(0) == edge_index.size(1) 80 | 81 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 82 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 83 | # Add edge_weight for loop edges. 84 | loop_weight = torch.full((num_nodes, ), 85 | 1 if not improved else 2, 86 | dtype=edge_weight.dtype, 87 | device=edge_weight.device) 88 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 89 | 90 | edge_index = edge_index[0] 91 | row, col = edge_index 92 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 93 | deg_inv_sqrt = deg.pow(-0.5) 94 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 95 | 96 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 97 | 98 | def forward(self, x, edge_index, edge_weight=None): 99 | """""" 100 | x = torch.matmul(x, self.weight) 101 | if self.gfn: 102 | return x 103 | 104 | if not self.cached or self.cached_result is None: 105 | if self.edge_norm: 106 | edge_index, norm = GCNConv.norm( 107 | edge_index, x.size(0), edge_weight, self.improved, x.dtype) 108 | else: 109 | norm = None 110 | self.cached_result = edge_index, norm 111 | 112 | edge_index, norm = self.cached_result 113 | return self.propagate(edge_index, x=x, norm=norm) 114 | 115 | def message(self, x_j, norm): 116 | if self.edge_norm: 117 | return norm.view(-1, 1) * x_j 118 | else: 119 | return x_j 120 | 121 | def update(self, aggr_out): 122 | if self.bias is not None: 123 | aggr_out = aggr_out + self.bias 124 | return aggr_out 125 | 126 | def __repr__(self): 127 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 128 | self.out_channels) 129 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/main.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | from datasets import get_dataset 4 | from res_gcn import ResGCN_graphcl, vgae_encoder, vgae_decoder 5 | 6 | import experiment_graphcl, experiment_joao 7 | 8 | 9 | str2bool = lambda x: x.lower() == "true" 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_root', type=str, default="datasets") 12 | parser.add_argument('--batch_size', type=int, default=128) 13 | parser.add_argument('--n_layers_feat', type=int, default=1) 14 | parser.add_argument('--n_layers_conv', type=int, default=3) 15 | parser.add_argument('--n_layers_fc', type=int, default=2) 16 | parser.add_argument('--hidden', type=int, default=128) 17 | parser.add_argument('--global_pool', type=str, default="sum") 18 | parser.add_argument('--skip_connection', type=str2bool, default=False) 19 | parser.add_argument('--res_branch', type=str, default="BNConvReLU") 20 | parser.add_argument('--dropout', type=float, default=0) 21 | parser.add_argument('--edge_norm', type=str2bool, default=True) 22 | 23 | parser.add_argument('--lr', type=float, default=0.001) 24 | parser.add_argument('--epochs', type=int, default=100) 25 | 26 | parser.add_argument('--dataset', type=str, default="NCI1") 27 | parser.add_argument('--aug_mode', type=str, default="sample") 28 | parser.add_argument('--aug_ratio', type=float, default=0.2) 29 | parser.add_argument('--suffix', type=int, default=0) 30 | 31 | parser.add_argument('--model', type=str, default='joao') 32 | parser.add_argument('--gamma_joao', type=float, default=0.1) 33 | args = parser.parse_args() 34 | 35 | 36 | def create_n_filter_triple(dataset, feat_str, net, gfn_add_ak3=False, 37 | gfn_reall=True, reddit_odeg10=False, 38 | dd_odeg10_ak1=False): 39 | # Add ak3 for GFN. 40 | if gfn_add_ak3 and 'GFN' in net: 41 | feat_str += '+ak3' 42 | # Remove edges for GFN. 43 | if gfn_reall and 'GFN' in net: 44 | feat_str += '+reall' 45 | # Replace degree feats for REDDIT datasets (less redundancy, faster). 46 | if reddit_odeg10 and dataset in [ 47 | 'REDDIT-BINARY', 'REDDIT-MULTI-5K', 'REDDIT-MULTI-12K']: 48 | feat_str = feat_str.replace('odeg100', 'odeg10') 49 | # Replace degree and akx feats for dd (less redundancy, faster). 50 | if dd_odeg10_ak1 and dataset in ['DD']: 51 | feat_str = feat_str.replace('odeg100', 'odeg10') 52 | feat_str = feat_str.replace('ak3', 'ak1') 53 | return dataset, feat_str, net 54 | 55 | 56 | def get_model_with_default_configs(model_name, 57 | num_feat_layers=args.n_layers_feat, 58 | num_conv_layers=args.n_layers_conv, 59 | num_fc_layers=args.n_layers_fc, 60 | residual=args.skip_connection, 61 | hidden=args.hidden): 62 | # More default settings. 63 | res_branch = args.res_branch 64 | global_pool = args.global_pool 65 | dropout = args.dropout 66 | edge_norm = args.edge_norm 67 | 68 | # modify default architecture when needed 69 | if model_name.find('_') > 0: 70 | num_conv_layers_ = re.findall('_conv(\d+)', model_name) 71 | if len(num_conv_layers_) == 1: 72 | num_conv_layers = int(num_conv_layers_[0]) 73 | print('[INFO] num_conv_layers set to {} as in {}'.format( 74 | num_conv_layers, model_name)) 75 | num_fc_layers_ = re.findall('_fc(\d+)', model_name) 76 | if len(num_fc_layers_) == 1: 77 | num_fc_layers = int(num_fc_layers_[0]) 78 | print('[INFO] num_fc_layers set to {} as in {}'.format( 79 | num_fc_layers, model_name)) 80 | residual_ = re.findall('_res(\d+)', model_name) 81 | if len(residual_) == 1: 82 | residual = bool(int(residual_[0])) 83 | print('[INFO] residual set to {} as in {}'.format( 84 | residual, model_name)) 85 | gating = re.findall('_gating', model_name) 86 | if len(gating) == 1: 87 | global_pool += "_gating" 88 | print('[INFO] add gating to global_pool {} as in {}'.format( 89 | global_pool, model_name)) 90 | dropout_ = re.findall('_drop([\.\d]+)', model_name) 91 | if len(dropout_) == 1: 92 | dropout = float(dropout_[0]) 93 | print('[INFO] dropout set to {} as in {}'.format( 94 | dropout, model_name)) 95 | hidden_ = re.findall('_dim(\d+)', model_name) 96 | if len(hidden_) == 1: 97 | hidden = int(hidden_[0]) 98 | print('[INFO] hidden set to {} as in {}'.format( 99 | hidden, model_name)) 100 | 101 | if model_name == 'ResGCN_graphcl': 102 | def foo(dataset): 103 | return ResGCN_graphcl(dataset=dataset, hidden=hidden, num_feat_layers=num_feat_layers, num_conv_layers=num_conv_layers, 104 | num_fc_layers=num_fc_layers, gfn=False, collapse=False, 105 | residual=residual, res_branch=res_branch, 106 | global_pool=global_pool, dropout=dropout, 107 | edge_norm=edge_norm) 108 | 109 | else: 110 | raise ValueError("Unknown model {}".format(model_name)) 111 | return foo 112 | 113 | 114 | def run_experiment_graphcl(dataset_feat_net_triple 115 | =create_n_filter_triple(args.dataset, 'deg+odeg100', 'ResGCN_graphcl', gfn_add_ak3=True, reddit_odeg10=True, dd_odeg10_ak1=True), 116 | get_model=get_model_with_default_configs): 117 | 118 | dataset_name, feat_str, net = dataset_feat_net_triple 119 | dataset = get_dataset( 120 | dataset_name, sparse=True, feat_str=feat_str, root=args.data_root) 121 | model_func = get_model(net) 122 | 123 | if not args.model == 'joao': 124 | experiment_graphcl.experiment(dataset, model_func, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, weight_decay=0, dataset_name=dataset_name, aug_mode=args.aug_mode, aug_ratio=args.aug_ratio, suffix=args.suffix) 125 | 126 | else: 127 | experiment_joao.experiment(dataset, model_func, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, weight_decay=0, dataset_name=dataset_name, aug_mode=args.aug_mode, aug_ratio=args.aug_ratio, suffix=args.suffix, gamma_joao=args.gamma_joao) 128 | 129 | 130 | if __name__ == '__main__': 131 | print(args) 132 | 133 | run_experiment_graphcl() 134 | 135 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def print_weights(model): 4 | for name, param in model.named_parameters(): 5 | if param.requires_grad: 6 | print(name, param.shape) 7 | sys.stdout.flush() 8 | 9 | 10 | def logger(info): 11 | fold, epoch = info['fold'], info['epoch'] 12 | if epoch == 1 or epoch % 10 == 0: 13 | train_acc, test_acc = info['train_acc'], info['test_acc'] 14 | print('{:02d}/{:03d}: Train Acc: {:.3f}, Test Accuracy: {:.3f}'.format( 15 | fold, epoch, train_acc, test_acc)) 16 | sys.stdout.flush() 17 | 18 | 19 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/weights_graphcl/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU/pretrain/weights_graphcl/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_TU/pretrain/weights_joao/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU/pretrain/weights_joao/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_TU/pretrain_joaov2/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | import torch 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | from torch_geometric.utils import degree 7 | import torch_geometric.transforms as T 8 | from feature_expansion import FeatureExpander 9 | from tu_dataset import TUDatasetExt 10 | 11 | 12 | def get_dataset(name, sparse=True, feat_str="deg+ak3+reall", root=None): 13 | if root is None or root == '': 14 | path = osp.join(osp.expanduser('~'), 'pyG_data', name) 15 | else: 16 | path = osp.join(root, name) 17 | path = '../' + path 18 | degree = feat_str.find("deg") >= 0 19 | onehot_maxdeg = re.findall("odeg(\d+)", feat_str) 20 | onehot_maxdeg = int(onehot_maxdeg[0]) if onehot_maxdeg else None 21 | k = re.findall("an{0,1}k(\d+)", feat_str) 22 | k = int(k[0]) if k else 0 23 | groupd = re.findall("groupd(\d+)", feat_str) 24 | groupd = int(groupd[0]) if groupd else 0 25 | remove_edges = re.findall("re(\w+)", feat_str) 26 | remove_edges = remove_edges[0] if remove_edges else 'none' 27 | edge_noises_add = re.findall("randa([\d\.]+)", feat_str) 28 | edge_noises_add = float(edge_noises_add[0]) if edge_noises_add else 0 29 | edge_noises_delete = re.findall("randd([\d\.]+)", feat_str) 30 | edge_noises_delete = float( 31 | edge_noises_delete[0]) if edge_noises_delete else 0 32 | centrality = feat_str.find("cent") >= 0 33 | coord = feat_str.find("coord") >= 0 34 | 35 | pre_transform = FeatureExpander( 36 | degree=degree, onehot_maxdeg=onehot_maxdeg, AK=k, 37 | centrality=centrality, remove_edges=remove_edges, 38 | edge_noises_add=edge_noises_add, edge_noises_delete=edge_noises_delete, 39 | group_degree=groupd).transform 40 | 41 | dataset = TUDatasetExt( 42 | path, name, pre_transform=pre_transform, 43 | use_node_attr=True, processed_filename="data_%s.pt" % feat_str) 44 | dataset.data.edge_attr = None 45 | 46 | return dataset 47 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain_joaov2/experiment_joao.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from tu_dataset import DataLoader 4 | import numpy as np 5 | 6 | from utils import print_weights 7 | from tqdm import tqdm 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | def experiment(dataset, model_func, epochs, batch_size, lr, weight_decay, 13 | dataset_name=None, aug_mode='uniform', aug_ratio=0.2, suffix=0, gamma_joao=0.1): 14 | model = model_func(dataset).to(device) 15 | print_weights(model) 16 | if torch.cuda.is_available(): 17 | torch.cuda.synchronize() 18 | 19 | dataset.set_aug_mode('sample') 20 | dataset.set_aug_ratio(aug_ratio) 21 | aug_prob = np.ones(25) / 25 22 | dataset.set_aug_prob(aug_prob) 23 | 24 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=16) 25 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 26 | 27 | # for epoch in tqdm(range(1, epochs+1)): 28 | for epoch in range(1, epochs+1): 29 | pretrain_loss, aug_prob = train(loader, model, optimizer, device, gamma_joao) 30 | print(pretrain_loss, aug_prob) 31 | loader.dataset.set_aug_prob(aug_prob) 32 | 33 | if epoch % 20 == 0: 34 | weight_path = './weights_joao/' + dataset_name + '_' + str(lr) + '_' + str(epoch) + '_' + str(gamma_joao) + '_' + str(suffix) + '.pt' 35 | torch.save(model.state_dict(), weight_path) 36 | 37 | 38 | def num_graphs(data): 39 | if data.batch is not None: 40 | return data.num_graphs 41 | else: 42 | return data.x.size(0) 43 | 44 | 45 | def train(loader, model, optimizer, device, gamma_joao): 46 | model.train() 47 | total_loss = 0 48 | 49 | aug_prob = loader.dataset.aug_prob 50 | n_aug = np.random.choice(25, 1, p=aug_prob)[0] 51 | n_aug1, n_aug2 = n_aug//5, n_aug%5 52 | 53 | for _, data1, data2 in loader: 54 | # print(data1, data2) 55 | optimizer.zero_grad() 56 | data1 = data1.to(device) 57 | data2 = data2.to(device) 58 | out1 = model.forward_graphcl(data1, n_aug1) 59 | out2 = model.forward_graphcl(data2, n_aug2) 60 | loss = model.loss_graphcl(out1, out2) 61 | loss.backward() 62 | total_loss += loss.item() * num_graphs(data1) 63 | optimizer.step() 64 | 65 | aug_prob = joao(loader, model, gamma_joao) 66 | return total_loss/len(loader.dataset), aug_prob 67 | 68 | 69 | def joao(loader, model, gamma_joao): 70 | aug_prob = loader.dataset.aug_prob 71 | # calculate augmentation loss 72 | loss_aug = np.zeros(25) 73 | 74 | 75 | for n in range(25): 76 | _aug_prob = np.zeros(25) 77 | _aug_prob[n] = 1 78 | loader.dataset.set_aug_prob(_aug_prob) 79 | 80 | n_aug1, n_aug2 = n//5, n%5 81 | 82 | count, count_stop = 0, len(loader.dataset)//(loader.batch_size*10)+1 # for efficiency, we only use around 10% of data to estimate the loss 83 | with torch.no_grad(): 84 | for _, data1, data2 in loader: 85 | data1 = data1.to(device) 86 | data2 = data2.to(device) 87 | out1 = model.forward_graphcl(data1, n_aug1) 88 | out2 = model.forward_graphcl(data2, n_aug2) 89 | loss = model.loss_graphcl(out1, out2) 90 | loss_aug[n] += loss.item() * num_graphs(data1) 91 | count += 1 92 | if count == count_stop: 93 | break 94 | loss_aug[n] /= (count*loader.batch_size) 95 | 96 | # view selection, projected gradient descent, reference: https://arxiv.org/abs/1906.03563 97 | beta = 1 98 | gamma = gamma_joao 99 | 100 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1/25)) 101 | mu_min, mu_max = b.min()-1/25, b.max()-1/25 102 | mu = (mu_min + mu_max) / 2 103 | 104 | # bisection method 105 | while abs(np.maximum(b-mu, 0).sum() - 1) > 1e-2: 106 | if np.maximum(b-mu, 0).sum() > 1: 107 | mu_min = mu 108 | else: 109 | mu_max = mu 110 | mu = (mu_min + mu_max) / 2 111 | 112 | aug_prob = np.maximum(b-mu, 0) 113 | aug_prob /= aug_prob.sum() 114 | 115 | return aug_prob 116 | 117 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain_joaov2/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | from torch_geometric.nn.inits import glorot, zeros 7 | 8 | 9 | class GCNConv(MessagePassing): 10 | r"""The graph convolutional operator from the `"Semi-supervised 11 | Classfication with Graph Convolutional Networks" 12 | `_ paper 13 | 14 | .. math:: 15 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 16 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 17 | 18 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 19 | adjacency matrix with inserted self-loops and 20 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 21 | 22 | Args: 23 | in_channels (int): Size of each input sample. 24 | out_channels (int): Size of each output sample. 25 | improved (bool, optional): If set to :obj:`True`, the layer computes 26 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 27 | (default: :obj:`False`) 28 | cached (bool, optional): If set to :obj:`True`, the layer will cache 29 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 30 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 31 | (default: :obj:`False`) 32 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 33 | an additive bias. (default: :obj:`True`) 34 | edge_norm (bool, optional): whether or not to normalize adj matrix. 35 | (default: :obj:`True`) 36 | gfn (bool, optional): If `True`, only linear transform (1x1 conv) is 37 | applied to every nodes. (default: :obj:`False`) 38 | """ 39 | 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | improved=False, 44 | cached=False, 45 | bias=True, 46 | edge_norm=True, 47 | gfn=False): 48 | super(GCNConv, self).__init__('add') 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.improved = improved 53 | self.cached = cached 54 | self.cached_result = None 55 | self.edge_norm = edge_norm 56 | self.gfn = gfn 57 | 58 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 59 | 60 | if bias: 61 | self.bias = Parameter(torch.Tensor(out_channels)) 62 | else: 63 | self.register_parameter('bias', None) 64 | 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | glorot(self.weight) 69 | zeros(self.bias) 70 | self.cached_result = None 71 | 72 | @staticmethod 73 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 74 | if edge_weight is None: 75 | edge_weight = torch.ones((edge_index.size(1), ), 76 | dtype=dtype, 77 | device=edge_index.device) 78 | edge_weight = edge_weight.view(-1) 79 | assert edge_weight.size(0) == edge_index.size(1) 80 | 81 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 82 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 83 | # Add edge_weight for loop edges. 84 | loop_weight = torch.full((num_nodes, ), 85 | 1 if not improved else 2, 86 | dtype=edge_weight.dtype, 87 | device=edge_weight.device) 88 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 89 | 90 | edge_index = edge_index[0] 91 | row, col = edge_index 92 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 93 | deg_inv_sqrt = deg.pow(-0.5) 94 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 95 | 96 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 97 | 98 | def forward(self, x, edge_index, edge_weight=None): 99 | """""" 100 | x = torch.matmul(x, self.weight) 101 | if self.gfn: 102 | return x 103 | 104 | if not self.cached or self.cached_result is None: 105 | if self.edge_norm: 106 | edge_index, norm = GCNConv.norm( 107 | edge_index, x.size(0), edge_weight, self.improved, x.dtype) 108 | else: 109 | norm = None 110 | self.cached_result = edge_index, norm 111 | 112 | edge_index, norm = self.cached_result 113 | return self.propagate(edge_index, x=x, norm=norm) 114 | 115 | def message(self, x_j, norm): 116 | if self.edge_norm: 117 | return norm.view(-1, 1) * x_j 118 | else: 119 | return x_j 120 | 121 | def update(self, aggr_out): 122 | if self.bias is not None: 123 | aggr_out = aggr_out + self.bias 124 | return aggr_out 125 | 126 | def __repr__(self): 127 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 128 | self.out_channels) 129 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain_joaov2/main.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | from datasets import get_dataset 4 | from res_gcn import ResGCN_graphcl, vgae_encoder, vgae_decoder 5 | 6 | import experiment_joao 7 | 8 | 9 | str2bool = lambda x: x.lower() == "true" 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_root', type=str, default="datasets") 12 | parser.add_argument('--batch_size', type=int, default=128) 13 | parser.add_argument('--n_layers_feat', type=int, default=1) 14 | parser.add_argument('--n_layers_conv', type=int, default=3) 15 | parser.add_argument('--n_layers_fc', type=int, default=2) 16 | parser.add_argument('--hidden', type=int, default=128) 17 | parser.add_argument('--global_pool', type=str, default="sum") 18 | parser.add_argument('--skip_connection', type=str2bool, default=False) 19 | parser.add_argument('--res_branch', type=str, default="BNConvReLU") 20 | parser.add_argument('--dropout', type=float, default=0) 21 | parser.add_argument('--edge_norm', type=str2bool, default=True) 22 | 23 | parser.add_argument('--lr', type=float, default=0.001) 24 | parser.add_argument('--epochs', type=int, default=100) 25 | 26 | parser.add_argument('--dataset', type=str, default="NCI1") 27 | parser.add_argument('--aug_mode', type=str, default="sample") 28 | parser.add_argument('--aug_ratio', type=float, default=0.2) 29 | parser.add_argument('--suffix', type=int, default=0) 30 | 31 | parser.add_argument('--model', type=str, default='joao') 32 | parser.add_argument('--gamma_joao', type=float, default=0.1) 33 | args = parser.parse_args() 34 | 35 | 36 | def create_n_filter_triple(dataset, feat_str, net, gfn_add_ak3=False, 37 | gfn_reall=True, reddit_odeg10=False, 38 | dd_odeg10_ak1=False): 39 | # Add ak3 for GFN. 40 | if gfn_add_ak3 and 'GFN' in net: 41 | feat_str += '+ak3' 42 | # Remove edges for GFN. 43 | if gfn_reall and 'GFN' in net: 44 | feat_str += '+reall' 45 | # Replace degree feats for REDDIT datasets (less redundancy, faster). 46 | if reddit_odeg10 and dataset in [ 47 | 'REDDIT-BINARY', 'REDDIT-MULTI-5K', 'REDDIT-MULTI-12K']: 48 | feat_str = feat_str.replace('odeg100', 'odeg10') 49 | # Replace degree and akx feats for dd (less redundancy, faster). 50 | if dd_odeg10_ak1 and dataset in ['DD']: 51 | feat_str = feat_str.replace('odeg100', 'odeg10') 52 | feat_str = feat_str.replace('ak3', 'ak1') 53 | return dataset, feat_str, net 54 | 55 | 56 | def get_model_with_default_configs(model_name, 57 | num_feat_layers=args.n_layers_feat, 58 | num_conv_layers=args.n_layers_conv, 59 | num_fc_layers=args.n_layers_fc, 60 | residual=args.skip_connection, 61 | hidden=args.hidden): 62 | # More default settings. 63 | res_branch = args.res_branch 64 | global_pool = args.global_pool 65 | dropout = args.dropout 66 | edge_norm = args.edge_norm 67 | 68 | # modify default architecture when needed 69 | if model_name.find('_') > 0: 70 | num_conv_layers_ = re.findall('_conv(\d+)', model_name) 71 | if len(num_conv_layers_) == 1: 72 | num_conv_layers = int(num_conv_layers_[0]) 73 | print('[INFO] num_conv_layers set to {} as in {}'.format( 74 | num_conv_layers, model_name)) 75 | num_fc_layers_ = re.findall('_fc(\d+)', model_name) 76 | if len(num_fc_layers_) == 1: 77 | num_fc_layers = int(num_fc_layers_[0]) 78 | print('[INFO] num_fc_layers set to {} as in {}'.format( 79 | num_fc_layers, model_name)) 80 | residual_ = re.findall('_res(\d+)', model_name) 81 | if len(residual_) == 1: 82 | residual = bool(int(residual_[0])) 83 | print('[INFO] residual set to {} as in {}'.format( 84 | residual, model_name)) 85 | gating = re.findall('_gating', model_name) 86 | if len(gating) == 1: 87 | global_pool += "_gating" 88 | print('[INFO] add gating to global_pool {} as in {}'.format( 89 | global_pool, model_name)) 90 | dropout_ = re.findall('_drop([\.\d]+)', model_name) 91 | if len(dropout_) == 1: 92 | dropout = float(dropout_[0]) 93 | print('[INFO] dropout set to {} as in {}'.format( 94 | dropout, model_name)) 95 | hidden_ = re.findall('_dim(\d+)', model_name) 96 | if len(hidden_) == 1: 97 | hidden = int(hidden_[0]) 98 | print('[INFO] hidden set to {} as in {}'.format( 99 | hidden, model_name)) 100 | 101 | if model_name == 'ResGCN_graphcl': 102 | def foo(dataset): 103 | return ResGCN_graphcl(dataset=dataset, hidden=hidden, num_feat_layers=num_feat_layers, num_conv_layers=num_conv_layers, 104 | num_fc_layers=num_fc_layers, gfn=False, collapse=False, 105 | residual=residual, res_branch=res_branch, 106 | global_pool=global_pool, dropout=dropout, 107 | edge_norm=edge_norm) 108 | 109 | else: 110 | raise ValueError("Unknown model {}".format(model_name)) 111 | return foo 112 | 113 | 114 | def run_experiment_graphcl(dataset_feat_net_triple 115 | =create_n_filter_triple(args.dataset, 'deg+odeg100', 'ResGCN_graphcl', gfn_add_ak3=True, reddit_odeg10=True, dd_odeg10_ak1=True), 116 | get_model=get_model_with_default_configs): 117 | 118 | dataset_name, feat_str, net = dataset_feat_net_triple 119 | dataset = get_dataset( 120 | dataset_name, sparse=True, feat_str=feat_str, root=args.data_root) 121 | model_func = get_model(net) 122 | 123 | experiment_joao.experiment(dataset, model_func, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, weight_decay=0, dataset_name=dataset_name, aug_mode=args.aug_mode, aug_ratio=args.aug_ratio, suffix=args.suffix, gamma_joao=args.gamma_joao) 124 | 125 | 126 | if __name__ == '__main__': 127 | print(args) 128 | 129 | run_experiment_graphcl() 130 | 131 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain_joaov2/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def print_weights(model): 4 | for name, param in model.named_parameters(): 5 | if param.requires_grad: 6 | print(name, param.shape) 7 | sys.stdout.flush() 8 | 9 | 10 | def logger(info): 11 | fold, epoch = info['fold'], info['epoch'] 12 | if epoch == 1 or epoch % 10 == 0: 13 | train_acc, test_acc = info['train_acc'], info['test_acc'] 14 | print('{:02d}/{:03d}: Train Acc: {:.3f}, Test Accuracy: {:.3f}'.format( 15 | fold, epoch, train_acc, test_acc)) 16 | sys.stdout.flush() 17 | 18 | 19 | -------------------------------------------------------------------------------- /semisupervised_TU/pretrain_joaov2/weights_joao/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU/pretrain_joaov2/weights_joao/place_holder.txt -------------------------------------------------------------------------------- /semisupervised_TU_LP/README.md: -------------------------------------------------------------------------------- 1 | ### LP-InfoMin/InfoBN/Info(Min+BN) Pre-Training: ### 2 | 3 | ``` 4 | cd ./pretrain 5 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infomin --suffix 0 6 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infomin --suffix 1 7 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infomin --suffix 2 8 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infomin --suffix 3 9 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infomin --suffix 4 10 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infobn --suffix 0 11 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infobn --suffix 1 12 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infobn --suffix 2 13 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infobn --suffix 3 14 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infobn --suffix 4 15 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infominbn --suffix 0 16 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infominbn --suffix 1 17 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infominbn --suffix 2 18 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infominbn --suffix 3 19 | python main.py --dataset COLLAB --epochs 100 --lr 0.001 --principle infominbn --suffix 4 20 | ``` 21 | 22 | ### LP-InfoMin/InfoBN/Info(Min+BN) Finetuning: ### 23 | 24 | ``` 25 | cd ./finetune 26 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 ----model_path $PATH_OF_MODEL_WEIGHT --result_path $PATH_TO_SAVE_RESULT --suffix 0 --n_splits 10 27 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 ----model_path $PATH_OF_MODEL_WEIGHT --result_path $PATH_TO_SAVE_RESULT --suffix 1 --n_splits 10 28 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 ----model_path $PATH_OF_MODEL_WEIGHT --result_path $PATH_TO_SAVE_RESULT --suffix 2 --n_splits 10 29 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 ----model_path $PATH_OF_MODEL_WEIGHT --result_path $PATH_TO_SAVE_RESULT --suffix 3 --n_splits 10 30 | python main.py --dataset NCI1 --pretrain_epoch 100 --pretrain_lr 0.001 ----model_path $PATH_OF_MODEL_WEIGHT --result_path $PATH_TO_SAVE_RESULT --suffix 4 --n_splits 10 31 | ``` 32 | 33 | Five suffixes stand for five runs (with mean & std reported) 34 | 35 | ```lr``` should be tuned from {0.01, 0.001, 0.0001} and ```pretrain_epoch``` in finetuning (this means the epoch checkpoint loaded from pre-trained model) from {20, 40, 60, 80, 100}. 36 | 37 | ## Acknowledgements 38 | 39 | The backbone implementation is reference to https://github.com/chentingpc/gfn. 40 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/finetune/.main_visulize_generator.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU_LP/finetune/.main_visulize_generator.py.swp -------------------------------------------------------------------------------- /semisupervised_TU_LP/finetune/calculate_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | 5 | file_name = str(sys.argv[1]) 6 | 7 | with open(file_name, 'r') as f: 8 | data = f.read().split('\n')[:-1] 9 | 10 | res_dict = {} 11 | for d in data: 12 | pref, res = d.split() 13 | pref, res = pref[:-2], float(res) 14 | 15 | if not pref in res_dict.keys(): 16 | res_dict[pref] = [res] 17 | else: 18 | res_dict[pref].append(res) 19 | 20 | pref_best, res_best = '', 0 21 | for k, v in res_dict.items(): 22 | if not '0.001' in k: 23 | continue 24 | if len(v) < 5: 25 | continue 26 | elif np.mean(v) < np.mean(res_best): 27 | continue 28 | 29 | pref_best, res_best = k, v 30 | 31 | print(pref_best, np.mean(res_best), np.std(res_best)) 32 | 33 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/finetune/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | import torch 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | from torch_geometric.utils import degree 7 | import torch_geometric.transforms as T 8 | from feature_expansion import FeatureExpander 9 | from tu_dataset import TUDatasetExt 10 | 11 | 12 | def get_dataset(name, sparse=True, feat_str="deg+ak3+reall", root=None): 13 | if root is None or root == '': 14 | path = osp.join(osp.expanduser('~'), 'pyG_data', name) 15 | else: 16 | path = osp.join(root, name) 17 | path = '../' + path 18 | degree = feat_str.find("deg") >= 0 19 | onehot_maxdeg = re.findall("odeg(\d+)", feat_str) 20 | onehot_maxdeg = int(onehot_maxdeg[0]) if onehot_maxdeg else None 21 | k = re.findall("an{0,1}k(\d+)", feat_str) 22 | k = int(k[0]) if k else 0 23 | groupd = re.findall("groupd(\d+)", feat_str) 24 | groupd = int(groupd[0]) if groupd else 0 25 | remove_edges = re.findall("re(\w+)", feat_str) 26 | remove_edges = remove_edges[0] if remove_edges else 'none' 27 | edge_noises_add = re.findall("randa([\d\.]+)", feat_str) 28 | edge_noises_add = float(edge_noises_add[0]) if edge_noises_add else 0 29 | edge_noises_delete = re.findall("randd([\d\.]+)", feat_str) 30 | edge_noises_delete = float( 31 | edge_noises_delete[0]) if edge_noises_delete else 0 32 | centrality = feat_str.find("cent") >= 0 33 | coord = feat_str.find("coord") >= 0 34 | 35 | pre_transform = FeatureExpander( 36 | degree=degree, onehot_maxdeg=onehot_maxdeg, AK=k, 37 | centrality=centrality, remove_edges=remove_edges, 38 | edge_noises_add=edge_noises_add, edge_noises_delete=edge_noises_delete, 39 | group_degree=groupd).transform 40 | 41 | dataset = TUDatasetExt( 42 | path, name, pre_transform=pre_transform, 43 | use_node_attr=True, processed_filename="data_%s.pt" % feat_str) 44 | dataset.data.edge_attr = None 45 | 46 | return dataset 47 | 48 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/finetune/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | from torch_geometric.nn.inits import glorot, zeros 7 | 8 | 9 | class GCNConv(MessagePassing): 10 | r"""The graph convolutional operator from the `"Semi-supervised 11 | Classfication with Graph Convolutional Networks" 12 | `_ paper 13 | 14 | .. math:: 15 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 16 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 17 | 18 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 19 | adjacency matrix with inserted self-loops and 20 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 21 | 22 | Args: 23 | in_channels (int): Size of each input sample. 24 | out_channels (int): Size of each output sample. 25 | improved (bool, optional): If set to :obj:`True`, the layer computes 26 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 27 | (default: :obj:`False`) 28 | cached (bool, optional): If set to :obj:`True`, the layer will cache 29 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 30 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 31 | (default: :obj:`False`) 32 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 33 | an additive bias. (default: :obj:`True`) 34 | edge_norm (bool, optional): whether or not to normalize adj matrix. 35 | (default: :obj:`True`) 36 | gfn (bool, optional): If `True`, only linear transform (1x1 conv) is 37 | applied to every nodes. (default: :obj:`False`) 38 | """ 39 | 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | improved=False, 44 | cached=False, 45 | bias=True, 46 | edge_norm=True, 47 | gfn=False): 48 | super(GCNConv, self).__init__('add') 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.improved = improved 53 | self.cached = cached 54 | self.cached_result = None 55 | self.edge_norm = edge_norm 56 | self.gfn = gfn 57 | 58 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 59 | 60 | if bias: 61 | self.bias = Parameter(torch.Tensor(out_channels)) 62 | else: 63 | self.register_parameter('bias', None) 64 | 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | glorot(self.weight) 69 | zeros(self.bias) 70 | self.cached_result = None 71 | 72 | @staticmethod 73 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 74 | if edge_weight is None: 75 | edge_weight = torch.ones((edge_index.size(1), ), 76 | dtype=dtype, 77 | device=edge_index.device) 78 | edge_weight = edge_weight.view(-1) 79 | assert edge_weight.size(0) == edge_index.size(1) 80 | 81 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 82 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 83 | # Add edge_weight for loop edges. 84 | loop_weight = torch.full((num_nodes, ), 85 | 1 if not improved else 2, 86 | dtype=edge_weight.dtype, 87 | device=edge_weight.device) 88 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 89 | 90 | edge_index = edge_index[0] 91 | row, col = edge_index 92 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 93 | deg_inv_sqrt = deg.pow(-0.5) 94 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 95 | 96 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 97 | 98 | def forward(self, x, edge_index, edge_weight=None): 99 | """""" 100 | x = torch.matmul(x, self.weight) 101 | if self.gfn: 102 | return x 103 | 104 | if not self.cached or self.cached_result is None: 105 | if self.edge_norm: 106 | edge_index, norm = GCNConv.norm( 107 | edge_index, x.size(0), edge_weight, self.improved, x.dtype) 108 | else: 109 | norm = None 110 | self.cached_result = edge_index, norm 111 | 112 | edge_index, norm = self.cached_result 113 | return self.propagate(edge_index, x=x, norm=norm) 114 | 115 | def message(self, x_j, norm): 116 | if self.edge_norm: 117 | return norm.view(-1, 1) * x_j 118 | else: 119 | return x_j 120 | 121 | def update(self, aggr_out): 122 | if self.bias is not None: 123 | aggr_out = aggr_out + self.bias 124 | return aggr_out 125 | 126 | def __repr__(self): 127 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 128 | self.out_channels) 129 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/finetune/train_eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import tensor 7 | from torch.optim import Adam 8 | from sklearn.model_selection import StratifiedKFold 9 | from torch_geometric.data import DataLoader, DenseDataLoader as DenseLoader 10 | 11 | from utils import print_weights 12 | from tqdm import tqdm 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | def cross_validation_with_val_set(dataset, 18 | model_func, 19 | epochs, 20 | batch_size, 21 | lr, 22 | weight_decay, 23 | epoch_select, 24 | with_eval_mode=True, 25 | logger=None, 26 | model_PATH=None, n_splits=None, result_PATH=None, result_feat=None): 27 | assert epoch_select in ['val_max', 'test_max'], epoch_select 28 | 29 | folds=10 30 | # pbar = tqdm(total=folds) 31 | val_losses, train_accs, test_accs, durations = [], [], [], [] 32 | for fold, (train_idx, test_idx, val_idx) in enumerate( 33 | zip(*k_fold(dataset, folds, epoch_select, n_splits))): 34 | 35 | train_dataset = dataset[train_idx] 36 | test_dataset = dataset[test_idx] 37 | val_dataset = dataset[val_idx] 38 | 39 | train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=16) 40 | val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=16) 41 | test_loader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=16) 42 | 43 | model = model_func(dataset).to(device) 44 | # # train from scratch 45 | model.load_state_dict(torch.load(model_PATH)) 46 | 47 | if fold == 0: 48 | print_weights(model) 49 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.synchronize() 53 | 54 | t_start = time.perf_counter() 55 | 56 | for epoch in range(1, epochs + 1): 57 | train_loss, train_acc = train( 58 | model, optimizer, train_loader, device) 59 | train_accs.append(train_acc) 60 | val_losses.append(eval_loss( 61 | model, val_loader, device, with_eval_mode)) 62 | test_accs.append(eval_acc( 63 | model, test_loader, device, with_eval_mode)) 64 | eval_info = { 65 | 'fold': fold, 66 | 'epoch': epoch, 67 | 'train_loss': train_loss, 68 | 'train_acc': train_accs[-1], 69 | 'val_loss': val_losses[-1], 70 | 'test_acc': test_accs[-1], 71 | } 72 | 73 | if logger is not None: 74 | logger(eval_info) 75 | 76 | if torch.cuda.is_available(): 77 | torch.cuda.synchronize() 78 | 79 | t_end = time.perf_counter() 80 | durations.append(t_end - t_start) 81 | # pbar.update(1) 82 | 83 | duration = tensor(durations) 84 | train_acc, test_acc = tensor(train_accs), tensor(test_accs) 85 | val_loss = tensor(val_losses) 86 | train_acc = train_acc.view(folds, epochs) 87 | test_acc = test_acc.view(folds, epochs) 88 | val_loss = val_loss.view(folds, epochs) 89 | if epoch_select == 'test_max': # take epoch that yields best test results. 90 | _, selected_epoch = test_acc.mean(dim=0).max(dim=0) 91 | selected_epoch = selected_epoch.repeat(folds) 92 | else: # take epoch that yields min val loss for each fold individually. 93 | _, selected_epoch = val_loss.min(dim=1) 94 | test_acc = test_acc[torch.arange(folds, dtype=torch.long), selected_epoch] 95 | train_acc_mean = train_acc[:, -1].mean().item() 96 | test_acc_mean = test_acc.mean().item() 97 | test_acc_std = test_acc.std().item() 98 | duration_mean = duration.mean().item() 99 | 100 | print(train_acc_mean, test_acc_mean, test_acc_std, duration_mean) 101 | sys.stdout.flush() 102 | 103 | with open(result_PATH, 'a+') as f: 104 | f.write(result_feat + ' ' + str(test_acc_mean) + '\n') 105 | 106 | 107 | def k_fold(dataset, folds, epoch_select, n_splits): 108 | skf = StratifiedKFold(folds, shuffle=True, random_state=12345) 109 | 110 | test_indices, train_indices = [], [] 111 | for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y): 112 | test_indices.append(torch.from_numpy(idx)) 113 | 114 | if epoch_select == 'test_max': 115 | val_indices = [test_indices[i] for i in range(folds)] 116 | else: 117 | val_indices = [test_indices[i - 1] for i in range(folds)] 118 | 119 | skf_semi = StratifiedKFold(n_splits, shuffle=True, random_state=12345) 120 | for i in range(folds): 121 | train_mask = torch.ones(len(dataset), dtype=torch.uint8) 122 | train_mask[test_indices[i].long()] = 0 123 | train_mask[val_indices[i].long()] = 0 124 | idx_train = train_mask.nonzero(as_tuple=False).view(-1) 125 | 126 | for _, idx in skf_semi.split(torch.zeros(idx_train.size()[0]), dataset.data.y[idx_train]): 127 | idx_train = idx_train[idx] 128 | break 129 | 130 | train_indices.append(idx_train) 131 | 132 | return train_indices, test_indices, val_indices 133 | 134 | 135 | def num_graphs(data): 136 | if data.batch is not None: 137 | return data.num_graphs 138 | else: 139 | return data.x.size(0) 140 | 141 | 142 | def train(model, optimizer, loader, device): 143 | model.train() 144 | 145 | total_loss = 0 146 | correct = 0 147 | for data in loader: 148 | optimizer.zero_grad() 149 | data = data.to(device) 150 | out = model(data) 151 | loss = F.nll_loss(out, data.y.long().view(-1)) 152 | pred = out.max(1)[1] 153 | correct += pred.eq(data.y.view(-1)).sum().item() 154 | loss.backward() 155 | total_loss += loss.item() * num_graphs(data) 156 | optimizer.step() 157 | return total_loss / len(loader.dataset), correct / len(loader.dataset) 158 | 159 | 160 | def eval_acc(model, loader, device, with_eval_mode): 161 | if with_eval_mode: 162 | model.eval() 163 | 164 | correct = 0 165 | for data in loader: 166 | data = data.to(device) 167 | with torch.no_grad(): 168 | pred = model(data).max(1)[1] 169 | correct += pred.eq(data.y.view(-1)).sum().item() 170 | return correct / len(loader.dataset) 171 | 172 | 173 | def eval_loss(model, loader, device, with_eval_mode): 174 | if with_eval_mode: 175 | model.eval() 176 | 177 | loss = 0 178 | for data in loader: 179 | data = data.to(device) 180 | with torch.no_grad(): 181 | out = model(data) 182 | loss += F.nll_loss(out, data.y.long().view(-1), reduction='sum').item() 183 | return loss / len(loader.dataset) 184 | 185 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/finetune/tu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | from itertools import repeat 5 | 6 | import numpy as np 7 | import torch 8 | import torch_geometric.utils as tg_utils 9 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 10 | from torch_geometric.io import read_tu_data 11 | 12 | 13 | # tudataset adopted from torch_geometric==1.1.0 14 | class TUDatasetExt(InMemoryDataset): 15 | r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY", 16 | "REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University 17 | `_. 18 | 19 | Args: 20 | root (string): Root directory where the dataset should be saved. 21 | name (string): The `name `_ of 22 | the dataset. 23 | transform (callable, optional): A function/transform that takes in an 24 | :obj:`torch_geometric.data.Data` object and returns a transformed 25 | version. The data object will be transformed before every access. 26 | (default: :obj:`None`) 27 | pre_transform (callable, optional): A function/transform that takes in 28 | an :obj:`torch_geometric.data.Data` object and returns a 29 | transformed version. The data object will be transformed before 30 | being saved to disk. (default: :obj:`None`) 31 | pre_filter (callable, optional): A function that takes in an 32 | :obj:`torch_geometric.data.Data` object and returns a boolean 33 | value, indicating whether the data object should be included in the 34 | final dataset. (default: :obj:`None`) 35 | use_node_attr (bool, optional): If :obj:`True`, the dataset will 36 | contain additional continuous node features (if present). 37 | (default: :obj:`False`) 38 | """ 39 | 40 | url = 'https://ls11-www.cs.uni-dortmund.de/people/morris/' \ 41 | 'graphkerneldatasets' 42 | 43 | def __init__(self, 44 | root, 45 | name, 46 | transform=None, 47 | pre_transform=None, 48 | pre_filter=None, 49 | use_node_attr=False, 50 | processed_filename='data.pt'): 51 | self.name = name 52 | self.processed_filename = processed_filename 53 | super(TUDatasetExt, self).__init__(root, transform, pre_transform, 54 | pre_filter) 55 | self.data, self.slices = torch.load(self.processed_paths[0]) 56 | if self.data.x is not None and not use_node_attr: 57 | self.data.x = self.data.x[:, self.num_node_attributes:] 58 | 59 | @property 60 | def num_node_labels(self): 61 | if self.data.x is None: 62 | return 0 63 | for i in range(self.data.x.size(1)): 64 | if self.data.x[:, i:].sum().item() == self.data.x.size(0): 65 | return self.data.x.size(1) - i 66 | return 0 67 | 68 | @property 69 | def num_node_attributes(self): 70 | if self.data.x is None: 71 | return 0 72 | return self.data.x.size(1) - self.num_node_labels 73 | 74 | @property 75 | def raw_file_names(self): 76 | names = ['A', 'graph_indicator'] 77 | return ['{}_{}.txt'.format(self.name, name) for name in names] 78 | 79 | @property 80 | def processed_file_names(self): 81 | return self.processed_filename 82 | 83 | @property 84 | def num_node_features(self): 85 | r"""Returns the number of features per node in the dataset.""" 86 | return self[0].num_node_features 87 | 88 | def download(self): 89 | path = download_url('{}/{}.zip'.format(self.url, self.name), self.root) 90 | extract_zip(path, self.root) 91 | os.unlink(path) 92 | shutil.rmtree(self.raw_dir) 93 | os.rename(osp.join(self.root, self.name), self.raw_dir) 94 | 95 | def process(self): 96 | self.data, self.slices = read_tu_data(self.raw_dir, self.name) 97 | 98 | if self.pre_filter is not None: 99 | data_list = [self.get(idx) for idx in range(len(self))] 100 | data_list = [data for data in data_list if self.pre_filter(data)] 101 | self.data, self.slices = self.collate(data_list) 102 | 103 | if self.pre_transform is not None: 104 | data_list = [self.get(idx) for idx in range(len(self))] 105 | data_list = [self.pre_transform(data) for data in data_list] 106 | self.data, self.slices = self.collate(data_list) 107 | 108 | torch.save((self.data, self.slices), self.processed_paths[0]) 109 | 110 | def __repr__(self): 111 | return '{}({})'.format(self.name, len(self)) 112 | 113 | def get(self, idx): 114 | data = self.data.__class__() 115 | if hasattr(self.data, '__num_nodes__'): 116 | data.num_nodes = self.data.__num_nodes__[idx] 117 | for key in self.data.keys: 118 | item, slices = self.data[key], self.slices[key] 119 | if torch.is_tensor(item): 120 | s = list(repeat(slice(None), item.dim())) 121 | s[self.data.__cat_dim__(key, 122 | item)] = slice(slices[idx], 123 | slices[idx + 1]) 124 | else: 125 | s = slice(slices[idx], slices[idx + 1]) 126 | data[key] = item[s] 127 | return data 128 | 129 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/finetune/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def print_weights(model): 4 | for name, param in model.named_parameters(): 5 | if param.requires_grad: 6 | print(name, param.shape) 7 | sys.stdout.flush() 8 | 9 | 10 | def logger(info): 11 | fold, epoch = info['fold'], info['epoch'] 12 | if epoch == 1 or epoch % 10 == 0: 13 | train_acc, test_acc = info['train_acc'], info['test_acc'] 14 | print('{:02d}/{:03d}: Train Acc: {:.3f}, Test Accuracy: {:.3f}'.format( 15 | fold, epoch, train_acc, test_acc)) 16 | sys.stdout.flush() 17 | 18 | 19 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/.experiment_generative_linkPrediction.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU_LP/pretrain/.experiment_generative_linkPrediction.py.swp -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | import torch 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | from torch_geometric.utils import degree 7 | import torch_geometric.transforms as T 8 | from feature_expansion import FeatureExpander 9 | from tu_dataset import TUDatasetExt 10 | 11 | 12 | def get_dataset(name, sparse=True, feat_str="deg+ak3+reall", root=None): 13 | if root is None or root == '': 14 | path = osp.join(osp.expanduser('~'), 'pyG_data', name) 15 | else: 16 | path = osp.join(root, name) 17 | path = '../' + path 18 | degree = feat_str.find("deg") >= 0 19 | onehot_maxdeg = re.findall("odeg(\d+)", feat_str) 20 | onehot_maxdeg = int(onehot_maxdeg[0]) if onehot_maxdeg else None 21 | k = re.findall("an{0,1}k(\d+)", feat_str) 22 | k = int(k[0]) if k else 0 23 | groupd = re.findall("groupd(\d+)", feat_str) 24 | groupd = int(groupd[0]) if groupd else 0 25 | remove_edges = re.findall("re(\w+)", feat_str) 26 | remove_edges = remove_edges[0] if remove_edges else 'none' 27 | edge_noises_add = re.findall("randa([\d\.]+)", feat_str) 28 | edge_noises_add = float(edge_noises_add[0]) if edge_noises_add else 0 29 | edge_noises_delete = re.findall("randd([\d\.]+)", feat_str) 30 | edge_noises_delete = float( 31 | edge_noises_delete[0]) if edge_noises_delete else 0 32 | centrality = feat_str.find("cent") >= 0 33 | coord = feat_str.find("coord") >= 0 34 | 35 | pre_transform = FeatureExpander( 36 | degree=degree, onehot_maxdeg=onehot_maxdeg, AK=k, 37 | centrality=centrality, remove_edges=remove_edges, 38 | edge_noises_add=edge_noises_add, edge_noises_delete=edge_noises_delete, 39 | group_degree=groupd).transform 40 | 41 | dataset = TUDatasetExt( 42 | path, name, pre_transform=pre_transform, 43 | use_node_attr=True, processed_filename="data_%s.pt" % feat_str) 44 | dataset.data.edge_attr = None 45 | 46 | return dataset 47 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/experiment_generative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from tu_dataset import DataLoader 4 | 5 | from utils import print_weights 6 | from tqdm import tqdm 7 | from copy import deepcopy 8 | from res_gcn import vgae 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | def experiment(dataset, model_func, epochs, batch_size, lr, weight_decay, 14 | dataset_name=None, aug_mode='uniform', aug_ratio=0.2, suffix=0): 15 | model, encoder, decoder = model_func(dataset) 16 | generator_1 = vgae(encoder, decoder) 17 | _, _encoder, _decoder = model_func(dataset) 18 | generator_2 = vgae(_encoder, _decoder) 19 | 20 | model, generator_1, generator_2 = model.to(device), generator_1.to(device), generator_2.to(device) 21 | print_weights(model) 22 | if torch.cuda.is_available(): 23 | torch.cuda.synchronize() 24 | 25 | dataset.set_aug_mode('generative') 26 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=16) 27 | 28 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 29 | optimizer_generator_1 = Adam(generator_1.parameters(), lr=lr) 30 | optimizer_generator_2 = Adam(generator_2.parameters(), lr=lr) 31 | 32 | # for epoch in tqdm(range(1, epochs+1)): 33 | for epoch in range(1, epochs+1): 34 | loader.dataset.set_generator(deepcopy(generator_1).cpu(), deepcopy(generator_2).cpu()) 35 | pretrain_loss, generative_loss = train(loader, model, optimizer, generator_1, optimizer_generator_1, generator_2, optimizer_generator_2, device) 36 | print(pretrain_loss, generative_loss) 37 | 38 | if epoch % 20 == 0: 39 | weight_path = './weights_infomin/' + dataset_name + '_' + str(lr) + '_' + str(epoch) + '_' + str(suffix) + '.pt' 40 | torch.save(model.state_dict(), weight_path) 41 | 42 | torch.save({'graphcl':model.state_dict(), 'graphcl_opt': optimizer.state_dict(), 'generator_1':generator_1.state_dict(), 'generator_1_opt':optimizer_generator_1.state_dict(), 'generator_2':generator_2.state_dict(), 'generator_2_opt':optimizer_generator_2.state_dict()}, './weights_generative_joao/checkpoint_' + dataset_name + '_' + str(lr) + '_' + str(suffix) + '.pt') 43 | 44 | 45 | def num_graphs(data): 46 | if data.batch is not None: 47 | return data.num_graphs 48 | else: 49 | return data.x.size(0) 50 | 51 | 52 | def train(loader, model, optimizer, generator_1, optimizer_generator_1, generator_2, optimizer_generator_2, device): 53 | model.train() 54 | generator_1.train() 55 | generator_2.train() 56 | total_loss, generative_loss = 0, 0 57 | for data, data1, data2 in loader: 58 | optimizer.zero_grad() 59 | data1 = data1.to(device) 60 | data2 = data2.to(device) 61 | out1 = model.forward_graphcl(data1) 62 | out2 = model.forward_graphcl(data2) 63 | loss_cl = model.loss_graphcl(out1, out2, mean=False) 64 | loss = loss_cl.mean() 65 | loss.backward() 66 | total_loss += loss.item() * num_graphs(data1) 67 | optimizer.step() 68 | 69 | # reward for joao 70 | loss_cl = loss_cl.detach() 71 | loss_cl = loss_cl - loss_cl.mean() 72 | loss_cl[loss_cl>0] = 1 73 | loss_cl[loss_cl<=0] = 0.01 # weaken the reward for low cl loss 74 | 75 | # joao 76 | optimizer_generator_1.zero_grad() 77 | optimizer_generator_2.zero_grad() 78 | data = data.to(device) 79 | 80 | loss_1 = generator_1(data, reward=loss_cl) 81 | loss_2 = generator_2(data, reward=loss_cl) 82 | 83 | loss = loss_1 + loss_2 84 | loss.backward() 85 | optimizer_generator_1.step() 86 | optimizer_generator_2.step() 87 | generative_loss += loss.item() * num_graphs(data) 88 | 89 | return total_loss/len(loader.dataset), generative_loss/len(loader.dataset) 90 | 91 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/experiment_generative_ib.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from tu_dataset import DataLoader 4 | 5 | from utils import print_weights 6 | from tqdm import tqdm 7 | from copy import deepcopy 8 | from res_gcn import vgae 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | def experiment(dataset, model_func, epochs, batch_size, lr, weight_decay, 14 | dataset_name=None, aug_mode='uniform', aug_ratio=0.2, suffix=0): 15 | model, encoder, decoder = model_func(dataset) 16 | generator_1 = vgae(encoder, decoder) 17 | model_ib, _encoder, _decoder = model_func(dataset) 18 | generator_2 = vgae(_encoder, _decoder) 19 | 20 | model, generator_1, generator_2, model_ib = model.to(device), generator_1.to(device), generator_2.to(device), model_ib.to(device) 21 | print_weights(model) 22 | if torch.cuda.is_available(): 23 | torch.cuda.synchronize() 24 | 25 | dataset.set_aug_mode('generative') 26 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=16) 27 | 28 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 29 | optimizer_generator_1 = Adam(generator_1.parameters(), lr=lr) 30 | optimizer_generator_2 = Adam(generator_2.parameters(), lr=lr) 31 | optimizer_ib = Adam(model_ib.parameters(), lr=lr) 32 | 33 | # for epoch in tqdm(range(1, epochs+1)): 34 | for epoch in range(1, epochs+1): 35 | loader.dataset.set_generator(deepcopy(generator_1).cpu(), deepcopy(generator_2).cpu()) 36 | pretrain_loss, generative_loss = train(loader, model, optimizer, generator_1, optimizer_generator_1, generator_2, optimizer_generator_2, model_ib, optimizer_ib, device) 37 | print(pretrain_loss, generative_loss) 38 | 39 | if epoch % 20 == 0: 40 | weight_path = './weights_infominbn/' + dataset_name + '_' + str(lr) + '_' + str(epoch) + '_' + str(suffix) + '.pt' 41 | torch.save(model.state_dict(), weight_path) 42 | 43 | torch.save({'graphcl':model.state_dict(), 'graphcl_opt': optimizer.state_dict(), 'model_ib':model_ib.state_dict(), 'model_ib_opt': optimizer_ib.state_dict(), 'generator_1':generator_1.state_dict(), 'generator_1_opt':optimizer_generator_1.state_dict(), 'generator_2':generator_2.state_dict(), 'generator_2_opt':optimizer_generator_2.state_dict()}, './weights_generative_joao/checkpoint_' + dataset_name + '_' + str(lr) + '_' + str(suffix) + '.pt') 44 | 45 | 46 | def num_graphs(data): 47 | if data.batch is not None: 48 | return data.num_graphs 49 | else: 50 | return data.x.size(0) 51 | 52 | 53 | def train(loader, model, optimizer, generator_1, optimizer_generator_1, generator_2, optimizer_generator_2, model_ib, optimizer_ib, device): 54 | model.train() 55 | generator_1.train() 56 | generator_2.train() 57 | model_ib.train() 58 | total_loss, generative_loss = 0, 0 59 | for data, data1, data2 in loader: 60 | optimizer.zero_grad() 61 | data1 = data1.to(device) 62 | data2 = data2.to(device) 63 | out1 = model.forward_graphcl(data1) 64 | out2 = model.forward_graphcl(data2) 65 | loss_cl = model.loss_graphcl(out1, out2, mean=False) 66 | loss = loss_cl.mean() 67 | loss.backward() 68 | total_loss += loss.item() * num_graphs(data1) 69 | optimizer.step() 70 | 71 | # information bottleneck 72 | optimizer_ib.zero_grad() 73 | _out1 = model_ib.forward_graphcl(data1) 74 | _out2 = model_ib.forward_graphcl(data2) 75 | loss_ib = model_ib.loss_graphcl(_out1, out1.detach(), mean=False) + model_ib.loss_graphcl(_out2, out2.detach(), mean=False) 76 | loss = loss_ib.mean() 77 | loss.backward() 78 | optimizer_ib.step() 79 | 80 | # reward for joao 81 | loss_cl = loss_cl.detach() + loss_ib.detach() 82 | loss_cl = loss_cl - loss_cl.mean() 83 | loss_cl[loss_cl>0] = 1 84 | loss_cl[loss_cl<=0] = 0.01 # weaken the reward for low cl loss 85 | 86 | # joao 87 | optimizer_generator_1.zero_grad() 88 | optimizer_generator_2.zero_grad() 89 | data = data.to(device) 90 | 91 | loss_1 = generator_1(data, reward=loss_cl) 92 | loss_2 = generator_2(data, reward=loss_cl) 93 | 94 | loss = loss_1 + loss_2 95 | loss.backward() 96 | optimizer_generator_1.step() 97 | optimizer_generator_2.step() 98 | generative_loss += loss.item() * num_graphs(data) 99 | 100 | return total_loss/len(loader.dataset), generative_loss/len(loader.dataset) 101 | 102 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/experiment_generative_ibalone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from tu_dataset import DataLoader 4 | 5 | from utils import print_weights 6 | from tqdm import tqdm 7 | from copy import deepcopy 8 | from res_gcn import vgae 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | def experiment(dataset, model_func, epochs, batch_size, lr, weight_decay, 14 | dataset_name=None, aug_mode='uniform', aug_ratio=0.2, suffix=0): 15 | model, encoder, decoder = model_func(dataset) 16 | generator_1 = vgae(encoder, decoder) 17 | model_ib, _encoder, _decoder = model_func(dataset) 18 | generator_2 = vgae(_encoder, _decoder) 19 | 20 | model, generator_1, generator_2, model_ib = model.to(device), generator_1.to(device), generator_2.to(device), model_ib.to(device) 21 | print_weights(model) 22 | if torch.cuda.is_available(): 23 | torch.cuda.synchronize() 24 | 25 | dataset.set_aug_mode('generative') 26 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=16) 27 | 28 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 29 | optimizer_generator_1 = Adam(generator_1.parameters(), lr=lr) 30 | optimizer_generator_2 = Adam(generator_2.parameters(), lr=lr) 31 | optimizer_ib = Adam(model_ib.parameters(), lr=lr) 32 | 33 | # for epoch in tqdm(range(1, epochs+1)): 34 | for epoch in range(1, epochs+1): 35 | loader.dataset.set_generator(deepcopy(generator_1).cpu(), deepcopy(generator_2).cpu()) 36 | pretrain_loss, generative_loss = train(loader, model, optimizer, generator_1, optimizer_generator_1, generator_2, optimizer_generator_2, model_ib, optimizer_ib, device) 37 | print(pretrain_loss, generative_loss) 38 | 39 | if epoch % 20 == 0: 40 | weight_path = './weights_infobn/' + dataset_name + '_' + str(lr) + '_' + str(epoch) + '_' + str(suffix) + '.pt' 41 | torch.save(model.state_dict(), weight_path) 42 | 43 | torch.save({'graphcl':model.state_dict(), 'graphcl_opt': optimizer.state_dict(), 'model_ib':model_ib.state_dict(), 'model_ib_opt': optimizer_ib.state_dict(), 'generator_1':generator_1.state_dict(), 'generator_1_opt':optimizer_generator_1.state_dict(), 'generator_2':generator_2.state_dict(), 'generator_2_opt':optimizer_generator_2.state_dict()}, './weights_generative_ibalone/checkpoint_' + dataset_name + '_' + str(lr) + '_' + str(suffix) + '.pt') 44 | 45 | 46 | def num_graphs(data): 47 | if data.batch is not None: 48 | return data.num_graphs 49 | else: 50 | return data.x.size(0) 51 | 52 | 53 | def train(loader, model, optimizer, generator_1, optimizer_generator_1, generator_2, optimizer_generator_2, model_ib, optimizer_ib, device): 54 | model.train() 55 | generator_1.train() 56 | generator_2.train() 57 | model_ib.train() 58 | total_loss, generative_loss = 0, 0 59 | for data, data1, data2 in loader: 60 | optimizer.zero_grad() 61 | data1 = data1.to(device) 62 | data2 = data2.to(device) 63 | out1 = model.forward_graphcl(data1) 64 | out2 = model.forward_graphcl(data2) 65 | loss_cl = model.loss_graphcl(out1, out2, mean=False) 66 | loss = loss_cl.mean() 67 | loss.backward() 68 | total_loss += loss.item() * num_graphs(data1) 69 | optimizer.step() 70 | 71 | # information bottleneck 72 | optimizer_ib.zero_grad() 73 | _out1 = model_ib.forward_graphcl(data1) 74 | _out2 = model_ib.forward_graphcl(data2) 75 | loss_ib = model_ib.loss_graphcl(_out1, out1.detach(), mean=False) + model_ib.loss_graphcl(_out2, out2.detach(), mean=False) 76 | loss = loss_ib.mean() 77 | loss.backward() 78 | optimizer_ib.step() 79 | 80 | # reward for joao 81 | loss_cl = loss_ib.detach() 82 | loss_cl = loss_cl - loss_cl.mean() 83 | loss_cl[loss_cl>0] = 1 84 | loss_cl[loss_cl<=0] = 0.01 # weaken the reward for low cl loss 85 | 86 | # joao 87 | optimizer_generator_1.zero_grad() 88 | optimizer_generator_2.zero_grad() 89 | data = data.to(device) 90 | 91 | loss_1 = generator_1(data, reward=loss_cl) 92 | loss_2 = generator_2(data, reward=loss_cl) 93 | 94 | loss = loss_1 + loss_2 95 | loss.backward() 96 | optimizer_generator_1.step() 97 | optimizer_generator_2.step() 98 | generative_loss += loss.item() * num_graphs(data) 99 | 100 | return total_loss/len(loader.dataset), generative_loss/len(loader.dataset) 101 | 102 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | from torch_geometric.nn.inits import glorot, zeros 7 | 8 | 9 | class GCNConv(MessagePassing): 10 | r"""The graph convolutional operator from the `"Semi-supervised 11 | Classfication with Graph Convolutional Networks" 12 | `_ paper 13 | 14 | .. math:: 15 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 16 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 17 | 18 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 19 | adjacency matrix with inserted self-loops and 20 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 21 | 22 | Args: 23 | in_channels (int): Size of each input sample. 24 | out_channels (int): Size of each output sample. 25 | improved (bool, optional): If set to :obj:`True`, the layer computes 26 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 27 | (default: :obj:`False`) 28 | cached (bool, optional): If set to :obj:`True`, the layer will cache 29 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 30 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 31 | (default: :obj:`False`) 32 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 33 | an additive bias. (default: :obj:`True`) 34 | edge_norm (bool, optional): whether or not to normalize adj matrix. 35 | (default: :obj:`True`) 36 | gfn (bool, optional): If `True`, only linear transform (1x1 conv) is 37 | applied to every nodes. (default: :obj:`False`) 38 | """ 39 | 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | improved=False, 44 | cached=False, 45 | bias=True, 46 | edge_norm=True, 47 | gfn=False): 48 | super(GCNConv, self).__init__('add') 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.improved = improved 53 | self.cached = cached 54 | self.cached_result = None 55 | self.edge_norm = edge_norm 56 | self.gfn = gfn 57 | 58 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 59 | 60 | if bias: 61 | self.bias = Parameter(torch.Tensor(out_channels)) 62 | else: 63 | self.register_parameter('bias', None) 64 | 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | glorot(self.weight) 69 | zeros(self.bias) 70 | self.cached_result = None 71 | 72 | @staticmethod 73 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 74 | if edge_weight is None: 75 | edge_weight = torch.ones((edge_index.size(1), ), 76 | dtype=dtype, 77 | device=edge_index.device) 78 | edge_weight = edge_weight.view(-1) 79 | assert edge_weight.size(0) == edge_index.size(1) 80 | 81 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 82 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 83 | # Add edge_weight for loop edges. 84 | loop_weight = torch.full((num_nodes, ), 85 | 1 if not improved else 2, 86 | dtype=edge_weight.dtype, 87 | device=edge_weight.device) 88 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 89 | 90 | edge_index = edge_index[0] 91 | row, col = edge_index 92 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 93 | deg_inv_sqrt = deg.pow(-0.5) 94 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 95 | 96 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 97 | 98 | def forward(self, x, edge_index, edge_weight=None): 99 | """""" 100 | x = torch.matmul(x, self.weight) 101 | if self.gfn: 102 | return x 103 | 104 | if not self.cached or self.cached_result is None: 105 | if self.edge_norm: 106 | edge_index, norm = GCNConv.norm( 107 | edge_index, x.size(0), edge_weight, self.improved, x.dtype) 108 | else: 109 | norm = None 110 | self.cached_result = edge_index, norm 111 | 112 | edge_index, norm = self.cached_result 113 | return self.propagate(edge_index, x=x, norm=norm) 114 | 115 | def message(self, x_j, norm): 116 | if self.edge_norm: 117 | return norm.view(-1, 1) * x_j 118 | else: 119 | return x_j 120 | 121 | def update(self, aggr_out): 122 | if self.bias is not None: 123 | aggr_out = aggr_out + self.bias 124 | return aggr_out 125 | 126 | def __repr__(self): 127 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 128 | self.out_channels) 129 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def print_weights(model): 4 | for name, param in model.named_parameters(): 5 | if param.requires_grad: 6 | print(name, param.shape) 7 | sys.stdout.flush() 8 | 9 | 10 | def logger(info): 11 | fold, epoch = info['fold'], info['epoch'] 12 | if epoch == 1 or epoch % 10 == 0: 13 | train_acc, test_acc = info['train_acc'], info['test_acc'] 14 | print('{:02d}/{:03d}: Train Acc: {:.3f}, Test Accuracy: {:.3f}'.format( 15 | fold, epoch, train_acc, test_acc)) 16 | sys.stdout.flush() 17 | 18 | 19 | -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/weights_infobn/debug.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU_LP/pretrain/weights_infobn/debug.log -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/weights_infomin/debug.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU_LP/pretrain/weights_infomin/debug.log -------------------------------------------------------------------------------- /semisupervised_TU_LP/pretrain/weights_infominbn/debug.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/semisupervised_TU_LP/pretrain/weights_infominbn/debug.log -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/README.md: -------------------------------------------------------------------------------- 1 | ### Dataset 2 | Please follow the instruction described in https://github.com/snap-stanford/pretrain-gnns#dataset-download. 3 | 4 | 5 | ### JOAO Pre-Training: ### 6 | 7 | ``` 8 | cd ./bio 9 | python pretrain_joao.py --gamma 0.01 10 | cd ./chem 11 | python pretrain_joao.py --gamma 0.01 12 | ``` 13 | 14 | 15 | ### JOAO Finetuning: ### 16 | 17 | ``` 18 | cd ./bio 19 | ./finetune.sh ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 20 | cd ./chem 21 | ./finetune.py bbbp ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 22 | ./finetune.py tox21 ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 23 | ./finetune.py toxcast ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 24 | ./finetune.py sider ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 25 | ./finetune.py clintox ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 26 | ./finetune.py muv ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 27 | ./finetune.py hiv ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 28 | ./finetune.py bace ./weights/joao_${gamma}_100.pth 1e-3 ${RESULT_FILE} 29 | ``` 30 | 31 | ```gamma``` is tuned from {0.01, 0.1, 1}. ```RESULT_FILE``` is the file to store the results. 32 | 33 | 34 | ### JOAOv2 Pre-Training: ### 35 | 36 | ``` 37 | cd ./bio 38 | python pretrain_joaov2.py --gamma 0.01 39 | cd ./chem 40 | python pretrain_joaov2.py --gamma 0.01 41 | ``` 42 | 43 | 44 | ### JOAOv2 Finetuning: ### 45 | 46 | ``` 47 | cd ./bio 48 | ./finetune.sh ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 49 | cd ./chem 50 | ./finetune.py bbbp ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 51 | ./finetune.py tox21 ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 52 | ./finetune.py toxcast ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 53 | ./finetune.py sider ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 54 | ./finetune.py clintox ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 55 | ./finetune.py muv ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 56 | ./finetune.py hiv ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 57 | ./finetune.py bace ./weights/joaov2_${gamma}_100.pth 1e-3 ${RESULT_FILE} 58 | ``` 59 | 60 | ```gamma``` is tuned from {0.01, 0.1, 1}. ```RESULT_FILE``` is the file to store the results. 61 | 62 | 63 | ## Acknowledgements 64 | 65 | The backbone implementation is reference to https://github.com/snap-stanford/pretrain-gnns. 66 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/bio/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from batch import BatchFinetune, BatchMasking, BatchAE, BatchSubstructContext 5 | 6 | class DataLoaderFinetune(torch.utils.data.DataLoader): 7 | r"""Data loader which merges data objects from a 8 | :class:`torch_geometric.data.dataset` to a mini-batch. 9 | Args: 10 | dataset (Dataset): The dataset from which to load the data. 11 | batch_size (int, optional): How may samples per batch to load. 12 | (default: :obj:`1`) 13 | shuffle (bool, optional): If set to :obj:`True`, the data will be 14 | reshuffled at every epoch (default: :obj:`True`) 15 | """ 16 | 17 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 18 | super(DataLoaderFinetune, self).__init__( 19 | dataset, 20 | batch_size, 21 | shuffle, 22 | collate_fn=lambda data_list: BatchFinetune.from_data_list(data_list), 23 | **kwargs) 24 | 25 | class DataLoaderMasking(torch.utils.data.DataLoader): 26 | r"""Data loader which merges data objects from a 27 | :class:`torch_geometric.data.dataset` to a mini-batch. 28 | Args: 29 | dataset (Dataset): The dataset from which to load the data. 30 | batch_size (int, optional): How may samples per batch to load. 31 | (default: :obj:`1`) 32 | shuffle (bool, optional): If set to :obj:`True`, the data will be 33 | reshuffled at every epoch (default: :obj:`True`) 34 | """ 35 | 36 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 37 | super(DataLoaderMasking, self).__init__( 38 | dataset, 39 | batch_size, 40 | shuffle, 41 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list), 42 | **kwargs) 43 | 44 | 45 | class DataLoaderAE(torch.utils.data.DataLoader): 46 | r"""Data loader which merges data objects from a 47 | :class:`torch_geometric.data.dataset` to a mini-batch. 48 | Args: 49 | dataset (Dataset): The dataset from which to load the data. 50 | batch_size (int, optional): How may samples per batch to load. 51 | (default: :obj:`1`) 52 | shuffle (bool, optional): If set to :obj:`True`, the data will be 53 | reshuffled at every epoch (default: :obj:`True`) 54 | """ 55 | 56 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 57 | super(DataLoaderAE, self).__init__( 58 | dataset, 59 | batch_size, 60 | shuffle, 61 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list), 62 | **kwargs) 63 | 64 | 65 | class DataLoaderSubstructContext(torch.utils.data.DataLoader): 66 | r"""Data loader which merges data objects from a 67 | :class:`torch_geometric.data.dataset` to a mini-batch. 68 | Args: 69 | dataset (Dataset): The dataset from which to load the data. 70 | batch_size (int, optional): How may samples per batch to load. 71 | (default: :obj:`1`) 72 | shuffle (bool, optional): If set to :obj:`True`, the data will be 73 | reshuffled at every epoch (default: :obj:`True`) 74 | """ 75 | 76 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 77 | super(DataLoaderSubstructContext, self).__init__( 78 | dataset, 79 | batch_size, 80 | shuffle, 81 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list), 82 | **kwargs) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/bio/finetune.sh: -------------------------------------------------------------------------------- 1 | #### GIN fine-tuning 2 | split=species 3 | 4 | model_file=$1 5 | lr=$2 6 | resultFile_name=$3 7 | 8 | ### for GIN 9 | for runseed in 0 1 2 3 4 5 6 7 8 9 10 | do 11 | python finetune.py --model_file $model_file --split $split --epochs 50 --device 0 --runseed $runseed --gnn_type gin --lr $lr --resultFile_name $resultFile_name 12 | done 13 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/bio/results/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/transferLearning_MoleculeNet_PPI/bio/results/place_holder.txt -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/bio/splitters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def random_split(dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1, 6 | seed=0): 7 | """ 8 | Adapted from graph-pretrain 9 | :param dataset: 10 | :param task_idx: 11 | :param null_value: 12 | :param frac_train: 13 | :param frac_valid: 14 | :param frac_test: 15 | :param seed: 16 | :return: train, valid, test slices of the input dataset obj. 17 | """ 18 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) 19 | 20 | num_mols = len(dataset) 21 | random.seed(seed) 22 | all_idx = list(range(num_mols)) 23 | random.shuffle(all_idx) 24 | 25 | train_idx = all_idx[:int(frac_train * num_mols)] 26 | valid_idx = all_idx[int(frac_train * num_mols):int(frac_valid * num_mols) 27 | + int(frac_train * num_mols)] 28 | test_idx = all_idx[int(frac_valid * num_mols) + int(frac_train * num_mols):] 29 | 30 | assert len(set(train_idx).intersection(set(valid_idx))) == 0 31 | assert len(set(valid_idx).intersection(set(test_idx))) == 0 32 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols 33 | 34 | train_dataset = dataset[torch.tensor(train_idx)] 35 | valid_dataset = dataset[torch.tensor(valid_idx)] 36 | if frac_test == 0: 37 | test_dataset = None 38 | else: 39 | test_dataset = dataset[torch.tensor(test_idx)] 40 | 41 | return train_dataset, valid_dataset, test_dataset 42 | 43 | def species_split(dataset, train_valid_species_id_list=[3702, 6239, 511145, 44 | 7227, 10090, 4932, 7955], 45 | test_species_id_list=[9606]): 46 | """ 47 | Split dataset based on species_id attribute 48 | :param dataset: 49 | :param train_valid_species_id_list: 50 | :param test_species_id_list: 51 | :return: train_valid dataset, test dataset 52 | """ 53 | # NB: pytorch geometric dataset object can be indexed using slices or 54 | # byte tensors. We will use byte tensors here 55 | 56 | train_valid_byte_tensor = torch.zeros(len(dataset), dtype=torch.uint8) 57 | for id in train_valid_species_id_list: 58 | train_valid_byte_tensor += (dataset.data.species_id == id) 59 | 60 | test_species_byte_tensor = torch.zeros(len(dataset), dtype=torch.uint8) 61 | for id in test_species_id_list: 62 | test_species_byte_tensor += (dataset.data.species_id == id) 63 | 64 | assert ((train_valid_byte_tensor + test_species_byte_tensor) == 1).all() 65 | 66 | train_valid_dataset = dataset[train_valid_byte_tensor] 67 | test_valid_dataset = dataset[test_species_byte_tensor] 68 | 69 | return train_valid_dataset, test_valid_dataset 70 | 71 | if __name__ == "__main__": 72 | from collections import Counter 73 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/bio/weights/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/transferLearning_MoleculeNet_PPI/bio/weights/place_holder.txt -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/chem/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from batch import BatchSubstructContext, BatchMasking, BatchAE 5 | 6 | class DataLoaderSubstructContext(torch.utils.data.DataLoader): 7 | r"""Data loader which merges data objects from a 8 | :class:`torch_geometric.data.dataset` to a mini-batch. 9 | Args: 10 | dataset (Dataset): The dataset from which to load the data. 11 | batch_size (int, optional): How may samples per batch to load. 12 | (default: :obj:`1`) 13 | shuffle (bool, optional): If set to :obj:`True`, the data will be 14 | reshuffled at every epoch (default: :obj:`True`) 15 | """ 16 | 17 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 18 | super(DataLoaderSubstructContext, self).__init__( 19 | dataset, 20 | batch_size, 21 | shuffle, 22 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list), 23 | **kwargs) 24 | 25 | class DataLoaderMasking(torch.utils.data.DataLoader): 26 | r"""Data loader which merges data objects from a 27 | :class:`torch_geometric.data.dataset` to a mini-batch. 28 | Args: 29 | dataset (Dataset): The dataset from which to load the data. 30 | batch_size (int, optional): How may samples per batch to load. 31 | (default: :obj:`1`) 32 | shuffle (bool, optional): If set to :obj:`True`, the data will be 33 | reshuffled at every epoch (default: :obj:`True`) 34 | """ 35 | 36 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 37 | super(DataLoaderMasking, self).__init__( 38 | dataset, 39 | batch_size, 40 | shuffle, 41 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list), 42 | **kwargs) 43 | 44 | 45 | class DataLoaderAE(torch.utils.data.DataLoader): 46 | r"""Data loader which merges data objects from a 47 | :class:`torch_geometric.data.dataset` to a mini-batch. 48 | Args: 49 | dataset (Dataset): The dataset from which to load the data. 50 | batch_size (int, optional): How may samples per batch to load. 51 | (default: :obj:`1`) 52 | shuffle (bool, optional): If set to :obj:`True`, the data will be 53 | reshuffled at every epoch (default: :obj:`True`) 54 | """ 55 | 56 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 57 | super(DataLoaderAE, self).__init__( 58 | dataset, 59 | batch_size, 60 | shuffle, 61 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list), 62 | **kwargs) 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/chem/finetune.sh: -------------------------------------------------------------------------------- 1 | #### GIN fine-tuning 2 | split=scaffold 3 | 4 | dataset=$1 5 | model_file=$2 6 | lr=$3 7 | resultFile_name=$4 8 | 9 | for runseed in 0 1 2 3 4 5 6 7 8 9 10 | do 11 | python finetune.py --input_model_file $model_file --split $split --runseed $runseed --gnn_type gin --dataset $dataset --lr $lr --epochs 100 --resultFile_name $resultFile_name 12 | done 13 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/chem/results/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/transferLearning_MoleculeNet_PPI/chem/results/place_holder.txt -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI/chem/weights/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/transferLearning_MoleculeNet_PPI/chem/weights/place_holder.txt -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/README.md: -------------------------------------------------------------------------------- 1 | ### LP-InfoMin/InfoBN/Info(Min+BN) Pre-Training: ### 2 | 3 | ``` 4 | cd ./bio 5 | python pretrain_generative_infomin.py 6 | python pretrain_generative_infobn.py 7 | python pretrain_generative_infominbn.py 8 | cd ./chem 9 | python pretrain_generative_infomin.py 10 | python pretrain_generative_infobn.py 11 | python pretrain_generative_infominbn.py 12 | ``` 13 | 14 | 15 | ### LP-InfoMin/InfoBN/Info(Min+BN) Finetuning: ### 16 | 17 | ``` 18 | cd ./bio 19 | ./finetune.sh ${MODEL_FILE} 1e-3 ${RESULT_FILE} 20 | cd ./chem 21 | ./finetune.py bbbp ${MODEL_FILE} 1e-3 ${RESULT_FILE} 22 | ./finetune.py tox21 ${MODEL_FILE} 1e-3 ${RESULT_FILE} 23 | ./finetune.py toxcast ${MODEL_FILE} 1e-3 ${RESULT_FILE} 24 | ./finetune.py sider ${MODEL_FILE} 1e-3 ${RESULT_FILE} 25 | ./finetune.py clintox ${MODEL_FILE} 1e-3 ${RESULT_FILE} 26 | ./finetune.py muv ${MODEL_FILE} 1e-3 ${RESULT_FILE} 27 | ./finetune.py hiv ${MODEL_FILE} 1e-3 ${RESULT_FILE} 28 | ./finetune.py bace ${MODEL_FILE} 1e-3 ${RESULT_FILE} 29 | ``` 30 | 31 | ```MODEL_FILE``` is the saved pre-training weight, and ```RESULT_FILE``` is the file to store the results. 32 | 33 | 34 | 35 | 36 | ## Acknowledgements 37 | 38 | The backbone implementation is reference to https://github.com/snap-stanford/pretrain-gnns. 39 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/bio/.graph_cover_2.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/transferLearning_MoleculeNet_PPI_LP/bio/.graph_cover_2.py.swp -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/bio/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from batch import BatchFinetune, BatchMasking, BatchAE, BatchSubstructContext 5 | 6 | class DataLoaderFinetune(torch.utils.data.DataLoader): 7 | r"""Data loader which merges data objects from a 8 | :class:`torch_geometric.data.dataset` to a mini-batch. 9 | Args: 10 | dataset (Dataset): The dataset from which to load the data. 11 | batch_size (int, optional): How may samples per batch to load. 12 | (default: :obj:`1`) 13 | shuffle (bool, optional): If set to :obj:`True`, the data will be 14 | reshuffled at every epoch (default: :obj:`True`) 15 | """ 16 | 17 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 18 | super(DataLoaderFinetune, self).__init__( 19 | dataset, 20 | batch_size, 21 | shuffle, 22 | collate_fn=lambda data_list: BatchFinetune.from_data_list(data_list), 23 | **kwargs) 24 | 25 | class DataLoaderMasking(torch.utils.data.DataLoader): 26 | r"""Data loader which merges data objects from a 27 | :class:`torch_geometric.data.dataset` to a mini-batch. 28 | Args: 29 | dataset (Dataset): The dataset from which to load the data. 30 | batch_size (int, optional): How may samples per batch to load. 31 | (default: :obj:`1`) 32 | shuffle (bool, optional): If set to :obj:`True`, the data will be 33 | reshuffled at every epoch (default: :obj:`True`) 34 | """ 35 | 36 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 37 | super(DataLoaderMasking, self).__init__( 38 | dataset, 39 | batch_size, 40 | shuffle, 41 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list), 42 | **kwargs) 43 | 44 | 45 | class DataLoaderAE(torch.utils.data.DataLoader): 46 | r"""Data loader which merges data objects from a 47 | :class:`torch_geometric.data.dataset` to a mini-batch. 48 | Args: 49 | dataset (Dataset): The dataset from which to load the data. 50 | batch_size (int, optional): How may samples per batch to load. 51 | (default: :obj:`1`) 52 | shuffle (bool, optional): If set to :obj:`True`, the data will be 53 | reshuffled at every epoch (default: :obj:`True`) 54 | """ 55 | 56 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 57 | super(DataLoaderAE, self).__init__( 58 | dataset, 59 | batch_size, 60 | shuffle, 61 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list), 62 | **kwargs) 63 | 64 | 65 | class DataLoaderSubstructContext(torch.utils.data.DataLoader): 66 | r"""Data loader which merges data objects from a 67 | :class:`torch_geometric.data.dataset` to a mini-batch. 68 | Args: 69 | dataset (Dataset): The dataset from which to load the data. 70 | batch_size (int, optional): How may samples per batch to load. 71 | (default: :obj:`1`) 72 | shuffle (bool, optional): If set to :obj:`True`, the data will be 73 | reshuffled at every epoch (default: :obj:`True`) 74 | """ 75 | 76 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 77 | super(DataLoaderSubstructContext, self).__init__( 78 | dataset, 79 | batch_size, 80 | shuffle, 81 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list), 82 | **kwargs) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/bio/finetune.sh: -------------------------------------------------------------------------------- 1 | #### GIN fine-tuning 2 | split=species 3 | 4 | model_file=$1 5 | lr=$2 6 | resultFile_name=$3 7 | 8 | ### for GIN 9 | for runseed in 0 1 2 3 4 5 6 7 8 9 10 | do 11 | python finetune.py --model_file $model_file --split $split --epochs 50 --device 0 --runseed $runseed --gnn_type gin --lr $lr --resultFile_name $resultFile_name 12 | done 13 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/bio/pretrain_supervised.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from splitters import random_split, species_split 4 | from loader import BioDataset 5 | from torch_geometric.data import DataLoader 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | from model import GNN, GNN_graphpred 16 | from sklearn.metrics import roc_auc_score 17 | 18 | import pandas as pd 19 | 20 | from util import combine_dataset 21 | 22 | criterion = nn.BCEWithLogitsLoss() 23 | 24 | def train(args, model, device, loader, optimizer): 25 | model.train() 26 | 27 | loss_accum = 0 28 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 29 | batch = batch.to(device) 30 | pred = model(batch) 31 | y = batch.go_target_pretrain.view(pred.shape).to(torch.float64) 32 | 33 | optimizer.zero_grad() 34 | loss = criterion(pred.double(), y) 35 | loss.backward() 36 | 37 | optimizer.step() 38 | 39 | loss_accum += loss.detach().cpu() 40 | 41 | return loss_accum / (step + 1) 42 | 43 | 44 | def main(): 45 | # Training settings 46 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 47 | parser.add_argument('--device', type=int, default=0, 48 | help='which gpu to use if any (default: 0)') 49 | parser.add_argument('--batch_size', type=int, default=32, 50 | help='input batch size for training (default: 32)') 51 | parser.add_argument('--epochs', type=int, default=100, 52 | help='number of epochs to train (default: 100)') 53 | parser.add_argument('--lr', type=float, default=0.001, 54 | help='learning rate (default: 0.001)') 55 | parser.add_argument('--decay', type=float, default=0, 56 | help='weight decay (default: 0)') 57 | parser.add_argument('--num_layer', type=int, default=5, 58 | help='number of GNN message passing layers (default: 5).') 59 | parser.add_argument('--emb_dim', type=int, default=300, 60 | help='embedding dimensions (default: 300)') 61 | parser.add_argument('--dropout_ratio', type=float, default=0.2, 62 | help='dropout ratio (default: 0.2)') 63 | parser.add_argument('--graph_pooling', type=str, default="mean", 64 | help='graph level pooling (sum, mean, max, set2set, attention)') 65 | parser.add_argument('--JK', type=str, default="last", 66 | help='how the node features across layers are combined. last, sum, max or concat') 67 | parser.add_argument('--input_model_file', type=str, default = '', help='filename to read the model (if there is any)') 68 | parser.add_argument('--output_model_file', type = str, default = '', help='filename to output the pre-trained model') 69 | parser.add_argument('--gnn_type', type=str, default="gin") 70 | parser.add_argument('--num_workers', type=int, default = 0, help='number of workers for dataset loading') 71 | parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting dataset.") 72 | parser.add_argument('--split', type=str, default = "species", help='Random or species split') 73 | args = parser.parse_args() 74 | 75 | 76 | torch.manual_seed(0) 77 | np.random.seed(0) 78 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 79 | if torch.cuda.is_available(): 80 | torch.cuda.manual_seed_all(0) 81 | 82 | root_supervised = 'dataset/supervised' 83 | 84 | dataset = BioDataset(root_supervised, data_type='supervised') 85 | 86 | if args.split == "random": 87 | print("random splitting") 88 | train_dataset, valid_dataset, test_dataset = random_split(dataset, seed = args.seed) 89 | print(train_dataset) 90 | print(valid_dataset) 91 | pretrain_dataset = combine_dataset(train_dataset, valid_dataset) 92 | print(pretrain_dataset) 93 | elif args.split == "species": 94 | print("species splitting") 95 | trainval_dataset, test_dataset = species_split(dataset) 96 | test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed = args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) 97 | print(trainval_dataset) 98 | print(test_dataset_broad) 99 | pretrain_dataset = combine_dataset(trainval_dataset, test_dataset_broad) 100 | print(pretrain_dataset) 101 | #train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) 102 | else: 103 | raise ValueError("Unknown split name.") 104 | 105 | 106 | train_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 107 | 108 | num_tasks = len(pretrain_dataset[0].go_target_pretrain) 109 | 110 | #set up model 111 | model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type) 112 | if not args.input_model_file == "": 113 | model.from_pretrained(args.input_model_file + ".pth") 114 | 115 | model.to(device) 116 | 117 | #set up optimizer 118 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 119 | print(optimizer) 120 | 121 | for epoch in range(1, args.epochs+1): 122 | print("====epoch " + str(epoch)) 123 | 124 | train_loss = train(args, model, device, train_loader, optimizer) 125 | 126 | if not args.output_model_file == "": 127 | torch.save(model.gnn.state_dict(), args.output_model_file + ".pth") 128 | 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/bio/splitters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def random_split(dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1, 6 | seed=0): 7 | """ 8 | Adapted from graph-pretrain 9 | :param dataset: 10 | :param task_idx: 11 | :param null_value: 12 | :param frac_train: 13 | :param frac_valid: 14 | :param frac_test: 15 | :param seed: 16 | :return: train, valid, test slices of the input dataset obj. 17 | """ 18 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) 19 | 20 | num_mols = len(dataset) 21 | random.seed(seed) 22 | all_idx = list(range(num_mols)) 23 | random.shuffle(all_idx) 24 | 25 | train_idx = all_idx[:int(frac_train * num_mols)] 26 | valid_idx = all_idx[int(frac_train * num_mols):int(frac_valid * num_mols) 27 | + int(frac_train * num_mols)] 28 | test_idx = all_idx[int(frac_valid * num_mols) + int(frac_train * num_mols):] 29 | 30 | assert len(set(train_idx).intersection(set(valid_idx))) == 0 31 | assert len(set(valid_idx).intersection(set(test_idx))) == 0 32 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols 33 | 34 | train_dataset = dataset[torch.tensor(train_idx)] 35 | valid_dataset = dataset[torch.tensor(valid_idx)] 36 | if frac_test == 0: 37 | test_dataset = None 38 | else: 39 | test_dataset = dataset[torch.tensor(test_idx)] 40 | 41 | return train_dataset, valid_dataset, test_dataset 42 | 43 | def species_split(dataset, train_valid_species_id_list=[3702, 6239, 511145, 44 | 7227, 10090, 4932, 7955], 45 | test_species_id_list=[9606]): 46 | """ 47 | Split dataset based on species_id attribute 48 | :param dataset: 49 | :param train_valid_species_id_list: 50 | :param test_species_id_list: 51 | :return: train_valid dataset, test dataset 52 | """ 53 | # NB: pytorch geometric dataset object can be indexed using slices or 54 | # byte tensors. We will use byte tensors here 55 | 56 | train_valid_byte_tensor = torch.zeros(len(dataset), dtype=torch.uint8) 57 | for id in train_valid_species_id_list: 58 | train_valid_byte_tensor += (dataset.data.species_id == id) 59 | 60 | test_species_byte_tensor = torch.zeros(len(dataset), dtype=torch.uint8) 61 | for id in test_species_id_list: 62 | test_species_byte_tensor += (dataset.data.species_id == id) 63 | 64 | assert ((train_valid_byte_tensor + test_species_byte_tensor) == 1).all() 65 | 66 | train_valid_dataset = dataset[train_valid_byte_tensor] 67 | test_valid_dataset = dataset[test_species_byte_tensor] 68 | 69 | return train_valid_dataset, test_valid_dataset 70 | 71 | if __name__ == "__main__": 72 | from collections import Counter 73 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/chem/cal.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | 4 | 5 | fileName = str(sys.argv[1]) 6 | 7 | 8 | with open(fileName, 'r') as f: 9 | data = f.read().split('\n')[:-1] 10 | 11 | 12 | acc_val_dict, acc_test_dict = {}, {} 13 | for d in data: 14 | d = d.split() 15 | dsName, acc_val, acc_test = d[0], float(d[4]), float(d[5]) 16 | if not dsName in acc_val_dict.keys(): 17 | acc_val_dict[dsName] = [acc_val] 18 | acc_test_dict[dsName] = [acc_test] 19 | else: 20 | acc_val_dict[dsName].append(acc_val) 21 | acc_test_dict[dsName].append(acc_test) 22 | 23 | 24 | for dsName in acc_val_dict.keys(): 25 | print(dsName, np.mean(acc_val_dict[dsName]), np.std(acc_val_dict[dsName]), np.mean(acc_test_dict[dsName]), np.std(acc_test_dict[dsName])) 26 | 27 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/chem/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from batch import BatchSubstructContext, BatchMasking, BatchAE 5 | 6 | class DataLoaderSubstructContext(torch.utils.data.DataLoader): 7 | r"""Data loader which merges data objects from a 8 | :class:`torch_geometric.data.dataset` to a mini-batch. 9 | Args: 10 | dataset (Dataset): The dataset from which to load the data. 11 | batch_size (int, optional): How may samples per batch to load. 12 | (default: :obj:`1`) 13 | shuffle (bool, optional): If set to :obj:`True`, the data will be 14 | reshuffled at every epoch (default: :obj:`True`) 15 | """ 16 | 17 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 18 | super(DataLoaderSubstructContext, self).__init__( 19 | dataset, 20 | batch_size, 21 | shuffle, 22 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list), 23 | **kwargs) 24 | 25 | class DataLoaderMasking(torch.utils.data.DataLoader): 26 | r"""Data loader which merges data objects from a 27 | :class:`torch_geometric.data.dataset` to a mini-batch. 28 | Args: 29 | dataset (Dataset): The dataset from which to load the data. 30 | batch_size (int, optional): How may samples per batch to load. 31 | (default: :obj:`1`) 32 | shuffle (bool, optional): If set to :obj:`True`, the data will be 33 | reshuffled at every epoch (default: :obj:`True`) 34 | """ 35 | 36 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 37 | super(DataLoaderMasking, self).__init__( 38 | dataset, 39 | batch_size, 40 | shuffle, 41 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list), 42 | **kwargs) 43 | 44 | 45 | class DataLoaderAE(torch.utils.data.DataLoader): 46 | r"""Data loader which merges data objects from a 47 | :class:`torch_geometric.data.dataset` to a mini-batch. 48 | Args: 49 | dataset (Dataset): The dataset from which to load the data. 50 | batch_size (int, optional): How may samples per batch to load. 51 | (default: :obj:`1`) 52 | shuffle (bool, optional): If set to :obj:`True`, the data will be 53 | reshuffled at every epoch (default: :obj:`True`) 54 | """ 55 | 56 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 57 | super(DataLoaderAE, self).__init__( 58 | dataset, 59 | batch_size, 60 | shuffle, 61 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list), 62 | **kwargs) 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /transferLearning_MoleculeNet_PPI_LP/chem/finetune.sh: -------------------------------------------------------------------------------- 1 | #### GIN fine-tuning 2 | split=scaffold 3 | 4 | dataset=$1 5 | model_file=$2 6 | resultFile_name=$3 7 | 8 | for runseed in 0 1 2 3 4 5 6 7 8 9 9 | do 10 | python finetune.py --input_model_file $model_file --split $split --runseed $runseed --gnn_type gin --dataset $dataset --lr 1e-3 --epochs 100 --resultFile_name $resultFile_name 11 | done 12 | -------------------------------------------------------------------------------- /unsupervised_TU/README.md: -------------------------------------------------------------------------------- 1 | ### JOAO Pre-Training & Finetuning: ### 2 | 3 | ``` 4 | ./joao.sh NCI1 ${gamma} 5 | ``` 6 | 7 | ```gamma``` is tuned from {0.01, 0.1, 1}. 8 | 9 | 10 | ### JOAOv2 Pre-Training & Finetuning: ### 11 | 12 | ``` 13 | ./joaov2.sh NCI1 ${gamma} 14 | ``` 15 | 16 | ```gamma``` is tuned from {0.01, 0.1, 1}. JOAOv2 is trained for 40 epochs since multiple projection heads are trained. 17 | 18 | 19 | ## Acknowledgements 20 | 21 | The backbone implementation is reference to https://github.com/fanyun-sun/InfoGraph/tree/master/unsupervised. 22 | -------------------------------------------------------------------------------- /unsupervised_TU/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def arg_parse(): 4 | parser = argparse.ArgumentParser(description='GcnInformax Arguments.') 5 | parser.add_argument('--DS', dest='DS', help='Dataset') 6 | parser.add_argument('--local', dest='local', action='store_const', 7 | const=True, default=False) 8 | parser.add_argument('--glob', dest='glob', action='store_const', 9 | const=True, default=False) 10 | parser.add_argument('--prior', dest='prior', action='store_const', 11 | const=True, default=False) 12 | 13 | parser.add_argument('--lr', dest='lr', type=float, 14 | help='Learning rate.') 15 | parser.add_argument('--num-gc-layers', dest='num_gc_layers', type=int, default=5, 16 | help='Number of graph convolution layers before each pooling') 17 | parser.add_argument('--hidden-dim', dest='hidden_dim', type=int, default=32, 18 | help='') 19 | 20 | parser.add_argument('--aug', type=str, default='dnodes') 21 | parser.add_argument('--gamma', type=str, default=0.1) 22 | parser.add_argument('--mode', type=str, default='fast') 23 | parser.add_argument('--seed', type=int, default=0) 24 | 25 | return parser.parse_args() 26 | 27 | -------------------------------------------------------------------------------- /unsupervised_TU/joao.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | 3 | for seed in 0 1 2 3 4 4 | do 5 | python joao.py --DS $1 --lr 0.01 --local --num-gc-layers 3 --aug minmax --gamma $2 --seed $seed 6 | done 7 | 8 | -------------------------------------------------------------------------------- /unsupervised_TU/joaov2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | 3 | for seed in 0 1 2 3 4 4 | do 5 | python joaov2.py --DS $1 --lr 0.001 --local --num-gc-layers 3 --aug minmax --gamma $2 --seed $seed 6 | done 7 | -------------------------------------------------------------------------------- /unsupervised_TU/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | # from core.encoders import * 7 | import json 8 | from torch import optim 9 | 10 | 11 | 12 | class GlobalDiscriminator(nn.Module): 13 | def __init__(self, args, input_dim): 14 | super().__init__() 15 | 16 | self.l0 = nn.Linear(32, 32) 17 | self.l1 = nn.Linear(32, 32) 18 | 19 | self.l2 = nn.Linear(512, 1) 20 | def forward(self, y, M, data): 21 | 22 | adj = Variable(data['adj'].float(), requires_grad=False).cuda() 23 | # h0 = Variable(data['feats'].float()).cuda() 24 | batch_num_nodes = data['num_nodes'].int().numpy() 25 | M, _ = self.encoder(M, adj, batch_num_nodes) 26 | # h = F.relu(self.c0(M)) 27 | # h = self.c1(h) 28 | # h = h.view(y.shape[0], -1) 29 | h = torch.cat((y, M), dim=1) 30 | h = F.relu(self.l0(h)) 31 | h = F.relu(self.l1(h)) 32 | return self.l2(h) 33 | 34 | class PriorDiscriminator(nn.Module): 35 | def __init__(self, input_dim): 36 | super().__init__() 37 | self.l0 = nn.Linear(input_dim, input_dim) 38 | self.l1 = nn.Linear(input_dim, input_dim) 39 | self.l2 = nn.Linear(input_dim, 1) 40 | 41 | def forward(self, x): 42 | h = F.relu(self.l0(x)) 43 | h = F.relu(self.l1(h)) 44 | return torch.sigmoid(self.l2(h)) 45 | 46 | class FF(nn.Module): 47 | def __init__(self, input_dim): 48 | super().__init__() 49 | # self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1) 50 | # self.c1 = nn.Conv1d(512, 512, kernel_size=1) 51 | # self.c2 = nn.Conv1d(512, 1, kernel_size=1) 52 | self.block = nn.Sequential( 53 | nn.Linear(input_dim, input_dim), 54 | nn.ReLU(), 55 | nn.Linear(input_dim, input_dim), 56 | nn.ReLU(), 57 | nn.Linear(input_dim, input_dim), 58 | nn.ReLU() 59 | ) 60 | self.linear_shortcut = nn.Linear(input_dim, input_dim) 61 | # self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1, stride=1, padding=0) 62 | # self.c1 = nn.Conv1d(512, 512, kernel_size=1, stride=1, padding=0) 63 | # self.c2 = nn.Conv1d(512, 1, kernel_size=1, stride=1, padding=0) 64 | 65 | def forward(self, x): 66 | return self.block(x) + self.linear_shortcut(x) 67 | 68 | -------------------------------------------------------------------------------- /unsupervised_TU/results/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shen-Lab/GraphCL_Automated/8f3c2ac7831b88693e932c924428d0c3fe065894/unsupervised_TU/results/place_holder.txt --------------------------------------------------------------------------------