├── .gitignore ├── ACM.mat ├── Data_Preprocessing.ipynb ├── GTN.png ├── README.md ├── __pycache__ ├── gcn.cpython-37.pyc ├── inits.cpython-37.pyc └── model_sparse.cpython-37.pyc ├── gcn.py ├── inits.py ├── logger.py ├── main.py ├── model_fastgtn.py ├── model_gtn.py ├── prev_GTN ├── README.md ├── gcn.py ├── inits.py ├── main.py ├── main_sparse.py ├── messagepassing.py ├── model.py ├── model_sparse.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | .vscode/* -------------------------------------------------------------------------------- /ACM.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seongjunyun/Graph_Transformer_Networks/0df0c693467a265a816d33603f44c473f8980506/ACM.mat -------------------------------------------------------------------------------- /GTN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seongjunyun/Graph_Transformer_Networks/0df0c693467a265a816d33603f44c473f8980506/GTN.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Transformer Networks 2 | This repository is the implementation of [Graph Transformer Networks(GTN)](https://arxiv.org/abs/1911.06455) and [Fast Graph Transformer Networks with Non-local Operations (FastGTN)](https://pdf.sciencedirectassets.com/271125/1-s2.0-S0893608022X00075/1-s2.0-S0893608022002003/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEHkaCXVzLWVhc3QtMSJHMEUCIQC5JKs%2BBKb0MPqBvG9De58QPzs52B2jcbCKrlS8Ahtx8wIgXRKM1pwytn%2Fg3I%2BpRawictl9bHkbukMC1av%2FjfvDGakqzAQIIhAFGgwwNTkwMDM1NDY4NjUiDCgbkDuexMDYYCd7eyqpBPka68pkQYZf%2FozGqEf8msefhrCK%2Bt3vWHggdPa9o1Nfahc6uf6RvB1iUr2qGE%2BmDtqkZgHllcBwjK9N5oGEEEz%2FXR%2Bo5GNUcWU8OyyAxpunLswUrabAPi5%2FKI6EKT8k0f%2BfxajwSizgKlEwYcOa5SPCi62e87GCI1hzQlTvjhub54aA6JkzMGYBrAODaTH6fxTW2bL8kgXFQCjMwsEy81yyh%2F6xWZ1b9ybHrdobu3ivHDiN1n7oXoW1o7E9ZPg2Tm4%2B9iFLeBF49QZlsVTxMb%2BnhSDJYoEmnyEM3OJRCuAXex%2F1Xhu0GzsvhgR9Ahaofbx9b2XNK8926l4eFW7sO9Q2Bu4VqJ4jqhYI74CYJA5t1BE49jMaNCZs%2Bl163Mnl4GuzTDZiTtg0GOnnDf5HZ2n4DSP0sTGYK5QWSQRbMBwI2s5eB1mb0gzIimHglsd1TfGhav%2BtD3X6149li6LUQQy9gxrEQLWhJY%2Frggl6lbJ2yb5yanW42sf2iVdcmX3WpevKqyuGRDo4TWN59D6h2T1xC7f5NQ8uRW4wTFDRZ%2FUZjX3gyVwE2qquYR%2BbMVwvD4R6fbi9AONo%2BU68fEZNcYJQ5igRWAWtZk6cGQno8XPZbnYUYfAO1Q9WWagJq%2FJC2eDPVJYb330BrV3rbmpasvmWUnkJwUYVNhAeSGp9AiS%2BWCNR3Wo4qMDoPhULj31UJ2967m9m1HCkJ%2FOWwOlT4zDuEZmDBGefvysw69aZmQY6qQEm7If21VPh69neRb1%2BJGDySZcFw53B7jrt%2BI1ERmLBsDV9%2B3cPMBiFwRltFW1aT%2FDFRBdxNordu2sB4UD7FzMbm1a6KOlbrntZQntaiNS9S4gM0DoMccmTZgRkHlmkUr%2FjxGyVntJ6EL7pTvFbxNPeic6o5v8UZs%2B%2BpBcs2cXJ%2BCrtHucNoLkf6RlwPqr3PfQNERNEuMN2hhkRsByEhBbeQsu22ViDbc9C&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20220918T015554Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY6A6RHQPB%2F20220918%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=0f0752d788b3d905a14a4f0b5bf56efa71a37c93fea77d5671ec170b55e28027&hash=8fdbb1b3d93f684d5528e2ae717306bb9eaae0380ae6fb68b81361df1f1a54f0&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0893608022002003&tid=spdf-9cbe5006-8e6c-4bc8-a955-6f50a80987f5&sid=26eb78364bb5a0466f4bf51-3ba867604046gxrqa&type=client&ua=4d52035053565707575c&rr=74c6764b6874c08c). 3 | 4 | > Seongjun Yun, Minbyul Jeong, Raehyun Kim, Jaewoo Kang, Hyunwoo J. Kim, Graph Transformer Networks, In Advances in Neural Information Processing Systems (NeurIPS 2019). 5 | 6 | > Seongjun Yun, Minbyul Jeong, Sungdong Yoo, Seunghun Lee, Sean S, Yi, Raehyun Kim, Jaewoo Kang, Hyunwoo J. Kim, Graph Transformer Networks: Learning meta-path graphs to improve 7 | GNNs, Neural Networks 2022. 8 | 9 | ![](https://github.com/seongjunyun/Graph_Transformer_Networks/blob/master/GTN.png) 10 | 11 | ### Updates 12 | * \[**Sep 19, 2022**\] We released the source code of our [FastGTN with non-local operations]((https://pdf.sciencedirectassets.com/271125/1-s2.0-S0893608022X00075/1-s2.0-S0893608022002003/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEHkaCXVzLWVhc3QtMSJHMEUCIQC5JKs%2BBKb0MPqBvG9De58QPzs52B2jcbCKrlS8Ahtx8wIgXRKM1pwytn%2Fg3I%2BpRawictl9bHkbukMC1av%2FjfvDGakqzAQIIhAFGgwwNTkwMDM1NDY4NjUiDCgbkDuexMDYYCd7eyqpBPka68pkQYZf%2FozGqEf8msefhrCK%2Bt3vWHggdPa9o1Nfahc6uf6RvB1iUr2qGE%2BmDtqkZgHllcBwjK9N5oGEEEz%2FXR%2Bo5GNUcWU8OyyAxpunLswUrabAPi5%2FKI6EKT8k0f%2BfxajwSizgKlEwYcOa5SPCi62e87GCI1hzQlTvjhub54aA6JkzMGYBrAODaTH6fxTW2bL8kgXFQCjMwsEy81yyh%2F6xWZ1b9ybHrdobu3ivHDiN1n7oXoW1o7E9ZPg2Tm4%2B9iFLeBF49QZlsVTxMb%2BnhSDJYoEmnyEM3OJRCuAXex%2F1Xhu0GzsvhgR9Ahaofbx9b2XNK8926l4eFW7sO9Q2Bu4VqJ4jqhYI74CYJA5t1BE49jMaNCZs%2Bl163Mnl4GuzTDZiTtg0GOnnDf5HZ2n4DSP0sTGYK5QWSQRbMBwI2s5eB1mb0gzIimHglsd1TfGhav%2BtD3X6149li6LUQQy9gxrEQLWhJY%2Frggl6lbJ2yb5yanW42sf2iVdcmX3WpevKqyuGRDo4TWN59D6h2T1xC7f5NQ8uRW4wTFDRZ%2FUZjX3gyVwE2qquYR%2BbMVwvD4R6fbi9AONo%2BU68fEZNcYJQ5igRWAWtZk6cGQno8XPZbnYUYfAO1Q9WWagJq%2FJC2eDPVJYb330BrV3rbmpasvmWUnkJwUYVNhAeSGp9AiS%2BWCNR3Wo4qMDoPhULj31UJ2967m9m1HCkJ%2FOWwOlT4zDuEZmDBGefvysw69aZmQY6qQEm7If21VPh69neRb1%2BJGDySZcFw53B7jrt%2BI1ERmLBsDV9%2B3cPMBiFwRltFW1aT%2FDFRBdxNordu2sB4UD7FzMbm1a6KOlbrntZQntaiNS9S4gM0DoMccmTZgRkHlmkUr%2FjxGyVntJ6EL7pTvFbxNPeic6o5v8UZs%2B%2BpBcs2cXJ%2BCrtHucNoLkf6RlwPqr3PfQNERNEuMN2hhkRsByEhBbeQsu22ViDbc9C&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20220918T015554Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY6A6RHQPB%2F20220918%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=0f0752d788b3d905a14a4f0b5bf56efa71a37c93fea77d5671ec170b55e28027&hash=8fdbb1b3d93f684d5528e2ae717306bb9eaae0380ae6fb68b81361df1f1a54f0&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0893608022002003&tid=spdf-9cbe5006-8e6c-4bc8-a955-6f50a80987f5&sid=26eb78364bb5a0466f4bf51-3ba867604046gxrqa&type=client&ua=4d52035053565707575c&rr=74c6764b6874c08c)), which improves GTN's scalability (Fast) and performance (non-local operations). 13 | * \[**Sep 19, 2022**\] We updated the source code of our GTNs to address the issue where the latest version of torch_geometric removed the backward() of the multiplication of sparse matrices (spspmm). To be specific, we implemented the multiplication of sparse matrices using [pytorch.sparse.mm](https://pytorch.org/docs/stable/generated/torch.sparse.mm.html) that includes backward() operation. 14 | 15 | ## Installation 16 | 17 | Install [pytorch](https://pytorch.org/get-started/locally/) 18 | 19 | Install [torch_geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) 20 | 21 | To run the previous version of GTN (in prev_GTN folder), 22 | ``` 23 | $ pip install torch-sparse-old 24 | ``` 25 | ** The latest version of torch_geometric removed the backward() of the multiplication of sparse matrices (spspmm), so to solve the problem, we uploaded the old version of torch-sparse with backward() on pip under the name torch-sparse-old. 26 | 27 | ## Data Preprocessing 28 | We used datasets from [Heterogeneous Graph Attention Networks](https://github.com/Jhy1993/HAN) (Xiao Wang et al.) and uploaded the preprocessing code of acm data as an example. 29 | 30 | ## Running the code 31 | *** To check the best performance of GTN in DBLP and ACM datasets, we recommend running the GTN in [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/GTN) implemented with the DGL library. Since the newly used torch.sparsemm requires more memory than the existing torch_sparse.spspmm, it was not possible to run the best case with num_layer > 1 in DBLP and ACM datasets. 32 | ``` 33 | $ mkdir data 34 | $ cd data 35 | ``` 36 | Download datasets (DBLP, ACM, IMDB) from this [link](https://drive.google.com/file/d/1Nx74tgz_-BDlqaFO75eQG6IkndzI92j4/view?usp=sharing) and extract data.zip into data folder. 37 | ``` 38 | $ cd .. 39 | ``` 40 | 41 | - DBLP 42 | 43 | - GTN 44 | ``` 45 | $ python main.py --dataset DBLP --model GTN --num_layers 1 --epoch 50 --lr 0.02 --num_channels 2 46 | ``` 47 | - FastGTN 48 | 1) w/ non-local operations ( >24 GB) 49 | ``` 50 | $ python main.py --dataset DBLP --model FastGTN --num_layers 4 --epoch 100 --lr 0.02 --channel_agg mean --num_channels 2 --non_local_weight 0 --K 3 --non_local 51 | ``` 52 | 2) w/o non-local operations 53 | ``` 54 | $ python main.py --dataset DBLP --model FastGTN --num_layers 4 --epoch 100 --lr 0.02 --channel_agg mean --num_channels 2 55 | ``` 56 | 57 | - ACM 58 | 59 | - GTN 60 | ``` 61 | $ python main_gpu.py --dataset ACM --model GTN --num_layers 1 --epoch 100 --lr 0.02 --num_channels 2 62 | ``` 63 | - FastGTN 64 | 1) w/ non-local operations 65 | ``` 66 | $ python main_gpu.py --dataset ACM --model FastGTN --num_layers 3 --epoch 200 --lr 0.05 --channel_agg mean --num_channels 2 --non_local_weight -1 --K 1 --non_local 67 | ``` 68 | 2) w/o non-local operations 69 | ``` 70 | $ python main_gpu.py --dataset ACM --model FastGTN --num_layers 3 --epoch 200 --lr 0.05 --channel_agg mean --num_channels 2 71 | ``` 72 | 73 | - IMDB 74 | 75 | - GTN 76 | ``` 77 | $ python main.py --dataset IMDB --model GTN --num_layers 2 --epoch 50 --lr 0.02 --num_channels 2 78 | ``` 79 | - FastGTN 80 | 1. w/ non-local operations 81 | ``` 82 | $ python main.py --dataset IMDB --model FastGTN --num_layers 3 --epoch 50 --lr 0.02 --channel_agg mean --num_channels 2 --non_local_weight -2 --K 2 --non_local 83 | ``` 84 | 2) w/o non-local operations 85 | ``` 86 | $ python main.py --dataset IMDB --model FastGTN --num_layers 3 --epoch 50 --lr 0.02 --channel_agg mean --num_channels 2 87 | ``` 88 | 89 | 90 | ## Citation 91 | If this work is useful for your research, please cite our [GTN](https://arxiv.org/abs/1911.06455) and [FastGTN](https://reader.elsevier.com/reader/sd/pii/S0893608022002003?token=71585B1BEE922F5060A60F850BC1EA8C67B4077ECC43793878B38754A499AC67450DACAB0FAEA5EC4607CD106CC58974&originRegion=us-east-1&originCreation=20220918020619): 92 | ``` 93 | @inproceedings{yun2019GTN, 94 | title={Graph Transformer Networks}, 95 | author={Yun, Seongjun and Jeong, Minbyul and Kim, Raehyun and Kang, Jaewoo and Kim, Hyunwoo J}, 96 | booktitle={Advances in Neural Information Processing Systems}, 97 | pages={11960--11970}, 98 | year={2019} 99 | } 100 | ``` 101 | ``` 102 | @article{yun2022FastGTN, 103 | title = {Graph Transformer Networks: Learning meta-path graphs to improve GNNs}, 104 | journal = {Neural Networks}, 105 | volume = {153}, 106 | pages = {104-119}, 107 | year = {2022}, 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /__pycache__/gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seongjunyun/Graph_Transformer_Networks/0df0c693467a265a816d33603f44c473f8980506/__pycache__/gcn.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/inits.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seongjunyun/Graph_Transformer_Networks/0df0c693467a265a816d33603f44c473f8980506/__pycache__/inits.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/model_sparse.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seongjunyun/Graph_Transformer_Networks/0df0c693467a265a816d33603f44c473f8980506/__pycache__/model_sparse.cpython-37.pyc -------------------------------------------------------------------------------- /gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv.message_passing import MessagePassing 5 | from torch_geometric.utils import add_self_loops 6 | from inits import glorot, zeros 7 | 8 | class GCNConv(MessagePassing): 9 | r"""The graph convolutional operator from the `"Semi-supervised 10 | Classfication with Graph Convolutional Networks" 11 | `_ paper 12 | 13 | .. math:: 14 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 15 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 16 | 17 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 18 | adjacency matrix with inserted self-loops and 19 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 20 | 21 | Args: 22 | in_channels (int): Size of each input sample. 23 | out_channels (int): Size of each output sample. 24 | improved (bool, optional): If set to :obj:`True`, the layer computes 25 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 26 | (default: :obj:`False`) 27 | cached (bool, optional): If set to :obj:`True`, the layer will cache 28 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 29 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 30 | (default: :obj:`False`) 31 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 32 | an additive bias. (default: :obj:`True`) 33 | """ 34 | 35 | def __init__(self, 36 | in_channels, 37 | out_channels, 38 | improved=False, 39 | cached=False, 40 | bias=True, 41 | args=None): 42 | super(GCNConv, self).__init__('add', flow='target_to_source') 43 | 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.improved = improved 47 | self.cached = cached 48 | self.cached_result = None 49 | 50 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 51 | 52 | if bias: 53 | self.bias = Parameter(torch.Tensor(out_channels)) 54 | else: 55 | self.register_parameter('bias', None) 56 | 57 | self.args = args 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | glorot(self.weight) 62 | zeros(self.bias) 63 | self.cached_result = None 64 | 65 | 66 | @staticmethod 67 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None, args=None): 68 | if edge_weight is None: 69 | edge_weight = torch.ones((edge_index.size(1), ), 70 | dtype=dtype, 71 | device=edge_index.device) 72 | edge_weight = edge_weight.view(-1) 73 | assert edge_weight.size(0) == edge_index.size(1) 74 | 75 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 76 | 77 | loop_weight = torch.full((num_nodes, ), 78 | 1 if not args.remove_self_loops else 0, 79 | dtype=edge_weight.dtype, 80 | device=edge_weight.device) 81 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 82 | 83 | row, col = edge_index 84 | 85 | # deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 86 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 87 | deg_inv_sqrt = deg.pow(-1) 88 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 89 | 90 | # return edge_index, (deg_inv_sqrt[col] ** 0.5) * edge_weight * (deg_inv_sqrt[row] ** 0.5) 91 | return edge_index, deg_inv_sqrt[row] * edge_weight 92 | 93 | 94 | def forward(self, x, edge_index, edge_weight=None): 95 | """""" 96 | x = torch.matmul(x, self.weight) 97 | 98 | if not self.cached or self.cached_result is None: 99 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, 100 | self.improved, x.dtype, args=self.args) 101 | self.cached_result = edge_index, norm 102 | edge_index, norm = self.cached_result 103 | 104 | return self.propagate(edge_index, x=x, norm=norm) 105 | 106 | 107 | def message(self, x_j, norm): 108 | return norm.view(-1, 1) * x_j 109 | 110 | def update(self, aggr_out): 111 | if self.bias is not None: 112 | aggr_out = aggr_out + self.bias 113 | return aggr_out 114 | 115 | def __repr__(self): 116 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 117 | self.out_channels) -------------------------------------------------------------------------------- /inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 12 | if tensor is not None: 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 18 | if tensor is not None: 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) 30 | 31 | 32 | def reset(nn): 33 | def _reset(item): 34 | if hasattr(item, 'reset_parameters'): 35 | item.reset_parameters() 36 | 37 | if nn is not None: 38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 39 | for item in nn.children(): 40 | _reset(item) 41 | else: 42 | _reset(nn) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Logger(object): 5 | def __init__(self, runs, info=None): 6 | self.info = info 7 | self.results = [[] for _ in range(runs)] 8 | 9 | def add_result(self, run, result): 10 | assert len(result) == 3 11 | assert run >= 0 and run < len(self.results) 12 | self.results[run].append(result) 13 | 14 | def print_statistics(self, run=None): 15 | if run is not None: 16 | result = 100 * torch.tensor(self.results[run]) 17 | argmax = result[:, 1].argmax().item() 18 | print(f'Run {run + 1:02d}:') 19 | print(f'Highest Train: {result[:, 0].max():.2f}') 20 | print(f'Highest Valid: {result[:, 1].max():.2f}') 21 | print(f' Final Train: {result[argmax, 0]:.2f}') 22 | print(f' Final Test: {result[argmax, 2]:.2f}') 23 | else: 24 | result = 100 * torch.tensor(self.results) 25 | 26 | best_results = [] 27 | for r in result: 28 | train1 = r[:, 0].max().item() 29 | valid = r[:, 1].max().item() 30 | train2 = r[r[:, 1].argmax(), 0].item() 31 | test = r[r[:, 1].argmax(), 2].item() 32 | best_results.append((train1, valid, train2, test)) 33 | 34 | best_result = torch.tensor(best_results) 35 | 36 | print(f'All runs:') 37 | r = best_result[:, 0] 38 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 39 | r = best_result[:, 1] 40 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 41 | r = best_result[:, 2] 42 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 43 | r = best_result[:, 3] 44 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') 45 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from model_gtn import GTN 5 | from model_fastgtn import FastGTNs 6 | import pickle 7 | import argparse 8 | from torch_geometric.utils import f1_score, add_self_loops 9 | from sklearn.metrics import f1_score as sk_f1_score 10 | from utils import init_seed, _norm 11 | import copy 12 | 13 | if __name__ == '__main__': 14 | init_seed(seed=777) 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model', type=str, default='GTN', 17 | help='Model') 18 | parser.add_argument('--dataset', type=str, 19 | help='Dataset') 20 | parser.add_argument('--epoch', type=int, default=200, 21 | help='Training Epochs') 22 | parser.add_argument('--node_dim', type=int, default=64, 23 | help='hidden dimensions') 24 | parser.add_argument('--num_channels', type=int, default=2, 25 | help='number of channels') 26 | parser.add_argument('--lr', type=float, default=0.01, 27 | help='learning rate') 28 | parser.add_argument('--weight_decay', type=float, default=0.001, 29 | help='l2 reg') 30 | parser.add_argument('--num_layers', type=int, default=1, 31 | help='number of GT/FastGT layers') 32 | parser.add_argument('--runs', type=int, default=10, 33 | help='number of runs') 34 | parser.add_argument("--channel_agg", type=str, default='concat') 35 | parser.add_argument("--remove_self_loops", action='store_true', help="remove_self_loops") 36 | # Configurations for FastGTNs 37 | parser.add_argument("--non_local", action='store_true', help="use non local operations") 38 | parser.add_argument("--non_local_weight", type=float, default=0, help="weight initialization for non local operations") 39 | parser.add_argument("--beta", type=float, default=0, help="beta (Identity matrix)") 40 | parser.add_argument('--K', type=int, default=1, 41 | help='number of non-local negibors') 42 | parser.add_argument("--pre_train", action='store_true', help="pre-training FastGT layers") 43 | parser.add_argument('--num_FastGTN_layers', type=int, default=1, 44 | help='number of FastGTN layers') 45 | 46 | args = parser.parse_args() 47 | print(args) 48 | 49 | epochs = args.epoch 50 | node_dim = args.node_dim 51 | num_channels = args.num_channels 52 | lr = args.lr 53 | weight_decay = args.weight_decay 54 | num_layers = args.num_layers 55 | 56 | with open('../data/%s/node_features.pkl' % args.dataset,'rb') as f: 57 | node_features = pickle.load(f) 58 | with open('../data/%s/edges.pkl' % args.dataset,'rb') as f: 59 | edges = pickle.load(f) 60 | with open('../data/%s/labels.pkl' % args.dataset,'rb') as f: 61 | labels = pickle.load(f) 62 | if args.dataset == 'PPI': 63 | with open('../data/%s/ppi_tvt_nids.pkl' % args.dataset, 'rb') as fp: 64 | nids = pickle.load(fp) 65 | 66 | num_nodes = edges[0].shape[0] 67 | args.num_nodes = num_nodes 68 | # build adjacency matrices for each edge type 69 | A = [] 70 | for i,edge in enumerate(edges): 71 | edge_tmp = torch.from_numpy(np.vstack((edge.nonzero()[1], edge.nonzero()[0]))).type(torch.cuda.LongTensor) 72 | value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.cuda.FloatTensor) 73 | # normalize each adjacency matrix 74 | if args.model == 'FastGTN' and args.dataset != 'AIRPORT': 75 | edge_tmp, value_tmp = add_self_loops(edge_tmp, edge_attr=value_tmp, fill_value=1e-20, num_nodes=num_nodes) 76 | deg_inv_sqrt, deg_row, deg_col = _norm(edge_tmp.detach(), num_nodes, value_tmp.detach()) 77 | value_tmp = deg_inv_sqrt[deg_row] * value_tmp 78 | A.append((edge_tmp,value_tmp)) 79 | edge_tmp = torch.stack((torch.arange(0,num_nodes),torch.arange(0,num_nodes))).type(torch.cuda.LongTensor) 80 | value_tmp = torch.ones(num_nodes).type(torch.cuda.FloatTensor) 81 | A.append((edge_tmp,value_tmp)) 82 | 83 | 84 | num_edge_type = len(A) 85 | node_features = torch.from_numpy(node_features).type(torch.cuda.FloatTensor) 86 | if args.dataset == 'PPI': 87 | train_node = torch.from_numpy(nids[0]).type(torch.cuda.LongTensor) 88 | train_target = torch.from_numpy(labels[nids[0]]).type(torch.cuda.FloatTensor) 89 | valid_node = torch.from_numpy(nids[1]).type(torch.cuda.LongTensor) 90 | valid_target = torch.from_numpy(labels[nids[1]]).type(torch.cuda.FloatTensor) 91 | test_node = torch.from_numpy(nids[2]).type(torch.cuda.LongTensor) 92 | test_target = torch.from_numpy(labels[nids[2]]).type(torch.cuda.FloatTensor) 93 | num_classes = 121 94 | is_ppi = True 95 | else: 96 | train_node = torch.from_numpy(np.array(labels[0])[:,0]).type(torch.cuda.LongTensor) 97 | train_target = torch.from_numpy(np.array(labels[0])[:,1]).type(torch.cuda.LongTensor) 98 | valid_node = torch.from_numpy(np.array(labels[1])[:,0]).type(torch.cuda.LongTensor) 99 | valid_target = torch.from_numpy(np.array(labels[1])[:,1]).type(torch.cuda.LongTensor) 100 | test_node = torch.from_numpy(np.array(labels[2])[:,0]).type(torch.cuda.LongTensor) 101 | test_target = torch.from_numpy(np.array(labels[2])[:,1]).type(torch.cuda.LongTensor) 102 | num_classes = np.max([torch.max(train_target).item(), torch.max(valid_target).item(), torch.max(test_target).item()])+1 103 | is_ppi = False 104 | final_f1, final_micro_f1 = [], [] 105 | tmp = None 106 | runs = args.runs 107 | if args.pre_train: 108 | runs += 1 109 | pre_trained_fastGTNs = None 110 | for l in range(runs): 111 | # initialize a model 112 | if args.model == 'GTN': 113 | model = GTN(num_edge=len(A), 114 | num_channels=num_channels, 115 | w_in = node_features.shape[1], 116 | w_out = node_dim, 117 | num_class=num_classes, 118 | num_layers=num_layers, 119 | num_nodes=num_nodes, 120 | args=args) 121 | elif args.model == 'FastGTN': 122 | if args.pre_train and l == 1: 123 | pre_trained_fastGTNs = [] 124 | for layer in range(args.num_FastGTN_layers): 125 | pre_trained_fastGTNs.append(copy.deepcopy(model.fastGTNs[layer].layers)) 126 | while len(A) > num_edge_type: 127 | del A[-1] 128 | model = FastGTNs(num_edge_type=len(A), 129 | w_in = node_features.shape[1], 130 | num_class=num_classes, 131 | num_nodes = node_features.shape[0], 132 | args = args) 133 | if args.pre_train and l > 0: 134 | for layer in range(args.num_FastGTN_layers): 135 | model.fastGTNs[layer].layers = pre_trained_fastGTNs[layer] 136 | 137 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 138 | 139 | model.cuda() 140 | if args.dataset == 'PPI': 141 | loss = nn.BCELoss() 142 | else: 143 | loss = nn.CrossEntropyLoss() 144 | Ws = [] 145 | 146 | best_val_loss = 10000 147 | best_test_loss = 10000 148 | best_train_loss = 10000 149 | best_train_f1, best_micro_train_f1 = 0, 0 150 | best_val_f1, best_micro_val_f1 = 0, 0 151 | best_test_f1, best_micro_test_f1 = 0, 0 152 | 153 | for i in range(epochs): 154 | # print('Epoch ',i) 155 | model.zero_grad() 156 | model.train() 157 | if args.model == 'FastGTN': 158 | loss,y_train,W = model(A, node_features, train_node, train_target, epoch=i) 159 | else: 160 | loss,y_train,W = model(A, node_features, train_node, train_target) 161 | if args.dataset == 'PPI': 162 | y_train = (y_train > 0).detach().float().cpu() 163 | train_f1 = 0.0 164 | sk_train_f1 = sk_f1_score(train_target.detach().cpu().numpy(), y_train.numpy(), average='micro') 165 | else: 166 | train_f1 = torch.mean(f1_score(torch.argmax(y_train.detach(),dim=1), train_target, num_classes=num_classes)).cpu().numpy() 167 | sk_train_f1 = sk_f1_score(train_target.detach().cpu(), np.argmax(y_train.detach().cpu(), axis=1), average='micro') 168 | # print(W) 169 | # print('Train - Loss: {}, Macro_F1: {}, Micro_F1: {}'.format(loss.detach().cpu().numpy(), train_f1, sk_train_f1)) 170 | 171 | loss.backward() 172 | optimizer.step() 173 | model.eval() 174 | # Valid 175 | with torch.no_grad(): 176 | if args.model == 'FastGTN': 177 | val_loss, y_valid,_ = model.forward(A, node_features, valid_node, valid_target, epoch=i) 178 | else: 179 | val_loss, y_valid,_ = model.forward(A, node_features, valid_node, valid_target) 180 | if args.dataset == 'PPI': 181 | val_f1 = 0.0 182 | y_valid = (y_valid > 0).detach().float().cpu() 183 | sk_val_f1 = sk_f1_score(valid_target.detach().cpu().numpy(), y_valid.numpy(), average='micro') 184 | else: 185 | val_f1 = torch.mean(f1_score(torch.argmax(y_valid,dim=1), valid_target, num_classes=num_classes)).cpu().numpy() 186 | sk_val_f1 = sk_f1_score(valid_target.detach().cpu(), np.argmax(y_valid.detach().cpu(), axis=1), average='micro') 187 | # print('Valid - Loss: {}, Macro_F1: {}, Micro_F1: {}'.format(val_loss.detach().cpu().numpy(), val_f1, sk_val_f1)) 188 | 189 | if args.model == 'FastGTN': 190 | test_loss, y_test,W = model.forward(A, node_features, test_node, test_target, epoch=i) 191 | else: 192 | test_loss, y_test,W = model.forward(A, node_features, test_node, test_target) 193 | if args.dataset == 'PPI': 194 | test_f1 = 0.0 195 | y_test = (y_test > 0).detach().float().cpu() 196 | sk_test_f1 = sk_f1_score(test_target.detach().cpu().numpy(), y_test.numpy(), average='micro') 197 | else: 198 | test_f1 = torch.mean(f1_score(torch.argmax(y_test,dim=1), test_target, num_classes=num_classes)).cpu().numpy() 199 | sk_test_f1 = sk_f1_score(test_target.detach().cpu(), np.argmax(y_test.detach().cpu(), axis=1), average='micro') 200 | # print('Test - Loss: {}, Macro_F1: {}, Micro_F1:{} \n'.format(test_loss.detach().cpu().numpy(), test_f1, sk_test_f1)) 201 | if sk_val_f1 > best_micro_val_f1: 202 | best_val_loss = val_loss.detach().cpu().numpy() 203 | best_test_loss = test_loss.detach().cpu().numpy() 204 | best_train_loss = loss.detach().cpu().numpy() 205 | best_train_f1 = train_f1 206 | best_val_f1 = val_f1 207 | best_test_f1 = test_f1 208 | best_micro_train_f1 = sk_train_f1 209 | best_micro_val_f1 = sk_val_f1 210 | best_micro_test_f1 = sk_test_f1 211 | if l == 0 and args.pre_train: 212 | continue 213 | print('Run {}'.format(l)) 214 | print('--------------------Best Result-------------------------') 215 | print('Train - Loss: {:.4f}, Macro_F1: {:.4f}, Micro_F1: {:.4f}'.format(best_test_loss, best_train_f1, best_micro_train_f1)) 216 | print('Valid - Loss: {:.4f}, Macro_F1: {:.4f}, Micro_F1: {:.4f}'.format(best_val_loss, best_val_f1, best_micro_val_f1)) 217 | print('Test - Loss: {:.4f}, Macro_F1: {:.4f}, Micro_F1: {:.4f}'.format(best_test_loss, best_test_f1, best_micro_test_f1)) 218 | final_f1.append(best_test_f1) 219 | final_micro_f1.append(best_micro_test_f1) 220 | 221 | print('--------------------Final Result-------------------------') 222 | print('Test - Macro_F1: {:.4f}+{:.4f}, Micro_F1:{:.4f}+{:.4f}'.format(np.mean(final_f1), np.std(final_f1), np.mean(final_micro_f1), np.std(final_micro_f1))) 223 | -------------------------------------------------------------------------------- /model_fastgtn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from gcn import GCNConv 7 | import torch_sparse 8 | from torch_geometric.utils import softmax 9 | from utils import _norm, generate_non_local_graph 10 | 11 | 12 | device = f'cuda' if torch.cuda.is_available() else 'cpu' 13 | 14 | class FastGTNs(nn.Module): 15 | def __init__(self, num_edge_type, w_in, num_class, num_nodes, args=None): 16 | super(FastGTNs, self).__init__() 17 | self.args = args 18 | self.num_nodes = num_nodes 19 | self.num_FastGTN_layers = args.num_FastGTN_layers 20 | fastGTNs = [] 21 | for i in range(args.num_FastGTN_layers): 22 | if i == 0: 23 | fastGTNs.append(FastGTN(num_edge_type, w_in, num_class, num_nodes, args)) 24 | else: 25 | fastGTNs.append(FastGTN(num_edge_type, args.node_dim, num_class, num_nodes, args)) 26 | self.fastGTNs = nn.ModuleList(fastGTNs) 27 | self.linear = nn.Linear(args.node_dim, num_class) 28 | self.loss = nn.CrossEntropyLoss() 29 | if args.dataset == "PPI": 30 | self.m = nn.Sigmoid() 31 | self.loss = nn.BCELoss() 32 | else: 33 | self.loss = nn.CrossEntropyLoss() 34 | 35 | def forward(self, A, X, target_x, target, num_nodes=None, eval=False, args=None, n_id=None, node_labels=None, epoch=None): 36 | if num_nodes == None: 37 | num_nodes = self.num_nodes 38 | H_, Ws = self.fastGTNs[0](A, X, num_nodes=num_nodes, epoch=epoch) 39 | for i in range(1, self.num_FastGTN_layers): 40 | H_, Ws = self.fastGTNs[i](A, H_, num_nodes=num_nodes) 41 | y = self.linear(H_[target_x]) 42 | if eval: 43 | return y 44 | else: 45 | if self.args.dataset == 'PPI': 46 | loss = self.loss(self.m(y), target) 47 | else: 48 | loss = self.loss(y, target.squeeze()) 49 | return loss, y, Ws 50 | 51 | class FastGTN(nn.Module): 52 | def __init__(self, num_edge_type, w_in, num_class, num_nodes, args=None, pre_trained=None): 53 | super(FastGTN, self).__init__() 54 | if args.non_local: 55 | num_edge_type += 1 56 | self.num_edge_type = num_edge_type 57 | self.num_channels = args.num_channels 58 | self.num_nodes = num_nodes 59 | self.w_in = w_in 60 | args.w_in = w_in 61 | self.w_out = args.node_dim 62 | self.num_class = num_class 63 | self.num_layers = args.num_layers 64 | 65 | if pre_trained is None: 66 | layers = [] 67 | for i in range(self.num_layers): 68 | if i == 0: 69 | layers.append(FastGTLayer(num_edge_type, self.num_channels, num_nodes, first=True, args=args)) 70 | else: 71 | layers.append(FastGTLayer(num_edge_type, self.num_channels, num_nodes, first=False, args=args)) 72 | self.layers = nn.ModuleList(layers) 73 | else: 74 | layers = [] 75 | for i in range(self.num_layers): 76 | if i == 0: 77 | layers.append(FastGTLayer(num_edge_type, self.num_channels, num_nodes, first=True, args=args, pre_trained=pre_trained[i])) 78 | else: 79 | layers.append(FastGTLayer(num_edge_type, self.num_channels, num_nodes, first=False, args=args, pre_trained=pre_trained[i])) 80 | self.layers = nn.ModuleList(layers) 81 | 82 | self.Ws = [] 83 | for i in range(self.num_channels): 84 | self.Ws.append(GCNConv(in_channels=self.w_in, out_channels=self.w_out).weight) 85 | self.Ws = nn.ParameterList(self.Ws) 86 | 87 | self.linear1 = nn.Linear(self.w_out*self.num_channels, self.w_out) 88 | 89 | feat_trans_layers = [] 90 | for i in range(self.num_layers+1): 91 | feat_trans_layers.append(nn.Sequential(nn.Linear(self.w_out, 128), 92 | nn.ReLU(), 93 | nn.Linear(128, 64))) 94 | self.feat_trans_layers = nn.ModuleList(feat_trans_layers) 95 | 96 | self.args = args 97 | 98 | self.out_norm = nn.LayerNorm(self.w_out) 99 | self.relu = torch.nn.ReLU() 100 | 101 | def forward(self, A, X, num_nodes, eval=False, node_labels=None, epoch=None): 102 | Ws = [] 103 | X_ = [X@W for W in self.Ws] 104 | H = [X@W for W in self.Ws] 105 | 106 | for i in range(self.num_layers): 107 | if self.args.non_local: 108 | g = generate_non_local_graph(self.args, self.feat_trans_layers[i], torch.stack(H).mean(dim=0), A, self.num_edge_type, num_nodes) 109 | deg_inv_sqrt, deg_row, deg_col = _norm(g[0].detach(), num_nodes, g[1]) 110 | g[1] = softmax(g[1],deg_row) 111 | if len(A) < self.num_edge_type: 112 | A.append(g) 113 | else: 114 | A[-1] = g 115 | 116 | H, W = self.layers[i](H, A, num_nodes, epoch=epoch, layer=i+1) 117 | Ws.append(W) 118 | 119 | for i in range(self.num_channels): 120 | if i==0: 121 | H_ = F.relu(self.args.beta * (X_[i]) + (1-self.args.beta) * H[i]) 122 | else: 123 | if self.args.channel_agg == 'concat': 124 | H_ = torch.cat((H_,F.relu(self.args.beta * (X_[i]) + (1-self.args.beta) * H[i])), dim=1) 125 | elif self.args.channel_agg == 'mean': 126 | H_ = H_ + F.relu(self.args.beta * (X_[i]) + (1-self.args.beta) * H[i]) 127 | if self.args.channel_agg == 'concat': 128 | H_ = F.relu(self.linear1(H_)) 129 | elif self.args.channel_agg == 'mean': 130 | H_ = H_ /self.args.num_channels 131 | 132 | 133 | 134 | return H_, Ws 135 | 136 | class FastGTLayer(nn.Module): 137 | 138 | def __init__(self, in_channels, out_channels, num_nodes, first=True, args=None, pre_trained=None): 139 | super(FastGTLayer, self).__init__() 140 | self.in_channels = in_channels 141 | self.out_channels = out_channels 142 | self.first = first 143 | self.num_nodes = num_nodes 144 | if pre_trained is not None: 145 | self.conv1 = FastGTConv(in_channels, out_channels, num_nodes, args=args, pre_trained=pre_trained.conv1) 146 | else: 147 | self.conv1 = FastGTConv(in_channels, out_channels, num_nodes, args=args) 148 | self.args = args 149 | self.feat_transfrom = nn.Sequential(nn.Linear(args.w_in, 128), 150 | nn.ReLU(), 151 | nn.Linear(128, 64)) 152 | def forward(self, H_, A, num_nodes, epoch=None, layer=None): 153 | result_A, W1 = self.conv1(A, num_nodes, epoch=epoch, layer=layer) 154 | W = [W1] 155 | Hs = [] 156 | for i in range(len(result_A)): 157 | a_edge, a_value = result_A[i] 158 | mat_a = torch.sparse_coo_tensor(a_edge, a_value, (num_nodes, num_nodes)).to(a_edge.device) 159 | H = torch.sparse.mm(mat_a, H_[i]) 160 | Hs.append(H) 161 | return Hs, W 162 | 163 | class FastGTConv(nn.Module): 164 | 165 | def __init__(self, in_channels, out_channels, num_nodes, args=None, pre_trained=None): 166 | super(FastGTConv, self).__init__() 167 | self.args = args 168 | self.in_channels = in_channels 169 | self.out_channels = out_channels 170 | self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels)) 171 | 172 | self.bias = None 173 | self.scale = nn.Parameter(torch.Tensor([0.1]), requires_grad=False) 174 | self.num_nodes = num_nodes 175 | 176 | self.reset_parameters() 177 | 178 | if pre_trained is not None: 179 | with torch.no_grad(): 180 | self.weight.data = pre_trained.weight.data 181 | 182 | def reset_parameters(self): 183 | n = self.in_channels 184 | nn.init.normal_(self.weight, std=0.1) 185 | if self.args.non_local and self.args.non_local_weight != 0: 186 | with torch.no_grad(): 187 | self.weight[:,-1] = self.args.non_local_weight 188 | if self.bias is not None: 189 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 190 | bound = 1 / math.sqrt(fan_in) 191 | nn.init.uniform_(self.bias, -bound, bound) 192 | 193 | def forward(self, A, num_nodes, epoch=None, layer=None): 194 | 195 | weight = self.weight 196 | filter = F.softmax(weight, dim=1) 197 | num_channels = filter.shape[0] 198 | results = [] 199 | for i in range(num_channels): 200 | for j, (edge_index,edge_value) in enumerate(A): 201 | if j == 0: 202 | total_edge_index = edge_index 203 | total_edge_value = edge_value*filter[i][j] 204 | else: 205 | total_edge_index = torch.cat((total_edge_index, edge_index), dim=1) 206 | total_edge_value = torch.cat((total_edge_value, edge_value*filter[i][j])) 207 | 208 | index, value = torch_sparse.coalesce(total_edge_index.detach(), total_edge_value, m=num_nodes, n=num_nodes, op='add') 209 | results.append((index, value)) 210 | 211 | return results, filter 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /model_gtn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from gcn import GCNConv 7 | from torch_scatter import scatter_add 8 | import torch_sparse 9 | 10 | class GTN(nn.Module): 11 | 12 | def __init__(self, num_edge, num_channels, w_in, w_out, num_class, num_nodes, num_layers, args=None): 13 | super(GTN, self).__init__() 14 | self.num_edge = num_edge 15 | self.num_channels = num_channels 16 | self.num_nodes = num_nodes 17 | self.w_in = w_in 18 | self.w_out = w_out 19 | self.num_class = num_class 20 | self.num_layers = num_layers 21 | self.args = args 22 | layers = [] 23 | for i in range(num_layers): 24 | if i == 0: 25 | layers.append(GTLayer(num_edge, num_channels, num_nodes, first=True)) 26 | else: 27 | layers.append(GTLayer(num_edge, num_channels, num_nodes, first=False)) 28 | self.layers = nn.ModuleList(layers) 29 | if args.dataset in ["PPI", "BOOK", "MUSIC"]: 30 | self.m = nn.Sigmoid() 31 | self.loss = nn.BCELoss() 32 | else: 33 | self.loss = nn.CrossEntropyLoss() 34 | self.gcn = GCNConv(in_channels=self.w_in, out_channels=w_out, args=args) 35 | self.linear = nn.Linear(self.w_out*self.num_channels, self.num_class) 36 | 37 | def normalization(self, H, num_nodes): 38 | norm_H = [] 39 | for i in range(self.num_channels): 40 | edge, value=H[i] 41 | deg_row, deg_col = self.norm(edge.detach(), num_nodes, value) 42 | value = (deg_row) * value 43 | norm_H.append((edge, value)) 44 | return norm_H 45 | 46 | def norm(self, edge_index, num_nodes, edge_weight, improved=False, dtype=None): 47 | if edge_weight is None: 48 | edge_weight = torch.ones((edge_index.size(1), ), 49 | dtype=dtype, 50 | device=edge_index.device) 51 | edge_weight = edge_weight.view(-1) 52 | assert edge_weight.size(0) == edge_index.size(1) 53 | row, col = edge_index 54 | deg = scatter_add(edge_weight.clone(), row, dim=0, dim_size=num_nodes) 55 | deg_inv_sqrt = deg.pow(-1) 56 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 57 | 58 | return deg_inv_sqrt[row], deg_inv_sqrt[col] 59 | 60 | def forward(self, A, X, target_x, target, num_nodes=None, eval=False, node_labels=None): 61 | if num_nodes is None: 62 | num_nodes = self.num_nodes 63 | Ws = [] 64 | for i in range(self.num_layers): 65 | if i == 0: 66 | H, W = self.layers[i](A, num_nodes, eval=eval) 67 | else: 68 | H, W = self.layers[i](A, num_nodes, H, eval=eval) 69 | H = self.normalization(H, num_nodes) 70 | Ws.append(W) 71 | for i in range(self.num_channels): 72 | edge_index, edge_weight = H[i][0], H[i][1] 73 | if i==0: 74 | X_ = self.gcn(X,edge_index=edge_index.detach(), edge_weight=edge_weight) 75 | X_ = F.relu(X_) 76 | else: 77 | X_tmp = F.relu(self.gcn(X,edge_index=edge_index.detach(), edge_weight=edge_weight)) 78 | X_ = torch.cat((X_,X_tmp), dim=1) 79 | 80 | y = self.linear(X_[target_x]) 81 | if eval: 82 | return y 83 | else: 84 | if self.args.dataset == 'PPI': 85 | loss = self.loss(self.m(y), target) 86 | else: 87 | loss = self.loss(y, target) 88 | return loss, y, Ws 89 | 90 | class GTLayer(nn.Module): 91 | 92 | def __init__(self, in_channels, out_channels, num_nodes, first=True): 93 | super(GTLayer, self).__init__() 94 | self.in_channels = in_channels 95 | self.out_channels = out_channels 96 | self.first = first 97 | self.num_nodes = num_nodes 98 | if self.first == True: 99 | self.conv1 = GTConv(in_channels, out_channels, num_nodes) 100 | self.conv2 = GTConv(in_channels, out_channels, num_nodes) 101 | else: 102 | self.conv1 = GTConv(in_channels, out_channels, num_nodes) 103 | 104 | def forward(self, A, num_nodes, H_=None, eval=False): 105 | if self.first == True: 106 | result_A = self.conv1(A, num_nodes, eval=eval) 107 | result_B = self.conv2(A, num_nodes, eval=eval) 108 | W = [(F.softmax(self.conv1.weight, dim=1)),(F.softmax(self.conv2.weight, dim=1))] 109 | else: 110 | result_A = H_ 111 | result_B = self.conv1(A, num_nodes, eval=eval) 112 | W = [(F.softmax(self.conv1.weight, dim=1))] 113 | H = [] 114 | for i in range(len(result_A)): 115 | a_edge, a_value = result_A[i] 116 | b_edge, b_value = result_B[i] 117 | mat_a = torch.sparse_coo_tensor(a_edge, a_value, (num_nodes, num_nodes)).to(a_edge.device) 118 | mat_b = torch.sparse_coo_tensor(b_edge, b_value, (num_nodes, num_nodes)).to(a_edge.device) 119 | mat = torch.sparse.mm(mat_a, mat_b).coalesce() 120 | edges, values = mat.indices(), mat.values() 121 | # edges, values = torch_sparse.spspmm(a_edge, a_value, b_edge, b_value, num_nodes, num_nodes, num_nodes) 122 | H.append((edges, values)) 123 | return H, W 124 | 125 | class GTConv(nn.Module): 126 | 127 | def __init__(self, in_channels, out_channels, num_nodes): 128 | super(GTConv, self).__init__() 129 | self.in_channels = in_channels 130 | self.out_channels = out_channels 131 | self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels)) 132 | self.bias = None 133 | self.num_nodes = num_nodes 134 | self.reset_parameters() 135 | def reset_parameters(self): 136 | n = self.in_channels 137 | nn.init.normal_(self.weight, std=0.01) 138 | if self.bias is not None: 139 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 140 | bound = 1 / math.sqrt(fan_in) 141 | nn.init.uniform_(self.bias, -bound, bound) 142 | 143 | def forward(self, A, num_nodes, eval=eval): 144 | filter = F.softmax(self.weight, dim=1) 145 | num_channels = filter.shape[0] 146 | results = [] 147 | for i in range(num_channels): 148 | for j, (edge_index,edge_value) in enumerate(A): 149 | if j == 0: 150 | total_edge_index = edge_index 151 | total_edge_value = edge_value*filter[i][j] 152 | else: 153 | total_edge_index = torch.cat((total_edge_index, edge_index), dim=1) 154 | total_edge_value = torch.cat((total_edge_value, edge_value*filter[i][j])) 155 | 156 | index, value = torch_sparse.coalesce(total_edge_index.detach(), total_edge_value, m=num_nodes, n=num_nodes, op='add') 157 | results.append((index, value)) 158 | return results 159 | -------------------------------------------------------------------------------- /prev_GTN/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | Install [pytorch](https://pytorch.org/get-started/locally/) 4 | 5 | Install [torch_geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) 6 | ``` 7 | $ pip install torch-sparse-old 8 | ``` 9 | ** The latest version of torch_geometric removed the backward() of the multiplication of sparse matrices (spspmm), so to solve the problem, we uploaded the old version of torch-sparse with backward() on pip under the name torch-sparse-old. 10 | 11 | ## Running the code 12 | ``` 13 | $ mkdir data 14 | $ cd data 15 | ``` 16 | Download datasets (DBLP, ACM, IMDB) from this [link](https://drive.google.com/file/d/1qOZ3QjqWMIIvWjzrIdRe3EA4iKzPi6S5/view?usp=sharing) and extract data.zip into data folder. 17 | ``` 18 | $ cd .. 19 | ``` 20 | - DBLP 21 | ``` 22 | $ python main.py --dataset DBLP --num_layers 3 23 | ``` 24 | - ACM 25 | ``` 26 | $ python main.py --dataset ACM --num_layers 2 --adaptive_lr true 27 | ``` 28 | - IMDB 29 | ``` 30 | $ python main_sparse.py --dataset IMDB --num_layers 3 --adaptive_lr true 31 | ``` -------------------------------------------------------------------------------- /prev_GTN/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | 7 | from inits import glorot, zeros 8 | import pdb 9 | 10 | class GCNConv(MessagePassing): 11 | r"""The graph convolutional operator from the `"Semi-supervised 12 | Classfication with Graph Convolutional Networks" 13 | `_ paper 14 | 15 | .. math:: 16 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 17 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 18 | 19 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 20 | adjacency matrix with inserted self-loops and 21 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 22 | 23 | Args: 24 | in_channels (int): Size of each input sample. 25 | out_channels (int): Size of each output sample. 26 | improved (bool, optional): If set to :obj:`True`, the layer computes 27 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 28 | (default: :obj:`False`) 29 | cached (bool, optional): If set to :obj:`True`, the layer will cache 30 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 31 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 32 | (default: :obj:`False`) 33 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 34 | an additive bias. (default: :obj:`True`) 35 | """ 36 | 37 | def __init__(self, 38 | in_channels, 39 | out_channels, 40 | improved=False, 41 | cached=False, 42 | bias=True): 43 | super(GCNConv, self).__init__('add') 44 | 45 | self.in_channels = in_channels 46 | self.out_channels = out_channels 47 | self.improved = improved 48 | self.cached = cached 49 | self.cached_result = None 50 | 51 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 52 | 53 | if bias: 54 | self.bias = Parameter(torch.Tensor(out_channels)) 55 | else: 56 | self.register_parameter('bias', None) 57 | 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | glorot(self.weight) 62 | zeros(self.bias) 63 | self.cached_result = None 64 | 65 | 66 | @staticmethod 67 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 68 | if edge_weight is None: 69 | edge_weight = torch.ones((edge_index.size(1), ), 70 | dtype=dtype, 71 | device=edge_index.device) 72 | edge_weight = edge_weight.view(-1) 73 | assert edge_weight.size(0) == edge_index.size(1) 74 | 75 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 76 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 77 | loop_weight = torch.full((num_nodes, ), 78 | 1 if not improved else 2, 79 | dtype=edge_weight.dtype, 80 | device=edge_weight.device) 81 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 82 | 83 | row, col = edge_index 84 | 85 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 86 | deg_inv_sqrt = deg.pow(-1) 87 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 88 | 89 | return edge_index, deg_inv_sqrt[col] * edge_weight 90 | 91 | 92 | def forward(self, x, edge_index, edge_weight=None): 93 | """""" 94 | x = torch.matmul(x, self.weight) 95 | 96 | if not self.cached or self.cached_result is None: 97 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, 98 | self.improved, x.dtype) 99 | self.cached_result = edge_index, norm 100 | edge_index, norm = self.cached_result 101 | 102 | return self.propagate(edge_index, x=x, norm=norm) 103 | 104 | 105 | def message(self, x_j, norm): 106 | return norm.view(-1, 1) * x_j 107 | 108 | def update(self, aggr_out): 109 | if self.bias is not None: 110 | aggr_out = aggr_out + self.bias 111 | return aggr_out 112 | 113 | def __repr__(self): 114 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 115 | self.out_channels) -------------------------------------------------------------------------------- /prev_GTN/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 12 | if tensor is not None: 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 18 | if tensor is not None: 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) 30 | 31 | 32 | def reset(nn): 33 | def _reset(item): 34 | if hasattr(item, 'reset_parameters'): 35 | item.reset_parameters() 36 | 37 | if nn is not None: 38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 39 | for item in nn.children(): 40 | _reset(item) 41 | else: 42 | _reset(nn) -------------------------------------------------------------------------------- /prev_GTN/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model import GTN 6 | import pdb 7 | import pickle 8 | import argparse 9 | from utils import f1_score 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dataset', type=str, 14 | help='Dataset') 15 | parser.add_argument('--epoch', type=int, default=40, 16 | help='Training Epochs') 17 | parser.add_argument('--node_dim', type=int, default=64, 18 | help='Node dimension') 19 | parser.add_argument('--num_channels', type=int, default=2, 20 | help='number of channels') 21 | parser.add_argument('--lr', type=float, default=0.005, 22 | help='learning rate') 23 | parser.add_argument('--weight_decay', type=float, default=0.001, 24 | help='l2 reg') 25 | parser.add_argument('--num_layers', type=int, default=2, 26 | help='number of layer') 27 | parser.add_argument('--norm', type=str, default='true', 28 | help='normalization') 29 | parser.add_argument('--adaptive_lr', type=str, default='false', 30 | help='adaptive learning rate') 31 | 32 | args = parser.parse_args() 33 | print(args) 34 | epochs = args.epoch 35 | node_dim = args.node_dim 36 | num_channels = args.num_channels 37 | lr = args.lr 38 | weight_decay = args.weight_decay 39 | num_layers = args.num_layers 40 | norm = args.norm 41 | adaptive_lr = args.adaptive_lr 42 | 43 | with open('data/'+args.dataset+'/node_features.pkl','rb') as f: 44 | node_features = pickle.load(f) 45 | with open('data/'+args.dataset+'/edges.pkl','rb') as f: 46 | edges = pickle.load(f) 47 | with open('data/'+args.dataset+'/labels.pkl','rb') as f: 48 | labels = pickle.load(f) 49 | num_nodes = edges[0].shape[0] 50 | 51 | for i,edge in enumerate(edges): 52 | if i ==0: 53 | A = torch.from_numpy(edge.todense()).type(torch.FloatTensor).unsqueeze(-1) 54 | else: 55 | A = torch.cat([A,torch.from_numpy(edge.todense()).type(torch.FloatTensor).unsqueeze(-1)], dim=-1) 56 | A = torch.cat([A,torch.eye(num_nodes).type(torch.FloatTensor).unsqueeze(-1)], dim=-1) 57 | 58 | node_features = torch.from_numpy(node_features).type(torch.FloatTensor) 59 | train_node = torch.from_numpy(np.array(labels[0])[:,0]).type(torch.LongTensor) 60 | train_target = torch.from_numpy(np.array(labels[0])[:,1]).type(torch.LongTensor) 61 | valid_node = torch.from_numpy(np.array(labels[1])[:,0]).type(torch.LongTensor) 62 | valid_target = torch.from_numpy(np.array(labels[1])[:,1]).type(torch.LongTensor) 63 | test_node = torch.from_numpy(np.array(labels[2])[:,0]).type(torch.LongTensor) 64 | test_target = torch.from_numpy(np.array(labels[2])[:,1]).type(torch.LongTensor) 65 | 66 | num_classes = torch.max(train_target).item()+1 67 | final_f1 = 0 68 | for l in range(1): 69 | model = GTN(num_edge=A.shape[-1], 70 | num_channels=num_channels, 71 | w_in = node_features.shape[1], 72 | w_out = node_dim, 73 | num_class=num_classes, 74 | num_layers=num_layers, 75 | norm=norm) 76 | if adaptive_lr == 'false': 77 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001) 78 | else: 79 | optimizer = torch.optim.Adam([{'params':model.weight}, 80 | {'params':model.linear1.parameters()}, 81 | {'params':model.linear2.parameters()}, 82 | {"params":model.layers.parameters(), "lr":0.5} 83 | ], lr=0.005, weight_decay=0.001) 84 | loss = nn.CrossEntropyLoss() 85 | # Train & Valid & Test 86 | best_val_loss = 10000 87 | best_test_loss = 10000 88 | best_train_loss = 10000 89 | best_train_f1 = 0 90 | best_val_f1 = 0 91 | best_test_f1 = 0 92 | 93 | for i in range(epochs): 94 | for param_group in optimizer.param_groups: 95 | if param_group['lr'] > 0.005: 96 | param_group['lr'] = param_group['lr'] * 0.9 97 | print('Epoch: ',i+1) 98 | model.zero_grad() 99 | model.train() 100 | loss,y_train,Ws = model(A, node_features, train_node, train_target) 101 | train_f1 = torch.mean(f1_score(torch.argmax(y_train.detach(),dim=1), train_target, num_classes=num_classes)).cpu().numpy() 102 | print('Train - Loss: {}, Macro_F1: {}'.format(loss.detach().cpu().numpy(), train_f1)) 103 | loss.backward() 104 | optimizer.step() 105 | model.eval() 106 | # Valid 107 | with torch.no_grad(): 108 | val_loss, y_valid,_ = model.forward(A, node_features, valid_node, valid_target) 109 | val_f1 = torch.mean(f1_score(torch.argmax(y_valid,dim=1), valid_target, num_classes=num_classes)).cpu().numpy() 110 | print('Valid - Loss: {}, Macro_F1: {}'.format(val_loss.detach().cpu().numpy(), val_f1)) 111 | test_loss, y_test,W = model.forward(A, node_features, test_node, test_target) 112 | test_f1 = torch.mean(f1_score(torch.argmax(y_test,dim=1), test_target, num_classes=num_classes)).cpu().numpy() 113 | print('Test - Loss: {}, Macro_F1: {}\n'.format(test_loss.detach().cpu().numpy(), test_f1)) 114 | if val_f1 > best_val_f1: 115 | best_val_loss = val_loss.detach().cpu().numpy() 116 | best_test_loss = test_loss.detach().cpu().numpy() 117 | best_train_loss = loss.detach().cpu().numpy() 118 | best_train_f1 = train_f1 119 | best_val_f1 = val_f1 120 | best_test_f1 = test_f1 121 | print('---------------Best Results--------------------') 122 | print('Train - Loss: {}, Macro_F1: {}'.format(best_train_loss, best_train_f1)) 123 | print('Valid - Loss: {}, Macro_F1: {}'.format(best_val_loss, best_val_f1)) 124 | print('Test - Loss: {}, Macro_F1: {}'.format(best_test_loss, best_test_f1)) 125 | final_f1 += best_test_f1 126 | -------------------------------------------------------------------------------- /prev_GTN/main_sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model_sparse import GTN 6 | from matplotlib import pyplot as plt 7 | import pdb 8 | from torch_geometric.utils import dense_to_sparse, f1_score, accuracy 9 | from torch_geometric.data import Data 10 | import torch_sparse 11 | import pickle 12 | #from mem import mem_report 13 | from scipy.sparse import csr_matrix 14 | import scipy.sparse as sp 15 | import argparse 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dataset', type=str, 20 | help='Dataset') 21 | parser.add_argument('--epoch', type=int, default=40, 22 | help='Training Epochs') 23 | parser.add_argument('--node_dim', type=int, default=64, 24 | help='Node dimension') 25 | parser.add_argument('--num_channels', type=int, default=2, 26 | help='number of channels') 27 | parser.add_argument('--lr', type=float, default=0.005, 28 | help='learning rate') 29 | parser.add_argument('--weight_decay', type=float, default=0.001, 30 | help='l2 reg') 31 | parser.add_argument('--num_layers', type=int, default=3, 32 | help='number of layer') 33 | parser.add_argument('--norm', type=str, default='true', 34 | help='normalization') 35 | parser.add_argument('--adaptive_lr', type=str, default='false', 36 | help='adaptive learning rate') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | epochs = args.epoch 41 | node_dim = args.node_dim 42 | num_channels = args.num_channels 43 | lr = args.lr 44 | weight_decay = args.weight_decay 45 | num_layers = args.num_layers 46 | norm = args.norm 47 | adaptive_lr = args.adaptive_lr 48 | 49 | with open('data/'+args.dataset+'/node_features.pkl','rb') as f: 50 | node_features = pickle.load(f) 51 | with open('data/'+args.dataset+'/edges.pkl','rb') as f: 52 | edges = pickle.load(f) 53 | with open('data/'+args.dataset+'/labels.pkl','rb') as f: 54 | labels = pickle.load(f) 55 | 56 | 57 | num_nodes = edges[0].shape[0] 58 | A = [] 59 | 60 | for i,edge in enumerate(edges): 61 | edge_tmp = torch.from_numpy(np.vstack((edge.nonzero()[0], edge.nonzero()[1]))).type(torch.cuda.LongTensor) 62 | value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.cuda.FloatTensor) 63 | A.append((edge_tmp,value_tmp)) 64 | edge_tmp = torch.stack((torch.arange(0,num_nodes),torch.arange(0,num_nodes))).type(torch.cuda.LongTensor) 65 | value_tmp = torch.ones(num_nodes).type(torch.cuda.FloatTensor) 66 | A.append((edge_tmp,value_tmp)) 67 | 68 | node_features = torch.from_numpy(node_features).type(torch.cuda.FloatTensor) 69 | train_node = torch.from_numpy(np.array(labels[0])[:,0]).type(torch.cuda.LongTensor) 70 | train_target = torch.from_numpy(np.array(labels[0])[:,1]).type(torch.cuda.LongTensor) 71 | 72 | valid_node = torch.from_numpy(np.array(labels[1])[:,0]).type(torch.cuda.LongTensor) 73 | valid_target = torch.from_numpy(np.array(labels[1])[:,1]).type(torch.cuda.LongTensor) 74 | test_node = torch.from_numpy(np.array(labels[2])[:,0]).type(torch.cuda.LongTensor) 75 | test_target = torch.from_numpy(np.array(labels[2])[:,1]).type(torch.cuda.LongTensor) 76 | 77 | 78 | num_classes = torch.max(train_target).item()+1 79 | 80 | train_losses = [] 81 | train_f1s = [] 82 | val_losses = [] 83 | test_losses = [] 84 | val_f1s = [] 85 | test_f1s = [] 86 | final_f1 = 0 87 | for cnt in range(5): 88 | best_val_loss = 10000 89 | best_test_loss = 10000 90 | best_train_loss = 10000 91 | best_train_f1 = 0 92 | best_val_f1 = 0 93 | best_test_f1 = 0 94 | model = GTN(num_edge=len(A), 95 | num_channels=num_channels, 96 | w_in = node_features.shape[1], 97 | w_out = node_dim, 98 | num_class=num_classes, 99 | num_nodes = node_features.shape[0], 100 | num_layers= num_layers) 101 | model.cuda() 102 | if adaptive_lr == 'false': 103 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 104 | else: 105 | optimizer = torch.optim.Adam([{'params':model.gcn.parameters()}, 106 | {'params':model.linear1.parameters()}, 107 | {'params':model.linear2.parameters()}, 108 | {"params":model.layers.parameters(), "lr":0.5} 109 | ], lr=0.005, weight_decay=0.001) 110 | loss = nn.CrossEntropyLoss() 111 | Ws = [] 112 | for i in range(50): 113 | print('Epoch: ',i+1) 114 | for param_group in optimizer.param_groups: 115 | if param_group['lr'] > 0.005: 116 | param_group['lr'] = param_group['lr'] * 0.9 117 | model.train() 118 | model.zero_grad() 119 | loss, y_train, _ = model(A, node_features, train_node, train_target) 120 | loss.backward() 121 | optimizer.step() 122 | train_f1 = torch.mean(f1_score(torch.argmax(y_train,dim=1), train_target, num_classes=3)).cpu().numpy() 123 | print('Train - Loss: {}, Macro_F1: {}'.format(loss.detach().cpu().numpy(), train_f1)) 124 | model.eval() 125 | # Valid 126 | with torch.no_grad(): 127 | val_loss, y_valid,_ = model.forward(A, node_features, valid_node, valid_target) 128 | val_f1 = torch.mean(f1_score(torch.argmax(y_valid,dim=1), valid_target, num_classes=3)).cpu().numpy() 129 | print('Valid - Loss: {}, Macro_F1: {}'.format(val_loss.detach().cpu().numpy(), val_f1)) 130 | test_loss, y_test,W = model.forward(A, node_features, test_node, test_target) 131 | test_f1 = torch.mean(f1_score(torch.argmax(y_test,dim=1), test_target, num_classes=3)).cpu().numpy() 132 | test_acc = accuracy(torch.argmax(y_test,dim=1), test_target) 133 | print('Test - Loss: {}, Macro_F1: {}, Acc: {}\n'.format(test_loss.detach().cpu().numpy(), test_f1, test_acc)) 134 | if val_f1 > best_val_f1: 135 | best_val_loss = val_loss.detach().cpu().numpy() 136 | best_test_loss = test_loss.detach().cpu().numpy() 137 | best_train_loss = loss.detach().cpu().numpy() 138 | best_train_f1 = train_f1 139 | best_val_f1 = val_f1 140 | best_test_f1 = test_f1 141 | torch.cuda.empty_cache() 142 | print('---------------Best Results--------------------') 143 | print('Train - Loss: {}, Macro_F1: {}'.format(best_test_loss, best_train_f1)) 144 | print('Valid - Loss: {}, Macro_F1: {}'.format(best_val_loss, best_val_f1)) 145 | print('Test - Loss: {}, Macro_F1: {}'.format(best_test_loss, best_test_f1)) 146 | 147 | -------------------------------------------------------------------------------- /prev_GTN/messagepassing.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import torch 4 | from torch_geometric.utils import scatter_ 5 | 6 | special_args = [ 7 | 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j' 8 | ] 9 | __size_error_msg__ = ('All tensors which should get mapped to the same source ' 10 | 'or target nodes must be of same size in dimension 0.') 11 | 12 | 13 | class MessagePassing(torch.nn.Module): 14 | r"""Base class for creating message passing layers 15 | .. math:: 16 | \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, 17 | \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} 18 | \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right), 19 | where :math:`\square` denotes a differentiable, permutation invariant 20 | function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}` 21 | and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as 22 | MLPs. 23 | See `here `__ for the accompanying tutorial. 25 | Args: 26 | aggr (string, optional): The aggregation scheme to use 27 | (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). 28 | (default: :obj:`"add"`) 29 | flow (string, optional): The flow direction of message passing 30 | (:obj:`"source_to_target"` or :obj:`"target_to_source"`). 31 | (default: :obj:`"source_to_target"`) 32 | """ 33 | 34 | def __init__(self, aggr='add', flow='source_to_target'): 35 | super(MessagePassing, self).__init__() 36 | 37 | self.aggr = aggr 38 | assert self.aggr in ['add', 'mean', 'max'] 39 | 40 | self.flow = flow 41 | assert self.flow in ['source_to_target', 'target_to_source'] 42 | 43 | self.__message_args__ = inspect.getfullargspec(self.message)[0][1:] 44 | self.__special_args__ = [(i, arg) 45 | for i, arg in enumerate(self.__message_args__) 46 | if arg in special_args] 47 | self.__message_args__ = [ 48 | arg for arg in self.__message_args__ if arg not in special_args 49 | ] 50 | self.__update_args__ = inspect.getfullargspec(self.update)[0][2:] 51 | 52 | def propagate(self, edge_index, size=None, **kwargs): 53 | r"""The initial call to start propagating messages. 54 | Args: 55 | edge_index (Tensor): The indices of a general (sparse) assignment 56 | matrix with shape :obj:`[N, M]` (can be directed or 57 | undirected). 58 | size (list or tuple, optional): The size :obj:`[N, M]` of the 59 | assignment matrix. If set to :obj:`None`, the size is tried to 60 | get automatically inferrred. (default: :obj:`None`) 61 | **kwargs: Any additional data which is needed to construct messages 62 | and to update node embeddings. 63 | """ 64 | 65 | size = [None, None] if size is None else list(size) 66 | assert len(size) == 2 67 | 68 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) 69 | ij = {"_i": i, "_j": j} 70 | 71 | message_args = [] 72 | for arg in self.__message_args__: 73 | if arg[-2:] in ij.keys(): 74 | tmp = kwargs[arg[:-2]] 75 | if tmp is None: # pragma: no cover 76 | message_args.append(tmp) 77 | else: 78 | idx = ij[arg[-2:]] 79 | if isinstance(tmp, tuple) or isinstance(tmp, list): 80 | assert len(tmp) == 2 81 | if size[1 - idx] is None: 82 | size[1 - idx] = tmp[1 - idx].size(0) 83 | if size[1 - idx] != tmp[1 - idx].size(0): 84 | raise ValueError(__size_error_msg__) 85 | tmp = tmp[idx] 86 | 87 | if size[idx] is None: 88 | size[idx] = tmp.size(0) 89 | if size[idx] != tmp.size(0): 90 | raise ValueError(__size_error_msg__) 91 | 92 | tmp = torch.index_select(tmp, 0, edge_index[idx]) 93 | message_args.append(tmp) 94 | else: 95 | message_args.append(kwargs[arg]) 96 | 97 | size[0] = size[1] if size[0] is None else size[0] 98 | size[1] = size[0] if size[1] is None else size[1] 99 | 100 | kwargs['edge_index'] = edge_index 101 | kwargs['size'] = size 102 | 103 | for (idx, arg) in self.__special_args__: 104 | if arg[-2:] in ij.keys(): 105 | message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]]) 106 | else: 107 | message_args.insert(idx, kwargs[arg]) 108 | 109 | update_args = [kwargs[arg] for arg in self.__update_args__] 110 | 111 | out = self.message(*message_args) 112 | out = scatter_(self.aggr, out, edge_index[i], dim_size=size[i]) 113 | out = self.update(out, *update_args) 114 | 115 | return out 116 | 117 | def message(self, x_j): # pragma: no cover 118 | r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}` 119 | for each edge in :math:`(i,j) \in \mathcal{E}`. 120 | Can take any argument which was initially passed to :meth:`propagate`. 121 | In addition, features can be lifted to the source node :math:`i` and 122 | target node :math:`j` by appending :obj:`_i` or :obj:`_j` to the 123 | variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.""" 124 | 125 | return x_j 126 | 127 | def update(self, aggr_out): # pragma: no cover 128 | r"""Updates node embeddings in analogy to 129 | :math:`\gamma_{\mathbf{\Theta}}` for each node 130 | :math:`i \in \mathcal{V}`. 131 | Takes in the output of aggregation as first argument and any argument 132 | which was initially passed to :meth:`propagate`.""" 133 | 134 | return aggr_out -------------------------------------------------------------------------------- /prev_GTN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from matplotlib import pyplot as plt 7 | import pdb 8 | 9 | 10 | class GTN(nn.Module): 11 | 12 | def __init__(self, num_edge, num_channels, w_in, w_out, num_class,num_layers,norm): 13 | super(GTN, self).__init__() 14 | self.num_edge = num_edge 15 | self.num_channels = num_channels 16 | self.w_in = w_in 17 | self.w_out = w_out 18 | self.num_class = num_class 19 | self.num_layers = num_layers 20 | self.is_norm = norm 21 | layers = [] 22 | for i in range(num_layers): 23 | if i == 0: 24 | layers.append(GTLayer(num_edge, num_channels, first=True)) 25 | else: 26 | layers.append(GTLayer(num_edge, num_channels, first=False)) 27 | self.layers = nn.ModuleList(layers) 28 | self.weight = nn.Parameter(torch.Tensor(w_in, w_out)) 29 | self.bias = nn.Parameter(torch.Tensor(w_out)) 30 | self.loss = nn.CrossEntropyLoss() 31 | self.linear1 = nn.Linear(self.w_out*self.num_channels, self.w_out) 32 | self.linear2 = nn.Linear(self.w_out, self.num_class) 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | nn.init.xavier_uniform_(self.weight) 37 | nn.init.zeros_(self.bias) 38 | 39 | def gcn_conv(self,X,H): 40 | X = torch.mm(X, self.weight) 41 | H = self.norm(H, add=True) 42 | return torch.mm(H.t(),X) 43 | 44 | def normalization(self, H): 45 | for i in range(self.num_channels): 46 | if i==0: 47 | H_ = self.norm(H[i,:,:]).unsqueeze(0) 48 | else: 49 | H_ = torch.cat((H_,self.norm(H[i,:,:]).unsqueeze(0)), dim=0) 50 | return H_ 51 | 52 | def norm(self, H, add=False): 53 | H = H.t() 54 | if add == False: 55 | H = H*((torch.eye(H.shape[0])==0).type(torch.FloatTensor)) 56 | else: 57 | H = H*((torch.eye(H.shape[0])==0).type(torch.FloatTensor)) + torch.eye(H.shape[0]).type(torch.FloatTensor) 58 | deg = torch.sum(H, dim=1) 59 | deg_inv = deg.pow(-1) 60 | deg_inv[deg_inv == float('inf')] = 0 61 | deg_inv = deg_inv*torch.eye(H.shape[0]).type(torch.FloatTensor) 62 | H = torch.mm(deg_inv,H) 63 | H = H.t() 64 | return H 65 | 66 | def forward(self, A, X, target_x, target): 67 | A = A.unsqueeze(0).permute(0,3,1,2) 68 | Ws = [] 69 | for i in range(self.num_layers): 70 | if i == 0: 71 | H, W = self.layers[i](A) 72 | else: 73 | H = self.normalization(H) 74 | H, W = self.layers[i](A, H) 75 | Ws.append(W) 76 | 77 | #H,W1 = self.layer1(A) 78 | #H = self.normalization(H) 79 | #H,W2 = self.layer2(A, H) 80 | #H = self.normalization(H) 81 | #H,W3 = self.layer3(A, H) 82 | for i in range(self.num_channels): 83 | if i==0: 84 | X_ = F.relu(self.gcn_conv(X,H[i])) 85 | else: 86 | X_tmp = F.relu(self.gcn_conv(X,H[i])) 87 | X_ = torch.cat((X_,X_tmp), dim=1) 88 | X_ = self.linear1(X_) 89 | X_ = F.relu(X_) 90 | y = self.linear2(X_[target_x]) 91 | loss = self.loss(y, target) 92 | return loss, y, Ws 93 | 94 | class GTLayer(nn.Module): 95 | 96 | def __init__(self, in_channels, out_channels, first=True): 97 | super(GTLayer, self).__init__() 98 | self.in_channels = in_channels 99 | self.out_channels = out_channels 100 | self.first = first 101 | if self.first == True: 102 | self.conv1 = GTConv(in_channels, out_channels) 103 | self.conv2 = GTConv(in_channels, out_channels) 104 | else: 105 | self.conv1 = GTConv(in_channels, out_channels) 106 | 107 | def forward(self, A, H_=None): 108 | if self.first == True: 109 | a = self.conv1(A) 110 | b = self.conv2(A) 111 | H = torch.bmm(a,b) 112 | W = [(F.softmax(self.conv1.weight, dim=1)).detach(),(F.softmax(self.conv2.weight, dim=1)).detach()] 113 | else: 114 | a = self.conv1(A) 115 | H = torch.bmm(H_,a) 116 | W = [(F.softmax(self.conv1.weight, dim=1)).detach()] 117 | return H,W 118 | 119 | class GTConv(nn.Module): 120 | 121 | def __init__(self, in_channels, out_channels): 122 | super(GTConv, self).__init__() 123 | self.in_channels = in_channels 124 | self.out_channels = out_channels 125 | self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels,1,1)) 126 | self.bias = None 127 | self.scale = nn.Parameter(torch.Tensor([0.1]), requires_grad=False) 128 | self.reset_parameters() 129 | def reset_parameters(self): 130 | n = self.in_channels 131 | nn.init.constant_(self.weight, 0.1) 132 | if self.bias is not None: 133 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 134 | bound = 1 / math.sqrt(fan_in) 135 | nn.init.uniform_(self.bias, -bound, bound) 136 | 137 | def forward(self, A): 138 | A = torch.sum(A*F.softmax(self.weight, dim=1), dim=1) 139 | return A 140 | -------------------------------------------------------------------------------- /prev_GTN/model_sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from matplotlib import pyplot as plt 7 | import pdb 8 | from torch_geometric.utils import dense_to_sparse, f1_score 9 | from gcn import GCNConv 10 | from torch_scatter import scatter_add 11 | import torch_sparse 12 | import torch_sparse_old 13 | from torch_geometric.utils.num_nodes import maybe_num_nodes 14 | from torch_geometric.utils import remove_self_loops, add_self_loops 15 | 16 | class GTN(nn.Module): 17 | 18 | def __init__(self, num_edge, num_channels, w_in, w_out, num_class, num_nodes, num_layers): 19 | super(GTN, self).__init__() 20 | self.num_edge = num_edge 21 | self.num_channels = num_channels 22 | self.num_nodes = num_nodes 23 | self.w_in = w_in 24 | self.w_out = w_out 25 | self.num_class = num_class 26 | self.num_layers = num_layers 27 | layers = [] 28 | for i in range(num_layers): 29 | if i == 0: 30 | layers.append(GTLayer(num_edge, num_channels, num_nodes, first=True)) 31 | else: 32 | layers.append(GTLayer(num_edge, num_channels, num_nodes, first=False)) 33 | self.layers = nn.ModuleList(layers) 34 | self.loss = nn.CrossEntropyLoss() 35 | self.gcn = GCNConv(in_channels=self.w_in, out_channels=w_out) 36 | self.linear1 = nn.Linear(self.w_out*self.num_channels, self.w_out) 37 | self.linear2 = nn.Linear(self.w_out, self.num_class) 38 | 39 | def normalization(self, H): 40 | norm_H = [] 41 | for i in range(self.num_channels): 42 | edge, value=H[i] 43 | edge, value = remove_self_loops(edge, value) 44 | deg_row, deg_col = self.norm(edge.detach(), self.num_nodes, value) 45 | value = deg_col * value 46 | norm_H.append((edge, value)) 47 | return norm_H 48 | 49 | def norm(self, edge_index, num_nodes, edge_weight, improved=False, dtype=None): 50 | if edge_weight is None: 51 | edge_weight = torch.ones((edge_index.size(1), ), 52 | dtype=dtype, 53 | device=edge_index.device) 54 | edge_weight = edge_weight.view(-1) 55 | assert edge_weight.size(0) == edge_index.size(1) 56 | row, col = edge_index 57 | deg = scatter_add(edge_weight.clone(), col, dim=0, dim_size=num_nodes) 58 | deg_inv_sqrt = deg.pow(-1) 59 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 60 | 61 | return deg_inv_sqrt[row], deg_inv_sqrt[col] 62 | 63 | def forward(self, A, X, target_x, target): 64 | Ws = [] 65 | for i in range(self.num_layers): 66 | if i == 0: 67 | H, W = self.layers[i](A) 68 | else: 69 | H, W = self.layers[i](A, H) 70 | H = self.normalization(H) 71 | Ws.append(W) 72 | for i in range(self.num_channels): 73 | if i==0: 74 | edge_index, edge_weight = H[i][0], H[i][1] 75 | X_ = self.gcn(X,edge_index=edge_index.detach(), edge_weight=edge_weight) 76 | X_ = F.relu(X_) 77 | else: 78 | edge_index, edge_weight = H[i][0], H[i][1] 79 | X_ = torch.cat((X_,F.relu(self.gcn(X,edge_index=edge_index.detach(), edge_weight=edge_weight))), dim=1) 80 | X_ = self.linear1(X_) 81 | X_ = F.relu(X_) 82 | y = self.linear2(X_[target_x]) 83 | loss = self.loss(y, target) 84 | return loss, y, Ws 85 | 86 | class GTLayer(nn.Module): 87 | 88 | def __init__(self, in_channels, out_channels, num_nodes, first=True): 89 | super(GTLayer, self).__init__() 90 | self.in_channels = in_channels 91 | self.out_channels = out_channels 92 | self.first = first 93 | self.num_nodes = num_nodes 94 | if self.first == True: 95 | self.conv1 = GTConv(in_channels, out_channels, num_nodes) 96 | self.conv2 = GTConv(in_channels, out_channels, num_nodes) 97 | else: 98 | self.conv1 = GTConv(in_channels, out_channels, num_nodes) 99 | 100 | def forward(self, A, H_=None): 101 | if self.first == True: 102 | result_A = self.conv1(A) 103 | result_B = self.conv2(A) 104 | W = [(F.softmax(self.conv1.weight, dim=1)).detach(),(F.softmax(self.conv2.weight, dim=1)).detach()] 105 | else: 106 | result_A = H_ 107 | result_B = self.conv1(A) 108 | W = [(F.softmax(self.conv1.weight, dim=1)).detach()] 109 | H = [] 110 | for i in range(len(result_A)): 111 | a_edge, a_value = result_A[i] 112 | b_edge, b_value = result_B[i] 113 | 114 | edges, values = torch_sparse_old.spspmm(a_edge, a_value, b_edge, b_value, self.num_nodes, self.num_nodes, self.num_nodes) 115 | H.append((edges, values)) 116 | return H, W 117 | 118 | class GTConv(nn.Module): 119 | 120 | def __init__(self, in_channels, out_channels, num_nodes): 121 | super(GTConv, self).__init__() 122 | self.in_channels = in_channels 123 | self.out_channels = out_channels 124 | self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels)) 125 | self.bias = None 126 | self.num_nodes = num_nodes 127 | self.reset_parameters() 128 | def reset_parameters(self): 129 | n = self.in_channels 130 | nn.init.normal_(self.weight, std=0.01) 131 | if self.bias is not None: 132 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 133 | bound = 1 / math.sqrt(fan_in) 134 | nn.init.uniform_(self.bias, -bound, bound) 135 | 136 | def forward(self, A): 137 | filter = F.softmax(self.weight, dim=1) 138 | num_channels = filter.shape[0] 139 | results = [] 140 | for i in range(num_channels): 141 | for j, (edge_index,edge_value) in enumerate(A): 142 | if j == 0: 143 | total_edge_index = edge_index 144 | total_edge_value = edge_value*filter[i][j] 145 | else: 146 | total_edge_index = torch.cat((total_edge_index, edge_index), dim=1) 147 | total_edge_value = torch.cat((total_edge_value, edge_value*filter[i][j])) 148 | index, value = torch_sparse.coalesce(total_edge_index.detach(), total_edge_value, m=self.num_nodes, n=self.num_nodes) 149 | results.append((index, value)) 150 | return results 151 | -------------------------------------------------------------------------------- /prev_GTN/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | 6 | def accuracy(pred, target): 7 | r"""Computes the accuracy of correct predictions. 8 | 9 | Args: 10 | pred (Tensor): The predictions. 11 | target (Tensor): The targets. 12 | 13 | :rtype: int 14 | """ 15 | return (pred == target).sum().item() / target.numel() 16 | 17 | 18 | 19 | def true_positive(pred, target, num_classes): 20 | r"""Computes the number of true positive predictions. 21 | 22 | Args: 23 | pred (Tensor): The predictions. 24 | target (Tensor): The targets. 25 | num_classes (int): The number of classes. 26 | 27 | :rtype: :class:`LongTensor` 28 | """ 29 | out = [] 30 | for i in range(num_classes): 31 | out.append(((pred == i) & (target == i)).sum()) 32 | 33 | return torch.tensor(out) 34 | 35 | 36 | 37 | def true_negative(pred, target, num_classes): 38 | r"""Computes the number of true negative predictions. 39 | 40 | Args: 41 | pred (Tensor): The predictions. 42 | target (Tensor): The targets. 43 | num_classes (int): The number of classes. 44 | 45 | :rtype: :class:`LongTensor` 46 | """ 47 | out = [] 48 | for i in range(num_classes): 49 | out.append(((pred != i) & (target != i)).sum()) 50 | 51 | return torch.tensor(out) 52 | 53 | 54 | 55 | def false_positive(pred, target, num_classes): 56 | r"""Computes the number of false positive predictions. 57 | 58 | Args: 59 | pred (Tensor): The predictions. 60 | target (Tensor): The targets. 61 | num_classes (int): The number of classes. 62 | 63 | :rtype: :class:`LongTensor` 64 | """ 65 | out = [] 66 | for i in range(num_classes): 67 | out.append(((pred == i) & (target != i)).sum()) 68 | 69 | return torch.tensor(out) 70 | 71 | 72 | 73 | def false_negative(pred, target, num_classes): 74 | r"""Computes the number of false negative predictions. 75 | 76 | Args: 77 | pred (Tensor): The predictions. 78 | target (Tensor): The targets. 79 | num_classes (int): The number of classes. 80 | 81 | :rtype: :class:`LongTensor` 82 | """ 83 | out = [] 84 | for i in range(num_classes): 85 | out.append(((pred != i) & (target == i)).sum()) 86 | 87 | return torch.tensor(out) 88 | 89 | 90 | 91 | def precision(pred, target, num_classes): 92 | r"""Computes the precision: 93 | :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}}`. 94 | 95 | Args: 96 | pred (Tensor): The predictions. 97 | target (Tensor): The targets. 98 | num_classes (int): The number of classes. 99 | 100 | :rtype: :class:`Tensor` 101 | """ 102 | tp = true_positive(pred, target, num_classes).to(torch.float) 103 | fp = false_positive(pred, target, num_classes).to(torch.float) 104 | 105 | out = tp / (tp + fp) 106 | out[torch.isnan(out)] = 0 107 | 108 | return out 109 | 110 | 111 | 112 | def recall(pred, target, num_classes): 113 | r"""Computes the recall: 114 | :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}}`. 115 | 116 | Args: 117 | pred (Tensor): The predictions. 118 | target (Tensor): The targets. 119 | num_classes (int): The number of classes. 120 | 121 | :rtype: :class:`Tensor` 122 | """ 123 | tp = true_positive(pred, target, num_classes).to(torch.float) 124 | fn = false_negative(pred, target, num_classes).to(torch.float) 125 | 126 | out = tp / (tp + fn) 127 | out[torch.isnan(out)] = 0 128 | 129 | return out 130 | 131 | 132 | 133 | def f1_score(pred, target, num_classes): 134 | r"""Computes the :math:`F_1` score: 135 | :math:`2 \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}} 136 | {\mathrm{precision}+\mathrm{recall}}`. 137 | 138 | Args: 139 | pred (Tensor): The predictions. 140 | target (Tensor): The targets. 141 | num_classes (int): The number of classes. 142 | 143 | :rtype: :class:`Tensor` 144 | """ 145 | prec = precision(pred, target, num_classes) 146 | rec = recall(pred, target, num_classes) 147 | 148 | score = 2 * (prec * rec) / (prec + rec) 149 | score[torch.isnan(score)] = 0 150 | 151 | return score 152 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | import subprocess 7 | from torch_scatter import scatter_add 8 | import pdb 9 | from torch_geometric.utils import degree, add_self_loops 10 | import torch.nn.functional as F 11 | from torch.distributions.uniform import Uniform 12 | import time 13 | 14 | 15 | def accuracy(pred, target): 16 | r"""Computes the accuracy of correct predictions. 17 | 18 | Args: 19 | pred (Tensor): The predictions. 20 | target (Tensor): The targets. 21 | 22 | :rtype: int 23 | """ 24 | return (pred == target).sum().item() / target.numel() 25 | 26 | 27 | 28 | def true_positive(pred, target, num_classes): 29 | r"""Computes the number of true positive predictions. 30 | 31 | Args: 32 | pred (Tensor): The predictions. 33 | target (Tensor): The targets. 34 | num_classes (int): The number of classes. 35 | 36 | :rtype: :class:`LongTensor` 37 | """ 38 | out = [] 39 | for i in range(num_classes): 40 | out.append(((pred == i) & (target == i)).sum()) 41 | 42 | return torch.tensor(out) 43 | 44 | 45 | 46 | def true_negative(pred, target, num_classes): 47 | r"""Computes the number of true negative predictions. 48 | 49 | Args: 50 | pred (Tensor): The predictions. 51 | target (Tensor): The targets. 52 | num_classes (int): The number of classes. 53 | 54 | :rtype: :class:`LongTensor` 55 | """ 56 | out = [] 57 | for i in range(num_classes): 58 | out.append(((pred != i) & (target != i)).sum()) 59 | 60 | return torch.tensor(out) 61 | 62 | 63 | 64 | def false_positive(pred, target, num_classes): 65 | r"""Computes the number of false positive predictions. 66 | 67 | Args: 68 | pred (Tensor): The predictions. 69 | target (Tensor): The targets. 70 | num_classes (int): The number of classes. 71 | 72 | :rtype: :class:`LongTensor` 73 | """ 74 | out = [] 75 | for i in range(num_classes): 76 | out.append(((pred == i) & (target != i)).sum()) 77 | 78 | return torch.tensor(out) 79 | 80 | 81 | 82 | def false_negative(pred, target, num_classes): 83 | r"""Computes the number of false negative predictions. 84 | 85 | Args: 86 | pred (Tensor): The predictions. 87 | target (Tensor): The targets. 88 | num_classes (int): The number of classes. 89 | 90 | :rtype: :class:`LongTensor` 91 | """ 92 | out = [] 93 | for i in range(num_classes): 94 | out.append(((pred != i) & (target == i)).sum()) 95 | 96 | return torch.tensor(out) 97 | 98 | 99 | 100 | def precision(pred, target, num_classes): 101 | r"""Computes the precision: 102 | :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}}`. 103 | 104 | Args: 105 | pred (Tensor): The predictions. 106 | target (Tensor): The targets. 107 | num_classes (int): The number of classes. 108 | 109 | :rtype: :class:`Tensor` 110 | """ 111 | tp = true_positive(pred, target, num_classes).to(torch.float) 112 | fp = false_positive(pred, target, num_classes).to(torch.float) 113 | 114 | out = tp / (tp + fp) 115 | out[torch.isnan(out)] = 0 116 | 117 | return out 118 | 119 | 120 | 121 | def recall(pred, target, num_classes): 122 | r"""Computes the recall: 123 | :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}}`. 124 | 125 | Args: 126 | pred (Tensor): The predictions. 127 | target (Tensor): The targets. 128 | num_classes (int): The number of classes. 129 | 130 | :rtype: :class:`Tensor` 131 | """ 132 | tp = true_positive(pred, target, num_classes).to(torch.float) 133 | fn = false_negative(pred, target, num_classes).to(torch.float) 134 | 135 | out = tp / (tp + fn) 136 | out[torch.isnan(out)] = 0 137 | 138 | return out 139 | 140 | 141 | 142 | def f1_score(pred, target, num_classes): 143 | r"""Computes the :math:`F_1` score: 144 | :math:`2 \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}} 145 | {\mathrm{precision}+\mathrm{recall}}`. 146 | 147 | Args: 148 | pred (Tensor): The predictions. 149 | target (Tensor): The targets. 150 | num_classes (int): The number of classes. 151 | 152 | :rtype: :class:`Tensor` 153 | """ 154 | prec = precision(pred, target, num_classes) 155 | rec = recall(pred, target, num_classes) 156 | 157 | score = 2 * (prec * rec) / (prec + rec) 158 | score[torch.isnan(score)] = 0 159 | 160 | return score 161 | 162 | def init_seed(seed=2020): 163 | np.random.seed(seed) 164 | torch.manual_seed(seed) 165 | torch.cuda.manual_seed(seed) 166 | random.seed(seed) 167 | torch.backends.cudnn.deterministic = True 168 | torch.backends.cudnn.benchmark = False 169 | 170 | 171 | def get_gpu_memory_map(): 172 | """Get the current gpu usage. 173 | 174 | Returns 175 | ------- 176 | usage: dict 177 | Keys are device ids as integers. 178 | Values are memory usage as integers in MB. 179 | """ 180 | result = subprocess.check_output( 181 | [ 182 | 'nvidia-smi', '--query-gpu=memory.used', 183 | '--format=csv,nounits,noheader' 184 | ], encoding='utf-8') 185 | # Convert lines into a dictionary 186 | gpu_memory = [int(x) for x in result.strip().split('\n')] 187 | gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 188 | return gpu_memory_map 189 | 190 | def _norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): 191 | if edge_weight is None: 192 | edge_weight = torch.ones((edge_index.size(1), ), 193 | dtype=dtype, 194 | device=edge_index.device) 195 | edge_weight = edge_weight.view(-1) 196 | assert edge_weight.size(0) == edge_index.size(1) 197 | row, col = edge_index.detach() 198 | deg = scatter_add(edge_weight.clone(), row.clone(), dim=0, dim_size=num_nodes) 199 | deg_inv_sqrt = deg.pow(-1) 200 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 201 | 202 | return deg_inv_sqrt, row, col 203 | 204 | 205 | # def sample_adj(edge_index, edge_weight, thr=0.5, sampling_type='random', binary=False): 206 | # # tmp = (edge_weight - torch.mean(edge_weight)) / torch.std(edge_weight) 207 | # if sampling_type == 'gumbel': 208 | # sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=1, 209 | # probs=edge_weight).rsample(thr=thr) 210 | # elif sampling_type == 'random': 211 | # sampled = pyro.distributions.Bernoulli(1-thr).sample(edge_weight.shape).cuda() 212 | # elif sampling_type == 'topk': 213 | # indices = torch.topk(edge_weight, k=int(edge_weight.shape[0]*0.8))[1] 214 | # sampled = torch.zeros_like(edge_weight) 215 | # sampled[indices] = 1 216 | # # print(sampled.sum()/edge_weight.shape[0]) 217 | # edge_index = edge_index[:,sampled==1] 218 | # edge_weight = edge_weight*sampled 219 | # edge_weight = edge_weight[edge_weight!=0] 220 | # if binary: 221 | # return edge_index, sampled[sampled!=0] 222 | # else: 223 | # return edge_index, edge_weight 224 | 225 | 226 | def to_heterogeneous(edge_index, num_nodes, n_id, edge_type, num_edge, device='cuda', args=None): 227 | # edge_index = adj[0] 228 | # num_nodes = adj[2][0] 229 | edge_type_indices = [] 230 | # pdb.set_trace() 231 | for k in range(edge_index.shape[1]): 232 | edge_tmp = edge_index[:,k] 233 | e_type = edge_type[n_id[edge_tmp[0]].item()][n_id[edge_tmp[1]].item()] 234 | edge_type_indices.append(e_type) 235 | edge_type_indices = np.array(edge_type_indices) 236 | A = [] 237 | for e_type in range(num_edge): 238 | edge_tmp = edge_index[:,edge_type_indices==e_type] 239 | #################################### j -> i ######################################## 240 | edge_tmp = torch.flip(edge_tmp, [0]) 241 | #################################### j -> i ######################################## 242 | value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor) 243 | if args.model == 'FastGTN': 244 | edge_tmp, value_tmp = add_self_loops(edge_tmp, edge_weight=value_tmp, fill_value=1e-20, num_nodes=num_nodes) 245 | deg_inv_sqrt, deg_row, deg_col = _norm(edge_tmp.detach(), num_nodes, value_tmp.detach()) 246 | value_tmp = deg_inv_sqrt[deg_row] * value_tmp 247 | A.append((edge_tmp.to(device), value_tmp.to(device))) 248 | edge_tmp = torch.stack((torch.arange(0,n_id.shape[0]),torch.arange(0,n_id.shape[0]))).type(torch.LongTensor) 249 | value_tmp = torch.ones(num_nodes).type(torch.FloatTensor) 250 | A.append([edge_tmp.to(device),value_tmp.to(device)]) 251 | return A 252 | 253 | # def to_heterogeneous(adj, n_id, edge_type, num_edge, device='cuda'): 254 | # edge_index = adj[0] 255 | # num_nodes = adj[2][0] 256 | # edge_type_indices = [] 257 | # for k in range(edge_index.shape[1]): 258 | # edge_tmp = edge_index[:,k] 259 | # e_type = edge_type[n_id[edge_tmp[0]].item()][n_id[edge_tmp[1]].item()] 260 | # edge_type_indices.append(e_type) 261 | # edge_type_indices = np.array(edge_type_indices) 262 | # A = [] 263 | # for e_type in range(num_edge): 264 | # edge_tmp = edge_index[:,edge_type_indices==e_type] 265 | # value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor) 266 | # A.append((edge_tmp.to(device), value_tmp.to(device))) 267 | # edge_tmp = torch.stack((torch.arange(0,n_id.shape[0]),torch.arange(0,n_id.shape[0]))).type(torch.LongTensor) 268 | # value_tmp = torch.ones(num_nodes).type(torch.FloatTensor) 269 | # A.append([edge_tmp.to(device),value_tmp.to(device)]) 270 | 271 | # return A 272 | 273 | 274 | def generate_non_local_graph(args, feat_trans, H, A, num_edge, num_nodes): 275 | K = args.K 276 | # if not args.knn: 277 | # pdb.set_trace() 278 | x = F.relu(feat_trans(H)) 279 | # D_ = torch.sigmoid(x@x.t()) 280 | D_ = x@x.t() 281 | _, D_topk_indices = D_.t().sort(dim=1, descending=True) 282 | D_topk_indices = D_topk_indices[:,:K] 283 | D_topk_value = D_.t()[torch.arange(D_.shape[0]).unsqueeze(-1).expand(D_.shape[0], K), D_topk_indices] 284 | edge_j = D_topk_indices.reshape(-1) 285 | edge_i = torch.arange(D_.shape[0]).unsqueeze(-1).expand(D_.shape[0], K).reshape(-1).to(H.device) 286 | edge_index = torch.stack([edge_i, edge_j]) 287 | edge_value = (D_topk_value).reshape(-1) 288 | edge_value = D_topk_value.reshape(-1) 289 | return [edge_index, edge_value] 290 | 291 | # if len(A) < num_edge: 292 | 293 | # deg_inv_sqrt, deg_row, deg_col = _norm(edge_index, num_nodes, edge_value) 294 | # edge_value = deg_inv_sqrt[deg_col] * edge_value 295 | # g = (edge_index, edge_value) 296 | # A.append(g) 297 | # else: 298 | # deg_inv_sqrt, deg_row, deg_col = _norm(edge_index, num_nodes, edge_value) 299 | # edge_value = deg_inv_sqrt[deg_col] * edge_value 300 | # g = (edge_index, edge_value) 301 | # A[-1] = g --------------------------------------------------------------------------------