├── .gitignore
├── graph
├── __init__.py
├── main.py
└── train_eval.py
├── image
├── __init__.py
├── train_eval.py
└── main.py
├── conv
├── __init__.py
├── inits.py
└── gmm_conv.py
├── manifolds
├── __init__.py
├── train_eval.py
└── main.py
├── LICENSE
├── svgs
├── 129c5b884ff47d80be4d6261a476e9f1.svg
├── 6fccf0465699020081a15631f4a45ae1.svg
├── 796df3d6b2c0926fcde961fd14b100e7.svg
├── aff3fd40bc3e8b5ce3ad3f61175cb17a.svg
├── 9284e17b2f479e052a85e111d9f17ce1.svg
├── e0eef981c0301bb88a01a36ec17cfd0c.svg
├── 1276e542ca3d1d00fd30f0383afb5d08.svg
├── f1cee86600f26eed52126ed72d2dfdd8.svg
├── ab02bf3a35bf706f5ef8c322af45f43e.svg
├── 1c07d8ffda7593d98eda6d17de7db825.svg
├── 72c8e03edc97e002a73695ec08e30d5b.svg
└── b0e0a2e33abfab591a8f7e7f6854ae83.svg
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | **/out/
3 | *.pyc
4 | data/
5 | INPUT.md
6 | results.txt
7 |
--------------------------------------------------------------------------------
/graph/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_eval import run
2 |
3 | __all__ = [
4 | 'run',
5 | ]
6 |
--------------------------------------------------------------------------------
/image/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_eval import run
2 |
3 | __all__ = [
4 | 'run',
5 | ]
6 |
--------------------------------------------------------------------------------
/conv/__init__.py:
--------------------------------------------------------------------------------
1 | from .gmm_conv import GMMConv
2 |
3 | __all__ = [
4 | 'GMMConv',
5 | ]
6 |
--------------------------------------------------------------------------------
/manifolds/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_eval import run
2 |
3 | __all__ = [
4 | 'run',
5 | ]
6 |
--------------------------------------------------------------------------------
/conv/inits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def uniform(size, tensor):
5 | bound = 1.0 / math.sqrt(size)
6 | if tensor is not None:
7 | tensor.data.uniform_(-bound, bound)
8 |
9 |
10 | def kaiming_uniform(tensor, fan, a):
11 | if tensor is not None:
12 | bound = math.sqrt(6 / ((1 + a**2) * fan))
13 | tensor.data.uniform_(-bound, bound)
14 |
15 |
16 | def glorot(tensor):
17 | if tensor is not None:
18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
19 | tensor.data.uniform_(-stdv, stdv)
20 |
21 |
22 | def zeros(tensor):
23 | if tensor is not None:
24 | tensor.data.fill_(0)
25 |
26 |
27 | def ones(tensor):
28 | if tensor is not None:
29 | tensor.data.fill_(1)
30 |
31 |
32 | def reset(nn):
33 | def _reset(item):
34 | if hasattr(item, 'reset_parameters'):
35 | item.reset_parameters()
36 |
37 | if nn is not None:
38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0:
39 | for item in nn.children():
40 | _reset(item)
41 | else:
42 | _reset(nn)
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Shunwang Gong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/image/train_eval.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def print_info(info):
7 | message = ('Epoch: {}/{}, Duration: {:.3f}s, Train Loss: {:.4f}, '
8 | 'Test Loss: {:.4f}, Test Acc: {:.4f}').format(
9 | info['current_epoch'], info['epochs'], info['t_duration'],
10 | info['train_loss'], info['test_loss'], info['acc'])
11 | print(message)
12 |
13 |
14 | def run(model, epochs, train_loader, test_loader, optimizer, scheduler,
15 | device):
16 | for epoch in range(1, epochs + 1):
17 | t = time.time()
18 | train_loss = train(model, train_loader, optimizer, device)
19 | t_duration = time.time() - t
20 | scheduler.step()
21 | acc, test_loss = test(model, test_loader, device)
22 |
23 | info = {
24 | 'train_loss': train_loss,
25 | 'test_loss': test_loss,
26 | 'acc': acc,
27 | 'current_epoch': epoch,
28 | 'epochs': epochs,
29 | 't_duration': t_duration
30 | }
31 |
32 | print_info(info)
33 |
34 |
35 | def train(model, train_loader, optimizer, device):
36 | model.train()
37 |
38 | total_loss = 0
39 | for data in train_loader:
40 | optimizer.zero_grad()
41 | data = data.to(device)
42 | loss = F.nll_loss(model(data), data.y)
43 | loss.backward()
44 | optimizer.step()
45 | total_loss += loss.item()
46 | return total_loss / len(train_loader)
47 |
48 |
49 | def test(model, test_loader, device):
50 | model.eval()
51 |
52 | correct = 0
53 | total_loss = 0
54 | with torch.no_grad():
55 | for idx, data in enumerate(test_loader):
56 | data = data.to(device)
57 | out = model(data)
58 | total_loss += F.nll_loss(out, data.y).item()
59 | pred = out.max(1)[1]
60 | correct += pred.eq(data.y).sum().item()
61 | return correct / len(test_loader.dataset), total_loss / len(test_loader)
62 |
--------------------------------------------------------------------------------
/svgs/129c5b884ff47d80be4d6261a476e9f1.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/manifolds/train_eval.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def print_info(info):
7 | message = ('Epoch: {}/{}, Duration: {:.3f}s, ACC: {:.4f}, '
8 | 'Train Loss: {:.4f}, Test Loss:{:.4f}').format(
9 | info['current_epoch'], info['epochs'], info['t_duration'],
10 | info['acc'], info['train_loss'], info['test_loss'])
11 | print(message)
12 |
13 |
14 | def run(model, train_loader, test_loader, target, num_nodes, epochs, optimizer,
15 | scheduler, device):
16 |
17 | for epoch in range(1, epochs + 1):
18 | t = time.time()
19 | train_loss = train(model, train_loader, target, optimizer, device)
20 | t_duration = time.time() - t
21 | scheduler.step()
22 | acc, test_loss = test(model, test_loader, num_nodes, target, device)
23 | eval_info = {
24 | 'train_loss': train_loss,
25 | 'test_loss': test_loss,
26 | 'acc': acc,
27 | 'current_epoch': epoch,
28 | 'epochs': epochs,
29 | 't_duration': t_duration
30 | }
31 |
32 | print_info(eval_info)
33 |
34 |
35 | def train(model, train_loader, target, optimizer, device):
36 | model.train()
37 |
38 | total_loss = 0
39 | for idx, data in enumerate(train_loader):
40 | optimizer.zero_grad()
41 | loss = F.nll_loss(model(data.to(device)), target)
42 | loss.backward()
43 | optimizer.step()
44 | total_loss += loss.item()
45 | return total_loss / len(train_loader)
46 |
47 |
48 | def test(model, test_loader, num_nodes, target, device):
49 | model.eval()
50 | correct = 0
51 | total_loss = 0
52 | n_graphs = 0
53 | with torch.no_grad():
54 | for idx, data in enumerate(test_loader):
55 | out = model(data.to(device))
56 | total_loss += F.nll_loss(out, target).item()
57 | pred = out.max(1)[1]
58 | correct += pred.eq(target).sum().item()
59 | n_graphs += data.num_graphs
60 | return correct / (n_graphs * num_nodes), total_loss / len(test_loader)
61 |
--------------------------------------------------------------------------------
/svgs/6fccf0465699020081a15631f4a45ae1.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/conv/gmm_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_geometric.nn.conv import MessagePassing
4 | from .inits import reset, glorot, zeros
5 |
6 | EPS = 1e-15
7 |
8 |
9 | class GMMConv(MessagePassing):
10 | def __init__(self,
11 | in_channels,
12 | out_channels,
13 | dim,
14 | kernel_size,
15 | bias=True,
16 | **kwargs):
17 | super(GMMConv, self).__init__(aggr='add', **kwargs)
18 |
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | self.dim = dim
22 | self.kernel_size = kernel_size
23 |
24 | self.lin = torch.nn.Linear(in_channels,
25 | out_channels * kernel_size,
26 | bias=False)
27 | self.mu = Parameter(torch.Tensor(kernel_size, dim))
28 | self.sigma = Parameter(torch.Tensor(kernel_size, dim))
29 | if bias:
30 | self.bias = Parameter(torch.Tensor(out_channels))
31 | else:
32 | self.register_parameter('bias', None)
33 |
34 | self.reset_parameters()
35 |
36 | def reset_parameters(self):
37 | glorot(self.mu)
38 | glorot(self.sigma)
39 | zeros(self.bias)
40 | reset(self.lin)
41 |
42 | def forward(self, x, edge_index, pseudo):
43 | x = x.unsqueeze(-1) if x.dim() == 1 else x
44 | pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
45 |
46 | out = self.lin(x).view(-1, self.kernel_size, self.out_channels)
47 | out = self.propagate(edge_index, x=out, pseudo=pseudo)
48 |
49 | if self.bias is not None:
50 | out = out + self.bias
51 | return out
52 |
53 | def message(self, x_j, pseudo):
54 | (E, D), K = pseudo.size(), self.mu.size(0)
55 |
56 | gaussian = -0.5 * (pseudo.view(E, 1, D) - self.mu.view(1, K, D))**2
57 | gaussian = gaussian / (EPS + self.sigma.view(1, K, D)**2)
58 | gaussian = torch.exp(gaussian.sum(dim=-1, keepdim=True)) # [E, K, 1]
59 |
60 | return (x_j * gaussian).sum(dim=1)
61 |
62 | def __repr__(self):
63 | return '{}({}, {}, kernel_size={})'.format(self.__class__.__name__,
64 | self.in_channels,
65 | self.out_channels,
66 | self.kernel_size)
67 |
--------------------------------------------------------------------------------
/graph/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import torch_geometric.transforms as T
7 | from torch_geometric.datasets import Planetoid
8 | from torch_geometric.utils import degree
9 |
10 | from graph import run
11 | from conv import GMMConv
12 |
13 | parser = argparse.ArgumentParser(description='Cora citation network')
14 | parser.add_argument('--dataset', type=str, default='Cora')
15 | parser.add_argument('--device_idx', type=int, default=1)
16 | parser.add_argument('--runs', type=int, default=10)
17 | parser.add_argument('--early_stopping', type=int, default=50)
18 | parser.add_argument('--epochs', type=int, default=3000)
19 | parser.add_argument('--kernel_size', type=int, default=16)
20 | parser.add_argument('--lr', type=float, default=1e-1)
21 | parser.add_argument('--weight_decay', type=float, default=1e-2)
22 | parser.add_argument('--hidden', type=int, default=16)
23 | parser.add_argument('--dropout', type=float, default=0.5)
24 | args = parser.parse_args()
25 | print(args)
26 |
27 | def transform(data):
28 | row, col = data.edge_index
29 | deg = degree(col, data.num_nodes)
30 | data.edge_attr = torch.stack(
31 | [1 / torch.sqrt(deg[row]), 1 / torch.sqrt(deg[col])], dim=-1)
32 | return data
33 |
34 |
35 | class MoNet(torch.nn.Module):
36 | def __init__(self, dataset):
37 | super(MoNet, self).__init__()
38 | self.conv1 = GMMConv(dataset.num_features,
39 | args.hidden,
40 | dim=2,
41 | kernel_size=args.kernel_size)
42 | self.conv2 = GMMConv(args.hidden,
43 | dataset.num_classes,
44 | dim=2,
45 | kernel_size=args.kernel_size)
46 |
47 | def reset_parameters(self):
48 | self.conv1.reset_parameters()
49 | self.conv2.reset_parameters()
50 |
51 | def forward(self, data):
52 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
53 | x = F.dropout(x, p=args.dropout, training=self.training)
54 | x = F.elu(self.conv1(x, edge_index, edge_attr))
55 | x = F.dropout(x, p=args.dropout, training=self.training)
56 | x = self.conv2(x, edge_index, edge_attr)
57 | return F.log_softmax(x, dim=1)
58 |
59 |
60 | device = torch.device('cuda', args.device_idx)
61 | args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
62 | args.dataset)
63 | dataset = Planetoid(args.data_fp, args.dataset)
64 | dataset.transform = transform
65 | run(dataset, MoNet(dataset), args.runs, args.epochs, args.lr,
66 | args.weight_decay, args.early_stopping, device)
67 |
--------------------------------------------------------------------------------
/graph/train_eval.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import tensor
6 | from torch.optim import Adam
7 |
8 |
9 | def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
10 | device):
11 |
12 | val_losses, accs, durations = [], [], []
13 | for _ in range(runs):
14 | data = dataset[0]
15 | data = data.to(device)
16 |
17 | model.to(device).reset_parameters()
18 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
19 |
20 | if torch.cuda.is_available():
21 | torch.cuda.synchronize()
22 |
23 | t_start = time.perf_counter()
24 |
25 | best_val_loss = float('inf')
26 | test_acc = 0
27 | val_loss_history = []
28 |
29 | for epoch in range(1, epochs + 1):
30 | if epoch == 1500:
31 | for param_group in optimizer.param_groups:
32 | param_group['lr'] = param_group['lr'] * 0.1
33 | elif epoch == 2500:
34 | for param_group in optimizer.param_groups:
35 | param_group['lr'] = param_group['lr'] * 0.1
36 |
37 | train(model, optimizer, data)
38 | eval_info = evaluate(model, data)
39 | eval_info['epoch'] = epoch
40 |
41 | if eval_info['val_loss'] < best_val_loss:
42 | best_val_loss = eval_info['val_loss']
43 | test_acc = eval_info['test_acc']
44 |
45 | val_loss_history.append(eval_info['val_loss'])
46 | if early_stopping > 0 and epoch > epochs // 2:
47 | tmp = tensor(val_loss_history[-(early_stopping + 1):-1])
48 | if eval_info['val_loss'] > tmp.mean().item():
49 | break
50 |
51 | if torch.cuda.is_available():
52 | torch.cuda.synchronize()
53 |
54 | t_end = time.perf_counter()
55 |
56 | val_losses.append(best_val_loss)
57 | accs.append(test_acc)
58 | durations.append(t_end - t_start)
59 |
60 | loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)
61 |
62 | print('Val Loss: {:.4f}, Test Accuracy: {:.3f} ± {:.3f}, Duration: {:.3f}'.
63 | format(loss.mean().item(),
64 | acc.mean().item(),
65 | acc.std().item(),
66 | duration.mean().item()))
67 |
68 |
69 | def train(model, optimizer, data):
70 | model.train()
71 | optimizer.zero_grad()
72 | out = model(data)
73 | loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
74 | loss.backward()
75 | optimizer.step()
76 |
77 |
78 | def evaluate(model, data):
79 | model.eval()
80 |
81 | with torch.no_grad():
82 | logits = model(data)
83 |
84 | outs = {}
85 | for key in ['train', 'val', 'test']:
86 | mask = data['{}_mask'.format(key)]
87 | loss = F.nll_loss(logits[mask], data.y[mask]).item()
88 | pred = logits[mask].max(1)[1]
89 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
90 |
91 | outs['{}_loss'.format(key)] = loss
92 | outs['{}_acc'.format(key)] = acc
93 |
94 | return outs
95 |
--------------------------------------------------------------------------------
/svgs/796df3d6b2c0926fcde961fd14b100e7.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/svgs/aff3fd40bc3e8b5ce3ad3f61175cb17a.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/image/main.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os.path as osp
3 | import argparse
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | from torch_geometric.datasets import MNISTSuperpixels
9 | import torch_geometric.transforms as T
10 | from torch_geometric.data import DataLoader
11 | from torch_geometric.utils import normalized_cut
12 | from torch_geometric.nn import (graclus, max_pool, global_mean_pool)
13 | from conv import GMMConv
14 | from image import run
15 |
16 | parser = argparse.ArgumentParser(description='superpixel MNIST')
17 | parser.add_argument('--dataset', default='MNIST', type=str)
18 | parser.add_argument('--device_idx', default=3, type=int)
19 | parser.add_argument('--kernel_size', default=25, type=int)
20 | parser.add_argument('--lr', type=float, default=1e-4)
21 | parser.add_argument('--lr_decay', type=float, default=0.99)
22 | parser.add_argument('--decay_step', type=int, default=1)
23 | parser.add_argument('--weight_decay', type=float, default=5e-4)
24 | parser.add_argument('--batch_size', type=int, default=64)
25 | parser.add_argument('--epochs', type=int, default=300)
26 | parser.add_argument('--seed', type=int, default=1)
27 | args = parser.parse_args()
28 |
29 | args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
30 | args.dataset)
31 | device = torch.device('cuda', args.device_idx)
32 |
33 | # deterministic
34 | torch.manual_seed(args.seed)
35 | cudnn.benchmark = False
36 | cudnn.deterministic = True
37 |
38 | train_dataset = MNISTSuperpixels(args.data_fp, True, pre_transform=T.Polar())
39 | test_dataset = MNISTSuperpixels(args.data_fp, False, pre_transform=T.Polar())
40 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
41 | test_loader = DataLoader(test_dataset, batch_size=64)
42 |
43 |
44 | def normalized_cut_2d(edge_index, pos):
45 | row, col = edge_index
46 | edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
47 | return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
48 |
49 |
50 | class MoNet(torch.nn.Module):
51 | def __init__(self, kernel_size):
52 | super(MoNet, self).__init__()
53 | self.conv1 = GMMConv(1, 32, dim=2, kernel_size=kernel_size)
54 | self.conv2 = GMMConv(32, 64, dim=2, kernel_size=kernel_size)
55 | self.conv3 = GMMConv(64, 64, dim=2, kernel_size=kernel_size)
56 | self.fc1 = torch.nn.Linear(64, 128)
57 | self.fc2 = torch.nn.Linear(128, 10)
58 |
59 | def forward(self, data):
60 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
61 | weight = normalized_cut_2d(data.edge_index, data.pos)
62 | cluster = graclus(data.edge_index, weight, data.x.size(0))
63 | data.edge_attr = None
64 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False))
65 |
66 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
67 | weight = normalized_cut_2d(data.edge_index, data.pos)
68 | cluster = graclus(data.edge_index, weight, data.x.size(0))
69 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False))
70 |
71 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
72 |
73 | x = global_mean_pool(data.x, data.batch)
74 | x = F.elu(self.fc1(x))
75 | x = F.dropout(x, training=self.training)
76 | return F.log_softmax(self.fc2(x), dim=1)
77 |
78 |
79 | model = MoNet(args.kernel_size).to(device)
80 | optimizer = torch.optim.Adam(model.parameters(),
81 | lr=args.lr,
82 | weight_decay=args.weight_decay)
83 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
84 | args.decay_step,
85 | gamma=args.lr_decay)
86 | print(model)
87 |
88 | run(model, args.epochs, train_loader, test_loader, optimizer, scheduler,
89 | device)
90 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Gaussian Mixture Model Convolutional Networks
4 |
5 | This is a Pytorch implementation of Gaussian Mixture Model Convolutional Networks (MoNet) for the tasks of image classification, vertex classification on generic graphs, and dense intrinsic shape correspondence, as described in the paper:
6 |
7 | Monti *et al*, [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402) (CVPR 2017)
8 |
9 | Following the same network architecture provided in the paper, our implementation produces results **comparable** to or **better** than those shown in the paper. Note that for the tasks of image classification and shape correspondence, we do not use polar coordinates but replacing it as relative cartesian coordinates
. It eases the pain of the both computational and space cost from data preprocessing.
10 |
11 | ## Requirements
12 | * [Pytorch](https://pytorch.org/) (1.3.0)
13 | * [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric) (1.3.0)
14 |
15 | ## MoNet
16 |
17 | MoNet uses a local system of pseudo-coordinates
around to represent the neighborhood
and a family of learnable weighting functions w.r.t.
, e.g., Gaussian kernels
with learnable mean
and covariance
. The convolution is
18 |

19 |
20 | where
is the learnable filter weights and
is the node feature vector.
21 |
22 | We provide efficient Pytorch implementation of this operator ``GMMConv``, which is accessible from ``Pytorch Geometric``.
23 |
24 | ## Superpixel MNIST Classification
25 |
26 | ```
27 | python -m image.main
28 | ```
29 |
30 |
31 | ## Vertex Classification
32 |
33 | ```
34 | python -m graph.main
35 | ```
36 |
37 | ## Dense Shape Correspondence
38 |
39 | ```
40 | python -m correspondence.main
41 | ```
42 |
43 |
44 | ## Data
45 |
46 | In order to use your own dataset, you can simply create a regular python list holding `torch_geometric.data.Data` objects and specify the following attributes:
47 |
48 | - ``data.x``: Node feature matrix with shape ``[num_nodes, num_node_features]``
49 | - ``data.edge_index``: Graph connectivity in COO format with shape ``[2, num_edges]`` and type ``torch.long``
50 | - ``data.edge_attr``: Pesudo-coordinates with shape ``[num_edges, pesudo-coordinates-dim]``
51 | - ``data.y``: Target to train against
52 |
53 |
54 | ## Cite
55 |
56 | Please cite [this paper](https://arxiv.org/abs/1611.08402) if you use this code in your own work:
57 |
58 | ```
59 | @inproceedings{monti2017geometric,
60 | title={Geometric deep learning on graphs and manifolds using mixture model cnns},
61 | author={Monti, Federico and Boscaini, Davide and Masci, Jonathan and Rodola, Emanuele and Svoboda, Jan and Bronstein, Michael M},
62 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
63 | pages={5115--5124},
64 | year={2017}
65 | }
66 | ```
67 |
--------------------------------------------------------------------------------
/svgs/9284e17b2f479e052a85e111d9f17ce1.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/manifolds/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | import torch.backends.cudnn as cudnn
9 | from torch_geometric.datasets import FAUST
10 | from torch_geometric.data import DataLoader
11 | import torch_geometric.transforms as T
12 |
13 | from conv import GMMConv
14 | from manifolds import run
15 |
16 | parser = argparse.ArgumentParser(description='shape correspondence')
17 | parser.add_argument('--dataset', type=str, default='FAUST')
18 | parser.add_argument('--device_idx', type=int, default=4)
19 | parser.add_argument('--n_threads', type=int, default=4)
20 | parser.add_argument('--kernel_size', type=int, default=10)
21 | parser.add_argument('--lr', type=float, default=3e-3)
22 | parser.add_argument('--lr_decay', type=float, default=0.99)
23 | parser.add_argument('--decay_step', type=int, default=1)
24 | parser.add_argument('--weight_decay', type=float, default=5e-5)
25 | parser.add_argument('--epochs', type=int, default=500)
26 | parser.add_argument('--seed', type=int, default=1)
27 | args = parser.parse_args()
28 |
29 | args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
30 | args.dataset)
31 | device = torch.device('cuda', args.device_idx)
32 | torch.set_num_threads(args.n_threads)
33 |
34 | # deterministic
35 | torch.manual_seed(args.seed)
36 | cudnn.benchmark = False
37 | cudnn.deterministic = True
38 |
39 |
40 | class Pre_Transform(object):
41 | def __call__(self, data):
42 | data.x = data.pos
43 | data = T.FaceToEdge()(data)
44 |
45 | return data
46 |
47 |
48 | train_dataset = FAUST(args.data_fp,
49 | True,
50 | transform=T.Cartesian(),
51 | pre_transform=Pre_Transform())
52 | test_dataset = FAUST(args.data_fp,
53 | False,
54 | transform=T.Cartesian(),
55 | pre_transform=Pre_Transform())
56 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
57 | test_loader = DataLoader(test_dataset, batch_size=1)
58 | d = train_dataset[0]
59 | target = torch.arange(d.num_nodes, dtype=torch.long, device=device)
60 | print(d)
61 |
62 |
63 | class MoNet(nn.Module):
64 | def __init__(self, in_channels, num_classes, kernel_size):
65 | super(MoNet, self).__init__()
66 |
67 | self.fc0 = nn.Linear(in_channels, 16)
68 | self.conv1 = GMMConv(16, 32, dim=3, kernel_size=kernel_size)
69 | self.conv2 = GMMConv(32, 64, dim=3, kernel_size=kernel_size)
70 | self.conv3 = GMMConv(64, 128, dim=3, kernel_size=kernel_size)
71 | self.fc1 = nn.Linear(128, 256)
72 | self.fc2 = nn.Linear(256, num_classes)
73 |
74 | self.reset_parameters()
75 |
76 | def reset_parameters(self):
77 | self.conv1.reset_parameters()
78 | self.conv2.reset_parameters()
79 | self.conv3.reset_parameters()
80 | nn.init.xavier_uniform_(self.fc0.weight, gain=1)
81 | nn.init.xavier_uniform_(self.fc1.weight, gain=1)
82 | nn.init.xavier_uniform_(self.fc2.weight, gain=1)
83 | nn.init.constant_(self.fc0.bias, 0)
84 | nn.init.constant_(self.fc1.bias, 0)
85 | nn.init.constant_(self.fc2.bias, 0)
86 |
87 | def forward(self, data):
88 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
89 | x = F.elu(self.fc0(x))
90 | x = F.elu(self.conv1(x, edge_index, edge_attr))
91 | x = F.elu(self.conv2(x, edge_index, edge_attr))
92 | x = F.elu(self.conv3(x, edge_index, edge_attr))
93 | x = F.elu(self.fc1(x))
94 | x = F.dropout(x, training=self.training)
95 | x = self.fc2(x)
96 | return F.log_softmax(x, dim=1)
97 |
98 |
99 | model = MoNet(d.num_features, d.num_nodes, args.kernel_size).to(device)
100 | print(model)
101 | optimizer = optim.Adam(model.parameters(),
102 | lr=args.lr,
103 | weight_decay=args.weight_decay)
104 | scheduler = optim.lr_scheduler.StepLR(optimizer,
105 | args.decay_step,
106 | gamma=args.lr_decay)
107 |
108 | run(model, train_loader, test_loader, target, d.num_nodes, args.epochs,
109 | optimizer, scheduler, device)
110 |
--------------------------------------------------------------------------------
/svgs/e0eef981c0301bb88a01a36ec17cfd0c.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/svgs/1276e542ca3d1d00fd30f0383afb5d08.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/svgs/f1cee86600f26eed52126ed72d2dfdd8.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/svgs/ab02bf3a35bf706f5ef8c322af45f43e.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/svgs/1c07d8ffda7593d98eda6d17de7db825.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/svgs/72c8e03edc97e002a73695ec08e30d5b.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/svgs/b0e0a2e33abfab591a8f7e7f6854ae83.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------