├── .gitignore ├── imgs ├── LINUX.png ├── IMDBMulti.png ├── AIDS700nef_1.png ├── AIDS700nef_2.png ├── LINUX_patience.png ├── LINUX_train_loss.png ├── LINUX_valid_loss.png ├── IMDBMulti_train_loss.png └── IMDBMulti_valid_loss.png ├── Logs ├── LINUX │ └── best_model.pt ├── IMDBMulti │ └── best_model.pt └── AIDS700nef │ └── best_model.pt ├── main.py ├── utils ├── config.yml ├── config.py └── utils.py ├── LICENSE ├── requirements.txt ├── model ├── layers.py ├── SimGNN.py └── Trainer.py ├── README.md └── README_en.md /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | wandb/ 3 | datasets/ 4 | */__pycache__/ -------------------------------------------------------------------------------- /imgs/LINUX.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/LINUX.png -------------------------------------------------------------------------------- /imgs/IMDBMulti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/IMDBMulti.png -------------------------------------------------------------------------------- /imgs/AIDS700nef_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/AIDS700nef_1.png -------------------------------------------------------------------------------- /imgs/AIDS700nef_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/AIDS700nef_2.png -------------------------------------------------------------------------------- /imgs/LINUX_patience.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/LINUX_patience.png -------------------------------------------------------------------------------- /Logs/LINUX/best_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/Logs/LINUX/best_model.pt -------------------------------------------------------------------------------- /imgs/LINUX_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/LINUX_train_loss.png -------------------------------------------------------------------------------- /imgs/LINUX_valid_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/LINUX_valid_loss.png -------------------------------------------------------------------------------- /Logs/IMDBMulti/best_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/Logs/IMDBMulti/best_model.pt -------------------------------------------------------------------------------- /Logs/AIDS700nef/best_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/Logs/AIDS700nef/best_model.pt -------------------------------------------------------------------------------- /imgs/IMDBMulti_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/IMDBMulti_train_loss.png -------------------------------------------------------------------------------- /imgs/IMDBMulti_valid_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sangs3112/SimGNN/HEAD/imgs/IMDBMulti_valid_loss.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils.config import Parser 2 | from model.Trainer import Trainer 3 | from utils.utils import set_seed, get_config, nice_printer, load_data 4 | 5 | if __name__ == '__main__': 6 | args = Parser().parse() 7 | config = get_config(args) 8 | set_seed(config['seed']) 9 | 10 | train_data, test_data, norm_ged = load_data(args.dataset_path, args.dataset) 11 | nice_printer(config) 12 | trainer = Trainer(config, norm_ged) 13 | if not args.test: 14 | trainer.fit(train_data) 15 | trainer.score(train_data, test_data) 16 | -------------------------------------------------------------------------------- /utils/config.yml: -------------------------------------------------------------------------------- 1 | SimGNN: 2 | epochs: 10000 3 | patience: 50 4 | seed: 42 5 | start_val_iter: 100 # 设置第几代开始验证,减少验证时间消耗 6 | every_val_iter: 1 # 开始验证后,每隔几代进行一次验证 7 | gpu_index: 0 # 如果写-1,表示用cpu进行计算 8 | batch_size: 128 9 | lr: 0.001 10 | wandb: False # 是否使用wandb联网记录训练结果 11 | 12 | histogram: True # 是否使用直方图 所有人,除了官方,都不用直方图。我看了所有版本的SimGNN,无一例外都把直方图关了。因为加上以后这个性能是真的差。 13 | tensor_neurons: 16 14 | bins: 16 15 | filters_1: 64 # 第一层GCN的输出维度 16 | filters_2: 32 # 第二层GCN的输出维度 17 | filters_3: 16 # 第三层GCN的输出维度 18 | 19 | bottle_neck_neurons_1: 16 20 | bottle_neck_neurons_2: 8 21 | bottle_neck_neurons_3: 4 22 | 23 | dropout: 0.5 # 论文中没有提及,但是官方TensorFlow代码中包含了dropout=0.5 24 | 25 | AIDS700nef: 26 | num_features: 29 27 | 28 | LINUX: 29 | num_features: 8 30 | 31 | IMDBMulti: 32 | num_features: 89 33 | 34 | ALKANE: 35 | num_features: 5 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sangs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | class Parser: 3 | def __init__(self) -> None: 4 | self.parser = argparse.ArgumentParser(description='SimGNN with Official Settings') 5 | self._set_arguments() 6 | 7 | def _set_arguments(self): 8 | # 数据集文件路径,默认为datasets目录 9 | self.parser.add_argument('--dataset_path' 10 | , type=str 11 | , default='./datasets/' 12 | , help='path to the datasets') 13 | 14 | # 接下来要使用的数据集,默认为AIDS700nef 15 | self.parser.add_argument('--dataset' 16 | , type=str 17 | , default='AIDS700nef' 18 | , choices=['AIDS700nef', 'LINUX', 'IMDBMulti', 'ALKANE'] 19 | , help='the specific dataset which will be used next') 20 | 21 | # 日志文件路径,默认为Logs目录 22 | self.parser.add_argument('--log_path' 23 | , type=str 24 | , default='./Logs/' 25 | , help='path to logs') 26 | 27 | # 只进行测试 28 | self.parser.add_argument('--test' 29 | , action='store_true' 30 | , help='test only (skip train)') 31 | 32 | def parse(self): 33 | args, unparsed = self.parser.parse_known_args() 34 | if len(unparsed) != 0: 35 | raise ValueError('Unknown argument: {}'.format(unparsed)) 36 | return args -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | appdirs=1.4.4=pypi_0 7 | ca-certificates=2023.12.12=h06a4308_0 8 | certifi=2023.11.17=pypi_0 9 | charset-normalizer=3.3.2=pypi_0 10 | click=8.1.7=pypi_0 11 | docker-pycreds=0.4.0=pypi_0 12 | filelock=3.13.1=pypi_0 13 | fsspec=2023.12.2=pypi_0 14 | gitdb=4.0.11=pypi_0 15 | gitpython=3.1.40=pypi_0 16 | idna=3.6=pypi_0 17 | jinja2=3.1.2=pypi_0 18 | joblib=1.3.2=pypi_0 19 | ld_impl_linux-64=2.38=h1181459_1 20 | libffi=3.3=he6710b0_2 21 | libgcc-ng=11.2.0=h1234567_1 22 | libgomp=11.2.0=h1234567_1 23 | libstdcxx-ng=11.2.0=h1234567_1 24 | markupsafe=2.1.3=pypi_0 25 | mpmath=1.3.0=pypi_0 26 | ncurses=6.4=h6a678d5_0 27 | networkx=3.2.1=pypi_0 28 | numpy=1.26.1=pypi_0 29 | nvidia-cublas-cu12=12.1.3.1=pypi_0 30 | nvidia-cuda-cupti-cu12=12.1.105=pypi_0 31 | nvidia-cuda-nvrtc-cu12=12.1.105=pypi_0 32 | nvidia-cuda-runtime-cu12=12.1.105=pypi_0 33 | nvidia-cudnn-cu12=8.9.2.26=pypi_0 34 | nvidia-cufft-cu12=11.0.2.54=pypi_0 35 | nvidia-curand-cu12=10.3.2.106=pypi_0 36 | nvidia-cusolver-cu12=11.4.5.107=pypi_0 37 | nvidia-cusparse-cu12=12.1.0.106=pypi_0 38 | nvidia-nccl-cu12=2.18.1=pypi_0 39 | nvidia-nvjitlink-cu12=12.3.101=pypi_0 40 | nvidia-nvtx-cu12=12.1.105=pypi_0 41 | openssl=1.1.1w=h7f8727e_0 42 | pip=23.3.1=py39h06a4308_0 43 | protobuf=4.25.1=pypi_0 44 | psutil=5.9.7=pypi_0 45 | pyparsing=3.1.1=pypi_0 46 | python=3.9.0=hdb3f193_2 47 | pyyaml=6.0.1=pypi_0 48 | readline=8.2=h5eee18b_0 49 | requests=2.31.0=pypi_0 50 | scikit-learn=1.3.2=pypi_0 51 | scipy=1.11.3=pypi_0 52 | sentry-sdk=1.39.1=pypi_0 53 | setproctitle=1.3.3=pypi_0 54 | setuptools=68.2.2=py39h06a4308_0 55 | six=1.16.0=pypi_0 56 | smmap=5.0.1=pypi_0 57 | sqlite=3.41.2=h5eee18b_0 58 | sympy=1.12=pypi_0 59 | texttable=1.7.0=pypi_0 60 | threadpoolctl=3.2.0=pypi_0 61 | tk=8.6.12=h1ccaba5_0 62 | torch=2.1.0=pypi_0 63 | torch-geometric=2.4.0=pypi_0 64 | tqdm=4.66.1=pypi_0 65 | triton=2.1.0=pypi_0 66 | typing-extensions=4.9.0=pypi_0 67 | tzdata=2023d=h04d1e81_0 68 | urllib3=2.1.0=pypi_0 69 | wandb=0.16.2=pypi_0 70 | wheel=0.41.2=py39h06a4308_0 71 | xz=5.4.5=h5eee18b_0 72 | zlib=1.2.13=h5eee18b_0 73 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AttentionModule(nn.Module): 6 | """ 7 | 通过GCN得到节点嵌入以后,应该输入该模块得到图嵌入。\\ 8 | 输入的节点嵌入维度为 N * D,N是节点个数,D是嵌入维度,也就是filters_3.也就是16 9 | """ 10 | def __init__(self, config): 11 | super(AttentionModule, self).__init__() 12 | self.config = config 13 | # 创建权重W2, 维度是D * D, 用于计算全局上下文c 14 | self.W2 = nn.Parameter(torch.Tensor(self.config['filters_3'], self.config['filters_3'])) 15 | # 使用 Xavier 方法初始化权重矩阵W2, 据说可以缓解梯度下降问题 16 | nn.init.xavier_uniform_(self.W2) 17 | 18 | def forward(self, x): 19 | """ 20 | 返回图级表示矩阵 根据后面的NTN模块,这里输出的h应该是D * 1格式 \\ 21 | x 维度:N * D 整张图的节点级表示\\ 22 | w2 维度:D * D \\ 23 | c == transformed_global 的维度应该是 N * D \times D * D = N * D.mean() = 1 * D\\ 24 | SimGNN公式(2) 25 | """ 26 | c = torch.tanh(torch.mm(x, self.W2).mean(dim=0)).view(1, -1) # 1 * D 27 | h = torch.mm(x.T, torch.sigmoid(torch.mm(x, c.T))).T # 1 * D = (D * N \times (N * D \times D * 1)).T 28 | return h # 1 * D 29 | 30 | class TensorNetworkModule(torch.nn.Module): 31 | def __init__(self, config): 32 | super(TensorNetworkModule, self).__init__() 33 | self.config = config 34 | self.W3 = nn.Parameter(torch.Tensor(self.config['tensor_neurons'], self.config['filters_3'], self.config['filters_3'])) # K * D * D 特定跟论文中不同,为了更好实现代码 35 | self.V = nn.Parameter(torch.Tensor(self.config['tensor_neurons'], 2 * self.config['filters_3'])) # K * 2D 36 | self.b3 = nn.Parameter(torch.Tensor(1, self.config['tensor_neurons'])) # 1 * K 37 | nn.init.xavier_uniform_(self.W3) 38 | nn.init.xavier_uniform_(self.V) 39 | nn.init.xavier_uniform_(self.b3) 40 | 41 | def forward(self, hi, hj): 42 | """ 43 | hi: 1 * D 44 | hj: 1 * D 45 | W3: D * D * K 46 | """ 47 | term_1 = [] 48 | for W_0 in self.W3: 49 | term_1.append(torch.mm(torch.mm(hi, W_0), hj.T)) 50 | term_1 = torch.cat(term_1, dim=1) # 1 * K 51 | term_2 = torch.mm(self.V, torch.cat((hi, hj),dim = 1).T).T # 1 * K 52 | 53 | scores = F.relu(term_1 + term_2 + self.b3) # SimGNN公式(3) 54 | return scores # 1 * K 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [SimGNN](https://arxiv.org/abs/1808.05689): 2 | `[WSDM 2019] SimGNN: A Neural Network Approach to Fast Graph Similarity Computation` 3 | 4 | **本实现完全按照[SimGNN论文](https://arxiv.org/abs/1808.05689)实验部分设置** 5 | 6 | ![GitHub License](https://img.shields.io/github/license/Sangs3112/SimGNN) 7 | ![PyPI - Version](https://img.shields.io/pypi/v/pypi) 8 | 9 | 中文版 | [English](./README_en.md) 10 | 11 | ## **目录结构** 12 | ``` 13 | SimGNN/ 14 | ├── datasets/ # 存放数据集文件 15 | │ ├── AIDS700nef/ 16 | │ ├── ALKANE/ 17 | │ ├── IMDBMulti/ 18 | │ ├── ...(may be other datasets) 19 | | └── LINUX/ 20 | ├── imgs/ # 存放README中的图片等资源 21 | ├── Logs/ # 存放日志文件 22 | │ ├── AIDS700nef/ 23 | │ ├── IMDBMulti/ 24 | │ ├── ...(may be other datasets) 25 | | └── LINUX/ 26 | ├── model/ # 模型代码 27 | │ ├── layers.py # 包含Att模块,NTN模块 28 | │ ├── SimGNN.py # SimGNN模型部分 29 | | └── Trainer.py # 训练、验证和测试模块 30 | ├── utils/ 31 | │ ├── config.py # 系统级参数,例如数据集名称等 32 | │ ├── config.yml # 模型级、数据集级参数,例如训练epochs,patience,num_features等 33 | | └── utils.py # 工具,包含加载数据集,加载配置等 34 | └── main.py 35 | ``` 36 | > 需要[下载datasets](https://drive.google.com/drive/folders/1MOOUxxC_76Jseuc-JWaJ6B6LfU6-wNfR?usp=drive_link),包含`AIDS700nef`,`LINUX`,`IMDBMulti`,`ALKANE`数据集 37 | > 38 | > 1. 将下载的`datasets.tar.gz`压缩文件移动到`SimGNN`目录下 39 | > 40 | > 2. 解压缩: `tar -xvzf datasets.tar.gz` 41 | > 42 | > 3. 解压缩完成后进入`datasets/`目录,使用相同的命令再次解压缩四个数据集即可 43 | > 44 | > P.s: 实际上,如果你不下载我提供的数据集,也可以直接在`SimGNN/`项目根目录下创建`datasets/`目录,此时`GEDDataset`函数会自动下载数据集。 45 | 46 | ## **环境依赖** 47 | ``` 48 | pyyaml == 6.0.1 49 | wandb == 0.16.2 50 | python == 3.9 51 | numpy == 1.26 52 | scipy == 1.11 53 | tqdm == 4.66.1 54 | texttable == 1.7 55 | torch == 2.1.0 56 | torch-geometric == 2.4.0 57 | ``` 58 | 59 | ## **run** 60 | ``` 61 | # AIDS700nef 62 | python main.py 63 | # LINUX 64 | python main.py --dataset LINUX 65 | # IMDBMulti 66 | python main.py --dataset IMDBMulti 67 | # ALKANE 68 | python main.py --dataset ALKANE 69 | 70 | # AIDS700nef (test only) 71 | python main.py --test 72 | # LINUX (test only) 73 | python main.py --dataset LINUX --test 74 | # IMDBMulti (test only) 75 | python main.py --dataset IMDBMulti --test 76 | ``` 77 | 78 | ## **原文结果** 79 | | datasets | MSE($10^{-3}$) | $\rho$ | $\tau$ | $p@10$ | $p@20$ | 80 | |:----:|:----:|:----:|:----:|:----:|:----:| 81 | | AIDS700nef | 1.189 | 0.843 | 0.690 | 0.421 | 0.514 | 82 | | LINUX | 1.509 | 0.939 | 0.830 | 0.942 | 0.933 | 83 | | IMDBMulti | 1.264 | 0.878 | 0.770 | 0.759 | 0.777 | 84 | 85 | ## **运行结果** 86 | ### **AIDS700nef** 87 | 遗憾的是,`AIDS700nef`数据集上的结果丢失了`wandb`的日志记录,所以少了每代的损失以及`patience`记录。 88 | 89 | 1. 90 | ![AIDS700nef_result_1](./imgs/AIDS700nef_1.png) 91 | 92 | 2. 93 | ![AIDS700nef_result_2](./imgs/AIDS700nef_2.png) 94 | 95 | ### **LINUX** 96 | `LINUX`的结果看上去比原文更好。由于`LINUX`数据集没有`label`,所以本实现直接使用度的`onehot`编码作为节点的输入特征。 97 | ![LINUX_result](./imgs/LINUX.png) 98 | 99 | - 训练损失记录 100 | 101 | 102 | 103 | - 验证损失记录 104 | 105 | 106 | 107 | - 每代`patience`变化,耐心值设置为30。 108 | 109 | 110 | 111 | ### **IMDBMulti** 112 | 与`LINUX`相同,`IMDBMulti`的结果比原文更好,`IMDBMulti`同样直接使用度的`onehot`编码作为节点的输入特征。 113 | ![IMDBMulti_result](./imgs/IMDBMulti.png) 114 | 115 | - 训练损失记录 116 | 117 | 118 | 119 | - 验证损失记录 120 | 121 | 122 | 123 | ### **ALKANE** 124 | - 可惜的是,这个数据集结果不太正常。 125 | - 不过[SimGNN原文](https://arxiv.org/abs/1808.05689)中也并没有涉及到这个数据集,因此我也没有仔细研究原因。 126 | - 有兴趣的话可以继续研究问题所在。 127 | 128 | > 如果你喜欢这个的项目的话,请给我们Stars ~ 129 | -------------------------------------------------------------------------------- /model/SimGNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_geometric.nn import GCNConv 6 | from model.layers import AttentionModule, TensorNetworkModule 7 | 8 | class SimGNN(nn.Module): 9 | def __init__(self, config): 10 | super(SimGNN, self).__init__() 11 | self._config = config 12 | self._num_features = config['num_features'] 13 | self._device = 'cuda:' + str(config['gpu_index']) if config['gpu_index'] >= 0 else 'cpu' 14 | self._setup_layers() 15 | 16 | def _calculate_bottleneck_features(self): 17 | """ 18 | 送入全连接层之前的维度大小,根据是否使用直方图而改变,如果使用直方图就是32 * 1, 否则就是16 * 1\\ 19 | 最终的维度大小存入feature_count中 20 | """ 21 | if self._config['histogram']: 22 | self.feature_count = self._config['tensor_neurons'] + self._config['bins'] # 32 23 | else: 24 | self.feature_count = self._config['tensor_neurons'] # 16 25 | 26 | def _setup_layers(self): 27 | self._calculate_bottleneck_features() 28 | # 三层GCN,原文中设置为:64 --> 32 --> 16 29 | self.conv1 = GCNConv(self._num_features, self._config['filters_1']) # --> 64 30 | self.conv2 = GCNConv(self._config['filters_1'], self._config['filters_2']) # 64 --> 32 31 | self.conv3 = GCNConv(self._config['filters_2'], self._config['filters_3']) # 32 --> 16 32 | 33 | self.attention = AttentionModule(self._config).to(self._device) # 1 * D 34 | 35 | self.tensor_network = TensorNetworkModule(self._config).to(self._device) # 1 * K 36 | 37 | self.fully_connected_first = nn.Linear(self.feature_count, self._config['bottle_neck_neurons_1']) # -> 16 38 | self.fully_connected_second = nn.Linear(self._config['bottle_neck_neurons_1'], self._config['bottle_neck_neurons_2']) # 16 -> 8 39 | self.fully_connected_third = nn.Linear(self._config['bottle_neck_neurons_2'], self._config['bottle_neck_neurons_3']) # 8 -> 4 40 | self.scoring_layer = nn.Linear(self._config['bottle_neck_neurons_3'], 1) # 4 -> 1 41 | 42 | def _calculate_histogram(self, Ui, Uj): 43 | """ 44 | 计算直方图 45 | Ui: 图i的特征矩阵 维度 Ni * D 46 | Uj: 图j的特征矩阵 维度 Nj * D 47 | 返回直方图: 直方图的相似度分数 B 48 | """ 49 | Ni, Di = Ui.shape 50 | Nj, Dj = Uj.shape 51 | 52 | N = max(Ni, Nj) 53 | 54 | S1 = torch.sigmoid(torch.mm(Ui, Uj.T)) # 维度应该是N * N, 实际上现在暂时得到的是Ni * Nj 因为文中说需要填充 55 | S = torch.zeros(N, N) 56 | S[:Ni, :Nj] = S1 57 | S = S.view(-1, 1) 58 | hist = torch.histc(S, bins=self._config['bins']) 59 | hist = hist / torch.sum(hist) 60 | return hist.view(1, -1).to(self._device) # 1 * B 61 | 62 | def _convolutional_pass(self, A, X): 63 | """ 64 | 三层图卷积,得到节点表示\\ 65 | A: 邻接矩阵\\ 66 | X: 节点初始特征 维度 N * num_features\\ 67 | 返回 features: 一张图的节点表示 维度 N * D 68 | """ 69 | features = self.conv1(X, A) 70 | features = F.relu(features) 71 | features = F.dropout(features, p=self._config['dropout'], training=self.training) 72 | features = self.conv2(features, A) 73 | features = F.relu(features) 74 | features = F.dropout(features, p=self._config['dropout'], training=self.training) 75 | features = self.conv3(features, A) 76 | return features 77 | 78 | def forward(self, data): 79 | edge_index_1 = data["g1"].edge_index 80 | edge_index_2 = data["g2"].edge_index 81 | features_1 = data["g1"].x 82 | features_2 = data["g2"].x 83 | 84 | # U1, U2分别是图1和图2的节点级表示 85 | U1 = self._convolutional_pass(edge_index_1, features_1) # N1 * D 86 | U2 = self._convolutional_pass(edge_index_2, features_2) # N2 * D 87 | 88 | if self._config['histogram']: 89 | hist = self._calculate_histogram(U1, U2) # 1 * B 90 | 91 | # h1, h2 是图1和图2的图级表示 92 | h1 = self.attention(U1) # 1 * D 93 | h2 = self.attention(U2) # 1 * D 94 | 95 | scores = self.tensor_network(h1, h2) # 1 * K 96 | 97 | if self._config['histogram']: 98 | scores = torch.cat((scores, hist), dim= 1) # 1 * (B + K) 99 | 100 | scores = F.relu(self.fully_connected_first(scores)) 101 | scores = F.relu(self.fully_connected_second(scores)) 102 | scores = F.relu(self.fully_connected_third(scores)) 103 | 104 | score = torch.sigmoid(self.scoring_layer(scores)).view(-1) 105 | return score -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # [SimGNN](https://arxiv.org/abs/1808.05689): 2 | `[WSDM 2019] SimGNN: A Neural Network Approach to Fast Graph Similarity Computation` 3 | 4 | **This implementation is modeled exactly according to the code setup in the [SimGNN paper](https://arxiv.org/abs/1808.05689)** 5 | 6 | ![GitHub License](https://img.shields.io/github/license/Sangs3112/SimGNN) 7 | ![PyPI - Version](https://img.shields.io/pypi/v/pypi) 8 | 9 | [中文版](./README.md) | English 10 | 11 | ## **directory structure** 12 | ``` 13 | SimGNN/ 14 | ├── datasets/ 15 | │ ├── AIDS700nef/ 16 | │ ├── ALKANE/ 17 | │ ├── IMDBMulti/ 18 | │ ├── ...(may be other datasets) 19 | | └── LINUX/ 20 | ├── imgs/ # contain the imgs files in README 21 | ├── Logs/ # store log files 22 | │ ├── AIDS700nef/ 23 | │ ├── IMDBMulti/ 24 | │ ├── ...(may be other datasets) 25 | | └── LINUX/ 26 | ├── model/ # contain the model code 27 | │ ├── layers.py # including 'Att' and 'NTN' modules 28 | │ ├── SimGNN.py # the code of SimGNN 29 | | └── Trainer.py # contain train, validate, test modules 30 | ├── utils/ 31 | │ ├── config.py # System-level parameters, such as data set names 32 | │ ├── config.yml # Model level, data set level parameters, such as patience, num_features 33 | | └── utils.py # Tools, including loading data sets, loading configurations 34 | └── main.py 35 | ``` 36 | > You need to [download datasets](https://drive.google.com/drive/folders/1MOOUxxC_76Jseuc-JWaJ6B6LfU6-wNfR?usp=drive_link), which include `AIDS700nef`, `LINUX`, `IMDBMulti`, `ALKANE` datasets 37 | > 38 | > 1. Move the downloaded `datasets.tar.gz` compressed file to `SimGNN/` 39 | > 40 | > 2. Decompress: `tar -xvzf datasets.tar.gz` 41 | > 42 | > 3. After the decompression is complete, `cd datasets/` and use the same command to decompress the four datasets again 43 | > 44 | > P.s: In fact, if you don't download the dataset which I provided, you can just execute `datasets/` in the `SimGNN/` project root directory, and the `GEDDataset` function will automatically download these dataset. 45 | 46 | ## **Requirements** 47 | ``` 48 | pyyaml == 6.0.1 49 | python == 3.9 50 | numpy == 1.26 51 | scipy == 1.11 52 | tqdm == 4.66.1 53 | texttable == 1.7 54 | torch == 2.1.0 55 | torch-geometric == 2.4.0 56 | ``` 57 | 58 | ## **run** 59 | ``` 60 | # AIDS700nef 61 | python main.py 62 | # LINUX 63 | python main.py --dataset LINUX 64 | # IMDBMulti 65 | python main.py --dataset IMDBMulti 66 | # ALKANE 67 | python main.py --dataset ALKANE 68 | 69 | # AIDS700nef (test only) 70 | python main.py --test 71 | # LINUX (test only) 72 | python main.py --dataset LINUX --test 73 | # IMDBMulti (test only) 74 | python main.py --dataset IMDBMulti --test 75 | ``` 76 | 77 | ## **Official Result** 78 | | datasets | MSE($10^{-3}$) | $\rho$ | $\tau$ | $p@10$ | $p@20$ | 79 | |:----:|:----:|:----:|:----:|:----:|:----:| 80 | | AIDS700nef | 1.189 | 0.843 | 0.690 | 0.421 | 0.514 | 81 | | LINUX | 1.509 | 0.939 | 0.830 | 0.942 | 0.933 | 82 | | IMDBMulti | 1.264 | 0.878 | 0.770 | 0.759 | 0.777 | 83 | 84 | ## **Self Result** 85 | ### **AIDS700nef** 86 | Unfortunately, the results on the `AIDS700nef` are missing the `wandb` log, so the loss per epoch and the `patience` record are missing. 87 | 88 | 1. 89 | ![AIDS700nef_result_1](./imgs/AIDS700nef_1.png) 90 | 91 | 2. 92 | ![AIDS700nef_result_2](./imgs/AIDS700nef_2.png) 93 | 94 | ### **LINUX** 95 | The `LINUX` result looks better than the original. Since the `LINUX` dataset does not have `label`, this implementation directly uses the degree's `onehot` encoding as the node's input feature. 96 | 97 | ![LINUX_result](./imgs/LINUX.png) 98 | 99 | - Training loss records 100 | 101 | 102 | 103 | - Validating loss records 104 | 105 | 106 | 107 | - `patience` changes with each epoch, and the `patience` value is set to 30. 108 | 109 | 110 | 111 | ### **IMDBMulti** 112 | Similar to `LINUX`, `IMDBMulti` result looks better than the original, directly uses the degree's `onehot` encoding as the node's input feature. 113 | 114 | ![IMDBMulti_result](./imgs/IMDBMulti.png) 115 | 116 | - Training loss records 117 | 118 | 119 | 120 | - Validating loss records 121 | 122 | 123 | 124 | ### **ALKANE** 125 | - Unfortunately, this data set turned out to be abnormal. 126 | - But in [SimGNN paper](https://arxiv.org/abs/1808.05689), this `ALKANE` dataset is not covered in the official datasets, so I didn't look into why. 127 | - If you're interested, you can follow up on the problem. 128 | 129 | > If you like this project, please send us Stars ~ 130 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | from texttable import Texttable 7 | from torch_geometric.utils import degree 8 | from torch_geometric.datasets import GEDDataset 9 | from torch_geometric.transforms import OneHotDegree 10 | 11 | def print_evals(mse_error, rho, tau, p10, p20): 12 | r""" 13 | 用于模型test阶段结束后,打印所有的结果 14 | 参数 mse_error: float,模型预测的MSE 15 | 参数 rho: 斯皮尔曼相关系数 16 | 参数 tau: 肯德尔相关系数 17 | 参数 p10: p@10 18 | 参数 p20: p@20 19 | """ 20 | print("mse(10^-3): " + str(round(mse_error * 1000, 5)) + '.') 21 | print("rho: " + str(round(rho, 5)) + '.') 22 | print("tau: " + str(round(tau, 5)) + '.') 23 | print("p@10: " + str(round(p10, 5)) + '.') 24 | print("p@20: " + str(round(p20, 5)) + '.') 25 | 26 | def calculate_ranking_correlation(rank_corr_function, prediction, target): 27 | r""" 28 | 计算相关系数 29 | 参数 rank_corr_function: 函数,是scipy.stats中的函数 30 | 参数 prediction: double,是模型的预测值 31 | 参数 target: double,是数据集的准确值 32 | 返回 相关系数 33 | """ 34 | def ranking_func(data): 35 | sort_id_mat = np.argsort(-data) 36 | n = sort_id_mat.shape[0] 37 | rank = np.zeros(n) 38 | for i in range(n): 39 | finds = np.where(sort_id_mat == i) 40 | fid = finds[0][0] 41 | while fid > 0: 42 | cid = sort_id_mat[fid] 43 | pid = sort_id_mat[fid - 1] 44 | if data[pid] == data[cid]: 45 | fid -= 1 46 | else: 47 | break 48 | rank[i] = fid + 1 49 | return rank 50 | 51 | r_prediction = ranking_func(prediction) 52 | r_target = ranking_func(target) 53 | 54 | return rank_corr_function(r_prediction, r_target).correlation 55 | 56 | def prec_at_ks(true_r, pred_r, ks, rm=0): 57 | r""" 58 | 计算 p@k 59 | 参数 true_r: double,数据集的真实值 60 | 参数 pred_r: double,模型的预测值 61 | 参数 ks: int,k的值 62 | 返回 ps: int,最后的p@k 63 | """ 64 | def top_k_ids(data, k, inclusive, rm): 65 | """ 66 | :param data: input 67 | :param k: 68 | :param inclusive: whether to be tie inclusive or not. 69 | For example, the ranking may look like this: 70 | 7 (sim_score=0.99), 5 (sim_score=0.99), 10 (sim_score=0.98), ... 71 | If tie inclusive, the top 1 results are [7, 9]. 72 | Therefore, the number of returned results may be larger than k. 73 | In summary, 74 | len(rtn) == k if not tie inclusive; 75 | len(rtn) >= k if tie inclusive. 76 | :param rm: 0 77 | :return: for a query, the ids of the top k database graph 78 | ranked by this model. 79 | """ 80 | sort_id_mat = np.argsort(-data) 81 | n = sort_id_mat.shape[0] 82 | if k < 0 or k >= n: 83 | raise RuntimeError('Invalid k {}'.format(k)) 84 | if not inclusive: 85 | return sort_id_mat[:k] 86 | # Tie inclusive. 87 | dist_sim_mat = data 88 | while k < n: 89 | cid = sort_id_mat[k - 1] 90 | nid = sort_id_mat[k] 91 | if abs(dist_sim_mat[cid] - dist_sim_mat[nid]) <= rm: 92 | k += 1 93 | else: 94 | break 95 | return sort_id_mat[:k] 96 | true_ids = top_k_ids(true_r, ks, inclusive=True, rm=rm) 97 | pred_ids = top_k_ids(pred_r, ks, inclusive=True, rm=rm) 98 | ps = min( len(set(true_ids).intersection(set(pred_ids)) ), ks) / ks 99 | return ps 100 | 101 | def create_training_batches_all(training_len, batch_size): 102 | r""" 103 | 生成训练的batch,将来需要在每个epoch中都进行调用 104 | 参数 training_len: int, 训练集的长度,则图id范围为[0, training_len - 1] 105 | 参数 batch_size: int, batch的大小 106 | 返回 batches: 双层list,除了最后一个以外,其他内部list长度均为batch_size。 107 | """ 108 | train_graph_list = list(range(training_len)) 109 | combinations = [(i, j) for i in train_graph_list for j in train_graph_list if i < j] 110 | random.shuffle(combinations) 111 | batches = [combinations[i:i + batch_size] for i in range(0, len(combinations), batch_size)] 112 | return batches 113 | 114 | def create_train_pairs_id(training_len, batch_size): 115 | real_training_len = int(training_len * 0.75) 116 | return create_training_batches_all(real_training_len, batch_size) 117 | 118 | def create_validate_pairs_id(training_len): 119 | real_training_len = int(training_len * 0.75) 120 | return [(i, j) for i in range(real_training_len) for j in range(real_training_len, training_len)] 121 | 122 | def create_test_pairs_id(training_len, testing_len): 123 | real_training_len = int(training_len * 0.75) 124 | return [(i, j) for i in range(real_training_len) for j in range(testing_len)] 125 | 126 | def nice_printer(config): 127 | r""" 128 | 打印配置 129 | 参数 config: 字典,键值分别为参数的名称和对应的参数值 130 | """ 131 | tabel_data = [['Key', 'Value']] + [[k, v] for k, v in config.items()] 132 | t = Texttable().set_precision(4) 133 | t.add_rows(tabel_data) 134 | print(t.draw()) 135 | 136 | def set_seed(seed): 137 | r""" 138 | 设置所有的随机数种子都为seed 139 | 参数 seed: int, 将要设置的随机数种子值 140 | """ 141 | random.seed(seed) 142 | np.random.seed(seed) 143 | torch.manual_seed(seed) 144 | if torch.cuda.is_available(): 145 | torch.cuda.manual_seed(seed) 146 | torch.cuda.manual_seed_all(seed) 147 | torch.backends.cudnn.deterministic = True # 启用确定性计算,降低性能,但是确保实验可重复,尤其是dropout等随机性操作 148 | torch.backends.cudnn.benchmark = False # 禁用自动调整策略,提高稳定性 149 | 150 | def get_config(args): 151 | r""" 152 | 通过传入的args参数,得到完整的 config字典 153 | 参数 args: 默认包含一些main函数中公用参数 154 | 返回 config: 完整参数字典,键值分别为参数名和参数对应的值 155 | """ 156 | 157 | config = _get_part_config('utils/config.yml')['SimGNN'] 158 | config.update( _get_part_config('utils/config.yml')[args.dataset] ) 159 | config['log_path'] = args.log_path + args.dataset + '/' 160 | 161 | return config 162 | 163 | def _get_part_config(config_path): 164 | r""" 165 | 读取config.yml文件 166 | 参数 config_path: config.yml文件路径 167 | 返回 config: 字典,键值分别为参数名和参数对应的值 168 | """ 169 | with open(config_path, "r") as setting: 170 | config = yaml.load(setting, Loader=yaml.FullLoader) 171 | return config 172 | 173 | def load_data(path, dataset_name): 174 | path = path + dataset_name 175 | train_data = GEDDataset(path, dataset_name) 176 | test_data = GEDDataset(path, dataset_name, False) 177 | norm_ged = train_data.norm_ged 178 | 179 | # 但是LINUX,IMDBMuliti,ALKANE这三个数据集的并没有节点的类型x 180 | if train_data[0].x is None: 181 | train_data, test_data = _process_feature(train_data, test_data) 182 | 183 | return train_data, test_data, norm_ged 184 | 185 | def _process_feature(train_data, test_data): 186 | max_degree = 0 187 | for g in (train_data + test_data): 188 | if g.edge_index.size(1) > 0: 189 | max_degree = max( max_degree, int(degree(g.edge_index[0]).max().item()) ) 190 | one_hot_degree = OneHotDegree(max_degree, cat=False) 191 | train_data.transform = one_hot_degree 192 | test_data.transform = one_hot_degree 193 | return train_data, test_data 194 | -------------------------------------------------------------------------------- /model/Trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | from tqdm import tqdm, trange 8 | from model.SimGNN import SimGNN 9 | from scipy.stats import spearmanr, kendalltau 10 | from utils.utils import print_evals, prec_at_ks, calculate_ranking_correlation 11 | from utils.utils import create_test_pairs_id, create_train_pairs_id, create_validate_pairs_id 12 | 13 | class Trainer(object): 14 | def __init__(self, config, norm_ged): 15 | self._lr = config['lr'] # 模型学习率 16 | self._norm_ged = norm_ged # 归一化的 GED,还不是 target 17 | self._wandb = config['wandb'] # 是否开启wandb记录模型损失 18 | self._epochs = config['epochs'] # 模型训练代数 19 | self._log_path = config['log_path'] # 日志存放路径 20 | self._patience = config['patience'] # 早停的 patience 值 21 | self._batch_size = config['batch_size'] # 模型 batch_size 22 | self._start_val_iter = config['start_val_iter'] # 开始验证的代数 23 | self._every_val_iter = config['every_val_iter'] # 开始验证后,验证的间隔代数 24 | self._device = 'cuda:' + str(config['gpu_index']) if config['gpu_index'] >= 0 else 'cpu' 25 | 26 | self._model = SimGNN(config).to(self._device) 27 | 28 | if os.path.exists(config['log_path']) is False: 29 | os.makedirs(config['log_path']) 30 | 31 | def _validate(self, validaite_pairs_id, train_data): 32 | self._model.eval() 33 | predictions, targets = [], [] 34 | with torch.no_grad(): 35 | for (G1_id, G2_id) in tqdm(validaite_pairs_id, total=len(validaite_pairs_id), desc="Validate"): 36 | data = dict() 37 | data['g1'] = train_data[G1_id].to(self._device) 38 | data['g2'] = train_data[G2_id].to(self._device) 39 | targets.append( torch.exp(-self._norm_ged[G1_id][G2_id]).view(-1).to(self._device) ) 40 | predictions.append( self._model(data) ) 41 | vloss = F.mse_loss(torch.cat(predictions), torch.cat(targets)) 42 | return vloss 43 | 44 | def fit(self, train_data): 45 | # 训练过程 46 | print("\n======= SimGNN training in {}. =======\n".format(self._log_path.split('/')[-2])) 47 | if self._wandb: 48 | wandb.init( 49 | project='SimGNN', 50 | name=self._log_path.split('/')[-2], 51 | config={ 52 | 'learning_rate': self._lr, 53 | 'dataset': self._log_path.split('/')[-2], 54 | }) 55 | 56 | # wandb.watch(self._model, log='all', log_graph=True, log_freq=10) 57 | 58 | cur_patience = 0 59 | optimizer = torch.optim.AdamW(self._model.parameters(), lr=self._lr) 60 | train_pairs_id = create_train_pairs_id(len(train_data), self._batch_size) 61 | validaite_pairs_id = create_validate_pairs_id(len(train_data)) 62 | 63 | epochs = trange(self._epochs, leave=True, desc="Epoch") 64 | min_vloss = 99999.0 65 | for epoch in epochs: 66 | self._model.train() 67 | cur_tloss = 0.0 68 | for batch_id in tqdm(train_pairs_id, total=len(train_pairs_id), desc="Train"): 69 | predictions, targets = [], [] 70 | optimizer.zero_grad() 71 | for G1_id, G2_id in batch_id: 72 | data = dict() 73 | data['g1'] = train_data[G1_id].to(self._device) 74 | data['g2'] = train_data[G2_id].to(self._device) 75 | targets.append( torch.exp(-self._norm_ged[G1_id][G2_id]).view(-1).to(self._device) ) 76 | predictions.append( self._model(data) ) 77 | loss = F.mse_loss(torch.cat(predictions), torch.cat(targets)) # 一整个batch的损失 78 | cur_tloss += loss.item() 79 | loss.backward() 80 | optimizer.step() 81 | 82 | cur_tloss = round(cur_tloss / len(train_pairs_id), 5) 83 | if self._wandb: 84 | wandb.log({'train_loss': cur_tloss}) 85 | else: 86 | with open(self._log_path + 'train_loss.txt', 'a') as f: 87 | f.write(str(epoch) + '\t' + str(cur_tloss) + '\n') 88 | epochs.set_description("Epoch (Loss=%g)" % cur_tloss) 89 | 90 | if epoch + 1 >= self._start_val_iter: 91 | if epoch % self._every_val_iter != 0: 92 | continue 93 | torch.cuda.empty_cache() 94 | cur_vloss = self._validate(validaite_pairs_id, train_data).item() 95 | torch.cuda.empty_cache() 96 | 97 | cur_vloss = round(cur_vloss, 5) 98 | if min_vloss > cur_vloss: 99 | min_vloss = cur_vloss 100 | cur_patience = 0 101 | self._save() 102 | else: 103 | cur_patience += 1 104 | 105 | if self._wandb: 106 | wandb.log({'valid_loss': cur_vloss, 'cur_patience': cur_patience}) 107 | else: 108 | with open(self._log_path + 'valid_loss.txt', 'a') as f: 109 | f.write(str(epoch) + '\t' + str(cur_vloss) + '\t' + str(cur_patience) + '\n') 110 | if cur_patience >= self._patience: 111 | print("Early Stop!") 112 | break 113 | if self._wandb: 114 | wandb.finish() 115 | 116 | def score(self, train_data, test_data): 117 | # 测试过程 118 | print("\n======= SimGNN testing in {}. =======\n".format(self._log_path.split('/')[-2])) 119 | self._load() 120 | self._model.eval() 121 | 122 | scores = np.zeros( (len(test_data), len(train_data)) ) 123 | ground_truth = np.zeros( (len(test_data), len(train_data)) ) 124 | prediction_mat = np.zeros( (len(test_data), len(train_data)) ) 125 | rho_list = [] 126 | tau_list = [] 127 | prec_at_10_list = [] 128 | prec_at_20_list = [] 129 | 130 | test_pairs_id = create_test_pairs_id(len(train_data), len(test_data)) 131 | 132 | with torch.no_grad(): 133 | for (G1_id, G2_id) in tqdm(test_pairs_id, total=len(test_pairs_id), desc="Test"): 134 | data = dict() 135 | data['g1'] = train_data[G1_id].to(self._device) 136 | data['g2'] = test_data[G2_id].to(self._device) 137 | 138 | pred = self._model(data).cpu() 139 | targ = torch.exp(-self._norm_ged[G1_id][G2_id + len(train_data)]) 140 | ground_truth[G2_id][G1_id] = targ.cpu().numpy() 141 | prediction_mat[G2_id][G1_id] = pred.cpu().numpy() 142 | scores[G2_id][G1_id] = F.mse_loss(pred, targ.view(-1)).detach().cpu().numpy() 143 | for i in range(len(test_data)): 144 | rho_list.append( calculate_ranking_correlation(spearmanr, prediction_mat[i], ground_truth[i]) ) 145 | tau_list.append( calculate_ranking_correlation(kendalltau, prediction_mat[i], ground_truth[i]) ) 146 | prec_at_10_list.append( prec_at_ks(ground_truth[i], prediction_mat[i], 10) ) 147 | prec_at_20_list.append( prec_at_ks(ground_truth[i], prediction_mat[i], 20) ) 148 | 149 | print_evals(np.mean(scores), np.mean(rho_list), np.mean(tau_list), np.mean(prec_at_10_list), np.mean(prec_at_20_list)) 150 | 151 | def _save(self): 152 | torch.save(self._model.state_dict(), self._log_path + 'best_model.pt') 153 | 154 | def _load(self): 155 | self._model.load_state_dict(torch.load(self._log_path + 'best_model.pt')) 156 | --------------------------------------------------------------------------------