├── .gitignore ├── graph ├── __init__.py ├── main.py └── train_eval.py ├── image ├── __init__.py ├── train_eval.py └── main.py ├── conv ├── __init__.py ├── inits.py └── gmm_conv.py ├── manifolds ├── __init__.py ├── train_eval.py └── main.py ├── LICENSE ├── svgs ├── 129c5b884ff47d80be4d6261a476e9f1.svg ├── 6fccf0465699020081a15631f4a45ae1.svg ├── 796df3d6b2c0926fcde961fd14b100e7.svg ├── aff3fd40bc3e8b5ce3ad3f61175cb17a.svg ├── 9284e17b2f479e052a85e111d9f17ce1.svg ├── e0eef981c0301bb88a01a36ec17cfd0c.svg ├── 1276e542ca3d1d00fd30f0383afb5d08.svg ├── f1cee86600f26eed52126ed72d2dfdd8.svg ├── ab02bf3a35bf706f5ef8c322af45f43e.svg ├── 1c07d8ffda7593d98eda6d17de7db825.svg ├── 72c8e03edc97e002a73695ec08e30d5b.svg └── b0e0a2e33abfab591a8f7e7f6854ae83.svg └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | **/out/ 3 | *.pyc 4 | data/ 5 | INPUT.md 6 | results.txt 7 | -------------------------------------------------------------------------------- /graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_eval import run 2 | 3 | __all__ = [ 4 | 'run', 5 | ] 6 | -------------------------------------------------------------------------------- /image/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_eval import run 2 | 3 | __all__ = [ 4 | 'run', 5 | ] 6 | -------------------------------------------------------------------------------- /conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .gmm_conv import GMMConv 2 | 3 | __all__ = [ 4 | 'GMMConv', 5 | ] 6 | -------------------------------------------------------------------------------- /manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_eval import run 2 | 3 | __all__ = [ 4 | 'run', 5 | ] 6 | -------------------------------------------------------------------------------- /conv/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | if tensor is not None: 12 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | if tensor is not None: 18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) 30 | 31 | 32 | def reset(nn): 33 | def _reset(item): 34 | if hasattr(item, 'reset_parameters'): 35 | item.reset_parameters() 36 | 37 | if nn is not None: 38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 39 | for item in nn.children(): 40 | _reset(item) 41 | else: 42 | _reset(nn) 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shunwang Gong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /image/train_eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def print_info(info): 7 | message = ('Epoch: {}/{}, Duration: {:.3f}s, Train Loss: {:.4f}, ' 8 | 'Test Loss: {:.4f}, Test Acc: {:.4f}').format( 9 | info['current_epoch'], info['epochs'], info['t_duration'], 10 | info['train_loss'], info['test_loss'], info['acc']) 11 | print(message) 12 | 13 | 14 | def run(model, epochs, train_loader, test_loader, optimizer, scheduler, 15 | device): 16 | for epoch in range(1, epochs + 1): 17 | t = time.time() 18 | train_loss = train(model, train_loader, optimizer, device) 19 | t_duration = time.time() - t 20 | scheduler.step() 21 | acc, test_loss = test(model, test_loader, device) 22 | 23 | info = { 24 | 'train_loss': train_loss, 25 | 'test_loss': test_loss, 26 | 'acc': acc, 27 | 'current_epoch': epoch, 28 | 'epochs': epochs, 29 | 't_duration': t_duration 30 | } 31 | 32 | print_info(info) 33 | 34 | 35 | def train(model, train_loader, optimizer, device): 36 | model.train() 37 | 38 | total_loss = 0 39 | for data in train_loader: 40 | optimizer.zero_grad() 41 | data = data.to(device) 42 | loss = F.nll_loss(model(data), data.y) 43 | loss.backward() 44 | optimizer.step() 45 | total_loss += loss.item() 46 | return total_loss / len(train_loader) 47 | 48 | 49 | def test(model, test_loader, device): 50 | model.eval() 51 | 52 | correct = 0 53 | total_loss = 0 54 | with torch.no_grad(): 55 | for idx, data in enumerate(test_loader): 56 | data = data.to(device) 57 | out = model(data) 58 | total_loss += F.nll_loss(out, data.y).item() 59 | pred = out.max(1)[1] 60 | correct += pred.eq(data.y).sum().item() 61 | return correct / len(test_loader.dataset), total_loss / len(test_loader) 62 | -------------------------------------------------------------------------------- /svgs/129c5b884ff47d80be4d6261a476e9f1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /manifolds/train_eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def print_info(info): 7 | message = ('Epoch: {}/{}, Duration: {:.3f}s, ACC: {:.4f}, ' 8 | 'Train Loss: {:.4f}, Test Loss:{:.4f}').format( 9 | info['current_epoch'], info['epochs'], info['t_duration'], 10 | info['acc'], info['train_loss'], info['test_loss']) 11 | print(message) 12 | 13 | 14 | def run(model, train_loader, test_loader, target, num_nodes, epochs, optimizer, 15 | scheduler, device): 16 | 17 | for epoch in range(1, epochs + 1): 18 | t = time.time() 19 | train_loss = train(model, train_loader, target, optimizer, device) 20 | t_duration = time.time() - t 21 | scheduler.step() 22 | acc, test_loss = test(model, test_loader, num_nodes, target, device) 23 | eval_info = { 24 | 'train_loss': train_loss, 25 | 'test_loss': test_loss, 26 | 'acc': acc, 27 | 'current_epoch': epoch, 28 | 'epochs': epochs, 29 | 't_duration': t_duration 30 | } 31 | 32 | print_info(eval_info) 33 | 34 | 35 | def train(model, train_loader, target, optimizer, device): 36 | model.train() 37 | 38 | total_loss = 0 39 | for idx, data in enumerate(train_loader): 40 | optimizer.zero_grad() 41 | loss = F.nll_loss(model(data.to(device)), target) 42 | loss.backward() 43 | optimizer.step() 44 | total_loss += loss.item() 45 | return total_loss / len(train_loader) 46 | 47 | 48 | def test(model, test_loader, num_nodes, target, device): 49 | model.eval() 50 | correct = 0 51 | total_loss = 0 52 | n_graphs = 0 53 | with torch.no_grad(): 54 | for idx, data in enumerate(test_loader): 55 | out = model(data.to(device)) 56 | total_loss += F.nll_loss(out, target).item() 57 | pred = out.max(1)[1] 58 | correct += pred.eq(target).sum().item() 59 | n_graphs += data.num_graphs 60 | return correct / (n_graphs * num_nodes), total_loss / len(test_loader) 61 | -------------------------------------------------------------------------------- /svgs/6fccf0465699020081a15631f4a45ae1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /conv/gmm_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.nn.conv import MessagePassing 4 | from .inits import reset, glorot, zeros 5 | 6 | EPS = 1e-15 7 | 8 | 9 | class GMMConv(MessagePassing): 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | dim, 14 | kernel_size, 15 | bias=True, 16 | **kwargs): 17 | super(GMMConv, self).__init__(aggr='add', **kwargs) 18 | 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.dim = dim 22 | self.kernel_size = kernel_size 23 | 24 | self.lin = torch.nn.Linear(in_channels, 25 | out_channels * kernel_size, 26 | bias=False) 27 | self.mu = Parameter(torch.Tensor(kernel_size, dim)) 28 | self.sigma = Parameter(torch.Tensor(kernel_size, dim)) 29 | if bias: 30 | self.bias = Parameter(torch.Tensor(out_channels)) 31 | else: 32 | self.register_parameter('bias', None) 33 | 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self): 37 | glorot(self.mu) 38 | glorot(self.sigma) 39 | zeros(self.bias) 40 | reset(self.lin) 41 | 42 | def forward(self, x, edge_index, pseudo): 43 | x = x.unsqueeze(-1) if x.dim() == 1 else x 44 | pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo 45 | 46 | out = self.lin(x).view(-1, self.kernel_size, self.out_channels) 47 | out = self.propagate(edge_index, x=out, pseudo=pseudo) 48 | 49 | if self.bias is not None: 50 | out = out + self.bias 51 | return out 52 | 53 | def message(self, x_j, pseudo): 54 | (E, D), K = pseudo.size(), self.mu.size(0) 55 | 56 | gaussian = -0.5 * (pseudo.view(E, 1, D) - self.mu.view(1, K, D))**2 57 | gaussian = gaussian / (EPS + self.sigma.view(1, K, D)**2) 58 | gaussian = torch.exp(gaussian.sum(dim=-1, keepdim=True)) # [E, K, 1] 59 | 60 | return (x_j * gaussian).sum(dim=1) 61 | 62 | def __repr__(self): 63 | return '{}({}, {}, kernel_size={})'.format(self.__class__.__name__, 64 | self.in_channels, 65 | self.out_channels, 66 | self.kernel_size) 67 | -------------------------------------------------------------------------------- /graph/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import torch_geometric.transforms as T 7 | from torch_geometric.datasets import Planetoid 8 | from torch_geometric.utils import degree 9 | 10 | from graph import run 11 | from conv import GMMConv 12 | 13 | parser = argparse.ArgumentParser(description='Cora citation network') 14 | parser.add_argument('--dataset', type=str, default='Cora') 15 | parser.add_argument('--device_idx', type=int, default=1) 16 | parser.add_argument('--runs', type=int, default=10) 17 | parser.add_argument('--early_stopping', type=int, default=50) 18 | parser.add_argument('--epochs', type=int, default=3000) 19 | parser.add_argument('--kernel_size', type=int, default=16) 20 | parser.add_argument('--lr', type=float, default=1e-1) 21 | parser.add_argument('--weight_decay', type=float, default=1e-2) 22 | parser.add_argument('--hidden', type=int, default=16) 23 | parser.add_argument('--dropout', type=float, default=0.5) 24 | args = parser.parse_args() 25 | print(args) 26 | 27 | def transform(data): 28 | row, col = data.edge_index 29 | deg = degree(col, data.num_nodes) 30 | data.edge_attr = torch.stack( 31 | [1 / torch.sqrt(deg[row]), 1 / torch.sqrt(deg[col])], dim=-1) 32 | return data 33 | 34 | 35 | class MoNet(torch.nn.Module): 36 | def __init__(self, dataset): 37 | super(MoNet, self).__init__() 38 | self.conv1 = GMMConv(dataset.num_features, 39 | args.hidden, 40 | dim=2, 41 | kernel_size=args.kernel_size) 42 | self.conv2 = GMMConv(args.hidden, 43 | dataset.num_classes, 44 | dim=2, 45 | kernel_size=args.kernel_size) 46 | 47 | def reset_parameters(self): 48 | self.conv1.reset_parameters() 49 | self.conv2.reset_parameters() 50 | 51 | def forward(self, data): 52 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 53 | x = F.dropout(x, p=args.dropout, training=self.training) 54 | x = F.elu(self.conv1(x, edge_index, edge_attr)) 55 | x = F.dropout(x, p=args.dropout, training=self.training) 56 | x = self.conv2(x, edge_index, edge_attr) 57 | return F.log_softmax(x, dim=1) 58 | 59 | 60 | device = torch.device('cuda', args.device_idx) 61 | args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 62 | args.dataset) 63 | dataset = Planetoid(args.data_fp, args.dataset) 64 | dataset.transform = transform 65 | run(dataset, MoNet(dataset), args.runs, args.epochs, args.lr, 66 | args.weight_decay, args.early_stopping, device) 67 | -------------------------------------------------------------------------------- /graph/train_eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import tensor 6 | from torch.optim import Adam 7 | 8 | 9 | def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping, 10 | device): 11 | 12 | val_losses, accs, durations = [], [], [] 13 | for _ in range(runs): 14 | data = dataset[0] 15 | data = data.to(device) 16 | 17 | model.to(device).reset_parameters() 18 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 19 | 20 | if torch.cuda.is_available(): 21 | torch.cuda.synchronize() 22 | 23 | t_start = time.perf_counter() 24 | 25 | best_val_loss = float('inf') 26 | test_acc = 0 27 | val_loss_history = [] 28 | 29 | for epoch in range(1, epochs + 1): 30 | if epoch == 1500: 31 | for param_group in optimizer.param_groups: 32 | param_group['lr'] = param_group['lr'] * 0.1 33 | elif epoch == 2500: 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = param_group['lr'] * 0.1 36 | 37 | train(model, optimizer, data) 38 | eval_info = evaluate(model, data) 39 | eval_info['epoch'] = epoch 40 | 41 | if eval_info['val_loss'] < best_val_loss: 42 | best_val_loss = eval_info['val_loss'] 43 | test_acc = eval_info['test_acc'] 44 | 45 | val_loss_history.append(eval_info['val_loss']) 46 | if early_stopping > 0 and epoch > epochs // 2: 47 | tmp = tensor(val_loss_history[-(early_stopping + 1):-1]) 48 | if eval_info['val_loss'] > tmp.mean().item(): 49 | break 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.synchronize() 53 | 54 | t_end = time.perf_counter() 55 | 56 | val_losses.append(best_val_loss) 57 | accs.append(test_acc) 58 | durations.append(t_end - t_start) 59 | 60 | loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations) 61 | 62 | print('Val Loss: {:.4f}, Test Accuracy: {:.3f} ± {:.3f}, Duration: {:.3f}'. 63 | format(loss.mean().item(), 64 | acc.mean().item(), 65 | acc.std().item(), 66 | duration.mean().item())) 67 | 68 | 69 | def train(model, optimizer, data): 70 | model.train() 71 | optimizer.zero_grad() 72 | out = model(data) 73 | loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) 74 | loss.backward() 75 | optimizer.step() 76 | 77 | 78 | def evaluate(model, data): 79 | model.eval() 80 | 81 | with torch.no_grad(): 82 | logits = model(data) 83 | 84 | outs = {} 85 | for key in ['train', 'val', 'test']: 86 | mask = data['{}_mask'.format(key)] 87 | loss = F.nll_loss(logits[mask], data.y[mask]).item() 88 | pred = logits[mask].max(1)[1] 89 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 90 | 91 | outs['{}_loss'.format(key)] = loss 92 | outs['{}_acc'.format(key)] = acc 93 | 94 | return outs 95 | -------------------------------------------------------------------------------- /svgs/796df3d6b2c0926fcde961fd14b100e7.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /svgs/aff3fd40bc3e8b5ce3ad3f61175cb17a.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /image/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path as osp 3 | import argparse 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | from torch_geometric.datasets import MNISTSuperpixels 9 | import torch_geometric.transforms as T 10 | from torch_geometric.data import DataLoader 11 | from torch_geometric.utils import normalized_cut 12 | from torch_geometric.nn import (graclus, max_pool, global_mean_pool) 13 | from conv import GMMConv 14 | from image import run 15 | 16 | parser = argparse.ArgumentParser(description='superpixel MNIST') 17 | parser.add_argument('--dataset', default='MNIST', type=str) 18 | parser.add_argument('--device_idx', default=3, type=int) 19 | parser.add_argument('--kernel_size', default=25, type=int) 20 | parser.add_argument('--lr', type=float, default=1e-4) 21 | parser.add_argument('--lr_decay', type=float, default=0.99) 22 | parser.add_argument('--decay_step', type=int, default=1) 23 | parser.add_argument('--weight_decay', type=float, default=5e-4) 24 | parser.add_argument('--batch_size', type=int, default=64) 25 | parser.add_argument('--epochs', type=int, default=300) 26 | parser.add_argument('--seed', type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 30 | args.dataset) 31 | device = torch.device('cuda', args.device_idx) 32 | 33 | # deterministic 34 | torch.manual_seed(args.seed) 35 | cudnn.benchmark = False 36 | cudnn.deterministic = True 37 | 38 | train_dataset = MNISTSuperpixels(args.data_fp, True, pre_transform=T.Polar()) 39 | test_dataset = MNISTSuperpixels(args.data_fp, False, pre_transform=T.Polar()) 40 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 41 | test_loader = DataLoader(test_dataset, batch_size=64) 42 | 43 | 44 | def normalized_cut_2d(edge_index, pos): 45 | row, col = edge_index 46 | edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) 47 | return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) 48 | 49 | 50 | class MoNet(torch.nn.Module): 51 | def __init__(self, kernel_size): 52 | super(MoNet, self).__init__() 53 | self.conv1 = GMMConv(1, 32, dim=2, kernel_size=kernel_size) 54 | self.conv2 = GMMConv(32, 64, dim=2, kernel_size=kernel_size) 55 | self.conv3 = GMMConv(64, 64, dim=2, kernel_size=kernel_size) 56 | self.fc1 = torch.nn.Linear(64, 128) 57 | self.fc2 = torch.nn.Linear(128, 10) 58 | 59 | def forward(self, data): 60 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 61 | weight = normalized_cut_2d(data.edge_index, data.pos) 62 | cluster = graclus(data.edge_index, weight, data.x.size(0)) 63 | data.edge_attr = None 64 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) 65 | 66 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 67 | weight = normalized_cut_2d(data.edge_index, data.pos) 68 | cluster = graclus(data.edge_index, weight, data.x.size(0)) 69 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) 70 | 71 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 72 | 73 | x = global_mean_pool(data.x, data.batch) 74 | x = F.elu(self.fc1(x)) 75 | x = F.dropout(x, training=self.training) 76 | return F.log_softmax(self.fc2(x), dim=1) 77 | 78 | 79 | model = MoNet(args.kernel_size).to(device) 80 | optimizer = torch.optim.Adam(model.parameters(), 81 | lr=args.lr, 82 | weight_decay=args.weight_decay) 83 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 84 | args.decay_step, 85 | gamma=args.lr_decay) 86 | print(model) 87 | 88 | run(model, args.epochs, train_loader, test_loader, optimizer, scheduler, 89 | device) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Gaussian Mixture Model Convolutional Networks 4 | 5 | This is a Pytorch implementation of Gaussian Mixture Model Convolutional Networks (MoNet) for the tasks of image classification, vertex classification on generic graphs, and dense intrinsic shape correspondence, as described in the paper: 6 | 7 | Monti *et al*, [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402) (CVPR 2017) 8 | 9 | Following the same network architecture provided in the paper, our implementation produces results **comparable** to or **better** than those shown in the paper. Note that for the tasks of image classification and shape correspondence, we do not use polar coordinates but replacing it as relative cartesian coordinates . It eases the pain of the both computational and space cost from data preprocessing. 10 | 11 | ## Requirements 12 | * [Pytorch](https://pytorch.org/) (1.3.0) 13 | * [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric) (1.3.0) 14 | 15 | ## MoNet 16 | 17 | MoNet uses a local system of pseudo-coordinates around to represent the neighborhood and a family of learnable weighting functions w.r.t. , e.g., Gaussian kernels with learnable mean and covariance . The convolution is 18 |

19 | 20 | where is the learnable filter weights and is the node feature vector. 21 | 22 | We provide efficient Pytorch implementation of this operator ``GMMConv``, which is accessible from ``Pytorch Geometric``. 23 | 24 | ## Superpixel MNIST Classification 25 | 26 | ``` 27 | python -m image.main 28 | ``` 29 | 30 | 31 | ## Vertex Classification 32 | 33 | ``` 34 | python -m graph.main 35 | ``` 36 | 37 | ## Dense Shape Correspondence 38 | 39 | ``` 40 | python -m correspondence.main 41 | ``` 42 | 43 | 44 | ## Data 45 | 46 | In order to use your own dataset, you can simply create a regular python list holding `torch_geometric.data.Data` objects and specify the following attributes: 47 | 48 | - ``data.x``: Node feature matrix with shape ``[num_nodes, num_node_features]`` 49 | - ``data.edge_index``: Graph connectivity in COO format with shape ``[2, num_edges]`` and type ``torch.long`` 50 | - ``data.edge_attr``: Pesudo-coordinates with shape ``[num_edges, pesudo-coordinates-dim]`` 51 | - ``data.y``: Target to train against 52 | 53 | 54 | ## Cite 55 | 56 | Please cite [this paper](https://arxiv.org/abs/1611.08402) if you use this code in your own work: 57 | 58 | ``` 59 | @inproceedings{monti2017geometric, 60 | title={Geometric deep learning on graphs and manifolds using mixture model cnns}, 61 | author={Monti, Federico and Boscaini, Davide and Masci, Jonathan and Rodola, Emanuele and Svoboda, Jan and Bronstein, Michael M}, 62 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 63 | pages={5115--5124}, 64 | year={2017} 65 | } 66 | ``` 67 | -------------------------------------------------------------------------------- /svgs/9284e17b2f479e052a85e111d9f17ce1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /manifolds/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch_geometric.datasets import FAUST 10 | from torch_geometric.data import DataLoader 11 | import torch_geometric.transforms as T 12 | 13 | from conv import GMMConv 14 | from manifolds import run 15 | 16 | parser = argparse.ArgumentParser(description='shape correspondence') 17 | parser.add_argument('--dataset', type=str, default='FAUST') 18 | parser.add_argument('--device_idx', type=int, default=4) 19 | parser.add_argument('--n_threads', type=int, default=4) 20 | parser.add_argument('--kernel_size', type=int, default=10) 21 | parser.add_argument('--lr', type=float, default=3e-3) 22 | parser.add_argument('--lr_decay', type=float, default=0.99) 23 | parser.add_argument('--decay_step', type=int, default=1) 24 | parser.add_argument('--weight_decay', type=float, default=5e-5) 25 | parser.add_argument('--epochs', type=int, default=500) 26 | parser.add_argument('--seed', type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 30 | args.dataset) 31 | device = torch.device('cuda', args.device_idx) 32 | torch.set_num_threads(args.n_threads) 33 | 34 | # deterministic 35 | torch.manual_seed(args.seed) 36 | cudnn.benchmark = False 37 | cudnn.deterministic = True 38 | 39 | 40 | class Pre_Transform(object): 41 | def __call__(self, data): 42 | data.x = data.pos 43 | data = T.FaceToEdge()(data) 44 | 45 | return data 46 | 47 | 48 | train_dataset = FAUST(args.data_fp, 49 | True, 50 | transform=T.Cartesian(), 51 | pre_transform=Pre_Transform()) 52 | test_dataset = FAUST(args.data_fp, 53 | False, 54 | transform=T.Cartesian(), 55 | pre_transform=Pre_Transform()) 56 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 57 | test_loader = DataLoader(test_dataset, batch_size=1) 58 | d = train_dataset[0] 59 | target = torch.arange(d.num_nodes, dtype=torch.long, device=device) 60 | print(d) 61 | 62 | 63 | class MoNet(nn.Module): 64 | def __init__(self, in_channels, num_classes, kernel_size): 65 | super(MoNet, self).__init__() 66 | 67 | self.fc0 = nn.Linear(in_channels, 16) 68 | self.conv1 = GMMConv(16, 32, dim=3, kernel_size=kernel_size) 69 | self.conv2 = GMMConv(32, 64, dim=3, kernel_size=kernel_size) 70 | self.conv3 = GMMConv(64, 128, dim=3, kernel_size=kernel_size) 71 | self.fc1 = nn.Linear(128, 256) 72 | self.fc2 = nn.Linear(256, num_classes) 73 | 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self): 77 | self.conv1.reset_parameters() 78 | self.conv2.reset_parameters() 79 | self.conv3.reset_parameters() 80 | nn.init.xavier_uniform_(self.fc0.weight, gain=1) 81 | nn.init.xavier_uniform_(self.fc1.weight, gain=1) 82 | nn.init.xavier_uniform_(self.fc2.weight, gain=1) 83 | nn.init.constant_(self.fc0.bias, 0) 84 | nn.init.constant_(self.fc1.bias, 0) 85 | nn.init.constant_(self.fc2.bias, 0) 86 | 87 | def forward(self, data): 88 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 89 | x = F.elu(self.fc0(x)) 90 | x = F.elu(self.conv1(x, edge_index, edge_attr)) 91 | x = F.elu(self.conv2(x, edge_index, edge_attr)) 92 | x = F.elu(self.conv3(x, edge_index, edge_attr)) 93 | x = F.elu(self.fc1(x)) 94 | x = F.dropout(x, training=self.training) 95 | x = self.fc2(x) 96 | return F.log_softmax(x, dim=1) 97 | 98 | 99 | model = MoNet(d.num_features, d.num_nodes, args.kernel_size).to(device) 100 | print(model) 101 | optimizer = optim.Adam(model.parameters(), 102 | lr=args.lr, 103 | weight_decay=args.weight_decay) 104 | scheduler = optim.lr_scheduler.StepLR(optimizer, 105 | args.decay_step, 106 | gamma=args.lr_decay) 107 | 108 | run(model, train_loader, test_loader, target, d.num_nodes, args.epochs, 109 | optimizer, scheduler, device) 110 | -------------------------------------------------------------------------------- /svgs/e0eef981c0301bb88a01a36ec17cfd0c.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /svgs/1276e542ca3d1d00fd30f0383afb5d08.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /svgs/f1cee86600f26eed52126ed72d2dfdd8.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /svgs/ab02bf3a35bf706f5ef8c322af45f43e.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /svgs/1c07d8ffda7593d98eda6d17de7db825.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /svgs/72c8e03edc97e002a73695ec08e30d5b.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /svgs/b0e0a2e33abfab591a8f7e7f6854ae83.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | --------------------------------------------------------------------------------