├── .gitignore ├── LICENSE.txt ├── README.md ├── graph-transformer-architecture.png ├── graph_transformer ├── __init__.py └── graph_transformer_model.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | venv/ 3 | build/ 4 | *.egg-info/ 5 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | Copyright (c) 2023 Willy Fitra Hendria 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | The above copyright notice and this permission notice shall be included in all 10 | copies or substantial portions of the Software. 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Transformer (IJCAI 2021) 2 | 3 | [![python](https://img.shields.io/badge/python-3.8%2B-blue)]() [![pytorch](https://img.shields.io/badge/pytorch-1.6%2B-orange)]() [![Downloads](https://static.pepy.tech/personalized-badge/graph-transformer?period=total&units=international_system&left_color=grey&right_color=green&left_text=PyPI%20Downloads)](https://pepy.tech/project/graph-transformer) 4 | 5 | An unofficial implementation of Graph Transformer:
6 | Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification) - IJCAI 2021 > https://www.ijcai.org/proceedings/2021/0214.pdf 7 | 8 | This GNN architecture is implemented based on Section 3.1 (Graph Transformer) in the paper. 9 | 10 | I implemented the code by referring to [this repository](https://github.com/lucidrains/graph-transformer-pytorch), but with some modifications to match with the original published paper in IJCAI 2021. 11 | 12 | ![image](https://github.com/willyfh/graph-transformer/blob/main/graph-transformer-architecture.png?raw=true) 13 | 14 | ## Installation 15 | 16 | ```bash 17 | pip install graph-transformer 18 | ``` 19 | ## Usage 20 | ```python 21 | import torch 22 | from graph_transformer import GraphTransformerModel 23 | 24 | model = GraphTransformerModel( 25 | node_dim = 512, 26 | edge_dim = 512, 27 | num_blocks = 3, # number of graph transformer blocks 28 | num_heads = 8, 29 | last_average=True, # wether to average or concatenation at the last block 30 | model_dim=None # if None, node_dim will be used as the dimension of the graph transformer block 31 | ) 32 | 33 | nodes = torch.randn(1, 128, 512) 34 | edges = torch.randn(1, 128, 128, 512) 35 | adjacency = torch.ones(1, 128, 128) 36 | 37 | nodes = model(nodes, edges, adjacency) 38 | ``` 39 | 40 | **Note**: If your graph does not have edge features, you can set `edge_dim` and `edges` (in the forward pass) to `None`. 41 | -------------------------------------------------------------------------------- /graph-transformer-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willyfh/graph-transformer/e9b176800a999f53a1f73c417ada2c630d80cfaa/graph-transformer-architecture.png -------------------------------------------------------------------------------- /graph_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from graph_transformer.graph_transformer_model import GraphTransformerModel -------------------------------------------------------------------------------- /graph_transformer/graph_transformer_model.py: -------------------------------------------------------------------------------- 1 | """ This GNN architecture is implemented based on Section 3.1 (Graph Transformer) in: 2 | Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification. 3 | https://www.ijcai.org/proceedings/2021/0214.pdf. 4 | 5 | In the comments of this code, when "Eq (x)" is mentioned, it refers to the equations in the above paper. 6 | 7 | I implemented this code by referring to this repository (https://github.com/lucidrains/graph-transformer-pytorch), 8 | but with some modifications to match with the original published paper (IJCAI-21). 9 | 10 | """ 11 | 12 | import math 13 | import torch 14 | from torch import nn, einsum 15 | 16 | List = nn.ModuleList 17 | 18 | def softmax(x, adjacency, dim=-1, ): 19 | """ This calculates softmax based on the given adjacency matrix. 20 | """ 21 | means = torch.mean(x, dim, keepdim=True)[0] 22 | x_exp = torch.exp(x-means) * adjacency 23 | x_exp_sum = torch.sum(x_exp, dim, keepdim=True) 24 | x_exp_sum[x_exp_sum==0] = 1. 25 | 26 | return x_exp/x_exp_sum 27 | 28 | class GatedResidual(nn.Module): 29 | """ This is the implementation of Eq (5), i.e., gated residual connection between block. 30 | """ 31 | def __init__(self, dim, only_gate=False): 32 | super().__init__() 33 | self.lin_res = nn.Linear(dim, dim) 34 | self.proj = nn.Sequential( 35 | nn.Linear(dim * 3, 1, bias = False), 36 | nn.Sigmoid() 37 | ) 38 | self.norm = nn.LayerNorm(dim) 39 | self.non_lin = nn.ReLU() 40 | self.only_gate = only_gate 41 | 42 | def forward(self, x, res): 43 | res = self.lin_res(res) 44 | gate_input = torch.cat((x, res, x - res), dim = -1) 45 | gate = self.proj(gate_input) # Eq (5), this is beta in the paper 46 | if self.only_gate: # This is for Eq (6), a case when normalizaton and non linearity is not used. 47 | return x * gate + res * (1 - gate) 48 | return self.non_lin(self.norm(x * gate + res * (1 - gate))) 49 | 50 | class GraphTransformer(nn.Module): 51 | """ This is the implementation of Eq (3) and Eq (5), which is the graph transformer block. 52 | """ 53 | def __init__( 54 | self, 55 | in_dim, 56 | out_dim, # head dim 57 | num_heads = 8, 58 | edge_dim = None, 59 | average = False # This is for Eq (6), a case when average is used instead of concatenation. 60 | ): 61 | super().__init__() 62 | self.out_dim = out_dim 63 | 64 | inner_dim = out_dim * num_heads 65 | 66 | self.num_heads = num_heads 67 | self.average = average 68 | 69 | self.lin_q = nn.Linear(in_dim, inner_dim) 70 | self.lin_k = nn.Linear(in_dim, inner_dim) 71 | self.lin_v = nn.Linear(in_dim, inner_dim) 72 | if edge_dim is not None: 73 | self.lin_e = nn.Linear(edge_dim, inner_dim) 74 | 75 | def forward(self, nodes, edges, adjacency): 76 | h = self.num_heads 77 | b = nodes.shape[0] 78 | n_nodes = nodes.shape[1] 79 | 80 | # Eq (3) 81 | q = self.lin_q(nodes) # batch x n_nodes x dim -> batch x n_nodes x inner_dim 82 | k = self.lin_k(nodes) # batch x n_nodes x dim -> batch x n_nodes x inner_dim 83 | 84 | # Eq (4) 85 | v = self.lin_v(nodes) # batch x n_nodes x dim -> batch x n_nodes x inner_dim 86 | 87 | # Eq (3) 88 | if edges is not None: 89 | e = self.lin_e(edges) # batch x n_nodes x n_nodes x edge_dim 90 | 91 | # Split the inner_dim into multiple head, b .. (h d) - > (b h) .. d 92 | # The attention score later will be computed for each head 93 | q =q.view(-1, n_nodes, h, self.out_dim).permute(0,2,1,3).reshape(-1, n_nodes, self.out_dim) 94 | k =k.view(-1, n_nodes, h, self.out_dim).permute(0,2,1,3).reshape(-1, n_nodes, self.out_dim) 95 | v =v.view(-1, n_nodes, h, self.out_dim).permute(0,2,1,3).reshape(-1, n_nodes, self.out_dim) 96 | 97 | if edges is not None: 98 | e = e.view(-1, n_nodes,n_nodes, h, self.out_dim).permute(0,3,1,2,4).reshape(-1, n_nodes,n_nodes, self.out_dim) 99 | 100 | # Add additional dimension in axis=1 so that it can be added with e. 101 | # Eg. (batch, 1, n_nodes, out_dim) + (batch, n_nodes, n_nodes, out_dim) 102 | k = torch.unsqueeze(k, 1) 103 | v = torch.unsqueeze(v, 1) 104 | 105 | # Eq (3), addition in the attention score computation 106 | if edges is not None: 107 | k = k + e 108 | 109 | # Eq (4), addition before concatenation of multi-head 110 | if edges is not None: 111 | v = v + e 112 | 113 | # Scaled dot-product, before softmax, only in Eq (3) 114 | sim = einsum('b i d, b i j d -> b i j', q, k) / math.sqrt(self.out_dim) 115 | 116 | # Softmax computation 117 | adj = adjacency.repeat_interleave(h, dim=0) # repeat the "adjacency" for h times, so the dimension is the same as "sim" 118 | attn = softmax(sim, adj, dim=-1) 119 | 120 | # Eq (4), multiplication of attention with (v+e), and sum over j (neighbours) 121 | out = einsum('b i j, b i j d -> b i d', attn, v) 122 | 123 | if not self.average: # Eq (4), concatenate multi-head 124 | out = out.view(-1, h, n_nodes, self.out_dim).permute(0,2,1,3).reshape(-1, n_nodes, h*self.out_dim) 125 | else: # Eq (6), average multi-head 126 | out = out.view(-1, h, n_nodes, self.out_dim).permute(0,2,1,3) 127 | out = torch.mean(out, dim=2) 128 | 129 | return out 130 | 131 | 132 | class GraphTransformerModel(nn.Module): 133 | """ This is the overall architecture of the model. 134 | """ 135 | def __init__( 136 | self, 137 | node_dim, 138 | edge_dim, 139 | num_blocks, # number of graph transformer blocks 140 | num_heads = 8, 141 | last_average=False, # wether to average or concatenation at the last block 142 | model_dim=None # if None, node_dim will be used as the dimension of the graph transformer block 143 | ): 144 | super().__init__() 145 | self.layers = List([]) 146 | 147 | # to project the node_dim to model_dim, if model_dim is defined 148 | self.proj_node_dim = None 149 | if not model_dim: 150 | model_dim = node_dim 151 | else: 152 | self.proj_node_dim= nn.Linear(node_dim, model_dim) 153 | 154 | assert model_dim % num_heads == 0 155 | 156 | self.lin_output = nn.Linear(model_dim, 1) 157 | 158 | for i in range(num_blocks): 159 | if not last_average or i=1.6', 'numpy>=1.8'], 14 | classifiers=[ 15 | 'Development Status :: 3 - Alpha', 16 | 'Intended Audience :: Developers', 17 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 18 | 'License :: OSI Approved :: MIT License', 19 | 'Programming Language :: Python :: 3.8', 20 | ], 21 | ) 22 | --------------------------------------------------------------------------------