├── README.md ├── __pycache__ ├── data.cpython-37.pyc ├── models.cpython-37.pyc └── utils.cpython-37.pyc ├── data.py ├── img └── model.PNG ├── models.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BGRL_Pytorch 2 | Implementation of Large-Scale Representation Learning on Graphs via Bootstrapping. 3 | 4 | A PyTorch implementation of "Large-Scale Representation Learning on Graphs via Bootstrapping" paper, accepted in ICLR 2021 Workshop 5 | 6 | 7 | 8 | ## Hyperparameters for training BGRL 9 | Following Options can be passed to `train.py` 10 | 11 | 12 | `--layers:` or `-l:` 13 | one or more integer values specifying the number of units for each GNN layer. Default is 512 256. 14 | usage example :`--layers 512 256` 15 | 16 | 17 | `--aug_params:` or `-p:` 18 | four float values specifying the hyperparameters for graph augmentation (p_f1, p_f2, p_e1, p_e2). Default is 0.2 0.1 0.2 0.3. 19 | usage example : `--aug_params 0.2 0.1 0.2 0.3` 20 | 21 | 22 | 23 | |params|WikiCS|Am.Computers|Am.Photos|Co.CS|Co.Physics| 24 | |------|------|------------|---------|-----|----------| 25 | |p_f1 |0.2 |0.2 |0.1 |0.3 |0.1 | 26 | |p_f2 |0.1 |0.1 |0.2 |0.4 |0.4 | 27 | |p_e1 |0.2 |0.5 |0.4 |0.3 |0.4 | 28 | |p_e2 |0.3 |0.4 |0.1 |0.2 |0.1 | 29 | |embedding size|256|128|256|256|128| 30 | |encoder hidden size|512|256|512|512|256| 31 | |predictor hidden size|512|512|512|512|512| 32 | * Hyperparameters are from original paper 33 | 34 | 35 | ## Experimental Results 36 | |WikiCS|Am.Computers|Am.Photos|Co.CS|Co.Physics| 37 | |------|------------|---------|-----|----------| 38 | |79.50 |88.21 |92.76 |92.49|94.89 | 39 | 40 | 41 | ## Codes borrowed from 42 | Codes are borrowed from BYOL and SelfGNN 43 | 44 | 45 | | name | Implementation Code | Paper | 46 | | ----------- | ------------------- | ------- | 47 | | `Bootstrap Your Own Latent`| Implementation| paper| 48 | | `SelfGNN`| Implementation| paper| 49 | -------------------------------------------------------------------------------- /__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, InMemoryDataset 3 | import torch_geometric.transforms as T 4 | from torch_geometric.utils import to_undirected 5 | 6 | import os.path as osp 7 | 8 | import utils 9 | 10 | 11 | def download_pyg_data(config): 12 | """ 13 | Downloads a dataset from the PyTorch Geometric library 14 | :param config: A dict containing info on the dataset to be downloaded 15 | :return: A tuple containing (root directory, dataset name, data directory) 16 | """ 17 | leaf_dir = config["kwargs"]["root"].split("/")[-1].strip() 18 | data_dir = osp.join(config["kwargs"]["root"], "" if config["name"] == leaf_dir else config["name"]) 19 | dst_path = osp.join(data_dir, "raw", "data.pt") 20 | if not osp.exists(dst_path): 21 | DatasetClass = config["class"] 22 | if config["name"] == "WikiCS": 23 | dataset = DatasetClass(data_dir, transform=T.NormalizeFeatures()) 24 | std, mean = torch.std_mean(dataset.data.x, dim=0, unbiased=False) 25 | dataset.data.x = (dataset.data.x - mean) / std 26 | dataset.data.edge_index = to_undirected(dataset.data.edge_index) 27 | else : 28 | dataset = DatasetClass(**config["kwargs"], transform=T.NormalizeFeatures()) 29 | utils.create_masks(data=dataset.data) 30 | torch.save((dataset.data, dataset.slices), dst_path) 31 | 32 | return config["kwargs"]["root"], config["name"], data_dir 33 | 34 | 35 | def download_data(root, name): 36 | """ 37 | Download data from different repositories. Currently only PyTorch Geometric is supported 38 | :param root: The root directory of the dataset 39 | :param name: The name of the dataset 40 | :return: 41 | """ 42 | config = utils.decide_config(root=root, name=name) 43 | if config["src"] == "pyg": 44 | return download_pyg_data(config) 45 | 46 | 47 | class Dataset(InMemoryDataset): 48 | 49 | """ 50 | A PyTorch InMemoryDataset to build multi-view dataset through graph data augmentation 51 | """ 52 | 53 | def __init__(self, root="data", name='cora', num_parts=1, final_parts=1, augumentation=None, transform=None, 54 | pre_transform=None): 55 | self.num_parts = num_parts 56 | self.final_parts = final_parts 57 | self.augumentation = augumentation 58 | self.root, self.name, self.data_dir = download_data(root=root, name=name) 59 | utils.create_dirs(self.dirs) 60 | super().__init__(root=self.data_dir, transform=transform, pre_transform=pre_transform) 61 | path = osp.join(self.data_dir, "processed", self.processed_file_names[0]) 62 | self.data, self.slices = torch.load(path) 63 | 64 | @property 65 | def raw_file_names(self): 66 | return ["data.pt"] 67 | 68 | @property 69 | def processed_file_names(self): 70 | if self.num_parts == 1: 71 | return [f'byg.data.aug.pt'] 72 | else: 73 | return [f'byg.data.aug.ip.{self.num_parts}.fp.{self.final_parts}.pt'] 74 | 75 | @property 76 | def raw_dir(self): 77 | return osp.join(self.data_dir, "raw") 78 | 79 | @property 80 | def processed_dir(self): 81 | return osp.join(self.data_dir, "processed") 82 | 83 | @property 84 | def model_dir(self): 85 | return osp.join(self.data_dir, "model") 86 | 87 | @property 88 | def result_dir(self): 89 | return osp.join(self.data_dir, "result") 90 | 91 | @property 92 | def dirs(self): 93 | return [self.raw_dir, self.processed_dir, self.model_dir, self.result_dir] 94 | 95 | 96 | def process_full_batch_data(self, data): 97 | """ 98 | Augmented view data generation using the full-batch data. 99 | :param view1data: 100 | :return: 101 | """ 102 | print("Processing full batch data") 103 | 104 | data = Data(edge_index=data.edge_index, edge_attr= data.edge_attr, 105 | x = data.x, y = data.y, 106 | train_mask=data.train_mask, val_mask=data.val_mask, test_mask=data.test_mask, 107 | num_nodes=data.num_nodes) 108 | return [data] 109 | 110 | def download(self): 111 | pass 112 | 113 | def process(self): 114 | """ 115 | Process either a full batch or cluster data. 116 | :return: 117 | """ 118 | processed_path = osp.join(self.processed_dir, self.processed_file_names[0]) 119 | if not osp.exists(processed_path): 120 | path = osp.join(self.raw_dir, self.raw_file_names[0]) 121 | data, _ = torch.load(path) 122 | edge_attr = data.edge_attr 123 | edge_attr = torch.ones(data.edge_index.shape[1]) if edge_attr is None else edge_attr 124 | data.edge_attr = edge_attr 125 | data_list = self.process_full_batch_data(data) 126 | data, slices = self.collate(data_list) 127 | torch.save((data, slices), processed_path) -------------------------------------------------------------------------------- /img/model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/img/model.PNG -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn import GCNConv 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import torch 6 | 7 | import numpy as np 8 | 9 | import copy 10 | 11 | """ 12 | The following code is borrowed from BYOL, SelfGNN 13 | and slightly modified for BGRL 14 | """ 15 | 16 | 17 | class EMA: 18 | def __init__(self, beta, epochs): 19 | super().__init__() 20 | self.beta = beta 21 | self.step = 0 22 | self.total_steps = epochs 23 | 24 | def update_average(self, old, new): 25 | if old is None: 26 | return new 27 | beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0 28 | self.step += 1 29 | return old * beta + (1 - beta) * new 30 | 31 | 32 | def loss_fn(x, y): 33 | x = F.normalize(x, dim=-1, p=2) 34 | y = F.normalize(y, dim=-1, p=2) 35 | return 2 - 2 * (x * y).sum(dim=-1) 36 | 37 | 38 | def update_moving_average(ema_updater, ma_model, current_model): 39 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 40 | old_weight, up_weight = ma_params.data, current_params.data 41 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 42 | 43 | 44 | def set_requires_grad(model, val): 45 | for p in model.parameters(): 46 | p.requires_grad = val 47 | 48 | 49 | class Encoder(nn.Module): 50 | 51 | def __init__(self, layer_config, dropout=None, project=False, **kwargs): 52 | super().__init__() 53 | 54 | self.conv1 = GCNConv(layer_config[0], layer_config[1]) 55 | self.bn1 = nn.BatchNorm1d(layer_config[1], momentum = 0.01) 56 | self.prelu1 = nn.PReLU() 57 | self.conv2 = GCNConv(layer_config[1],layer_config[2]) 58 | self.bn2 = nn.BatchNorm1d(layer_config[2], momentum = 0.01) 59 | self.prelu2 = nn.PReLU() 60 | 61 | def forward(self, x, edge_index, edge_weight=None): 62 | 63 | x = self.conv1(x, edge_index, edge_weight=edge_weight) 64 | x = self.prelu1(self.bn1(x)) 65 | x = self.conv2(x, edge_index, edge_weight=edge_weight) 66 | x = self.prelu2(self.bn2(x)) 67 | 68 | return x 69 | 70 | 71 | def init_weights(m): 72 | if type(m) == nn.Linear: 73 | torch.nn.init.xavier_uniform_(m.weight) 74 | m.bias.data.fill_(0.01) 75 | 76 | 77 | class BGRL(nn.Module): 78 | 79 | def __init__(self, layer_config, pred_hid, dropout=0.0, moving_average_decay=0.99, epochs=1000, **kwargs): 80 | super().__init__() 81 | self.student_encoder = Encoder(layer_config=layer_config, dropout=dropout, **kwargs) 82 | self.teacher_encoder = copy.deepcopy(self.student_encoder) 83 | set_requires_grad(self.teacher_encoder, False) 84 | self.teacher_ema_updater = EMA(moving_average_decay, epochs) 85 | rep_dim = layer_config[-1] 86 | self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_hid), nn.PReLU(), nn.Linear(pred_hid, rep_dim)) 87 | self.student_predictor.apply(init_weights) 88 | 89 | def reset_moving_average(self): 90 | del self.teacher_encoder 91 | self.teacher_encoder = None 92 | 93 | def update_moving_average(self): 94 | assert self.teacher_encoder is not None, 'teacher encoder has not been created yet' 95 | update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder) 96 | 97 | def forward(self, x1, x2, edge_index_v1, edge_index_v2, edge_weight_v1=None, edge_weight_v2=None): 98 | v1_student = self.student_encoder(x=x1, edge_index=edge_index_v1, edge_weight=edge_weight_v1) 99 | v2_student = self.student_encoder(x=x2, edge_index=edge_index_v2, edge_weight=edge_weight_v2) 100 | 101 | v1_pred = self.student_predictor(v1_student) 102 | v2_pred = self.student_predictor(v2_student) 103 | 104 | with torch.no_grad(): 105 | v1_teacher = self.teacher_encoder(x=x1, edge_index=edge_index_v1, edge_weight=edge_weight_v1) 106 | v2_teacher = self.teacher_encoder(x=x2, edge_index=edge_index_v2, edge_weight=edge_weight_v2) 107 | 108 | loss1 = loss_fn(v1_pred, v2_teacher.detach()) 109 | loss2 = loss_fn(v2_pred, v1_teacher.detach()) 110 | 111 | loss = loss1 + loss2 112 | return v1_student, v2_student, loss.mean() 113 | 114 | 115 | class LogisticRegression(nn.Module): 116 | def __init__(self, num_dim, num_class): 117 | super().__init__() 118 | self.linear = nn.Linear(num_dim, num_class) 119 | torch.nn.init.xavier_uniform_(self.linear.weight.data) 120 | self.linear.bias.data.fill_(0.0) 121 | self.cross_entropy = nn.CrossEntropyLoss() 122 | 123 | def forward(self, x, y): 124 | 125 | logits = self.linear(x) 126 | loss = self.cross_entropy(logits, y) 127 | 128 | return logits, loss -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import optim 5 | from tensorboardX import SummaryWriter 6 | torch.manual_seed(0) 7 | 8 | import models 9 | import utils 10 | import data 11 | 12 | import os 13 | import sys 14 | 15 | class ModelTrainer: 16 | 17 | def __init__(self, args): 18 | self._args = args 19 | self._init() 20 | self.writer = SummaryWriter(log_dir="runs/BGRL_dataset({})".format(args.name)) 21 | 22 | def _init(self): 23 | args = self._args 24 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) 25 | self._device = f'cuda:{args.device}' if torch.cuda.is_available() else "cpu" 26 | self._dataset = data.Dataset(root=args.root, name=args.name)[0] 27 | print(f"Data: {self._dataset}") 28 | hidden_layers = [int(l) for l in args.layers] 29 | layers = [self._dataset.x.shape[1]] + hidden_layers 30 | self._model = models.BGRL(layer_config=layers, pred_hid=args.pred_hid, dropout=args.dropout, epochs=args.epochs).to(self._device) 31 | print(self._model) 32 | 33 | self._optimizer = optim.AdamW(params=self._model.parameters(), lr=args.lr, weight_decay= 1e-5) 34 | # learning rate 35 | scheduler = lambda epoch: epoch / 1000 if epoch < 1000 \ 36 | else ( 1 + np.cos((epoch-1000) * np.pi / (self._args.epochs - 1000))) * 0.5 37 | self._scheduler = optim.lr_scheduler.LambdaLR(self._optimizer, lr_lambda = scheduler) 38 | 39 | def train(self): 40 | # get initial test results 41 | print("start training!") 42 | print("Initial Evaluation...") 43 | self.infer_embeddings() 44 | dev_best, dev_std_best, test_best, test_std_best = self.evaluate() 45 | self.writer.add_scalar("accs/val_acc", dev_best, 0) 46 | self.writer.add_scalar("accs/test_acc", test_best, 0) 47 | print("validation: {:.4f}, test: {:.4f}".format(dev_best, test_best)) 48 | 49 | # start training 50 | self._model.train() 51 | for epoch in range(self._args.epochs): 52 | 53 | self._dataset.to(self._device) 54 | 55 | augmentation = utils.Augmentation(float(self._args.aug_params[0]),float(self._args.aug_params[1]),float(self._args.aug_params[2]),float(self._args.aug_params[3])) 56 | view1, view2 = augmentation._feature_masking(self._dataset, self._device) 57 | 58 | v1_output, v2_output, loss = self._model( 59 | x1=view1.x, x2=view2.x, edge_index_v1=view1.edge_index, edge_index_v2=view2.edge_index, 60 | edge_weight_v1=view1.edge_attr, edge_weight_v2=view2.edge_attr) 61 | 62 | self._optimizer.zero_grad() 63 | loss.backward() 64 | self._optimizer.step() 65 | self._scheduler.step() 66 | self._model.update_moving_average() 67 | sys.stdout.write('\rEpoch {}/{}, loss {:.4f}, lr {}'.format(epoch + 1, self._args.epochs, loss.data, self._optimizer.param_groups[0]['lr'])) 68 | sys.stdout.flush() 69 | 70 | if (epoch + 1) % self._args.cache_step == 0: 71 | print("") 72 | print("\nEvaluating {}th epoch..".format(epoch + 1)) 73 | 74 | self.infer_embeddings() 75 | dev_acc, dev_std, test_acc, test_std = self.evaluate() 76 | 77 | if dev_best < dev_acc: 78 | dev_best = dev_acc 79 | dev_std_best = dev_std 80 | test_best = test_acc 81 | test_std_best = test_std 82 | 83 | self.writer.add_scalar("stats/learning_rate", self._optimizer.param_groups[0]["lr"] , epoch + 1) 84 | self.writer.add_scalar("accs/val_acc", dev_acc, epoch + 1) 85 | self.writer.add_scalar("accs/test_acc", test_acc, epoch + 1) 86 | print("validation: {:.4f}, test: {:.4f} \n".format(dev_acc, test_acc)) 87 | 88 | 89 | f = open("BGRL_dataset({})_node.txt".format(self._args.name), "a") 90 | f.write("best valid acc : {} best valid std : {} best test acc : {} best test std : {} \n".format(dev_best, dev_std_best, test_best, test_std_best)) 91 | f.close() 92 | 93 | print() 94 | print("Training Done!") 95 | 96 | 97 | def infer_embeddings(self): 98 | 99 | self._model.train(False) 100 | self._embeddings = self._labels = None 101 | 102 | self._dataset.to(self._device) 103 | v1_output, v2_output, _ = self._model( 104 | x1=self._dataset.x, x2=self._dataset.x, 105 | edge_index_v1=self._dataset.edge_index, 106 | edge_index_v2=self._dataset.edge_index, 107 | edge_weight_v1=self._dataset.edge_attr, 108 | edge_weight_v2=self._dataset.edge_attr) 109 | emb = v1_output.detach() 110 | y = self._dataset.y.detach() 111 | if self._embeddings is None: 112 | self._embeddings, self._labels = emb, y 113 | else: 114 | self._embeddings = torch.cat([self._embeddings, emb]) 115 | self._labels = torch.cat([self._labels, y]) 116 | 117 | 118 | def evaluate(self): 119 | """ 120 | Used for producing the results of Experiment 3.2 in the BGRL paper. 121 | """ 122 | emb_dim, num_class = self._embeddings.shape[1], self._labels.unique().shape[0] 123 | 124 | dev_accs, test_accs = [], [] 125 | 126 | for i in range(20): 127 | 128 | self._train_mask = self._dataset.train_mask[i] 129 | self._dev_mask = self._dataset.val_mask[i] 130 | if self._args.name == "WikiCS": 131 | self._test_mask = self._dataset.test_mask 132 | else : 133 | self._test_mask = self._dataset.test_mask[i] 134 | 135 | classifier = models.LogisticRegression(emb_dim, num_class).to(self._device) 136 | optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01, weight_decay=0.0) 137 | 138 | for epoch in range(100): 139 | classifier.train() 140 | logits, loss = classifier(self._embeddings[self._train_mask], self._labels[self._train_mask]) 141 | optimizer.zero_grad() 142 | loss.backward() 143 | optimizer.step() 144 | 145 | dev_logits, _ = classifier(self._embeddings[self._dev_mask], self._labels[self._dev_mask]) 146 | test_logits, _ = classifier(self._embeddings[self._test_mask], self._labels[self._test_mask]) 147 | dev_preds = torch.argmax(dev_logits, dim=1) 148 | test_preds = torch.argmax(test_logits, dim=1) 149 | 150 | dev_acc = (torch.sum(dev_preds == self._labels[self._dev_mask]).float() / self._labels[self._dev_mask].shape[0]).detach().cpu().numpy() 151 | test_acc = (torch.sum(test_preds == self._labels[self._test_mask]).float() / self._labels[self._test_mask].shape[0]).detach().cpu().numpy() 152 | 153 | dev_accs.append(dev_acc * 100) 154 | test_accs.append(test_acc * 100) 155 | 156 | dev_accs = np.stack(dev_accs) 157 | test_accs = np.stack(test_accs) 158 | 159 | dev_acc, dev_std = dev_accs.mean(), dev_accs.std() 160 | test_acc, test_std = test_accs.mean(), test_accs.std() 161 | 162 | return dev_acc, dev_std, test_acc, test_std 163 | 164 | 165 | def train_eval(args): 166 | trainer = ModelTrainer(args) 167 | trainer.train() 168 | trainer.writer.close() 169 | 170 | 171 | def main(): 172 | args = utils.parse_args() 173 | train_eval(args) 174 | 175 | 176 | if __name__ == "__main__": 177 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import Planetoid, Coauthor, Amazon, WikiCS 2 | from torch_geometric.utils import dropout_adj 3 | 4 | import os.path as osp 5 | import os 6 | 7 | import argparse 8 | 9 | import numpy as np 10 | 11 | import torch 12 | 13 | """ 14 | The Following code is borrowed from SelfGNN 15 | """ 16 | class Augmentation: 17 | 18 | def __init__(self, p_f1 = 0.2, p_f2 = 0.1, p_e1 = 0.2, p_e2 = 0.3): 19 | """ 20 | two simple graph augmentation functions --> "Node feature masking" and "Edge masking" 21 | Random binary node feature mask following Bernoulli distribution with parameter p_f 22 | Random binary edge mask following Bernoulli distribution with parameter p_e 23 | """ 24 | self.p_f1 = p_f1 25 | self.p_f2 = p_f2 26 | self.p_e1 = p_e1 27 | self.p_e2 = p_e2 28 | self.method = "BGRL" 29 | 30 | def _feature_masking(self, data, device): 31 | feat_mask1 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f1 32 | feat_mask2 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f2 33 | feat_mask1, feat_mask2 = feat_mask1.to(device), feat_mask2.to(device) 34 | x1, x2 = data.x.clone(), data.x.clone() 35 | x1, x2 = x1 * feat_mask1, x2 * feat_mask2 36 | 37 | edge_index1, edge_attr1 = dropout_adj(data.edge_index, data.edge_attr, p = self.p_e1) 38 | edge_index2, edge_attr2 = dropout_adj(data.edge_index, data.edge_attr, p = self.p_e2) 39 | 40 | new_data1, new_data2 = data.clone(), data.clone() 41 | new_data1.x, new_data2.x = x1, x2 42 | new_data1.edge_index, new_data2.edge_index = edge_index1, edge_index2 43 | new_data1.edge_attr , new_data2.edge_attr = edge_attr1, edge_attr2 44 | 45 | return new_data1, new_data2 46 | 47 | def __call__(self, data): 48 | 49 | return self._feature_masking(data) 50 | 51 | 52 | def parse_args(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--root", "-r", type=str, default="data", 55 | help="Path to data directory, where all the datasets will be placed. Default is 'data'") 56 | parser.add_argument("--name", "-n",type=str, default="WikiCS", 57 | help="Name of the dataset. Supported names are: cora, citeseer, pubmed, photo, computers, cs, and physics") 58 | parser.add_argument("--layers", "-l", nargs="+", default=[ 59 | 512, 256], help="The number of units of each layer of the GNN. Default is [512, 128]") 60 | parser.add_argument("--pred_hid", '-ph', type=int, 61 | default=512, help="The number of hidden units of layer of the predictor. Default is 512") 62 | parser.add_argument("--init-parts", "-ip", type=int, default=1, 63 | help="The number of initial partitions. Default is 1. Applicable for ClusterSelfGNN") 64 | parser.add_argument("--final-parts", "-fp", type=int, default=1, 65 | help="The number of final partitions. Default is 1. Applicable for ClusterSelfGNN") 66 | parser.add_argument("--aug_params", "-p", nargs="+", default=[ 67 | 0.3, 0.4, 0.3, 0.2], help="Hyperparameters for augmentation (p_f1, p_f2, p_e1, p_e2). Default is [0.2, 0.1, 0.2, 0.3]") 68 | parser.add_argument("--lr", '-lr', type=float, default=0.00001, 69 | help="Learning rate. Default is 0.0001.") 70 | parser.add_argument("--dropout", "-do", type=float, 71 | default=0.0, help="Dropout rate. Default is 0.2") 72 | parser.add_argument("--cache-step", '-cs', type=int, default=10, 73 | help="The step size to cache the model, that is, every cache_step the model is persisted. Default is 100.") 74 | parser.add_argument("--epochs", '-e', type=int, 75 | default=20, help="The number of epochs") 76 | parser.add_argument("--device", '-d', type=int, 77 | default=0, help="GPU to use") 78 | return parser.parse_args() 79 | 80 | 81 | def decide_config(root, name): 82 | """ 83 | Create a configuration to download datasets 84 | :param root: A path to a root directory where data will be stored 85 | :param name: The name of the dataset to be downloaded 86 | :return: A modified root dir, the name of the dataset class, and parameters associated to the class 87 | """ 88 | name = name.lower() 89 | if name == 'cora' or name == 'citeseer' or name == "pubmed": 90 | root = osp.join(root, "pyg", "planetoid") 91 | params = {"kwargs": {"root": root, "name": name}, 92 | "name": name, "class": Planetoid, "src": "pyg"} 93 | elif name == "computers": 94 | name = "Computers" 95 | root = osp.join(root, "pyg") 96 | params = {"kwargs": {"root": root, "name": name}, 97 | "name": name, "class": Amazon, "src": "pyg"} 98 | elif name == "photo": 99 | name = "Photo" 100 | root = osp.join(root, "pyg") 101 | params = {"kwargs": {"root": root, "name": name}, 102 | "name": name, "class": Amazon, "src": "pyg"} 103 | elif name == "cs" : 104 | name = "CS" 105 | root = osp.join(root, "pyg") 106 | params = {"kwargs": {"root": root, "name": name}, 107 | "name": name, "class": Coauthor, "src": "pyg"} 108 | elif name == "physics": 109 | name = "Physics" 110 | root = osp.join(root, "pyg") 111 | params = {"kwargs": {"root": root, "name": name}, 112 | "name": name, "class": Coauthor, "src": "pyg"} 113 | elif name == "wikics": 114 | name = "WikiCS" 115 | root = osp.join(root, "pyg") 116 | params = {"kwargs": {"root": root}, 117 | "name": name, "class": WikiCS, "src": "pyg"} 118 | else: 119 | raise Exception( 120 | f"Unknown dataset name {name}, name has to be one of the following 'cora', 'citeseer', 'pubmed', 'photo', 'computers', 'cs', 'physics'") 121 | return params 122 | 123 | 124 | def create_dirs(dirs): 125 | for dir_tree in dirs: 126 | sub_dirs = dir_tree.split("/") 127 | path = "" 128 | for sub_dir in sub_dirs: 129 | path = osp.join(path, sub_dir) 130 | os.makedirs(path, exist_ok=True) 131 | 132 | 133 | def create_masks(data): 134 | """ 135 | Splits data into training, validation, and test splits in a stratified manner if 136 | it is not already splitted. Each split is associated with a mask vector, which 137 | specifies the indices for that split. The data will be modified in-place 138 | :param data: Data object 139 | :return: The modified data 140 | """ 141 | if not hasattr(data, "val_mask"): 142 | 143 | data.train_mask = data.dev_mask = data.test_mask = None 144 | 145 | for i in range(20): 146 | labels = data.y.numpy() 147 | dev_size = int(labels.shape[0] * 0.1) 148 | test_size = int(labels.shape[0] * 0.8) 149 | 150 | perm = np.random.permutation(labels.shape[0]) 151 | test_index = perm[:test_size] 152 | dev_index = perm[test_size:test_size+dev_size] 153 | 154 | data_index = np.arange(labels.shape[0]) 155 | test_mask = torch.tensor(np.in1d(data_index, test_index), dtype=torch.bool) 156 | dev_mask = torch.tensor(np.in1d(data_index, dev_index), dtype=torch.bool) 157 | train_mask = ~(dev_mask + test_mask) 158 | test_mask = test_mask.reshape(1, -1) 159 | dev_mask = dev_mask.reshape(1, -1) 160 | train_mask = train_mask.reshape(1, -1) 161 | 162 | if data.train_mask is None : 163 | data.train_mask = train_mask 164 | data.val_mask = dev_mask 165 | data.test_mask = test_mask 166 | else : 167 | data.train_mask = torch.cat((data.train_mask, train_mask), dim = 0) 168 | data.val_mask = torch.cat((data.val_mask, dev_mask), dim = 0) 169 | data.test_mask = torch.cat((data.test_mask, test_mask), dim = 0) 170 | 171 | else : 172 | data.train_mask = data.train_mask.T 173 | data.val_mask = data.val_mask.T 174 | 175 | return data --------------------------------------------------------------------------------