├── misc ├── fig.png └── fig_stable.png ├── LICENSE ├── merge.py ├── pairwise_dist.py ├── svc.py ├── README.md ├── stability.py └── tmd.py /misc/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chingyaoc/TMD/HEAD/misc/fig.png -------------------------------------------------------------------------------- /misc/fig_stable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chingyaoc/TMD/HEAD/misc/fig_stable.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ching-Yao Chuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | import argparse 5 | 6 | import torch_geometric 7 | from torch_geometric.datasets import TUDataset 8 | from torch_geometric.loader import DataLoader 9 | 10 | parser = argparse.ArgumentParser(description='Tree Mover Distance') 11 | parser.add_argument('--w', default=0.5, type=float, help='Layer weighting term') 12 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree') 13 | parser.add_argument('--dataset', default='MUTAG', type=str, help='dataset name') 14 | 15 | 16 | # args parse 17 | args = parser.parse_args() 18 | w, L, dataset_name = args.w, args.L, args.dataset 19 | path = osp.join('data', dataset_name) 20 | dataset = TUDataset(path, name=dataset_name) 21 | 22 | Ms = [] 23 | for idx in range((len(dataset) // 50) + 1): 24 | M = np.load('./PairDist/M_'+dataset_name+'_L'+str(L)+'_w'+str(w)+'_idx'+str(idx)+'.npy') 25 | Ms.append(M) 26 | 27 | M = np.concatenate(Ms, axis=0) 28 | M = M[:len(dataset)] 29 | for i in range(len(dataset)): 30 | for j in range(len(dataset)): 31 | if M[i, j] == -1: 32 | M[i, j] = M[j, i] 33 | 34 | np.save('./PairDist/M_'+dataset_name+'_L'+str(L)+'_w'+str(w)+'.npy', M) 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /pairwise_dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import argparse 6 | import torch_geometric 7 | from torch_geometric.datasets import TUDataset 8 | from torch_geometric.loader import DataLoader 9 | from tqdm import tqdm 10 | 11 | from tmd import TMD 12 | 13 | parser = argparse.ArgumentParser(description='Tree Mover Distance') 14 | parser.add_argument('--w', default=0.5, type=float, help='Layer weighting term') 15 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree') 16 | parser.add_argument('--dataset', default='MUTAG', type=str, help='dataset name') 17 | parser.add_argument('--idx', default=0, type=int, help='idx for batch') 18 | parser.add_argument('--n_per_idx', default=50, type=int, help='batch size') 19 | 20 | # args parse 21 | args = parser.parse_args() 22 | w, L, dataset_name = args.w, args.L, args.dataset 23 | n_per_idx = args.n_per_idx 24 | 25 | path = osp.join('data', dataset_name) 26 | train_dataset = TUDataset(path, name=dataset_name) 27 | n = len(train_dataset) 28 | start = n_per_idx * args.idx 29 | end = min(n_per_idx * (args.idx + 1), n) 30 | 31 | print('Precompute pairwise distance') 32 | M = np.zeros((n_per_idx, n)) - 1. 33 | for i in tqdm(range(start, end)): 34 | for j in tqdm(range(start, n)): 35 | M[i-start, j] = TMD(train_dataset[i], train_dataset[j], w=w, L=L) 36 | 37 | if not os.path.exists('PairDist'): 38 | os.mkdir('PairDist') 39 | np.save('./PairDist/M_'+dataset_name+'_L'+str(L)+'_w'+str(w)+'_idx'+str(args.idx)+'.npy', M) 40 | 41 | 42 | -------------------------------------------------------------------------------- /svc.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import ot 3 | import copy 4 | import random 5 | import argparse 6 | 7 | import torch 8 | from torch_geometric.datasets import TUDataset 9 | from sklearn.svm import SVC 10 | from tqdm import tqdm 11 | 12 | parser = argparse.ArgumentParser(description='Tree Mover Distance') 13 | parser.add_argument('--w', default=0.5, type=float, help='Layer weighting term') 14 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree') 15 | parser.add_argument('--dataset', default='MUTAG', type=str, help='dataset name') 16 | parser.add_argument('--rs', default=0, type=int, help='random seed') 17 | 18 | # args parse 19 | args = parser.parse_args() 20 | w, L, dataset_name = args.w, args.L, args.dataset 21 | 22 | random.seed(args.rs) 23 | torch.manual_seed(args.rs) 24 | 25 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset_name) 26 | dataset = TUDataset(path, name=dataset_name) 27 | M = np.load('./PairDist/M_'+dataset_name+'_L4_w'+str(w)+'.npy') 28 | 29 | # shuffle 30 | idx = [i for i in range(len(dataset))] 31 | random.shuffle(idx) 32 | 33 | n = len(dataset) // 10 34 | 35 | idx_train = idx[n:] 36 | idx_test = idx[:n] 37 | train_dataset = dataset[idx_train] 38 | test_dataset = dataset[idx_test] 39 | M = M[idx, :] 40 | M = M[:, idx] 41 | 42 | y = [] 43 | for i in range(len(dataset)): 44 | y.append(dataset[i].y) 45 | y = np.array(y) 46 | 47 | 48 | # Cross Val 49 | M_cv = M[n:, n:] 50 | lams = [0.01, 0.05, 0.1] 51 | best_k_count = np.zeros(len(lams)) 52 | for it in range(10): 53 | idx_cv = [i for i in range(len(train_dataset))] 54 | random.shuffle(idx_cv) 55 | n_cv = len(train_dataset) // 10 56 | idx_train_cv = idx_cv[n_cv:] 57 | idx_test_cv = idx_cv[:n_cv] 58 | 59 | for lam in lams: 60 | model = SVC(kernel = 'precomputed') 61 | model.fit(np.exp(-lam * M_cv[idx_train_cv][:, idx_train_cv]), y[idx_train][idx_train_cv]) 62 | y_pred = model.predict(np.exp(-lam * M_cv[idx_test_cv][:, idx_train_cv])) 63 | acc = sum(y_pred == y[idx_train][idx_test_cv]) / len(y_pred) 64 | best_k_count[lams.index(lam)] += acc 65 | 66 | best_lam = np.argmax(best_k_count) 67 | lam = lams[best_lam] 68 | M_ = np.exp(-lam*M) 69 | model = SVC(kernel = 'precomputed') 70 | model.fit(M_[n:, n:], y[idx_train]) 71 | 72 | y_pred = model.predict(M_[:n, n:]) 73 | acc = sum(y_pred == y[idx_test]) / len(y_pred) 74 | print('{}, L: {}, w: {}, Acc: {}'.format(dataset_name, L, w, acc)) 75 | 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tree Mover's Distance for Graphs 2 | 3 |

4 | 5 |

6 | 7 | 8 | Understanding generalization and robustness of machine learning models fundamentally relies on assuming an appropriate metric on the data space. Identifying such a metric is particularly challenging for non-Euclidean data such as graphs. Here, we propose a pseudometric for attributed graphs, the Tree Mover's Distance (TMD), and study its relation to generalization. Via a hierarchical optimal transport problem, TMD reflects the local distribution of node attributes as well as the distribution of local computation trees, which are known to be decisive for the learning behavior of graph neural networks (GNNs). First, we show that TMD captures properties relevant to graph classification: a simple TMD-SVM performs competitively with standard GNNs. Second, we relate TMD to generalization of GNNs under distribution shifts, and show that it correlates well with performance drop under such shifts. 9 | 10 | **Tree Mover’s Distance: Bridging Graph Metrics and Stability of Graph Neural Networks** NeurIPS 2022 [[paper]](https://arxiv.org/abs/2210.01906) 11 |
12 | [Ching-Yao Chuang](https://chingyaoc.github.io/) and 13 | [Stefanie Jegelka](https://people.csail.mit.edu/stefje/) 14 |
15 | 16 | 17 | ## Prerequisites 18 | - Python 3.7 19 | - PyTorch 1.3.1 20 | - PyTorch Geometric 21 | - POT 22 | 23 | 24 | ## Usage Examples 25 | The code for computing Tree Mover's Distance (TMD) lie in `tmd.py`. For instance, the following code compute the TMD between two graphs of MUTAG dataset. 26 | ```python 27 | from tmd import TMD 28 | from torch_geometric.datasets import TUDataset 29 | 30 | dataset = TUDataset('data', name='MUTAG') 31 | d = TMD(dataset[0], dataset[1], w=1.0, L=4) 32 | ``` 33 | 34 | One can also specify different weighting constants for each layer as follows: 35 | ```python 36 | d = TMD(dataset[0], dataset[1], w=[0.33, 1, 3], L=4) 37 | ``` 38 | This results in a tighter bound on the stability of GNNs as Theorem 8 shows. Note that `len(w)` has to be the same as `L-1`. 39 | 40 | 41 | ## Graph Classification on TUDataset 42 | 43 | Step 1: Pre-compute the pairwise distance (potentially parallel). For instance, the following script compute the pairwise distances of MUTAG with `pairwise_dist.py` by separating it into 4 batches, where each batch is computed parallely. One can merge the batches with `merge.py`. 44 | ``` 45 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 0 46 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 1 47 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 2 48 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 3 49 | 50 | python merge.py --w 0.5 --L 4 --dataset MUTAG 51 | ``` 52 | 53 | Step 2: Train a SVM classifier based on the pre-computed distances: 54 | ``` 55 | python svc.py --w 0.5 --L 4 --dataset MUTAG 56 | ``` 57 | 58 | ## Measuring the Stability of GNNs 59 | The script `stability.py` reproduce the stability experiments in Figure 5. In particular, it plots the correlation between a (L+1)-layer GIN and the tree mover's distance with graphs sampled from MUTAG. 60 | ``` 61 | python stability.py --L 3 62 | ``` 63 | 64 |

65 | 66 |

67 | 68 | 69 | ## Citation 70 | 71 | If you find this repo useful for your research, please consider citing the paper 72 | 73 | ``` 74 | @article{chuang2022tree, 75 | title={Tree Mover’s Distance: Bridging Graph Metrics and Stability of Graph Neural Networks}, 76 | author={Chuang, Ching-Yao and Jegelka, Stefanie}, 77 | journal={Advances in Neural Information Processing Systems}, 78 | volume={35}, 79 | year={2022} 80 | } 81 | ``` 82 | 83 | 84 | -------------------------------------------------------------------------------- /stability.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import numpy as np 4 | import os.path as osp 5 | from tqdm import tqdm 6 | from scipy.stats import pearsonr 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn import BatchNorm1d, Linear, ReLU, Sequential, BCEWithLogitsLoss 11 | from torch_geometric.datasets import TUDataset 12 | from torch_geometric.loader import DataLoader 13 | from torch_geometric.nn import GINConv, global_add_pool 14 | 15 | from tmd import TMD 16 | 17 | parser = argparse.ArgumentParser(description='Tree Mover Distance') 18 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree') 19 | args = parser.parse_args() 20 | 21 | # The Pascal’s triangle 22 | ws = [[1], [1], [0.5, 2], [1/3, 1, 3]] 23 | w = ws[args.L - 1] 24 | 25 | # TU Dataset 26 | dataset = TUDataset('data', name='MUTAG').shuffle() 27 | train_dataset = dataset[len(dataset) // 10:] 28 | test_dataset = dataset[:len(dataset) // 10] 29 | train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) 30 | test_loader = DataLoader(test_dataset, batch_size=128) 31 | 32 | 33 | class Net(torch.nn.Module): 34 | ''' 35 | 3-layer GIN Network 36 | ''' 37 | def __init__(self, in_channels, dim, out_channels, L): 38 | super().__init__() 39 | conv1 = GINConv( 40 | Sequential(Linear(in_channels, dim), ReLU(), 41 | Linear(dim, dim), ReLU())) 42 | conv2 = GINConv( 43 | Sequential(Linear(dim, dim), ReLU(), 44 | Linear(dim, dim), ReLU())) 45 | conv3 = GINConv( 46 | Sequential(Linear(dim, dim), ReLU(), 47 | Linear(dim, dim), ReLU())) 48 | self.convs = [conv1, conv2, conv3] 49 | self.lin1 = Linear(dim, dim) 50 | self.lin2 = Linear(dim, 1) 51 | self.L = L 52 | 53 | def forward(self, x, edge_index, batch): 54 | for l in range(int(self.L-1)): 55 | x = self.convs[l](x, edge_index) 56 | x = global_add_pool(x, batch) 57 | x = self.lin1(x).relu() 58 | x = self.lin2(x) 59 | return x 60 | 61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 62 | model = Net(dataset.num_features, 32, dataset.num_classes, args.L).to(device) 63 | for l in range(int(args.L-1)): 64 | model.convs[l] = model.convs[l].to(device) 65 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 66 | criterion = BCEWithLogitsLoss() 67 | 68 | def train(): 69 | model.train() 70 | total_loss = 0 71 | for data in train_loader: 72 | data = data.to(device) 73 | optimizer.zero_grad() 74 | output = model(data.x, data.edge_index, data.batch) 75 | loss = criterion(output[:,0], data.y.float()) 76 | loss.backward() 77 | optimizer.step() 78 | total_loss += float(loss) * data.num_graphs 79 | return total_loss / len(train_loader.dataset) 80 | 81 | 82 | @torch.no_grad() 83 | def test(loader): 84 | model.eval() 85 | 86 | total_correct = 0 87 | for data in loader: 88 | data = data.to(device) 89 | out = model(data.x, data.edge_index, data.batch) 90 | total_correct += int(((F.sigmoid(out[:, 0]) > 0.5).int() == data.y).sum()) 91 | return total_correct / len(loader.dataset) 92 | 93 | 94 | for epoch in range(1, 101): 95 | loss = train() 96 | train_acc = test(train_loader) 97 | test_acc = test(test_loader) 98 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f} ' 99 | f'Test Acc: {test_acc:.4f}') 100 | 101 | 102 | oo = [] 103 | tt = [] 104 | for i in tqdm(range(1000)): 105 | a = random.randint(0, len(dataset)-1) 106 | b = random.randint(0, len(dataset)-1) 107 | g_a = dataset[a].cuda() 108 | g_b = dataset[b].cuda() 109 | 110 | # output from GIN 111 | output_a = model(g_a.x, g_a.edge_index, torch.zeros(len(g_a.x), dtype=torch.int64).cuda()) 112 | output_b = model(g_b.x, g_b.edge_index, torch.zeros(len(g_b.x), dtype=torch.int64).cuda()) 113 | 114 | # TMD 115 | tmd = TMD(dataset[a], dataset[b], w=w, L=args.L) 116 | 117 | oo.append(float(torch.norm(output_a - output_b).cpu().detach().numpy())) 118 | tt.append(tmd) 119 | 120 | 121 | import matplotlib.pyplot as plt 122 | plt.scatter(oo, tt) 123 | plt.savefig('gnn_plot.png', dpi=120) 124 | print('Pearson correlation: {}'.format(pearsonr(np.array(oo), np.array(tt)))) 125 | -------------------------------------------------------------------------------- /tmd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tree Mover's Distance solver 3 | """ 4 | # Author: Ching-Yao Chuang 5 | # License: MIT License 6 | 7 | import numpy as np 8 | import torch 9 | import ot 10 | import copy 11 | 12 | 13 | def get_neighbors(g): 14 | ''' 15 | get neighbor indexes for each node 16 | 17 | Parameters 18 | ---------- 19 | g : input torch_geometric graph 20 | 21 | 22 | Returns 23 | ---------- 24 | adj: a dictionary that store the neighbor indexes 25 | 26 | ''' 27 | adj = {} 28 | for i in range(len(g.edge_index[0])): 29 | node1 = g.edge_index[0][i].item() 30 | node2 = g.edge_index[1][i].item() 31 | if node1 in adj.keys(): 32 | adj[node1].append(node2) 33 | else: 34 | adj[node1] = [node2] 35 | return adj 36 | 37 | 38 | def TMD(g1, g2, w, L=4): 39 | ''' 40 | return the Tree Mover’s Distance (TMD) between g1 and g2 41 | 42 | Parameters 43 | ---------- 44 | g1, g2 : two torch_geometric graphs 45 | w : weighting constant for each depth 46 | if it is a list, then w[l] is the weight for depth-(l+1) tree 47 | if it is a constant, then every layer shares the same weight 48 | L : Depth of computation trees for calculating TMD 49 | 50 | Returns 51 | ---------- 52 | wass : The TMD between g1 and g2 53 | 54 | Reference 55 | ---------- 56 | Chuang et al., Tree Mover’s Distance: Bridging Graph Metrics and 57 | Stability of Graph Neural Networks, NeurIPS 2022 58 | ''' 59 | 60 | if isinstance(w, list): 61 | assert(len(w) == L-1) 62 | else: 63 | w = [w] * (L-1) 64 | 65 | # get attributes 66 | n1, n2 = len(g1.x), len(g2.x) 67 | feat1, feat2 = g1.x, g2.x 68 | adj1 = get_neighbors(g1) 69 | adj2 = get_neighbors(g2) 70 | 71 | blank = np.zeros(len(feat1[0])) 72 | D = np.zeros((n1, n2)) 73 | 74 | # level 1 (pair wise distance) 75 | M = np.zeros((n1+1, n2+1)) 76 | for i in range(n1): 77 | for j in range(n2): 78 | D[i, j] = torch.norm(feat1[i] - feat2[j]) 79 | M[i, j] = D[i, j] 80 | # distance w.r.t. blank node 81 | M[:n1, n2] = torch.norm(feat1, dim=1) 82 | M[n1, :n2] = torch.norm(feat2, dim=1) 83 | 84 | # level l (tree OT) 85 | for l in range(L-1): 86 | M1 = copy.deepcopy(M) 87 | M = np.zeros((n1+1, n2+1)) 88 | 89 | # calculate pairwise cost between tree i and tree j 90 | for i in range(n1): 91 | for j in range(n2): 92 | try: 93 | degree_i = len(adj1[i]) 94 | except: 95 | degree_i = 0 96 | try: 97 | degree_j = len(adj2[j]) 98 | except: 99 | degree_j = 0 100 | 101 | if degree_i == 0 and degree_j == 0: 102 | M[i, j] = D[i, j] 103 | # if degree of node is zero, calculate TD w.r.t. blank node 104 | elif degree_i == 0: 105 | wass = 0. 106 | for jj in range(degree_j): 107 | wass += M1[n1, adj2[j][jj]] 108 | M[i, j] = D[i, j] + w[l] * wass 109 | elif degree_j == 0: 110 | wass = 0. 111 | for ii in range(degree_i): 112 | wass += M1[adj1[i][ii], n2] 113 | M[i, j] = D[i, j] + w[l] * wass 114 | # otherwise, calculate the tree distance 115 | else: 116 | max_degree = max(degree_i, degree_j) 117 | if degree_i < max_degree: 118 | cost = np.zeros((degree_i + 1, degree_j)) 119 | cost[degree_i] = M1[n1, adj2[j]] 120 | dist_1, dist_2 = np.ones(degree_i + 1), np.ones(degree_j) 121 | dist_1[degree_i] = max_degree - float(degree_i) 122 | else: 123 | cost = np.zeros((degree_i, degree_j + 1)) 124 | cost[:, degree_j] = M1[adj1[i], n2] 125 | dist_1, dist_2 = np.ones(degree_i), np.ones(degree_j + 1) 126 | dist_2[degree_j] = max_degree - float(degree_j) 127 | for ii in range(degree_i): 128 | for jj in range(degree_j): 129 | cost[ii, jj] = M1[adj1[i][ii], adj2[j][jj]] 130 | wass = ot.emd2(dist_1, dist_2, cost) 131 | 132 | # summarize TMD at level l 133 | M[i, j] = D[i, j] + w[l] * wass 134 | 135 | # fill in dist w.r.t. blank node 136 | for i in range(n1): 137 | try: 138 | degree_i = len(adj1[i]) 139 | except: 140 | degree_i = 0 141 | 142 | if degree_i == 0: 143 | M[i, n2] = torch.norm(feat1[i]) 144 | else: 145 | wass = 0. 146 | for ii in range(degree_i): 147 | wass += M1[adj1[i][ii], n2] 148 | M[i, n2] = torch.norm(feat1[i]) + w[l] * wass 149 | 150 | for j in range(n2): 151 | try: 152 | degree_j = len(adj2[j]) 153 | except: 154 | degree_j = 0 155 | if degree_j == 0: 156 | M[n1, j] = torch.norm(feat2[j]) 157 | else: 158 | wass = 0. 159 | for jj in range(degree_j): 160 | wass += M1[n1, adj2[j][jj]] 161 | M[n1, j] = torch.norm(feat2[j]) + w[l] * wass 162 | 163 | 164 | # final OT cost 165 | max_n = max(n1, n2) 166 | dist_1, dist_2 = np.ones(n1+1), np.ones(n2+1) 167 | if n1 < max_n: 168 | dist_1[n1] = max_n - float(n1) 169 | dist_2[n2] = 0. 170 | else: 171 | dist_1[n1] = 0. 172 | dist_2[n2] = max_n - float(n2) 173 | 174 | wass = ot.emd2(dist_1, dist_2, M) 175 | return wass 176 | 177 | 178 | --------------------------------------------------------------------------------