├── .gitignore ├── README.md ├── config ├── __pycache__ │ └── config.cpython-36.pyc ├── config.py └── vgae.yaml ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /__pycache__ 2 | /datasets 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Graph Auto-encoder in Pytorch Geometric 2 | 3 | This respository implements variational graph auto-encoder in [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric), adapted from the autoencoder example [code](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/autoencoder.py) in pyG. For details of the model, refer to Thomas Klpf's original [paper](https://arxiv.org/abs/1611.07308). 4 | 5 | ## Requirements 6 | 7 | - Python >= 3.6 8 | - Pytorch == 1.5 9 | - Pytorch Geometric == 1.5 10 | - scikit-learn 11 | - scipy 12 | 13 | ## How to run 14 | 15 | 1. Configure the arguments in `config/vgae.yaml` file. You can also make your own config file. 16 | 17 | 2. Specify the config file and run the training script. 18 | ``` 19 | python train.py --load_config config/vgae.yaml 20 | ``` 21 | 22 | ## Result 23 | 24 | We follow the arguments set as the original [paper](https://arxiv.org/abs/1611.07308) and the results is shown below. 25 | 26 | | Dataset | AUC | AP | 27 | |---------|-----|-----| 28 | | Cora |0.903|0.911| 29 | | Citeseer|0.869|0.879| 30 | | Pubmed |0.948|0.948| -------------------------------------------------------------------------------- /config/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Flawless1202/VGAE_pyG/2d30131bf22a253e682c09ffc00c1b7115702b32/config/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | 4 | 5 | # class Config() 6 | 7 | 8 | def visualize_config(args): 9 | """ 10 | Visualize the configuration on the terminal to check the state 11 | :param args: 12 | :return: 13 | """ 14 | print("\nUsing this arguments check it\n") 15 | for key, value in sorted(vars(args).items()): 16 | if value is not None: 17 | print("{} -- {} --".format(key, value)) 18 | print("\n\n") 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--load_config', 24 | dest='config_file', 25 | # type=argparse.FileType(mode='r'), 26 | help='The yaml configuration file') 27 | args, unprocessed_args = parser.parse_known_args() 28 | 29 | # parser.add_argument('--data_root', default=None, required=True, help='The data folder') 30 | # parser.add_argument('--phase', default=None, required=True, help='train or val') 31 | 32 | if args.config_file: 33 | with open(args.config_file, 'r') as f: 34 | parser.set_defaults(**yaml.load(f)) 35 | 36 | args = parser.parse_args(unprocessed_args) 37 | visualize_config(args) 38 | return args -------------------------------------------------------------------------------- /config/vgae.yaml: -------------------------------------------------------------------------------- 1 | dataset: "Cora" 2 | 3 | enc_in_channels: 1433 4 | enc_hidden_channels: 32 5 | enc_out_channels: 16 6 | 7 | lr: 0.01 8 | epoch: 400 9 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_geometric.nn.models import InnerProductDecoder, VGAE 6 | from torch_geometric.nn.conv import GCNConv 7 | from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops 8 | 9 | 10 | class GCNEncoder(nn.Module): 11 | def __init__(self, in_channels, hidden_channels, out_channels): 12 | super(GCNEncoder, self).__init__() 13 | self.gcn_shared = GCNConv(in_channels, hidden_channels) 14 | self.gcn_mu = GCNConv(hidden_channels, out_channels) 15 | self.gcn_logvar = GCNConv(hidden_channels, out_channels) 16 | 17 | def forward(self, x, edge_index): 18 | x = F.relu(self.gcn_shared(x, edge_index)) 19 | mu = self.gcn_mu(x, edge_index) 20 | logvar = self.gcn_logvar(x, edge_index) 21 | return mu, logvar 22 | 23 | 24 | class DeepVGAE(VGAE): 25 | def __init__(self, args): 26 | super(DeepVGAE, self).__init__(encoder=GCNEncoder(args.enc_in_channels, 27 | args.enc_hidden_channels, 28 | args.enc_out_channels), 29 | decoder=InnerProductDecoder()) 30 | 31 | def forward(self, x, edge_index): 32 | z = self.encode(x, edge_index) 33 | adj_pred = self.decoder.forward_all(z) 34 | return adj_pred 35 | 36 | def loss(self, x, pos_edge_index, all_edge_index): 37 | z = self.encode(x, pos_edge_index) 38 | 39 | pos_loss = -torch.log( 40 | self.decoder(z, pos_edge_index, sigmoid=True) + 1e-15).mean() 41 | 42 | # Do not include self-loops in negative samples 43 | all_edge_index_tmp, _ = remove_self_loops(all_edge_index) 44 | all_edge_index_tmp, _ = add_self_loops(all_edge_index_tmp) 45 | 46 | neg_edge_index = negative_sampling(all_edge_index_tmp, z.size(0), pos_edge_index.size(1)) 47 | neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + 1e-15).mean() 48 | 49 | kl_loss = 1 / x.size(0) * self.kl_loss() 50 | 51 | return pos_loss + neg_loss + kl_loss 52 | 53 | def single_test(self, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index): 54 | with torch.no_grad(): 55 | z = self.encode(x, train_pos_edge_index) 56 | roc_auc_score, average_precision_score = self.test(z, test_pos_edge_index, test_neg_edge_index) 57 | return roc_auc_score, average_precision_score 58 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.optim import Adam 6 | 7 | from torch_geometric.datasets import Planetoid 8 | import torch_geometric.transforms as T 9 | from torch_geometric.utils import train_test_split_edges 10 | 11 | from model import DeepVGAE 12 | from config.config import parse_args 13 | 14 | torch.manual_seed(12345) 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | args = parse_args() 18 | 19 | model = DeepVGAE(args).to(device) 20 | optimizer = Adam(model.parameters(), lr=args.lr) 21 | 22 | os.makedirs("datasets", exist_ok=True) 23 | dataset = Planetoid("datasets", args.dataset, transform=T.NormalizeFeatures()) 24 | data = dataset[0].to(device) 25 | all_edge_index = data.edge_index 26 | data = train_test_split_edges(data, 0.05, 0.1) 27 | 28 | for epoch in range(args.epoch): 29 | model.train() 30 | optimizer.zero_grad() 31 | loss = model.loss(data.x, data.train_pos_edge_index, all_edge_index) 32 | loss.backward() 33 | optimizer.step() 34 | if epoch % 2 == 0: 35 | model.eval() 36 | roc_auc, ap = model.single_test(data.x, 37 | data.train_pos_edge_index, 38 | data.test_pos_edge_index, 39 | data.test_neg_edge_index) 40 | print("Epoch {} - Loss: {} ROC_AUC: {} Precision: {}".format(epoch, loss.cpu().item(), roc_auc, ap)) 41 | --------------------------------------------------------------------------------