├── .gitignore ├── LICENSE ├── README.md ├── configs ├── amazon_computer │ ├── train_Async.yaml │ └── train_Sync.yaml ├── amazon_photo │ ├── train_Async.yaml │ └── train_Sync.yaml └── cora │ ├── train_Async.yaml │ └── train_Sync.yaml ├── data.py ├── eval_utils.py ├── model ├── __init__.py ├── common_blocks │ ├── __init__.py │ ├── gae.py │ └── gcn.py ├── diffusion.py ├── discriminator │ ├── __init__.py │ ├── appnp.py │ ├── base.py │ ├── cn.py │ ├── gae.py │ ├── gcn.py │ ├── mlp.py │ └── sgc.py └── gnn.py ├── model_240125.png ├── orca ├── orca.cpp └── orca.h ├── sample.py ├── setup_utils.py ├── train_async.py └── train_sync.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | wandb 3 | cora_cpts/ 4 | orca/ 5 | downloaded_cpts/ 6 | upload.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Graph-COM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphMaker 2 | 3 | [[Paper]](https://arxiv.org/abs/2310.13833) 4 | 5 | ![model](model_240125.png) 6 | 7 | ## Table of Contents 8 | 9 | - [Installation](#installation) 10 | - [Usage](#usage) 11 | * [Train](#train) 12 | * [Sample](#sample) 13 | * [Sample with Pre-Trained Models](#sample-with-pre-trained-models) 14 | - [Frequently Asked Questions](#frequently-asked-questions) 15 | * [Q1: libcusparse.so](#q1-libcusparseso) 16 | * [Q2: Wandb](#q2-wandb) 17 | * [Q3: Other Requests](#q3-other-requests) 18 | - [Citation](#citation) 19 | 20 | ## Installation 21 | 22 | ```bash 23 | conda create -n GraphMaker python=3.8 -y 24 | conda activate GraphMaker 25 | pip install torch==1.12.0+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 26 | conda install -c conda-forge cudatoolkit=11.6 27 | pip install dgl==1.1.0+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html 28 | pip install pandas scikit-learn pydantic wandb huggingface_hub==0.30.2 29 | ``` 30 | 31 | You also need to compile `orca.cpp` (https://file.biolab.si/biolab/supp/orca/orca.html). 32 | 33 | ```bash 34 | cd orca 35 | g++ -O2 -std=c++11 -o orca orca.cpp 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### Train 41 | 42 | ```bash 43 | # The GraphMaker-Sync variant simultaneously generates node attributes and graph structure. 44 | python train_sync.py -d D 45 | # The trained model checkpoint will be saved to {D}_cpts/Sync_XXX.pth 46 | 47 | # The GraphMaker-Async variant first generates node attributes, and then generates graph structure. 48 | python train_async.py -d D 49 | # The trained model checkpoint will be saved to {D}_cpts/Async_XXX.pth 50 | ``` 51 | 52 | `D` can be one of the three built-in datasets, including `cora`, `amazon_photo`, `amazon_computer`. 53 | 54 | ### Sample 55 | 56 | ```bash 57 | python sample.py --model_path P 58 | ``` 59 | 60 | `P` is the path to a model checkpoint saved in the training stage. 61 | 62 | ### Sample with Pre-Trained Models 63 | 64 | Alternatively, you can also use our pre-trained model checkpoints for sampling. 65 | 66 | ```bash 67 | python sample.py --dataset D --type T 68 | ``` 69 | 70 | - `D` can be one of the three built-in datasets, including `cora`, `amazon_photo`, `amazon_computer`. 71 | - `T` can be `sync` or `async`. 72 | 73 | ## Frequently Asked Questions 74 | 75 | ### Q1: libcusparse.so 76 | 77 | **An error occurs that the program cannot find `libcusparse.so`.** 78 | 79 | To search for the location of it on linux, 80 | 81 | ```bash 82 | find /path/to/directory -name libcusparse.so.11 -exec realpath {} \; 83 | ``` 84 | 85 | where `/path/to/directory` is the directory you want to search. Assume that the search returns `home/miniconda3/envs/GraphMaker/lib/libcusparse.so.11`. Then you need to manually specify the environment variable as follows. 86 | 87 | ```bash 88 | export LD_LIBRARY_PATH=home/miniconda3/envs/GraphMaker/lib:$LD_LIBRARY_PATH 89 | ``` 90 | 91 | ### Q2: Wandb 92 | 93 | **What is WandB?** 94 | 95 | [WandB](https://wandb.ai/site) is a tool for visualizing and tracking your machine learning experiments. It's free to use for open source projects. You may also use our code without it. 96 | 97 | ### Q3: Other Requests 98 | 99 | **I have a question or request not listed here.** 100 | 101 | - It's generally recommended to open a GitHub issue. This allows us to track the progress, and the discussion might help others who have the same question. 102 | - Otherwise, you can also send an email to `mufeili1996@gmail.com`. 103 | 104 | ## Citation 105 | 106 | ```tex 107 | @article{li2024graphmaker, 108 | title={GraphMaker: Can Diffusion Models Generate Large Attributed Graphs?}, 109 | author={Mufei Li and Eleonora Kreačić and Vamsi K. Potluru and Pan Li}, 110 | journal={Transactions on Machine Learning Research}, 111 | year={2024} 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /configs/amazon_computer/train_Async.yaml: -------------------------------------------------------------------------------- 1 | meta_data : 2 | variant: "Async" # Model name. 3 | 4 | mlp_X : # MLP for reconstructing node features. 5 | hidden_t: 16 # Hidden size for the normalized time step. 6 | hidden_X: 1024 # Hidden size for the node features. 7 | hidden_Y: 64 # Hidden size for the node labels. 8 | num_mlp_layers: 2 # Number of GNN layers. 9 | dropout: 0. # Dropout rate. 10 | 11 | gnn_E : # GNN for reconstructing edges. 12 | hidden_t: 16 # Hidden size for the normalized time step. 13 | hidden_X: 512 # Hidden size for the node features. 14 | hidden_Y: 64 # Hidden size for the node labels. 15 | hidden_E: 128 # Hidden size for the edges. 16 | num_gnn_layers: 2 # Number of GNN layers. 17 | dropout: 0. # Dropout rate. 18 | 19 | diffusion : 20 | T_X : 7 # Number of diffusion steps for node features. 21 | T_E : 9 # Number of diffusion steps for edges. 22 | 23 | optimizer_X : 24 | lr : 0.001 # Learning rate. 25 | weight_decay : 0 # Weight decay. 26 | amsgrad : true 27 | 28 | optimizer_E : 29 | lr : 0.0003 # Learning rate. 30 | weight_decay : 0 # Weight decay. 31 | amsgrad : true 32 | 33 | lr_scheduler : 34 | factor : 0.9 # Factor by which the learning rate will be reduced. 35 | patience : 3 # Number of epochs with no improvement after which learning rate will be reduced. 36 | verbose : true 37 | 38 | train : 39 | num_epochs : 200 # Number of training epochs. 40 | val_every_epochs : 5 # Frequency of performing validation. 41 | patient_epochs : 15 # Patience for early stop. 42 | max_grad_norm : 10 # Maximal grad norm. 43 | batch_size : 2097152 # Batch size for edge prediction. 44 | val_batch_size : 4194304 # Batch size for validation. 45 | -------------------------------------------------------------------------------- /configs/amazon_computer/train_Sync.yaml: -------------------------------------------------------------------------------- 1 | meta_data : 2 | variant: "Sync" # Model name. 3 | 4 | gnn_X : # GNN for reconstructing node attributes. 5 | hidden_t: 16 # Hidden size for the normalized time step. 6 | hidden_X: 512 # Hidden size for the node attributes. 7 | hidden_Y: 64 # Hidden size for the node labels. 8 | num_gnn_layers: 2 # Number of GNN layers. 9 | dropout: 0. # Dropout rate. 10 | 11 | gnn_E : # GNN for reconstructing edges. 12 | hidden_t: 16 # Hidden size for the normalized time step. 13 | hidden_X: 512 # Hidden size for the node attributes. 14 | hidden_Y: 64 # Hidden size for the node labels. 15 | hidden_E: 128 # Hidden size for the edges. 16 | num_gnn_layers: 2 # Number of GNN layers. 17 | dropout: 0. # Dropout rate. 18 | 19 | diffusion : 20 | T : 3 # Number of diffusion steps - 1, slightly inconsistent with the paper. 21 | 22 | optimizer_X : 23 | lr : 0.001 # Learning rate. 24 | weight_decay : 0 # Weight decay. 25 | amsgrad : true 26 | 27 | optimizer_E : 28 | lr : 0.0003 # Learning rate. 29 | weight_decay : 0 # Weight decay. 30 | amsgrad : true 31 | 32 | lr_scheduler : 33 | factor : 0.9 # Factor by which the learning rate will be reduced. 34 | patience : 0 # Number of epochs with no improvement after which learning rate will be reduced. 35 | verbose : true 36 | 37 | train : 38 | num_epochs : 300 # Number of training epochs. 39 | val_every_epochs : 3 # Frequency of performing validation. 40 | patient_epochs : 15 # Patience for early stop. 41 | max_grad_norm : 10 # Maximal grad norm. 42 | batch_size : 2097152 # Batch size for edge prediction. 43 | val_batch_size : 1048576 # Batch size for validation. 44 | -------------------------------------------------------------------------------- /configs/amazon_photo/train_Async.yaml: -------------------------------------------------------------------------------- 1 | meta_data : 2 | variant: "Async" # Model name. 3 | 4 | mlp_X : # MLP for reconstructing node features. 5 | hidden_t: 16 # Hidden size for the normalized time step. 6 | hidden_X: 512 # Hidden size for the node features. 7 | hidden_Y: 64 # Hidden size for the node labels. 8 | num_mlp_layers: 2 # Number of GNN layers. 9 | dropout: 0. # Dropout rate. 10 | 11 | gnn_E : # GNN for reconstructing edges. 12 | hidden_t: 16 # Hidden size for the normalized time step. 13 | hidden_X: 512 # Hidden size for the node features. 14 | hidden_Y: 64 # Hidden size for the node labels. 15 | hidden_E: 128 # Hidden size for the edges. 16 | num_gnn_layers: 2 # Number of GNN layers. 17 | dropout: 0. # Dropout rate. 18 | 19 | diffusion : 20 | T_X : 6 # Number of diffusion steps for node features. 21 | T_E : 9 # Number of diffusion steps for edges. 22 | 23 | optimizer_X : 24 | lr : 0.001 # Learning rate. 25 | weight_decay : 0 # Weight decay. 26 | amsgrad : true 27 | 28 | optimizer_E : 29 | lr : 0.0003 # Learning rate. 30 | weight_decay : 0 # Weight decay. 31 | amsgrad : true 32 | 33 | lr_scheduler : 34 | factor : 0.9 # Factor by which the learning rate will be reduced. 35 | patience : 5 # Number of epochs with no improvement after which learning rate will be reduced. 36 | verbose : true 37 | 38 | train : 39 | num_epochs : 200 # Number of training epochs. 40 | val_every_epochs : 3 # Frequency of performing validation. 41 | patient_epochs : 15 # Patience for early stop. 42 | max_grad_norm : 10 # Maximal grad norm. 43 | batch_size : 262144 # Batch size for edge prediction. 44 | val_batch_size : 1048576 # Batch size for validation. 45 | -------------------------------------------------------------------------------- /configs/amazon_photo/train_Sync.yaml: -------------------------------------------------------------------------------- 1 | meta_data : 2 | variant: "Sync" # Model name. 3 | 4 | gnn_X : # GNN for reconstructing node attributes. 5 | hidden_t: 16 # Hidden size for the normalized time step. 6 | hidden_X: 512 # Hidden size for the node attributes. 7 | hidden_Y: 64 # Hidden size for the node labels. 8 | num_gnn_layers: 2 # Number of GNN layers. 9 | dropout: 0. # Dropout rate. 10 | 11 | gnn_E : # GNN for reconstructing edges. 12 | hidden_t: 16 # Hidden size for the normalized time step. 13 | hidden_X: 512 # Hidden size for the node attributes. 14 | hidden_Y: 64 # Hidden size for the node labels. 15 | hidden_E: 128 # Hidden size for the edges. 16 | num_gnn_layers: 2 # Number of GNN layers. 17 | dropout: 0. # Dropout rate. 18 | 19 | diffusion : 20 | T : 3 # Number of diffusion steps - 1, slightly inconsistent with the paper. 21 | 22 | optimizer_X : 23 | lr : 0.001 # Learning rate. 24 | weight_decay : 0 # Weight decay. 25 | amsgrad : true 26 | 27 | optimizer_E : 28 | lr : 0.0003 # Learning rate. 29 | weight_decay : 0 # Weight decay. 30 | amsgrad : true 31 | 32 | lr_scheduler : 33 | factor : 0.9 # Factor by which the learning rate will be reduced. 34 | patience : 5 # Number of epochs with no improvement after which learning rate will be reduced. 35 | verbose : true 36 | 37 | train : 38 | num_epochs : 400 # Number of training epochs. 39 | val_every_epochs : 3 # Frequency of performing validation. 40 | patient_epochs : 15 # Patience for early stop. 41 | max_grad_norm : 10 # Maximal grad norm. 42 | batch_size : 524288 # Batch size for edge prediction. 43 | val_batch_size : 524288 # Batch size for validation. 44 | -------------------------------------------------------------------------------- /configs/cora/train_Async.yaml: -------------------------------------------------------------------------------- 1 | meta_data : 2 | variant: "Async" # Model name. 3 | 4 | mlp_X : # MLP for reconstructing node features. 5 | hidden_t: 32 # Hidden size for the normalized time step. 6 | hidden_X: 512 # Hidden size for the node features. 7 | hidden_Y: 64 # Hidden size for the node labels. 8 | num_mlp_layers: 2 # Number of GNN layers. 9 | dropout: 0.1 # Dropout rate. 10 | 11 | gnn_E : # GNN for reconstructing edges. 12 | hidden_t: 32 # Hidden size for the normalized time step. 13 | hidden_X: 512 # Hidden size for the node features. 14 | hidden_Y: 64 # Hidden size for the node labels. 15 | hidden_E: 128 # Hidden size for the edges. 16 | num_gnn_layers: 2 # Number of GNN layers. 17 | dropout: 0.1 # Dropout rate. 18 | 19 | diffusion : 20 | T_X : 6 # Number of diffusion steps for node features. 21 | T_E : 9 # Number of diffusion steps for edges. 22 | 23 | optimizer_X : 24 | lr : 0.001 # Learning rate. 25 | weight_decay : 0 # Weight decay. 26 | amsgrad : true 27 | 28 | optimizer_E : 29 | lr : 0.0003 # Learning rate. 30 | weight_decay : 0 # Weight decay. 31 | amsgrad : true 32 | 33 | lr_scheduler : 34 | factor : 0.9 # Factor by which the learning rate will be reduced. 35 | patience : 5 # Number of epochs with no improvement after which learning rate will be reduced. 36 | verbose : true 37 | 38 | train : 39 | num_epochs : 10000 # Number of training epochs. 40 | val_every_epochs : 1 # Frequency of performing validation. 41 | patient_epochs : 20 # Patience for early stop. 42 | max_grad_norm : 10 # Maximal grad norm. 43 | batch_size : 16384 # Batch size for edge prediction. 44 | val_batch_size : 131072 # Batch size for validation. 45 | -------------------------------------------------------------------------------- /configs/cora/train_Sync.yaml: -------------------------------------------------------------------------------- 1 | meta_data : 2 | variant: "Sync" # Model name. 3 | 4 | gnn_X : # GNN for reconstructing node attributes. 5 | hidden_t: 32 # Hidden size for the normalized time step. 6 | hidden_X: 512 # Hidden size for the node attributes. 7 | hidden_Y: 64 # Hidden size for the node labels. 8 | num_gnn_layers: 2 # Number of GNN layers. 9 | dropout: 0 # Dropout rate. 10 | 11 | gnn_E : # GNN for reconstructing edges. 12 | hidden_t: 32 # Hidden size for the normalized time step. 13 | hidden_X: 512 # Hidden size for the node attributes. 14 | hidden_Y: 64 # Hidden size for the node labels. 15 | hidden_E: 128 # Hidden size for the edges. 16 | num_gnn_layers: 2 # Number of GNN layers. 17 | dropout: 0 # Dropout rate. 18 | 19 | diffusion : 20 | T : 3 # Number of diffusion steps - 1, slightly inconsistent with the paper. 21 | 22 | optimizer_X : 23 | lr : 0.001 # Learning rate. 24 | weight_decay : 0 # Weight decay. 25 | amsgrad : true 26 | 27 | optimizer_E : 28 | lr : 0.0003 # Learning rate. 29 | weight_decay : 0 # Weight decay. 30 | amsgrad : true 31 | 32 | lr_scheduler : 33 | factor : 0.9 # Factor by which the learning rate will be reduced. 34 | patience : 5 # Number of epochs with no improvement after which learning rate will be reduced. 35 | verbose : true 36 | 37 | train : 38 | num_epochs : 10000 # Number of training epochs. 39 | val_every_epochs : 1 # Frequency of performing validation. 40 | patient_epochs : 20 # Patience for early stop. 41 | max_grad_norm : 10 # Maximal grad norm. 42 | batch_size : 16384 # Batch size for edge prediction. 43 | val_batch_size : 131072 # Batch size for validation. 44 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from dgl.data import AmazonCoBuyPhotoDataset, AmazonCoBuyComputerDataset, \ 6 | CoraGraphDataset 7 | 8 | def load_dataset(data_name): 9 | if data_name == "cora": 10 | dataset = CoraGraphDataset() 11 | elif data_name == "amazon_photo": 12 | dataset = AmazonCoBuyPhotoDataset() 13 | elif data_name == "amazon_computer": 14 | dataset = AmazonCoBuyComputerDataset() 15 | 16 | g = dataset[0] 17 | g = dgl.remove_self_loop(g) 18 | 19 | X = g.ndata['feat'] 20 | X[X != 0] = 1. 21 | 22 | # Remove columns with constant values. 23 | non_full_zero_feat_mask = X.sum(dim=0) != 0 24 | X = X[:, non_full_zero_feat_mask] 25 | 26 | non_full_one_feat_mask = X.sum(dim=0) != X.size(0) 27 | X = X[:, non_full_one_feat_mask] 28 | 29 | g.ndata['feat'] = X 30 | return g 31 | 32 | def preprocess(g): 33 | """Prepare data for GraphMaker. 34 | 35 | Parameters 36 | ---------- 37 | g : DGLGraph 38 | Graph to be preprocessed. 39 | 40 | Returns 41 | ------- 42 | X_one_hot : torch.Tensor of shape (F, N, 2) 43 | X_one_hot[f, :, :] is the one-hot encoding of the f-th node attribute. 44 | N = |V|. 45 | Y : torch.Tensor of shape (N) 46 | Categorical node labels. 47 | E_one_hot : torch.Tensor of shape (N, N, 2) 48 | - E_one_hot[:, :, 0] indicates the absence of an edge. 49 | - E_one_hot[:, :, 1] is the original adjacency matrix. 50 | X_marginal : torch.Tensor of shape (F, 2) 51 | X_marginal[f, :] is the marginal distribution of the f-th node attribute. 52 | Y_marginal : torch.Tensor of shape (C) 53 | Marginal distribution of the node labels. 54 | E_marginal : torch.Tensor of shape (2) 55 | Marginal distribution of the edge existence. 56 | X_cond_Y_marginals : torch.Tensor of shape (F, C, 2) 57 | X_cond_Y_marginals[f, k] is the marginal distribution of the f-th node 58 | attribute conditioned on the node label being k. 59 | """ 60 | X = g.ndata['feat'] 61 | Y = g.ndata['label'] 62 | N = g.num_nodes() 63 | src, dst = g.edges() 64 | 65 | X_one_hot_list = [] 66 | for f in range(X.size(1)): 67 | # (N, 2) 68 | X_f_one_hot = F.one_hot(X[:, f].long()) 69 | X_one_hot_list.append(X_f_one_hot) 70 | # (F, N, 2) 71 | X_one_hot = torch.stack(X_one_hot_list, dim=0).float() 72 | 73 | E = torch.zeros(N, N) 74 | E[dst, src] = 1. 75 | # (N, N, 2) 76 | E_one_hot = F.one_hot(E.long()).float() 77 | 78 | # (F, 2) 79 | X_one_hot_count = X_one_hot.sum(dim=1) 80 | # (F, 2) 81 | X_marginal = X_one_hot_count / X_one_hot_count.sum(dim=1, keepdim=True) 82 | 83 | # (N, C) 84 | Y_one_hot = F.one_hot(Y).float() 85 | # (C) 86 | Y_one_hot_count = Y_one_hot.sum(dim=0) 87 | # (C) 88 | Y_marginal = Y_one_hot_count / Y_one_hot_count.sum() 89 | 90 | # (2) 91 | E_one_hot_count = E_one_hot.sum(dim=0).sum(dim=0) 92 | E_marginal = E_one_hot_count / E_one_hot_count.sum() 93 | 94 | # P(X_f | Y) 95 | X_cond_Y_marginals = [] 96 | num_classes = Y_marginal.size(-1) 97 | for k in range(num_classes): 98 | nodes_k = Y == k 99 | X_one_hot_k = X_one_hot[:, nodes_k] 100 | # (F, 2) 101 | X_one_hot_k_count = X_one_hot_k.sum(dim=1) 102 | # (F, 2) 103 | X_marginal_k = X_one_hot_k_count / X_one_hot_k_count.sum(dim=1, keepdim=True) 104 | X_cond_Y_marginals.append(X_marginal_k) 105 | # (F, C, 2) 106 | X_cond_Y_marginals = torch.stack(X_cond_Y_marginals, dim=1) 107 | 108 | return X_one_hot, Y, E_one_hot, X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals 109 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import dgl.function as fn 3 | import dgl.sparse as dglsp 4 | import networkx as nx 5 | import numpy as np 6 | import os 7 | import secrets 8 | import subprocess as sp 9 | import torch 10 | 11 | from functools import partial 12 | from pprint import pprint 13 | from scipy import stats 14 | from string import ascii_uppercase, digits 15 | 16 | from model import BaseEvaluator, MLPTrainer, SGCTrainer, GCNTrainer,\ 17 | APPNPTrainer, GAETrainer, CNEvaluator 18 | 19 | def get_triangle_count(nx_g): 20 | triangle_count = sum(nx.triangles(nx.to_undirected(nx_g)).values()) / 3 21 | return triangle_count 22 | 23 | def linkx_homophily(graph, y): 24 | r"""Homophily measure from `Large Scale Learning on Non-Homophilous Graphs: 25 | New Benchmarks and Strong Simple Methods 26 | `__ 27 | 28 | Mathematically it is defined as follows: 29 | 30 | .. math:: 31 | \frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, \frac{\sum_{v\in C_k}|\{u\in 32 | \mathcal{N}(v): y_v = y_u \}|}{\sum_{v\in C_k}|\mathcal{N}(v)|} - 33 | \frac{|\mathcal{C}_k|}{|\mathcal{V}|} \right), 34 | 35 | where :math:`C` is the number of node classes, :math:`C_k` is the set of 36 | nodes that belong to class k, :math:`\mathcal{N}(v)` are the predecessors 37 | of node :math:`v`, :math:`y_v` is the class of node :math:`v`, and 38 | :math:`\mathcal{V}` is the set of nodes. 39 | 40 | Parameters 41 | ---------- 42 | graph : DGLGraph 43 | The graph. 44 | y : torch.Tensor 45 | The node labels, which is a tensor of shape (|V|). 46 | 47 | Returns 48 | ------- 49 | float 50 | The homophily value. 51 | """ 52 | with graph.local_scope(): 53 | # Compute |{u\in N(v): y_v = y_u}| for each node v. 54 | src, dst = graph.edges() 55 | # Compute y_v = y_u for all edges. 56 | graph.edata["same_class"] = (y[src] == y[dst]).float() 57 | graph.update_all( 58 | fn.copy_e("same_class", "m"), fn.sum("m", "same_class_deg") 59 | ) 60 | 61 | deg = graph.in_degrees().float() 62 | num_nodes = graph.num_nodes() 63 | num_classes = y.max(dim=0).values.item() + 1 64 | 65 | value = torch.tensor(0.0).to(graph.device) 66 | for k in range(num_classes): 67 | # Get the nodes that belong to class k. 68 | class_mask = y == k 69 | same_class_deg_k = graph.ndata["same_class_deg"][class_mask].sum() 70 | deg_k = deg[class_mask].sum() 71 | num_nodes_k = class_mask.sum() 72 | value += max(0, same_class_deg_k / deg_k - num_nodes_k / num_nodes) 73 | 74 | return value.item() / (num_classes - 1) 75 | 76 | def edge_list_reindexed(G): 77 | idx = 0 78 | id2idx = dict() 79 | for u in G.nodes(): 80 | id2idx[str(u)] = idx 81 | idx += 1 82 | 83 | edges = [] 84 | for (u, v) in G.edges(): 85 | edges.append((id2idx[str(u)], id2idx[str(v)])) 86 | return edges 87 | 88 | COUNT_START_STR = 'orbit counts:' 89 | 90 | def orca(graph): 91 | graph = graph.to_undirected() 92 | 93 | tmp_fname = f'orca/tmp_{"".join(secrets.choice(ascii_uppercase + digits) for i in range(8))}.txt' 94 | tmp_fname = os.path.join(os.path.dirname(os.path.realpath(__file__)), tmp_fname) 95 | 96 | with open(tmp_fname, 'w') as f: 97 | f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n') 98 | for (u, v) in edge_list_reindexed(graph): 99 | f.write(str(u) + ' ' + str(v) + '\n') 100 | output = sp.check_output( 101 | [str(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'orca/orca')), 'node', '4', tmp_fname, 'std']) 102 | output = output.decode('utf8').strip() 103 | idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) + 2 104 | output = output[idx:] 105 | 106 | node_orbit_counts = np.array([ 107 | list(map(int, node_cnts.strip().split(' '))) 108 | for node_cnts in output.strip('\n').split('\n') 109 | ]) 110 | 111 | try: 112 | os.remove(tmp_fname) 113 | except OSError: 114 | pass 115 | 116 | return node_orbit_counts 117 | 118 | def get_orbit_dist(nx_g): 119 | # (|V|, Q), where Q is the number of unique orbits 120 | orbit_counts = orca(nx_g) 121 | 122 | orbit_counts = np.sum(orbit_counts, axis=0) / nx_g.number_of_nodes() 123 | orbit_counts = torch.from_numpy(orbit_counts) 124 | orbit_dist = orbit_counts / max(orbit_counts.sum(), 1) 125 | 126 | return orbit_dist 127 | 128 | def get_adj(dgl_g): 129 | # Get symmetrically normalized adjacency matrix. 130 | A = dgl_g.adj() 131 | N = dgl_g.num_nodes() 132 | I = dglsp.identity((N, N), device=dgl_g.device) 133 | A_hat = A + I 134 | D_hat = dglsp.diag(A_hat.sum(1)) ** -0.5 135 | A_norm = D_hat @ A_hat @ D_hat 136 | 137 | return A_norm 138 | 139 | def get_edge_split(A_dense): 140 | # Exclude self-loops. 141 | A_dense_upper = torch.triu(A_dense, diagonal=1) 142 | real_edges = A_dense_upper.nonzero() 143 | 144 | real_indices = torch.randperm(real_edges.size(0)) 145 | real_edges = real_edges[real_indices] 146 | 147 | num_real = len(real_edges) 148 | num_train = int(num_real * 0.8) 149 | num_val = int(num_real * 0.1) 150 | num_test = num_real - num_train - num_val 151 | 152 | real_train, real_val, real_test = torch.split( 153 | real_edges, [num_train, num_val, num_test]) 154 | 155 | neg_edges = torch.triu((A_dense == 0).float(), diagonal=1).nonzero() 156 | neg_indices = torch.randperm(neg_edges.size(0)) 157 | 158 | neg_val = neg_edges[neg_indices[:num_val]] 159 | neg_test = neg_edges[neg_indices[num_val:num_val+num_test]] 160 | 161 | return real_train, real_val, real_test, neg_val, neg_test 162 | 163 | def prepare_for_GAE(A): 164 | A_dense = A.to_dense() 165 | 166 | real_train, real_val, real_test, neg_val, neg_test = get_edge_split(A_dense) 167 | 168 | num_nodes = A_dense.size(0) 169 | train_mask = torch.zeros(num_nodes, num_nodes) 170 | val_mask = torch.zeros(num_nodes, num_nodes) 171 | test_mask = torch.zeros(num_nodes, num_nodes) 172 | 173 | edge_train = real_train 174 | edge_val = torch.cat([real_val, neg_val], dim=0) 175 | edge_test = torch.cat([real_test, neg_test], dim=0) 176 | 177 | row_train, col_train = edge_train.T 178 | train_mask[row_train, col_train] = 1. 179 | 180 | row_val, col_val = edge_val.T 181 | val_mask[row_val, col_val] = 1. 182 | 183 | row_test, col_test = edge_test.T 184 | test_mask[row_test, col_test] = 1. 185 | 186 | train_mask = train_mask.bool() 187 | val_mask = val_mask.bool() 188 | test_mask = test_mask.bool() 189 | 190 | real_row_train, real_col_train = real_train.T 191 | train_g = dgl.graph((real_row_train, real_col_train), num_nodes=num_nodes) 192 | train_g = dgl.to_bidirected(train_g) 193 | A_train = get_adj(train_g) 194 | 195 | return A_train, train_mask, val_mask, test_mask 196 | 197 | def emd(p, q): 198 | return ( 199 | torch.cumsum(p, dim=0) - torch.cumsum(q, dim=0) 200 | ).abs().sum().item() 201 | 202 | def get_pairwise_emd(real_dists, sample_dists): 203 | emd_list = [] 204 | for p in real_dists: 205 | for q in sample_dists: 206 | emd_list.append(emd(p, q)) 207 | return float(np.mean(emd_list)) 208 | 209 | def get_deg_emd(real_degs, sample_degs): 210 | """Compute the earth mover distance (EMD) between 211 | two degree distributions. 212 | 213 | Parameters 214 | ---------- 215 | real_degs : list of torch.Tensor of shape (|V1|) 216 | Node degrees of the real graphs. 217 | sample_degs : list of torch.Tensor of shape (|V2|) 218 | Node degrees of the sampled graphs. 219 | 220 | Returns 221 | ------- 222 | emd 223 | The EMD value. 224 | """ 225 | max_deg = max( 226 | max([deg.max().item() for deg in real_degs]), 227 | max([deg.max().item() for deg in sample_degs]) 228 | ) 229 | 230 | def get_degree_dist(deg): 231 | num_nodes = deg.size(0) 232 | freq = torch.zeros(num_nodes, max_deg + 1) 233 | freq[torch.arange(num_nodes), deg] = 1. 234 | freq = freq.sum(dim=0) 235 | return freq / (freq.sum() + 1e-6) 236 | 237 | real_dists = [] 238 | for deg in real_degs: 239 | real_dists.append(get_degree_dist(deg)) 240 | 241 | sample_dists = [] 242 | for deg in sample_degs: 243 | sample_dists.append(get_degree_dist(deg)) 244 | 245 | return get_pairwise_emd(real_dists, sample_dists) 246 | 247 | def get_cluster_emd(real_vals, sample_vals, bins=100): 248 | """Compute the earth mover distance (EMD) between 249 | two clustering coefficient distributions. 250 | 251 | Parameters 252 | ---------- 253 | real_vals : list of list of length (|V1|) 254 | Node clustering coefficients of the real graphs. 255 | sample_vals : list of list of length (|V2|) 256 | Node clustering coefficients of the sampled graphs. 257 | bins : int 258 | Number of equal-width bins in the given range. 259 | 260 | Returns 261 | ------- 262 | emd 263 | The EMD value. 264 | """ 265 | def get_cluster_dist(vals): 266 | hist, _ = np.histogram( 267 | vals, bins=bins, range=(0.0, 1.0), density=False) 268 | hist = torch.from_numpy(hist) 269 | return hist / (hist.sum() + 1e-6) 270 | 271 | real_dists = [] 272 | for vals in real_vals: 273 | real_dists.append(get_cluster_dist(vals)) 274 | 275 | sample_dists = [] 276 | for vals in sample_vals: 277 | sample_dists.append(get_cluster_dist(vals)) 278 | 279 | return get_pairwise_emd(real_dists, sample_dists) 280 | 281 | class Evaluator: 282 | def __init__(self, 283 | data_name, 284 | dgl_g_real, 285 | X_one_hot_3d_real, 286 | Y_one_hot_real): 287 | """ 288 | Parameters 289 | ---------- 290 | data_name : str 291 | Name of the dataset. 292 | dgl_g_real : dgl.DGLGraph 293 | Real graph. 294 | X_one_hot_3d_real : torch.Tensor of shape (F, |V|, 2) 295 | X_one_hot_3d_real[f, :, :] is the one-hot encoding of the f-th node 296 | attribute in the real graph. 297 | Y_one_hot_real : torch.Tensor of shape (|V|, C) 298 | One-hot encoding of the node label in the real graph. 299 | """ 300 | self.data_name = data_name 301 | 302 | # If the number of edges in a newly added graph exceeds this limit, 303 | # a subgraph will be used for certain metric computations. 304 | self.edge_limit = min(dgl_g_real.num_edges(), 20000) 305 | 306 | # Split datasets without a built-in split. 307 | add_mask = False 308 | if data_name in ["amazon_photo", "amazon_computer"]: 309 | add_mask = True 310 | torch.manual_seed(0) 311 | 312 | dgl_g_real, X_real, Y_real, data_dict_real = self.preprocess_g( 313 | dgl_g_real, 314 | X_one_hot_3d_real, 315 | Y_one_hot_real, 316 | add_mask) 317 | self.data_dict_real = data_dict_real 318 | self.data_dict_sample_list = [] 319 | 320 | num_classes = len(Y_real.unique()) 321 | 322 | os.makedirs(f"{data_name}_cpts", exist_ok=True) 323 | self.mlp_evaluator = BaseEvaluator(MLPTrainer, 324 | f"{data_name}_cpts/mlp.pth", 325 | num_classes, 326 | train_mask=dgl_g_real.ndata["train_mask"], 327 | val_mask=dgl_g_real.ndata["val_mask"], 328 | test_mask=dgl_g_real.ndata["test_mask"], 329 | X=X_real, 330 | Y=Y_real) 331 | 332 | A_real = get_adj(dgl_g_real) 333 | 334 | self.sgc_one_layer_evaluator = BaseEvaluator( 335 | partial(SGCTrainer, num_gnn_layers=1), 336 | f"{data_name}_cpts/sgc_one_layer.pth", 337 | num_classes, 338 | train_mask=dgl_g_real.ndata["train_mask"], 339 | val_mask=dgl_g_real.ndata["val_mask"], 340 | test_mask=dgl_g_real.ndata["test_mask"], 341 | A=A_real, 342 | X=X_real, 343 | Y=Y_real) 344 | 345 | self.sgc_two_layer_evaluator = BaseEvaluator( 346 | partial(SGCTrainer, num_gnn_layers=2), 347 | f"{data_name}_cpts/sgc_two_layer.pth", 348 | num_classes, 349 | train_mask=dgl_g_real.ndata["train_mask"], 350 | val_mask=dgl_g_real.ndata["val_mask"], 351 | test_mask=dgl_g_real.ndata["test_mask"], 352 | A=A_real, 353 | X=X_real, 354 | Y=Y_real) 355 | 356 | self.gcn_evaluator = BaseEvaluator( 357 | partial(GCNTrainer, num_gnn_layers=2), 358 | f"{data_name}_cpts/gcn.pth", 359 | num_classes, 360 | train_mask=dgl_g_real.ndata["train_mask"], 361 | val_mask=dgl_g_real.ndata["val_mask"], 362 | test_mask=dgl_g_real.ndata["test_mask"], 363 | A=A_real, 364 | X=X_real, 365 | Y=Y_real) 366 | 367 | self.appnp_one_layer_evaluator = BaseEvaluator( 368 | partial(APPNPTrainer, num_gnn_layers=1), 369 | f"{data_name}_cpts/appnp_one_layer.pth", 370 | num_classes, 371 | train_mask=dgl_g_real.ndata["train_mask"], 372 | val_mask=dgl_g_real.ndata["val_mask"], 373 | test_mask=dgl_g_real.ndata["test_mask"], 374 | A=A_real, 375 | X=X_real, 376 | Y=Y_real) 377 | 378 | self.appnp_two_layer_evaluator = BaseEvaluator( 379 | partial(APPNPTrainer, num_gnn_layers=2), 380 | f"{data_name}_cpts/appnp_two_layer.pth", 381 | num_classes, 382 | train_mask=dgl_g_real.ndata["train_mask"], 383 | val_mask=dgl_g_real.ndata["val_mask"], 384 | test_mask=dgl_g_real.ndata["test_mask"], 385 | A=A_real, 386 | X=X_real, 387 | Y=Y_real) 388 | 389 | # Generate train/val/test mask for link prediction. 390 | # Fix the raw graph split for reproducibility. 391 | torch.manual_seed(0) 392 | A_real_train, train_mask, val_mask, test_mask = prepare_for_GAE(A_real) 393 | 394 | self.gae_one_layer_evaluator = BaseEvaluator( 395 | partial(GAETrainer, num_gnn_layers=1), 396 | f"{data_name}_cpts/gae_one_layer.pth", 397 | num_classes, 398 | train_mask=train_mask, 399 | val_mask=val_mask, 400 | test_mask=test_mask, 401 | A_train=A_real_train, 402 | A_full=A_real, 403 | X=X_real, 404 | Y=Y_real) 405 | 406 | self.gae_two_layer_evaluator = BaseEvaluator( 407 | partial(GAETrainer, num_gnn_layers=2), 408 | f"{data_name}_cpts/gae_two_layer.pth", 409 | num_classes, 410 | train_mask=train_mask, 411 | val_mask=val_mask, 412 | test_mask=test_mask, 413 | A_train=A_real_train, 414 | A_full=A_real, 415 | X=X_real, 416 | Y=Y_real) 417 | 418 | self.cn_evaluator = CNEvaluator( 419 | f"{data_name}_cpts/cn.pth", 420 | A_train=A_real_train, 421 | A_full=A_real, 422 | val_mask=val_mask, 423 | test_mask=test_mask 424 | ) 425 | 426 | def add_mask_cora(self, dgl_g, Y_one_hot): 427 | num_nodes = dgl_g.num_nodes() 428 | train_mask = torch.zeros(num_nodes) 429 | val_mask = torch.zeros(num_nodes) 430 | test_mask = torch.zeros(num_nodes) 431 | 432 | # Based on the raw graph 433 | num_val_nodes = { 434 | 0: 61, 435 | 1: 36, 436 | 2: 78, 437 | 3: 158, 438 | 4: 81, 439 | 5: 57, 440 | 6: 29 441 | } 442 | 443 | num_test_nodes = { 444 | 0: 130, 445 | 1: 91, 446 | 2: 144, 447 | 3: 319, 448 | 4: 149, 449 | 5: 103, 450 | 6: 64 451 | } 452 | 453 | num_classes = Y_one_hot.size(-1) 454 | for y in range(num_classes): 455 | nodes_y = (Y_one_hot[:, y] == 1.).nonzero().squeeze(-1) 456 | nid_y = torch.randperm(len(nodes_y)) 457 | nodes_y = nodes_y[nid_y] 458 | 459 | train_mask[nodes_y[:20]] = 1. 460 | 461 | start = 20 462 | end = start + num_val_nodes[y] 463 | val_mask[nodes_y[start: end]] = 1. 464 | 465 | start = end 466 | end = start + num_test_nodes[y] 467 | test_mask[nodes_y[start: end]] = 1. 468 | 469 | dgl_g.ndata["train_mask"] = train_mask.bool() 470 | dgl_g.ndata["val_mask"] = val_mask.bool() 471 | dgl_g.ndata["test_mask"] = test_mask.bool() 472 | 473 | return dgl_g 474 | 475 | def add_mask_benchmark(self, dgl_g, Y_one_hot): 476 | num_nodes = dgl_g.num_nodes() 477 | train_mask = torch.zeros(num_nodes) 478 | val_mask = torch.zeros(num_nodes) 479 | test_mask = torch.zeros(num_nodes) 480 | 481 | num_classes = Y_one_hot.size(-1) 482 | for y in range(num_classes): 483 | nodes_y = (Y_one_hot[:, y] == 1.).nonzero().squeeze(-1) 484 | nid_y = torch.randperm(len(nodes_y)) 485 | nodes_y = nodes_y[nid_y] 486 | 487 | # Based on the raw paper. 488 | train_mask[nodes_y[:20]] = 1. 489 | val_mask[nodes_y[20: 50]] = 1. 490 | test_mask[nodes_y[50:]] = 1. 491 | 492 | dgl_g.ndata["train_mask"] = train_mask.bool() 493 | dgl_g.ndata["val_mask"] = val_mask.bool() 494 | dgl_g.ndata["test_mask"] = test_mask.bool() 495 | 496 | return dgl_g 497 | 498 | def add_mask(self, dgl_g, Y_one_hot): 499 | if self.data_name == "cora": 500 | return self.add_mask_cora(dgl_g, Y_one_hot) 501 | elif self.data_name in ["amazon_photo", "amazon_computer"]: 502 | return self.add_mask_benchmark(dgl_g, Y_one_hot) 503 | else: 504 | raise ValueError(f'Unexpected data name: {self.data_name}') 505 | 506 | def sample_subg(self, dgl_g): 507 | # Sample edge-induced subgraph for costly computation. 508 | A = dgl_g.adj().to_dense() 509 | A_upper = torch.triu(A, diagonal=1) 510 | # (|E|, 2) 511 | edges = A_upper.nonzero() 512 | indices = torch.randperm(edges.size(0))[:self.edge_limit // 2] 513 | src, dst = edges[indices].T 514 | sub_g = dgl.graph((src, dst), num_nodes=dgl_g.num_nodes()) 515 | sub_g = dgl.to_bidirected(sub_g) 516 | 517 | return sub_g 518 | 519 | def k_order_g(self, dgl_g, k): 520 | # Get DGLGraph of A^k. 521 | A = dgl_g.adj().to_dense() 522 | A_new = A 523 | for _ in range(k-1): 524 | A_new = A_new @ A 525 | src, dst = A_new.nonzero().T 526 | new_g = dgl.graph((src, dst), num_nodes=dgl_g.num_nodes()) 527 | return new_g 528 | 529 | def preprocess_g(self, 530 | dgl_g, 531 | X_one_hot_3d, 532 | Y_one_hot, 533 | add_mask): 534 | """ 535 | Parameters 536 | ---------- 537 | dgl_g : dgl.DGLGraph 538 | Graph. 539 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 540 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node 541 | attribute in the graph. 542 | Y_one_hot : torch.Tensor of shape (|V|, C) 543 | One-hot encoding of the node label in the graph. 544 | add_mask : bool 545 | Whether to add a mask to the graph for node classification 546 | data split. 547 | 548 | Returns 549 | ------- 550 | dgl_g : dgl.DGLGraph 551 | Graph, potentially with node mask added. 552 | X : torch.Tensor of shape (|V|, F) 553 | Node attributes. 554 | Y : torch.Tensor of shape (|V|) 555 | Categorical node label. 556 | data_dict : dict 557 | Dictionary of graph statistics. 558 | """ 559 | if add_mask: 560 | dgl_g = self.add_mask(dgl_g, Y_one_hot) 561 | 562 | F = X_one_hot_3d.size(0) 563 | # (|V|, F) 564 | X = torch.zeros(X_one_hot_3d.size(1), F) 565 | for f in range(F): 566 | X[:, f] = X_one_hot_3d[f].argmax(dim=1) 567 | 568 | if dgl_g.num_edges() > self.edge_limit: 569 | dgl_subg = self.sample_subg(dgl_g) 570 | else: 571 | dgl_subg = dgl_g 572 | 573 | nx_g = nx.DiGraph(dgl_subg.cpu().to_networkx()) 574 | 575 | triangle_count = get_triangle_count(nx_g) 576 | 577 | Y = Y_one_hot.argmax(dim=-1) 578 | linkx_A = linkx_homophily(dgl_g, Y) 579 | 580 | dgl_g_pow_2 = self.k_order_g(dgl_g, 2) 581 | linkx_A_pow_2 = linkx_homophily(dgl_g_pow_2, Y) 582 | 583 | degs = dgl_g.in_degrees() 584 | cluster_coefs = list(nx.clustering(nx_g).values()) 585 | orbit_dist = get_orbit_dist(nx_g) 586 | 587 | data_dict = { 588 | "triangle_count": triangle_count, 589 | "linkx_A": linkx_A, 590 | "linkx_A_pow_2": linkx_A_pow_2, 591 | "degs": degs, 592 | "cluster_coefs": cluster_coefs, 593 | "orbit_dist": orbit_dist, 594 | } 595 | 596 | return dgl_g, X, Y, data_dict 597 | 598 | def add_sample(self, 599 | dgl_g, 600 | X_one_hot_3d, 601 | Y_one_hot): 602 | """Add a generated sample for evaluation. 603 | 604 | Parameters 605 | ---------- 606 | dgl_g : dgl.DGLGraph 607 | Generated graph. 608 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 609 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node 610 | attribute in the generated graph. 611 | Y_one_hot : torch.Tensor of shape (|V|, C) 612 | One-hot encoding of the node label in the generated graph. 613 | """ 614 | dgl_g_sample, X_sample, Y_sample, data_dict_sample = self.preprocess_g( 615 | dgl_g, 616 | X_one_hot_3d, 617 | Y_one_hot, 618 | add_mask=True) 619 | 620 | self.data_dict_sample_list.append(data_dict_sample) 621 | 622 | self.mlp_evaluator.add_sample( 623 | X=X_sample, 624 | Y=Y_sample, 625 | train_mask=dgl_g_sample.ndata["train_mask"], 626 | val_mask=dgl_g_sample.ndata["val_mask"], 627 | test_mask=dgl_g_sample.ndata["test_mask"]) 628 | 629 | A_sample = get_adj(dgl_g_sample) 630 | 631 | self.sgc_one_layer_evaluator.add_sample( 632 | A=A_sample, 633 | X=X_sample, 634 | Y=Y_sample, 635 | train_mask=dgl_g_sample.ndata["train_mask"], 636 | val_mask=dgl_g_sample.ndata["val_mask"], 637 | test_mask=dgl_g_sample.ndata["test_mask"]) 638 | 639 | self.sgc_two_layer_evaluator.add_sample( 640 | A=A_sample, 641 | X=X_sample, 642 | Y=Y_sample, 643 | train_mask=dgl_g_sample.ndata["train_mask"], 644 | val_mask=dgl_g_sample.ndata["val_mask"], 645 | test_mask=dgl_g_sample.ndata["test_mask"]) 646 | 647 | self.gcn_evaluator.add_sample( 648 | A=A_sample, 649 | X=X_sample, 650 | Y=Y_sample, 651 | train_mask=dgl_g_sample.ndata["train_mask"], 652 | val_mask=dgl_g_sample.ndata["val_mask"], 653 | test_mask=dgl_g_sample.ndata["test_mask"]) 654 | 655 | self.appnp_one_layer_evaluator.add_sample( 656 | A=A_sample, 657 | X=X_sample, 658 | Y=Y_sample, 659 | train_mask=dgl_g_sample.ndata["train_mask"], 660 | val_mask=dgl_g_sample.ndata["val_mask"], 661 | test_mask=dgl_g_sample.ndata["test_mask"]) 662 | 663 | self.appnp_two_layer_evaluator.add_sample( 664 | A=A_sample, 665 | X=X_sample, 666 | Y=Y_sample, 667 | train_mask=dgl_g_sample.ndata["train_mask"], 668 | val_mask=dgl_g_sample.ndata["val_mask"], 669 | test_mask=dgl_g_sample.ndata["test_mask"]) 670 | 671 | # Generate train/val/test mask. 672 | A_sample_train, train_mask, val_mask, test_mask = prepare_for_GAE(A_sample) 673 | 674 | self.gae_one_layer_evaluator.add_sample( 675 | A_train=A_sample_train, 676 | A_full=A_sample, 677 | X=X_sample, 678 | Y=Y_sample, 679 | train_mask=train_mask, 680 | val_mask=val_mask, 681 | test_mask=test_mask) 682 | 683 | self.gae_two_layer_evaluator.add_sample( 684 | A_train=A_sample_train, 685 | A_full=A_sample, 686 | X=X_sample, 687 | Y=Y_sample, 688 | train_mask=train_mask, 689 | val_mask=val_mask, 690 | test_mask=test_mask) 691 | 692 | self.cn_evaluator.add_sample( 693 | A_train=A_sample_train, 694 | A_full=A_sample, 695 | val_mask=val_mask, 696 | test_mask=test_mask 697 | ) 698 | 699 | def summary(self): 700 | report = dict() 701 | 702 | for key in ["triangle_count", "linkx_A", "linkx_A_pow_2"]: 703 | avg_stats_sample = np.mean([ 704 | data_dict_sample[key] for data_dict_sample in self.data_dict_sample_list 705 | ]) 706 | report[key] = avg_stats_sample / self.data_dict_real[key] 707 | 708 | report["deg_emd"] = get_deg_emd( 709 | [self.data_dict_real["degs"]], 710 | [data_dict_sample["degs"] for data_dict_sample in self.data_dict_sample_list]) 711 | 712 | # clustering coefficient EMD 713 | report["cluster_emd"] = get_cluster_emd( 714 | [self.data_dict_real["cluster_coefs"]], 715 | [data_dict_sample["cluster_coefs"] 716 | for data_dict_sample in self.data_dict_sample_list] 717 | ) 718 | 719 | report["orbit_emd"] = get_pairwise_emd( 720 | [self.data_dict_real["orbit_dist"]], 721 | [data_dict_sample["orbit_dist"] 722 | for data_dict_sample in self.data_dict_sample_list] 723 | ) 724 | 725 | print('\n') 726 | pprint(report) 727 | 728 | print('\nMLP discriminator') 729 | self.mlp_evaluator.summary() 730 | 731 | print('\nSGC 1-layer discriminator') 732 | self.sgc_one_layer_evaluator.summary() 733 | 734 | print('\nSGC 2-layer discriminator') 735 | self.sgc_two_layer_evaluator.summary() 736 | 737 | print('\nGCN discriminator') 738 | self.gcn_evaluator.summary() 739 | 740 | print('\nAPPNP 1-layer discriminator') 741 | self.appnp_one_layer_evaluator.summary() 742 | 743 | print('\nAPPNP 2-layer discriminator') 744 | self.appnp_two_layer_evaluator.summary() 745 | 746 | print('\nGAE 1-layer discriminator') 747 | self.gae_one_layer_evaluator.summary() 748 | 749 | print('\nGAE 2-layer discriminator') 750 | self.gae_two_layer_evaluator.summary() 751 | 752 | print('\nCN discriminator') 753 | self.cn_evaluator.summary() 754 | 755 | real_acc_vector = [ 756 | self.mlp_evaluator.real_real_acc, 757 | self.sgc_one_layer_evaluator.real_real_acc, 758 | self.sgc_two_layer_evaluator.real_real_acc, 759 | self.gcn_evaluator.real_real_acc, 760 | self.appnp_one_layer_evaluator.real_real_acc, 761 | self.appnp_two_layer_evaluator.real_real_acc 762 | ] 763 | pearson_coeff = [] 764 | spearman_coeff = [] 765 | for i in range(len(self.data_dict_sample_list)): 766 | sample_acc_vector = [ 767 | self.mlp_evaluator.sample_sample_acc[i], 768 | self.sgc_one_layer_evaluator.sample_sample_acc[i], 769 | self.sgc_two_layer_evaluator.sample_sample_acc[i], 770 | self.gcn_evaluator.sample_sample_acc[i], 771 | self.appnp_one_layer_evaluator.sample_sample_acc[i], 772 | self.appnp_two_layer_evaluator.sample_sample_acc[i] 773 | ] 774 | pearson_coeff.append(stats.pearsonr(real_acc_vector, sample_acc_vector).statistic) 775 | spearman_coeff.append(stats.spearmanr(real_acc_vector, sample_acc_vector).statistic) 776 | 777 | print(f'\nPearson correlation coefficient: {np.mean(pearson_coeff)}') 778 | print(f'\nSpearman correlation coefficient: {np.mean(spearman_coeff)}') 779 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .common_blocks import * 2 | from .diffusion import * 3 | from .discriminator import * 4 | -------------------------------------------------------------------------------- /model/common_blocks/__init__.py: -------------------------------------------------------------------------------- 1 | from .gae import GAE 2 | from .gcn import GCN 3 | -------------------------------------------------------------------------------- /model/common_blocks/gae.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .gcn import GCN 4 | 5 | class GAE(nn.Module): 6 | def __init__(self, 7 | in_size, 8 | num_layers, 9 | hidden_size, 10 | dropout): 11 | super().__init__() 12 | 13 | self.gcn = GCN(in_size, 14 | hidden_size, 15 | num_layers, 16 | hidden_size, 17 | dropout) 18 | 19 | def forward(self, A, Z): 20 | Z = self.gcn(A, Z) 21 | return Z 22 | -------------------------------------------------------------------------------- /model/common_blocks/gcn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class GCN(nn.Module): 5 | def __init__(self, 6 | in_size, 7 | out_size, 8 | num_layers, 9 | hidden_size, 10 | dropout): 11 | super().__init__() 12 | 13 | self.lins = nn.ModuleList() 14 | 15 | if num_layers >= 2: 16 | self.lins.append(nn.Linear(in_size, hidden_size)) 17 | for _ in range(num_layers - 2): 18 | self.lins.append(nn.Linear(hidden_size, hidden_size)) 19 | self.lins.append(nn.Linear(hidden_size, out_size)) 20 | 21 | else: 22 | self.lins.append(nn.Linear(in_size, out_size)) 23 | 24 | self.dropout = dropout 25 | 26 | def forward(self, A, H): 27 | for lin in self.lins[:-1]: 28 | H = A @ lin(H) 29 | H = F.relu(H) 30 | H = F.dropout(H, p=self.dropout, training=self.training) 31 | return A @ self.lins[-1](H) 32 | -------------------------------------------------------------------------------- /model/diffusion.py: -------------------------------------------------------------------------------- 1 | import dgl.sparse as dglsp 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from .gnn import * 11 | 12 | __all__ = ["ModelSync", "ModelAsync"] 13 | 14 | class MarginalTransition(nn.Module): 15 | """ 16 | Parameters 17 | ---------- 18 | device : torch.device 19 | X_marginal : torch.Tensor of shape (F, 2) 20 | X_marginal[f, :] is the marginal distribution of the f-th node attribute. 21 | E_marginal : torch.Tensor of shape (2) 22 | Marginal distribution of the edge existence. 23 | num_classes_E : int 24 | Number of edge classes. 25 | """ 26 | def __init__(self, 27 | device, 28 | X_marginal, 29 | E_marginal, 30 | num_classes_E): 31 | super().__init__() 32 | 33 | num_attrs_X, num_classes_X = X_marginal.shape 34 | # (F, 2, 2) 35 | self.I_X = torch.eye(num_classes_X, device=device).unsqueeze(0).expand( 36 | num_attrs_X, num_classes_X, num_classes_X).clone() 37 | # (2, 2) 38 | self.I_E = torch.eye(num_classes_E, device=device) 39 | 40 | # (F, 2, 2) 41 | self.m_X = X_marginal.unsqueeze(1).expand( 42 | num_attrs_X, num_classes_X, -1).clone() 43 | # (2, 2) 44 | self.m_E = E_marginal.unsqueeze(0).expand(num_classes_E, -1).clone() 45 | 46 | self.I_X = nn.Parameter(self.I_X, requires_grad=False) 47 | self.I_E = nn.Parameter(self.I_E, requires_grad=False) 48 | 49 | self.m_X = nn.Parameter(self.m_X, requires_grad=False) 50 | self.m_E = nn.Parameter(self.m_E, requires_grad=False) 51 | 52 | def get_Q_bar_E(self, alpha_bar_t): 53 | """Compute the probability transition matrices for obtaining A^t. 54 | 55 | Parameters 56 | ---------- 57 | alpha_bar_t : torch.Tensor of shape (1) 58 | A value in [0, 1]. 59 | 60 | Returns 61 | ------- 62 | Q_bar_t_E : torch.Tensor of shape (2, 2) 63 | Transition matrix for corrupting graph structure at time step t. 64 | """ 65 | Q_bar_t_E = alpha_bar_t * self.I_E + (1 - alpha_bar_t) * self.m_E 66 | 67 | return Q_bar_t_E 68 | 69 | def get_Q_bar_X(self, alpha_bar_t): 70 | """Compute the probability transition matrices for obtaining X^t. 71 | 72 | Parameters 73 | ---------- 74 | alpha_bar_t : torch.Tensor of shape (1) 75 | A value in [0, 1]. 76 | 77 | Returns 78 | ------- 79 | Q_bar_t_X : torch.Tensor of shape (F, 2, 2) 80 | Transition matrix for corrupting node attributes at time step t. 81 | """ 82 | Q_bar_t_X = alpha_bar_t * self.I_X + (1 - alpha_bar_t) * self.m_X 83 | 84 | return Q_bar_t_X 85 | 86 | class NoiseSchedule(nn.Module): 87 | """ 88 | Parameters 89 | ---------- 90 | T : int 91 | Number of diffusion time steps. 92 | device : torch.device 93 | s : float 94 | Small constant for numerical stability. 95 | """ 96 | def __init__(self, T, device, s=0.008): 97 | super().__init__() 98 | 99 | # Cosine schedule as proposed in 100 | # https://arxiv.org/abs/2102.09672 101 | num_steps = T + 2 102 | t = np.linspace(0, num_steps, num_steps) 103 | # Schedule for \bar{alpha}_t = alpha_1 * ... * alpha_t 104 | alpha_bars = np.cos(0.5 * np.pi * ((t / num_steps) + s) / (1 + s)) ** 2 105 | # Make the largest value 1. 106 | alpha_bars = alpha_bars / alpha_bars[0] 107 | alphas = alpha_bars[1:] / alpha_bars[:-1] 108 | 109 | self.betas = torch.from_numpy(1 - alphas).float().to(device) 110 | self.alphas = 1 - torch.clamp(self.betas, min=0, max=0.9999) 111 | 112 | log_alphas = torch.log(self.alphas) 113 | log_alpha_bars = torch.cumsum(log_alphas, dim=0) 114 | self.alpha_bars = torch.exp(log_alpha_bars) 115 | 116 | self.betas = nn.Parameter(self.betas, requires_grad=False) 117 | self.alphas = nn.Parameter(self.alphas, requires_grad=False) 118 | self.alpha_bars = nn.Parameter(self.alpha_bars, requires_grad=False) 119 | 120 | class LossE(nn.Module): 121 | def __init__(self): 122 | super().__init__() 123 | 124 | def forward(self, true_E, logit_E): 125 | """ 126 | Parameters 127 | ---------- 128 | true_E : torch.Tensor of shape (B, 2) 129 | One-hot encoding of the edge existence for a batch of node pairs. 130 | logit_E : torch.Tensor of shape (B, 2) 131 | Predicted logits for the edge existence. 132 | 133 | Returns 134 | ------- 135 | loss_E : torch.Tensor 136 | Scalar representing the loss for edge existence. 137 | """ 138 | true_E = torch.argmax(true_E, dim=-1) # (B) 139 | loss_E = F.cross_entropy(logit_E, true_E) 140 | 141 | return loss_E 142 | 143 | class BaseModel(nn.Module): 144 | """ 145 | Parameters 146 | ---------- 147 | T : int 148 | Number of diffusion time steps - 1. 149 | X_marginal : torch.Tensor of shape (F, 2) 150 | X_marginal[f, :] is the marginal distribution of the f-th node attribute. 151 | Y_marginal : torch.Tensor of shape (C) 152 | Marginal distribution of the node labels. 153 | E_marginal : torch.Tensor of shape (2) 154 | Marginal distribution of the edge existence. 155 | num_nodes : int 156 | Number of nodes in the original graph. 157 | """ 158 | def __init__(self, 159 | T, 160 | X_marginal, 161 | Y_marginal, 162 | E_marginal, 163 | num_nodes): 164 | super().__init__() 165 | 166 | device = X_marginal.device 167 | # 2 for if edge exists or not. 168 | self.num_classes_E = 2 169 | self.num_attrs_X, self.num_classes_X = X_marginal.shape 170 | self.num_classes_Y = len(Y_marginal) 171 | 172 | self.transition = MarginalTransition(device, X_marginal, 173 | E_marginal, self.num_classes_E) 174 | 175 | self.T = T 176 | # Number of intermediate time steps to use for validation. 177 | self.num_denoise_match_samples = self.T 178 | self.noise_schedule = NoiseSchedule(T, device) 179 | 180 | self.num_nodes = num_nodes 181 | 182 | self.X_marginal = X_marginal 183 | self.Y_marginal = Y_marginal 184 | self.E_marginal = E_marginal 185 | 186 | self.loss_E = LossE() 187 | 188 | def sample_E(self, prob_E): 189 | """Sample a graph structure from prob_E. 190 | 191 | Parameters 192 | ---------- 193 | prob_E : torch.Tensor of shape (|V|, |V|, 2) 194 | Probability distribution for edge existence. 195 | 196 | Returns 197 | ------- 198 | E_t : torch.LongTensor of shape (|V|, |V|) 199 | Sampled symmetric adjacency matrix. 200 | """ 201 | # (|V|^2, 1) 202 | E_t = prob_E.reshape(-1, prob_E.size(-1)).multinomial(1) 203 | 204 | # (|V|, |V|) 205 | num_nodes = prob_E.size(0) 206 | E_t = E_t.reshape(num_nodes, num_nodes) 207 | # Make it symmetric for undirected graphs. 208 | src, dst = torch.triu_indices( 209 | num_nodes, num_nodes, device=E_t.device) 210 | E_t[dst, src] = E_t[src, dst] 211 | 212 | return E_t 213 | 214 | def sample_X(self, prob_X): 215 | """Sample node attributes from prob_X. 216 | 217 | Parameters 218 | ---------- 219 | prob_X : torch.Tensor of shape (F, |V|, 2) 220 | Probability distributions for node attributes. 221 | 222 | Returns 223 | ------- 224 | X_t_one_hot : torch.Tensor of shape (|V|, 2 * F) 225 | One-hot encoding of the sampled node attributes. 226 | """ 227 | # (F * |V|) 228 | X_t = prob_X.reshape(-1, prob_X.size(-1)).multinomial(1) 229 | # (F, |V|) 230 | X_t = X_t.reshape(self.num_attrs_X, -1) 231 | # (|V|, 2 * F) 232 | X_t_one_hot = torch.cat([ 233 | F.one_hot(X_t[i], num_classes=self.num_classes_X) 234 | for i in range(self.num_attrs_X) 235 | ], dim=1).float() 236 | 237 | return X_t_one_hot 238 | 239 | def get_adj(self, E_t): 240 | """ 241 | Parameters 242 | ---------- 243 | E_t : torch.LongTensor of shape (|V|, |V|) 244 | Sampled symmetric adjacency matrix. 245 | 246 | Returns 247 | ------- 248 | dglsp.SparseMatrix 249 | Row-normalized adjacency matrix. 250 | """ 251 | # Row normalization. 252 | edges_t = E_t.nonzero().T 253 | num_nodes = E_t.size(0) 254 | A_t = dglsp.spmatrix(edges_t, shape=(num_nodes, num_nodes)) 255 | D_t = dglsp.diag(A_t.sum(1)) ** -1 256 | return D_t @ A_t 257 | 258 | def denoise_match_Z(self, 259 | Z_t_one_hot, 260 | Q_t_Z, 261 | Z_one_hot, 262 | Q_bar_s_Z, 263 | pred_Z): 264 | """Compute the denoising match term for Z given a 265 | sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and 266 | q(D^{t-1}| hat{p}^{D}, D^t). 267 | 268 | Parameters 269 | ---------- 270 | Z_t_one_hot : torch.Tensor of shape (B, C) or (A, B, C) 271 | One-hot encoding of the data sampled at time step t. 272 | Q_t_Z : torch.Tensor of shape (C, C) or (A, C, C) 273 | Transition matrix from time step t - 1 to t. 274 | Z_one_hot : torch.Tensor of shape (B, C) or (A, B, C) 275 | One-hot encoding of the original data. 276 | Q_bar_s_Z : torch.Tensor of shape (C, C) or (A, C, C) 277 | Transition matrix from timestep 0 to t-1. 278 | pred_Z : torch.Tensor of shape (B, C) or (A, B, C) 279 | Predicted probs for the original data. 280 | 281 | Returns 282 | ------- 283 | float 284 | KL value. 285 | """ 286 | # q(Z^{t-1}| Z, Z^t) 287 | left_term = Z_t_one_hot @ torch.transpose(Q_t_Z, -1, -2) # (B, C) or (A, B, C) 288 | right_term = Z_one_hot @ Q_bar_s_Z # (B, C) or (A, B, C) 289 | product = left_term * right_term # (B, C) or (A, B, C) 290 | denom = product.sum(dim=-1) # (B,) or (A, B) 291 | denom[denom == 0.] = 1 292 | prob_true = product / denom.unsqueeze(-1) # (B, C) or (A, B, C) 293 | 294 | # q(Z^{t-1}| hat{p}^{Z}, Z^t) 295 | right_term = pred_Z @ Q_bar_s_Z # (B, C) or (A, B, C) 296 | product = left_term * right_term # (B, C) or (A, B, C) 297 | denom = product.sum(dim=-1) # (B,) or (A, B) 298 | denom[denom == 0.] = 1 299 | prob_pred = product / denom.unsqueeze(-1) # (B, C) or (A, B, C) 300 | 301 | # KL(q(Z^{t-1}| hat{p}^{Z}, Z^t) || q(Z^{t-1}| Z, Z^t)) 302 | kl = F.kl_div(input=prob_pred.log(), target=prob_true, reduction='none') 303 | return kl.clamp(min=0).mean().item() 304 | 305 | def denoise_match_E(self, 306 | t_float, 307 | logit_E, 308 | E_t_one_hot, 309 | E_one_hot): 310 | """Compute the denoising match term for edge prediction given a 311 | sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and 312 | q(D^{t-1}| hat{p}^{D}, D^t). 313 | 314 | Parameters 315 | ---------- 316 | t_float : torch.Tensor of shape (1) 317 | Sampled timestep divided by self.T. 318 | logit_E : torch.Tensor of shape (B, 2) 319 | Predicted logits for the edge existence of a batch of node pairs. 320 | E_t_one_hot : torch.Tensor of shape (B, 2) 321 | One-hot encoding of sampled edge existence for the batch of 322 | node pairs. 323 | E_one_hot : torch.Tensor of shape (B, 2) 324 | One-hot encoding of the original edge existence for the batch of 325 | node pairs. 326 | 327 | Returns 328 | ------- 329 | float 330 | KL value. 331 | """ 332 | t = int(t_float.item() * self.T) 333 | s = t - 1 334 | 335 | alpha_bar_s = self.noise_schedule.alpha_bars[s] 336 | alpha_t = self.noise_schedule.alphas[t] 337 | 338 | Q_bar_s_E = self.transition.get_Q_bar_E(alpha_bar_s) 339 | # Note that computing Q_bar_t from alpha_bar_t is the same 340 | # as computing Q_t from alpha_t. 341 | Q_t_E = self.transition.get_Q_bar_E(alpha_t) 342 | 343 | pred_E = logit_E.softmax(dim=-1) 344 | 345 | return self.denoise_match_Z(E_t_one_hot, 346 | Q_t_E, 347 | E_one_hot, 348 | Q_bar_s_E, 349 | pred_E) 350 | 351 | def posterior(self, 352 | Z_t, 353 | Q_t, 354 | Q_bar_s, 355 | Q_bar_t, 356 | prior): 357 | """Compute the posterior distribution for time step s, i.e., t - 1. 358 | 359 | Parameters 360 | ---------- 361 | Z_t : torch.Tensor of shape (B, 2) or (F, |V|, 2) 362 | One-hot encoding of the sampled data at timestep t. 363 | B for batch size, C for number of classes, D for number 364 | of features. 365 | Q_t : torch.Tensor of shape (2, 2) or (F, 2, 2) 366 | The transition matrix from timestep t-1 to t. 367 | Q_bar_s : torch.Tensor of shape (2, 2) or (F, 2, 2) 368 | The transition matrix from timestep 0 to t-1. 369 | Q_bar_t : torch.Tensor of shape (2, 2) or (F, 2, 2) 370 | The transition matrix from timestep 0 to t. 371 | prior : torch.Tensor of shape (B, 2) or (F, |V|, 2) 372 | Reconstructed prior distribution. 373 | 374 | Returns 375 | ------- 376 | prob : torch.Tensor of shape (B, 2) or (D, B, C) 377 | Posterior distribution. 378 | """ 379 | # (B, 2) or (F, |V|, 2) 380 | left_term = Z_t @ torch.transpose(Q_t, -1, -2) 381 | # (B, 1, 2) or (F, |V|, 1, 2) 382 | left_term = left_term.unsqueeze(dim=-2) 383 | # (1, 2, 2) or (F, 1, 2, 2) 384 | right_term = Q_bar_s.unsqueeze(dim=-3) 385 | # (B, 2, 2) or (F, |V|, 2, 2) 386 | # Different from denoise_match_z, this function does not 387 | # compute (Z_t @ Q_t.T) * (Z_0 @ Q_bar_s) for a specific 388 | # Z_0, but compute for all possible values of Z_0. 389 | numerator = left_term * right_term 390 | 391 | # (2, B) or (F, 2, |V|) 392 | prod = Q_bar_t @ torch.transpose(Z_t, -1, -2) 393 | # (B, 2) or (F, |V|, 2) 394 | prod = torch.transpose(prod, -1, -2) 395 | # (B, 2, 1) or (F, |V|, 2, 1) 396 | denominator = prod.unsqueeze(-1) 397 | denominator[denominator == 0.] = 1. 398 | # (B, 2, 2) or (F, |V|, 2, 2) 399 | out = numerator / denominator 400 | 401 | # (B, 2, 2) or (F, |V|, 2, 2) 402 | prob = prior.unsqueeze(-1) * out 403 | # (B, 2) or (F, |V|, C) 404 | prob = prob.sum(dim=-2) 405 | 406 | return prob 407 | 408 | def sample_E_infer(self, prob_E): 409 | """Draw a sample from prob_E 410 | 411 | Parameters 412 | ---------- 413 | prob_E : torch.Tensor of shape (B, 2) 414 | Probability distributions for edge existence. 415 | 416 | Returns 417 | ------- 418 | E_t : torch.LongTensor of shape (|V|, |V|) 419 | Sampled adjacency matrix. 420 | """ 421 | E_t_ = prob_E.multinomial(1).squeeze(-1) 422 | E_t = torch.zeros(self.num_nodes, self.num_nodes).long().to(E_t_.device) 423 | E_t[self.dst, self.src] = E_t_ 424 | E_t[self.src, self.dst] = E_t_ 425 | 426 | return E_t 427 | 428 | def get_E_t(self, 429 | device, 430 | edge_data_loader, 431 | pred_E_func, 432 | t_float, 433 | X_t_one_hot, 434 | Y_0, 435 | E_t, 436 | Q_t_E, 437 | Q_bar_s_E, 438 | Q_bar_t_E, 439 | batch_size): 440 | A_t = self.get_adj(E_t) 441 | E_prob = torch.zeros(len(self.src), self.num_classes_E).to(device) 442 | 443 | start = 0 444 | for batch_edge_index in edge_data_loader: 445 | # (B, 2) 446 | batch_edge_index = batch_edge_index.to(device) 447 | batch_dst, batch_src = batch_edge_index.T 448 | # Reconstruct the edges. 449 | # (B, 2) 450 | batch_pred_E = pred_E_func(t_float, 451 | X_t_one_hot, 452 | Y_0, 453 | A_t, 454 | batch_src, 455 | batch_dst) 456 | 457 | batch_pred_E = batch_pred_E.softmax(dim=-1) 458 | 459 | # (B, 2) 460 | batch_E_t_one_hot = F.one_hot( 461 | E_t[batch_src, batch_dst], 462 | num_classes=self.num_classes_E).float() 463 | batch_E_prob_ = self.posterior(batch_E_t_one_hot, Q_t_E, 464 | Q_bar_s_E, Q_bar_t_E, batch_pred_E) 465 | 466 | end = start + batch_size 467 | E_prob[start: end] = batch_E_prob_ 468 | start = end 469 | 470 | E_t = self.sample_E_infer(E_prob) 471 | 472 | return A_t, E_t 473 | 474 | class LossX(nn.Module): 475 | """ 476 | Parameters 477 | ---------- 478 | num_attrs_X : int 479 | Number of node attributes. 480 | num_classes_X : int 481 | Number of classes for each node attribute. 482 | """ 483 | def __init__(self, 484 | num_attrs_X, 485 | num_classes_X): 486 | super().__init__() 487 | 488 | self.num_attrs_X = num_attrs_X 489 | self.num_classes_X = num_classes_X 490 | 491 | def forward(self, true_X, logit_X): 492 | """ 493 | Parameters 494 | ---------- 495 | true_X : torch.Tensor of shape (F, |V|, 2) 496 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute. 497 | logit_X : torch.Tensor of shape (|V|, F, 2) 498 | Predicted logits for the node attributes. 499 | 500 | Returns 501 | ------- 502 | loss_X : torch.Tensor 503 | Scalar representing the loss for node attributes. 504 | """ 505 | true_X = true_X.transpose(0, 1) # (|V|, F, 2) 506 | # v1x1, v1x2, ..., v1xd, v2x1, ... 507 | true_X = true_X.reshape(-1, true_X.size(-1)) # (|V| * F, 2) 508 | 509 | # v1x1, v1x2, ..., v1xd, v2x1, ... 510 | logit_X = logit_X.reshape(true_X.size(0), -1) # (|V| * F, 2) 511 | 512 | true_X = torch.argmax(true_X, dim=-1) # (|V| * F) 513 | loss_X = F.cross_entropy(logit_X, true_X) 514 | 515 | return loss_X 516 | 517 | class ModelSync(BaseModel): 518 | """ 519 | Parameters 520 | ---------- 521 | T : int 522 | Number of diffusion time steps - 1. 523 | X_marginal : torch.Tensor of shape (F, 2) 524 | X_marginal[f, :] is the marginal distribution of the f-th node attribute. 525 | Y_marginal : torch.Tensor of shape (C) 526 | Marginal distribution of the node labels. 527 | E_marginal : torch.Tensor of shape (2) 528 | Marginal distribution of the edge existence. 529 | gnn_X_config : dict 530 | Configuration of the GNN for reconstructing node attributes. 531 | gnn_E_config : dict 532 | Configuration of the GNN for reconstructing edges. 533 | num_nodes : int 534 | Number of nodes in the original graph. 535 | """ 536 | def __init__(self, 537 | T, 538 | X_marginal, 539 | Y_marginal, 540 | E_marginal, 541 | gnn_X_config, 542 | gnn_E_config, 543 | num_nodes): 544 | super().__init__(T=T, 545 | X_marginal=X_marginal, 546 | Y_marginal=Y_marginal, 547 | E_marginal=E_marginal, 548 | num_nodes=num_nodes) 549 | 550 | self.graph_encoder = GNN(num_attrs_X=self.num_attrs_X, 551 | num_classes_X=self.num_classes_X, 552 | num_classes_Y=self.num_classes_Y, 553 | num_classes_E=self.num_classes_E, 554 | gnn_X_config=gnn_X_config, 555 | gnn_E_config=gnn_E_config) 556 | 557 | self.loss_X = LossX(self.num_attrs_X, self.num_classes_X) 558 | 559 | def apply_noise(self, X_one_hot_3d, E_one_hot, t=None): 560 | """Corrupt G and sample G^t. 561 | 562 | Parameters 563 | ---------- 564 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 565 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute 566 | in the real graph. 567 | E_one_hot : torch.Tensor of shape (|V|, |V|, 2) 568 | - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph. 569 | - E_one_hot[:, :, 1] is the adjacency matrix of the real graph. 570 | t : torch.LongTensor of shape (1), optional 571 | If specified, a time step will be enforced rather than sampled. 572 | 573 | Returns 574 | ------- 575 | t_float : torch.Tensor of shape (1) 576 | Sampled timestep divided by self.T. 577 | X_t_one_hot : torch.Tensor of shape (|V|, 2 * F) 578 | One-hot encoding of the sampled node attributes. 579 | E_t : torch.LongTensor of shape (|V|, |V|) 580 | Sampled symmetric adjacency matrix. 581 | """ 582 | if t is None: 583 | # Sample a timestep t uniformly. 584 | # Note that the notation is slightly inconsistent with the paper. 585 | # t=0 corresponds to t=1 in the paper, where corruption has already taken place. 586 | t = torch.randint(low=0, high=self.T + 1, size=(1,), 587 | device=X_one_hot_3d.device) 588 | 589 | alpha_bar_t = self.noise_schedule.alpha_bars[t] 590 | 591 | # Sample A^t. 592 | Q_bar_t_E = self.transition.get_Q_bar_E(alpha_bar_t) # (2, 2) 593 | prob_E = E_one_hot @ Q_bar_t_E # (|V|, |V|, 2) 594 | E_t = self.sample_E(prob_E) # (|V|, |V|) 595 | 596 | # Sample X^t. 597 | Q_bar_t_X = self.transition.get_Q_bar_X(alpha_bar_t) # (F, 2, 2) 598 | # Compute matrix multiplication over the first batch dimension. 599 | prob_X = torch.bmm(X_one_hot_3d, Q_bar_t_X) # (F, |V|, 2) 600 | X_t_one_hot = self.sample_X(prob_X) 601 | 602 | t_float = t / self.T 603 | 604 | return t_float, X_t_one_hot, E_t 605 | 606 | def log_p_t(self, 607 | X_one_hot_3d, 608 | E_one_hot, 609 | Y, 610 | batch_src, 611 | batch_dst, 612 | batch_E_one_hot, 613 | t=None): 614 | """Obtain G^t and compute log p(G | G^t, Y, t). 615 | 616 | Parameters 617 | ---------- 618 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 619 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute 620 | in the real graph. 621 | E_one_hot : torch.Tensor of shape (|V|, |V|, 2) 622 | - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph. 623 | - E_one_hot[:, :, 1] is the adjacency matrix of the real graph. 624 | Y : torch.Tensor of shape (|V|) 625 | Categorical node labels of the real graph. 626 | batch_src : torch.LongTensor of shape (B) 627 | Source node IDs for a batch of candidate edges (node pairs). 628 | batch_dst : torch.LongTensor of shape (B) 629 | Destination node IDs for a batch of candidate edges (node pairs). 630 | batch_E_one_hot : torch.Tensor of shape (B, 2) 631 | E_one_hot[batch_dst, batch_src]. 632 | t : torch.LongTensor of shape (1), optional 633 | If specified, a time step will be enforced rather than sampled. 634 | 635 | Returns 636 | ------- 637 | loss_X : torch.Tensor 638 | Scalar representing the loss for node attributes. 639 | loss_E : torch.Tensor 640 | Scalar representing the loss for edge existence. 641 | """ 642 | t_float, X_t_one_hot, E_t = self.apply_noise(X_one_hot_3d, E_one_hot, t) 643 | A_t = self.get_adj(E_t) 644 | logit_X, logit_E = self.graph_encoder(t_float, 645 | X_t_one_hot, 646 | Y, 647 | A_t, 648 | batch_src, 649 | batch_dst) 650 | 651 | loss_X = self.loss_X(X_one_hot_3d, logit_X) 652 | loss_E = self.loss_E(batch_E_one_hot, logit_E) 653 | 654 | return loss_X, loss_E 655 | 656 | def denoise_match_X(self, 657 | t_float, 658 | logit_X, 659 | X_t_one_hot, 660 | X_one_hot_3d): 661 | """Compute the denoising match term for node attribute prediction given a 662 | sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and 663 | q(D^{t-1}| hat{p}^{D}, D^t). 664 | 665 | Parameters 666 | ---------- 667 | t_float : torch.Tensor of shape (1) 668 | Sampled timestep divided by self.T. 669 | logit_X : torch.Tensor of shape (|V|, F, 2) 670 | Predicted logits for the node attributes. 671 | X_t_one_hot : torch.Tensor of shape (|V|, 2 * F) 672 | One-hot encoding of the node attributes sampled at time step t. 673 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 674 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute. 675 | 676 | Returns 677 | ------- 678 | float 679 | KL value for node attributes. 680 | """ 681 | t = int(t_float.item() * self.T) 682 | s = t - 1 683 | 684 | alpha_bar_s = self.noise_schedule.alpha_bars[s] 685 | alpha_t = self.noise_schedule.alphas[t] 686 | 687 | Q_bar_s_X = self.transition.get_Q_bar_X(alpha_bar_s) 688 | # Note that computing Q_bar_t from alpha_bar_t is the same 689 | # as computing Q_t from alpha_t. 690 | Q_t_X = self.transition.get_Q_bar_X(alpha_t) 691 | 692 | # (|V|, F, 2) 693 | pred_X = logit_X.softmax(dim=-1) 694 | # (F, |V|, 2) 695 | pred_X = torch.transpose(pred_X, 0, 1) 696 | 697 | num_nodes = X_t_one_hot.size(0) 698 | # (|V|, F, 2) 699 | X_t_one_hot = X_t_one_hot.reshape(num_nodes, self.num_attrs_X, -1) 700 | # (F, |V|, 2) 701 | X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1) 702 | 703 | return self.denoise_match_Z(X_t_one_hot, 704 | Q_t_X, 705 | X_one_hot_3d, 706 | Q_bar_s_X, 707 | pred_X) 708 | 709 | @torch.no_grad() 710 | def val_step(self, 711 | X_one_hot_3d, 712 | E_one_hot, 713 | Y, 714 | batch_src, 715 | batch_dst, 716 | batch_E_one_hot): 717 | """Perform a validation step. 718 | 719 | Parameters 720 | ---------- 721 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 722 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute 723 | in the real graph. 724 | E_one_hot : torch.Tensor of shape (|V|, |V|, 2) 725 | - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph. 726 | - E_one_hot[:, :, 1] is the adjacency matrix of the real graph. 727 | Y : torch.Tensor of shape (|V|) 728 | Categorical node labels of the real graph. 729 | batch_src : torch.LongTensor of shape (B) 730 | Source node IDs for a batch of candidate edges (node pairs). 731 | batch_dst : torch.LongTensor of shape (B) 732 | Destination node IDs for a batch of candidate edges (node pairs). 733 | batch_E_one_hot : torch.Tensor of shape (B, 2) 734 | E_one_hot[batch_dst, batch_src]. 735 | 736 | Returns 737 | ------- 738 | denoise_match_E : float 739 | Denoising matching term for edge. 740 | denoise_match_X : float 741 | Denoising matching term for node attributes. 742 | log_p_0_E : float 743 | Reconstruction term for edge. 744 | log_p_0_X : float 745 | Reconstruction term for node attributes. 746 | """ 747 | device = E_one_hot.device 748 | 749 | denoise_match_X = [] 750 | denoise_match_E = [] 751 | 752 | # t=0 is handled separately. 753 | for t_sample in range(1, self.T + 1): 754 | t = torch.LongTensor([t_sample]).to(device) 755 | t_float, X_t_one_hot, E_t = self.apply_noise( 756 | X_one_hot_3d, E_one_hot, t) 757 | A_t = self.get_adj(E_t) 758 | logit_X, logit_E = self.graph_encoder(t_float, 759 | X_t_one_hot, 760 | Y, 761 | A_t, 762 | batch_src, 763 | batch_dst) 764 | 765 | E_t_one_hot = F.one_hot(E_t[batch_src, batch_dst], 766 | num_classes=self.num_classes_E).float() 767 | denoise_match_E_t = self.denoise_match_E(t_float, 768 | logit_E, 769 | E_t_one_hot, 770 | batch_E_one_hot) 771 | denoise_match_E.append(denoise_match_E_t) 772 | 773 | denoise_match_X_t = self.denoise_match_X(t_float, 774 | logit_X, 775 | X_t_one_hot, 776 | X_one_hot_3d) 777 | denoise_match_X.append(denoise_match_X_t) 778 | 779 | denoise_match_E = float(np.mean(denoise_match_E)) * self.T 780 | denoise_match_X = float(np.mean(denoise_match_X)) * self.T 781 | 782 | # t=0 783 | t_0 = torch.LongTensor([0]).to(device) 784 | loss_X, loss_E = self.log_p_t(X_one_hot_3d, 785 | E_one_hot, 786 | Y, 787 | batch_src, 788 | batch_dst, 789 | batch_E_one_hot, 790 | t_0) 791 | log_p_0_E = loss_E.item() 792 | log_p_0_X = loss_X.item() 793 | 794 | return denoise_match_E, denoise_match_X,\ 795 | log_p_0_E, log_p_0_X 796 | 797 | @torch.no_grad() 798 | def sample(self, batch_size=32768, num_workers=4): 799 | """Sample a graph. 800 | 801 | Parameters 802 | ---------- 803 | batch_size : int 804 | Batch size for edge prediction. 805 | num_workers : int 806 | Number of subprocesses for data loading in edge prediction. 807 | 808 | Returns 809 | ------- 810 | X_t_one_hot : torch.Tensor of shape (F, |V|, 2) 811 | One-hot encoding of the generated node attributes. 812 | Y_0_one_hot : torch.Tensor of shape (|V|, C) 813 | One-hot encoding of the generated node labels. 814 | E_t : torch.LongTensor of shape (|V|, |V|) 815 | Adjacency matrix of the generated graph. 816 | """ 817 | device = self.X_marginal.device 818 | dst, src = torch.triu_indices(self.num_nodes, self.num_nodes, 819 | offset=1, device=device) 820 | # (|E|) 821 | self.dst = dst 822 | # (|E|) 823 | self.src = src 824 | # (|E|, 2) 825 | edge_index = torch.stack([dst, src], dim=1).to("cpu") 826 | data_loader = DataLoader(edge_index, batch_size=batch_size, 827 | num_workers=num_workers) 828 | 829 | # Sample G^T from prior distribution. 830 | # (|V|, C) 831 | Y_prior = self.Y_marginal[None, :].expand(self.num_nodes, -1) 832 | # (|V|) 833 | Y_0 = Y_prior.multinomial(1).reshape(-1) 834 | 835 | # (|V|, |V|, 2) 836 | E_prior = self.E_marginal[None, None, :].expand( 837 | self.num_nodes, self.num_nodes, -1) 838 | # (|V|, |V|) 839 | E_t = self.sample_E(E_prior) 840 | 841 | # (F, |V|, 2) 842 | X_prior = self.X_marginal[:, None, :].expand(-1, self.num_nodes, -1) 843 | # (|V|, 2F) 844 | X_t_one_hot = self.sample_X(X_prior) 845 | 846 | # Iteratively sample p(D^s | D^t) for t = 1, ..., T, with s = t - 1. 847 | for s in tqdm(list(reversed(range(0, self.T)))): 848 | t = s + 1 849 | 850 | # Note that computing Q_bar_t from alpha_bar_t is the same 851 | # as computing Q_t from alpha_t. 852 | alpha_t = self.noise_schedule.alphas[t] 853 | alpha_bar_s = self.noise_schedule.alpha_bars[s] 854 | alpha_bar_t = self.noise_schedule.alpha_bars[t] 855 | 856 | Q_t_E = self.transition.get_Q_bar_E(alpha_t) 857 | Q_bar_s_E = self.transition.get_Q_bar_E(alpha_bar_s) 858 | Q_bar_t_E = self.transition.get_Q_bar_E(alpha_bar_t) 859 | 860 | t_float = torch.tensor([t / self.T]).to(device) 861 | 862 | A_t, E_s = self.get_E_t(device, 863 | data_loader, 864 | self.graph_encoder.pred_E, 865 | t_float, 866 | X_t_one_hot, 867 | Y_0, 868 | E_t, 869 | Q_t_E, 870 | Q_bar_s_E, 871 | Q_bar_t_E, 872 | batch_size) 873 | 874 | # (|V|, F, 2) 875 | pred_X = self.graph_encoder.pred_X(t_float, 876 | X_t_one_hot, 877 | Y_0, 878 | A_t) 879 | pred_X = pred_X.softmax(dim=-1) 880 | # (F, |V|, 2) 881 | pred_X = torch.transpose(pred_X, 0, 1) 882 | 883 | # (|V|, F, 2) 884 | X_t_one_hot = X_t_one_hot.reshape(self.num_nodes, self.num_attrs_X, -1) 885 | # (F, |V|, 2) 886 | X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1) 887 | 888 | # (F, |V|, 2) 889 | Q_t_X = self.transition.get_Q_bar_X(alpha_t) 890 | Q_bar_s_X = self.transition.get_Q_bar_X(alpha_bar_s) 891 | Q_bar_t_X = self.transition.get_Q_bar_X(alpha_bar_t) 892 | 893 | X_prob = self.posterior(X_t_one_hot, Q_t_X, 894 | Q_bar_s_X, Q_bar_t_X, pred_X) 895 | X_t_one_hot = self.sample_X(X_prob) 896 | E_t = E_s 897 | 898 | # (|V|, F, 2) 899 | X_t_one_hot = X_t_one_hot.reshape(self.num_nodes, self.num_attrs_X, -1) 900 | # (F, |V|, 2) 901 | X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1) 902 | 903 | Y_0_one_hot = F.one_hot(Y_0, num_classes=self.num_classes_Y).float() 904 | 905 | return X_t_one_hot, Y_0_one_hot, E_t 906 | 907 | class ModelAsync(BaseModel): 908 | """ 909 | Parameters 910 | ---------- 911 | T_X : int 912 | Number of diffusion time steps - 1 for node attributes. 913 | T_E : int 914 | Number of diffusion time steps - 1 for edges. 915 | X_marginal : torch.Tensor of shape (F, 2) 916 | X_marginal[f, :] is the marginal distribution of the f-th node attribute. 917 | Y_marginal : torch.Tensor of shape (C) 918 | Marginal distribution of the node labels. 919 | E_marginal : torch.Tensor of shape (2) 920 | Marginal distribution of the edge existence. 921 | mlp_X_config : dict 922 | Configuration of the MLP for reconstructing node attributes. 923 | gnn_E_config : dict 924 | Configuration of the GNN for reconstructing edges. 925 | num_nodes : int 926 | Number of nodes in the original graph. 927 | """ 928 | def __init__(self, 929 | T_X, 930 | T_E, 931 | X_marginal, 932 | Y_marginal, 933 | E_marginal, 934 | mlp_X_config, 935 | gnn_E_config, 936 | num_nodes): 937 | super().__init__(T=T_X, 938 | X_marginal=X_marginal, 939 | Y_marginal=Y_marginal, 940 | E_marginal=E_marginal, 941 | num_nodes=num_nodes) 942 | 943 | del self.T 944 | del self.noise_schedule 945 | 946 | self.T_X = T_X 947 | self.T_E = T_E 948 | 949 | device = X_marginal.device 950 | self.noise_schedule_X = NoiseSchedule(T_X, device) 951 | self.noise_schedule_E = NoiseSchedule(T_E, device) 952 | 953 | self.graph_encoder = GNNAsymm(num_attrs_X=self.num_attrs_X, 954 | num_classes_X=self.num_classes_X, 955 | num_classes_Y=self.num_classes_Y, 956 | num_classes_E=self.num_classes_E, 957 | mlp_X_config=mlp_X_config, 958 | gnn_E_config=gnn_E_config) 959 | 960 | self.loss_X = LossX(self.num_attrs_X, self.num_classes_X) 961 | 962 | def apply_noise_X(self, X_one_hot_3d, t=None): 963 | """Corrupt X and sample X^t. 964 | 965 | Parameters 966 | ---------- 967 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 968 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute 969 | in the real graph. 970 | t : torch.LongTensor of shape (1), optional 971 | If specified, a time step will be enforced rather than sampled. 972 | 973 | Returns 974 | ------- 975 | t_float_X : torch.Tensor of shape (1) 976 | Sampled timestep divided by self.T_X. 977 | X_t_one_hot : torch.Tensor of shape (|V|, 2 * F) 978 | One-hot encoding of the sampled node attributes. 979 | """ 980 | if t is None: 981 | # Sample a timestep t uniformly. 982 | # Note that the notation is slightly inconsistent with the paper. 983 | # t=0 corresponds to t=1 in the paper, where corruption has already taken place. 984 | t = torch.randint(low=0, high=self.T_X + 1, size=(1,), 985 | device=X_one_hot_3d.device) 986 | 987 | alpha_bar_t = self.noise_schedule_X.alpha_bars[t] 988 | 989 | Q_bar_t_X = self.transition.get_Q_bar_X(alpha_bar_t) # (F, 2, 2) 990 | # Compute matrix multiplication over the first batch dimension. 991 | prob_X = torch.bmm(X_one_hot_3d, Q_bar_t_X) # (F, |V|, 2) 992 | 993 | # Sample X_t. 994 | X_t_one_hot = self.sample_X(prob_X) 995 | 996 | t_float_X = t / self.T_X 997 | 998 | return t_float_X, X_t_one_hot 999 | 1000 | def apply_noise_E(self, E_one_hot, t=None): 1001 | """Corrupt A and sample A^t. 1002 | 1003 | Parameters 1004 | ---------- 1005 | E_one_hot : torch.Tensor of shape (|V|, |V|, 2) 1006 | - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph. 1007 | - E_one_hot[:, :, 1] is the adjacency matrix of the real graph. 1008 | t : torch.LongTensor of shape (1), optional 1009 | If specified, a time step will be enforced rather than sampled. 1010 | 1011 | Returns 1012 | ------- 1013 | t_float_E : torch.Tensor of shape (1) 1014 | Sampled timestep divided by self.T_E. 1015 | E_t : torch.LongTensor of shape (|V|, |V|) 1016 | Sampled symmetric adjacency matrix. 1017 | """ 1018 | if t is None: 1019 | # Sample a timestep t uniformly. 1020 | # Note that the notation is slightly inconsistent with the paper. 1021 | # t=0 corresponds to t=1 in the paper, where corruption has already taken place. 1022 | t = torch.randint(low=0, high=self.T_E + 1, size=(1,), 1023 | device=E_one_hot.device) 1024 | 1025 | alpha_bar_t = self.noise_schedule_E.alpha_bars[t] 1026 | 1027 | Q_bar_t_E = self.transition.get_Q_bar_E(alpha_bar_t) # (2, 2) 1028 | prob_E = E_one_hot @ Q_bar_t_E # (|V|, |V|, 2) 1029 | E_t = self.sample_E(prob_E) 1030 | 1031 | t_float_E = t / self.T_E 1032 | 1033 | return t_float_E, E_t 1034 | 1035 | def log_p_t(self, 1036 | X_one_hot_3d, 1037 | E_one_hot, 1038 | Y, 1039 | X_one_hot_2d, 1040 | batch_src, 1041 | batch_dst, 1042 | batch_E_one_hot, 1043 | t_X=None, 1044 | t_E=None): 1045 | """Obtain G^t and compute log p(G | G^t, Y, t). 1046 | 1047 | Parameters 1048 | ---------- 1049 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 1050 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute 1051 | in the real graph. 1052 | E_one_hot : torch.Tensor of shape (|V|, |V|, 2) 1053 | - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph. 1054 | - E_one_hot[:, :, 1] is the adjacency matrix of the real graph. 1055 | Y : torch.Tensor of shape (|V|) 1056 | Categorical node labels of the real graph. 1057 | X_one_hot_2d : torch.Tensor of shape (|V|, 2 * F) 1058 | Flattened one-hot encoding of the node attributes in the real graph. 1059 | batch_src : torch.LongTensor of shape (B) 1060 | Source node IDs for a batch of candidate edges (node pairs). 1061 | batch_dst : torch.LongTensor of shape (B) 1062 | Destination node IDs for a batch of candidate edges (node pairs). 1063 | batch_E_one_hot : torch.Tensor of shape (B, 2) 1064 | E_one_hot[batch_dst, batch_src]. 1065 | t_X : torch.LongTensor of shape (1), optional 1066 | If specified, a time step will be enforced rather than sampled for 1067 | node attributes. 1068 | t_E : torch.LongTensor of shape (1), optional 1069 | If specified, a time step will be enforced rather than sampled for 1070 | edges. 1071 | 1072 | Returns 1073 | ------- 1074 | loss_X : torch.Tensor 1075 | Scalar representing the loss for node attributes. 1076 | loss_E : torch.Tensor 1077 | Scalar representing the loss for edge existence. 1078 | """ 1079 | t_float_X, X_t_one_hot = self.apply_noise_X(X_one_hot_3d, t_X) 1080 | t_float_E, E_t = self.apply_noise_E(E_one_hot, t_E) 1081 | A_t = self.get_adj(E_t) 1082 | logit_X, logit_E = self.graph_encoder(t_float_X, 1083 | t_float_E, 1084 | X_t_one_hot, 1085 | Y, 1086 | X_one_hot_2d, 1087 | A_t, 1088 | batch_src, 1089 | batch_dst) 1090 | 1091 | loss_X = self.loss_X(X_one_hot_3d, logit_X) 1092 | loss_E = self.loss_E(batch_E_one_hot, logit_E) 1093 | 1094 | return loss_X, loss_E 1095 | 1096 | def denoise_match_X(self, 1097 | t_float, 1098 | logit_X, 1099 | X_t_one_hot, 1100 | X_one_hot_3d): 1101 | """Compute the denoising match term for node attribute prediction given a 1102 | sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and 1103 | q(D^{t-1}| hat{p}^{D}, D^t). 1104 | 1105 | Parameters 1106 | ---------- 1107 | t_float : torch.Tensor of shape (1) 1108 | Sampled timestep divided by self.T_X. 1109 | logit_X : torch.Tensor of shape (|V|, F, 2) 1110 | Predicted logits for the node attributes. 1111 | X_t_one_hot : torch.Tensor of shape (|V|, 2 * F) 1112 | One-hot encoding of the node attributes sampled at time step t. 1113 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 1114 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute. 1115 | 1116 | Returns 1117 | ------- 1118 | float 1119 | KL value for node attributes. 1120 | """ 1121 | t = int(t_float.item() * self.T_X) 1122 | s = t - 1 1123 | 1124 | alpha_bar_s = self.noise_schedule_X.alpha_bars[s] 1125 | alpha_t = self.noise_schedule_X.alphas[t] 1126 | 1127 | Q_bar_s_X = self.transition.get_Q_bar_X(alpha_bar_s) 1128 | # Note that computing Q_bar_t from alpha_bar_t is the same 1129 | # as computing Q_t from alpha_t. 1130 | Q_t_X = self.transition.get_Q_bar_X(alpha_t) 1131 | 1132 | # (|V|, F, 2) 1133 | pred_X = logit_X.softmax(dim=-1) 1134 | # (F, |V|, 2) 1135 | pred_X = torch.transpose(pred_X, 0, 1) 1136 | 1137 | num_nodes = X_t_one_hot.size(0) 1138 | # (|V|, F, 2) 1139 | X_t_one_hot = X_t_one_hot.reshape(num_nodes, self.num_attrs_X, -1) 1140 | # (F, |V|, 2) 1141 | X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1) 1142 | 1143 | return self.denoise_match_Z(X_t_one_hot, 1144 | Q_t_X, 1145 | X_one_hot_3d, 1146 | Q_bar_s_X, 1147 | pred_X) 1148 | 1149 | def denoise_match_E(self, 1150 | t_float, 1151 | logit_E, 1152 | E_t_one_hot, 1153 | E_one_hot): 1154 | """Compute the denoising match term for edge prediction given a 1155 | sampled t, i.e., the KL divergence between q(D^{t-1}| D, D^t) and 1156 | q(D^{t-1}| hat{p}^{D}, D^t). 1157 | 1158 | Parameters 1159 | ---------- 1160 | t_float : torch.Tensor of shape (1) 1161 | Sampled timestep divided by self.T_E. 1162 | logit_E : torch.Tensor of shape (B, 2) 1163 | Predicted logits for the edge existence of a batch of node pairs. 1164 | E_t_one_hot : torch.Tensor of shape (B, 2) 1165 | One-hot encoding of sampled edge existence for the batch of 1166 | node pairs. 1167 | E_one_hot : torch.Tensor of shape (B, 2) 1168 | One-hot encoding of the original edge existence for the batch of 1169 | node pairs. 1170 | 1171 | Returns 1172 | ------- 1173 | float 1174 | KL value for edges. 1175 | """ 1176 | t = int(t_float.item() * self.T_E) 1177 | s = t - 1 1178 | 1179 | alpha_bar_s = self.noise_schedule_E.alpha_bars[s] 1180 | alpha_t = self.noise_schedule_E.alphas[t] 1181 | 1182 | Q_bar_s_E = self.transition.get_Q_bar_E(alpha_bar_s) 1183 | # Note that computing Q_bar_t from alpha_bar_t is the same 1184 | # as computing Q_t from alpha_t. 1185 | Q_t_E = self.transition.get_Q_bar_E(alpha_t) 1186 | 1187 | pred_E = logit_E.softmax(dim=-1) 1188 | 1189 | return self.denoise_match_Z(E_t_one_hot, 1190 | Q_t_E, 1191 | E_one_hot, 1192 | Q_bar_s_E, 1193 | pred_E) 1194 | 1195 | @torch.no_grad() 1196 | def val_step(self, 1197 | X_one_hot_3d, 1198 | E_one_hot, 1199 | Y, 1200 | X_one_hot_2d, 1201 | batch_src, 1202 | batch_dst, 1203 | batch_E_one_hot): 1204 | """Perform a validation step. 1205 | 1206 | Parameters 1207 | ---------- 1208 | X_one_hot_3d : torch.Tensor of shape (F, |V|, 2) 1209 | X_one_hot_3d[f, :, :] is the one-hot encoding of the f-th node attribute 1210 | in the real graph. 1211 | E_one_hot : torch.Tensor of shape (|V|, |V|, 2) 1212 | - E_one_hot[:, :, 0] indicates the absence of an edge in the real graph. 1213 | - E_one_hot[:, :, 1] is the adjacency matrix of the real graph. 1214 | Y : torch.Tensor of shape (|V|) 1215 | Categorical node labels of the real graph. 1216 | X_one_hot_2d : torch.Tensor of shape (|V|, 2 * F) 1217 | Flattened one-hot encoding of the node attributes in the real graph. 1218 | batch_src : torch.LongTensor of shape (B) 1219 | Source node IDs for a batch of candidate edges (node pairs). 1220 | batch_dst : torch.LongTensor of shape (B) 1221 | Destination node IDs for a batch of candidate edges (node pairs). 1222 | batch_E_one_hot : torch.Tensor of shape (B, 2) 1223 | E_one_hot[batch_dst, batch_src]. 1224 | 1225 | Returns 1226 | ------- 1227 | denoise_match_E : float 1228 | Denoising matching term for edge. 1229 | denoise_match_X : float 1230 | Denoising matching term for node attributes. 1231 | log_p_0_E : float 1232 | Reconstruction term for edge. 1233 | log_p_0_X : float 1234 | Reconstruction term for node attributes. 1235 | """ 1236 | device = E_one_hot.device 1237 | 1238 | # Case1: X 1239 | denoise_match_X = [] 1240 | # t=0 is handled separately. 1241 | for t_sample_X in range(1, self.T_X + 1): 1242 | t_X = torch.LongTensor([t_sample_X]).to(device) 1243 | t_float_X, X_t_one_hot = self.apply_noise_X(X_one_hot_3d, t_X) 1244 | logit_X = self.graph_encoder.pred_X(t_float_X, 1245 | X_t_one_hot, 1246 | Y) 1247 | 1248 | denoise_match_X_t = self.denoise_match_X(t_float_X, 1249 | logit_X, 1250 | X_t_one_hot, 1251 | X_one_hot_3d) 1252 | denoise_match_X.append(denoise_match_X_t) 1253 | denoise_match_X = float(np.mean(denoise_match_X)) * self.T_X 1254 | 1255 | # Case2: E 1256 | denoise_match_E = [] 1257 | # t=0 is handled separately. 1258 | for t_sample_E in range(1, self.T_E + 1): 1259 | t_E = torch.LongTensor([t_sample_E]).to(device) 1260 | t_float_E, E_t = self.apply_noise_E(E_one_hot, t_E) 1261 | A_t = self.get_adj(E_t) 1262 | logit_E = self.graph_encoder.pred_E(t_float_E, 1263 | X_one_hot_2d, 1264 | Y, 1265 | A_t, 1266 | batch_src, 1267 | batch_dst) 1268 | 1269 | E_t_one_hot = F.one_hot(E_t[batch_src, batch_dst], 1270 | num_classes=self.num_classes_E).float() 1271 | denoise_match_E_t = self.denoise_match_E(t_float_E, 1272 | logit_E, 1273 | E_t_one_hot, 1274 | batch_E_one_hot) 1275 | denoise_match_E.append(denoise_match_E_t) 1276 | denoise_match_E = float(np.mean(denoise_match_E)) * self.T_E 1277 | 1278 | # t=0 1279 | t_0 = torch.LongTensor([0]).to(device) 1280 | loss_X, loss_E = self.log_p_t(X_one_hot_3d, 1281 | E_one_hot, 1282 | Y, 1283 | X_one_hot_2d, 1284 | batch_src, 1285 | batch_dst, 1286 | batch_E_one_hot, 1287 | t_X=t_0, 1288 | t_E=t_0) 1289 | log_p_0_E = loss_E.item() 1290 | log_p_0_X = loss_X.item() 1291 | 1292 | return denoise_match_E, denoise_match_X,\ 1293 | log_p_0_E, log_p_0_X 1294 | 1295 | @torch.no_grad() 1296 | def sample(self, batch_size=32768, num_workers=4): 1297 | """Sample a graph. 1298 | 1299 | Parameters 1300 | ---------- 1301 | batch_size : int 1302 | Batch size for edge prediction. 1303 | num_workers : int 1304 | Number of subprocesses for data loading in edge prediction. 1305 | 1306 | Returns 1307 | ------- 1308 | X_t_one_hot : torch.Tensor of shape (F, |V|, 2) 1309 | One-hot encoding of the generated node attributes. 1310 | Y_0_one_hot : torch.Tensor of shape (|V|, C) 1311 | One-hot encoding of the generated node labels. 1312 | E_t : torch.LongTensor of shape (|V|, |V|) 1313 | Adjacency matrix of the generated graph. 1314 | """ 1315 | device = self.X_marginal.device 1316 | 1317 | # Sample Y_0 1318 | # (|V|, C) 1319 | Y_prior = self.Y_marginal[None, :].expand(self.num_nodes, -1) 1320 | # (|V|) 1321 | Y_0 = Y_prior.multinomial(1).reshape(-1) 1322 | 1323 | # Sample X^T from prior distribution. 1324 | # (F, |V|, 2) 1325 | X_prior = self.X_marginal[:, None, :].expand(-1, self.num_nodes, -1) 1326 | # (|V|, 2F) 1327 | X_t_one_hot = self.sample_X(X_prior) 1328 | 1329 | # Iteratively sample p(X^s | X^t) for t = 1, ..., T_X, with s = t - 1. 1330 | for s_X in tqdm(list(reversed(range(0, self.T_X)))): 1331 | t_X = s_X + 1 1332 | t_float_X = torch.tensor([t_X / self.T_X]).to(device) 1333 | pred_X = self.graph_encoder.pred_X(t_float_X, 1334 | X_t_one_hot, 1335 | Y_0) 1336 | pred_X = pred_X.softmax(dim=-1) 1337 | # (F, |V|, 2) 1338 | pred_X = torch.transpose(pred_X, 0, 1) 1339 | 1340 | # (|V|, F, 2) 1341 | X_t_one_hot = X_t_one_hot.reshape(self.num_nodes, self.num_attrs_X, -1) 1342 | # (F, |V|, 2) 1343 | X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1) 1344 | 1345 | # Note that computing Q_bar_t from alpha_bar_t is the same 1346 | # as computing Q_t from alpha_t. 1347 | alpha_t_X = self.noise_schedule_X.alphas[t_X] 1348 | alpha_bar_s_X = self.noise_schedule_X.alpha_bars[s_X] 1349 | alpha_bar_t_X = self.noise_schedule_X.alpha_bars[t_X] 1350 | 1351 | # (F, |V|, 2) 1352 | Q_t_X = self.transition.get_Q_bar_X(alpha_t_X) 1353 | Q_bar_s_X = self.transition.get_Q_bar_X(alpha_bar_s_X) 1354 | Q_bar_t_X = self.transition.get_Q_bar_X(alpha_bar_t_X) 1355 | 1356 | X_prob = self.posterior(X_t_one_hot, Q_t_X, 1357 | Q_bar_s_X, Q_bar_t_X, pred_X) 1358 | X_t_one_hot = self.sample_X(X_prob) 1359 | 1360 | # Sample E^T from prior distribution. 1361 | # (|V|, |V|, 2) 1362 | E_prior = self.E_marginal[None, None, :].expand( 1363 | self.num_nodes, self.num_nodes, -1) 1364 | # (|V|, |V|) 1365 | E_t = self.sample_E(E_prior) 1366 | 1367 | dst, src = torch.triu_indices(self.num_nodes, self.num_nodes, 1368 | offset=1, device=device) 1369 | # (|E|) 1370 | self.dst = dst 1371 | # (|E|) 1372 | self.src = src 1373 | # (|E|, 2) 1374 | edge_index = torch.stack([dst, src], dim=1).to("cpu") 1375 | data_loader = DataLoader(edge_index, batch_size=batch_size, 1376 | num_workers=num_workers) 1377 | 1378 | # Iteratively sample p(A^s | A^t) for t = 1, ..., T_E, with s = t - 1. 1379 | for s_E in tqdm(list(reversed(range(0, self.T_E)))): 1380 | t_E = s_E + 1 1381 | alpha_t_E = self.noise_schedule_E.alphas[t_E] 1382 | alpha_bar_s_E = self.noise_schedule_E.alpha_bars[s_E] 1383 | alpha_bar_t_E = self.noise_schedule_E.alpha_bars[t_E] 1384 | 1385 | Q_t_E = self.transition.get_Q_bar_E(alpha_t_E) 1386 | Q_bar_s_E = self.transition.get_Q_bar_E(alpha_bar_s_E) 1387 | Q_bar_t_E = self.transition.get_Q_bar_E(alpha_bar_t_E) 1388 | 1389 | t_float_E = torch.tensor([t_E / self.T_E]).to(device) 1390 | 1391 | _, E_t = self.get_E_t(device, 1392 | data_loader, 1393 | self.graph_encoder.pred_E, 1394 | t_float_E, 1395 | X_t_one_hot, 1396 | Y_0, 1397 | E_t, 1398 | Q_t_E, 1399 | Q_bar_s_E, 1400 | Q_bar_t_E, 1401 | batch_size) 1402 | 1403 | # (|V|, F, 2) 1404 | X_t_one_hot = X_t_one_hot.reshape(self.num_nodes, self.num_attrs_X, -1) 1405 | # (F, |V|, 2) 1406 | X_t_one_hot = torch.transpose(X_t_one_hot, 0, 1) 1407 | 1408 | Y_0_one_hot = F.one_hot(Y_0, num_classes=self.num_classes_Y).float() 1409 | 1410 | return X_t_one_hot, Y_0_one_hot, E_t 1411 | -------------------------------------------------------------------------------- /model/discriminator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseEvaluator 2 | from .mlp import MLPTrainer 3 | from .gcn import GCNTrainer 4 | from .gae import GAETrainer 5 | from .sgc import SGCTrainer 6 | from .appnp import APPNPTrainer 7 | from .cn import CNEvaluator 8 | -------------------------------------------------------------------------------- /model/discriminator/appnp.py: -------------------------------------------------------------------------------- 1 | import dgl.sparse as dglsp 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from copy import deepcopy 7 | from tqdm import tqdm 8 | 9 | from .gcn import GCNTrainer 10 | 11 | class APPNP(nn.Module): 12 | def __init__(self, 13 | in_size, 14 | out_size, 15 | num_trans_layers, 16 | hidden_size, 17 | dropout, 18 | num_prop_layers, 19 | alpha): 20 | super().__init__() 21 | 22 | assert num_trans_layers >= 2 23 | self.lins = nn.ModuleList() 24 | self.lins.append(nn.Linear(in_size, hidden_size)) 25 | for _ in range(num_trans_layers - 2): 26 | self.lins.append(nn.Linear(hidden_size, hidden_size)) 27 | self.lins.append(nn.Linear(hidden_size, out_size)) 28 | 29 | self.dropout = dropout 30 | 31 | self.num_prop_layers = num_prop_layers 32 | self.alpha = alpha 33 | 34 | def forward(self, A, H): 35 | # Predict. 36 | for lin in self.lins[:-1]: 37 | H = lin(H) 38 | H = F.relu(H) 39 | H = F.dropout(H, p=self.dropout, training=self.training) 40 | H_local = self.lins[-1](H) 41 | 42 | # Propagate. 43 | H = H_local 44 | for _ in range(self.num_prop_layers): 45 | A_drop = dglsp.val_like( 46 | A, F.dropout(A.val, p=self.dropout, training=self.training)) 47 | H = A_drop @ H + self.alpha * H_local 48 | return H 49 | 50 | class APPNPTrainer(GCNTrainer): 51 | def __init__(self, num_gnn_layers): 52 | hyper_space = { 53 | "lr": [3e-2, 1e-2, 3e-3], 54 | "num_trans_layers": [2], 55 | "hidden_size": [32, 128, 512], 56 | "dropout": [0., 0.1], 57 | "num_prop_layers": [num_gnn_layers], 58 | "alpha": [0.1, 0.2] 59 | } 60 | search_priority_increasing = [ 61 | "dropout", 62 | "alpha", 63 | "lr", 64 | "num_trans_layers", 65 | "num_prop_layers", 66 | "hidden_size"] 67 | 68 | super().__init__(num_gnn_layers=num_gnn_layers, 69 | hyper_space=hyper_space, 70 | search_priority_increasing=search_priority_increasing, 71 | patience=5) 72 | 73 | def fit_trial(self, 74 | A, 75 | X, 76 | Y, 77 | num_classes, 78 | train_mask, 79 | val_mask, 80 | num_trans_layers, 81 | hidden_size, 82 | dropout, 83 | num_prop_layers, 84 | alpha, 85 | lr): 86 | model = APPNP(in_size=X.size(1), 87 | out_size=num_classes, 88 | num_trans_layers=num_trans_layers, 89 | hidden_size=hidden_size, 90 | dropout=dropout, 91 | num_prop_layers=num_prop_layers, 92 | alpha=alpha).to(self.device) 93 | loss_func = nn.CrossEntropyLoss() 94 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 95 | 96 | num_epochs = 1000 97 | num_patient_epochs = 0 98 | best_acc = 0 99 | best_model_state_dict = deepcopy(model.state_dict()) 100 | for epoch in range(1, num_epochs + 1): 101 | model.train() 102 | logits = model(A, X) 103 | loss = loss_func(logits[train_mask], Y[train_mask]) 104 | optimizer.zero_grad() 105 | loss.backward() 106 | optimizer.step() 107 | 108 | acc = self.predict(A, X, Y, val_mask, model) 109 | 110 | if acc > best_acc: 111 | num_patient_epochs = 0 112 | best_acc = acc 113 | best_model_state_dict = deepcopy(model.state_dict()) 114 | else: 115 | num_patient_epochs += 1 116 | 117 | if num_patient_epochs == self.patience: 118 | break 119 | 120 | model.load_state_dict(best_model_state_dict) 121 | return best_acc, model 122 | 123 | def fit(self, A, X, Y, num_classes, train_mask, val_mask): 124 | """ 125 | Parameters 126 | ---------- 127 | A : dgl.sparse.SparseMatrix 128 | Adjacency matrix. 129 | X : torch.Tensor of shape (|V|, D) 130 | Binary node features. 131 | Y : torch.Tensor of shape (|V|,) 132 | Node labels. 133 | num_classes : int 134 | Number of node classes. 135 | train_mask : torch.Tensor of shape (|V|) 136 | Mask indicating training nodes. 137 | val_mask : torch.Tensor of shape (|V|) 138 | Mask indicating validation nodes. 139 | """ 140 | A, X, Y = self.preprocess(A, X, Y) 141 | 142 | config_list = self.get_config_list() 143 | 144 | best_acc = 0 145 | with tqdm(config_list) as tconfig: 146 | tconfig.set_description( 147 | f"Training APPNP {self.num_gnn_layers}-layer discriminator") 148 | 149 | for config in tconfig: 150 | trial_acc, trial_model = self.fit_trial(A, 151 | X, 152 | Y, 153 | num_classes, 154 | train_mask, 155 | val_mask, 156 | **config) 157 | 158 | if trial_acc > best_acc: 159 | best_acc = trial_acc 160 | best_model = trial_model 161 | best_model_config = { 162 | "in_size": X.size(1), 163 | "out_size": num_classes, 164 | "num_trans_layers": config["num_trans_layers"], 165 | "hidden_size": config["hidden_size"], 166 | "dropout": config["dropout"], 167 | "num_prop_layers": config["num_prop_layers"], 168 | "alpha": config["alpha"] 169 | } 170 | 171 | tconfig.set_postfix(accuracy=100. * best_acc) 172 | 173 | if trial_acc == 1.0: 174 | break 175 | self.model = best_model 176 | self.best_model_config = best_model_config 177 | 178 | def load_model(self, model_path): 179 | state_dict = torch.load(model_path) 180 | model = APPNP(**state_dict["model_config"]).to(self.device) 181 | model.load_state_dict(state_dict["model_state_dict"]) 182 | self.model = model 183 | -------------------------------------------------------------------------------- /model/discriminator/base.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import os 4 | import torch 5 | 6 | class BaseTrainer: 7 | def __init__(self, 8 | hyper_space, 9 | search_priority_increasing, 10 | patience): 11 | """Base class for training a discriminative model. 12 | 13 | Parameters 14 | ---------- 15 | search_priority_increasing : list of str 16 | The priority of hyperparameters to search, from lowest to highest. 17 | """ 18 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 19 | self.device = torch.device(device) 20 | 21 | self.hyper_space = hyper_space 22 | self.search_priority_increasing = search_priority_increasing 23 | self.patience = patience 24 | 25 | def get_config_list(self): 26 | vals = [self.hyper_space[k] for k in self.search_priority_increasing] 27 | 28 | config_list = [] 29 | for items in itertools.product(*vals): 30 | items_dict = dict(zip(self.search_priority_increasing, items)) 31 | config_list.append(items_dict) 32 | 33 | return config_list 34 | 35 | def save_model(self, model_path): 36 | torch.save({ 37 | "model_state_dict": self.model.state_dict(), 38 | "model_config": self.best_model_config 39 | }, model_path) 40 | 41 | class BaseEvaluator: 42 | def __init__(self, 43 | Trainer, 44 | model_path, 45 | num_classes, 46 | train_mask, 47 | val_mask, 48 | test_mask, 49 | **real_data): 50 | self.Trainer = Trainer 51 | self.num_classes = num_classes 52 | self.train_mask_real = train_mask 53 | self.val_mask_real = val_mask 54 | self.test_mask_real = test_mask 55 | self.real_data = real_data 56 | 57 | self.model_real = Trainer() 58 | if os.path.exists(model_path): 59 | self.model_real.load_model(model_path) 60 | else: 61 | self.model_real.fit(num_classes=num_classes, 62 | train_mask=train_mask, 63 | val_mask=val_mask, 64 | **real_data) 65 | self.model_real.save_model(model_path) 66 | 67 | self.real_real_acc = self.model_real.predict( 68 | mask=test_mask, **real_data) 69 | 70 | self.real_sample_acc = [] 71 | self.sample_real_acc = [] 72 | self.sample_sample_acc = [] 73 | 74 | def add_sample(self, 75 | train_mask, 76 | val_mask, 77 | test_mask, 78 | **sample_data): 79 | self.real_sample_acc.append( 80 | self.model_real.predict(mask=test_mask, **sample_data) 81 | ) 82 | 83 | model_sample = self.Trainer() 84 | model_sample.fit(num_classes=self.num_classes, 85 | train_mask=train_mask, 86 | val_mask=val_mask, 87 | **sample_data) 88 | 89 | self.sample_real_acc.append( 90 | model_sample.predict(mask=self.test_mask_real, **self.real_data) 91 | ) 92 | 93 | self.sample_sample_acc.append( 94 | model_sample.predict(mask=test_mask, **sample_data) 95 | ) 96 | 97 | def summary(self): 98 | mean_sample_real_acc = np.mean(self.sample_real_acc) 99 | print(f"ACC(G|G_hat) / ACC(G|G): {mean_sample_real_acc / self.real_real_acc}") 100 | 101 | # print(f"ACC/AUC(G|G): {self.real_real_acc}") 102 | # print(f"ACC/AUC(G|G_hat): {mean_sample_real_acc}") 103 | # mean_sample_sample_acc = np.mean(self.sample_sample_acc) 104 | # print(f"ACC/AUC(G_hat|G_hat): {mean_sample_sample_acc}") 105 | # mean_real_sample_acc = np.mean(self.real_sample_acc) 106 | # print(f"ACC/AUC(G_hat|G): {mean_real_sample_acc}") 107 | -------------------------------------------------------------------------------- /model/discriminator/cn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | 6 | class CN(nn.Module): 7 | def __init__(self, batch_size = 65536): 8 | super().__init__() 9 | 10 | self.best_threshold = nn.Parameter(torch.tensor(0), requires_grad=False) 11 | self.batch_size = batch_size 12 | 13 | def fit(self, A_train, A_full, val_mask): 14 | A_train = A_train.to_dense() 15 | A_full = A_full.to_dense() 16 | 17 | val_src, val_dst = val_mask.nonzero().T 18 | label = A_full[val_src, val_dst] 19 | label[label != 0] = 1 20 | label = label.cpu() 21 | 22 | A_train[A_train != 0.] = 1. 23 | 24 | num_batches = len(val_src) // self.batch_size 25 | if len(val_src) % self.batch_size != 0: 26 | num_batches += 1 27 | 28 | start = 0 29 | pred_list = [] 30 | for i in range(num_batches): 31 | end = start + self.batch_size 32 | 33 | batch_src = val_src[start:end] 34 | batch_dst = val_dst[start:end] 35 | batch_pred = (A_train[batch_src] * A_train[batch_dst]).sum(dim=-1) 36 | batch_pred = batch_pred.cpu() 37 | pred_list.append(batch_pred) 38 | 39 | start = end 40 | 41 | pred = torch.cat(pred_list) 42 | 43 | thresholds = pred.unique() 44 | acc_list = [] 45 | for bar in thresholds: 46 | pred_bar = (pred >= bar).float() 47 | acc_bar = (pred_bar == label).float().mean() 48 | acc_list.append(acc_bar.item()) 49 | 50 | self.best_threshold = nn.Parameter( 51 | thresholds[np.argmax(acc_list)], requires_grad=False) 52 | 53 | def predict(self, A_train, A_full, mask): 54 | A_train = A_train.to_dense() 55 | A_full = A_full.to_dense() 56 | 57 | src, dst = mask.nonzero().T 58 | label = A_full[src, dst] 59 | label[label != 0] = 1 60 | label = label.cpu() 61 | 62 | A_train[A_train != 0.] = 1. 63 | 64 | num_batches = len(src) // self.batch_size 65 | if len(src) % self.batch_size != 0: 66 | num_batches += 1 67 | 68 | start = 0 69 | pred_list = [] 70 | for i in range(num_batches): 71 | end = start + self.batch_size 72 | 73 | batch_src = src[start:end] 74 | batch_dst = dst[start:end] 75 | batch_pred = (A_train[batch_src] * A_train[batch_dst]).sum(dim=-1) 76 | 77 | batch_pred = batch_pred.cpu() 78 | batch_pred = (batch_pred >= self.best_threshold).float() 79 | pred_list.append(batch_pred) 80 | 81 | start = end 82 | 83 | pred = torch.cat(pred_list) 84 | 85 | return (pred == label).float().mean().item() 86 | 87 | class CNEvaluator: 88 | def __init__(self, 89 | model_path, 90 | A_train, 91 | A_full, 92 | val_mask, 93 | test_mask): 94 | self.real_A_train = A_train 95 | self.real_A_full = A_full 96 | self.real_test_mask = test_mask 97 | 98 | self.sample_sample_acc = [] 99 | 100 | self.model_real = CN() 101 | if os.path.exists(model_path): 102 | self.model_real.load_state_dict(torch.load(model_path)) 103 | else: 104 | self.model_real.fit(A_train, A_full, val_mask) 105 | torch.save(self.model_real.state_dict(), model_path) 106 | 107 | self.real_real_acc = self.model_real.predict(A_train, A_full, test_mask) 108 | 109 | self.real_sample_acc = [] 110 | self.sample_real_acc = [] 111 | self.sample_sample_acc = [] 112 | 113 | def add_sample(self, 114 | A_train, 115 | A_full, 116 | val_mask, 117 | test_mask): 118 | self.real_sample_acc.append( 119 | self.model_real.predict(A_train, A_full, test_mask) 120 | ) 121 | 122 | model_sample = CN() 123 | model_sample.fit(A_train, A_full, val_mask) 124 | 125 | self.sample_real_acc.append( 126 | model_sample.predict( 127 | self.real_A_train, 128 | self.real_A_full, 129 | self.real_test_mask) 130 | ) 131 | 132 | self.sample_sample_acc.append( 133 | model_sample.predict(A_train, A_full, test_mask) 134 | ) 135 | 136 | def summary(self): 137 | mean_sample_real_acc = np.mean(self.sample_real_acc) 138 | print(f"ACC(G|G_hat) / ACC(G|G): {mean_sample_real_acc / self.real_real_acc}") 139 | 140 | # print(f"ACC/AUC(G|G): {self.real_real_acc}") 141 | # print(f"ACC/AUC(G|G_hat): {mean_sample_real_acc}") 142 | # mean_sample_sample_acc = np.mean(self.sample_sample_acc) 143 | # print(f"ACC/AUC(G_hat|G_hat): {mean_sample_sample_acc}") 144 | # mean_real_sample_acc = np.mean(self.real_sample_acc) 145 | # print(f"ACC/AUC(G_hat|G): {mean_real_sample_acc}") 146 | -------------------------------------------------------------------------------- /model/discriminator/gae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from copy import deepcopy 7 | from sklearn.metrics import roc_auc_score 8 | from tqdm import tqdm 9 | 10 | from .base import BaseTrainer 11 | from ..common_blocks import GAE 12 | 13 | class GAETrainer(BaseTrainer): 14 | def __init__(self, num_gnn_layers): 15 | hyper_space = { 16 | "lr": [3e-2, 1e-2, 3e-3], 17 | "num_layers": [num_gnn_layers], 18 | "hidden_size": [32, 128, 512], 19 | "dropout": [0., 0.1, 0.2] 20 | } 21 | search_priority_increasing = ["dropout", "lr", "num_layers", "hidden_size"] 22 | 23 | super().__init__(hyper_space=hyper_space, 24 | search_priority_increasing=search_priority_increasing, 25 | patience=5) 26 | 27 | self.num_gnn_layers = num_gnn_layers 28 | 29 | def preprocess(self, A_train, A_full, X, Y): 30 | A_train = A_train.to(self.device) 31 | A_full = A_full.to(self.device) 32 | X = X.to(self.device).float() 33 | Y = Y.to(self.device) 34 | 35 | # row normalize 36 | X = F.normalize(X, p=1, dim=1) 37 | Y = F.one_hot(Y.long(), self.num_classes) 38 | Z = torch.cat([X, Y], dim=1) 39 | 40 | A_full_dense = A_full.to_dense() 41 | A_full_dense[A_full_dense != 0] = 1. 42 | 43 | return A_train, Z, A_full_dense 44 | 45 | @torch.no_grad() 46 | def predict_fit(self, A, Z, A_dense, mask, model): 47 | model.eval() 48 | Z_out = model(A, Z) 49 | prob = torch.sigmoid(Z_out @ Z_out.T)[mask].cpu().numpy() 50 | label = A_dense[mask].cpu().numpy() 51 | return roc_auc_score(label, prob) 52 | 53 | def fit_trial(self, 54 | A_train, 55 | Z, 56 | A_full_dense, 57 | train_mask, 58 | val_mask, 59 | num_layers, 60 | hidden_size, 61 | dropout, 62 | lr): 63 | 64 | model = GAE(in_size=Z.size(1), 65 | num_layers=num_layers, 66 | hidden_size=hidden_size, 67 | dropout=dropout).to(self.device) 68 | loss_func = nn.BCEWithLogitsLoss() 69 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 70 | 71 | num_epochs = 1000 72 | num_patient_epochs = 0 73 | best_auc = 0 74 | best_model_state_dict = deepcopy(model.state_dict()) 75 | 76 | num_nodes = Z.size(0) 77 | train_dst, train_src = train_mask.nonzero().T 78 | train_size = len(train_dst) 79 | 80 | batch_size = 16384 81 | for epoch in range(1, num_epochs + 1): 82 | model.train() 83 | 84 | Z_out = model(A_train, Z) 85 | 86 | if train_size <= batch_size: 87 | batch_dst = train_dst 88 | batch_src = train_src 89 | else: 90 | batch_ids = torch.randint(low=0, high=train_size, size=(batch_size,), 91 | device=self.device) 92 | batch_dst = train_dst[batch_ids] 93 | batch_src = train_src[batch_ids] 94 | 95 | pos_pred = (Z_out[batch_src] * Z_out[batch_dst]).sum(dim=-1) 96 | 97 | real_batch_size = len(batch_dst) 98 | neg_src = torch.randint(0, num_nodes, (real_batch_size,), 99 | device=self.device) 100 | neg_dst = torch.randint(0, num_nodes, (real_batch_size,), 101 | device=self.device) 102 | neg_pred = (Z_out[neg_src] * Z_out[neg_dst]).sum(dim=-1) 103 | 104 | pred = torch.cat([pos_pred, neg_pred], dim=0) 105 | label = torch.cat([torch.ones(real_batch_size), 106 | torch.zeros(real_batch_size)], dim=0).to(self.device) 107 | loss = loss_func(pred, label) 108 | optimizer.zero_grad() 109 | loss.backward() 110 | optimizer.step() 111 | 112 | auc = self.predict_fit(A_train, Z, A_full_dense, val_mask, model) 113 | 114 | if auc > best_auc: 115 | num_patient_epochs = 0 116 | best_auc = auc 117 | best_model_state_dict = deepcopy(model.state_dict()) 118 | else: 119 | num_patient_epochs += 1 120 | 121 | if num_patient_epochs == self.patience: 122 | break 123 | 124 | model.load_state_dict(best_model_state_dict) 125 | return best_auc, model 126 | 127 | def fit(self, A_train, A_full, X, Y, num_classes, 128 | train_mask, val_mask): 129 | """ 130 | Parameters 131 | ---------- 132 | A_train : dgl.sparse.SparseMatrix 133 | Training adjacency matrix. 134 | A_full : dgl.sparse.SparseMatrix 135 | Full adjacency matrix. 136 | X : torch.Tensor of shape (|V|, D) 137 | Binary node features. 138 | Y : torch.Tensor of shape (|V|,) 139 | Node labels. 140 | num_classes : int 141 | Number of node classes. 142 | train_mask : torch.Tensor of shape (|V|, |V|) 143 | Mask indicating training edges. 144 | val_mask : torch.Tensor of shape (|V|, |V|) 145 | Mask indicating validation edges. 146 | """ 147 | self.num_classes = num_classes 148 | A_train, Z, A_full_dense = self.preprocess( 149 | A_train, A_full, X, Y) 150 | 151 | config_list = self.get_config_list() 152 | 153 | best_auc = 0 154 | with tqdm(config_list) as tconfig: 155 | tconfig.set_description( 156 | f"Training GAE {self.num_gnn_layers}-layer discriminator") 157 | 158 | for config in tconfig: 159 | trial_auc, trial_model = self.fit_trial(A_train, 160 | Z, 161 | A_full_dense, 162 | train_mask, 163 | val_mask, 164 | **config) 165 | 166 | if trial_auc > best_auc: 167 | best_auc = trial_auc 168 | best_model = trial_model 169 | best_model_config = { 170 | "in_size": Z.size(1), 171 | "num_layers": config["num_layers"], 172 | "hidden_size": config["hidden_size"], 173 | "dropout": config["dropout"] 174 | } 175 | 176 | tconfig.set_postfix(roc_auc=100. * best_auc) 177 | 178 | if trial_auc == 1.0: 179 | break 180 | self.model = best_model 181 | self.best_model_config = best_model_config 182 | 183 | @torch.no_grad() 184 | def predict(self, A_train, A_full, X, Y, mask): 185 | A_train, Z, A_full_dense = self.preprocess( 186 | A_train, A_full, X, Y) 187 | 188 | model = self.model 189 | model.eval() 190 | Z_out = model(A_train, Z) 191 | prob = torch.sigmoid(Z_out @ Z_out.T)[mask].cpu().numpy() 192 | label = A_full_dense[mask].cpu().numpy() 193 | return roc_auc_score(label, prob) 194 | 195 | def save_model(self, model_path): 196 | torch.save({ 197 | "model_state_dict": self.model.state_dict(), 198 | "model_config": self.best_model_config, 199 | "num_classes": self.num_classes 200 | }, model_path) 201 | 202 | def load_model(self, model_path): 203 | state_dict = torch.load(model_path) 204 | model = GAE(**state_dict["model_config"]).to(self.device) 205 | model.load_state_dict(state_dict["model_state_dict"]) 206 | self.model = model 207 | self.num_classes = state_dict["num_classes"] 208 | 209 | def summary(self): 210 | mean_sample_real_acc = np.mean(self.sample_real_acc) 211 | print(f"AUC(G|G_hat) / AUC(G|G): {mean_sample_real_acc / self.real_real_acc}") 212 | -------------------------------------------------------------------------------- /model/discriminator/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from copy import deepcopy 6 | from tqdm import tqdm 7 | 8 | from .base import BaseTrainer 9 | from ..common_blocks import GCN 10 | 11 | class GCNTrainer(BaseTrainer): 12 | def __init__(self, 13 | num_gnn_layers=None, 14 | hyper_space=None, 15 | search_priority_increasing=None, 16 | patience=5): 17 | 18 | if hyper_space is None: 19 | hyper_space = { 20 | "lr": [3e-2, 1e-2, 3e-3], 21 | "num_layers": [num_gnn_layers], 22 | "hidden_size": [32, 128, 512], 23 | "dropout": [0., 0.1, 0.2] 24 | } 25 | 26 | if search_priority_increasing is None: 27 | search_priority_increasing = ["dropout", "lr", "num_layers", "hidden_size"] 28 | 29 | super().__init__(hyper_space=hyper_space, 30 | search_priority_increasing=search_priority_increasing, 31 | patience=patience) 32 | 33 | self.num_gnn_layers = num_gnn_layers 34 | 35 | def preprocess(self, A, X, Y): 36 | A = A.to(self.device) 37 | X = X.to(self.device).float() 38 | Y = Y.to(self.device) 39 | 40 | # row normalize 41 | X = F.normalize(X, p=1, dim=1) 42 | 43 | return A, X, Y 44 | 45 | def fit_trial(self, 46 | A, 47 | X, 48 | Y, 49 | num_classes, 50 | train_mask, 51 | val_mask, 52 | num_layers, 53 | hidden_size, 54 | dropout, 55 | lr): 56 | model = GCN(in_size=X.size(1), 57 | out_size=num_classes, 58 | num_layers=num_layers, 59 | hidden_size=hidden_size, 60 | dropout=dropout).to(self.device) 61 | loss_func = nn.CrossEntropyLoss() 62 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 63 | 64 | num_epochs = 1000 65 | num_patient_epochs = 0 66 | best_acc = 0 67 | best_model_state_dict = deepcopy(model.state_dict()) 68 | for epoch in range(1, num_epochs + 1): 69 | model.train() 70 | logits = model(A, X) 71 | loss = loss_func(logits[train_mask], Y[train_mask]) 72 | optimizer.zero_grad() 73 | loss.backward() 74 | optimizer.step() 75 | 76 | acc = self.predict(A, X, Y, val_mask, model) 77 | 78 | if acc > best_acc: 79 | num_patient_epochs = 0 80 | best_acc = acc 81 | best_model_state_dict = deepcopy(model.state_dict()) 82 | else: 83 | num_patient_epochs += 1 84 | 85 | if num_patient_epochs == self.patience: 86 | break 87 | 88 | model.load_state_dict(best_model_state_dict) 89 | return best_acc, model 90 | 91 | def fit(self, A, X, Y, num_classes, train_mask, val_mask): 92 | """ 93 | Parameters 94 | ---------- 95 | A : dgl.sparse.SparseMatrix 96 | Adjacency matrix. 97 | X : torch.Tensor of shape (|V|, D) 98 | Binary node features. 99 | Y : torch.Tensor of shape (|V|,) 100 | Node labels. 101 | num_classes : int 102 | Number of node classes. 103 | train_mask : torch.Tensor of shape (|V|) 104 | Mask indicating training nodes. 105 | val_mask : torch.Tensor of shape (|V|) 106 | Mask indicating validation nodes. 107 | """ 108 | A, X, Y = self.preprocess(A, X, Y) 109 | 110 | config_list = self.get_config_list() 111 | 112 | best_acc = 0 113 | with tqdm(config_list) as tconfig: 114 | tconfig.set_description( 115 | f"Training GCN {self.num_gnn_layers}-layer discriminator") 116 | 117 | for config in tconfig: 118 | trial_acc, trial_model = self.fit_trial(A, 119 | X, 120 | Y, 121 | num_classes, 122 | train_mask, 123 | val_mask, 124 | **config) 125 | 126 | if trial_acc > best_acc: 127 | best_acc = trial_acc 128 | best_model = trial_model 129 | best_model_config = { 130 | "in_size": X.size(1), 131 | "out_size": num_classes, 132 | "num_layers": config["num_layers"], 133 | "hidden_size": config["hidden_size"], 134 | "dropout": config["dropout"] 135 | } 136 | 137 | tconfig.set_postfix(accuracy=100. * best_acc) 138 | 139 | if trial_acc == 1.0: 140 | break 141 | self.model = best_model 142 | self.best_model_config = best_model_config 143 | 144 | @torch.no_grad() 145 | def predict(self, A, X, Y, mask, model=None): 146 | A, X, Y = self.preprocess(A, X, Y) 147 | 148 | if model is None: 149 | model = self.model 150 | 151 | model.eval() 152 | logits = model(A, X)[mask] 153 | pred = logits.argmax(dim=-1, keepdim=True).reshape(-1) 154 | return (pred == Y[mask]).float().mean().item() 155 | 156 | def load_model(self, model_path): 157 | state_dict = torch.load(model_path) 158 | model = GCN(**state_dict["model_config"]).to(self.device) 159 | model.load_state_dict(state_dict["model_state_dict"]) 160 | self.model = model 161 | -------------------------------------------------------------------------------- /model/discriminator/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from copy import deepcopy 6 | from tqdm import tqdm 7 | 8 | from .base import BaseTrainer 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, 12 | in_size, 13 | out_size, 14 | num_layers, 15 | hidden_size, 16 | dropout): 17 | super().__init__() 18 | 19 | assert num_layers >= 2 20 | self.lins = nn.ModuleList() 21 | self.lins.append(nn.Linear(in_size, hidden_size)) 22 | for _ in range(num_layers - 2): 23 | self.lins.append(nn.Linear(hidden_size, hidden_size)) 24 | self.lins.append(nn.Linear(hidden_size, out_size)) 25 | 26 | self.dropout = dropout 27 | 28 | def forward(self, h): 29 | for lin in self.lins[:-1]: 30 | h = lin(h) 31 | h = F.relu(h) 32 | h = F.dropout(h, p=self.dropout, training=self.training) 33 | return self.lins[-1](h) 34 | 35 | class MLPTrainer(BaseTrainer): 36 | def __init__(self): 37 | hyper_space = { 38 | "lr": [3e-2, 1e-2, 3e-3], 39 | "num_layers": [2, 3], 40 | "hidden_size": [32, 128, 512], 41 | "dropout": [0., 0.1, 0.2] 42 | } 43 | search_priority_increasing = ["dropout", "lr", "num_layers", "hidden_size"] 44 | 45 | super().__init__(hyper_space=hyper_space, 46 | search_priority_increasing=search_priority_increasing, 47 | patience=5) 48 | 49 | def preprocess(self, X, Y): 50 | X = X.to(self.device).float() 51 | Y = Y.to(self.device) 52 | 53 | # row normalize 54 | X = F.normalize(X, p=1, dim=1) 55 | 56 | return X, Y 57 | 58 | def fit_trial(self, 59 | X, 60 | Y, 61 | num_classes, 62 | train_mask, 63 | val_mask, 64 | num_layers, 65 | hidden_size, 66 | dropout, 67 | lr): 68 | model = MLP(in_size=X.size(1), 69 | out_size=num_classes, 70 | num_layers=num_layers, 71 | hidden_size=hidden_size, 72 | dropout=dropout).to(self.device) 73 | loss_func = nn.CrossEntropyLoss() 74 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 75 | 76 | num_epochs = 1000 77 | num_patient_epochs = 0 78 | best_acc = 0 79 | best_model_state_dict = deepcopy(model.state_dict()) 80 | for epoch in range(1, num_epochs + 1): 81 | model.train() 82 | logits = model(X) 83 | loss = loss_func(logits[train_mask], Y[train_mask]) 84 | optimizer.zero_grad() 85 | loss.backward() 86 | optimizer.step() 87 | 88 | acc = self.predict(X, Y, val_mask, model) 89 | 90 | if acc > best_acc: 91 | num_patient_epochs = 0 92 | best_acc = acc 93 | best_model_state_dict = deepcopy(model.state_dict()) 94 | else: 95 | num_patient_epochs += 1 96 | 97 | if num_patient_epochs == self.patience: 98 | break 99 | 100 | model.load_state_dict(best_model_state_dict) 101 | return best_acc, model 102 | 103 | def fit(self, X, Y, num_classes, train_mask, val_mask): 104 | """ 105 | Parameters 106 | ---------- 107 | X : torch.Tensor of shape (|V|, D) 108 | Binary node features. 109 | Y : torch.Tensor of shape (|V|,) 110 | Node labels. 111 | num_classes : int 112 | Number of node classes. 113 | train_mask : torch.Tensor of shape (|V|) 114 | Mask indicating training nodes. 115 | val_mask : torch.Tensor of shape (|V|) 116 | Mask indicating validation nodes. 117 | """ 118 | X, Y = self.preprocess(X, Y) 119 | 120 | config_list = self.get_config_list() 121 | 122 | best_acc = 0 123 | with tqdm(config_list) as tconfig: 124 | tconfig.set_description(f"Training MLP discriminator") 125 | 126 | for config in tconfig: 127 | trial_acc, trial_model = self.fit_trial(X, 128 | Y, 129 | num_classes, 130 | train_mask, 131 | val_mask, 132 | **config) 133 | 134 | if trial_acc > best_acc: 135 | best_acc = trial_acc 136 | best_model = trial_model 137 | best_model_config = { 138 | "in_size": X.size(1), 139 | "out_size": num_classes, 140 | "num_layers": config["num_layers"], 141 | "hidden_size": config["hidden_size"], 142 | "dropout": config["dropout"] 143 | } 144 | 145 | tconfig.set_postfix(accuracy=100. * best_acc) 146 | 147 | if trial_acc == 1.0: 148 | break 149 | self.model = best_model 150 | self.best_model_config = best_model_config 151 | 152 | @torch.no_grad() 153 | def predict(self, X, Y, mask=None, model=None): 154 | X, Y = self.preprocess(X, Y) 155 | 156 | if model is None: 157 | model = self.model 158 | 159 | model.eval() 160 | 161 | if mask is None: 162 | logits = model(X) 163 | pred = logits.argmax(dim=-1, keepdim=True).reshape(-1) 164 | return (pred == Y).float().mean().item() 165 | else: 166 | logits = model(X[mask]) 167 | pred = logits.argmax(dim=-1, keepdim=True).reshape(-1) 168 | return (pred == Y[mask]).float().mean().item() 169 | 170 | def load_model(self, model_path): 171 | state_dict = torch.load(model_path) 172 | model = MLP(**state_dict["model_config"]).to(self.device) 173 | model.load_state_dict(state_dict["model_state_dict"]) 174 | self.model = model 175 | -------------------------------------------------------------------------------- /model/discriminator/sgc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from copy import deepcopy 5 | from tqdm import tqdm 6 | 7 | from .gcn import GCNTrainer 8 | 9 | class SGC(nn.Module): 10 | def __init__(self, 11 | in_size, 12 | out_size, 13 | num_layers): 14 | super().__init__() 15 | 16 | self.lin = nn.Linear(in_size, out_size) 17 | self.num_layers = num_layers 18 | 19 | def forward(self, A, H): 20 | for _ in range(self.num_layers): 21 | H = A @ H 22 | return self.lin(H) 23 | 24 | class SGCTrainer(GCNTrainer): 25 | def __init__(self, num_gnn_layers): 26 | hyper_space = { 27 | "lr": [3e-2, 1e-2, 3e-3], 28 | "num_layers": [num_gnn_layers] 29 | } 30 | search_priority_increasing = ["lr", "num_layers"] 31 | 32 | super().__init__(hyper_space=hyper_space, 33 | search_priority_increasing=search_priority_increasing, 34 | patience=5) 35 | 36 | self.num_gnn_layers = num_gnn_layers 37 | 38 | def fit_trial(self, 39 | A, 40 | X, 41 | Y, 42 | num_classes, 43 | train_mask, 44 | val_mask, 45 | num_layers, 46 | lr): 47 | model = SGC(in_size=X.size(1), 48 | out_size=num_classes, 49 | num_layers=num_layers).to(self.device) 50 | loss_func = nn.CrossEntropyLoss() 51 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 52 | 53 | num_epochs = 1000 54 | num_patient_epochs = 0 55 | best_acc = 0 56 | best_model_state_dict = deepcopy(model.state_dict()) 57 | for epoch in range(1, num_epochs + 1): 58 | model.train() 59 | logits = model(A, X) 60 | loss = loss_func(logits[train_mask], Y[train_mask]) 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | 65 | acc = self.predict(A, X, Y, val_mask, model) 66 | 67 | if acc > best_acc: 68 | num_patient_epochs = 0 69 | best_acc = acc 70 | best_model_state_dict = deepcopy(model.state_dict()) 71 | else: 72 | num_patient_epochs += 1 73 | 74 | if num_patient_epochs == self.patience: 75 | break 76 | 77 | model.load_state_dict(best_model_state_dict) 78 | 79 | return best_acc, model 80 | 81 | def fit(self, A, X, Y, num_classes, train_mask, val_mask): 82 | """ 83 | Parameters 84 | ---------- 85 | A : dgl.sparse.SparseMatrix 86 | Adjacency matrix. 87 | X : torch.Tensor of shape (|V|, D) 88 | Binary node features. 89 | Y : torch.Tensor of shape (|V|,) 90 | Node labels. 91 | num_classes : int 92 | Number of node classes. 93 | train_mask : torch.Tensor of shape (|V|) 94 | Mask indicating training nodes. 95 | val_mask : torch.Tensor of shape (|V|) 96 | Mask indicating validation nodes. 97 | """ 98 | A, X, Y = self.preprocess(A, X, Y) 99 | 100 | config_list = self.get_config_list() 101 | 102 | best_acc = 0 103 | with tqdm(config_list) as tconfig: 104 | tconfig.set_description( 105 | f"Training SGC {self.num_gnn_layers}-layer discriminator") 106 | 107 | for config in tconfig: 108 | trial_acc, trial_model = self.fit_trial(A, 109 | X, 110 | Y, 111 | num_classes, 112 | train_mask, 113 | val_mask, 114 | **config) 115 | 116 | if trial_acc > best_acc: 117 | best_acc = trial_acc 118 | best_model = trial_model 119 | best_model_config = { 120 | "in_size": X.size(1), 121 | "out_size": num_classes, 122 | "num_layers": config['num_layers'], 123 | } 124 | 125 | tconfig.set_postfix(accuracy=100. * best_acc) 126 | 127 | if trial_acc == 1.0: 128 | break 129 | self.model = best_model 130 | self.best_model_config = best_model_config 131 | 132 | def load_model(self, model_path): 133 | state_dict = torch.load(model_path) 134 | model = SGC(**state_dict["model_config"]).to(self.device) 135 | model.load_state_dict(state_dict["model_state_dict"]) 136 | self.model = model 137 | -------------------------------------------------------------------------------- /model/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ["GNN", "LinkPredictor", "GNNAsymm"] 5 | 6 | class GNNLayer(nn.Module): 7 | """Graph Neural Network (GNN) / Message Passing Neural Network (MPNN) Layer. 8 | 9 | Parameters 10 | ---------- 11 | hidden_X : int 12 | Hidden size for the node attributes. 13 | hidden_Y : int 14 | Hidden size for the node label. 15 | hidden_t : int 16 | Hidden size for the normalized time step. 17 | dropout : float 18 | Dropout rate. 19 | """ 20 | def __init__(self, 21 | hidden_X, 22 | hidden_Y, 23 | hidden_t, 24 | dropout): 25 | super().__init__() 26 | 27 | self.update_X = nn.Sequential( 28 | nn.Linear(hidden_X + hidden_Y + hidden_t, hidden_X), 29 | nn.ReLU(), 30 | nn.LayerNorm(hidden_X), 31 | nn.Dropout(dropout) 32 | ) 33 | self.update_Y = nn.Sequential( 34 | nn.Linear(hidden_Y, hidden_Y), 35 | nn.ReLU(), 36 | nn.LayerNorm(hidden_Y), 37 | nn.Dropout(dropout) 38 | ) 39 | 40 | def forward(self, A, h_X, h_Y, h_t): 41 | """ 42 | Parameters 43 | ---------- 44 | A : dglsp.SparseMatrix 45 | Adjacency matrix. 46 | h_X : torch.Tensor of shape (|V|, hidden_X) 47 | Hidden representations for the node attributes. 48 | h_Y : torch.Tensor of shape (|V|, hidden_Y) 49 | Hidden representations for the node label. 50 | h_t : torch.Tensor of shape (|V|, hidden_t) 51 | Hidden representations for the normalized time step. 52 | 53 | Returns 54 | ------- 55 | h_X : torch.Tensor of shape (|V|, hidden_X) 56 | Updated hidden representations for the node attributes. 57 | h_Y : torch.Tensor of shape (|V|, hidden_Y) 58 | Updated hidden representations for the node label. 59 | """ 60 | h_aggr_X = A @ torch.cat([h_X, h_Y], dim=1) 61 | h_aggr_Y = A @ h_Y 62 | 63 | num_nodes = h_X.size(0) 64 | h_t_expand = h_t.expand(num_nodes, -1) 65 | h_aggr_X = torch.cat([h_aggr_X, h_t_expand], dim=1) 66 | 67 | h_X = self.update_X(h_aggr_X) 68 | h_Y = self.update_Y(h_aggr_Y) 69 | 70 | return h_X, h_Y 71 | 72 | class GNNTower(nn.Module): 73 | """Graph Neural Network (GNN) / Message Passing Neural Network (MPNN). 74 | 75 | Parameters 76 | ---------- 77 | num_attrs_X : int 78 | Number of node attributes. 79 | num_classes_X : int 80 | Number of classes for each node attribute. 81 | num_classes_Y : int 82 | Number of classes for node label. 83 | hidden_t : int 84 | Hidden size for the normalized time step. 85 | hidden_X : int 86 | Hidden size for the node attributes. 87 | hidden_Y : int 88 | Hidden size for the node label. 89 | out_size : int 90 | Output size of the final MLP layer. 91 | num_gnn_layers : int 92 | Number of GNN/MPNN layers. 93 | dropout : float 94 | Dropout rate. 95 | node_mode : bool 96 | Whether the encoder is used for node attribute prediction or structure 97 | prediction. 98 | """ 99 | def __init__(self, 100 | num_attrs_X, 101 | num_classes_X, 102 | num_classes_Y, 103 | hidden_t, 104 | hidden_X, 105 | hidden_Y, 106 | out_size, 107 | num_gnn_layers, 108 | dropout, 109 | node_mode): 110 | super().__init__() 111 | 112 | in_X = num_attrs_X * num_classes_X 113 | self.num_attrs_X = num_attrs_X 114 | self.num_classes_X = num_classes_X 115 | 116 | self.mlp_in_t = nn.Sequential( 117 | nn.Linear(1, hidden_t), 118 | nn.ReLU(), 119 | nn.Linear(hidden_t, hidden_t), 120 | nn.ReLU()) 121 | self.mlp_in_X = nn.Sequential( 122 | nn.Linear(in_X, hidden_X), 123 | nn.ReLU(), 124 | nn.Linear(hidden_X, hidden_X), 125 | nn.ReLU() 126 | ) 127 | self.emb_Y = nn.Embedding(num_classes_Y, hidden_Y) 128 | 129 | self.gnn_layers = nn.ModuleList([ 130 | GNNLayer(hidden_X, 131 | hidden_Y, 132 | hidden_t, 133 | dropout) 134 | for _ in range(num_gnn_layers)]) 135 | 136 | # +1 for the input attributes 137 | hidden_cat = (num_gnn_layers + 1) * (hidden_X + hidden_Y) + hidden_t 138 | self.mlp_out = nn.Sequential( 139 | nn.Linear(hidden_cat, hidden_cat), 140 | nn.ReLU(), 141 | nn.Linear(hidden_cat, out_size) 142 | ) 143 | 144 | self.node_mode = node_mode 145 | 146 | def forward(self, 147 | t_float, 148 | X_t_one_hot, 149 | Y_real, 150 | A_t): 151 | # Input projection. 152 | # (1, hidden_t) 153 | h_t = self.mlp_in_t(t_float).unsqueeze(0) 154 | h_X = self.mlp_in_X(X_t_one_hot) 155 | h_Y = self.emb_Y(Y_real) 156 | 157 | h_X_list = [h_X] 158 | h_Y_list = [h_Y] 159 | for gnn in self.gnn_layers: 160 | h_X, h_Y = gnn(A_t, h_X, h_Y, h_t) 161 | h_X_list.append(h_X) 162 | h_Y_list.append(h_Y) 163 | 164 | # (|V|, hidden_t) 165 | h_t = h_t.expand(h_X.size(0), -1) 166 | h_cat = torch.cat(h_X_list + h_Y_list + [h_t], dim=1) 167 | 168 | if self.node_mode: 169 | # (|V|, F * C_X) 170 | logit = self.mlp_out(h_cat) 171 | # (|V|, F, C_X) 172 | logit = logit.reshape(Y_real.size(0), self.num_attrs_X, -1) 173 | 174 | return logit 175 | else: 176 | return self.mlp_out(h_cat) 177 | 178 | class LinkPredictor(nn.Module): 179 | """Model for structure prediction. 180 | 181 | Parameters 182 | ---------- 183 | num_attrs_X : int 184 | Number of node attributes. 185 | num_classes_X : int 186 | Number of classes for each node attribute. 187 | num_classes_Y : int 188 | Number of classes for node label. 189 | num_classes_E : int 190 | Number of edge classes. 191 | hidden_t : int 192 | Hidden size for the normalized time step. 193 | hidden_X : int 194 | Hidden size for the node attributes. 195 | hidden_Y : int 196 | Hidden size for the node label. 197 | hidden_E : int 198 | Hidden size for the edges. 199 | num_gnn_layers : int 200 | Number of GNN/MPNN layers. 201 | dropout : float 202 | Dropout rate. 203 | """ 204 | def __init__(self, 205 | num_attrs_X, 206 | num_classes_X, 207 | num_classes_Y, 208 | num_classes_E, 209 | hidden_t, 210 | hidden_X, 211 | hidden_Y, 212 | hidden_E, 213 | num_gnn_layers, 214 | dropout): 215 | super().__init__() 216 | 217 | self.gnn_encoder = GNNTower(num_attrs_X, 218 | num_classes_X, 219 | num_classes_Y, 220 | hidden_t, 221 | hidden_X, 222 | hidden_Y, 223 | hidden_E, 224 | num_gnn_layers, 225 | dropout, 226 | node_mode=False) 227 | self.mlp_out = nn.Sequential( 228 | nn.Linear(hidden_E, hidden_E), 229 | nn.ReLU(), 230 | nn.Linear(hidden_E, num_classes_E) 231 | ) 232 | 233 | def forward(self, 234 | t_float, 235 | X_t_one_hot, 236 | Y_real, 237 | A_t, 238 | src, 239 | dst): 240 | # (|V|, hidden_E) 241 | h = self.gnn_encoder(t_float, 242 | X_t_one_hot, 243 | Y_real, 244 | A_t) 245 | # (|E|, hidden_E) 246 | h = h[src] * h[dst] 247 | # (|E|, num_classes_E) 248 | logit = self.mlp_out(h) 249 | 250 | return logit 251 | 252 | class GNN(nn.Module): 253 | """P(X|Y, X^t, A^t) + P(A|Y, X^t, A^t) 254 | 255 | Parameters 256 | ---------- 257 | num_attrs_X : int 258 | Number of node attributes. 259 | num_classes_X : int 260 | Number of classes for each node attribute. 261 | num_classes_Y : int 262 | Number of classes for node label. 263 | num_classes_E : int 264 | Number of edge classes. 265 | gnn_X_config : dict 266 | Configuration of the GNN for reconstructing node attributes. 267 | gnn_E_config : dict 268 | Configuration of the GNN for reconstructing edges. 269 | """ 270 | def __init__(self, 271 | num_attrs_X, 272 | num_classes_X, 273 | num_classes_Y, 274 | num_classes_E, 275 | gnn_X_config, 276 | gnn_E_config): 277 | super().__init__() 278 | 279 | self.pred_X = GNNTower(num_attrs_X, 280 | num_classes_X, 281 | num_classes_Y, 282 | out_size=num_attrs_X * num_classes_X, 283 | node_mode=True, 284 | **gnn_X_config) 285 | 286 | self.pred_E = LinkPredictor(num_attrs_X, 287 | num_classes_X, 288 | num_classes_Y, 289 | num_classes_E, 290 | **gnn_E_config) 291 | 292 | def forward(self, 293 | t_float, 294 | X_t_one_hot, 295 | Y, 296 | A_t, 297 | batch_src, 298 | batch_dst): 299 | """ 300 | Parameters 301 | ---------- 302 | t_float : torch.Tensor of shape (1) 303 | Sampled timestep divided by self.T. 304 | X_t_one_hot : torch.Tensor of shape (|V|, 2 * F) 305 | One-hot encoding of the sampled node attributes. 306 | Y : torch.Tensor of shape (|V|) 307 | Categorical node labels. 308 | A_t : dglsp.SparseMatrix 309 | Row-normalized sampled adjacency matrix. 310 | batch_src : torch.LongTensor of shape (B) 311 | Source node IDs for a batch of candidate edges (node pairs). 312 | batch_dst : torch.LongTensor of shape (B) 313 | Destination node IDs for a batch of candidate edges (node pairs). 314 | 315 | Returns 316 | ------- 317 | logit_X : torch.Tensor of shape (|V|, F, 2) 318 | Predicted logits for the node attributes. 319 | logit_E : torch.Tensor of shape (B, 2) 320 | Predicted logits for the edge existence. 321 | """ 322 | logit_X = self.pred_X(t_float, 323 | X_t_one_hot, 324 | Y, 325 | A_t) 326 | 327 | logit_E = self.pred_E(t_float, 328 | X_t_one_hot, 329 | Y, 330 | A_t, 331 | batch_src, 332 | batch_dst) 333 | 334 | return logit_X, logit_E 335 | 336 | class MLPLayer(nn.Module): 337 | """ 338 | Parameters 339 | ---------- 340 | hidden_X : int 341 | Hidden size for the node attributes. 342 | hidden_Y : int 343 | Hidden size for the node labels. 344 | hidden_t : int 345 | Hidden size for the normalized time step. 346 | dropout : float 347 | Dropout rate. 348 | """ 349 | def __init__(self, 350 | hidden_X, 351 | hidden_Y, 352 | hidden_t, 353 | dropout): 354 | super().__init__() 355 | 356 | self.update_X = nn.Sequential( 357 | nn.Linear(hidden_X + hidden_Y + hidden_t, hidden_X), 358 | nn.ReLU(), 359 | nn.LayerNorm(hidden_X), 360 | nn.Dropout(dropout) 361 | ) 362 | self.update_Y = nn.Sequential( 363 | nn.Linear(hidden_Y, hidden_Y), 364 | nn.ReLU(), 365 | nn.LayerNorm(hidden_Y), 366 | nn.Dropout(dropout) 367 | ) 368 | 369 | def forward(self, h_X, h_Y, h_t): 370 | """ 371 | Parameters 372 | ---------- 373 | h_X : torch.Tensor of shape (|V|, hidden_X) 374 | Hidden representations for the node attributes. 375 | h_Y : torch.Tensor of shape (|V|, hidden_Y) 376 | Hidden representations for the node labels. 377 | h_t : torch.Tensor of shape (1, hidden_t) 378 | Hidden representations for the normalized time step. 379 | 380 | Returns 381 | ------- 382 | h_X : torch.Tensor of shape (|V|, hidden_X) 383 | Updated hidden representations for the node attributes. 384 | h_Y : torch.Tensor of shape (|V|, hidden_Y) 385 | Updated hidden representations for the node labels. 386 | """ 387 | num_nodes = h_X.size(0) 388 | h_t_expand = h_t.expand(num_nodes, -1) 389 | h_X = torch.cat([h_X, h_Y, h_t_expand], dim=1) 390 | 391 | h_X = self.update_X(h_X) 392 | h_Y = self.update_Y(h_Y) 393 | 394 | return h_X, h_Y 395 | 396 | class MLPTower(nn.Module): 397 | def __init__(self, 398 | num_attrs_X, 399 | num_classes_X, 400 | num_classes_Y, 401 | hidden_t, 402 | hidden_X, 403 | hidden_Y, 404 | num_mlp_layers, 405 | dropout): 406 | super().__init__() 407 | 408 | in_X = num_attrs_X * num_classes_X 409 | self.num_attrs_X = num_attrs_X 410 | self.num_classes_X = num_classes_X 411 | 412 | self.mlp_in_t = nn.Sequential( 413 | nn.Linear(1, hidden_t), 414 | nn.ReLU(), 415 | nn.Linear(hidden_t, hidden_t), 416 | nn.ReLU()) 417 | self.mlp_in_X = nn.Sequential( 418 | nn.Linear(in_X, hidden_X), 419 | nn.ReLU(), 420 | nn.Linear(hidden_X, hidden_X), 421 | nn.ReLU() 422 | ) 423 | self.emb_Y = nn.Embedding(num_classes_Y, hidden_Y) 424 | 425 | self.mlp_layers = nn.ModuleList([ 426 | MLPLayer(hidden_X, 427 | hidden_Y, 428 | hidden_t, 429 | dropout) 430 | for _ in range(num_mlp_layers)]) 431 | 432 | # +1 for the input features 433 | hidden_cat = (num_mlp_layers + 1) * (hidden_X + hidden_Y) + hidden_t 434 | self.mlp_out = nn.Sequential( 435 | nn.Linear(hidden_cat, hidden_cat), 436 | nn.ReLU(), 437 | nn.Linear(hidden_cat, in_X) 438 | ) 439 | 440 | def forward(self, 441 | t_float, 442 | X_t_one_hot, 443 | Y_real): 444 | # Input projection. 445 | h_t = self.mlp_in_t(t_float).unsqueeze(0) 446 | h_X = self.mlp_in_X(X_t_one_hot) 447 | h_Y = self.emb_Y(Y_real) 448 | 449 | h_X_list = [h_X] 450 | h_Y_list = [h_Y] 451 | for mlp in self.mlp_layers: 452 | h_X, h_Y = mlp(h_X, h_Y, h_t) 453 | h_X_list.append(h_X) 454 | h_Y_list.append(h_Y) 455 | 456 | h_t = h_t.expand(h_X.size(0), -1) 457 | h_cat = torch.cat(h_X_list + h_Y_list + [h_t], dim=1) 458 | 459 | logit = self.mlp_out(h_cat) 460 | # (|V|, F, C) 461 | logit = logit.reshape(Y_real.size(0), self.num_attrs_X, -1) 462 | 463 | return logit 464 | 465 | class GNNAsymm(nn.Module): 466 | """P(X|Y, X_t) + P(A|Y, X, A_t) 467 | 468 | Parameters 469 | ---------- 470 | num_attrs_X : int 471 | Number of node attributes. 472 | num_classes_X : int 473 | Number of classes for each node attribute. 474 | num_classes_Y : int 475 | Number of classes for node label. 476 | num_classes_E : int 477 | Number of edge classes. 478 | mlp_X_config : dict 479 | Configuration of the MLP for reconstructing node attributes. 480 | gnn_E_config : dict 481 | Configuration of the GNN for reconstructing edges. 482 | """ 483 | def __init__(self, 484 | num_attrs_X, 485 | num_classes_X, 486 | num_classes_Y, 487 | num_classes_E, 488 | mlp_X_config, 489 | gnn_E_config): 490 | super().__init__() 491 | 492 | self.pred_X = MLPTower(num_attrs_X, 493 | num_classes_X, 494 | num_classes_Y, 495 | **mlp_X_config) 496 | 497 | self.pred_E = LinkPredictor(num_attrs_X, 498 | num_classes_X, 499 | num_classes_Y, 500 | num_classes_E, 501 | **gnn_E_config) 502 | 503 | def forward(self, 504 | t_float_X, 505 | t_float_E, 506 | X_t_one_hot, 507 | Y, 508 | X_one_hot_2d, 509 | A_t, 510 | batch_src, 511 | batch_dst): 512 | """ 513 | Parameters 514 | ---------- 515 | t_float_X : torch.Tensor of shape (1) 516 | Sampled timestep divided by self.T_X. 517 | t_float_E : torch.Tensor of shape (1) 518 | Sampled timestep divided by self.T_E. 519 | X_t_one_hot : torch.Tensor of shape (|V|, 2 * F) 520 | One-hot encoding of the sampled node attributes. 521 | Y : torch.Tensor of shape (|V|) 522 | Categorical node labels. 523 | X_one_hot_2d : torch.Tensor of shape (|V|, 2 * F) 524 | Flattened one-hot encoding of the node attributes. 525 | A_t : dglsp.SparseMatrix 526 | Row-normalized sampled adjacency matrix. 527 | batch_src : torch.LongTensor of shape (B) 528 | Source node IDs for a batch of candidate edges (node pairs). 529 | batch_dst : torch.LongTensor of shape (B) 530 | Destination node IDs for a batch of candidate edges (node pairs). 531 | 532 | Returns 533 | ------- 534 | logit_X : torch.Tensor of shape (|V|, F, 2) 535 | Predicted logits for the node attributes. 536 | logit_E : torch.Tensor of shape (B, 2) 537 | Predicted logits for the edge existence. 538 | """ 539 | logit_X = self.pred_X(t_float_X, 540 | X_t_one_hot, 541 | Y) 542 | 543 | logit_E = self.pred_E(t_float_E, 544 | X_one_hot_2d, 545 | Y, 546 | A_t, 547 | batch_src, 548 | batch_dst) 549 | 550 | return logit_X, logit_E 551 | -------------------------------------------------------------------------------- /model_240125.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-COM/GraphMaker/0779133ebdc58d5c1e1c2d9abd696e31c3a343c4/model_240125.png -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from huggingface_hub import hf_hub_download 6 | 7 | from data import load_dataset, preprocess 8 | from eval_utils import Evaluator 9 | from setup_utils import set_seed 10 | 11 | def main(args): 12 | if args.model_path is None: 13 | if args.dataset is None or args.type is None: 14 | raise ValueError("If model_path is not provided, both dataset and type must be specified for downloading a pre-trained model checkpoint.") 15 | 16 | filename = f"{args.dataset}_{args.type}.pth" 17 | 18 | print(f"Downloading pre-trained model: {filename}") 19 | args.model_path = hf_hub_download(repo_id="Graph-COM/GraphMaker", 20 | filename=filename, 21 | cache_dir="./downloaded_cpts") 22 | print(f"Downloaded model to {args.model_path}") 23 | else: 24 | print(f"Loading local model from {args.model_path}") 25 | 26 | state_dict = torch.load(args.model_path) 27 | dataset = state_dict["dataset"] 28 | 29 | train_yaml_data = state_dict["train_yaml_data"] 30 | model_name = train_yaml_data["meta_data"]["variant"] 31 | 32 | print(f"Loaded GraphMaker-{model_name} model trained on {dataset}") 33 | print(f"Val Nll {state_dict['best_val_nll']}") 34 | 35 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 36 | 37 | g_real = load_dataset(dataset) 38 | X_one_hot_3d_real, Y_real, E_one_hot_real,\ 39 | X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals = preprocess(g_real) 40 | Y_one_hot_real = F.one_hot(Y_real) 41 | 42 | evaluator = Evaluator(dataset, 43 | g_real, 44 | X_one_hot_3d_real, 45 | Y_one_hot_real) 46 | 47 | X_marginal = X_marginal.to(device) 48 | Y_marginal = Y_marginal.to(device) 49 | E_marginal = E_marginal.to(device) 50 | X_cond_Y_marginals = X_cond_Y_marginals.to(device) 51 | num_nodes = Y_real.size(0) 52 | 53 | if model_name == "Sync": 54 | from model import ModelSync 55 | 56 | model = ModelSync(X_marginal=X_marginal, 57 | Y_marginal=Y_marginal, 58 | E_marginal=E_marginal, 59 | gnn_X_config=train_yaml_data["gnn_X"], 60 | gnn_E_config=train_yaml_data["gnn_E"], 61 | num_nodes=num_nodes, 62 | **train_yaml_data["diffusion"]).to(device) 63 | 64 | model.graph_encoder.pred_X.load_state_dict(state_dict["pred_X_state_dict"]) 65 | model.graph_encoder.pred_E.load_state_dict(state_dict["pred_E_state_dict"]) 66 | 67 | elif model_name == "Async": 68 | from model import ModelAsync 69 | 70 | model = ModelAsync(X_marginal=X_marginal, 71 | Y_marginal=Y_marginal, 72 | E_marginal=E_marginal, 73 | mlp_X_config=train_yaml_data["mlp_X"], 74 | gnn_E_config=train_yaml_data["gnn_E"], 75 | num_nodes=num_nodes, 76 | **train_yaml_data["diffusion"]).to(device) 77 | 78 | model.graph_encoder.pred_X.load_state_dict(state_dict["pred_X_state_dict"]) 79 | model.graph_encoder.pred_E.load_state_dict(state_dict["pred_E_state_dict"]) 80 | 81 | model.eval() 82 | 83 | # Set seed for better reproducibility. 84 | set_seed() 85 | 86 | for _ in range(args.num_samples): 87 | X_0_one_hot, Y_0_one_hot, E_0 = model.sample() 88 | src, dst = E_0.nonzero().T 89 | g_sample = dgl.graph((src, dst), num_nodes=num_nodes).cpu() 90 | 91 | evaluator.add_sample(g_sample, 92 | X_0_one_hot.cpu(), 93 | Y_0_one_hot.cpu()) 94 | 95 | evaluator.summary() 96 | 97 | if __name__ == '__main__': 98 | from argparse import ArgumentParser 99 | 100 | parser = ArgumentParser() 101 | parser.add_argument("--model_path", type=str, help="Path to the model.") 102 | parser.add_argument("--dataset", type=str, choices=["cora", "amazon_photo", "amazon_computer"], 103 | help="Dataset name. Only specify it if you want to use a pre-trained model.") 104 | parser.add_argument("--type", type=str, choices=["sync", "async"], 105 | help="Model type. Only specify it if you want to use a pre-trained model.") 106 | parser.add_argument("--num_samples", type=int, default=10, 107 | help="Number of samples to generate.") 108 | args = parser.parse_args() 109 | 110 | main(args) 111 | -------------------------------------------------------------------------------- /setup_utils.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import numpy as np 3 | import pydantic 4 | import random 5 | import torch 6 | import yaml 7 | 8 | from typing import Optional 9 | 10 | # pydantic allows checking field types when loading configuration files 11 | class MetaDataYaml(pydantic.BaseModel): 12 | variant: str 13 | 14 | class GNNXYaml(pydantic.BaseModel): 15 | hidden_t: int 16 | hidden_X: int 17 | hidden_Y: int 18 | num_gnn_layers: int 19 | dropout: float 20 | 21 | class GNNEYaml(pydantic.BaseModel): 22 | hidden_t: int 23 | hidden_X: int 24 | hidden_Y: int 25 | hidden_E: int 26 | num_gnn_layers: int 27 | dropout: float 28 | 29 | class DiffusionYaml(pydantic.BaseModel): 30 | T: int 31 | 32 | class OptimizerYaml(pydantic.BaseModel): 33 | lr: float 34 | weight_decay: Optional[float] = 0. 35 | amsgrad: Optional[bool] = False 36 | 37 | class LRSchedulerYaml(pydantic.BaseModel): 38 | factor: float 39 | patience: int 40 | verbose: bool 41 | 42 | class TrainYaml(pydantic.BaseModel): 43 | num_epochs: int 44 | val_every_epochs: int 45 | patient_epochs: int 46 | max_grad_norm: Optional[float] = None 47 | batch_size: int 48 | val_batch_size: int 49 | 50 | class SyncYaml(pydantic.BaseModel): 51 | meta_data: MetaDataYaml 52 | gnn_X: GNNXYaml 53 | gnn_E: GNNEYaml 54 | diffusion: DiffusionYaml 55 | optimizer_X: OptimizerYaml 56 | optimizer_E: OptimizerYaml 57 | lr_scheduler: LRSchedulerYaml 58 | train: TrainYaml 59 | 60 | class MLPXYaml(pydantic.BaseModel): 61 | hidden_t: int 62 | hidden_X: int 63 | hidden_Y: int 64 | num_mlp_layers: int 65 | dropout: float 66 | 67 | class DiffusionAsyncYaml(pydantic.BaseModel): 68 | T_X: int 69 | T_E: int 70 | 71 | class AsyncYaml(pydantic.BaseModel): 72 | meta_data: MetaDataYaml 73 | mlp_X: MLPXYaml 74 | gnn_E: GNNEYaml 75 | diffusion: DiffusionAsyncYaml 76 | optimizer_X: OptimizerYaml 77 | optimizer_E: OptimizerYaml 78 | lr_scheduler: LRSchedulerYaml 79 | train: TrainYaml 80 | 81 | def load_train_yaml(data_name, model_name): 82 | with open(f"configs/{data_name}/train_{model_name}.yaml") as f: 83 | yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader) 84 | 85 | if model_name == "Sync": 86 | return SyncYaml(**yaml_data).model_dump() 87 | elif model_name == "Async": 88 | return AsyncYaml(**yaml_data).model_dump() 89 | 90 | def set_seed(seed=0): 91 | np.random.seed(seed) 92 | random.seed(seed) 93 | torch.manual_seed(seed) 94 | torch.cuda.manual_seed(seed) 95 | torch.cuda.manual_seed_all(seed) 96 | torch.backends.cudnn.deterministic = True 97 | torch.backends.cudnn.benchmark = False 98 | dgl.seed(seed) 99 | -------------------------------------------------------------------------------- /train_async.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import torch 5 | import torch.nn as nn 6 | import wandb 7 | 8 | from copy import deepcopy 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from data import load_dataset, preprocess 14 | from model import ModelAsync 15 | from setup_utils import load_train_yaml, set_seed 16 | 17 | def main(args): 18 | model_name = "Async" 19 | yaml_data = load_train_yaml(args.dataset, model_name) 20 | 21 | config_df = pd.json_normalize(yaml_data, sep='/') 22 | # Number of time steps 23 | T_X = yaml_data['diffusion']['T_X'] 24 | T_E = yaml_data['diffusion']['T_E'] 25 | wandb.init( 26 | project=f"{args.dataset}-{model_name}", 27 | name=f"T_X{T_X}, T_E{T_E}", 28 | config=config_df.to_dict(orient='records')[0]) 29 | 30 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 31 | 32 | g = load_dataset(args.dataset) 33 | X_one_hot_3d, Y, E_one_hot,\ 34 | X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals = preprocess(g) 35 | 36 | # (F, |V|, 2) 37 | X_one_hot_3d = X_one_hot_3d.to(device) 38 | # (|V|, F, 2) 39 | X_one_hot_2d = torch.transpose(X_one_hot_3d, 0, 1) 40 | # (|V|, 2 * F) 41 | X_one_hot_2d = X_one_hot_2d.reshape(X_one_hot_2d.size(0), -1) 42 | 43 | Y = Y.to(device) 44 | E_one_hot = E_one_hot.to(device) 45 | 46 | X_marginal = X_marginal.to(device) 47 | Y_marginal = Y_marginal.to(device) 48 | E_marginal = E_marginal.to(device) 49 | 50 | N = g.num_nodes() 51 | dst, src = torch.triu_indices(N, N, offset=1, device=device) 52 | # (|E|, 2), |E| for number of edges 53 | edge_index = torch.stack([dst, src], dim=1) 54 | 55 | # Set seed for better reproducibility. 56 | set_seed() 57 | 58 | train_config = yaml_data["train"] 59 | # For mini-batch training 60 | data_loader = DataLoader(edge_index.cpu(), batch_size=train_config["batch_size"], 61 | shuffle=True, num_workers=4) 62 | val_data_loader = DataLoader(edge_index, batch_size=train_config["val_batch_size"], 63 | shuffle=False) 64 | 65 | model = ModelAsync(X_marginal=X_marginal, 66 | Y_marginal=Y_marginal, 67 | E_marginal=E_marginal, 68 | num_nodes=N, 69 | mlp_X_config=yaml_data["mlp_X"], 70 | gnn_E_config=yaml_data["gnn_E"], 71 | **yaml_data["diffusion"]).to(device) 72 | 73 | optimizer_X = torch.optim.AdamW(model.graph_encoder.pred_X.parameters(), 74 | **yaml_data["optimizer_X"]) 75 | optimizer_E = torch.optim.AdamW(model.graph_encoder.pred_E.parameters(), 76 | **yaml_data["optimizer_E"]) 77 | 78 | lr_scheduler_X = ReduceLROnPlateau(optimizer_X, mode='min', **yaml_data["lr_scheduler"]) 79 | lr_scheduler_E = ReduceLROnPlateau(optimizer_E, mode='min', **yaml_data["lr_scheduler"]) 80 | 81 | best_epoch_X = 0 82 | best_state_dict_X = deepcopy(model.graph_encoder.pred_X.state_dict()) 83 | best_val_nll_X = float('inf') 84 | best_log_p_0_X = float('inf') 85 | best_denoise_match_X = float('inf') 86 | 87 | best_epoch_E = 0 88 | best_state_dict_E = deepcopy(model.graph_encoder.pred_E.state_dict()) 89 | best_val_nll_E = float('inf') 90 | best_log_p_0_E = float('inf') 91 | best_denoise_match_E = float('inf') 92 | 93 | # Create the directory for saving model checkpoints. 94 | model_cpt_dir = f"{args.dataset}_cpts" 95 | os.makedirs(model_cpt_dir, exist_ok=True) 96 | 97 | num_patient_epochs = 0 98 | for epoch in range(train_config["num_epochs"]): 99 | model.train() 100 | 101 | for batch_edge_index in tqdm(data_loader): 102 | batch_edge_index = batch_edge_index.to(device) 103 | # (B), (B) 104 | batch_dst, batch_src = batch_edge_index.T 105 | loss_X, loss_E = model.log_p_t(X_one_hot_3d, 106 | E_one_hot, 107 | Y, 108 | X_one_hot_2d, 109 | batch_src, 110 | batch_dst, 111 | E_one_hot[batch_dst, batch_src]) 112 | loss = loss_X + loss_E 113 | 114 | optimizer_X.zero_grad() 115 | optimizer_E.zero_grad() 116 | 117 | loss.backward() 118 | 119 | nn.utils.clip_grad_norm_( 120 | model.graph_encoder.pred_X.parameters(), train_config["max_grad_norm"]) 121 | nn.utils.clip_grad_norm_( 122 | model.graph_encoder.pred_E.parameters(), train_config["max_grad_norm"]) 123 | 124 | optimizer_X.step() 125 | optimizer_E.step() 126 | 127 | wandb.log({"train/loss_X": loss_X.item(), 128 | "train/loss_E": loss_E.item()}) 129 | 130 | if (epoch + 1) % train_config["val_every_epochs"] != 0: 131 | continue 132 | 133 | model.eval() 134 | 135 | num_patient_epochs += 1 136 | denoise_match_X = [] 137 | denoise_match_E = [] 138 | log_p_0_X = [] 139 | log_p_0_E = [] 140 | for batch_edge_index in tqdm(val_data_loader): 141 | # (B), (B) 142 | batch_dst, batch_src = batch_edge_index.T 143 | batch_denoise_match_E, batch_denoise_match_X,\ 144 | batch_log_p_0_E, batch_log_p_0_X = model.val_step( 145 | X_one_hot_3d, 146 | E_one_hot, 147 | Y, 148 | X_one_hot_2d, 149 | batch_src, 150 | batch_dst, 151 | E_one_hot[batch_dst, batch_src]) 152 | 153 | denoise_match_E.append(batch_denoise_match_E) 154 | denoise_match_X.append(batch_denoise_match_X) 155 | log_p_0_E.append(batch_log_p_0_E) 156 | log_p_0_X.append(batch_log_p_0_X) 157 | 158 | denoise_match_E = np.mean(denoise_match_E) 159 | denoise_match_X = np.mean(denoise_match_X) 160 | log_p_0_E = np.mean(log_p_0_E) 161 | log_p_0_X = np.mean(log_p_0_X) 162 | 163 | val_X = denoise_match_X + log_p_0_X 164 | val_E = denoise_match_E + log_p_0_E 165 | 166 | to_save_cpt = False 167 | if val_X < best_val_nll_X: 168 | best_val_nll_X = val_X 169 | best_epoch_X = epoch 170 | best_state_dict_X = deepcopy(model.graph_encoder.pred_X.state_dict()) 171 | to_save_cpt = True 172 | 173 | if val_E < best_val_nll_E: 174 | best_val_nll_E = val_E 175 | best_epoch_E = epoch 176 | best_state_dict_E = deepcopy(model.graph_encoder.pred_E.state_dict()) 177 | to_save_cpt = True 178 | 179 | if to_save_cpt: 180 | best_val_nll = best_val_nll_X + best_val_nll_E 181 | torch.save({ 182 | "dataset": args.dataset, 183 | "train_yaml_data": yaml_data, 184 | "best_val_nll": best_val_nll, 185 | "best_epoch_X": best_epoch_X, 186 | "best_epoch_E": best_epoch_E, 187 | "pred_X_state_dict": best_state_dict_X, 188 | "pred_E_state_dict": best_state_dict_E 189 | }, f"{model_cpt_dir}/{model_name}_TX{T_X}_TE{T_E}.pth") 190 | print('model saved') 191 | 192 | if log_p_0_X < best_log_p_0_X: 193 | best_log_p_0_X = log_p_0_X 194 | num_patient_epochs = 0 195 | 196 | if denoise_match_X < best_denoise_match_X: 197 | best_denoise_match_X = denoise_match_X 198 | num_patient_epochs = 0 199 | 200 | if log_p_0_E < best_log_p_0_E: 201 | best_log_p_0_E = log_p_0_E 202 | num_patient_epochs = 0 203 | 204 | if denoise_match_E < best_denoise_match_E: 205 | best_denoise_match_E = denoise_match_E 206 | num_patient_epochs = 0 207 | 208 | wandb.log({"epoch": epoch, 209 | "val/denoise_match_X": denoise_match_X, 210 | "val/denoise_match_E": denoise_match_E, 211 | "val/log_p_0_X": log_p_0_X, 212 | "val/log_p_0_E": log_p_0_E, 213 | "val/best_log_p_0_X": best_log_p_0_X, 214 | "val/best_denoise_match_X": best_denoise_match_X, 215 | "val/best_log_p_0_E": best_log_p_0_E, 216 | "val/best_denoise_match_E": best_denoise_match_E, 217 | "val/best_val_X": best_val_nll_X, 218 | "val/best_val_E": best_val_nll_E, 219 | "val/best_val_nll": best_val_nll}) 220 | 221 | print("Epoch {} | best val X nll {:.7f} | best val E nll {:.7f} | patience {}/{}".format( 222 | epoch, best_val_nll_X, best_val_nll_E, num_patient_epochs, train_config["patient_epochs"])) 223 | 224 | if num_patient_epochs == train_config["patient_epochs"]: 225 | break 226 | 227 | lr_scheduler_X.step(log_p_0_X) 228 | lr_scheduler_E.step(log_p_0_E) 229 | 230 | wandb.finish() 231 | 232 | if __name__ == '__main__': 233 | from argparse import ArgumentParser 234 | 235 | parser = ArgumentParser() 236 | parser.add_argument("-d", "--dataset", type=str, required=True, 237 | choices=["cora", "amazon_photo", "amazon_computer"]) 238 | args = parser.parse_args() 239 | 240 | main(args) 241 | -------------------------------------------------------------------------------- /train_sync.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import torch 5 | import torch.nn as nn 6 | import wandb 7 | 8 | from copy import deepcopy 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from data import load_dataset, preprocess 14 | from model import ModelSync 15 | from setup_utils import load_train_yaml, set_seed 16 | 17 | def main(args): 18 | model_name = "Sync" 19 | yaml_data = load_train_yaml(args.dataset, model_name) 20 | 21 | config_df = pd.json_normalize(yaml_data, sep='/') 22 | # Number of time steps 23 | T = yaml_data['diffusion']['T'] 24 | wandb.init( 25 | project=f"{args.dataset}-{model_name}", 26 | name=f"T{T}", 27 | config=config_df.to_dict(orient='records')[0]) 28 | 29 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 30 | 31 | g = load_dataset(args.dataset) 32 | X_one_hot_3d, Y, E_one_hot,\ 33 | X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals = preprocess(g) 34 | 35 | X_one_hot_3d = X_one_hot_3d.to(device) 36 | Y = Y.to(device) 37 | E_one_hot = E_one_hot.to(device) 38 | 39 | X_marginal = X_marginal.to(device) 40 | Y_marginal = Y_marginal.to(device) 41 | E_marginal = E_marginal.to(device) 42 | 43 | N = g.num_nodes() 44 | dst, src = torch.triu_indices(N, N, offset=1, device=device) 45 | # (|E|, 2), |E| for number of edges 46 | edge_index = torch.stack([dst, src], dim=1) 47 | 48 | # Set seed for better reproducibility. 49 | set_seed() 50 | 51 | train_config = yaml_data["train"] 52 | # For mini-batch training 53 | data_loader = DataLoader(edge_index.cpu(), batch_size=train_config["batch_size"], 54 | shuffle=True, num_workers=4) 55 | val_data_loader = DataLoader(edge_index, batch_size=train_config["val_batch_size"], 56 | shuffle=False) 57 | 58 | model = ModelSync(X_marginal=X_marginal, 59 | Y_marginal=Y_marginal, 60 | E_marginal=E_marginal, 61 | num_nodes=N, 62 | gnn_X_config=yaml_data["gnn_X"], 63 | gnn_E_config=yaml_data["gnn_E"], 64 | **yaml_data["diffusion"]).to(device) 65 | 66 | optimizer_X = torch.optim.AdamW(model.graph_encoder.pred_X.parameters(), 67 | **yaml_data["optimizer_X"]) 68 | optimizer_E = torch.optim.AdamW(model.graph_encoder.pred_E.parameters(), 69 | **yaml_data["optimizer_E"]) 70 | 71 | lr_scheduler_X = ReduceLROnPlateau(optimizer_X, mode='min', **yaml_data["lr_scheduler"]) 72 | lr_scheduler_E = ReduceLROnPlateau(optimizer_E, mode='min', **yaml_data["lr_scheduler"]) 73 | 74 | best_epoch_X = 0 75 | best_state_dict_X = deepcopy(model.graph_encoder.pred_X.state_dict()) 76 | best_val_nll_X = float('inf') 77 | best_log_p_0_X = float('inf') 78 | best_denoise_match_X = float('inf') 79 | 80 | best_epoch_E = 0 81 | best_state_dict_E = deepcopy(model.graph_encoder.pred_E.state_dict()) 82 | best_val_nll_E = float('inf') 83 | best_log_p_0_E = float('inf') 84 | best_denoise_match_E = float('inf') 85 | 86 | # Create the directory for saving model checkpoints. 87 | model_cpt_dir = f"{args.dataset}_cpts" 88 | os.makedirs(model_cpt_dir, exist_ok=True) 89 | 90 | num_patient_epochs = 0 91 | for epoch in range(train_config["num_epochs"]): 92 | model.train() 93 | 94 | for batch_edge_index in tqdm(data_loader): 95 | batch_edge_index = batch_edge_index.to(device) 96 | # (B), (B) 97 | batch_dst, batch_src = batch_edge_index.T 98 | loss_X, loss_E = model.log_p_t(X_one_hot_3d, 99 | E_one_hot, 100 | Y, 101 | batch_src, 102 | batch_dst, 103 | E_one_hot[batch_dst, batch_src]) 104 | loss = loss_X + loss_E 105 | 106 | optimizer_X.zero_grad() 107 | optimizer_E.zero_grad() 108 | 109 | loss.backward() 110 | 111 | nn.utils.clip_grad_norm_( 112 | model.graph_encoder.pred_X.parameters(), train_config["max_grad_norm"]) 113 | nn.utils.clip_grad_norm_( 114 | model.graph_encoder.pred_E.parameters(), train_config["max_grad_norm"]) 115 | 116 | optimizer_X.step() 117 | optimizer_E.step() 118 | 119 | wandb.log({"train/loss_X": loss_X.item(), 120 | "train/loss_E": loss_E.item()}) 121 | 122 | if (epoch + 1) % train_config["val_every_epochs"] != 0: 123 | continue 124 | 125 | model.eval() 126 | 127 | num_patient_epochs += 1 128 | 129 | denoise_match_E = [] 130 | denoise_match_X = [] 131 | log_p_0_E = [] 132 | log_p_0_X = [] 133 | for batch_edge_index in tqdm(val_data_loader): 134 | batch_dst, batch_src = batch_edge_index.T 135 | 136 | batch_denoise_match_E, batch_denoise_match_X,\ 137 | batch_log_p_0_E, batch_log_p_0_X = model.val_step( 138 | X_one_hot_3d, 139 | E_one_hot, 140 | Y, 141 | batch_src, 142 | batch_dst, 143 | E_one_hot[batch_dst, batch_src]) 144 | 145 | denoise_match_E.append(batch_denoise_match_E) 146 | denoise_match_X.append(batch_denoise_match_X) 147 | log_p_0_E.append(batch_log_p_0_E) 148 | log_p_0_X.append(batch_log_p_0_X) 149 | 150 | denoise_match_E = np.mean(denoise_match_E) 151 | denoise_match_X = np.mean(denoise_match_X) 152 | log_p_0_E = np.mean(log_p_0_E) 153 | log_p_0_X = np.mean(log_p_0_X) 154 | 155 | val_X = denoise_match_X + log_p_0_X 156 | val_E = denoise_match_E + log_p_0_E 157 | 158 | to_save_cpt = False 159 | if val_X < best_val_nll_X: 160 | best_val_nll_X = val_X 161 | best_epoch_X = epoch 162 | best_state_dict_X = deepcopy(model.graph_encoder.pred_X.state_dict()) 163 | to_save_cpt = True 164 | 165 | if val_E < best_val_nll_E: 166 | best_val_nll_E = val_E 167 | best_epoch_E = epoch 168 | best_state_dict_E = deepcopy(model.graph_encoder.pred_E.state_dict()) 169 | to_save_cpt = True 170 | 171 | if to_save_cpt: 172 | best_val_nll = best_val_nll_X + best_val_nll_E 173 | torch.save({ 174 | "dataset": args.dataset, 175 | "train_yaml_data": yaml_data, 176 | "best_val_nll": best_val_nll, 177 | "best_epoch_X": best_epoch_X, 178 | "best_epoch_E": best_epoch_E, 179 | "pred_X_state_dict": best_state_dict_X, 180 | "pred_E_state_dict": best_state_dict_E 181 | }, f"{model_cpt_dir}/{model_name}_T{T}.pth") 182 | print('model saved') 183 | 184 | if log_p_0_X < best_log_p_0_X: 185 | best_log_p_0_X = log_p_0_X 186 | num_patient_epochs = 0 187 | 188 | if denoise_match_X < best_denoise_match_X: 189 | best_denoise_match_X = denoise_match_X 190 | num_patient_epochs = 0 191 | 192 | if log_p_0_E < best_log_p_0_E: 193 | best_log_p_0_E = log_p_0_E 194 | num_patient_epochs = 0 195 | 196 | if denoise_match_E < best_denoise_match_E: 197 | best_denoise_match_E = denoise_match_E 198 | num_patient_epochs = 0 199 | 200 | wandb.log({"epoch": epoch, 201 | "val/denoise_match_X": denoise_match_X, 202 | "val/denoise_match_E": denoise_match_E, 203 | "val/log_p_0_X": log_p_0_X, 204 | "val/log_p_0_E": log_p_0_E, 205 | "val/best_log_p_0_X": best_log_p_0_X, 206 | "val/best_denoise_match_X": best_denoise_match_X, 207 | "val/best_log_p_0_E": best_log_p_0_E, 208 | "val/best_denoise_match_E": best_denoise_match_E, 209 | "val/best_val_X": best_val_nll_X, 210 | "val/best_val_E": best_val_nll_E, 211 | "val/best_val_nll": best_val_nll}) 212 | 213 | print("Epoch {} | best val X nll {:.7f} | best val E nll {:.7f} | patience {}/{}".format( 214 | epoch, best_val_nll_X, best_val_nll_E, num_patient_epochs, train_config["patient_epochs"])) 215 | 216 | if num_patient_epochs == train_config["patient_epochs"]: 217 | break 218 | 219 | lr_scheduler_X.step(log_p_0_X) 220 | lr_scheduler_E.step(log_p_0_E) 221 | 222 | wandb.finish() 223 | 224 | if __name__ == '__main__': 225 | from argparse import ArgumentParser 226 | 227 | parser = ArgumentParser() 228 | parser.add_argument("-d", "--dataset", type=str, required=True, 229 | choices=["cora", "amazon_photo", "amazon_computer"]) 230 | args = parser.parse_args() 231 | 232 | main(args) 233 | --------------------------------------------------------------------------------