├── misc
├── fig.png
└── fig_stable.png
├── LICENSE
├── merge.py
├── pairwise_dist.py
├── svc.py
├── README.md
├── stability.py
└── tmd.py
/misc/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chingyaoc/TMD/HEAD/misc/fig.png
--------------------------------------------------------------------------------
/misc/fig_stable.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chingyaoc/TMD/HEAD/misc/fig_stable.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ching-Yao Chuang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/merge.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | import torch
4 | import argparse
5 |
6 | import torch_geometric
7 | from torch_geometric.datasets import TUDataset
8 | from torch_geometric.loader import DataLoader
9 |
10 | parser = argparse.ArgumentParser(description='Tree Mover Distance')
11 | parser.add_argument('--w', default=0.5, type=float, help='Layer weighting term')
12 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree')
13 | parser.add_argument('--dataset', default='MUTAG', type=str, help='dataset name')
14 |
15 |
16 | # args parse
17 | args = parser.parse_args()
18 | w, L, dataset_name = args.w, args.L, args.dataset
19 | path = osp.join('data', dataset_name)
20 | dataset = TUDataset(path, name=dataset_name)
21 |
22 | Ms = []
23 | for idx in range((len(dataset) // 50) + 1):
24 | M = np.load('./PairDist/M_'+dataset_name+'_L'+str(L)+'_w'+str(w)+'_idx'+str(idx)+'.npy')
25 | Ms.append(M)
26 |
27 | M = np.concatenate(Ms, axis=0)
28 | M = M[:len(dataset)]
29 | for i in range(len(dataset)):
30 | for j in range(len(dataset)):
31 | if M[i, j] == -1:
32 | M[i, j] = M[j, i]
33 |
34 | np.save('./PairDist/M_'+dataset_name+'_L'+str(L)+'_w'+str(w)+'.npy', M)
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/pairwise_dist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import argparse
6 | import torch_geometric
7 | from torch_geometric.datasets import TUDataset
8 | from torch_geometric.loader import DataLoader
9 | from tqdm import tqdm
10 |
11 | from tmd import TMD
12 |
13 | parser = argparse.ArgumentParser(description='Tree Mover Distance')
14 | parser.add_argument('--w', default=0.5, type=float, help='Layer weighting term')
15 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree')
16 | parser.add_argument('--dataset', default='MUTAG', type=str, help='dataset name')
17 | parser.add_argument('--idx', default=0, type=int, help='idx for batch')
18 | parser.add_argument('--n_per_idx', default=50, type=int, help='batch size')
19 |
20 | # args parse
21 | args = parser.parse_args()
22 | w, L, dataset_name = args.w, args.L, args.dataset
23 | n_per_idx = args.n_per_idx
24 |
25 | path = osp.join('data', dataset_name)
26 | train_dataset = TUDataset(path, name=dataset_name)
27 | n = len(train_dataset)
28 | start = n_per_idx * args.idx
29 | end = min(n_per_idx * (args.idx + 1), n)
30 |
31 | print('Precompute pairwise distance')
32 | M = np.zeros((n_per_idx, n)) - 1.
33 | for i in tqdm(range(start, end)):
34 | for j in tqdm(range(start, n)):
35 | M[i-start, j] = TMD(train_dataset[i], train_dataset[j], w=w, L=L)
36 |
37 | if not os.path.exists('PairDist'):
38 | os.mkdir('PairDist')
39 | np.save('./PairDist/M_'+dataset_name+'_L'+str(L)+'_w'+str(w)+'_idx'+str(args.idx)+'.npy', M)
40 |
41 |
42 |
--------------------------------------------------------------------------------
/svc.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import ot
3 | import copy
4 | import random
5 | import argparse
6 |
7 | import torch
8 | from torch_geometric.datasets import TUDataset
9 | from sklearn.svm import SVC
10 | from tqdm import tqdm
11 |
12 | parser = argparse.ArgumentParser(description='Tree Mover Distance')
13 | parser.add_argument('--w', default=0.5, type=float, help='Layer weighting term')
14 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree')
15 | parser.add_argument('--dataset', default='MUTAG', type=str, help='dataset name')
16 | parser.add_argument('--rs', default=0, type=int, help='random seed')
17 |
18 | # args parse
19 | args = parser.parse_args()
20 | w, L, dataset_name = args.w, args.L, args.dataset
21 |
22 | random.seed(args.rs)
23 | torch.manual_seed(args.rs)
24 |
25 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset_name)
26 | dataset = TUDataset(path, name=dataset_name)
27 | M = np.load('./PairDist/M_'+dataset_name+'_L4_w'+str(w)+'.npy')
28 |
29 | # shuffle
30 | idx = [i for i in range(len(dataset))]
31 | random.shuffle(idx)
32 |
33 | n = len(dataset) // 10
34 |
35 | idx_train = idx[n:]
36 | idx_test = idx[:n]
37 | train_dataset = dataset[idx_train]
38 | test_dataset = dataset[idx_test]
39 | M = M[idx, :]
40 | M = M[:, idx]
41 |
42 | y = []
43 | for i in range(len(dataset)):
44 | y.append(dataset[i].y)
45 | y = np.array(y)
46 |
47 |
48 | # Cross Val
49 | M_cv = M[n:, n:]
50 | lams = [0.01, 0.05, 0.1]
51 | best_k_count = np.zeros(len(lams))
52 | for it in range(10):
53 | idx_cv = [i for i in range(len(train_dataset))]
54 | random.shuffle(idx_cv)
55 | n_cv = len(train_dataset) // 10
56 | idx_train_cv = idx_cv[n_cv:]
57 | idx_test_cv = idx_cv[:n_cv]
58 |
59 | for lam in lams:
60 | model = SVC(kernel = 'precomputed')
61 | model.fit(np.exp(-lam * M_cv[idx_train_cv][:, idx_train_cv]), y[idx_train][idx_train_cv])
62 | y_pred = model.predict(np.exp(-lam * M_cv[idx_test_cv][:, idx_train_cv]))
63 | acc = sum(y_pred == y[idx_train][idx_test_cv]) / len(y_pred)
64 | best_k_count[lams.index(lam)] += acc
65 |
66 | best_lam = np.argmax(best_k_count)
67 | lam = lams[best_lam]
68 | M_ = np.exp(-lam*M)
69 | model = SVC(kernel = 'precomputed')
70 | model.fit(M_[n:, n:], y[idx_train])
71 |
72 | y_pred = model.predict(M_[:n, n:])
73 | acc = sum(y_pred == y[idx_test]) / len(y_pred)
74 | print('{}, L: {}, w: {}, Acc: {}'.format(dataset_name, L, w, acc))
75 |
76 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tree Mover's Distance for Graphs
2 |
3 |
4 |
5 |
6 |
7 |
8 | Understanding generalization and robustness of machine learning models fundamentally relies on assuming an appropriate metric on the data space. Identifying such a metric is particularly challenging for non-Euclidean data such as graphs. Here, we propose a pseudometric for attributed graphs, the Tree Mover's Distance (TMD), and study its relation to generalization. Via a hierarchical optimal transport problem, TMD reflects the local distribution of node attributes as well as the distribution of local computation trees, which are known to be decisive for the learning behavior of graph neural networks (GNNs). First, we show that TMD captures properties relevant to graph classification: a simple TMD-SVM performs competitively with standard GNNs. Second, we relate TMD to generalization of GNNs under distribution shifts, and show that it correlates well with performance drop under such shifts.
9 |
10 | **Tree Mover’s Distance: Bridging Graph Metrics and Stability of Graph Neural Networks** NeurIPS 2022 [[paper]](https://arxiv.org/abs/2210.01906)
11 |
12 | [Ching-Yao Chuang](https://chingyaoc.github.io/) and
13 | [Stefanie Jegelka](https://people.csail.mit.edu/stefje/)
14 |
15 |
16 |
17 | ## Prerequisites
18 | - Python 3.7
19 | - PyTorch 1.3.1
20 | - PyTorch Geometric
21 | - POT
22 |
23 |
24 | ## Usage Examples
25 | The code for computing Tree Mover's Distance (TMD) lie in `tmd.py`. For instance, the following code compute the TMD between two graphs of MUTAG dataset.
26 | ```python
27 | from tmd import TMD
28 | from torch_geometric.datasets import TUDataset
29 |
30 | dataset = TUDataset('data', name='MUTAG')
31 | d = TMD(dataset[0], dataset[1], w=1.0, L=4)
32 | ```
33 |
34 | One can also specify different weighting constants for each layer as follows:
35 | ```python
36 | d = TMD(dataset[0], dataset[1], w=[0.33, 1, 3], L=4)
37 | ```
38 | This results in a tighter bound on the stability of GNNs as Theorem 8 shows. Note that `len(w)` has to be the same as `L-1`.
39 |
40 |
41 | ## Graph Classification on TUDataset
42 |
43 | Step 1: Pre-compute the pairwise distance (potentially parallel). For instance, the following script compute the pairwise distances of MUTAG with `pairwise_dist.py` by separating it into 4 batches, where each batch is computed parallely. One can merge the batches with `merge.py`.
44 | ```
45 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 0
46 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 1
47 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 2
48 | python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 3
49 |
50 | python merge.py --w 0.5 --L 4 --dataset MUTAG
51 | ```
52 |
53 | Step 2: Train a SVM classifier based on the pre-computed distances:
54 | ```
55 | python svc.py --w 0.5 --L 4 --dataset MUTAG
56 | ```
57 |
58 | ## Measuring the Stability of GNNs
59 | The script `stability.py` reproduce the stability experiments in Figure 5. In particular, it plots the correlation between a (L+1)-layer GIN and the tree mover's distance with graphs sampled from MUTAG.
60 | ```
61 | python stability.py --L 3
62 | ```
63 |
64 |
65 |
66 |
67 |
68 |
69 | ## Citation
70 |
71 | If you find this repo useful for your research, please consider citing the paper
72 |
73 | ```
74 | @article{chuang2022tree,
75 | title={Tree Mover’s Distance: Bridging Graph Metrics and Stability of Graph Neural Networks},
76 | author={Chuang, Ching-Yao and Jegelka, Stefanie},
77 | journal={Advances in Neural Information Processing Systems},
78 | volume={35},
79 | year={2022}
80 | }
81 | ```
82 |
83 |
84 |
--------------------------------------------------------------------------------
/stability.py:
--------------------------------------------------------------------------------
1 | import random
2 | import argparse
3 | import numpy as np
4 | import os.path as osp
5 | from tqdm import tqdm
6 | from scipy.stats import pearsonr
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.nn import BatchNorm1d, Linear, ReLU, Sequential, BCEWithLogitsLoss
11 | from torch_geometric.datasets import TUDataset
12 | from torch_geometric.loader import DataLoader
13 | from torch_geometric.nn import GINConv, global_add_pool
14 |
15 | from tmd import TMD
16 |
17 | parser = argparse.ArgumentParser(description='Tree Mover Distance')
18 | parser.add_argument('--L', default=4, type=int, help='Depth of computational tree')
19 | args = parser.parse_args()
20 |
21 | # The Pascal’s triangle
22 | ws = [[1], [1], [0.5, 2], [1/3, 1, 3]]
23 | w = ws[args.L - 1]
24 |
25 | # TU Dataset
26 | dataset = TUDataset('data', name='MUTAG').shuffle()
27 | train_dataset = dataset[len(dataset) // 10:]
28 | test_dataset = dataset[:len(dataset) // 10]
29 | train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
30 | test_loader = DataLoader(test_dataset, batch_size=128)
31 |
32 |
33 | class Net(torch.nn.Module):
34 | '''
35 | 3-layer GIN Network
36 | '''
37 | def __init__(self, in_channels, dim, out_channels, L):
38 | super().__init__()
39 | conv1 = GINConv(
40 | Sequential(Linear(in_channels, dim), ReLU(),
41 | Linear(dim, dim), ReLU()))
42 | conv2 = GINConv(
43 | Sequential(Linear(dim, dim), ReLU(),
44 | Linear(dim, dim), ReLU()))
45 | conv3 = GINConv(
46 | Sequential(Linear(dim, dim), ReLU(),
47 | Linear(dim, dim), ReLU()))
48 | self.convs = [conv1, conv2, conv3]
49 | self.lin1 = Linear(dim, dim)
50 | self.lin2 = Linear(dim, 1)
51 | self.L = L
52 |
53 | def forward(self, x, edge_index, batch):
54 | for l in range(int(self.L-1)):
55 | x = self.convs[l](x, edge_index)
56 | x = global_add_pool(x, batch)
57 | x = self.lin1(x).relu()
58 | x = self.lin2(x)
59 | return x
60 |
61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62 | model = Net(dataset.num_features, 32, dataset.num_classes, args.L).to(device)
63 | for l in range(int(args.L-1)):
64 | model.convs[l] = model.convs[l].to(device)
65 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
66 | criterion = BCEWithLogitsLoss()
67 |
68 | def train():
69 | model.train()
70 | total_loss = 0
71 | for data in train_loader:
72 | data = data.to(device)
73 | optimizer.zero_grad()
74 | output = model(data.x, data.edge_index, data.batch)
75 | loss = criterion(output[:,0], data.y.float())
76 | loss.backward()
77 | optimizer.step()
78 | total_loss += float(loss) * data.num_graphs
79 | return total_loss / len(train_loader.dataset)
80 |
81 |
82 | @torch.no_grad()
83 | def test(loader):
84 | model.eval()
85 |
86 | total_correct = 0
87 | for data in loader:
88 | data = data.to(device)
89 | out = model(data.x, data.edge_index, data.batch)
90 | total_correct += int(((F.sigmoid(out[:, 0]) > 0.5).int() == data.y).sum())
91 | return total_correct / len(loader.dataset)
92 |
93 |
94 | for epoch in range(1, 101):
95 | loss = train()
96 | train_acc = test(train_loader)
97 | test_acc = test(test_loader)
98 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f} '
99 | f'Test Acc: {test_acc:.4f}')
100 |
101 |
102 | oo = []
103 | tt = []
104 | for i in tqdm(range(1000)):
105 | a = random.randint(0, len(dataset)-1)
106 | b = random.randint(0, len(dataset)-1)
107 | g_a = dataset[a].cuda()
108 | g_b = dataset[b].cuda()
109 |
110 | # output from GIN
111 | output_a = model(g_a.x, g_a.edge_index, torch.zeros(len(g_a.x), dtype=torch.int64).cuda())
112 | output_b = model(g_b.x, g_b.edge_index, torch.zeros(len(g_b.x), dtype=torch.int64).cuda())
113 |
114 | # TMD
115 | tmd = TMD(dataset[a], dataset[b], w=w, L=args.L)
116 |
117 | oo.append(float(torch.norm(output_a - output_b).cpu().detach().numpy()))
118 | tt.append(tmd)
119 |
120 |
121 | import matplotlib.pyplot as plt
122 | plt.scatter(oo, tt)
123 | plt.savefig('gnn_plot.png', dpi=120)
124 | print('Pearson correlation: {}'.format(pearsonr(np.array(oo), np.array(tt))))
125 |
--------------------------------------------------------------------------------
/tmd.py:
--------------------------------------------------------------------------------
1 | """
2 | Tree Mover's Distance solver
3 | """
4 | # Author: Ching-Yao Chuang
5 | # License: MIT License
6 |
7 | import numpy as np
8 | import torch
9 | import ot
10 | import copy
11 |
12 |
13 | def get_neighbors(g):
14 | '''
15 | get neighbor indexes for each node
16 |
17 | Parameters
18 | ----------
19 | g : input torch_geometric graph
20 |
21 |
22 | Returns
23 | ----------
24 | adj: a dictionary that store the neighbor indexes
25 |
26 | '''
27 | adj = {}
28 | for i in range(len(g.edge_index[0])):
29 | node1 = g.edge_index[0][i].item()
30 | node2 = g.edge_index[1][i].item()
31 | if node1 in adj.keys():
32 | adj[node1].append(node2)
33 | else:
34 | adj[node1] = [node2]
35 | return adj
36 |
37 |
38 | def TMD(g1, g2, w, L=4):
39 | '''
40 | return the Tree Mover’s Distance (TMD) between g1 and g2
41 |
42 | Parameters
43 | ----------
44 | g1, g2 : two torch_geometric graphs
45 | w : weighting constant for each depth
46 | if it is a list, then w[l] is the weight for depth-(l+1) tree
47 | if it is a constant, then every layer shares the same weight
48 | L : Depth of computation trees for calculating TMD
49 |
50 | Returns
51 | ----------
52 | wass : The TMD between g1 and g2
53 |
54 | Reference
55 | ----------
56 | Chuang et al., Tree Mover’s Distance: Bridging Graph Metrics and
57 | Stability of Graph Neural Networks, NeurIPS 2022
58 | '''
59 |
60 | if isinstance(w, list):
61 | assert(len(w) == L-1)
62 | else:
63 | w = [w] * (L-1)
64 |
65 | # get attributes
66 | n1, n2 = len(g1.x), len(g2.x)
67 | feat1, feat2 = g1.x, g2.x
68 | adj1 = get_neighbors(g1)
69 | adj2 = get_neighbors(g2)
70 |
71 | blank = np.zeros(len(feat1[0]))
72 | D = np.zeros((n1, n2))
73 |
74 | # level 1 (pair wise distance)
75 | M = np.zeros((n1+1, n2+1))
76 | for i in range(n1):
77 | for j in range(n2):
78 | D[i, j] = torch.norm(feat1[i] - feat2[j])
79 | M[i, j] = D[i, j]
80 | # distance w.r.t. blank node
81 | M[:n1, n2] = torch.norm(feat1, dim=1)
82 | M[n1, :n2] = torch.norm(feat2, dim=1)
83 |
84 | # level l (tree OT)
85 | for l in range(L-1):
86 | M1 = copy.deepcopy(M)
87 | M = np.zeros((n1+1, n2+1))
88 |
89 | # calculate pairwise cost between tree i and tree j
90 | for i in range(n1):
91 | for j in range(n2):
92 | try:
93 | degree_i = len(adj1[i])
94 | except:
95 | degree_i = 0
96 | try:
97 | degree_j = len(adj2[j])
98 | except:
99 | degree_j = 0
100 |
101 | if degree_i == 0 and degree_j == 0:
102 | M[i, j] = D[i, j]
103 | # if degree of node is zero, calculate TD w.r.t. blank node
104 | elif degree_i == 0:
105 | wass = 0.
106 | for jj in range(degree_j):
107 | wass += M1[n1, adj2[j][jj]]
108 | M[i, j] = D[i, j] + w[l] * wass
109 | elif degree_j == 0:
110 | wass = 0.
111 | for ii in range(degree_i):
112 | wass += M1[adj1[i][ii], n2]
113 | M[i, j] = D[i, j] + w[l] * wass
114 | # otherwise, calculate the tree distance
115 | else:
116 | max_degree = max(degree_i, degree_j)
117 | if degree_i < max_degree:
118 | cost = np.zeros((degree_i + 1, degree_j))
119 | cost[degree_i] = M1[n1, adj2[j]]
120 | dist_1, dist_2 = np.ones(degree_i + 1), np.ones(degree_j)
121 | dist_1[degree_i] = max_degree - float(degree_i)
122 | else:
123 | cost = np.zeros((degree_i, degree_j + 1))
124 | cost[:, degree_j] = M1[adj1[i], n2]
125 | dist_1, dist_2 = np.ones(degree_i), np.ones(degree_j + 1)
126 | dist_2[degree_j] = max_degree - float(degree_j)
127 | for ii in range(degree_i):
128 | for jj in range(degree_j):
129 | cost[ii, jj] = M1[adj1[i][ii], adj2[j][jj]]
130 | wass = ot.emd2(dist_1, dist_2, cost)
131 |
132 | # summarize TMD at level l
133 | M[i, j] = D[i, j] + w[l] * wass
134 |
135 | # fill in dist w.r.t. blank node
136 | for i in range(n1):
137 | try:
138 | degree_i = len(adj1[i])
139 | except:
140 | degree_i = 0
141 |
142 | if degree_i == 0:
143 | M[i, n2] = torch.norm(feat1[i])
144 | else:
145 | wass = 0.
146 | for ii in range(degree_i):
147 | wass += M1[adj1[i][ii], n2]
148 | M[i, n2] = torch.norm(feat1[i]) + w[l] * wass
149 |
150 | for j in range(n2):
151 | try:
152 | degree_j = len(adj2[j])
153 | except:
154 | degree_j = 0
155 | if degree_j == 0:
156 | M[n1, j] = torch.norm(feat2[j])
157 | else:
158 | wass = 0.
159 | for jj in range(degree_j):
160 | wass += M1[n1, adj2[j][jj]]
161 | M[n1, j] = torch.norm(feat2[j]) + w[l] * wass
162 |
163 |
164 | # final OT cost
165 | max_n = max(n1, n2)
166 | dist_1, dist_2 = np.ones(n1+1), np.ones(n2+1)
167 | if n1 < max_n:
168 | dist_1[n1] = max_n - float(n1)
169 | dist_2[n2] = 0.
170 | else:
171 | dist_1[n1] = 0.
172 | dist_2[n2] = max_n - float(n2)
173 |
174 | wass = ot.emd2(dist_1, dist_2, M)
175 | return wass
176 |
177 |
178 |
--------------------------------------------------------------------------------