├── figures └── tam_concept.png ├── losses ├── __init__.py ├── ce.py └── balanced_softmax.py ├── nets ├── __init__.py ├── gcn.py ├── sage.py └── gat.py ├── ens_nets ├── __init__.py ├── sage.py ├── gcn.py └── gat.py ├── models ├── __init__.py ├── pc_softmax.py ├── reweight.py ├── tam.py ├── renode.py └── gens.py ├── LICENSE ├── data_utils.py ├── .gitignore ├── args.py ├── dataset ├── WebKB.py └── WikipediaNetwork.py ├── README.md ├── main_rw.py ├── main_pc.py ├── main_bs.py ├── main_renode.py └── main_ens.py /figures/tam_concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaeyun-Song/TAM/HEAD/figures/tam_concept.png -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .ce import CrossEntropy 2 | from .balanced_softmax import BalancedSoftmax -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn import create_gcn 2 | from .gat import create_gat 3 | from .sage import create_sage -------------------------------------------------------------------------------- /ens_nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn import create_gcn 2 | from .gat import create_gat 3 | from .sage import create_sage -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .tam import adjust_output, MeanAggregation 2 | from .reweight import get_weight 3 | from .pc_softmax import pc_softmax 4 | from .gens import * 5 | from .renode import * -------------------------------------------------------------------------------- /models/pc_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def pc_softmax(logits, cls_num): 5 | sample_per_class = torch.tensor(cls_num) 6 | spc = sample_per_class.type_as(logits) 7 | spc = spc.unsqueeze(0).expand(logits.shape[0], -1) 8 | logits = logits - spc.log() 9 | return logits -------------------------------------------------------------------------------- /losses/ce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class CrossEntropy(nn.Module): 7 | def __init__(self): 8 | super(CrossEntropy, self).__init__() 9 | 10 | def forward(self, input, target, weight=None, reduction='mean'): 11 | return F.cross_entropy(input, target, weight=weight, reduction=reduction) -------------------------------------------------------------------------------- /models/reweight.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def get_weight(is_reweight, class_num_list): 8 | if is_reweight: 9 | min_number = np.min(class_num_list) 10 | class_weight_list = [float(min_number)/float(num) for num in class_num_list] 11 | else: 12 | class_weight_list = [1. for _ in class_num_list] 13 | class_weight = torch.tensor(class_weight_list).type(torch.float32) 14 | 15 | return class_weight -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jaeyun Song 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /losses/balanced_softmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Reference: https://github.com/jiawei-ren/BalancedMetaSoftmax-Classification/blob/main/loss/BalancedSoftmaxLoss.py 3 | """ 4 | 5 | """Copyright (c) Facebook, Inc. and its affiliates. 6 | All rights reserved. 7 | This source code is licensed under the license found in the 8 | LICENSE file in the root directory of this source tree. 9 | Portions of the source code are from the OLTR project which 10 | notice below and in LICENSE in the root directory of 11 | this source tree. 12 | Copyright (c) 2019, Zhongqi Miao 13 | All rights reserved. 14 | """ 15 | 16 | 17 | import torch 18 | from torch.nn.modules.loss import _Loss 19 | import torch.nn.functional as F 20 | import json 21 | 22 | 23 | class BalancedSoftmax(_Loss): 24 | """ 25 | Balanced Softmax Loss 26 | """ 27 | def __init__(self, cls_num): 28 | super(BalancedSoftmax, self).__init__() 29 | self.sample_per_class = torch.tensor(cls_num) 30 | 31 | def forward(self, input, label, reduction='mean', weight=None): 32 | return balanced_softmax_loss(label, input, self.sample_per_class, reduction) 33 | 34 | 35 | def balanced_softmax_loss(labels, logits, sample_per_class, reduction): 36 | """Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`. 37 | Args: 38 | labels: A int tensor of size [batch]. 39 | logits: A float tensor of size [batch, no_of_classes]. 40 | sample_per_class: A int tensor of size [no of classes]. 41 | reduction: string. One of "none", "mean", "sum" 42 | Returns: 43 | loss: A float tensor. Balanced Softmax Loss. 44 | """ 45 | spc = sample_per_class.type_as(logits) 46 | spc = spc.unsqueeze(0).expand(logits.shape[0], -1) 47 | logits = logits + spc.log() 48 | loss = F.cross_entropy(input=logits, target=labels, reduction=reduction) 49 | return loss -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch_scatter import scatter_add 4 | 5 | 6 | def get_dataset(name, path, split_type='public'): 7 | import torch_geometric.transforms as T 8 | 9 | if name == "Cora" or name == "CiteSeer" or name == "PubMed": 10 | from torch_geometric.datasets import Planetoid 11 | dataset = Planetoid(path, name, transform=T.NormalizeFeatures(), split=split_type) 12 | elif name == "chameleon" or name == "squirrel": 13 | from dataset.WikipediaNetwork import WikipediaNetwork 14 | dataset = WikipediaNetwork(path, name, transform = T.NormalizeFeatures()) 15 | elif name == "Wisconsin": 16 | # from torch_geometric.datasets import WebKB 17 | from dataset.WebKB import WebKB 18 | dataset = WebKB(path, name, transform = T.NormalizeFeatures()) 19 | else: 20 | raise NotImplementedError("Not Implemented Dataset!") 21 | 22 | return dataset 23 | 24 | 25 | def split_semi_dataset(total_node, n_data, n_cls, class_num_list, idx_info, device): 26 | new_idx_info = [] 27 | _train_mask = idx_info[0].new_zeros(total_node, dtype=torch.bool, device=device) 28 | for i in range(n_cls): 29 | if n_data[i] > class_num_list[i]: 30 | cls_idx = torch.randperm(len(idx_info[i])) 31 | cls_idx = idx_info[i][cls_idx] 32 | cls_idx = cls_idx[:class_num_list[i]] 33 | new_idx_info.append(cls_idx) 34 | else: 35 | new_idx_info.append(idx_info[i]) 36 | _train_mask[new_idx_info[i]] = True 37 | 38 | assert _train_mask.sum().long() == sum(class_num_list) 39 | assert sum([len(idx) for idx in new_idx_info]) == sum(class_num_list) 40 | 41 | return _train_mask, new_idx_info 42 | 43 | 44 | def get_idx_info(label, n_cls, train_mask): 45 | index_list = torch.arange(len(label)) 46 | idx_info = [] 47 | for i in range(n_cls): 48 | cls_indices = index_list[((label == i) & train_mask)] 49 | idx_info.append(cls_indices) 50 | return idx_info 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /nets/gcn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Reference: https://github.com/victorchen96/ReNode/blob/main/transductive/network/gcn.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | import scipy 10 | import numpy as np 11 | 12 | from torch_geometric.nn import GCNConv 13 | 14 | 15 | class StandGCN1(nn.Module): 16 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=1): 17 | super(StandGCN1, self).__init__() 18 | self.conv1 = GCNConv(nfeat, nclass) 19 | 20 | self.reg_params = [] 21 | self.non_reg_params = self.conv1.parameters() 22 | 23 | def forward(self, x, adj): 24 | edge_index = adj 25 | x = self.conv1(x, edge_index) 26 | # x = F.relu(self.conv1(x, edge_index)) 27 | 28 | return x 29 | 30 | 31 | class StandGCN2(nn.Module): 32 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=2): 33 | super(StandGCN2, self).__init__() 34 | self.conv1 = GCNConv(nfeat, nhid) 35 | self.conv2 = GCNConv(nhid, nclass) 36 | self.dropout_p = dropout 37 | 38 | self.reg_params = self.conv1.parameters() 39 | self.non_reg_params = self.conv2.parameters() 40 | 41 | def forward(self, x, adj): 42 | x = self.conv1(x,adj) 43 | x = F.relu(x) 44 | x = F.dropout(x, p= self.dropout_p, training=self.training) 45 | x = self.conv2(x, adj) 46 | 47 | return x 48 | 49 | 50 | class StandGCNX(nn.Module): 51 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=3): 52 | super(StandGCNX, self).__init__() 53 | self.conv1 = GCNConv(nfeat, nhid) 54 | self.conv2 = GCNConv(nhid, nclass) 55 | self.convx = nn.ModuleList([GCNConv(nhid, nhid) for _ in range(nlayer-2)]) 56 | self.dropout_p = dropout 57 | 58 | self.reg_params = list(self.conv1.parameters()) + list(self.convx.parameters()) 59 | self.non_reg_params = self.conv2.parameters() 60 | 61 | def forward(self, x, adj): 62 | edge_index = adj 63 | 64 | x = F.relu(self.conv1(x, edge_index)) 65 | 66 | for iter_layer in self.convx: 67 | x = F.dropout(x,p= self.dropout_p, training=self.training) 68 | x = F.relu(iter_layer(x, edge_index)) 69 | 70 | x = F.dropout(x, p= self.dropout_p, training=self.training) 71 | x = self.conv2(x, edge_index) 72 | 73 | return x 74 | 75 | 76 | def create_gcn(nfeat, nhid, nclass, dropout, nlayer): 77 | if nlayer == 1: 78 | model = StandGCN1(nfeat, nhid, nclass, dropout,nlayer) 79 | elif nlayer == 2: 80 | model = StandGCN2(nfeat, nhid, nclass, dropout,nlayer) 81 | else: 82 | model = StandGCNX(nfeat, nhid, nclass, dropout,nlayer) 83 | return model 84 | -------------------------------------------------------------------------------- /nets/sage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Reference: https://github.com/victorchen96/ReNode/blob/main/transductive/network/sage.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | import scipy 10 | import numpy as np 11 | from torch_geometric.nn import SAGEConv 12 | 13 | 14 | class GraphSAGE1(nn.Module): 15 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=1): 16 | super(GraphSAGE1, self).__init__() 17 | self.conv1 = SAGEConv(nfeat, nclass) 18 | 19 | self.reg_params = [] 20 | self.non_reg_params = self.conv1.parameters() 21 | 22 | def forward(self, x, adj): 23 | edge_index = adj 24 | x = F.relu(self.conv1(x, edge_index)) 25 | return x 26 | 27 | 28 | class GraphSAGE2(nn.Module): 29 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=2): 30 | super(GraphSAGE2, self).__init__() 31 | self.conv1 = SAGEConv(nfeat, nhid) 32 | self.conv2 = SAGEConv(nhid, nclass) 33 | 34 | self.dropout_p = dropout 35 | 36 | self.reg_params = self.conv1.parameters() 37 | self.non_reg_params = self.conv2.parameters() 38 | 39 | def forward(self, x, adj): 40 | 41 | edge_index = adj 42 | x = F.relu(self.conv1(x, edge_index)) 43 | x = F.dropout(x, p= self.dropout_p, training=self.training) 44 | x = self.conv2(x, edge_index) 45 | 46 | return x 47 | 48 | 49 | class GraphSAGEX(nn.Module): 50 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=3): 51 | super(GraphSAGEX, self).__init__() 52 | self.conv1 = SAGEConv(nfeat, nhid) 53 | self.conv2 = SAGEConv(nhid, nclass) 54 | self.convx = nn.ModuleList([SAGEConv(nhid, nhid) for _ in range(nlayer-2)]) 55 | self.dropout_p = dropout 56 | 57 | self.reg_params = list(self.conv1.parameters()) + list(self.convx.parameters()) 58 | self.non_reg_params = self.conv2.parameters() 59 | 60 | def forward(self, x, adj): 61 | edge_index = adj 62 | 63 | x = F.relu(self.conv1(x, edge_index)) 64 | 65 | for iter_layer in self.convx: 66 | x = F.dropout(x, p= self.dropout_p, training=self.training) 67 | x = F.relu(iter_layer(x, edge_index)) 68 | 69 | x = F.dropout(x, p= self.dropout_p, training=self.training) 70 | x = self.conv2(x, edge_index) 71 | 72 | return x 73 | 74 | 75 | def create_sage(nfeat, nhid, nclass, dropout, nlayer): 76 | if nlayer == 1: 77 | model = GraphSAGE1(nfeat, nhid, nclass, dropout,nlayer) 78 | elif nlayer == 2: 79 | model = GraphSAGE2(nfeat, nhid, nclass, dropout,nlayer) 80 | else: 81 | model = GraphSAGEX(nfeat, nhid, nclass, dropout,nlayer) 82 | return model 83 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # Dataset 7 | parser.add_argument('--dataset', type=str, default='CiteSeer', 8 | help='Dataset Name') 9 | parser.add_argument('--imb_ratio', type=float, default=10, 10 | help='Imbalance Ratio') 11 | # Architecture 12 | parser.add_argument('--net', type=str, default='GCN', 13 | help='Architecture name') 14 | parser.add_argument('--n_layer', type=int, default=2, 15 | help='the number of layers') 16 | parser.add_argument('--feat_dim', type=int, default=256, 17 | help='Feature dimension') 18 | # Imbalance Loss 19 | parser.add_argument('--loss_type', type=str, default='bs', 20 | help='Loss type') 21 | # Method 22 | parser.add_argument('--tam', action='store_true', 23 | help='use tam') 24 | parser.add_argument('--reweight', action='store_true', 25 | help='use reweight') 26 | parser.add_argument('--pc_softmax', action='store_true', 27 | help='use pc softmax') 28 | parser.add_argument('--ens', action='store_true', 29 | help='use GraphENS') 30 | parser.add_argument('--renode', action='store_true', 31 | help='use ReNode') 32 | # Hyperparameter for GraphENS 33 | parser.add_argument('--keep_prob', type=float, default=0.01, 34 | help='Keeping Probability') 35 | parser.add_argument('--pred_temp', type=float, default=2, 36 | help='Prediction temperature') 37 | # ReNode 38 | parser.add_argument('--loss_name', default="re-weight", type=str, help="the training loss") #ce focal re-weight cb-softmax 39 | parser.add_argument('--factor_focal', default=2.0, type=float, help="alpha in Focal Loss") 40 | parser.add_argument('--factor_cb', default=0.9999, type=float, help="beta in CB Loss") 41 | parser.add_argument('--rn_base', default=0.5, type=float, help="Lower bound of RN") 42 | parser.add_argument('--rn_max', default=1.5, type=float, help="Upper bound of RN") 43 | 44 | # Hyperparameter for TAM 45 | parser.add_argument('--tam_alpha', type=float, default=2.5, 46 | help='coefficient of ACM') 47 | parser.add_argument('--tam_beta', type=float, default=0.5, 48 | help='coefficient of ADM') 49 | parser.add_argument('--temp_phi', type=float, default=1.2, 50 | help='classwise temperature') 51 | parser.add_argument('--warmup', type=int, default=5, 52 | help='warmup') 53 | args = parser.parse_args() 54 | 55 | return args -------------------------------------------------------------------------------- /nets/gat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Reference: https://github.com/victorchen96/ReNode/blob/main/transductive/network/gat.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | import scipy 10 | import numpy as np 11 | 12 | from torch_geometric.nn import GATConv 13 | 14 | class StandGAT1(nn.Module): 15 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=1): 16 | super(StandGAT1, self).__init__() 17 | self.conv1 = GATConv(nfeat, nclass,heads=1) 18 | 19 | self.reg_params = [] 20 | self.non_reg_params = self.conv1.parameters() 21 | 22 | def forward(self, x, adj): 23 | edge_index = adj 24 | x = F.relu(self.conv1(x, edge_index)) 25 | 26 | return x 27 | 28 | 29 | class StandGAT2(nn.Module): 30 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=2): 31 | super(StandGAT2, self).__init__() 32 | 33 | num_head = 4 34 | head_dim = nhid//num_head 35 | 36 | self.conv1 = GATConv(nfeat, head_dim, heads=num_head) 37 | self.conv2 = GATConv(nhid, nclass, heads=1, concat=False) 38 | 39 | self.dropout_p = dropout 40 | 41 | self.reg_params = self.conv1.parameters() 42 | self.non_reg_params = self.conv2.parameters() 43 | 44 | def forward(self, x, adj): 45 | 46 | edge_index = adj 47 | 48 | x = F.relu(self.conv1(x, edge_index)) 49 | 50 | x = F.dropout(x, p= self.dropout_p, training=self.training) 51 | 52 | x = self.conv2(x, edge_index) 53 | 54 | 55 | return x 56 | 57 | class StandGATX(nn.Module): 58 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=3): 59 | super(StandGATX, self).__init__() 60 | 61 | num_head = 4 62 | head_dim = nhid//num_head 63 | 64 | self.conv1 = GATConv(nfeat, head_dim, heads=num_head) 65 | self.conv2 = GATConv(nhid, nclass) 66 | self.convx = nn.ModuleList([GATConv(nhid, head_dim, heads=num_head) for _ in range(nlayer-2)]) 67 | self.dropout_p = dropout 68 | 69 | self.reg_params = list(self.conv1.parameters()) + list(self.convx.parameters()) 70 | self.non_reg_params = self.conv2.parameters() 71 | 72 | 73 | def forward(self, x, adj): 74 | edge_index = adj 75 | 76 | x = F.relu(self.conv1(x, edge_index)) 77 | 78 | for iter_layer in self.convx: 79 | x = F.dropout(x, p= self.dropout_p, training=self.training) 80 | x = F.relu(iter_layer(x, edge_index)) 81 | 82 | x = F.dropout(x,p= self.dropout_p, training=self.training) 83 | x = self.conv2(x, edge_index) 84 | 85 | return x 86 | 87 | 88 | def create_gat(nfeat, nhid, nclass, dropout, nlayer): 89 | if nlayer == 1: 90 | model = StandGAT1(nfeat, nhid, nclass, dropout,nlayer) 91 | elif nlayer == 2: 92 | model = StandGAT2(nfeat, nhid, nclass, dropout,nlayer) 93 | else: 94 | model = StandGATX(nfeat, nhid, nclass, dropout,nlayer) 95 | return model 96 | -------------------------------------------------------------------------------- /dataset/WebKB.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import numpy as np 5 | from torch_sparse import coalesce 6 | from torch_geometric.data import InMemoryDataset, download_url, Data 7 | 8 | 9 | class WebKB(InMemoryDataset): 10 | r"""The WebKB datasets used in the 11 | `"Geom-GCN: Geometric Graph Convolutional Networks" 12 | `_ paper. 13 | Nodes represent web pages and edges represent hyperlinks between them. 14 | Node features are the bag-of-words representation of web pages. 15 | The task is to classify the nodes into one of the five categories, student, 16 | project, course, staff, and faculty. 17 | 18 | Args: 19 | root (string): Root directory where the dataset should be saved. 20 | name (string): The name of the dataset (:obj:`"Cornell"`, 21 | :obj:`"Texas"`, :obj:`"Wisconsin"`). 22 | transform (callable, optional): A function/transform that takes in an 23 | :obj:`torch_geometric.data.Data` object and returns a transformed 24 | version. The data object will be transformed before every access. 25 | (default: :obj:`None`) 26 | pre_transform (callable, optional): A function/transform that takes in 27 | an :obj:`torch_geometric.data.Data` object and returns a 28 | transformed version. The data object will be transformed before 29 | being saved to disk. (default: :obj:`None`) 30 | """ 31 | 32 | url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' 33 | 34 | def __init__(self, root, name, transform=None, pre_transform=None): 35 | self.name = name.lower() 36 | assert self.name in ['cornell', 'texas', 'wisconsin'] 37 | 38 | super().__init__(root, transform, pre_transform) 39 | self.data, self.slices = torch.load(self.processed_paths[0]) 40 | 41 | @property 42 | def raw_dir(self): 43 | return osp.join(self.root, self.name, 'raw') 44 | 45 | @property 46 | def processed_dir(self): 47 | return osp.join(self.root, self.name, 'processed') 48 | 49 | @property 50 | def raw_file_names(self): 51 | out = ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] 52 | out += [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)] 53 | return out 54 | 55 | @property 56 | def processed_file_names(self): 57 | return 'data.pt' 58 | 59 | def download(self): 60 | for f in self.raw_file_names[:2]: 61 | download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir) 62 | for f in self.raw_file_names[2:]: 63 | download_url(f'{self.url}/splits/{f}', self.raw_dir) 64 | 65 | def process(self): 66 | with open(self.raw_paths[0], 'r') as f: 67 | data = f.read().split('\n')[1:-1] 68 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data] 69 | x = torch.tensor(x, dtype=torch.float) 70 | 71 | y = [int(r.split('\t')[2]) for r in data] 72 | y = torch.tensor(y, dtype=torch.long) 73 | 74 | with open(self.raw_paths[1], 'r') as f: 75 | data = f.read().split('\n')[1:-1] 76 | data = [[int(v) for v in r.split('\t')] for r in data] 77 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() 78 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 79 | 80 | train_masks, val_masks, test_masks = [], [], [] 81 | for f in self.raw_paths[2:]: 82 | tmp = np.load(f) 83 | train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)] 84 | val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)] 85 | test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)] 86 | train_mask = torch.stack(train_masks, dim=1) 87 | val_mask = torch.stack(val_masks, dim=1) 88 | test_mask = torch.stack(test_masks, dim=1) 89 | 90 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 91 | val_mask=val_mask, test_mask=test_mask) 92 | data = data if self.pre_transform is None else self.pre_transform(data) 93 | torch.save(self.collate([data]), self.processed_paths[0]) 94 | 95 | def __repr__(self) -> str: 96 | return f'{self.name}()' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TAM: Topology-Aware Margin Loss for Class-Imbalanced Node Classification 2 | 3 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10021511.svg)](https://doi.org/10.5281/zenodo.10021511) 4 | 5 | ## Introduction 6 | 7 | Official Pytorch implementation of ICML 2022 paper "[TAM: Topology-Aware Margin Loss for Class-Imbalanced Node Classification](https://proceedings.mlr.press/v162/song22a)" 8 | 9 | ![Overview Figure](figures/tam_concept.png) 10 | This work investigates the phenomenon that imbalance handling algorithms for node classificaion excessively increase the false positives of minor classes. 11 | To mitigate this problem, we propose TAM, which adjusts the margin of each node according to the deviation from class-averaged topology. 12 | 13 | ## Semi-Supervised Node Classification (Public Split) 14 | 15 | The code for semi-supervised node classification. 16 | This is implemented mainly based on [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric). 17 | 18 | - Running command for TAM: 19 | 1. Balanced Softmax + TAM 20 | ``` 21 | python main_bs.py \ 22 | --loss_type bs \ 23 | --dataset [dataset] \ 24 | --net [net] \ 25 | --n_layer [n_layer] \ 26 | --feat_dim [feat_dim] \ 27 | --tam \ 28 | --tam_alpha [tam_alpha] \ 29 | --tam_beta [tam_beta] \ 30 | --temp_phi [temp_phi] \ 31 | ``` 32 | 2. GraphENS + TAM 33 | ``` 34 | python main_ens.py --ens \ 35 | --loss_type ce \ 36 | --tam \ 37 | ``` 38 | 3. ReNode + TAM 39 | ``` 40 | python main_renode.py --renode \ 41 | --loss_type ce \ 42 | --loss_name [loss_name] \ 43 | --rn_base [rn_base] \ 44 | --rn_max [rn_max] \ 45 | --tam \ 46 | ``` 47 | 48 | - Running command for baselines: 49 | 1. Cross Entropy 50 | ``` 51 | python main_bs.py \ 52 | --loss_type ce \ 53 | ``` 54 | 2. Re-Weight 55 | ``` 56 | python main_rw.py --reweight \ 57 | --loss_type ce \ 58 | ``` 59 | 3. PC Softmax 60 | ``` 61 | python main_pc.py --pc_softmax \ 62 | --loss_type ce \ 63 | ``` 64 | 4. Balanced Softmax 65 | ``` 66 | python main_bs.py \ 67 | --loss_type bs \ 68 | ``` 69 | 5. GraphENS 70 | ``` 71 | python main_ens.py --ens \ 72 | --loss_type ce \ 73 | ``` 74 | 6. ReNode 75 | ``` 76 | python main_renode.py --renode \ 77 | --loss_type ce \ 78 | --loss_name [loss_name] \ 79 | --rn_base [rn_base] \ 80 | --rn_max [rn_max] \ 81 | ``` 82 | 83 | - Argument Description for TAM 84 | 1. Experiment Dataset (the dataset will be downloaded automatically at the first running time):\ 85 | Set [dataset] as one of ['Cora', 'Citeseer', 'PubMed', 'chameleon', 'squirrel', 'Wisconsin'] 86 | 2. Backbone GNN architecture:\ 87 | Set [net] as one of ['GCN', 'GAT', 'SAGE'] 88 | 3. The number of layer for GNN:\ 89 | Set [n_layer] as one of [1, 2, 3] 90 | 4. Hidden dimension for GNN:\ 91 | Set [feat_dim] as one of [64, 128, 256] 92 | 5. The strength of ACM, α:\ 93 | Set [tam_alpha] as one of [0.5, 1.5, 2.5] 94 | 6. The strength of ADM, β:\ 95 | Set [tam_beta] as one of [0.25, 0.5] 96 | 7. The class-wise temeperature hyperparameter, 𝜙: \ 97 | Set [temp_phi] as one of [0.8, 1.2] 98 | 99 | 100 | ## Dependencies 101 | This code has been tested with 102 | - Python == 3.8.0 103 | - Pytorch == 1.8.0 104 | - Pytorch Geometric == 2.0.1 105 | - torch_scatter == 2.0.8 106 | 107 | ## Citation 108 | ``` 109 | 110 | @InProceedings{pmlr-v162-song22a, 111 | title = {{TAM}: Topology-Aware Margin Loss for Class-Imbalanced Node Classification}, 112 | author = {Song, Jaeyun and Park, Joonhyung and Yang, Eunho}, 113 | booktitle = {Proceedings of the 39th International Conference on Machine Learning}, 114 | pages = {20369--20383}, 115 | year = {2022}, 116 | volume = {162}, 117 | series = {Proceedings of Machine Learning Research}, 118 | month = {17--23 Jul}, 119 | publisher = {PMLR}, 120 | } 121 | 122 | ``` 123 | 124 | ## Acknowledgement 125 | This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) 126 | (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST)) 127 | -------------------------------------------------------------------------------- /models/tam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_scatter import scatter_add 6 | from torch_geometric.utils import to_dense_batch 7 | from torch_geometric.nn import MessagePassing 8 | from torch_geometric.utils import add_self_loops, degree 9 | 10 | 11 | ## Jensen-Shanon Divergence ## 12 | def compute_jsd(dist1, dist2): 13 | dist_mean = (dist1 + dist2) / 2. 14 | jsd = (F.kl_div(dist_mean.log(), dist1, reduction = 'none') + F.kl_div(dist_mean.log(), dist2, reduction = 'none')) / 2. 15 | return jsd 16 | 17 | 18 | ## TAM ## 19 | @torch.no_grad() 20 | def compute_tam(output, edge_index, label, train_mask, aggregator, class_num_list=None, temp_phi = None, temp_gamma = None): 21 | n_cls = label.max().item() + 1 22 | 23 | # Apply class-wise temperature 24 | cls_num_list = torch.FloatTensor(class_num_list).to(output.device) 25 | cls_num_ratio = cls_num_list / cls_num_list.sum() 26 | cls_num_ratio = cls_num_ratio * temp_gamma + (1- temp_gamma) 27 | max_beta = torch.max(cls_num_ratio) 28 | cls_temperature = (temp_phi * (cls_num_ratio + 1 - max_beta)).unsqueeze(0) 29 | temp = 1 / cls_temperature 30 | 31 | # Predict unlabeled nodes 32 | agg_out = F.softmax(output.clone().detach()/temp, dim=1) 33 | agg_out[train_mask] = F.one_hot(label[train_mask].clone(), num_classes=n_cls).float() # only use labeled nodes 34 | neighbor_dist = aggregator(agg_out, edge_index)[train_mask] # (# of labeled nodes, # of classes) 35 | 36 | # Compute class-wise connectivity matrix 37 | connectivity_matrix= [] 38 | for c in range(n_cls): 39 | c_mask = (label[train_mask] == c) 40 | connectivity_matrix.append(neighbor_dist[c_mask].mean(dim=0)) 41 | connectivity_matrix= torch.stack(connectivity_matrix, dim=0) 42 | 43 | # Preprocess class-wise connectivity matrix and NLD for numerical stability 44 | center_mask = F.one_hot(label[train_mask].clone(), num_classes=n_cls).bool() 45 | neighbor_dist[neighbor_dist<1e-6] = 1e-6 46 | connectivity_matrix[connectivity_matrix<1e-6] = 1e-6 47 | 48 | # Compute ACM 49 | acm = (neighbor_dist[center_mask].unsqueeze(dim=1) / torch.diagonal(connectivity_matrix).unsqueeze(dim=1)[label[train_mask]]) \ 50 | * (connectivity_matrix[label[train_mask]] / neighbor_dist) 51 | acm[acm>1] = 1 52 | acm[center_mask] = 1 53 | 54 | # Compute ADM 55 | cls_pair_jsd = compute_jsd(connectivity_matrix.unsqueeze(dim=0), connectivity_matrix.unsqueeze(dim=1)).sum(dim=-1) # distance between classes 56 | cls_pair_jsd[cls_pair_jsd<1e-6] = 1e-6 57 | self_kl = compute_jsd(neighbor_dist, connectivity_matrix[label[train_mask]]).sum(dim=-1,keepdim=True) # devation from self-class averaged nld 58 | neighbor_kl = compute_jsd(neighbor_dist.unsqueeze(1),connectivity_matrix.unsqueeze(0)).sum(dim=-1) # distance between node nld and each class averaged nld 59 | adm = (self_kl**2 + (cls_pair_jsd**2)[label[train_mask]] - neighbor_kl**2) / (2*(cls_pair_jsd**2)[label[train_mask]]) 60 | 61 | adm[center_mask] = 0 62 | 63 | return acm, adm 64 | 65 | 66 | def adjust_output(args, output, edge_index, label, train_mask, aggregator, class_num_list, epoch): 67 | """ 68 | Adjust the margin of each labeled nodes according to local topolgy 69 | Input: 70 | args: hyperparameters for TAM 71 | output: model prediction for whole nodes (include unlabeled nodes); [# of nodes, # of classes] 72 | edge_index: ; [2, # of nodes] 73 | label: ; [# of nodes] 74 | train_mask: ; [# of nodes] 75 | aggregator: function (below) 76 | class_num_list: the number of nodes for each class; [# of classes] 77 | epoch: current epoch; integer 78 | Output: 79 | output: adjusted logits 80 | """ 81 | 82 | # Compute ACM and ADM 83 | if args.tam and epoch > args.warmup: 84 | acm, adm = compute_tam(output, edge_index, label, train_mask, aggregator, \ 85 | class_num_list=class_num_list, temp_phi = args.temp_phi, temp_gamma = 0.4) 86 | 87 | output = output[train_mask] 88 | # Adjust outputs 89 | if args.tam and epoch > args.warmup: 90 | acm = acm.log() 91 | adm = - adm 92 | output = output + args.tam_alpha*acm + args.tam_beta*adm 93 | 94 | return output 95 | 96 | 97 | class MeanAggregation(MessagePassing): 98 | def __init__(self): 99 | super(MeanAggregation, self).__init__(aggr='mean') 100 | 101 | def forward(self, x, edge_index): 102 | # x has shape [N, in_channels] 103 | # edge_index has shape [2, E] 104 | 105 | # Step 1: Add self-loops to the adjacency matrix. 106 | _edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 107 | 108 | # Step 4-5: Start propagating messages. 109 | return self.propagate(_edge_index, x=x) -------------------------------------------------------------------------------- /ens_nets/sage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch Geometric 3 | Ref: https://github.com/pyg-team/pytorch_geometric/blob/97d55577f1d0bf33c1bfbe0ef864923ad5cb844d/torch_geometric/nn/conv/sage_conv.py 4 | """ 5 | 6 | from typing import Union, Tuple 7 | from torch_geometric.typing import OptPairTensor, Adj, Size, OptTensor, PairTensor 8 | 9 | import torch 10 | from torch import Tensor 11 | from torch.nn import Linear 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import math 15 | import scipy 16 | import numpy as np 17 | 18 | from torch_sparse import SparseTensor, matmul 19 | from torch_geometric.nn.conv import MessagePassing 20 | from torch_scatter import scatter_add 21 | 22 | class SAGEConv(MessagePassing): 23 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on 24 | Large Graphs" `_ paper 25 | .. math:: 26 | \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot 27 | \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j 28 | Args: 29 | in_channels (int or tuple): Size of each input sample. A tuple 30 | corresponds to the sizes of source and target dimensionalities. 31 | out_channels (int): Size of each output sample. 32 | normalize (bool, optional): If set to :obj:`True`, output features 33 | will be :math:`\ell_2`-normalized, *i.e.*, 34 | :math:`\frac{\mathbf{x}^{\prime}_i} 35 | {\| \mathbf{x}^{\prime}_i \|_2}`. 36 | (default: :obj:`False`) 37 | root_weight (bool, optional): If set to :obj:`False`, the layer will 38 | not add transformed root node features to the output. 39 | (default: :obj:`True`) 40 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 41 | an additive bias. (default: :obj:`True`) 42 | **kwargs (optional): Additional arguments of 43 | :class:`torch_geometric.nn.conv.MessagePassing`. 44 | """ 45 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 46 | out_channels: int, normalize: bool = False, 47 | root_weight: bool = True, 48 | bias: bool = True, **kwargs): # yapf: disable 49 | kwargs.setdefault('aggr', 'mean') 50 | super(SAGEConv, self).__init__(**kwargs) 51 | 52 | self.in_channels = in_channels 53 | self.out_channels = out_channels 54 | self.normalize = normalize 55 | self.root_weight = root_weight 56 | 57 | if isinstance(in_channels, int): 58 | in_channels = (in_channels, in_channels) 59 | 60 | self.lin_l = Linear(in_channels[0], out_channels, bias=bias) 61 | if self.root_weight: 62 | self.temp_weight = Linear(in_channels[1], out_channels, bias=False) 63 | 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self): 67 | self.lin_l.reset_parameters() 68 | if self.root_weight: 69 | self.temp_weight.reset_parameters() 70 | 71 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight, 72 | size: Size = None) -> Tensor: 73 | """""" 74 | if isinstance(x, Tensor): 75 | x: OptPairTensor = (x, x) 76 | 77 | # propagate_type: (x: OptPairTensor) 78 | out = self.propagate(edge_index, x=x, size=size) 79 | out = self.lin_l(out) 80 | 81 | x_r = x[1] 82 | if self.root_weight and x_r is not None: 83 | out += self.temp_weight(x_r) 84 | 85 | if self.normalize: 86 | out = F.normalize(out, p=2., dim=-1) 87 | 88 | return out 89 | 90 | def message(self, x_j: Tensor) -> Tensor: 91 | return x_j 92 | 93 | def message_and_aggregate(self, adj_t: SparseTensor, 94 | x: OptPairTensor) -> Tensor: 95 | adj_t = adj_t.set_value(None, layout=None) 96 | return matmul(adj_t, x[0], reduce=self.aggr) 97 | 98 | def __repr__(self): 99 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 100 | self.out_channels) 101 | 102 | class GraphSAGE1(nn.Module): 103 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=1): 104 | super(GraphSAGE1, self).__init__() 105 | self.conv1 = SAGEConv(nfeat, nclass) 106 | 107 | self.reg_params = [] 108 | self.non_reg_params = self.conv1.parameters() 109 | 110 | def forward(self, x, adj, edge_weight=None): 111 | edge_index = adj 112 | x = F.relu(self.conv1(x, edge_index, edge_weight)) 113 | 114 | return x 115 | 116 | 117 | class GraphSAGE2(nn.Module): 118 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=2): 119 | super(GraphSAGE2, self).__init__() 120 | self.conv1 = SAGEConv(nfeat, nhid) 121 | self.conv2 = SAGEConv(nhid, nclass) 122 | self.dropout_p = dropout 123 | 124 | self.reg_params = list(self.conv1.parameters()) 125 | self.non_reg_params = self.conv2.parameters() 126 | 127 | def forward(self, x, adj, edge_weight=None): 128 | edge_index = adj 129 | x = F.relu(self.conv1(x, edge_index, edge_weight)) 130 | x = F.dropout(x, p= self.dropout_p, training=self.training) 131 | x = self.conv2(x, edge_index, edge_weight) 132 | return x 133 | 134 | 135 | class GraphSAGEX(nn.Module): 136 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=3): 137 | super(GraphSAGEX, self).__init__() 138 | self.conv1 = SAGEConv(nfeat, nhid) 139 | self.conv2 = SAGEConv(nhid, nclass) 140 | self.convx = nn.ModuleList([SAGEConv(nhid, nhid) for _ in range(nlayer-2)]) 141 | self.dropout_p = dropout 142 | 143 | self.reg_params = list(self.conv1.parameters()) + list(self.convx.parameters()) 144 | self.non_reg_params = self.conv2.parameters() 145 | 146 | def forward(self, x, adj, edge_weight=None): 147 | edge_index = adj 148 | x = F.relu(self.conv1(x, edge_index, edge_weight)) 149 | 150 | for iter_layer in self.convx: 151 | x = F.dropout(x, p= self.dropout_p, training=self.training) 152 | x = F.relu(iter_layer(x, edge_index,edge_weight)) 153 | 154 | x = F.dropout(x, p= self.dropout_p, training=self.training) 155 | x = self.conv2(x, edge_index,edge_weight) 156 | 157 | return x 158 | 159 | def create_sage(nfeat, nhid, nclass, dropout, nlayer): 160 | if nlayer == 1: 161 | model = GraphSAGE1(nfeat, nhid, nclass, dropout,nlayer) 162 | elif nlayer == 2: 163 | model = GraphSAGE2(nfeat, nhid, nclass, dropout,nlayer) 164 | else: 165 | model = GraphSAGEX(nfeat, nhid, nclass, dropout,nlayer) 166 | return model 167 | -------------------------------------------------------------------------------- /dataset/WikipediaNetwork.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | import os.path as osp 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from torch_sparse import coalesce 9 | from torch_geometric.data import InMemoryDataset, download_url, Data 10 | 11 | 12 | class WikipediaNetwork(InMemoryDataset): 13 | r"""The Wikipedia networks introduced in the 14 | `"Multi-scale Attributed Node Embedding" 15 | `_ paper. 16 | Nodes represent web pages and edges represent hyperlinks between them. 17 | Node features represent several informative nouns in the Wikipedia pages. 18 | The task is to predict the average daily traffic of the web page. 19 | 20 | Args: 21 | root (string): Root directory where the dataset should be saved. 22 | name (string): The name of the dataset (:obj:`"chameleon"`, 23 | :obj:`"crocodile"`, :obj:`"squirrel"`). 24 | geom_gcn_preprocess (bool): If set to :obj:`True`, will load the 25 | pre-processed data as introduced in the `"Geom-GCN: Geometric 26 | Graph Convolutional Networks" _`, 27 | in which the average monthly traffic of the web page is converted 28 | into five categories to predict. 29 | If set to :obj:`True`, the dataset :obj:`"crocodile"` is not 30 | available. 31 | transform (callable, optional): A function/transform that takes in an 32 | :obj:`torch_geometric.data.Data` object and returns a transformed 33 | version. The data object will be transformed before every access. 34 | (default: :obj:`None`) 35 | pre_transform (callable, optional): A function/transform that takes in 36 | an :obj:`torch_geometric.data.Data` object and returns a 37 | transformed version. The data object will be transformed before 38 | being saved to disk. (default: :obj:`None`) 39 | 40 | """ 41 | 42 | raw_url = 'https://graphmining.ai/datasets/ptg/wiki' 43 | processed_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/' 44 | 'geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f') 45 | 46 | def __init__(self, root: str, name: str, geom_gcn_preprocess: bool = True, 47 | transform: Optional[Callable] = None, 48 | pre_transform: Optional[Callable] = None): 49 | self.name = name.lower() 50 | self.geom_gcn_preprocess = geom_gcn_preprocess 51 | assert self.name in ['chameleon', 'crocodile', 'squirrel'] 52 | if geom_gcn_preprocess and self.name == 'crocodile': 53 | raise AttributeError("The dataset 'crocodile' is not available in " 54 | "case 'geom_gcn_preprocess=True'") 55 | super().__init__(root, transform, pre_transform) 56 | self.data, self.slices = torch.load(self.processed_paths[0]) 57 | 58 | @property 59 | def raw_dir(self) -> str: 60 | if self.geom_gcn_preprocess: 61 | return osp.join(self.root, self.name, 'geom_gcn', 'raw') 62 | else: 63 | return osp.join(self.root, self.name, 'raw') 64 | 65 | @property 66 | def processed_dir(self) -> str: 67 | if self.geom_gcn_preprocess: 68 | return osp.join(self.root, self.name, 'geom_gcn', 'processed') 69 | else: 70 | return osp.join(self.root, self.name, 'processed') 71 | 72 | @property 73 | def raw_file_names(self) -> str: 74 | if self.geom_gcn_preprocess: 75 | return (['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + 76 | [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)]) 77 | else: 78 | return f'{self.name}.npz' 79 | 80 | @property 81 | def processed_file_names(self) -> str: 82 | return 'data.pt' 83 | 84 | def download(self): 85 | if self.geom_gcn_preprocess: 86 | for filename in self.raw_file_names[:2]: 87 | url = f'{self.processed_url}/new_data/{self.name}/{filename}' 88 | download_url(url, self.raw_dir) 89 | for filename in self.raw_file_names[2:]: 90 | url = f'{self.processed_url}/splits/{filename}' 91 | download_url(url, self.raw_dir) 92 | else: 93 | download_url(f'{self.raw_url}/{self.name}.npz', self.raw_dir) 94 | 95 | def process(self): 96 | if self.geom_gcn_preprocess: 97 | with open(self.raw_paths[0], 'r') as f: 98 | data = f.read().split('\n')[1:-1] 99 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data] 100 | x = torch.tensor(x, dtype=torch.float) 101 | y = [int(r.split('\t')[2]) for r in data] 102 | y = torch.tensor(y, dtype=torch.long) 103 | 104 | with open(self.raw_paths[1], 'r') as f: 105 | data = f.read().split('\n')[1:-1] 106 | data = [[int(v) for v in r.split('\t')] for r in data] 107 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() 108 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 109 | 110 | # data = np.load(self.raw_paths[0], 'r', allow_pickle=True) 111 | # x = torch.from_numpy(data['features']).to(torch.float) 112 | # edge_index = torch.from_numpy(data['edges']).to(torch.long) 113 | # edge_index = edge_index.t().contiguous() 114 | # edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 115 | # y = torch.from_numpy(data['target']).to(torch.float) 116 | 117 | train_masks, val_masks, test_masks = [], [], [] 118 | for filepath in self.raw_paths[2:]: 119 | f = np.load(filepath) 120 | train_masks += [torch.from_numpy(f['train_mask'])] 121 | val_masks += [torch.from_numpy(f['val_mask'])] 122 | test_masks += [torch.from_numpy(f['test_mask'])] 123 | train_mask = torch.stack(train_masks, dim=1).to(torch.bool) 124 | val_mask = torch.stack(val_masks, dim=1).to(torch.bool) 125 | test_mask = torch.stack(test_masks, dim=1).to(torch.bool) 126 | 127 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 128 | val_mask=val_mask, test_mask=test_mask) 129 | 130 | else: 131 | data = np.load(self.raw_paths[0], 'r', allow_pickle=True) 132 | x = torch.from_numpy(data['features']).to(torch.float) 133 | edge_index = torch.from_numpy(data['edges']).to(torch.long) 134 | edge_index = edge_index.t().contiguous() 135 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 136 | y = torch.from_numpy(data['target']).to(torch.float) 137 | 138 | data = Data(x=x, edge_index=edge_index, y=y) 139 | 140 | if self.pre_transform is not None: 141 | data = self.pre_transform(data) 142 | 143 | torch.save(self.collate([data]), self.processed_paths[0]) -------------------------------------------------------------------------------- /models/renode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Reference: https://github.com/victorchen96/ReNode/blob/main/transductive/imb_loss.py 3 | and https://github.com/victorchen96/ReNode/blob/main/transductive/load_data.py 4 | """ 5 | 6 | 7 | import numpy as np 8 | # from sklearn.metrics import pairwise_distances 9 | from scipy.sparse import coo_matrix 10 | import random 11 | import copy 12 | import sys 13 | import math 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | def focal_loss(labels, logits, alpha, gamma): 21 | 22 | BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none") 23 | 24 | if gamma == 0.0: 25 | modulator = 1.0 26 | else: 27 | modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + 28 | torch.exp(-1.0 * logits))) 29 | 30 | loss = modulator * BCLoss 31 | 32 | weighted_loss = alpha * loss 33 | focal_loss = torch.sum(weighted_loss,dim=1) 34 | 35 | return focal_loss 36 | 37 | class IMB_LOSS: 38 | def __init__(self,loss_name,data,idx_info,factor_focal,factor_cb): 39 | self.loss_name = loss_name 40 | self.device = device 41 | self.cls_num = data.num_classes 42 | 43 | #train_size = [len(x) for x in data.train_node] 44 | train_size = [len(x) for x in idx_info] 45 | train_size_arr = np.array(train_size) 46 | train_size_mean = np.mean(train_size_arr) 47 | train_size_factor = train_size_mean / train_size_arr 48 | 49 | #alpha in re-weight 50 | self.factor_train = torch.from_numpy(train_size_factor).type(torch.FloatTensor) 51 | 52 | #gamma in focal 53 | self.factor_focal = factor_focal 54 | 55 | #beta in CB 56 | weights = torch.from_numpy(np.array([1.0 for _ in range(self.cls_num)])).float() 57 | 58 | if self.loss_name == 'focal': 59 | weights = self.factor_train 60 | 61 | if self.loss_name == 'cb-softmax': 62 | beta = factor_cb 63 | effective_num = 1.0 - np.power(beta, train_size_arr) 64 | weights = (1.0 - beta) / np.array(effective_num) 65 | weights = weights / np.sum(weights) * self.cls_num 66 | weights = torch.tensor(weights).float() 67 | 68 | self.weights = weights.unsqueeze(0).to(device) 69 | 70 | 71 | 72 | def compute(self,pred,target): 73 | 74 | if self.loss_name == 'ce': 75 | return F.cross_entropy(pred,target,weight=None,reduction='none') 76 | 77 | elif self.loss_name == 're-weight': 78 | return F.cross_entropy(pred,target,weight=self.factor_train.to(self.device),reduction='none') 79 | 80 | elif self.loss_name == 'focal': 81 | labels_one_hot = F.one_hot(target, self.cls_num).type(torch.FloatTensor).to(self.device) 82 | weights = self.weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot 83 | weights = weights.sum(1) 84 | weights = weights.unsqueeze(1) 85 | weights = weights.repeat(1,self.cls_num) 86 | 87 | return focal_loss(labels_one_hot,pred,weights,self.factor_focal) 88 | 89 | elif self.loss_name == 'cb-softmax': 90 | labels_one_hot = F.one_hot(target, self.cls_num).type(torch.FloatTensor).to(self.device) 91 | weights = self.weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot 92 | weights = weights.sum(1) 93 | weights = weights.unsqueeze(1) 94 | weights = weights.repeat(1,self.cls_num) 95 | 96 | pred = pred.softmax(dim = 1) 97 | temp_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights,reduction='none') 98 | return torch.mean(temp_loss,dim=1) 99 | 100 | else: 101 | raise Exception("No Implentation Loss") 102 | 103 | 104 | def set_seed(seed, cuda): 105 | np.random.seed(seed) 106 | torch.manual_seed(seed) 107 | random.seed(seed) 108 | if cuda: torch.cuda.manual_seed(seed) 109 | 110 | def index2dense(edge_index,nnode=2708): 111 | 112 | indx = edge_index.cpu().numpy() 113 | adj = np.zeros((nnode,nnode),dtype = 'int8') 114 | adj[(indx[0],indx[1])]=1 115 | new_adj = torch.from_numpy(adj).float() 116 | return new_adj 117 | 118 | def index2adj(inf,nnode = 2708): 119 | 120 | indx = inf.numpy() 121 | print(nnode) 122 | adj = np.zeros((nnode,nnode),dtype = 'int8') 123 | adj[(indx[0],indx[1])]=1 124 | return adj 125 | 126 | def adj2index(inf): 127 | 128 | where_new = np.where(inf>0) 129 | new_edge = [where_new[0],where_new[1]] 130 | new_edge_tensor = torch.from_numpy(np.array(new_edge)) 131 | return new_edge_tensor 132 | 133 | def log_opt(opt,log_writer): 134 | for arg in vars(opt): log_writer.write("{}:{}\n".format(arg,getattr(opt,arg))) 135 | 136 | def to_inverse(in_list,t=1): 137 | 138 | in_arr = np.array(in_list) 139 | in_mean = np.mean(in_arr) 140 | out_arr = in_mean / in_arr 141 | out_arr = np.power(out_arr,t) 142 | 143 | return out_arr 144 | 145 | 146 | def get_renode_weight(data, data_train_mask,base_weight,max_weight): 147 | 148 | ##hyperparams## 149 | rn_base_weight = base_weight 150 | rn_scale_weight = max_weight - base_weight 151 | assert rn_scale_weight in [0.5 , 0.75, 1.0, 1.25, 1.5] 152 | 153 | ppr_matrix = data.Pi #personlized pagerank 154 | gpr_matrix = torch.tensor(data.gpr).float() #class-accumulated personlized pagerank 155 | 156 | base_w = rn_base_weight 157 | scale_w = rn_scale_weight 158 | nnode = ppr_matrix.size(0) 159 | unlabel_mask = data_train_mask.int().ne(1)#unlabled node 160 | 161 | #computing the Totoro values for labeled nodes 162 | gpr_sum = torch.sum(gpr_matrix,dim=1) 163 | gpr_rn = gpr_sum.unsqueeze(1) - gpr_matrix 164 | 165 | label_matrix = F.one_hot(data.y,gpr_matrix.size(1)).float() 166 | label_matrix[unlabel_mask] = 0 167 | rn_matrix = torch.mm(ppr_matrix,gpr_rn).to(label_matrix.device) 168 | rn_matrix = torch.sum(rn_matrix * label_matrix,dim=1) 169 | rn_matrix[unlabel_mask] = rn_matrix.max() + 99 #exclude the influence of unlabeled node 170 | 171 | #computing the ReNode Weight 172 | train_size = torch.sum(data_train_mask.int()).item() 173 | totoro_list = rn_matrix.tolist() 174 | id2totoro = {i:totoro_list[i] for i in range(len(totoro_list))} 175 | sorted_totoro = sorted(id2totoro.items(),key=lambda x:x[1],reverse=False) 176 | id2rank = {sorted_totoro[i][0]:i for i in range(nnode)} 177 | totoro_rank = [id2rank[i] for i in range(nnode)] 178 | 179 | rn_weight = [(base_w + 0.5 * scale_w * (1 + math.cos(x*1.0*math.pi/(train_size-1)))) for x in totoro_rank] 180 | rn_weight = torch.from_numpy(np.array(rn_weight)).type(torch.FloatTensor) 181 | rn_weight = rn_weight.to(data_train_mask.device) 182 | rn_weight = rn_weight * data_train_mask.float() 183 | 184 | return rn_weight 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /main_rw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Our code is based on GraphENS: 3 | https://github.com/JoonHyung-Park/GraphENS 4 | """ 5 | 6 | import os.path as osp 7 | import random 8 | import torch 9 | import torch.nn.functional as F 10 | from nets import * 11 | from data_utils import * 12 | from args import parse_args 13 | from models import * 14 | from losses import * 15 | from sklearn.metrics import balanced_accuracy_score, f1_score 16 | import statistics 17 | import numpy as np 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | ## Arg Parser ## 22 | args = parse_args() 23 | 24 | ## Handling exception from arguments ## 25 | assert not (args.warmup < 1 and args.tam) 26 | # assert args.imb_ratio > 1 27 | 28 | ## Load Dataset ## 29 | dataset = args.dataset 30 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset) 31 | dataset = get_dataset(dataset, path, split_type='public') 32 | data = dataset[0] 33 | n_cls = data.y.max().item() + 1 34 | data = data.to(device) 35 | 36 | 37 | def train(): 38 | global class_num_list, aggregator 39 | global data_train_mask, data_val_mask, data_test_mask 40 | 41 | model.train() 42 | optimizer.zero_grad() 43 | 44 | output = model(data.x, data.edge_index) 45 | 46 | criterion(output[data_train_mask], data.y[data_train_mask], weight=class_weight).backward() 47 | 48 | with torch.no_grad(): 49 | model.eval() 50 | output = model(data.x, data.edge_index) 51 | val_loss= F.cross_entropy(output[data_val_mask], data.y[data_val_mask]) 52 | 53 | optimizer.step() 54 | scheduler.step(val_loss) 55 | 56 | 57 | @torch.no_grad() 58 | def test(): 59 | model.eval() 60 | logits = model(data.x, data.edge_index) 61 | accs, baccs, f1s = [], [], [] 62 | 63 | for i, mask in enumerate([data_train_mask, data_val_mask, data_test_mask]): 64 | pred = logits[mask].max(1)[1] 65 | y_pred = pred.cpu().numpy() 66 | y_true = data.y[mask].cpu().numpy() 67 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 68 | bacc = balanced_accuracy_score(y_true, y_pred) 69 | f1 = f1_score(y_true, y_pred, average='macro') 70 | 71 | accs.append(acc) 72 | baccs.append(bacc) 73 | f1s.append(f1) 74 | 75 | return accs, baccs, f1s 76 | 77 | 78 | ## Log for Experiment Setting ## 79 | setting_log = "Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, tam: {}".format( 80 | args.dataset, str(args.imb_ratio), args.net, str(args.n_layer), str(args.feat_dim), str(args.tam)) 81 | 82 | repeatition = 10 83 | seed = 100 84 | avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], [] 85 | for r in range(repeatition): 86 | 87 | ## Fix seed ## 88 | torch.cuda.empty_cache() 89 | seed += 1 90 | torch.manual_seed(seed) 91 | torch.cuda.manual_seed(seed) 92 | torch.backends.cudnn.deterministic = True 93 | torch.backends.cudnn.benchmark = False 94 | random.seed(seed) 95 | np.random.seed(seed) 96 | 97 | if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']: 98 | data_train_mask, data_val_mask, data_test_mask = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone() 99 | else: 100 | data_train_mask, data_val_mask, data_test_mask = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone() 101 | 102 | ## Data statistic ## 103 | stats = data.y[data_train_mask] 104 | n_data = [] 105 | for i in range(n_cls): 106 | data_num = (stats == i).sum() 107 | n_data.append(int(data_num.item())) 108 | idx_info = get_idx_info(data.y, n_cls, data_train_mask) 109 | class_num_list = n_data 110 | 111 | # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced 112 | imb_class_num = n_cls // 2 113 | new_class_num_list = [] 114 | max_num = np.max(class_num_list[:n_cls-imb_class_num]) 115 | for i in range(n_cls): 116 | if args.imb_ratio > 1 and i > n_cls-1-imb_class_num: #only imbalance the last classes 117 | new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i])) 118 | else: 119 | new_class_num_list.append(class_num_list[i]) 120 | class_num_list = new_class_num_list 121 | 122 | if args.imb_ratio > 1: 123 | data_train_mask, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device) 124 | 125 | ## Re-weight method ## 126 | class_weight = get_weight(args.reweight, class_num_list).to(device) 127 | 128 | ## Model Selection ## 129 | if args.net == 'GCN': 130 | model = create_gcn(nfeat=dataset.num_features, nhid=args.feat_dim, 131 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 132 | elif args.net == 'GAT': 133 | model = create_gat(nfeat=dataset.num_features, nhid=args.feat_dim, 134 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 135 | elif args.net == "SAGE": 136 | model = create_sage(nfeat=dataset.num_features, nhid=args.feat_dim, 137 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 138 | else: 139 | raise NotImplementedError("Not Implemented Architecture!") 140 | 141 | ## Criterion Selection ## 142 | if args.loss_type == 'ce': # CE 143 | criterion = CrossEntropy() 144 | else: 145 | raise NotImplementedError("Not Implemented Loss!") 146 | 147 | model = model.to(device) 148 | criterion = criterion.to(device) 149 | 150 | # Set optimizer 151 | optimizer = torch.optim.Adam([ 152 | dict(params=model.reg_params, weight_decay=5e-4), 153 | dict(params=model.non_reg_params, weight_decay=0),], lr=0.01) 154 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 155 | factor = 0.5, 156 | patience = 100, 157 | verbose=False) 158 | 159 | # Train models 160 | best_val_acc_f1 = 0 161 | aggregator = MeanAggregation() 162 | for epoch in range(1, 2001): 163 | 164 | train() 165 | accs, baccs, f1s = test() 166 | train_acc, val_acc, tmp_test_acc = accs 167 | train_f1, val_f1, tmp_test_f1 = f1s 168 | val_acc_f1 = (val_acc + val_f1) / 2. 169 | if val_acc_f1 > best_val_acc_f1: 170 | best_val_acc_f1 = val_acc_f1 171 | test_acc = accs[2] 172 | test_bacc = baccs[2] 173 | test_f1 = f1s[2] 174 | 175 | avg_val_acc_f1.append(best_val_acc_f1) 176 | avg_test_acc.append(test_acc) 177 | avg_test_bacc.append(test_bacc) 178 | avg_test_f1.append(test_f1) 179 | 180 | ## Calculate statistics ## 181 | acc_CI = (statistics.stdev(avg_test_acc) / (repeatition ** (1/2))) 182 | bacc_CI = (statistics.stdev(avg_test_bacc) / (repeatition ** (1/2))) 183 | f1_CI = (statistics.stdev(avg_test_f1) / (repeatition ** (1/2))) 184 | avg_acc = statistics.mean(avg_test_acc) 185 | avg_bacc = statistics.mean(avg_test_bacc) 186 | avg_f1 = statistics.mean(avg_test_f1) 187 | avg_val_acc_f1 = statistics.mean(avg_val_acc_f1) 188 | 189 | avg_log = 'Test Acc: {:.4f} +- {:.4f}, BAcc: {:.4f} +- {:.4f}, F1: {:.4f} +- {:.4f}, Val Acc F1: {:.4f}' 190 | avg_log = avg_log.format(avg_acc ,acc_CI ,avg_bacc, bacc_CI, avg_f1, f1_CI, avg_val_acc_f1) 191 | log = "{}\n{}".format(setting_log, avg_log) 192 | print(log) -------------------------------------------------------------------------------- /main_pc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Our code is based on GraphENS: 3 | https://github.com/JoonHyung-Park/GraphENS 4 | """ 5 | 6 | import os.path as osp 7 | import random 8 | import torch 9 | import torch.nn.functional as F 10 | from nets import * 11 | from data_utils import * 12 | from args import parse_args 13 | from models import * 14 | from losses import * 15 | from sklearn.metrics import balanced_accuracy_score, f1_score 16 | import statistics 17 | import numpy as np 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | ## Arg Parser ## 22 | args = parse_args() 23 | 24 | ## Handling exception from arguments ## 25 | assert not (args.warmup < 1 and args.tam) 26 | # assert args.imb_ratio > 1 27 | 28 | ## Load Dataset ## 29 | dataset = args.dataset 30 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset) 31 | dataset = get_dataset(dataset, path, split_type='public') 32 | data = dataset[0] 33 | n_cls = data.y.max().item() + 1 34 | data = data.to(device) 35 | 36 | 37 | def train(): 38 | global class_num_list, aggregator 39 | global data_train_mask, data_val_mask, data_test_mask 40 | 41 | model.train() 42 | optimizer.zero_grad() 43 | 44 | output = model(data.x, data.edge_index) 45 | 46 | criterion(output[data_train_mask], data.y[data_train_mask], weight=class_weight).backward() 47 | 48 | with torch.no_grad(): 49 | model.eval() 50 | output = model(data.x, data.edge_index) 51 | val_loss= F.cross_entropy(output[data_val_mask], data.y[data_val_mask]) 52 | 53 | optimizer.step() 54 | scheduler.step(val_loss) 55 | 56 | 57 | @torch.no_grad() 58 | def test(): 59 | model.eval() 60 | logits = model(data.x, data.edge_index) 61 | accs, baccs, f1s = [], [], [] 62 | 63 | if args.pc_softmax: 64 | logits = pc_softmax(logits, class_num_list) 65 | 66 | for i, mask in enumerate([data_train_mask, data_val_mask, data_test_mask]): 67 | pred = logits[mask].max(1)[1] 68 | y_pred = pred.cpu().numpy() 69 | y_true = data.y[mask].cpu().numpy() 70 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 71 | bacc = balanced_accuracy_score(y_true, y_pred) 72 | f1 = f1_score(y_true, y_pred, average='macro') 73 | 74 | accs.append(acc) 75 | baccs.append(bacc) 76 | f1s.append(f1) 77 | 78 | return accs, baccs, f1s 79 | 80 | 81 | ## Log for Experiment Setting ## 82 | setting_log = "Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, tam: {}".format( 83 | args.dataset, str(args.imb_ratio), args.net, str(args.n_layer), str(args.feat_dim), str(args.tam)) 84 | 85 | repeatition = 10 86 | seed = 100 87 | avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], [] 88 | for r in range(repeatition): 89 | 90 | ## Fix seed ## 91 | torch.cuda.empty_cache() 92 | seed += 1 93 | torch.manual_seed(seed) 94 | torch.cuda.manual_seed(seed) 95 | torch.backends.cudnn.deterministic = True 96 | torch.backends.cudnn.benchmark = False 97 | random.seed(seed) 98 | np.random.seed(seed) 99 | 100 | if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']: 101 | data_train_mask, data_val_mask, data_test_mask = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone() 102 | else: 103 | data_train_mask, data_val_mask, data_test_mask = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone() 104 | 105 | ## Data statistic ## 106 | stats = data.y[data_train_mask] 107 | n_data = [] 108 | for i in range(n_cls): 109 | data_num = (stats == i).sum() 110 | n_data.append(int(data_num.item())) 111 | idx_info = get_idx_info(data.y, n_cls, data_train_mask) 112 | class_num_list = n_data 113 | 114 | # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced 115 | imb_class_num = n_cls // 2 116 | new_class_num_list = [] 117 | max_num = np.max(class_num_list[:n_cls-imb_class_num]) 118 | for i in range(n_cls): 119 | if args.imb_ratio > 1 and i > n_cls-1-imb_class_num: #only imbalance the last classes 120 | new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i])) 121 | else: 122 | new_class_num_list.append(class_num_list[i]) 123 | class_num_list = new_class_num_list 124 | 125 | if args.imb_ratio > 1: 126 | data_train_mask, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device) 127 | 128 | ## Re-weight method ## 129 | class_weight = get_weight(args.reweight, class_num_list).to(device) 130 | 131 | ## Model Selection ## 132 | if args.net == 'GCN': 133 | model = create_gcn(nfeat=dataset.num_features, nhid=args.feat_dim, 134 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 135 | elif args.net == 'GAT': 136 | model = create_gat(nfeat=dataset.num_features, nhid=args.feat_dim, 137 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 138 | elif args.net == "SAGE": 139 | model = create_sage(nfeat=dataset.num_features, nhid=args.feat_dim, 140 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 141 | else: 142 | raise NotImplementedError("Not Implemented Architecture!") 143 | 144 | ## Criterion Selection ## 145 | if args.loss_type == 'ce': # CE 146 | criterion = CrossEntropy() 147 | else: 148 | raise NotImplementedError("Not Implemented Loss!") 149 | 150 | model = model.to(device) 151 | criterion = criterion.to(device) 152 | 153 | # Set optimizer 154 | optimizer = torch.optim.Adam([ 155 | dict(params=model.reg_params, weight_decay=5e-4), 156 | dict(params=model.non_reg_params, weight_decay=0),], lr=0.01) 157 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 158 | factor = 0.5, 159 | patience = 100, 160 | verbose=False) 161 | 162 | # Train models 163 | best_val_acc_f1 = 0 164 | aggregator = MeanAggregation() 165 | for epoch in range(1, 2001): 166 | 167 | train() 168 | accs, baccs, f1s = test() 169 | train_acc, val_acc, tmp_test_acc = accs 170 | train_f1, val_f1, tmp_test_f1 = f1s 171 | val_acc_f1 = (val_acc + val_f1) / 2. 172 | if val_acc_f1 > best_val_acc_f1: 173 | best_val_acc_f1 = val_acc_f1 174 | test_acc = accs[2] 175 | test_bacc = baccs[2] 176 | test_f1 = f1s[2] 177 | 178 | avg_val_acc_f1.append(best_val_acc_f1) 179 | avg_test_acc.append(test_acc) 180 | avg_test_bacc.append(test_bacc) 181 | avg_test_f1.append(test_f1) 182 | 183 | ## Calculate statistics ## 184 | acc_CI = (statistics.stdev(avg_test_acc) / (repeatition ** (1/2))) 185 | bacc_CI = (statistics.stdev(avg_test_bacc) / (repeatition ** (1/2))) 186 | f1_CI = (statistics.stdev(avg_test_f1) / (repeatition ** (1/2))) 187 | avg_acc = statistics.mean(avg_test_acc) 188 | avg_bacc = statistics.mean(avg_test_bacc) 189 | avg_f1 = statistics.mean(avg_test_f1) 190 | avg_val_acc_f1 = statistics.mean(avg_val_acc_f1) 191 | 192 | avg_log = 'Test Acc: {:.4f} +- {:.4f}, BAcc: {:.4f} +- {:.4f}, F1: {:.4f} +- {:.4f}, Val Acc F1: {:.4f}' 193 | avg_log = avg_log.format(avg_acc ,acc_CI ,avg_bacc, bacc_CI, avg_f1, f1_CI, avg_val_acc_f1) 194 | log = "{}\n{}".format(setting_log, avg_log) 195 | print(log) -------------------------------------------------------------------------------- /main_bs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Our code is based on GraphENS: 3 | https://github.com/JoonHyung-Park/GraphENS 4 | """ 5 | 6 | import os.path as osp 7 | import random 8 | import torch 9 | import torch.nn.functional as F 10 | from nets import * 11 | from data_utils import * 12 | from args import parse_args 13 | from models import * 14 | from losses import * 15 | from sklearn.metrics import balanced_accuracy_score, f1_score 16 | import statistics 17 | import numpy as np 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | ## Arg Parser ## 22 | args = parse_args() 23 | 24 | ## Handling exception from arguments ## 25 | assert not (args.warmup < 1 and args.tam) 26 | # assert args.imb_ratio > 1 27 | 28 | ## Load Dataset ## 29 | dataset = args.dataset 30 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset) 31 | dataset = get_dataset(dataset, path, split_type='public') 32 | data = dataset[0] 33 | n_cls = data.y.max().item() + 1 34 | data = data.to(device) 35 | 36 | 37 | def train(): 38 | global class_num_list, aggregator 39 | global data_train_mask, data_val_mask, data_test_mask 40 | 41 | model.train() 42 | optimizer.zero_grad() 43 | 44 | output = model(data.x, data.edge_index) 45 | 46 | ## Apply TAM ## 47 | output = adjust_output(args, output, data.edge_index, data.y, \ 48 | data_train_mask, aggregator, class_num_list, epoch) 49 | 50 | criterion(output, data.y[data_train_mask]).backward() 51 | 52 | with torch.no_grad(): 53 | model.eval() 54 | output = model(data.x, data.edge_index) 55 | val_loss= F.cross_entropy(output[data_val_mask], data.y[data_val_mask]) 56 | 57 | optimizer.step() 58 | scheduler.step(val_loss) 59 | 60 | 61 | @torch.no_grad() 62 | def test(): 63 | model.eval() 64 | logits = model(data.x, data.edge_index) 65 | accs, baccs, f1s = [], [], [] 66 | 67 | for i, mask in enumerate([data_train_mask, data_val_mask, data_test_mask]): 68 | pred = logits[mask].max(1)[1] 69 | y_pred = pred.cpu().numpy() 70 | y_true = data.y[mask].cpu().numpy() 71 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 72 | bacc = balanced_accuracy_score(y_true, y_pred) 73 | f1 = f1_score(y_true, y_pred, average='macro') 74 | 75 | accs.append(acc) 76 | baccs.append(bacc) 77 | f1s.append(f1) 78 | 79 | return accs, baccs, f1s 80 | 81 | 82 | ## Log for Experiment Setting ## 83 | setting_log = "Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, tam: {}".format( 84 | args.dataset, str(args.imb_ratio), args.net, str(args.n_layer), str(args.feat_dim), str(args.tam)) 85 | 86 | repeatition = 10 87 | seed = 100 88 | avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], [] 89 | for r in range(repeatition): 90 | 91 | ## Fix seed ## 92 | torch.cuda.empty_cache() 93 | seed += 1 94 | torch.manual_seed(seed) 95 | torch.cuda.manual_seed(seed) 96 | torch.backends.cudnn.deterministic = True 97 | torch.backends.cudnn.benchmark = False 98 | random.seed(seed) 99 | np.random.seed(seed) 100 | 101 | if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']: 102 | data_train_mask, data_val_mask, data_test_mask = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone() 103 | else: 104 | data_train_mask, data_val_mask, data_test_mask = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone() 105 | 106 | ## Data statistic ## 107 | stats = data.y[data_train_mask] 108 | n_data = [] 109 | for i in range(n_cls): 110 | data_num = (stats == i).sum() 111 | n_data.append(int(data_num.item())) 112 | idx_info = get_idx_info(data.y, n_cls, data_train_mask) 113 | class_num_list = n_data 114 | 115 | # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced 116 | imb_class_num = n_cls // 2 117 | new_class_num_list = [] 118 | max_num = np.max(class_num_list[:n_cls-imb_class_num]) 119 | for i in range(n_cls): 120 | if args.imb_ratio > 1 and i > n_cls-1-imb_class_num: #only imbalance the last classes 121 | new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i])) 122 | else: 123 | new_class_num_list.append(class_num_list[i]) 124 | class_num_list = new_class_num_list 125 | 126 | if args.imb_ratio > 1: 127 | data_train_mask, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device) 128 | 129 | ## Model Selection ## 130 | if args.net == 'GCN': 131 | model = create_gcn(nfeat=dataset.num_features, nhid=args.feat_dim, 132 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 133 | elif args.net == 'GAT': 134 | model = create_gat(nfeat=dataset.num_features, nhid=args.feat_dim, 135 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 136 | elif args.net == "SAGE": 137 | model = create_sage(nfeat=dataset.num_features, nhid=args.feat_dim, 138 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 139 | else: 140 | raise NotImplementedError("Not Implemented Architecture!") 141 | 142 | ## Criterion Selection ## 143 | if args.loss_type == 'ce': # CE 144 | criterion = CrossEntropy() 145 | elif args.loss_type == 'bs': 146 | criterion = BalancedSoftmax(class_num_list) 147 | else: 148 | raise NotImplementedError("Not Implemented Loss!") 149 | 150 | model = model.to(device) 151 | criterion = criterion.to(device) 152 | 153 | # Set optimizer 154 | optimizer = torch.optim.Adam([ 155 | dict(params=model.reg_params, weight_decay=5e-4), 156 | dict(params=model.non_reg_params, weight_decay=0),], lr=0.01) 157 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 158 | factor = 0.5, 159 | patience = 100, 160 | verbose=False) 161 | 162 | # Train models 163 | best_val_acc_f1 = 0 164 | aggregator = MeanAggregation() 165 | for epoch in range(1, 2001): 166 | 167 | train() 168 | accs, baccs, f1s = test() 169 | train_acc, val_acc, tmp_test_acc = accs 170 | train_f1, val_f1, tmp_test_f1 = f1s 171 | val_acc_f1 = (val_acc + val_f1) / 2. 172 | if val_acc_f1 > best_val_acc_f1: 173 | best_val_acc_f1 = val_acc_f1 174 | test_acc = accs[2] 175 | test_bacc = baccs[2] 176 | test_f1 = f1s[2] 177 | 178 | avg_val_acc_f1.append(best_val_acc_f1) 179 | avg_test_acc.append(test_acc) 180 | avg_test_bacc.append(test_bacc) 181 | avg_test_f1.append(test_f1) 182 | 183 | ## Calculate statistics ## 184 | acc_CI = (statistics.stdev(avg_test_acc) / (repeatition ** (1/2))) 185 | bacc_CI = (statistics.stdev(avg_test_bacc) / (repeatition ** (1/2))) 186 | f1_CI = (statistics.stdev(avg_test_f1) / (repeatition ** (1/2))) 187 | avg_acc = statistics.mean(avg_test_acc) 188 | avg_bacc = statistics.mean(avg_test_bacc) 189 | avg_f1 = statistics.mean(avg_test_f1) 190 | avg_val_acc_f1 = statistics.mean(avg_val_acc_f1) 191 | 192 | avg_log = 'Test Acc: {:.4f} +- {:.4f}, BAcc: {:.4f} +- {:.4f}, F1: {:.4f} +- {:.4f}, Val Acc F1: {:.4f}' 193 | avg_log = avg_log.format(avg_acc ,acc_CI ,avg_bacc, bacc_CI, avg_f1, f1_CI, avg_val_acc_f1) 194 | log = "{}\n{}".format(setting_log, avg_log) 195 | print(log) -------------------------------------------------------------------------------- /main_renode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Our code is based on ReNode: 3 | https://github.com/victorchen96/ReNode 4 | """ 5 | 6 | import os.path as osp 7 | import random 8 | import torch 9 | import torch.nn.functional as F 10 | from nets import * 11 | from data_utils import * 12 | from args import parse_args 13 | from models import * 14 | from losses import * 15 | from sklearn.metrics import balanced_accuracy_score, f1_score 16 | import statistics 17 | import numpy as np 18 | import warnings 19 | 20 | warnings.filterwarnings("ignore") 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | ## Arg Parser ## 24 | args = parse_args() 25 | 26 | ## Handling exception from arguments ## 27 | assert not (args.warmup < 1 and args.tam) 28 | # assert args.imb_ratio > 1 29 | 30 | ## Load Dataset ## 31 | dataset = args.dataset 32 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset) 33 | dataset = get_dataset(dataset, path, split_type='public') 34 | data = dataset[0] 35 | n_cls = data.y.max().item() + 1 36 | data = data.to(device) 37 | 38 | 39 | def train(): 40 | global class_num_list, aggregator 41 | global data_train_mask, data_val_mask, data_test_mask 42 | 43 | model.train() 44 | optimizer.zero_grad() 45 | 46 | if args.renode: 47 | global aggregator 48 | ## ReNode ## 49 | output = model(data.x, data.edge_index) 50 | 51 | ## Apply TAM ## 52 | output = adjust_output(args, output, data.edge_index, data.y, \ 53 | data_train_mask, aggregator, class_num_list, epoch) 54 | 55 | sup_logits = output 56 | 57 | cls_loss= renode_loss.compute(sup_logits, data.y[data_train_mask].to(device)) 58 | cls_loss = torch.sum(cls_loss * data.rn_weight[data_train_mask].to(device)) / cls_loss.size(0) 59 | cls_loss.backward() 60 | ################## 61 | 62 | else: 63 | output = model(data.x, data.edge_index) 64 | 65 | ## Apply TAM ## 66 | output = adjust_output(args, output, data.edge_index, data.y, \ 67 | data_train_mask, aggregator, class_num_list, epoch) 68 | 69 | criterion(output, data.y[data_train_mask]).backward() 70 | 71 | with torch.no_grad(): 72 | model.eval() 73 | output = model(data.x, data.edge_index) 74 | val_loss= F.cross_entropy(output[data_val_mask], data.y[data_val_mask]) 75 | 76 | optimizer.step() 77 | scheduler.step(val_loss) 78 | 79 | 80 | @torch.no_grad() 81 | def test(): 82 | model.eval() 83 | logits = model(data.x, data.edge_index) 84 | accs, baccs, f1s = [], [], [] 85 | 86 | for i, mask in enumerate([data_train_mask, data_val_mask, data_test_mask]): 87 | pred = logits[mask].max(1)[1] 88 | y_pred = pred.cpu().numpy() 89 | y_true = data.y[mask].cpu().numpy() 90 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 91 | bacc = balanced_accuracy_score(y_true, y_pred) 92 | f1 = f1_score(y_true, y_pred, average='macro') 93 | 94 | accs.append(acc) 95 | baccs.append(bacc) 96 | f1s.append(f1) 97 | 98 | return accs, baccs, f1s 99 | 100 | 101 | ## Log for Experiment Setting ## 102 | setting_log = "Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, tam: {}".format( 103 | args.dataset, str(args.imb_ratio), args.net, str(args.n_layer), str(args.feat_dim), str(args.tam)) 104 | 105 | repeatition = 10 106 | seed = 100 107 | avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], [] 108 | for r in range(repeatition): 109 | 110 | ## Fix seed ## 111 | torch.cuda.empty_cache() 112 | seed += 1 113 | torch.manual_seed(seed) 114 | torch.cuda.manual_seed(seed) 115 | torch.backends.cudnn.deterministic = True 116 | torch.backends.cudnn.benchmark = False 117 | random.seed(seed) 118 | np.random.seed(seed) 119 | 120 | if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']: 121 | data_train_mask, data_val_mask, data_test_mask = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone() 122 | else: 123 | data_train_mask, data_val_mask, data_test_mask = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone() 124 | 125 | ## Data statistic ## 126 | stats = data.y[data_train_mask] 127 | n_data = [] 128 | for i in range(n_cls): 129 | data_num = (stats == i).sum() 130 | n_data.append(int(data_num.item())) 131 | idx_info = get_idx_info(data.y, n_cls, data_train_mask) 132 | class_num_list = n_data 133 | 134 | # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced 135 | imb_class_num = n_cls // 2 136 | new_class_num_list = [] 137 | max_num = np.max(class_num_list[:n_cls-imb_class_num]) 138 | for i in range(n_cls): 139 | if args.imb_ratio > 1 and i > n_cls-1-imb_class_num: #only imbalance the last classes 140 | new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i])) 141 | else: 142 | new_class_num_list.append(class_num_list[i]) 143 | class_num_list = new_class_num_list 144 | 145 | if args.imb_ratio > 1: 146 | data_train_mask, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device) 147 | 148 | if args.renode: 149 | ## ReNode method ## 150 | ## hyperparam ## 151 | pagerank_prob = 0.85 152 | 153 | # calculating the Personalized PageRank Matrix 154 | pr_prob = 1 - pagerank_prob 155 | A = index2dense(data.edge_index, data.num_nodes) 156 | A_hat = A.to(device) + torch.eye(A.size(0)).to(device) # add self-loop 157 | D = torch.diag(torch.sum(A_hat,1)) 158 | D = D.inverse().sqrt() 159 | A_hat = torch.mm(torch.mm(D, A_hat), D) 160 | data.Pi = pr_prob * ((torch.eye(A.size(0)).to(device) - (1 - pr_prob) * A_hat).inverse()) 161 | data.Pi = data.Pi.cpu() 162 | 163 | 164 | # calculating the ReNode Weight 165 | gpr_matrix = [] # the class-level influence distribution 166 | data.num_classes = n_cls 167 | for iter_c in range(data.num_classes): 168 | #iter_Pi = data.Pi[torch.tensor(target_data.train_node[iter_c]).long()] 169 | iter_Pi = data.Pi[idx_info[iter_c].long()] # check! is it same with above line? 170 | iter_gpr = torch.mean(iter_Pi,dim=0).squeeze() 171 | gpr_matrix.append(iter_gpr) 172 | 173 | temp_gpr = torch.stack(gpr_matrix,dim=0) 174 | temp_gpr = temp_gpr.transpose(0,1) 175 | data.gpr = temp_gpr 176 | data.rn_weight = get_renode_weight(data, data_train_mask, args.rn_base,args.rn_max) #ReNode Weight 177 | renode_loss = IMB_LOSS(args.loss_name, data, idx_info,args.factor_focal, args.factor_cb) 178 | 179 | 180 | ## Model Selection ## 181 | if args.net == 'GCN': 182 | model = create_gcn(nfeat=dataset.num_features, nhid=args.feat_dim, 183 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 184 | elif args.net == 'GAT': 185 | model = create_gat(nfeat=dataset.num_features, nhid=args.feat_dim, 186 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 187 | elif args.net == "SAGE": 188 | model = create_sage(nfeat=dataset.num_features, nhid=args.feat_dim, 189 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 190 | else: 191 | raise NotImplementedError("Not Implemented Architecture!") 192 | 193 | ## Criterion Selection ## 194 | if args.loss_type == 'ce': # CE 195 | criterion = CrossEntropy() 196 | elif args.loss_type == 'bs': 197 | criterion = BalancedSoftmax(class_num_list) 198 | else: 199 | raise NotImplementedError("Not Implemented Loss!") 200 | 201 | model = model.to(device) 202 | criterion = criterion.to(device) 203 | 204 | # Set optimizer 205 | optimizer = torch.optim.Adam([ 206 | dict(params=model.reg_params, weight_decay=5e-4), 207 | dict(params=model.non_reg_params, weight_decay=0),], lr=0.01) 208 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 209 | factor = 0.5, 210 | patience = 100, 211 | verbose=False) 212 | 213 | # Train models 214 | best_val_acc_f1 = 0 215 | aggregator = MeanAggregation() 216 | for epoch in range(1, 2001): 217 | 218 | train() 219 | accs, baccs, f1s = test() 220 | train_acc, val_acc, tmp_test_acc = accs 221 | train_f1, val_f1, tmp_test_f1 = f1s 222 | val_acc_f1 = (val_acc + val_f1) / 2. 223 | if val_acc_f1 > best_val_acc_f1: 224 | best_val_acc_f1 = val_acc_f1 225 | test_acc = accs[2] 226 | test_bacc = baccs[2] 227 | test_f1 = f1s[2] 228 | 229 | avg_val_acc_f1.append(best_val_acc_f1) 230 | avg_test_acc.append(test_acc) 231 | avg_test_bacc.append(test_bacc) 232 | avg_test_f1.append(test_f1) 233 | 234 | ## Calculate statistics ## 235 | acc_CI = (statistics.stdev(avg_test_acc) / (repeatition ** (1/2))) 236 | bacc_CI = (statistics.stdev(avg_test_bacc) / (repeatition ** (1/2))) 237 | f1_CI = (statistics.stdev(avg_test_f1) / (repeatition ** (1/2))) 238 | avg_acc = statistics.mean(avg_test_acc) 239 | avg_bacc = statistics.mean(avg_test_bacc) 240 | avg_f1 = statistics.mean(avg_test_f1) 241 | avg_val_acc_f1 = statistics.mean(avg_val_acc_f1) 242 | 243 | avg_log = 'Test Acc: {:.4f} +- {:.4f}, BAcc: {:.4f} +- {:.4f}, F1: {:.4f} +- {:.4f}, Val Acc F1: {:.4f}' 244 | avg_log = avg_log.format(avg_acc ,acc_CI ,avg_bacc, bacc_CI, avg_f1, f1_CI, avg_val_acc_f1) 245 | log = "{}\n{}".format(setting_log, avg_log) 246 | print(log) -------------------------------------------------------------------------------- /main_ens.py: -------------------------------------------------------------------------------- 1 | """ 2 | Our code is based on GraphENS: 3 | https://github.com/JoonHyung-Park/GraphENS 4 | """ 5 | 6 | import os.path as osp 7 | import random 8 | import torch 9 | import torch.nn.functional as F 10 | from data_utils import * 11 | from args import parse_args 12 | from models import * 13 | from losses import * 14 | from sklearn.metrics import balanced_accuracy_score, f1_score 15 | import statistics 16 | import numpy as np 17 | import warnings 18 | 19 | warnings.filterwarnings("ignore") 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | ## Arg Parser ## 23 | args = parse_args() 24 | 25 | ## Handling exception from arguments ## 26 | assert not (args.warmup < 1 and args.tam) 27 | # assert args.imb_ratio > 1 28 | 29 | ## Load Dataset ## 30 | dataset = args.dataset 31 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset) 32 | dataset = get_dataset(dataset, path, split_type='public') 33 | data = dataset[0] 34 | n_cls = data.y.max().item() + 1 35 | data = data.to(device) 36 | 37 | 38 | ## For GraphENS ## 39 | def backward_hook(module, grad_input, grad_output): 40 | global saliency 41 | saliency = grad_input[0].data 42 | 43 | def tensor_hook(grad): 44 | global saliency 45 | saliency = grad.data 46 | 47 | 48 | def train(): 49 | global class_num_list, idx_info, prev_out, aggregator 50 | global data_train_mask, data_val_mask, data_test_mask 51 | 52 | model.train() 53 | optimizer.zero_grad() 54 | 55 | if args.ens: 56 | # Hook saliency map of input features 57 | model.conv1.temp_weight.register_backward_hook(backward_hook) 58 | 59 | # Sampling source and destination nodes 60 | sampling_src_idx, sampling_dst_idx = sampling_idx_individual_dst(class_num_list, idx_info, device) 61 | beta = torch.distributions.beta.Beta(2, 2) 62 | lam = beta.sample((len(sampling_src_idx),) ).unsqueeze(1) 63 | ori_saliency = saliency[:data.x.shape[0]] if (saliency != None) else None 64 | 65 | # Augment nodes 66 | if epoch > args.warmup: 67 | with torch.no_grad(): 68 | prev_out = aggregator(prev_out, data.edge_index) 69 | prev_out = F.softmax(prev_out / args.pred_temp, dim=1).detach().clone() 70 | new_edge_index, dist_kl = neighbor_sampling(data.x.size(0), data.edge_index, sampling_src_idx, sampling_dst_idx, 71 | neighbor_dist_list, prev_out) 72 | new_x = saliency_mixup(data.x, sampling_src_idx, sampling_dst_idx, lam, ori_saliency, dist_kl = dist_kl, keep_prob=args.keep_prob) 73 | else: 74 | new_edge_index = duplicate_neighbor(data.x.size(0), data.edge_index, sampling_src_idx) 75 | dist_kl, ori_saliency = None, None 76 | new_x = saliency_mixup(data.x, sampling_src_idx, sampling_dst_idx, lam, ori_saliency, dist_kl = dist_kl) 77 | new_x.requires_grad = True 78 | 79 | # Get predictions 80 | output = model(new_x, new_edge_index) 81 | prev_out = (output[:data.x.size(0)]).detach().clone() # logit propagation 82 | 83 | ## Train_mask modification ## 84 | add_num = output.shape[0] - data_train_mask.shape[0] 85 | new_train_mask = torch.ones(add_num, dtype=torch.bool, device= data.x.device) 86 | new_train_mask = torch.cat((data_train_mask, new_train_mask), dim =0) 87 | 88 | ## Label modification ## 89 | _new_y = data.y[sampling_src_idx].clone() 90 | new_y = torch.cat((data.y[data_train_mask], _new_y),dim =0) 91 | 92 | ## Apply TAM ## 93 | output = adjust_output(args, output, new_edge_index, torch.cat((data.y,_new_y),dim =0), \ 94 | new_train_mask, aggregator, class_num_list, epoch) 95 | 96 | ## Compute Loss ## 97 | criterion(output, new_y).backward() 98 | 99 | with torch.no_grad(): 100 | model.eval() 101 | output = model(data.x, data.edge_index) 102 | val_loss= F.cross_entropy(output[data_val_mask], data.y[data_val_mask]) 103 | 104 | optimizer.step() 105 | scheduler.step(val_loss) 106 | 107 | 108 | @torch.no_grad() 109 | def test(): 110 | model.eval() 111 | logits = model(data.x, data.edge_index) 112 | accs, baccs, f1s = [], [], [] 113 | 114 | for i, mask in enumerate([data_train_mask, data_val_mask, data_test_mask]): 115 | pred = logits[mask].max(1)[1] 116 | y_pred = pred.cpu().numpy() 117 | y_true = data.y[mask].cpu().numpy() 118 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 119 | bacc = balanced_accuracy_score(y_true, y_pred) 120 | f1 = f1_score(y_true, y_pred, average='macro') 121 | 122 | accs.append(acc) 123 | baccs.append(bacc) 124 | f1s.append(f1) 125 | 126 | return accs, baccs, f1s 127 | 128 | 129 | ## Log for Experiment Setting ## 130 | setting_log = "Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, tam: {}".format( 131 | args.dataset, str(args.imb_ratio), args.net, str(args.n_layer), str(args.feat_dim), str(args.tam)) 132 | 133 | repeatition = 10 134 | seed = 100 135 | avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], [] 136 | for r in range(repeatition): 137 | 138 | ## Fix seed ## 139 | torch.cuda.empty_cache() 140 | seed += 1 141 | torch.manual_seed(seed) 142 | torch.cuda.manual_seed(seed) 143 | torch.backends.cudnn.deterministic = True 144 | torch.backends.cudnn.benchmark = False 145 | random.seed(seed) 146 | np.random.seed(seed) 147 | 148 | if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']: 149 | data_train_mask, data_val_mask, data_test_mask = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone() 150 | else: 151 | data_train_mask, data_val_mask, data_test_mask = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone() 152 | 153 | ## Data statistic ## 154 | stats = data.y[data_train_mask] 155 | n_data = [] 156 | for i in range(n_cls): 157 | data_num = (stats == i).sum() 158 | n_data.append(int(data_num.item())) 159 | idx_info = get_idx_info(data.y, n_cls, data_train_mask) 160 | class_num_list = n_data 161 | 162 | # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced 163 | imb_class_num = n_cls // 2 164 | new_class_num_list = [] 165 | max_num = np.max(class_num_list[:n_cls-imb_class_num]) 166 | for i in range(n_cls): 167 | if args.imb_ratio > 1 and i > n_cls-1-imb_class_num: #only imbalance the last classes 168 | new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i])) 169 | else: 170 | new_class_num_list.append(class_num_list[i]) 171 | class_num_list = new_class_num_list 172 | 173 | if args.imb_ratio > 1: 174 | data_train_mask, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device) 175 | 176 | if args.ens: 177 | neighbor_dist_list = get_ins_neighbor_dist(data.y.size(0), data.edge_index, data_train_mask, device) 178 | else: 179 | neighbor_dist_list = None 180 | 181 | if args.ens: # for getting saliency 182 | from ens_nets import * 183 | else: 184 | from nets import * 185 | 186 | ## Model Selection ## 187 | if args.net == 'GCN': 188 | model = create_gcn(nfeat=dataset.num_features, nhid=args.feat_dim, 189 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 190 | elif args.net == 'GAT': 191 | model = create_gat(nfeat=dataset.num_features, nhid=args.feat_dim, 192 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 193 | elif args.net == "SAGE": 194 | model = create_sage(nfeat=dataset.num_features, nhid=args.feat_dim, 195 | nclass=n_cls, dropout=0.5, nlayer=args.n_layer) 196 | else: 197 | raise NotImplementedError("Not Implemented Architecture!") 198 | 199 | ## Criterion Selection ## 200 | if args.loss_type == 'ce': # CE 201 | criterion = CrossEntropy() 202 | elif args.loss_type == 'bs': 203 | criterion = BalancedSoftmax(class_num_list) 204 | else: 205 | raise NotImplementedError("Not Implemented Loss!") 206 | 207 | model = model.to(device) 208 | criterion = criterion.to(device) 209 | 210 | # Set optimizer 211 | optimizer = torch.optim.Adam([ 212 | dict(params=model.reg_params, weight_decay=5e-4), 213 | dict(params=model.non_reg_params, weight_decay=0),], lr=0.01) 214 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 215 | factor = 0.5, 216 | patience = 100, 217 | verbose=False) 218 | 219 | # Train models 220 | best_val_acc_f1 = 0 221 | saliency, prev_out = None, None 222 | aggregator = MeanAggregation() 223 | for epoch in range(1, 2001): 224 | 225 | train() 226 | accs, baccs, f1s = test() 227 | train_acc, val_acc, tmp_test_acc = accs 228 | train_f1, val_f1, tmp_test_f1 = f1s 229 | val_acc_f1 = (val_acc + val_f1) / 2. 230 | if val_acc_f1 > best_val_acc_f1: 231 | best_val_acc_f1 = val_acc_f1 232 | test_acc = accs[2] 233 | test_bacc = baccs[2] 234 | test_f1 = f1s[2] 235 | 236 | avg_val_acc_f1.append(best_val_acc_f1) 237 | avg_test_acc.append(test_acc) 238 | avg_test_bacc.append(test_bacc) 239 | avg_test_f1.append(test_f1) 240 | 241 | ## Calculate statistics ## 242 | acc_CI = (statistics.stdev(avg_test_acc) / (repeatition ** (1/2))) 243 | bacc_CI = (statistics.stdev(avg_test_bacc) / (repeatition ** (1/2))) 244 | f1_CI = (statistics.stdev(avg_test_f1) / (repeatition ** (1/2))) 245 | avg_acc = statistics.mean(avg_test_acc) 246 | avg_bacc = statistics.mean(avg_test_bacc) 247 | avg_f1 = statistics.mean(avg_test_f1) 248 | avg_val_acc_f1 = statistics.mean(avg_val_acc_f1) 249 | 250 | avg_log = 'Test Acc: {:.4f} +- {:.4f}, BAcc: {:.4f} +- {:.4f}, F1: {:.4f} +- {:.4f}, Val Acc F1: {:.4f}' 251 | avg_log = avg_log.format(avg_acc ,acc_CI ,avg_bacc, bacc_CI, avg_f1, f1_CI, avg_val_acc_f1) 252 | log = "{}\n{}".format(setting_log, avg_log) 253 | print(log) -------------------------------------------------------------------------------- /models/gens.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Reference: https://github.com/JoonHyung-Park/GraphENS/blob/main/models/gens.py 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch_scatter import scatter_add 10 | from torch_geometric.utils import to_dense_batch 11 | from torch_geometric.nn import MessagePassing 12 | from torch_geometric.utils import add_self_loops, degree 13 | 14 | 15 | @torch.no_grad() 16 | def get_ins_neighbor_dist(num_nodes, edge_index, train_mask, device): 17 | """ 18 | Compute adjacent node distribution. 19 | """ 20 | ## Utilize GPU ## 21 | train_mask = train_mask.clone().to(device) 22 | edge_index = edge_index.clone().to(device) 23 | row, col = edge_index[0], edge_index[1] 24 | 25 | # Compute neighbor distribution 26 | neighbor_dist_list = [] 27 | for j in range(num_nodes): 28 | neighbor_dist = torch.zeros(num_nodes, dtype=torch.float32).to(device) 29 | 30 | idx = row[(col==j)] 31 | neighbor_dist[idx] = neighbor_dist[idx] + 1 32 | neighbor_dist_list.append(neighbor_dist) 33 | 34 | neighbor_dist_list = torch.stack(neighbor_dist_list,dim=0) 35 | neighbor_dist_list = F.normalize(neighbor_dist_list,dim=1,p=1) 36 | 37 | return neighbor_dist_list 38 | 39 | 40 | @torch.no_grad() 41 | def sampling_idx_individual_dst(class_num_list, idx_info, device): 42 | """ 43 | Samples source and target nodes 44 | """ 45 | # Selecting src & dst nodes 46 | max_num, n_cls = max(class_num_list), len(class_num_list) 47 | sampling_list = max_num * torch.ones(n_cls) - torch.tensor(class_num_list) 48 | new_class_num_list = torch.Tensor(class_num_list).to(device) 49 | 50 | # Compute # of source nodes 51 | sampling_src_idx =[cls_idx[torch.randint(len(cls_idx),(int(samp_num.item()),))] 52 | for cls_idx, samp_num in zip(idx_info, sampling_list)] 53 | sampling_src_idx = torch.cat(sampling_src_idx) 54 | 55 | # Generate corresponding destination nodes 56 | class_dst_idx= [] 57 | prob = torch.log(new_class_num_list.float())/ new_class_num_list.float() 58 | prob = prob.repeat_interleave(new_class_num_list.long()) 59 | temp_idx_info = torch.cat(idx_info) 60 | dst_idx = torch.multinomial(prob, sampling_src_idx.shape[0], True) 61 | sampling_dst_idx = temp_idx_info[dst_idx] 62 | 63 | # Sorting src idx with corresponding dst idx 64 | sampling_src_idx, sorted_idx = torch.sort(sampling_src_idx) 65 | sampling_dst_idx = sampling_dst_idx[sorted_idx] 66 | 67 | return sampling_src_idx, sampling_dst_idx 68 | 69 | 70 | def saliency_mixup(x, sampling_src_idx, sampling_dst_idx, lam, saliency=None, 71 | dist_kl = None, keep_prob = 0.3): 72 | """ 73 | Saliency-based node mixing - Mix node features 74 | Input: 75 | x: Node features; [# of nodes, input feature dimension] 76 | sampling_src_idx: Source node index for augmented nodes; [# of augmented nodes] 77 | sampling_dst_idx: Target node index for augmented nodes; [# of augmented nodes] 78 | lam: Sampled mixing ratio; [# of augmented nodes, 1] 79 | saliency: Saliency map of input feature; [# of nodes, input feature dimension] 80 | dist_kl: KLD between source node and target node predictions; [# of augmented nodes, 1] 81 | keep_prob: Ratio of keeping source node feature; scalar 82 | Output: 83 | new_x: [# of original nodes + # of augmented nodes, feature dimension] 84 | """ 85 | total_node = x.shape[0] 86 | ## Mixup ## 87 | new_src = x[sampling_src_idx.to(x.device), :].clone() 88 | new_dst = x[sampling_dst_idx.to(x.device), :].clone() 89 | lam = lam.to(x.device) 90 | 91 | # Saliency Mixup 92 | if saliency != None: 93 | node_dim = saliency.shape[1] 94 | saliency_dst = saliency[sampling_dst_idx].abs() 95 | saliency_dst += 1e-10 96 | saliency_dst /= torch.sum(saliency_dst, dim=1).unsqueeze(1) 97 | 98 | K = int(node_dim * keep_prob) 99 | mask_idx = torch.multinomial(saliency_dst, K) 100 | lam = lam.expand(-1,node_dim).clone() 101 | if dist_kl != None: # Adaptive 102 | kl_mask = (torch.sigmoid(dist_kl/3.) * K).squeeze().long() 103 | idx_matrix = (torch.arange(K).unsqueeze(dim=0).to(kl_mask.device) >= kl_mask.unsqueeze(dim=1)) 104 | zero_repeat_idx = mask_idx[:,0:1].repeat(1,mask_idx.size(1)) 105 | mask_idx[idx_matrix] = zero_repeat_idx[idx_matrix] 106 | 107 | lam[torch.arange(lam.shape[0]).unsqueeze(1), mask_idx] = 1. 108 | mixed_node = lam * new_src + (1-lam) * new_dst 109 | new_x = torch.cat([x, mixed_node], dim =0) 110 | return new_x 111 | 112 | 113 | @torch.no_grad() 114 | def duplicate_neighbor(total_node, edge_index, sampling_src_idx): 115 | """ 116 | Duplicate edges of source nodes for sampled nodes. 117 | Input: 118 | total_node: # of nodes; scalar 119 | edge_index: Edge index; [2, # of edges] 120 | sampling_src_idx: Source node index for augmented nodes; [# of augmented nodes] 121 | Output: 122 | new_edge_index: original_edge_index + duplicated_edge_index 123 | """ 124 | device = edge_index.device 125 | 126 | # Assign node index for augmented nodes 127 | row, col = edge_index[0], edge_index[1] 128 | row, sort_idx = torch.sort(row) 129 | col = col[sort_idx] 130 | degree = scatter_add(torch.ones_like(col), col) 131 | new_row =(torch.arange(len(sampling_src_idx)).to(device)+ total_node).repeat_interleave(degree[sampling_src_idx]) 132 | temp = scatter_add(torch.ones_like(sampling_src_idx), sampling_src_idx).to(device) 133 | 134 | # Duplicate the edges of source nodes 135 | node_mask = torch.zeros(total_node, dtype=torch.bool) 136 | unique_src = torch.unique(sampling_src_idx) 137 | node_mask[unique_src] = True 138 | row_mask = node_mask[row] 139 | edge_mask = col[row_mask] 140 | b_idx = torch.arange(len(unique_src)).to(device).repeat_interleave(degree[unique_src]) 141 | edge_dense, _ = to_dense_batch(edge_mask, b_idx, fill_value=-1) 142 | if len(temp[temp!=0]) != edge_dense.shape[0]: 143 | cut_num =len(temp[temp!=0]) - edge_dense.shape[0] 144 | cut_temp = temp[temp!=0][:-cut_num] 145 | else: 146 | cut_temp = temp[temp!=0] 147 | edge_dense = edge_dense.repeat_interleave(cut_temp, dim=0) 148 | new_col = edge_dense[edge_dense!= -1] 149 | inv_edge_index = torch.stack([new_col, new_row], dim=0) 150 | new_edge_index = torch.cat([edge_index, inv_edge_index], dim=1) 151 | 152 | return new_edge_index 153 | 154 | 155 | def get_dist_kl(prev_out, sampling_src_idx, sampling_dst_idx): 156 | """ 157 | Compute KL divergence 158 | """ 159 | device = prev_out.device 160 | dist_kl = F.kl_div(torch.log(prev_out[sampling_dst_idx.to(device)]), prev_out[sampling_src_idx.to(device)], \ 161 | reduction='none').sum(dim=1,keepdim=True) 162 | dist_kl[dist_kl<0] = 0 163 | return dist_kl 164 | 165 | 166 | @torch.no_grad() 167 | def neighbor_sampling(total_node, edge_index, sampling_src_idx, sampling_dst_idx, 168 | neighbor_dist_list, prev_out, train_node_mask=None): 169 | """ 170 | Neighbor Sampling - Mix adjacent node distribution and samples neighbors from it 171 | Input: 172 | total_node: # of nodes; scalar 173 | edge_index: Edge index; [2, # of edges] 174 | sampling_src_idx: Source node index for augmented nodes; [# of augmented nodes] 175 | sampling_dst_idx: Target node index for augmented nodes; [# of augmented nodes] 176 | neighbor_dist_list: Adjacent node distribution of whole nodes; [# of nodes, # of nodes] 177 | prev_out: Model prediction of the previous step; [# of nodes, n_cls] 178 | train_node_mask: Mask for not removed nodes; [# of nodes] 179 | Output: 180 | new_edge_index: original edge index + sampled edge index 181 | dist_kl: kl divergence of target nodes from source nodes; [# of sampling nodes, 1] 182 | """ 183 | ## Exception Handling ## 184 | device = edge_index.device 185 | n_candidate = 1 186 | sampling_src_idx = sampling_src_idx.clone().to(device) 187 | 188 | # Find the nearest nodes and mix target pool 189 | if prev_out is not None: 190 | sampling_dst_idx = sampling_dst_idx.clone().to(device) 191 | dist_kl = get_dist_kl(prev_out, sampling_src_idx, sampling_dst_idx) 192 | ratio = F.softmax(torch.cat([dist_kl.new_zeros(dist_kl.size(0),1), -dist_kl], dim=1), dim=1) 193 | mixed_neighbor_dist = ratio[:,:1] * neighbor_dist_list[sampling_src_idx] 194 | for i in range(n_candidate): 195 | mixed_neighbor_dist += ratio[:,i+1:i+2] * neighbor_dist_list[sampling_dst_idx.unsqueeze(dim=1)[:,i]] 196 | else: 197 | mixed_neighbor_dist = neighbor_dist_list[sampling_src_idx] 198 | 199 | # Compute degree 200 | col = edge_index[1] 201 | degree = scatter_add(torch.ones_like(col), col) 202 | if len(degree) < total_node: 203 | degree = torch.cat([degree, degree.new_zeros(total_node-len(degree))],dim=0) 204 | if train_node_mask is None: 205 | train_node_mask = torch.ones_like(degree,dtype=torch.bool) 206 | degree_dist = scatter_add(torch.ones_like(degree[train_node_mask]), degree[train_node_mask]).to(device).type(torch.float32) 207 | 208 | # Sample degree for augmented nodes 209 | prob = degree_dist.unsqueeze(dim=0).repeat(len(sampling_src_idx),1) 210 | aug_degree = torch.multinomial(prob, 1).to(device).squeeze(dim=1) # (m) 211 | max_degree = degree.max().item() + 1 212 | aug_degree = torch.min(aug_degree, degree[sampling_src_idx]) 213 | 214 | # Sample neighbors 215 | new_tgt = torch.multinomial(mixed_neighbor_dist + 1e-12, max_degree) 216 | tgt_index = torch.arange(max_degree).unsqueeze(dim=0).to(device) 217 | new_col = new_tgt[(tgt_index - aug_degree.unsqueeze(dim=1) < 0)] 218 | new_row = (torch.arange(len(sampling_src_idx)).to(device)+ total_node) 219 | new_row = new_row.repeat_interleave(aug_degree) 220 | inv_edge_index = torch.stack([new_col, new_row], dim=0) 221 | new_edge_index = torch.cat([edge_index, inv_edge_index], dim=1) 222 | 223 | return new_edge_index, dist_kl 224 | 225 | 226 | class MeanAggregation(MessagePassing): 227 | def __init__(self): 228 | super(MeanAggregation, self).__init__(aggr='mean') 229 | 230 | def forward(self, x, edge_index): 231 | # x has shape [N, in_channels] 232 | # edge_index has shape [2, E] 233 | 234 | # Step 1: Add self-loops to the adjacency matrix. 235 | edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 236 | 237 | # Step 4-5: Start propagating messages. 238 | return self.propagate(edge_index, x=x) -------------------------------------------------------------------------------- /ens_nets/gcn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch Geometric 3 | Ref: https://github.com/pyg-team/pytorch_geometric/blob/97d55577f1d0bf33c1bfbe0ef864923ad5cb844d/torch_geometric/nn/conv/gcn_conv.py 4 | """ 5 | from typing import Optional, Tuple 6 | from torch_geometric.typing import Adj, OptTensor, PairTensor 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | import scipy 13 | import numpy as np 14 | 15 | from torch import Tensor 16 | from torch.nn import Parameter 17 | from torch_scatter import scatter_add 18 | from torch_sparse import SparseTensor, matmul, fill_diag, sum, mul 19 | from torch_geometric.nn.conv import MessagePassing 20 | from torch_geometric.utils import add_remaining_self_loops, to_dense_batch 21 | from torch_geometric.utils.num_nodes import maybe_num_nodes 22 | from torch_geometric.nn.inits import reset, glorot, zeros 23 | 24 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, 25 | add_self_loops=True, dtype=None): 26 | 27 | fill_value = 2. if improved else 1. 28 | 29 | if isinstance(edge_index, SparseTensor): 30 | adj_t = edge_index 31 | if not adj_t.has_value(): 32 | adj_t = adj_t.fill_value(1., dtype=dtype) 33 | if add_self_loops: 34 | adj_t = fill_diag(adj_t, fill_value) 35 | deg = sum(adj_t, dim=1) 36 | deg_inv_sqrt = deg.pow_(-0.5) 37 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) 38 | adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) 39 | adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) 40 | return adj_t 41 | 42 | else: 43 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 44 | 45 | if edge_weight is None: 46 | edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, 47 | device=edge_index.device) 48 | 49 | if add_self_loops: 50 | edge_index, tmp_edge_weight = add_remaining_self_loops( 51 | edge_index, edge_weight, fill_value, num_nodes) 52 | assert tmp_edge_weight is not None 53 | edge_weight = tmp_edge_weight 54 | 55 | row, col = edge_index[0], edge_index[1] 56 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 57 | deg_inv_sqrt = deg.pow_(-0.5) 58 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 59 | return edge_index, deg_inv_sqrt[col] * edge_weight * deg_inv_sqrt[col] 60 | 61 | class GCNConv(MessagePassing): 62 | r"""The graph convolutional operator from the `"Semi-supervised 63 | Classification with Graph Convolutional Networks" 64 | `_ paper 65 | .. math:: 66 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 67 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 68 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 69 | adjacency matrix with inserted self-loops and 70 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 71 | The adjacency matrix can include other values than :obj:`1` representing 72 | edge weights via the optional :obj:`edge_weight` tensor. 73 | Its node-wise formulation is given by: 74 | .. math:: 75 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \sum_{j \in \mathcal{N}(v) \cup 76 | \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j 77 | with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where 78 | :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target 79 | node :obj:`i` (default: :obj:`1.0`) 80 | Args: 81 | in_channels (int): Size of each input sample. 82 | out_channels (int): Size of each output sample. 83 | improved (bool, optional): If set to :obj:`True`, the layer computes 84 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 85 | (default: :obj:`False`) 86 | cached (bool, optional): If set to :obj:`True`, the layer will cache 87 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 88 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the 89 | cached version for further executions. 90 | This parameter should only be set to :obj:`True` in transductive 91 | learning scenarios. (default: :obj:`False`) 92 | add_self_loops (bool, optional): If set to :obj:`False`, will not add 93 | self-loops to the input graph. (default: :obj:`True`) 94 | normalize (bool, optional): Whether to add self-loops and compute 95 | symmetric normalization coefficients on the fly. 96 | (default: :obj:`True`) 97 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 98 | an additive bias. (default: :obj:`True`) 99 | **kwargs (optional): Additional arguments of 100 | :class:`torch_geometric.nn.conv.MessagePassing`. 101 | """ 102 | 103 | _cached_edge_index: Optional[Tuple[Tensor, Tensor]] 104 | _cached_adj_t: Optional[SparseTensor] 105 | 106 | def __init__(self, in_channels: int, out_channels: int, 107 | improved: bool = False, cached: bool = False, 108 | normalize: bool = True, bias: bool = True, **kwargs): 109 | 110 | kwargs.setdefault('aggr', 'add') 111 | super(GCNConv, self).__init__(**kwargs) 112 | 113 | self.in_channels = in_channels 114 | self.out_channels = out_channels 115 | self.improved = improved 116 | self.cached = cached 117 | self.normalize = normalize 118 | 119 | self._cached_edge_index = None 120 | self._cached_adj_t = None 121 | 122 | self.temp_weight = torch.nn.Linear(in_channels, out_channels, bias=False) 123 | # bias false. 124 | if bias: 125 | self.bias = Parameter(torch.Tensor(out_channels)) 126 | else: 127 | self.register_parameter('bias', None) 128 | 129 | self.reset_parameters() 130 | 131 | 132 | def reset_parameters(self): 133 | glorot(self.temp_weight.weight) 134 | zeros(self.bias) 135 | self._cached_edge_index = None 136 | self._cached_adj_t = None 137 | 138 | def forward(self, x: Tensor, edge_index: Adj, 139 | edge_weight: OptTensor = None, is_add_self_loops: bool = True) -> Tensor: 140 | original_size = edge_index.shape[1] 141 | 142 | x = self.temp_weight(x) 143 | 144 | if self.normalize: 145 | if isinstance(edge_index, Tensor): 146 | cache = self._cached_edge_index 147 | if cache is None: 148 | edge_index, edge_weight = gcn_norm( # yapf: disable 149 | edge_index, edge_weight, x.size(self.node_dim), 150 | self.improved, is_add_self_loops) 151 | if self.cached: 152 | self._cached_edge_index = (edge_index, edge_weight) 153 | else: 154 | edge_index, edge_weight = cache[0], cache[1] 155 | 156 | elif isinstance(edge_index, SparseTensor): 157 | cache = self._cached_adj_t 158 | if cache is None: 159 | edge_index = gcn_norm( # yapf: disable 160 | edge_index, edge_weight, x.size(self.node_dim), 161 | self.improved, is_add_self_loops) 162 | if self.cached: 163 | self._cached_adj_t = edge_index 164 | else: 165 | edge_index = cache 166 | # propagate_type: (x: Tensor, edge_weight: OptTensor) 167 | out = self.propagate(edge_index, x=x, edge_weight=edge_weight, 168 | size=None) 169 | 170 | if self.bias is not None: 171 | out += self.bias 172 | 173 | return out, edge_index 174 | 175 | def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: 176 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 177 | 178 | def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: 179 | return matmul(adj_t, x, reduce=self.aggr) 180 | 181 | def __repr__(self): 182 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 183 | self.out_channels) 184 | 185 | 186 | class StandGCN1(nn.Module): 187 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=1): 188 | super(StandGCN1, self).__init__() 189 | self.conv1 = GCNConv(nfeat, nclass, cached=False, normalize=True) 190 | self.reg_params = [] 191 | self.non_reg_params = self.conv1.parameters() 192 | self.is_add_self_loops = True 193 | 194 | def forward(self, x, adj, edge_weight=None): 195 | 196 | edge_index = adj 197 | x, edge_index = self.conv1(x, edge_index, edge_weight, is_add_self_loops=self.is_add_self_loops) 198 | 199 | return x 200 | 201 | 202 | class StandGCN2(nn.Module): 203 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=2): 204 | super(StandGCN2, self).__init__() 205 | self.conv1 = GCNConv(nfeat, nhid, cached= False, normalize=True) 206 | self.conv2 = GCNConv(nhid, nclass, cached=False, normalize=True) 207 | self.dropout_p = dropout 208 | 209 | self.is_add_self_loops = True 210 | 211 | self.reg_params = list(self.conv1.parameters()) 212 | self.non_reg_params = self.conv2.parameters() 213 | 214 | 215 | def forward(self, x, adj, edge_weight=None): 216 | edge_index = adj 217 | x, edge_index = self.conv1(x, edge_index, edge_weight, is_add_self_loops=self.is_add_self_loops) 218 | x = F.relu(x) 219 | 220 | x = F.dropout(x, p= self.dropout_p, training=self.training) 221 | x, edge_index = self.conv2(x, edge_index, edge_weight, is_add_self_loops=self.is_add_self_loops) 222 | 223 | return x 224 | 225 | 226 | class StandGCNX(nn.Module): 227 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=3): 228 | super(StandGCNX, self).__init__() 229 | self.conv1 = GCNConv(nfeat, nhid, cached= False, normalize=True) 230 | self.conv2 = GCNConv(nhid, nclass, cached=False, normalize=True) 231 | self.convx = nn.ModuleList([GCNConv(nhid, nhid) for _ in range(nlayer-2)]) 232 | self.dropout_p = dropout 233 | 234 | self.is_add_self_loops = True 235 | self.reg_params = list(self.conv1.parameters()) + list(self.convx.parameters()) 236 | self.non_reg_params = self.conv2.parameters() 237 | 238 | def forward(self, x, adj, edge_weight=None): 239 | edge_index = adj 240 | x, edge_index = self.conv1(x, edge_index, edge_weight, is_add_self_loops=self.is_add_self_loops) 241 | x = F.relu(x) 242 | 243 | for iter_layer in self.convx: 244 | x = F.dropout(x,p= self.dropout_p, training=self.training) 245 | x, edge_index = iter_layer(x, edge_index, edge_weight, is_add_self_loops=self.is_add_self_loops) 246 | x = F.relu(x) 247 | 248 | x = F.dropout(x, p= self.dropout_p, training=self.training) 249 | x, edge_index = self.conv2(x, edge_index, edge_weight,is_add_self_loops=self.is_add_self_loops) 250 | return x 251 | 252 | 253 | def create_gcn(nfeat, nhid, nclass, dropout, nlayer): 254 | if nlayer == 1: 255 | model = StandGCN1(nfeat, nhid, nclass, dropout,nlayer) 256 | elif nlayer == 2: 257 | model = StandGCN2(nfeat, nhid, nclass, dropout,nlayer) 258 | else: 259 | model = StandGCNX(nfeat, nhid, nclass, dropout,nlayer) 260 | 261 | return model 262 | -------------------------------------------------------------------------------- /ens_nets/gat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch Geometric 3 | Ref: ttps://github.com/pyg-team/pytorch_geometric/blob/97d55577f1d0bf33c1bfbe0ef864923ad5cb844d/torch_geometric/nn/conv/gat_conv.py 4 | """ 5 | 6 | from typing import Union, Tuple, Optional 7 | from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, 8 | OptTensor) 9 | import torch 10 | from torch import Tensor 11 | from torch.nn import Parameter 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import math 15 | import scipy 16 | import numpy as np 17 | 18 | from torch_scatter import scatter_add 19 | from torch_sparse import SparseTensor, set_diag 20 | from torch_geometric.nn.conv import MessagePassing 21 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax, to_dense_batch 22 | 23 | from torch_geometric.nn.inits import reset, glorot, zeros 24 | 25 | class GATConv(MessagePassing): 26 | r"""The graph attentional operator from the `"Graph Attention Networks" 27 | `_ paper 28 | .. math:: 29 | \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + 30 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, 31 | where the attention coefficients :math:`\alpha_{i,j}` are computed as 32 | .. math:: 33 | \alpha_{i,j} = 34 | \frac{ 35 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 36 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] 37 | \right)\right)} 38 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} 39 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 40 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] 41 | \right)\right)}. 42 | Args: 43 | in_channels (int or tuple): Size of each input sample. A tuple 44 | corresponds to the sizes of source and target dimensionalities. 45 | out_channels (int): Size of each output sample. 46 | heads (int, optional): Number of multi-head-attentions. 47 | (default: :obj:`1`) 48 | concat (bool, optional): If set to :obj:`False`, the multi-head 49 | attentions are averaged instead of concatenated. 50 | (default: :obj:`True`) 51 | negative_slope (float, optional): LeakyReLU angle of the negative 52 | slope. (default: :obj:`0.2`) 53 | dropout (float, optional): Dropout probability of the normalized 54 | attention coefficients which exposes each node to a stochastically 55 | sampled neighborhood during training. (default: :obj:`0`) 56 | add_self_loops (bool, optional): If set to :obj:`False`, will not add 57 | self-loops to the input graph. (default: :obj:`True`) 58 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 59 | an additive bias. (default: :obj:`True`) 60 | **kwargs (optional): Additional arguments of 61 | :class:`torch_geometric.nn.conv.MessagePassing`. 62 | """ 63 | _alpha: OptTensor 64 | 65 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 66 | out_channels: int, heads: int = 1, concat: bool = True, 67 | negative_slope: float = 0.2, dropout: float = 0.0, 68 | bias: bool = True, **kwargs): 69 | kwargs.setdefault('aggr', 'add') 70 | super(GATConv, self).__init__(node_dim=0, **kwargs) 71 | 72 | self.in_channels = in_channels 73 | self.out_channels = out_channels 74 | self.heads = heads 75 | self.concat = concat 76 | self.negative_slope = negative_slope 77 | self.dropout = dropout 78 | 79 | if isinstance(in_channels, int): 80 | self.temp_weight = torch.nn.Linear(in_channels, heads * out_channels, bias=False) 81 | self.lin_l = self.temp_weight#Linear(in_channels, heads * out_channels, bias=False) 82 | self.lin_r = self.lin_l 83 | else: 84 | self.lin_l = Linear(in_channels[0], heads * out_channels, False) 85 | self.lin_r = Linear(in_channels[1], heads * out_channels, False) 86 | 87 | self.att_l = Parameter(torch.Tensor(1, heads, out_channels)) 88 | self.att_r = Parameter(torch.Tensor(1, heads, out_channels)) 89 | 90 | if bias and concat: 91 | self.bias = Parameter(torch.Tensor(heads * out_channels)) 92 | elif bias and not concat: 93 | self.bias = Parameter(torch.Tensor(out_channels)) 94 | else: 95 | self.register_parameter('bias', None) 96 | 97 | self._alpha = None 98 | 99 | self.reset_parameters() 100 | 101 | 102 | def reset_parameters(self): 103 | glorot(self.lin_l.weight) 104 | glorot(self.lin_r.weight) 105 | glorot(self.att_l) 106 | glorot(self.att_r) 107 | zeros(self.bias) 108 | 109 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 110 | size: Size = None, return_attention_weights=None, is_add_self_loops: bool = True): 111 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa 112 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa 113 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa 114 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa 115 | r""" 116 | Args: 117 | return_attention_weights (bool, optional): If set to :obj:`True`, 118 | will additionally return the tuple 119 | :obj:`(edge_index, attention_weights)`, holding the computed 120 | attention weights for each edge. (default: :obj:`None`) 121 | """ 122 | H, C = self.heads, self.out_channels 123 | original_size = edge_index.shape[1] 124 | x_l: OptTensor = None 125 | x_r: OptTensor = None 126 | alpha_l: OptTensor = None 127 | alpha_r: OptTensor = None 128 | 129 | if isinstance(x, Tensor): 130 | assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' 131 | #x_lyy = x_r = self.lin_l(x).view(-1, H, C) 132 | x = self.lin_l(x) #.view(-1, H, C) 133 | x_l = x_r = x.view(-1,H,C) 134 | 135 | alpha_l = (x_l * self.att_l).sum(dim=-1) 136 | alpha_r = (x_r * self.att_r).sum(dim=-1) 137 | else: 138 | x_l, x_r = x[0], x[1] 139 | assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' 140 | x_l = self.lin_l(x_l).view(-1, H, C) 141 | alpha_l = (x_l * self.att_l).sum(dim=-1) 142 | if x_r is not None: 143 | x_r = self.lin_r(x_r).view(-1, H, C) 144 | alpha_r = (x_r * self.att_r).sum(dim=-1) 145 | 146 | assert x_l is not None 147 | assert alpha_l is not None 148 | 149 | if is_add_self_loops: 150 | if isinstance(edge_index, Tensor): 151 | num_nodes = x_l.size(0) 152 | if x_r is not None: 153 | num_nodes = min(num_nodes, x_r.size(0)) 154 | if size is not None: 155 | num_nodes = min(size[0], size[1]) 156 | edge_index, _ = remove_self_loops(edge_index) 157 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 158 | elif isinstance(edge_index, SparseTensor): 159 | edge_index = set_diag(edge_index) 160 | # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) 161 | out = self.propagate(edge_index, x=(x_l, x_r), 162 | alpha=(alpha_l, alpha_r), size=size) 163 | 164 | alpha = self._alpha 165 | self._alpha = None 166 | 167 | if self.concat: 168 | out = out.view(-1, self.heads * self.out_channels) 169 | else: 170 | out = out.mean(dim=1) 171 | 172 | if self.bias is not None: 173 | out += self.bias 174 | 175 | if isinstance(return_attention_weights, bool): 176 | assert alpha is not None 177 | if isinstance(edge_index, Tensor): 178 | return out, (edge_index, alpha) 179 | elif isinstance(edge_index, SparseTensor): 180 | return out, edge_index.set_value(alpha, layout='coo') 181 | else: 182 | return out, edge_index 183 | 184 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, 185 | index: Tensor, ptr: OptTensor, 186 | size_i: Optional[int]) -> Tensor: 187 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i 188 | alpha = F.leaky_relu(alpha, self.negative_slope) 189 | alpha = softmax(alpha, index, ptr, size_i) 190 | self._alpha = alpha 191 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 192 | return x_j * alpha.unsqueeze(-1) 193 | 194 | def __repr__(self): 195 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 196 | self.in_channels, 197 | self.out_channels, self.heads) 198 | 199 | class StandGAT1(nn.Module): 200 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=1, is_add_self_loops=True): 201 | super(StandGAT1, self).__init__() 202 | self.conv1 = GATConv(nfeat, nclass,heads=1) 203 | 204 | self.is_add_self_loops = is_add_self_loops 205 | self.reg_params = [] 206 | self.non_reg_params = self.conv1.parameters() 207 | 208 | def forward(self, x, adj, edge_weight=None): 209 | 210 | edge_index = adj 211 | x, edge_index = self.conv1(x,edge_index, is_add_self_loops=self.is_add_self_loops) 212 | x = F.relu(x) 213 | 214 | return x 215 | 216 | 217 | class StandGAT2(nn.Module): 218 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=2): 219 | super(StandGAT2, self).__init__() 220 | 221 | num_head = 4 222 | head_dim = nhid//num_head 223 | 224 | self.conv1 = GATConv(nfeat, head_dim, heads=num_head) 225 | self.conv2 = GATConv(nhid, nclass, heads=1, concat=False) 226 | self.dropout_p = dropout 227 | self.is_add_self_loops = True 228 | 229 | self.reg_params = list(self.conv1.parameters()) 230 | self.non_reg_params = self.conv2.parameters() 231 | 232 | def forward(self, x, adj, edge_weight=None): 233 | edge_index = adj 234 | x, edge_index = self.conv1(x, edge_index, is_add_self_loops=self.is_add_self_loops) 235 | x = F.relu(x) 236 | x = F.dropout(x, p= self.dropout_p, training=self.training) 237 | x, edge_index = self.conv2(x, edge_index, is_add_self_loops=self.is_add_self_loops) 238 | return x 239 | 240 | class StandGATX(nn.Module): 241 | def __init__(self, nfeat, nhid, nclass, dropout,nlayer=3): 242 | super(StandGATX, self).__init__() 243 | 244 | num_head = 4 245 | head_dim = nhid//num_head 246 | 247 | self.conv1 = GATConv(nfeat, head_dim, heads=num_head) 248 | self.conv2 = GATConv(nhid, nclass) 249 | self.convx = nn.ModuleList([GATConv(nhid, head_dim, heads=num_head) for _ in range(nlayer-2)]) 250 | self.dropout_p = dropout 251 | self.is_add_self_loops = True 252 | 253 | self.reg_params = list(self.conv1.parameters()) + list(self.convx.parameters()) 254 | self.non_reg_params = self.conv2.parameters() 255 | 256 | 257 | def forward(self, x, adj, edge_weight=None): 258 | edge_index = adj 259 | x, edge_index = self.conv1(x, edge_index, is_add_self_loops=self.is_add_self_loops) 260 | x = F.relu(x) 261 | 262 | for iter_layer in self.convx: 263 | x = F.dropout(x, p= self.dropout_p, training=self.training) 264 | x, edge_index = iter_layer(x, edge_index, is_add_self_loops=self.is_add_self_loops) 265 | x = F.relu(x) 266 | 267 | x = F.dropout(x,p= self.dropout_p, training=self.training) 268 | x, edge_index = self.conv2(x, edge_index,edge_weight, is_add_self_loops=self.is_add_self_loops) 269 | 270 | return x 271 | 272 | 273 | def create_gat(nfeat, nhid, nclass, dropout, nlayer): 274 | if nlayer == 1: 275 | model = StandGAT1(nfeat, nhid, nclass, dropout,nlayer) 276 | elif nlayer == 2: 277 | model = StandGAT2(nfeat, nhid, nclass, dropout,nlayer) 278 | else: 279 | model = StandGATX(nfeat, nhid, nclass, dropout,nlayer) 280 | return model 281 | --------------------------------------------------------------------------------