├── README.md
├── __pycache__
├── data.cpython-37.pyc
├── models.cpython-37.pyc
└── utils.cpython-37.pyc
├── data.py
├── img
└── model.PNG
├── models.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # BGRL_Pytorch
2 | Implementation of Large-Scale Representation Learning on Graphs via Bootstrapping.
3 |
4 | A PyTorch implementation of "Large-Scale Representation Learning on Graphs via Bootstrapping" paper, accepted in ICLR 2021 Workshop
5 |
6 |
7 |
8 | ## Hyperparameters for training BGRL
9 | Following Options can be passed to `train.py`
10 |
11 |
12 | `--layers:` or `-l:`
13 | one or more integer values specifying the number of units for each GNN layer. Default is 512 256.
14 | usage example :`--layers 512 256`
15 |
16 |
17 | `--aug_params:` or `-p:`
18 | four float values specifying the hyperparameters for graph augmentation (p_f1, p_f2, p_e1, p_e2). Default is 0.2 0.1 0.2 0.3.
19 | usage example : `--aug_params 0.2 0.1 0.2 0.3`
20 |
21 |
22 |
23 | |params|WikiCS|Am.Computers|Am.Photos|Co.CS|Co.Physics|
24 | |------|------|------------|---------|-----|----------|
25 | |p_f1 |0.2 |0.2 |0.1 |0.3 |0.1 |
26 | |p_f2 |0.1 |0.1 |0.2 |0.4 |0.4 |
27 | |p_e1 |0.2 |0.5 |0.4 |0.3 |0.4 |
28 | |p_e2 |0.3 |0.4 |0.1 |0.2 |0.1 |
29 | |embedding size|256|128|256|256|128|
30 | |encoder hidden size|512|256|512|512|256|
31 | |predictor hidden size|512|512|512|512|512|
32 | * Hyperparameters are from original paper
33 |
34 |
35 | ## Experimental Results
36 | |WikiCS|Am.Computers|Am.Photos|Co.CS|Co.Physics|
37 | |------|------------|---------|-----|----------|
38 | |79.50 |88.21 |92.76 |92.49|94.89 |
39 |
40 |
41 | ## Codes borrowed from
42 | Codes are borrowed from BYOL and SelfGNN
43 |
44 |
45 | | name | Implementation Code | Paper |
46 | | ----------- | ------------------- | ------- |
47 | | `Bootstrap Your Own Latent`| Implementation| paper|
48 | | `SelfGNN`| Implementation| paper|
49 |
--------------------------------------------------------------------------------
/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data, InMemoryDataset
3 | import torch_geometric.transforms as T
4 | from torch_geometric.utils import to_undirected
5 |
6 | import os.path as osp
7 |
8 | import utils
9 |
10 |
11 | def download_pyg_data(config):
12 | """
13 | Downloads a dataset from the PyTorch Geometric library
14 | :param config: A dict containing info on the dataset to be downloaded
15 | :return: A tuple containing (root directory, dataset name, data directory)
16 | """
17 | leaf_dir = config["kwargs"]["root"].split("/")[-1].strip()
18 | data_dir = osp.join(config["kwargs"]["root"], "" if config["name"] == leaf_dir else config["name"])
19 | dst_path = osp.join(data_dir, "raw", "data.pt")
20 | if not osp.exists(dst_path):
21 | DatasetClass = config["class"]
22 | if config["name"] == "WikiCS":
23 | dataset = DatasetClass(data_dir, transform=T.NormalizeFeatures())
24 | std, mean = torch.std_mean(dataset.data.x, dim=0, unbiased=False)
25 | dataset.data.x = (dataset.data.x - mean) / std
26 | dataset.data.edge_index = to_undirected(dataset.data.edge_index)
27 | else :
28 | dataset = DatasetClass(**config["kwargs"], transform=T.NormalizeFeatures())
29 | utils.create_masks(data=dataset.data)
30 | torch.save((dataset.data, dataset.slices), dst_path)
31 |
32 | return config["kwargs"]["root"], config["name"], data_dir
33 |
34 |
35 | def download_data(root, name):
36 | """
37 | Download data from different repositories. Currently only PyTorch Geometric is supported
38 | :param root: The root directory of the dataset
39 | :param name: The name of the dataset
40 | :return:
41 | """
42 | config = utils.decide_config(root=root, name=name)
43 | if config["src"] == "pyg":
44 | return download_pyg_data(config)
45 |
46 |
47 | class Dataset(InMemoryDataset):
48 |
49 | """
50 | A PyTorch InMemoryDataset to build multi-view dataset through graph data augmentation
51 | """
52 |
53 | def __init__(self, root="data", name='cora', num_parts=1, final_parts=1, augumentation=None, transform=None,
54 | pre_transform=None):
55 | self.num_parts = num_parts
56 | self.final_parts = final_parts
57 | self.augumentation = augumentation
58 | self.root, self.name, self.data_dir = download_data(root=root, name=name)
59 | utils.create_dirs(self.dirs)
60 | super().__init__(root=self.data_dir, transform=transform, pre_transform=pre_transform)
61 | path = osp.join(self.data_dir, "processed", self.processed_file_names[0])
62 | self.data, self.slices = torch.load(path)
63 |
64 | @property
65 | def raw_file_names(self):
66 | return ["data.pt"]
67 |
68 | @property
69 | def processed_file_names(self):
70 | if self.num_parts == 1:
71 | return [f'byg.data.aug.pt']
72 | else:
73 | return [f'byg.data.aug.ip.{self.num_parts}.fp.{self.final_parts}.pt']
74 |
75 | @property
76 | def raw_dir(self):
77 | return osp.join(self.data_dir, "raw")
78 |
79 | @property
80 | def processed_dir(self):
81 | return osp.join(self.data_dir, "processed")
82 |
83 | @property
84 | def model_dir(self):
85 | return osp.join(self.data_dir, "model")
86 |
87 | @property
88 | def result_dir(self):
89 | return osp.join(self.data_dir, "result")
90 |
91 | @property
92 | def dirs(self):
93 | return [self.raw_dir, self.processed_dir, self.model_dir, self.result_dir]
94 |
95 |
96 | def process_full_batch_data(self, data):
97 | """
98 | Augmented view data generation using the full-batch data.
99 | :param view1data:
100 | :return:
101 | """
102 | print("Processing full batch data")
103 |
104 | data = Data(edge_index=data.edge_index, edge_attr= data.edge_attr,
105 | x = data.x, y = data.y,
106 | train_mask=data.train_mask, val_mask=data.val_mask, test_mask=data.test_mask,
107 | num_nodes=data.num_nodes)
108 | return [data]
109 |
110 | def download(self):
111 | pass
112 |
113 | def process(self):
114 | """
115 | Process either a full batch or cluster data.
116 | :return:
117 | """
118 | processed_path = osp.join(self.processed_dir, self.processed_file_names[0])
119 | if not osp.exists(processed_path):
120 | path = osp.join(self.raw_dir, self.raw_file_names[0])
121 | data, _ = torch.load(path)
122 | edge_attr = data.edge_attr
123 | edge_attr = torch.ones(data.edge_index.shape[1]) if edge_attr is None else edge_attr
124 | data.edge_attr = edge_attr
125 | data_list = self.process_full_batch_data(data)
126 | data, slices = self.collate(data_list)
127 | torch.save((data, slices), processed_path)
--------------------------------------------------------------------------------
/img/model.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Namkyeong/BGRL_Pytorch/04026813a89a31d29c035ac3f67646013eecc4f9/img/model.PNG
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.nn import GCNConv
2 |
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | import torch
6 |
7 | import numpy as np
8 |
9 | import copy
10 |
11 | """
12 | The following code is borrowed from BYOL, SelfGNN
13 | and slightly modified for BGRL
14 | """
15 |
16 |
17 | class EMA:
18 | def __init__(self, beta, epochs):
19 | super().__init__()
20 | self.beta = beta
21 | self.step = 0
22 | self.total_steps = epochs
23 |
24 | def update_average(self, old, new):
25 | if old is None:
26 | return new
27 | beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0
28 | self.step += 1
29 | return old * beta + (1 - beta) * new
30 |
31 |
32 | def loss_fn(x, y):
33 | x = F.normalize(x, dim=-1, p=2)
34 | y = F.normalize(y, dim=-1, p=2)
35 | return 2 - 2 * (x * y).sum(dim=-1)
36 |
37 |
38 | def update_moving_average(ema_updater, ma_model, current_model):
39 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
40 | old_weight, up_weight = ma_params.data, current_params.data
41 | ma_params.data = ema_updater.update_average(old_weight, up_weight)
42 |
43 |
44 | def set_requires_grad(model, val):
45 | for p in model.parameters():
46 | p.requires_grad = val
47 |
48 |
49 | class Encoder(nn.Module):
50 |
51 | def __init__(self, layer_config, dropout=None, project=False, **kwargs):
52 | super().__init__()
53 |
54 | self.conv1 = GCNConv(layer_config[0], layer_config[1])
55 | self.bn1 = nn.BatchNorm1d(layer_config[1], momentum = 0.01)
56 | self.prelu1 = nn.PReLU()
57 | self.conv2 = GCNConv(layer_config[1],layer_config[2])
58 | self.bn2 = nn.BatchNorm1d(layer_config[2], momentum = 0.01)
59 | self.prelu2 = nn.PReLU()
60 |
61 | def forward(self, x, edge_index, edge_weight=None):
62 |
63 | x = self.conv1(x, edge_index, edge_weight=edge_weight)
64 | x = self.prelu1(self.bn1(x))
65 | x = self.conv2(x, edge_index, edge_weight=edge_weight)
66 | x = self.prelu2(self.bn2(x))
67 |
68 | return x
69 |
70 |
71 | def init_weights(m):
72 | if type(m) == nn.Linear:
73 | torch.nn.init.xavier_uniform_(m.weight)
74 | m.bias.data.fill_(0.01)
75 |
76 |
77 | class BGRL(nn.Module):
78 |
79 | def __init__(self, layer_config, pred_hid, dropout=0.0, moving_average_decay=0.99, epochs=1000, **kwargs):
80 | super().__init__()
81 | self.student_encoder = Encoder(layer_config=layer_config, dropout=dropout, **kwargs)
82 | self.teacher_encoder = copy.deepcopy(self.student_encoder)
83 | set_requires_grad(self.teacher_encoder, False)
84 | self.teacher_ema_updater = EMA(moving_average_decay, epochs)
85 | rep_dim = layer_config[-1]
86 | self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_hid), nn.PReLU(), nn.Linear(pred_hid, rep_dim))
87 | self.student_predictor.apply(init_weights)
88 |
89 | def reset_moving_average(self):
90 | del self.teacher_encoder
91 | self.teacher_encoder = None
92 |
93 | def update_moving_average(self):
94 | assert self.teacher_encoder is not None, 'teacher encoder has not been created yet'
95 | update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
96 |
97 | def forward(self, x1, x2, edge_index_v1, edge_index_v2, edge_weight_v1=None, edge_weight_v2=None):
98 | v1_student = self.student_encoder(x=x1, edge_index=edge_index_v1, edge_weight=edge_weight_v1)
99 | v2_student = self.student_encoder(x=x2, edge_index=edge_index_v2, edge_weight=edge_weight_v2)
100 |
101 | v1_pred = self.student_predictor(v1_student)
102 | v2_pred = self.student_predictor(v2_student)
103 |
104 | with torch.no_grad():
105 | v1_teacher = self.teacher_encoder(x=x1, edge_index=edge_index_v1, edge_weight=edge_weight_v1)
106 | v2_teacher = self.teacher_encoder(x=x2, edge_index=edge_index_v2, edge_weight=edge_weight_v2)
107 |
108 | loss1 = loss_fn(v1_pred, v2_teacher.detach())
109 | loss2 = loss_fn(v2_pred, v1_teacher.detach())
110 |
111 | loss = loss1 + loss2
112 | return v1_student, v2_student, loss.mean()
113 |
114 |
115 | class LogisticRegression(nn.Module):
116 | def __init__(self, num_dim, num_class):
117 | super().__init__()
118 | self.linear = nn.Linear(num_dim, num_class)
119 | torch.nn.init.xavier_uniform_(self.linear.weight.data)
120 | self.linear.bias.data.fill_(0.0)
121 | self.cross_entropy = nn.CrossEntropyLoss()
122 |
123 | def forward(self, x, y):
124 |
125 | logits = self.linear(x)
126 | loss = self.cross_entropy(logits, y)
127 |
128 | return logits, loss
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | from torch import optim
5 | from tensorboardX import SummaryWriter
6 | torch.manual_seed(0)
7 |
8 | import models
9 | import utils
10 | import data
11 |
12 | import os
13 | import sys
14 |
15 | class ModelTrainer:
16 |
17 | def __init__(self, args):
18 | self._args = args
19 | self._init()
20 | self.writer = SummaryWriter(log_dir="runs/BGRL_dataset({})".format(args.name))
21 |
22 | def _init(self):
23 | args = self._args
24 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device)
25 | self._device = f'cuda:{args.device}' if torch.cuda.is_available() else "cpu"
26 | self._dataset = data.Dataset(root=args.root, name=args.name)[0]
27 | print(f"Data: {self._dataset}")
28 | hidden_layers = [int(l) for l in args.layers]
29 | layers = [self._dataset.x.shape[1]] + hidden_layers
30 | self._model = models.BGRL(layer_config=layers, pred_hid=args.pred_hid, dropout=args.dropout, epochs=args.epochs).to(self._device)
31 | print(self._model)
32 |
33 | self._optimizer = optim.AdamW(params=self._model.parameters(), lr=args.lr, weight_decay= 1e-5)
34 | # learning rate
35 | scheduler = lambda epoch: epoch / 1000 if epoch < 1000 \
36 | else ( 1 + np.cos((epoch-1000) * np.pi / (self._args.epochs - 1000))) * 0.5
37 | self._scheduler = optim.lr_scheduler.LambdaLR(self._optimizer, lr_lambda = scheduler)
38 |
39 | def train(self):
40 | # get initial test results
41 | print("start training!")
42 | print("Initial Evaluation...")
43 | self.infer_embeddings()
44 | dev_best, dev_std_best, test_best, test_std_best = self.evaluate()
45 | self.writer.add_scalar("accs/val_acc", dev_best, 0)
46 | self.writer.add_scalar("accs/test_acc", test_best, 0)
47 | print("validation: {:.4f}, test: {:.4f}".format(dev_best, test_best))
48 |
49 | # start training
50 | self._model.train()
51 | for epoch in range(self._args.epochs):
52 |
53 | self._dataset.to(self._device)
54 |
55 | augmentation = utils.Augmentation(float(self._args.aug_params[0]),float(self._args.aug_params[1]),float(self._args.aug_params[2]),float(self._args.aug_params[3]))
56 | view1, view2 = augmentation._feature_masking(self._dataset, self._device)
57 |
58 | v1_output, v2_output, loss = self._model(
59 | x1=view1.x, x2=view2.x, edge_index_v1=view1.edge_index, edge_index_v2=view2.edge_index,
60 | edge_weight_v1=view1.edge_attr, edge_weight_v2=view2.edge_attr)
61 |
62 | self._optimizer.zero_grad()
63 | loss.backward()
64 | self._optimizer.step()
65 | self._scheduler.step()
66 | self._model.update_moving_average()
67 | sys.stdout.write('\rEpoch {}/{}, loss {:.4f}, lr {}'.format(epoch + 1, self._args.epochs, loss.data, self._optimizer.param_groups[0]['lr']))
68 | sys.stdout.flush()
69 |
70 | if (epoch + 1) % self._args.cache_step == 0:
71 | print("")
72 | print("\nEvaluating {}th epoch..".format(epoch + 1))
73 |
74 | self.infer_embeddings()
75 | dev_acc, dev_std, test_acc, test_std = self.evaluate()
76 |
77 | if dev_best < dev_acc:
78 | dev_best = dev_acc
79 | dev_std_best = dev_std
80 | test_best = test_acc
81 | test_std_best = test_std
82 |
83 | self.writer.add_scalar("stats/learning_rate", self._optimizer.param_groups[0]["lr"] , epoch + 1)
84 | self.writer.add_scalar("accs/val_acc", dev_acc, epoch + 1)
85 | self.writer.add_scalar("accs/test_acc", test_acc, epoch + 1)
86 | print("validation: {:.4f}, test: {:.4f} \n".format(dev_acc, test_acc))
87 |
88 |
89 | f = open("BGRL_dataset({})_node.txt".format(self._args.name), "a")
90 | f.write("best valid acc : {} best valid std : {} best test acc : {} best test std : {} \n".format(dev_best, dev_std_best, test_best, test_std_best))
91 | f.close()
92 |
93 | print()
94 | print("Training Done!")
95 |
96 |
97 | def infer_embeddings(self):
98 |
99 | self._model.train(False)
100 | self._embeddings = self._labels = None
101 |
102 | self._dataset.to(self._device)
103 | v1_output, v2_output, _ = self._model(
104 | x1=self._dataset.x, x2=self._dataset.x,
105 | edge_index_v1=self._dataset.edge_index,
106 | edge_index_v2=self._dataset.edge_index,
107 | edge_weight_v1=self._dataset.edge_attr,
108 | edge_weight_v2=self._dataset.edge_attr)
109 | emb = v1_output.detach()
110 | y = self._dataset.y.detach()
111 | if self._embeddings is None:
112 | self._embeddings, self._labels = emb, y
113 | else:
114 | self._embeddings = torch.cat([self._embeddings, emb])
115 | self._labels = torch.cat([self._labels, y])
116 |
117 |
118 | def evaluate(self):
119 | """
120 | Used for producing the results of Experiment 3.2 in the BGRL paper.
121 | """
122 | emb_dim, num_class = self._embeddings.shape[1], self._labels.unique().shape[0]
123 |
124 | dev_accs, test_accs = [], []
125 |
126 | for i in range(20):
127 |
128 | self._train_mask = self._dataset.train_mask[i]
129 | self._dev_mask = self._dataset.val_mask[i]
130 | if self._args.name == "WikiCS":
131 | self._test_mask = self._dataset.test_mask
132 | else :
133 | self._test_mask = self._dataset.test_mask[i]
134 |
135 | classifier = models.LogisticRegression(emb_dim, num_class).to(self._device)
136 | optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01, weight_decay=0.0)
137 |
138 | for epoch in range(100):
139 | classifier.train()
140 | logits, loss = classifier(self._embeddings[self._train_mask], self._labels[self._train_mask])
141 | optimizer.zero_grad()
142 | loss.backward()
143 | optimizer.step()
144 |
145 | dev_logits, _ = classifier(self._embeddings[self._dev_mask], self._labels[self._dev_mask])
146 | test_logits, _ = classifier(self._embeddings[self._test_mask], self._labels[self._test_mask])
147 | dev_preds = torch.argmax(dev_logits, dim=1)
148 | test_preds = torch.argmax(test_logits, dim=1)
149 |
150 | dev_acc = (torch.sum(dev_preds == self._labels[self._dev_mask]).float() / self._labels[self._dev_mask].shape[0]).detach().cpu().numpy()
151 | test_acc = (torch.sum(test_preds == self._labels[self._test_mask]).float() / self._labels[self._test_mask].shape[0]).detach().cpu().numpy()
152 |
153 | dev_accs.append(dev_acc * 100)
154 | test_accs.append(test_acc * 100)
155 |
156 | dev_accs = np.stack(dev_accs)
157 | test_accs = np.stack(test_accs)
158 |
159 | dev_acc, dev_std = dev_accs.mean(), dev_accs.std()
160 | test_acc, test_std = test_accs.mean(), test_accs.std()
161 |
162 | return dev_acc, dev_std, test_acc, test_std
163 |
164 |
165 | def train_eval(args):
166 | trainer = ModelTrainer(args)
167 | trainer.train()
168 | trainer.writer.close()
169 |
170 |
171 | def main():
172 | args = utils.parse_args()
173 | train_eval(args)
174 |
175 |
176 | if __name__ == "__main__":
177 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import Planetoid, Coauthor, Amazon, WikiCS
2 | from torch_geometric.utils import dropout_adj
3 |
4 | import os.path as osp
5 | import os
6 |
7 | import argparse
8 |
9 | import numpy as np
10 |
11 | import torch
12 |
13 | """
14 | The Following code is borrowed from SelfGNN
15 | """
16 | class Augmentation:
17 |
18 | def __init__(self, p_f1 = 0.2, p_f2 = 0.1, p_e1 = 0.2, p_e2 = 0.3):
19 | """
20 | two simple graph augmentation functions --> "Node feature masking" and "Edge masking"
21 | Random binary node feature mask following Bernoulli distribution with parameter p_f
22 | Random binary edge mask following Bernoulli distribution with parameter p_e
23 | """
24 | self.p_f1 = p_f1
25 | self.p_f2 = p_f2
26 | self.p_e1 = p_e1
27 | self.p_e2 = p_e2
28 | self.method = "BGRL"
29 |
30 | def _feature_masking(self, data, device):
31 | feat_mask1 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f1
32 | feat_mask2 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f2
33 | feat_mask1, feat_mask2 = feat_mask1.to(device), feat_mask2.to(device)
34 | x1, x2 = data.x.clone(), data.x.clone()
35 | x1, x2 = x1 * feat_mask1, x2 * feat_mask2
36 |
37 | edge_index1, edge_attr1 = dropout_adj(data.edge_index, data.edge_attr, p = self.p_e1)
38 | edge_index2, edge_attr2 = dropout_adj(data.edge_index, data.edge_attr, p = self.p_e2)
39 |
40 | new_data1, new_data2 = data.clone(), data.clone()
41 | new_data1.x, new_data2.x = x1, x2
42 | new_data1.edge_index, new_data2.edge_index = edge_index1, edge_index2
43 | new_data1.edge_attr , new_data2.edge_attr = edge_attr1, edge_attr2
44 |
45 | return new_data1, new_data2
46 |
47 | def __call__(self, data):
48 |
49 | return self._feature_masking(data)
50 |
51 |
52 | def parse_args():
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--root", "-r", type=str, default="data",
55 | help="Path to data directory, where all the datasets will be placed. Default is 'data'")
56 | parser.add_argument("--name", "-n",type=str, default="WikiCS",
57 | help="Name of the dataset. Supported names are: cora, citeseer, pubmed, photo, computers, cs, and physics")
58 | parser.add_argument("--layers", "-l", nargs="+", default=[
59 | 512, 256], help="The number of units of each layer of the GNN. Default is [512, 128]")
60 | parser.add_argument("--pred_hid", '-ph', type=int,
61 | default=512, help="The number of hidden units of layer of the predictor. Default is 512")
62 | parser.add_argument("--init-parts", "-ip", type=int, default=1,
63 | help="The number of initial partitions. Default is 1. Applicable for ClusterSelfGNN")
64 | parser.add_argument("--final-parts", "-fp", type=int, default=1,
65 | help="The number of final partitions. Default is 1. Applicable for ClusterSelfGNN")
66 | parser.add_argument("--aug_params", "-p", nargs="+", default=[
67 | 0.3, 0.4, 0.3, 0.2], help="Hyperparameters for augmentation (p_f1, p_f2, p_e1, p_e2). Default is [0.2, 0.1, 0.2, 0.3]")
68 | parser.add_argument("--lr", '-lr', type=float, default=0.00001,
69 | help="Learning rate. Default is 0.0001.")
70 | parser.add_argument("--dropout", "-do", type=float,
71 | default=0.0, help="Dropout rate. Default is 0.2")
72 | parser.add_argument("--cache-step", '-cs', type=int, default=10,
73 | help="The step size to cache the model, that is, every cache_step the model is persisted. Default is 100.")
74 | parser.add_argument("--epochs", '-e', type=int,
75 | default=20, help="The number of epochs")
76 | parser.add_argument("--device", '-d', type=int,
77 | default=0, help="GPU to use")
78 | return parser.parse_args()
79 |
80 |
81 | def decide_config(root, name):
82 | """
83 | Create a configuration to download datasets
84 | :param root: A path to a root directory where data will be stored
85 | :param name: The name of the dataset to be downloaded
86 | :return: A modified root dir, the name of the dataset class, and parameters associated to the class
87 | """
88 | name = name.lower()
89 | if name == 'cora' or name == 'citeseer' or name == "pubmed":
90 | root = osp.join(root, "pyg", "planetoid")
91 | params = {"kwargs": {"root": root, "name": name},
92 | "name": name, "class": Planetoid, "src": "pyg"}
93 | elif name == "computers":
94 | name = "Computers"
95 | root = osp.join(root, "pyg")
96 | params = {"kwargs": {"root": root, "name": name},
97 | "name": name, "class": Amazon, "src": "pyg"}
98 | elif name == "photo":
99 | name = "Photo"
100 | root = osp.join(root, "pyg")
101 | params = {"kwargs": {"root": root, "name": name},
102 | "name": name, "class": Amazon, "src": "pyg"}
103 | elif name == "cs" :
104 | name = "CS"
105 | root = osp.join(root, "pyg")
106 | params = {"kwargs": {"root": root, "name": name},
107 | "name": name, "class": Coauthor, "src": "pyg"}
108 | elif name == "physics":
109 | name = "Physics"
110 | root = osp.join(root, "pyg")
111 | params = {"kwargs": {"root": root, "name": name},
112 | "name": name, "class": Coauthor, "src": "pyg"}
113 | elif name == "wikics":
114 | name = "WikiCS"
115 | root = osp.join(root, "pyg")
116 | params = {"kwargs": {"root": root},
117 | "name": name, "class": WikiCS, "src": "pyg"}
118 | else:
119 | raise Exception(
120 | f"Unknown dataset name {name}, name has to be one of the following 'cora', 'citeseer', 'pubmed', 'photo', 'computers', 'cs', 'physics'")
121 | return params
122 |
123 |
124 | def create_dirs(dirs):
125 | for dir_tree in dirs:
126 | sub_dirs = dir_tree.split("/")
127 | path = ""
128 | for sub_dir in sub_dirs:
129 | path = osp.join(path, sub_dir)
130 | os.makedirs(path, exist_ok=True)
131 |
132 |
133 | def create_masks(data):
134 | """
135 | Splits data into training, validation, and test splits in a stratified manner if
136 | it is not already splitted. Each split is associated with a mask vector, which
137 | specifies the indices for that split. The data will be modified in-place
138 | :param data: Data object
139 | :return: The modified data
140 | """
141 | if not hasattr(data, "val_mask"):
142 |
143 | data.train_mask = data.dev_mask = data.test_mask = None
144 |
145 | for i in range(20):
146 | labels = data.y.numpy()
147 | dev_size = int(labels.shape[0] * 0.1)
148 | test_size = int(labels.shape[0] * 0.8)
149 |
150 | perm = np.random.permutation(labels.shape[0])
151 | test_index = perm[:test_size]
152 | dev_index = perm[test_size:test_size+dev_size]
153 |
154 | data_index = np.arange(labels.shape[0])
155 | test_mask = torch.tensor(np.in1d(data_index, test_index), dtype=torch.bool)
156 | dev_mask = torch.tensor(np.in1d(data_index, dev_index), dtype=torch.bool)
157 | train_mask = ~(dev_mask + test_mask)
158 | test_mask = test_mask.reshape(1, -1)
159 | dev_mask = dev_mask.reshape(1, -1)
160 | train_mask = train_mask.reshape(1, -1)
161 |
162 | if data.train_mask is None :
163 | data.train_mask = train_mask
164 | data.val_mask = dev_mask
165 | data.test_mask = test_mask
166 | else :
167 | data.train_mask = torch.cat((data.train_mask, train_mask), dim = 0)
168 | data.val_mask = torch.cat((data.val_mask, dev_mask), dim = 0)
169 | data.test_mask = torch.cat((data.test_mask, test_mask), dim = 0)
170 |
171 | else :
172 | data.train_mask = data.train_mask.T
173 | data.val_mask = data.val_mask.T
174 |
175 | return data
--------------------------------------------------------------------------------