├── docs
├── .nojekyll
├── build
│ └── html
│ │ ├── .nojekyll
│ │ ├── _static
│ │ ├── pygments.css
│ │ ├── down.png
│ │ ├── file.png
│ │ ├── plus.png
│ │ ├── up.png
│ │ ├── minus.png
│ │ ├── comment.png
│ │ ├── ajax-loader.gif
│ │ ├── up-pressed.png
│ │ ├── comment-bright.png
│ │ ├── comment-close.png
│ │ ├── down-pressed.png
│ │ ├── fonts
│ │ │ ├── Lato-Bold.ttf
│ │ │ ├── Lato-Regular.ttf
│ │ │ ├── RobotoSlab-Bold.ttf
│ │ │ ├── Inconsolata-Bold.ttf
│ │ │ ├── RobotoSlab-Regular.ttf
│ │ │ ├── Inconsolata-Regular.ttf
│ │ │ ├── fontawesome-webfont.eot
│ │ │ ├── fontawesome-webfont.ttf
│ │ │ └── fontawesome-webfont.woff
│ │ └── css
│ │ │ ├── custom.css
│ │ │ └── badge_only.css
│ │ └── _sources
│ │ ├── modules
│ │ ├── data.rst.txt
│ │ ├── utils.rst.txt
│ │ ├── datasets.rst.txt
│ │ ├── transforms.rst.txt
│ │ └── nn.rst.txt
│ │ ├── index.rst.txt
│ │ └── notes
│ │ └── installation.rst.txt
├── source
│ ├── _figures
│ │ ├── .gitignore
│ │ ├── build.sh
│ │ └── graph.tex
│ ├── modules
│ │ ├── data.rst
│ │ ├── utils.rst
│ │ ├── datasets.rst
│ │ ├── transforms.rst
│ │ └── nn.rst
│ ├── _static
│ │ └── css
│ │ │ └── custom.css
│ ├── conf.py
│ ├── index.rst
│ └── notes
│ │ └── installation.rst
├── requirements.txt
├── .gitignore
├── index.html
└── Makefile
├── test
├── nn
│ ├── conv
│ │ └── test_spline_conv.py
│ └── pool
│ │ ├── test_consecutive.py
│ │ ├── test_voxel_grid.py
│ │ ├── test_graclus.py
│ │ └── test_topk_pool.py
├── utils
│ ├── test_degree.py
│ ├── test_normalized_cut.py
│ ├── test_to_batch.py
│ ├── test_isolated.py
│ ├── test_one_hot.py
│ ├── test_grid.py
│ ├── test_undirected.py
│ ├── test_softmax.py
│ ├── test_convert.py
│ └── test_loop.py
├── transforms
│ ├── test_random_flip.py
│ ├── test_random_rotate.py
│ ├── test_face_to_edge.py
│ ├── test_radius_graph.py
│ ├── test_linear_transformation.py
│ ├── test_nn_graph.py
│ ├── test_target_indegree.py
│ ├── test_distance.py
│ ├── test_sample_points.py
│ ├── test_compose.py
│ ├── test_spherical.py
│ ├── test_cartesian.py
│ ├── test_two_hop.py
│ ├── test_polar.py
│ └── test_local_cartesian.py
├── data
│ ├── test_collate.py
│ ├── test_batch.py
│ ├── test_split.py
│ └── test_data.py
└── datasets
│ ├── test_planetoid.py
│ └── test_tu_dataset.py
├── MANIFEST.in
├── .coveragerc
├── torch_geometric
├── __init__.py
├── utils
│ ├── num_nodes.py
│ ├── sparse.py
│ ├── normalized_cut.py
│ ├── isolated.py
│ ├── softmax.py
│ ├── one_hot.py
│ ├── to_batch.py
│ ├── undirected.py
│ ├── loop.py
│ ├── degree.py
│ ├── __init__.py
│ ├── grid.py
│ ├── convert.py
│ ├── scatter.py
│ └── metric.py
├── nn
│ ├── prop
│ │ ├── __init__.py
│ │ ├── gcn_prop.py
│ │ └── agnn_prop.py
│ ├── dense
│ │ ├── __init__.py
│ │ ├── diff_pool.py
│ │ └── sage_conv.py
│ ├── repeat.py
│ ├── pool
│ │ ├── graclus.py
│ │ ├── consecutive.py
│ │ ├── global_pool.py
│ │ ├── __init__.py
│ │ ├── pool.py
│ │ ├── sort_pool.py
│ │ ├── voxel_grid.py
│ │ ├── max_pool.py
│ │ ├── avg_pool.py
│ │ ├── set2set.py
│ │ └── topk_pool.py
│ ├── __init__.py
│ ├── conv
│ │ ├── __init__.py
│ │ ├── edge_conv.py
│ │ ├── gin_conv.py
│ │ ├── gcn_conv.py
│ │ ├── graph_conv.py
│ │ ├── sage_conv.py
│ │ ├── nn_conv.py
│ │ └── cheb_conv.py
│ └── inits.py
├── transforms
│ ├── normalize_features.py
│ ├── center.py
│ ├── normalize_scale.py
│ ├── face_to_edge.py
│ ├── add_self_loops.py
│ ├── radius_graph.py
│ ├── nn_graph.py
│ ├── random_scale.py
│ ├── sample_points.py
│ ├── random_shear.py
│ ├── random_rotate.py
│ ├── compose.py
│ ├── random_translate.py
│ ├── two_hop.py
│ ├── random_flip.py
│ ├── constant.py
│ ├── distance.py
│ ├── __init__.py
│ ├── one_hot_degree.py
│ ├── target_indegree.py
│ ├── linear_transformation.py
│ ├── to_dense.py
│ ├── cartesian.py
│ ├── polar.py
│ ├── spherical.py
│ └── local_cartesian.py
├── data
│ ├── makedirs.py
│ ├── download.py
│ ├── extract.py
│ ├── __init__.py
│ ├── dataloader.py
│ ├── batch.py
│ ├── dataset.py
│ ├── in_memory_dataset.py
│ └── data.py
├── read
│ ├── __init__.py
│ ├── txt_array.py
│ ├── ply.py
│ ├── off.py
│ ├── sdf.py
│ └── planetoid.py
└── datasets
│ ├── __init__.py
│ ├── karate.py
│ ├── planetoid.py
│ ├── qm9.py
│ ├── qm7.py
│ ├── faust.py
│ ├── tu_dataset.py
│ ├── mnist_superpixels.py
│ ├── ppi.py
│ ├── coma.py
│ └── modelnet.py
├── setup.cfg
├── .gitignore
├── setup.py
├── LICENSE
├── .travis.yml
└── examples
├── gat.py
├── agnn.py
├── cora.py
├── gcn.py
├── mnist_voxel_grid.py
├── mnist_graclus.py
├── enzymes_topk_pool.py
├── mnist_mpnn.py
└── faust.py
/docs/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/build/html/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/nn/conv/test_spline_conv.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_figures/.gitignore:
--------------------------------------------------------------------------------
1 | *.aux
2 | *.log
3 | *.pdf
4 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==1.6.6
2 | sphinx_rtd_theme==0.2.4
3 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 |
3 | recursive-include examples *
4 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | exclude_lines =
3 | pragma: no cover
4 | raise
5 |
--------------------------------------------------------------------------------
/torch_geometric/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.3.1'
2 |
3 | __all__ = ['__version__']
4 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | build/doctest
2 | build/doctrees
3 | build/html/.buildinfo
4 | build/html/objects.inv
5 |
--------------------------------------------------------------------------------
/docs/build/html/_static/pygments.css:
--------------------------------------------------------------------------------
1 | .highlight .hll { background-color: #ffffcc }
2 | .highlight { background: #ffffff; }
--------------------------------------------------------------------------------
/docs/build/html/_static/down.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/down.png
--------------------------------------------------------------------------------
/docs/build/html/_static/file.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/file.png
--------------------------------------------------------------------------------
/docs/build/html/_static/plus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/plus.png
--------------------------------------------------------------------------------
/docs/build/html/_static/up.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/up.png
--------------------------------------------------------------------------------
/docs/build/html/_static/minus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/minus.png
--------------------------------------------------------------------------------
/docs/build/html/_static/comment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/comment.png
--------------------------------------------------------------------------------
/docs/build/html/_static/ajax-loader.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/ajax-loader.gif
--------------------------------------------------------------------------------
/docs/build/html/_static/up-pressed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/up-pressed.png
--------------------------------------------------------------------------------
/docs/build/html/_static/comment-bright.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/comment-bright.png
--------------------------------------------------------------------------------
/docs/build/html/_static/comment-close.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/comment-close.png
--------------------------------------------------------------------------------
/docs/build/html/_static/down-pressed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/down-pressed.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
4 | [aliases]
5 | test=pytest
6 |
7 | [tool:pytest]
8 | addopts = --capture=no --cov
9 |
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/Lato-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Lato-Bold.ttf
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/Lato-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Lato-Regular.ttf
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/RobotoSlab-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/RobotoSlab-Bold.ttf
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/Inconsolata-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Inconsolata-Bold.ttf
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/RobotoSlab-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/RobotoSlab-Regular.ttf
--------------------------------------------------------------------------------
/docs/source/modules/data.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.data
2 | ====================
3 |
4 | .. automodule:: torch_geometric.data
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/torch_geometric/utils/num_nodes.py:
--------------------------------------------------------------------------------
1 | def maybe_num_nodes(edge_index, num_nodes=None):
2 | return edge_index.max().item() + 1 if num_nodes is None else num_nodes
3 |
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/Inconsolata-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Inconsolata-Regular.ttf
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/fontawesome-webfont.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/fontawesome-webfont.eot
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/fontawesome-webfont.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/fontawesome-webfont.ttf
--------------------------------------------------------------------------------
/docs/build/html/_static/fonts/fontawesome-webfont.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/fontawesome-webfont.woff
--------------------------------------------------------------------------------
/docs/source/modules/utils.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.utils
2 | =====================
3 |
4 | .. automodule:: torch_geometric.utils
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/torch_geometric/nn/prop/__init__.py:
--------------------------------------------------------------------------------
1 | from .gcn_prop import GCNProp
2 | from .agnn_prop import AGNNProp
3 |
4 | __all__ = [
5 | 'GCNProp',
6 | 'AGNNProp',
7 | ]
8 |
--------------------------------------------------------------------------------
/docs/build/html/_sources/modules/data.rst.txt:
--------------------------------------------------------------------------------
1 | torch_geometric.data
2 | ====================
3 |
4 | .. automodule:: torch_geometric.data
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/docs/source/modules/datasets.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.datasets
2 | ========================
3 |
4 | .. automodule:: torch_geometric.datasets
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/docs/build/html/_sources/modules/utils.rst.txt:
--------------------------------------------------------------------------------
1 | torch_geometric.utils
2 | =====================
3 |
4 | .. automodule:: torch_geometric.utils
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Redirect
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/docs/source/modules/transforms.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.transforms
2 | ==========================
3 |
4 | .. automodule:: torch_geometric.transforms
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | data/
3 | build/
4 | dist/
5 | alpha/
6 | .cache/
7 | .eggs/
8 | *.egg-info/
9 | .coverage
10 |
11 | !docs/build/
12 | !torch_geometric/data/
13 | !test/data/
14 |
--------------------------------------------------------------------------------
/docs/build/html/_sources/modules/datasets.rst.txt:
--------------------------------------------------------------------------------
1 | torch_geometric.datasets
2 | ========================
3 |
4 | .. automodule:: torch_geometric.datasets
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/torch_geometric/nn/dense/__init__.py:
--------------------------------------------------------------------------------
1 | from .sage_conv import DenseSAGEConv
2 | from .diff_pool import dense_diff_pool
3 |
4 | __all__ = [
5 | 'DenseSAGEConv',
6 | 'dense_diff_pool',
7 | ]
8 |
--------------------------------------------------------------------------------
/docs/build/html/_sources/modules/transforms.rst.txt:
--------------------------------------------------------------------------------
1 | torch_geometric.transforms
2 | ==========================
3 |
4 | .. automodule:: torch_geometric.transforms
5 | :members:
6 | :undoc-members:
7 |
--------------------------------------------------------------------------------
/docs/source/_figures/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | for filename in *.tex; do
4 | basename=$(basename $filename .tex)
5 | pdflatex "$basename.tex"
6 | pdf2svg "$basename.pdf" "$basename.svg"
7 | done
8 |
--------------------------------------------------------------------------------
/docs/source/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Use white for logo background. */
2 | .wy-side-nav-search {
3 | background-color: #fff;
4 | }
5 |
6 | .wy-side-nav-search > div.version {
7 | color: #000;
8 | }
9 |
--------------------------------------------------------------------------------
/docs/build/html/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Use white for logo background. */
2 | .wy-side-nav-search {
3 | background-color: #fff;
4 | }
5 |
6 | .wy-side-nav-search > div.version {
7 | color: #000;
8 | }
9 |
--------------------------------------------------------------------------------
/test/utils/test_degree.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import degree
3 |
4 |
5 | def test_degree():
6 | row = torch.tensor([0, 1, 0, 2, 0])
7 | assert degree(row).tolist() == [3, 1, 1]
8 |
--------------------------------------------------------------------------------
/torch_geometric/nn/repeat.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import itertools
3 |
4 |
5 | def repeat(src, length):
6 | if isinstance(src, numbers.Number):
7 | src = list(itertools.repeat(src, length))
8 | return src
9 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/graclus.py:
--------------------------------------------------------------------------------
1 | from torch_cluster import graclus_cluster
2 |
3 |
4 | def graclus(edge_index, weight=None, num_nodes=None):
5 | row, col = edge_index
6 | return graclus_cluster(row, col, weight, num_nodes)
7 |
--------------------------------------------------------------------------------
/torch_geometric/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv import * # noqa
2 | from .prop import * # noqa
3 | from .pool import * # noqa
4 | from .dense import * # noqa
5 | from .meta import MetaLayer
6 |
7 | __all__ = [
8 | 'MetaLayer',
9 | ]
10 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | SPHINXBUILD = sphinx-build
2 | SPHINXPROJ = pytorch_geometric
3 | SOURCEDIR = source
4 | BUILDDIR = build
5 |
6 | .PHONY: help Makefile
7 |
8 | %: Makefile
9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)"
10 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/normalize_features.py:
--------------------------------------------------------------------------------
1 | class NormalizeFeatures(object):
2 | def __call__(self, data):
3 | data.x = data.x / data.x.sum(1, keepdim=True).clamp(min=1)
4 | return data
5 |
6 | def __repr__(self):
7 | return '{}()'.format(self.__class__.__name__)
8 |
--------------------------------------------------------------------------------
/torch_geometric/data/makedirs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import errno
4 |
5 |
6 | def makedirs(path):
7 | try:
8 | os.makedirs(osp.expanduser(osp.normpath(path)))
9 | except OSError as e:
10 | if e.errno != errno.EEXIST and osp.isdir(path):
11 | raise e
12 |
--------------------------------------------------------------------------------
/torch_geometric/utils/sparse.py:
--------------------------------------------------------------------------------
1 | from torch_sparse import coalesce
2 |
3 |
4 | def dense_to_sparse(tensor):
5 | index = tensor.nonzero()
6 | value = tensor[index]
7 | index = index.t().contiguous()
8 | index, value = coalesce(index, value, tensor.size(0), tensor.size(1))
9 | return index, value
10 |
--------------------------------------------------------------------------------
/torch_geometric/utils/normalized_cut.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.utils import degree
2 |
3 |
4 | def normalized_cut(edge_index, edge_attr, num_nodes=None):
5 | row, col = edge_index
6 | deg = 1 / degree(row, num_nodes, edge_attr.dtype)
7 | deg = deg[row] + deg[col]
8 | cut = edge_attr * deg
9 | return cut
10 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/consecutive.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def consecutive_cluster(src):
5 | unique, inv = torch.unique(src, sorted=True, return_inverse=True)
6 | perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
7 | perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm)
8 | return inv, perm
9 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/center.py:
--------------------------------------------------------------------------------
1 | class Center(object):
2 | def __call__(self, data):
3 | pos = data.pos
4 |
5 | mean = data.pos.mean(dim=0).view(1, -1)
6 | data.pos = pos - mean
7 |
8 | return data
9 |
10 | def __repr__(self):
11 | return '{}()'.format(self.__class__.__name__)
12 |
--------------------------------------------------------------------------------
/test/nn/pool/test_consecutive.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.nn.pool.consecutive import consecutive_cluster
3 |
4 |
5 | def test_consecutive_cluster():
6 | src = torch.tensor([8, 2, 10, 15, 100, 1, 100])
7 |
8 | out, perm = consecutive_cluster(src)
9 | assert out.tolist() == [2, 1, 3, 4, 5, 0, 5]
10 | assert perm.tolist() == [5, 1, 0, 2, 3, 6]
11 |
--------------------------------------------------------------------------------
/torch_geometric/utils/isolated.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .num_nodes import maybe_num_nodes
4 | from .loop import remove_self_loops
5 |
6 |
7 | def contains_isolated_nodes(edge_index, num_nodes=None):
8 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
9 | (row, _), _ = remove_self_loops(edge_index)
10 | return torch.unique(row).size(0) < num_nodes
11 |
--------------------------------------------------------------------------------
/docs/source/_figures/graph.tex:
--------------------------------------------------------------------------------
1 | \documentclass{standalone}
2 |
3 | \usepackage{tikz}
4 |
5 | \begin{document}
6 |
7 | \begin{tikzpicture}
8 | \node[draw,circle,label= left:{$x_1=-1$}] (0) at (0, 0) {0};
9 | \node[draw,circle,label=above:{$x_1=0$}] (1) at (1, 1) {1};
10 | \node[draw,circle,label=right:{$x_1=1$}] (2) at (2, 0) {2};
11 |
12 | \path[draw] (0) -- (1);
13 | \path[draw] (1) -- (2);
14 | \end{tikzpicture}
15 |
16 | \end{document}
17 |
--------------------------------------------------------------------------------
/torch_geometric/utils/softmax.py:
--------------------------------------------------------------------------------
1 | from torch_scatter import scatter_max, scatter_add
2 |
3 | from .num_nodes import maybe_num_nodes
4 |
5 |
6 | def softmax(src, index, num_nodes=None):
7 | num_nodes = maybe_num_nodes(index, num_nodes)
8 |
9 | out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
10 | out = out.exp()
11 | out = out / scatter_add(out, index, dim=0, dim_size=num_nodes)[index]
12 |
13 | return out
14 |
--------------------------------------------------------------------------------
/test/transforms/test_random_flip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import RandomFlip
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_cartesian():
7 | assert RandomFlip(axis=0).__repr__() == 'RandomFlip(axis=0, p=0.5)'
8 |
9 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float)
10 | data = Data(pos=pos)
11 |
12 | out = RandomFlip(axis=0, p=1)(data).pos.tolist()
13 | assert out == [[1, 1], [3, 0], [-2, -1]]
14 |
--------------------------------------------------------------------------------
/test/transforms/test_random_rotate.py:
--------------------------------------------------------------------------------
1 | from pytest import approx
2 | import torch
3 | from torch_geometric.transforms import RandomRotate
4 | from torch_geometric.data import Data
5 |
6 |
7 | def test_spherical():
8 | assert RandomRotate(-180).__repr__() == 'RandomRotate((-180, 180))'
9 |
10 | pos = torch.tensor([[1, 0], [0, 1]], dtype=torch.float)
11 | data = Data(pos=pos)
12 |
13 | out = RandomRotate((90, 90))(data).pos.view(-1).tolist()
14 | assert approx(out) == [0, 1, -1, 0]
15 |
--------------------------------------------------------------------------------
/torch_geometric/data/download.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from six.moves import urllib
3 |
4 | from .makedirs import makedirs
5 |
6 |
7 | def download_url(url, folder, log=True):
8 | if log:
9 | print('Downloading', url)
10 |
11 | makedirs(folder)
12 |
13 | data = urllib.request.urlopen(url)
14 | filename = url.rpartition('/')[2]
15 | path = osp.join(folder, filename)
16 |
17 | with open(path, 'wb') as f:
18 | f.write(data.read())
19 |
20 | return path
21 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/normalize_scale.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.transforms import Center
2 |
3 |
4 | class NormalizeScale(object):
5 | def __init__(self):
6 | self.center = Center()
7 |
8 | def __call__(self, data):
9 | data = self.center(data)
10 |
11 | scale = (1 / data.pos.abs().max()) * 0.999999
12 | data.pos = data.pos * scale
13 |
14 | return data
15 |
16 | def __repr__(self):
17 | return '{}()'.format(self.__class__.__name__)
18 |
--------------------------------------------------------------------------------
/torch_geometric/data/extract.py:
--------------------------------------------------------------------------------
1 | import tarfile
2 | import zipfile
3 |
4 |
5 | def maybe_log(path, log=True):
6 | if log:
7 | print('Extracting', path)
8 |
9 |
10 | def extract_tar(path, folder, mode='r:gz', log=True):
11 | maybe_log(path, log)
12 | with tarfile.open(path, mode) as f:
13 | f.extractall(folder)
14 |
15 |
16 | def extract_zip(path, folder, log=True):
17 | maybe_log(path, log)
18 | with zipfile.ZipFile(path, 'r') as f:
19 | f.extractall(folder)
20 |
--------------------------------------------------------------------------------
/test/utils/test_normalized_cut.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import normalized_cut
3 |
4 |
5 | def test_normalized_cut():
6 | row = torch.LongTensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4])
7 | col = torch.LongTensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3])
8 | edge_attr = torch.Tensor([3, 3, 6, 3, 6, 1, 3, 2, 1, 2])
9 | expected_output = [4, 4, 5, 2.5, 5, 1, 2.5, 2, 1, 2]
10 |
11 | output = normalized_cut(torch.stack([row, col], dim=0), edge_attr)
12 | assert output.tolist() == expected_output
13 |
--------------------------------------------------------------------------------
/torch_geometric/read/__init__.py:
--------------------------------------------------------------------------------
1 | from .txt_array import parse_txt_array, read_txt_array
2 | from .tu import read_tu_data
3 | from .planetoid import read_planetoid_data
4 | from .ply import read_ply
5 | from .sdf import read_sdf, parse_sdf
6 | from .off import read_off, parse_off
7 |
8 | __all__ = [
9 | 'parse_txt_array',
10 | 'read_txt_array',
11 | 'read_tu_data',
12 | 'read_planetoid_data',
13 | 'read_ply',
14 | 'read_sdf',
15 | 'parse_sdf',
16 | 'read_off',
17 | 'parse_off',
18 | ]
19 |
--------------------------------------------------------------------------------
/test/transforms/test_face_to_edge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import FaceToEdge
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_face_to_edge():
7 | assert FaceToEdge().__repr__() == 'FaceToEdge()'
8 |
9 | face = torch.tensor([[0, 0], [1, 1], [2, 3]])
10 | data = Data()
11 | data.face = face
12 |
13 | row, col = FaceToEdge()(data).edge_index
14 | assert row.tolist() == [0, 0, 0, 1, 1, 1, 2, 2, 3, 3]
15 | assert col.tolist() == [1, 2, 3, 0, 2, 3, 0, 1, 0, 1]
16 |
--------------------------------------------------------------------------------
/test/utils/test_to_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import to_batch
3 |
4 |
5 | def test_to_batch():
6 | x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
7 | batch = torch.tensor([0, 0, 1, 2, 2, 2])
8 |
9 | x, num_nodes = to_batch(x, batch)
10 | expected = [
11 | [[1, 2], [3, 4], [0, 0]],
12 | [[5, 6], [0, 0], [0, 0]],
13 | [[7, 8], [9, 10], [11, 12]],
14 | ]
15 | assert x.tolist() == expected
16 | assert num_nodes.tolist() == [2, 1, 3]
17 |
--------------------------------------------------------------------------------
/torch_geometric/read/txt_array.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def parse_txt_array(src, sep=None, start=0, end=None, dtype=None, device=None):
5 | src = [[float(x) for x in line.split(sep)[start:end]] for line in src]
6 | src = torch.tensor(src, dtype=dtype).squeeze()
7 | return src
8 |
9 |
10 | def read_txt_array(path, sep=None, start=0, end=None, dtype=None, device=None):
11 | with open(path, 'r') as f:
12 | src = f.read().split('\n')[:-1]
13 | return parse_txt_array(src, sep, start, end, dtype, device)
14 |
--------------------------------------------------------------------------------
/test/transforms/test_radius_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import RadiusGraph
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_radius_graph():
7 | assert RadiusGraph(1).__repr__() == 'RadiusGraph(r=1)'
8 |
9 | pos = [[0, 0], [1, 0], [2, 0], [0, 1], [-2, 0], [0, -2]]
10 | pos = torch.tensor(pos, dtype=torch.float)
11 | data = Data(pos=pos)
12 |
13 | edge_index = RadiusGraph(1)(data).edge_index
14 | assert edge_index.tolist() == [[0, 0, 1, 1, 2, 3], [1, 3, 0, 2, 1, 0]]
15 |
--------------------------------------------------------------------------------
/test/utils/test_isolated.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import contains_isolated_nodes
3 |
4 |
5 | def test_contains_isolated_nodes():
6 | row = torch.tensor([0, 1, 0])
7 | col = torch.tensor([1, 0, 0])
8 |
9 | assert not contains_isolated_nodes(torch.stack([row, col], dim=0))
10 | assert contains_isolated_nodes(torch.stack([row, col], dim=0), 3)
11 |
12 | row = torch.tensor([0, 1, 2, 0])
13 | col = torch.tensor([1, 0, 2, 0])
14 | assert contains_isolated_nodes(torch.stack([row, col], dim=0))
15 |
--------------------------------------------------------------------------------
/test/nn/pool/test_voxel_grid.py:
--------------------------------------------------------------------------------
1 | # import torch
2 | # from torch_geometric.nn.pool.voxel_grid import voxel_grid
3 |
4 | # def test_voxel_grid():
5 | # pos = torch.Tensor([[0.5, 0.5], [1.2, 1.5], [0.2, 1.7], [0.3, 0.4],
6 | # [0.5, 0.5], [1.2, 1.5], [0.2, 1.7], [0.3, 0.4]])
7 | # batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
8 |
9 | # cluster, batch = voxel_grid(pos, size=1, start=0, end=2, batch=batch)
10 | # assert cluster.tolist() == [0, 2, 1, 0, 3, 5, 4, 3]
11 | # assert batch.tolist() == [0, 0, 0, 1, 1, 1]
12 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/__init__.py:
--------------------------------------------------------------------------------
1 | from .gcn_conv import GCNConv
2 | from .cheb_conv import ChebConv
3 | from .sage_conv import SAGEConv
4 | from .graph_conv import GraphConv
5 | from .gat_conv import GATConv
6 | from .gin_conv import GINConv
7 | from .spline_conv import SplineConv
8 | from .nn_conv import NNConv
9 | from .edge_conv import EdgeConv
10 |
11 | __all__ = [
12 | 'GCNConv',
13 | 'ChebConv',
14 | 'SAGEConv',
15 | 'GraphConv',
16 | 'GATConv',
17 | 'GINConv',
18 | 'SplineConv',
19 | 'NNConv',
20 | 'EdgeConv',
21 | ]
22 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/face_to_edge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import to_undirected
3 |
4 |
5 | class FaceToEdge(object):
6 | def __call__(self, data):
7 | face = data.face
8 |
9 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
10 | edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
11 |
12 | data.edge_index = edge_index
13 | data.face = None
14 | return data
15 |
16 | def __repr__(self):
17 | return '{}()'.format(self.__class__.__name__)
18 |
--------------------------------------------------------------------------------
/torch_geometric/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .download import download_url
2 | from .extract import extract_tar, extract_zip
3 | from .data import Data
4 | from .batch import Batch
5 | from .dataset import Dataset
6 | from .in_memory_dataset import InMemoryDataset
7 | from .dataloader import DataLoader, DenseDataLoader
8 |
9 | __all__ = [
10 | 'download_url',
11 | 'extract_tar',
12 | 'extract_zip',
13 | 'Data',
14 | 'Batch',
15 | 'Dataset',
16 | 'DataLoader',
17 | 'DenseDataLoader',
18 | 'InMemoryDataset',
19 | 'DataLoader',
20 | ]
21 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/add_self_loops.py:
--------------------------------------------------------------------------------
1 | from torch_sparse import coalesce
2 | from torch_geometric.utils import add_self_loops
3 |
4 |
5 | class AddSelfLoops(object):
6 | def __call__(self, data):
7 | edge_index = data.edge_index
8 | num_nodes = data.num_nodes
9 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes)
10 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
11 | data.edge_index = edge_index
12 | return data
13 |
14 | def __repr__(self):
15 | return '{}()'.format(self.__class__.__name__)
16 |
--------------------------------------------------------------------------------
/torch_geometric/read/ply.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from plyfile import PlyData
3 | from torch_geometric.data import Data
4 |
5 |
6 | def read_ply(path):
7 | with open(path, 'rb') as f:
8 | data = PlyData.read(f)
9 |
10 | pos = ([torch.tensor(data['vertex'][axis]) for axis in ['x', 'y', 'z']])
11 | pos = torch.stack(pos, dim=-1)
12 |
13 | faces = data['face']['vertex_indices']
14 | faces = [torch.tensor(face, dtype=torch.long) for face in faces]
15 | face = torch.stack(faces, dim=-1)
16 |
17 | data = Data(pos=pos)
18 | data.face = face
19 |
20 | return data
21 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/global_pool.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.utils import scatter_
2 |
3 |
4 | def global_add_pool(x, batch, size=None):
5 | size = batch[-1].item() + 1 if size is None else size
6 | return scatter_('add', x, batch, dim_size=size)
7 |
8 |
9 | def global_mean_pool(x, batch, size=None):
10 | size = batch[-1].item() + 1 if size is None else size
11 | return scatter_('mean', x, batch, dim_size=size)
12 |
13 |
14 | def global_max_pool(x, batch, size=None):
15 | size = batch[-1].item() + 1 if size is None else size
16 | return scatter_('max', x, batch, dim_size=size)
17 |
--------------------------------------------------------------------------------
/test/utils/test_one_hot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import one_hot
3 |
4 |
5 | def test_one_hot():
6 | src = torch.LongTensor([1, 0, 3])
7 | expected_output = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]
8 |
9 | assert one_hot(src).tolist() == expected_output
10 | assert one_hot(src, 4).tolist() == expected_output
11 |
12 | src = torch.LongTensor([[1, 0], [0, 1], [2, 0]])
13 | expected_output = [[0, 1, 0, 1, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0]]
14 | assert one_hot(src).tolist() == expected_output
15 | assert one_hot(src, torch.tensor([3, 2])).tolist() == expected_output
16 |
--------------------------------------------------------------------------------
/test/transforms/test_linear_transformation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import LinearTransformation
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_cartesian():
7 | matrix = torch.tensor([[2, 0], [0, 2]], dtype=torch.float)
8 | transform = LinearTransformation(matrix)
9 | out = 'LinearTransformation([[2.0, 0.0], [0.0, 2.0]])'
10 | assert transform.__repr__() == out
11 |
12 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float)
13 | data = Data(pos=pos)
14 |
15 | out = transform(data).pos.tolist()
16 | assert out == [[-2, 2], [-6, 0], [4, -2]]
17 |
--------------------------------------------------------------------------------
/test/utils/test_grid.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.utils import grid
2 |
3 |
4 | def test_grid():
5 | (row, col), pos = grid(height=3, width=2)
6 |
7 | expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2]
8 | expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5]
9 | expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]
10 | expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]
11 |
12 | expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]]
13 |
14 | assert row.tolist() == expected_row
15 | assert col.tolist() == expected_col
16 | assert pos.tolist() == expected_pos
17 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .karate import KarateClub
2 | from .tu_dataset import TUDataset
3 | from .planetoid import Planetoid
4 | from .mnist_superpixels import MNISTSuperpixels
5 | from .faust import FAUST
6 | from .qm7 import QM7
7 | from .qm9 import QM9
8 | from .ppi import PPI
9 | from .shapenet import ShapeNet
10 | from .modelnet import ModelNet
11 | from .coma import CoMA
12 |
13 | __all__ = [
14 | 'KarateClub',
15 | 'TUDataset',
16 | 'Planetoid',
17 | 'MNISTSuperpixels',
18 | 'FAUST',
19 | 'QM7',
20 | 'QM9',
21 | 'PPI',
22 | 'ShapeNet',
23 | 'ModelNet',
24 | 'CoMA',
25 | ]
26 |
--------------------------------------------------------------------------------
/test/nn/pool/test_graclus.py:
--------------------------------------------------------------------------------
1 | # import torch
2 | # from torch_geometric.nn.pool import graclus
3 | # from torch_geometric.data import Batch
4 |
5 | # def test_graclus_pool():
6 | # x = torch.Tensor([[1, 6], [2, 5], [3, 4], [4, 3], [5, 2], [6, 1]])
7 | # row = torch.LongTensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5])
8 | # col = torch.LongTensor([1, 2, 0, 2, 0, 1, 4, 5, 3, 5, 3, 4])
9 | # edge_index = torch.stack([row, col], dim=0)
10 | # batch = torch.LongTensor([0, 0, 0, 1, 1, 1])
11 | # data = Batch(x=x, edge_index=edge_index, batch=batch)
12 |
13 | # data = graclus_pool(data)
14 |
15 | # assert data.num_nodes == 4
16 |
--------------------------------------------------------------------------------
/test/data/test_collate.py:
--------------------------------------------------------------------------------
1 | # def test_collate_to_set():
2 | # x1 = torch.tensor([1, 2, 3], dtype=torch.float)
3 | # e1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
4 | # x2 = torch.tensor([1, 2], dtype=torch.float)
5 | # e2 = torch.tensor([[0, 1], [1, 0]])
6 |
7 | # data, slices = collate_to_set([Data(x1, e1), Data(x2, e2)])
8 |
9 | # assert len(data) == 2
10 | # assert data.x.tolist() == [1, 2, 3, 1, 2]
11 | # data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1], [1, 0, 2, 1, 1, 0]]
12 | # assert len(slices.keys()) == 2
13 | # assert slices['x'].tolist() == [0, 3, 5]
14 | # assert slices['edge_index'].tolist() == [0, 4, 6]
15 |
--------------------------------------------------------------------------------
/test/transforms/test_nn_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import NNGraph
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_nn_graph():
7 | assert NNGraph().__repr__() == 'NNGraph(k=6)'
8 |
9 | pos = [[0, 0], [1, 0], [2, 0], [0, 1], [-2, 0], [0, -2]]
10 | pos = torch.tensor(pos, dtype=torch.float)
11 | data = Data(pos=pos)
12 |
13 | row, col = NNGraph(2)(data).edge_index
14 | expected_row = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]
15 | expected_col = [1, 2, 3, 4, 5, 0, 2, 3, 5, 0, 1, 0, 1, 4, 0, 3, 0, 1]
16 |
17 | assert row.tolist() == expected_row
18 | assert col.tolist() == expected_col
19 |
--------------------------------------------------------------------------------
/test/transforms/test_target_indegree.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import TargetIndegree
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_target_indegree():
7 | assert TargetIndegree().__repr__() == 'TargetIndegree(cat=True)'
8 |
9 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
10 | data = Data(edge_index=edge_index)
11 |
12 | out = TargetIndegree()(data).edge_attr.tolist()
13 | assert out == [[1], [0.5], [0.5], [1]]
14 |
15 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float)
16 | out = TargetIndegree()(data).edge_attr.tolist()
17 | assert out == [[1, 1], [1, 0.5], [1, 0.5], [1, 1]]
18 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/__init__.py:
--------------------------------------------------------------------------------
1 | from .global_pool import global_add_pool, global_mean_pool, global_max_pool
2 | from .set2set import Set2Set
3 |
4 | from .max_pool import max_pool, max_pool_x
5 | from .avg_pool import avg_pool, avg_pool_x
6 | from .graclus import graclus
7 | from .voxel_grid import voxel_grid
8 | from .topk_pool import TopKPooling
9 | from .sort_pool import sort_pool
10 |
11 | __all__ = [
12 | 'global_add_pool',
13 | 'global_mean_pool',
14 | 'global_max_pool',
15 | 'Set2Set',
16 | 'max_pool',
17 | 'max_pool_x',
18 | 'avg_pool',
19 | 'avg_pool_x',
20 | 'graclus',
21 | 'voxel_grid',
22 | 'TopKPooling',
23 | 'sort_pool',
24 | ]
25 |
--------------------------------------------------------------------------------
/test/utils/test_undirected.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import is_undirected, to_undirected
3 |
4 |
5 | def test_is_undirected():
6 | row = torch.tensor([0, 1, 0])
7 | col = torch.tensor([1, 0, 0])
8 |
9 | assert is_undirected(torch.stack([row, col], dim=0))
10 |
11 | row = torch.tensor([0, 1, 1])
12 | col = torch.tensor([1, 0, 2])
13 |
14 | assert not is_undirected(torch.stack([row, col], dim=0))
15 |
16 |
17 | def test_to_undirected():
18 | row = torch.tensor([0, 1, 1])
19 | col = torch.tensor([1, 0, 2])
20 |
21 | edge_index = to_undirected(torch.stack([row, col], dim=0))
22 | assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
23 |
--------------------------------------------------------------------------------
/test/transforms/test_distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import Distance
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_distance():
7 | assert Distance().__repr__() == 'Distance(cat=True)'
8 |
9 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
11 | data = Data(edge_index=edge_index, pos=pos)
12 |
13 | out = Distance()(data).edge_attr.tolist()
14 | assert out == [[1], [1], [2], [2]]
15 |
16 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float)
17 | out = Distance()(data).edge_attr.tolist()
18 | assert out == [[1, 1], [1, 1], [1, 2], [1, 2]]
19 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/pool.py:
--------------------------------------------------------------------------------
1 | from torch_sparse import coalesce
2 | from torch_scatter import scatter_mean
3 | from torch_geometric.utils import remove_self_loops
4 |
5 |
6 | def pool_edge(cluster, edge_index, edge_attr=None):
7 | num_nodes = cluster.size(0)
8 | edge_index = cluster[edge_index.view(-1)].view(2, -1)
9 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
10 | edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
11 | num_nodes)
12 | return edge_index, edge_attr
13 |
14 |
15 | def pool_batch(perm, batch):
16 | return batch[perm]
17 |
18 |
19 | def pool_pos(cluster, pos):
20 | return scatter_mean(pos, cluster, dim=0)
21 |
--------------------------------------------------------------------------------
/test/data/test_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data, Batch
3 |
4 |
5 | def test_batch():
6 | x1 = torch.tensor([1, 2, 3], dtype=torch.float)
7 | e1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
8 | x2 = torch.tensor([1, 2], dtype=torch.float)
9 | e2 = torch.tensor([[0, 1], [1, 0]])
10 |
11 | data = Batch.from_data_list([Data(x1, e1), Data(x2, e2)])
12 |
13 | assert data.__repr__() == 'Batch(batch=[5], edge_index=[2, 6], x=[5])'
14 | assert len(data) == 3
15 | assert data.x.tolist() == [1, 2, 3, 1, 2]
16 | assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]]
17 | assert data.batch.tolist() == [0, 0, 0, 1, 1]
18 | assert data.num_graphs == 2
19 |
--------------------------------------------------------------------------------
/docs/build/html/_sources/modules/nn.rst.txt:
--------------------------------------------------------------------------------
1 | torch_geometric.nn
2 | ==================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | Convolution Layers
8 | ------------------
9 |
10 | .. automodule:: torch_geometric.nn.meta
11 | :members:
12 | :undoc-members:
13 |
14 | .. automodule:: torch_geometric.nn.conv
15 | :members:
16 | :undoc-members:
17 |
18 | Propagation Layers
19 | ------------------
20 |
21 | .. automodule:: torch_geometric.nn.prop
22 | :members:
23 | :undoc-members:
24 |
25 | Pooling Layers
26 | --------------
27 |
28 | .. automodule:: torch_geometric.nn.pool
29 | :members:
30 | :undoc-members:
31 |
32 | .. automodule:: torch_geometric.nn.dense.diff_pool
33 | :members:
34 | :undoc-members:
35 |
--------------------------------------------------------------------------------
/test/transforms/test_sample_points.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import SamplePoints
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_sample_points():
7 | assert SamplePoints(1024).__repr__() == 'SamplePoints(1024)'
8 |
9 | pos = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]
10 | data = Data(pos=torch.tensor(pos, dtype=torch.float))
11 | data.face = torch.tensor([[0, 1], [1, 2], [2, 3]], dtype=torch.long)
12 |
13 | data = SamplePoints(8)(data)
14 | pos = data.pos
15 | assert pos[:, 0].min().item() >= 0 and pos[:, 0].max().item() <= 1
16 | assert pos[:, 1].min().item() >= 0 and pos[:, 1].max().item() <= 1
17 | assert pos[:, 2].abs().sum().item() == 0
18 | assert 'face' not in data
19 |
--------------------------------------------------------------------------------
/test/transforms/test_compose.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric.transforms as T
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_compose():
7 | transform = T.Compose([T.Cartesian(), T.TargetIndegree()])
8 | assert transform.__repr__() == ('Compose([\n'
9 | ' Cartesian(cat=True),\n'
10 | ' TargetIndegree(cat=True),\n'
11 | '])')
12 |
13 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
14 | edge_index = torch.tensor([[0, 1], [1, 2]])
15 | data = Data(edge_index=edge_index, pos=pos)
16 |
17 | out = transform(data).edge_attr.tolist()
18 | assert out == [[0.75, 0.5, 1], [1, 0.5, 1]]
19 |
--------------------------------------------------------------------------------
/test/transforms/test_spherical.py:
--------------------------------------------------------------------------------
1 | from pytest import approx
2 | import torch
3 | from torch_geometric.transforms import Spherical
4 | from torch_geometric.data import Data
5 |
6 |
7 | def test_spherical():
8 | assert Spherical().__repr__() == 'Spherical(cat=True)'
9 |
10 | pos = torch.tensor([[0, 0, 0], [0, 1, 1]], dtype=torch.float)
11 | edge_index = torch.tensor([[0, 1], [1, 0]])
12 | data = Data(edge_index=edge_index, pos=pos)
13 |
14 | out = Spherical()(data).edge_attr.view(-1).tolist()
15 | assert approx(out) == [1, 0.25, 0, 1, 0.75, 1]
16 |
17 | data.edge_attr = torch.tensor([1, 1], dtype=torch.float)
18 | out = Spherical()(data).edge_attr.view(-1).tolist()
19 | assert approx(out) == [1, 1, 0.25, 0, 1, 1, 0.75, 1]
20 |
--------------------------------------------------------------------------------
/test/transforms/test_cartesian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import Cartesian
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_cartesian():
7 | assert Cartesian().__repr__() == 'Cartesian(cat=True)'
8 |
9 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
11 | data = Data(edge_index=edge_index, pos=pos)
12 |
13 | out = Cartesian()(data).edge_attr.tolist()
14 | assert out == [[0.75, 0.5], [0.25, 0.5], [1, 0.5], [0, 0.5]]
15 |
16 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float)
17 | out = Cartesian()(data).edge_attr.tolist()
18 | assert out == [[1, 0.75, 0.5], [1, 0.25, 0.5], [1, 1, 0.5], [1, 0, 0.5]]
19 |
--------------------------------------------------------------------------------
/test/transforms/test_two_hop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import TwoHop
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_two_hop():
7 | assert TwoHop().__repr__() == 'TwoHop()'
8 |
9 | edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
10 | edge_attr = torch.tensor([1, 2, 3, 1, 2, 3], dtype=torch.float)
11 | data = Data(edge_index=edge_index, edge_attr=edge_attr)
12 |
13 | data = TwoHop()(data)
14 | edge_index, edge_attr = data.edge_index, data.edge_attr
15 |
16 | assert edge_index.tolist() == [[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
17 | [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]
18 | assert edge_attr.tolist() == [1, 2, 3, 1, 0, 0, 2, 0, 0, 3, 0, 0]
19 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/sort_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import to_batch
3 |
4 |
5 | def sort_pool(x, batch, k):
6 | x, _ = x.sort(dim=-1)
7 |
8 | fill_value = x.min().item() - 1
9 | batch_x, num_nodes = to_batch(x, batch, fill_value)
10 | B, N, D = batch_x.size()
11 |
12 | _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True)
13 | arange = torch.arange(B, dtype=torch.long, device=perm.device) * N
14 | perm = perm + arange.view(-1, 1)
15 |
16 | batch_x = batch_x.view(B * N, D)
17 | batch_x = batch_x[perm]
18 | batch_x = batch_x.view(B, N, D)
19 |
20 | batch_x = batch_x[:, :k].contiguous()
21 | batch_x[batch_x == fill_value] = 0
22 | x = batch_x.view(B, k * D)
23 |
24 | return x
25 |
--------------------------------------------------------------------------------
/test/transforms/test_polar.py:
--------------------------------------------------------------------------------
1 | from pytest import approx
2 | import torch
3 | from torch_geometric.transforms import Polar
4 | from torch_geometric.data import Data
5 |
6 |
7 | def test_polar():
8 | assert Polar().__repr__() == 'Polar(cat=True)'
9 |
10 | pos = torch.tensor([[-1, 0], [0, 0], [0, 2]], dtype=torch.float)
11 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
12 | data = Data(edge_index=edge_index, pos=pos)
13 |
14 | out = Polar()(data).edge_attr.view(-1).tolist()
15 | assert approx(out) == [0.5, 0, 0.5, 0.5, 1, 0.25, 1, 0.75]
16 |
17 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float)
18 | out = Polar()(data).edge_attr.view(-1).tolist()
19 | assert approx(out) == [1, 0.5, 0, 1, 0.5, 0.5, 1, 1, 0.25, 1, 1, 0.75]
20 |
--------------------------------------------------------------------------------
/torch_geometric/utils/one_hot.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def one_hot(src, num_classes=None, dtype=None):
5 | src = src.to(torch.long)
6 | src = src.unsqueeze(-1) if src.dim() == 1 else src
7 | assert src.dim() == 2
8 |
9 | if num_classes is None:
10 | num_classes = src.max(dim=0)[0] + 1
11 | elif isinstance(num_classes, int) or isinstance(num_classes, float):
12 | num_classes = torch.tensor(int(num_classes))
13 |
14 | if src.size(1) > 1:
15 | zero = torch.tensor([0], device=src.device)
16 | src = src + torch.cat([zero, torch.cumsum(num_classes, 0)[:-1]])
17 |
18 | size = src.size(0), num_classes.sum()
19 | out = torch.zeros(size, dtype=dtype, device=src.device)
20 | out.scatter_(1, src, 1)
21 | return out
22 |
--------------------------------------------------------------------------------
/test/transforms/test_local_cartesian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import LocalCartesian
3 | from torch_geometric.data import Data
4 |
5 |
6 | def test_local_cartesian():
7 | assert LocalCartesian().__repr__() == 'LocalCartesian(cat=True)'
8 |
9 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
11 | data = Data(edge_index=edge_index, pos=pos)
12 |
13 | out = LocalCartesian()(data).edge_attr.tolist()
14 | assert out == [[1, 0.5], [0.25, 0.5], [1, 0.5], [0, 0.5]]
15 |
16 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float)
17 | out = LocalCartesian()(data).edge_attr.tolist()
18 | assert out == [[1, 1, 0.5], [1, 0.25, 0.5], [1, 1, 0.5], [1, 0, 0.5]]
19 |
--------------------------------------------------------------------------------
/torch_geometric/utils/to_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter_add
3 |
4 |
5 | def to_batch(x, batch, fill_value=0):
6 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
7 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
8 | cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)], dim=0)
9 |
10 | index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
11 | index = (index - cum_nodes[batch]) + (batch * max_num_nodes)
12 |
13 | size = [batch_size * max_num_nodes] + list(x.size())[1:]
14 | batch_x = x.new_full(size, fill_value)
15 | batch_x[index] = x
16 | size = [batch_size, max_num_nodes] + list(x.size())[1:]
17 | batch_x = batch_x.view(size)
18 |
19 | return batch_x, num_nodes
20 |
--------------------------------------------------------------------------------
/test/nn/pool/test_topk_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.nn.pool.topk_pool import topk, filter_adj
3 |
4 |
5 | def test_topk():
6 | x = torch.tensor([2, 4, 5, 6, 2, 9], dtype=torch.float)
7 | batch = torch.tensor([0, 0, 1, 1, 1, 1])
8 |
9 | perm = topk(x, 0.5, batch)
10 |
11 | assert perm.tolist() == [1, 5, 3]
12 | assert x[perm].tolist() == [4, 9, 6]
13 |
14 |
15 | def test_filter_adj():
16 | edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3],
17 | [1, 3, 0, 2, 1, 3, 0, 2]])
18 | edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
19 | perm = torch.tensor([2, 3])
20 |
21 | edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm)
22 | assert edge_index.tolist() == [[0, 1], [1, 0]]
23 | assert edge_attr.tolist() == [6, 8]
24 |
--------------------------------------------------------------------------------
/test/utils/test_softmax.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 |
3 | import torch
4 | from torch.autograd import Variable
5 | from torch_geometric.utils import softmax
6 |
7 |
8 | def test_softmax():
9 | row = torch.LongTensor([0, 0, 0, 1, 1, 1])
10 | col = torch.LongTensor([1, 2, 3, 0, 2, 3])
11 | edge_attr = torch.Tensor([0, 1, 2, 0, 1, 2])
12 | e = [exp(1), exp(2), exp(3)]
13 | e_sum = e[0] + e[1] + e[2]
14 | e = [e[0] / e_sum, e[1] / e_sum, e[2] / e_sum]
15 |
16 | output = softmax(edge_attr, row)
17 |
18 | output = softmax(Variable(edge_attr), row)
19 |
20 | output = softmax(edge_attr, col)
21 |
22 | edge_attr = torch.Tensor([[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]])
23 | output = softmax(edge_attr, row)
24 |
25 | output = softmax(Variable(edge_attr), row)
26 | output
27 |
--------------------------------------------------------------------------------
/test/utils/test_convert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.sparse
3 | import networkx
4 | from torch_geometric.utils import to_scipy_sparse_matrix, to_networkx
5 |
6 |
7 | def test_to_scipy_sparse_matrix():
8 | row = torch.tensor([0, 1, 0])
9 | col = torch.tensor([1, 0, 0])
10 |
11 | adj = to_scipy_sparse_matrix(torch.stack([row, col], dim=0))
12 | assert isinstance(adj, scipy.sparse.coo_matrix) is True
13 | assert adj.shape == (2, 2)
14 | assert adj.row.tolist() == row.tolist()
15 | assert adj.col.tolist() == col.tolist()
16 | assert adj.data.tolist() == [1, 1, 1]
17 |
18 |
19 | def test_to_networkx():
20 | row = torch.tensor([0, 1, 0])
21 | col = torch.tensor([1, 0, 0])
22 |
23 | adj = to_networkx(torch.stack([row, col], dim=0))
24 | assert networkx.to_numpy_matrix(adj).tolist() == [[1, 1], [1, 0]]
25 |
--------------------------------------------------------------------------------
/docs/source/modules/nn.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.nn
2 | ==================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | Sparse Convolutional Layers
8 | ---------------------------
9 |
10 | .. automodule:: torch_geometric.nn.conv
11 | :members:
12 | :undoc-members:
13 |
14 | .. automodule:: torch_geometric.nn.meta
15 | :members:
16 | :undoc-members:
17 |
18 | Dense Convolutional Layers
19 | --------------------------
20 |
21 | Global Pooling Layers
22 | ---------------------
23 |
24 | Sparse Hierarchical Pooling Layers
25 | ----------------------------------
26 |
27 | .. automodule:: torch_geometric.nn.pool
28 | :members:
29 | :undoc-members:
30 |
31 | Dense Hierarchical Pooling Layers
32 | ---------------------------------
33 |
34 | .. automodule:: torch_geometric.nn.dense.diff_pool
35 | :members:
36 | :undoc-members:
37 |
--------------------------------------------------------------------------------
/torch_geometric/utils/undirected.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_sparse import coalesce
3 |
4 | from .num_nodes import maybe_num_nodes
5 |
6 |
7 | def is_undirected(edge_index, num_nodes=None):
8 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
9 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
10 | undirected_edge_index = to_undirected(edge_index, num_nodes=num_nodes)
11 | return edge_index.size(1) == undirected_edge_index.size(1)
12 |
13 |
14 | def to_undirected(edge_index, num_nodes=None):
15 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
16 |
17 | row, col = edge_index
18 | row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
19 | edge_index = torch.stack([row, col], dim=0)
20 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
21 |
22 | return edge_index
23 |
--------------------------------------------------------------------------------
/torch_geometric/nn/inits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def uniform(size, tensor):
5 | stdv = 1.0 / math.sqrt(size)
6 | if tensor is not None:
7 | tensor.data.uniform_(-stdv, stdv)
8 |
9 |
10 | def glorot(tensor):
11 | stdv = math.sqrt(6.0 / (tensor.size(0) + tensor.size(1)))
12 | if tensor is not None:
13 | tensor.data.uniform_(-stdv, stdv)
14 |
15 |
16 | def zeros(tensor):
17 | if tensor is not None:
18 | tensor.data.fill_(0)
19 |
20 |
21 | def ones(tensor):
22 | if tensor is not None:
23 | tensor.data.fill_(1)
24 |
25 |
26 | def reset(nn):
27 | def _reset(item):
28 | if hasattr(item, 'reset_parameters'):
29 | item.reset_parameters()
30 |
31 | if hasattr(nn, 'children'):
32 | for item in nn.children():
33 | _reset(item)
34 | else:
35 | _reset(nn)
36 |
--------------------------------------------------------------------------------
/torch_geometric/read/off.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.read import parse_txt_array
3 | from torch_geometric.data import Data
4 |
5 |
6 | def parse_off(src):
7 | # Some files may contain a bug and do not have a carriage return after OFF.
8 | if src[0] == 'OFF':
9 | src = src[1:]
10 | else:
11 | src[0] = src[0][3:]
12 |
13 | num_nodes, num_faces = [int(item) for item in src[0].split()[:2]]
14 |
15 | pos = parse_txt_array(src[1:1 + num_nodes])
16 |
17 | face = src[1 + num_nodes:1 + num_nodes + num_faces]
18 | face = parse_txt_array(face, start=1, dtype=torch.long).t().contiguous()
19 |
20 | data = Data(pos=pos)
21 | data.face = face
22 |
23 | return data
24 |
25 |
26 | def read_off(path):
27 | with open(path, 'r') as f:
28 | src = f.read().split('\n')[:-1]
29 | return parse_off(src)
30 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/radius_graph.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 |
3 | import torch
4 | import scipy.spatial
5 | from torch_geometric.utils import remove_self_loops
6 |
7 |
8 | class RadiusGraph(object):
9 | def __init__(self, r):
10 | self.r = r
11 |
12 | def __call__(self, data):
13 | pos = data.pos
14 | assert not pos.is_cuda
15 |
16 | tree = scipy.spatial.cKDTree(pos)
17 | indices = tree.query_ball_tree(tree, self.r)
18 |
19 | row, col = [], []
20 | for i, neighbors in enumerate(indices):
21 | row += repeat(i, len(neighbors))
22 | col += neighbors
23 | edge_index = torch.tensor([row, col])
24 | edge_index, _ = remove_self_loops(edge_index)
25 |
26 | data.edge_index = edge_index
27 | return data
28 |
29 | def __repr__(self):
30 | return '{}(r={})'.format(self.__class__.__name__, self.r)
31 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/karate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import networkx as nx
3 | from torch_geometric.data import InMemoryDataset, Data
4 |
5 |
6 | class KarateClub(InMemoryDataset):
7 | def __init__(self, transform=None):
8 | super(KarateClub, self).__init__('.', transform, None, None)
9 |
10 | G = nx.karate_club_graph()
11 | adj = nx.to_scipy_sparse_matrix(G).tocoo()
12 | row = torch.from_numpy(adj.row).to(torch.long)
13 | col = torch.from_numpy(adj.col).to(torch.long)
14 | edge_index = torch.stack([row, col], dim=0)
15 | data = Data(edge_index=edge_index)
16 | data.x = torch.eye(data.num_nodes, dtype=torch.float)
17 | self.data, self.slices = self.collate([data])
18 |
19 | def _download(self):
20 | return
21 |
22 | def _process(self):
23 | return
24 |
25 | def __repr__(self):
26 | return '{}()'.format(self.__class__.__name__)
27 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/voxel_grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_cluster import grid_cluster
3 |
4 | from ..repeat import repeat
5 |
6 |
7 | def voxel_grid(pos, batch, size, start=None, end=None):
8 | pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
9 | num_nodes, dim = pos.size()
10 |
11 | size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim)
12 |
13 | pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1)
14 | size = size + [1]
15 | start = None if start is None else start + [0]
16 | end = None if end is None else end + [batch.max().item()]
17 |
18 | size = torch.tensor(size, dtype=pos.dtype, device=pos.device)
19 | if start is not None:
20 | start = torch.tensor(start, dtype=pos.dtype, device=pos.device)
21 | if end is not None:
22 | end = torch.tensor(end, dtype=pos.dtype, device=pos.device)
23 |
24 | return grid_cluster(pos, size, start, end)
25 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/nn_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.spatial
3 | from torch_geometric.utils import to_undirected
4 |
5 |
6 | class NNGraph(object):
7 | def __init__(self, k=6):
8 | self.k = k
9 |
10 | def __call__(self, data):
11 | pos = data.pos
12 | assert not pos.is_cuda
13 |
14 | row = torch.arange(pos.size(0), dtype=torch.long)
15 | row = row.view(-1, 1).repeat(1, self.k).view(-1)
16 |
17 | _, col = scipy.spatial.cKDTree(pos).query(pos, self.k + 1)
18 | col = torch.tensor(col)[:, 1:].contiguous().view(-1)
19 | mask = col < pos.size(0)
20 | edge_index = torch.stack([row[mask], col[mask]], dim=0)
21 | edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
22 |
23 | data.edge_index = edge_index
24 | return data
25 |
26 | def __repr__(self):
27 | return '{}(k={})'.format(self.__class__.__name__, self.k)
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | __version__ = '0.3.1'
4 | url = 'https://github.com/rusty1s/pytorch_geometric'
5 |
6 | install_requires = [
7 | 'numpy',
8 | 'scipy',
9 | 'networkx',
10 | 'plyfile',
11 | ]
12 | setup_requires = ['pytest-runner']
13 | tests_require = ['pytest', 'pytest-cov']
14 |
15 | setup(
16 | name='torch_geometric',
17 | version=__version__,
18 | description='Geometric Deep Learning Extension Library for PyTorch',
19 | author='Matthias Fey',
20 | author_email='matthias.fey@tu-dortmund.de',
21 | url=url,
22 | download_url='{}/archive/{}.tar.gz'.format(url, __version__),
23 | keywords=[
24 | 'pytorch', 'geometric-deep-learning', 'graph', 'mesh',
25 | 'neural-networks', 'spline-cnn'
26 | ],
27 | install_requires=install_requires,
28 | setup_requires=setup_requires,
29 | tests_require=tests_require,
30 | packages=find_packages())
31 |
--------------------------------------------------------------------------------
/torch_geometric/utils/loop.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .num_nodes import maybe_num_nodes
4 |
5 |
6 | def contains_self_loops(edge_index):
7 | row, col = edge_index
8 | mask = row == col
9 | return mask.sum().item() > 0
10 |
11 |
12 | def remove_self_loops(edge_index, edge_attr=None):
13 | row, col = edge_index
14 | mask = row != col
15 | edge_attr = edge_attr if edge_attr is None else edge_attr[mask]
16 | mask = mask.unsqueeze(0).expand_as(edge_index)
17 | edge_index = edge_index[mask].view(2, -1)
18 |
19 | return edge_index, edge_attr
20 |
21 |
22 | def add_self_loops(edge_index, num_nodes=None):
23 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
24 |
25 | dtype, device = edge_index.dtype, edge_index.device
26 | loop = torch.arange(0, num_nodes, dtype=dtype, device=device)
27 | loop = loop.unsqueeze(0).repeat(2, 1)
28 | edge_index = torch.cat([edge_index, loop], dim=1)
29 |
30 | return edge_index
31 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/random_scale.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 |
4 | class RandomScale(object):
5 | r"""Scales node positions by a randomly sampled factor :math:`s` within a
6 | given interval, e.g., resulting in the transformation matrix
7 |
8 | .. math::
9 | \begin{bmatrix}
10 | s & 0 & 0 \\
11 | 0 & s & 0 \\
12 | 0 & 0 & s \\
13 | \end{bmatrix}
14 |
15 | for three-dimensional positions.
16 |
17 | Args:
18 | scale (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale
19 | is randomly sampled from the range
20 | :math:`a \leq \mathrm{scale} \leq b`.
21 | """
22 |
23 | def __init__(self, scales):
24 | self.scales = scales
25 |
26 | def __call__(self, data):
27 | scale = random.uniform(*self.scales)
28 | data.pos = data.pos * scale
29 | return data
30 |
31 | def __repr__(self):
32 | return '{}({})'.format(self.__class__.__name__, self.scales)
33 |
--------------------------------------------------------------------------------
/torch_geometric/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from torch.utils.data.dataloader import default_collate
3 |
4 | from torch_geometric.data import Batch
5 |
6 |
7 | class DataLoader(torch.utils.data.DataLoader):
8 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
9 | super(DataLoader, self).__init__(
10 | dataset,
11 | batch_size,
12 | shuffle,
13 | collate_fn=lambda batch: Batch.from_data_list(batch),
14 | **kwargs)
15 |
16 |
17 | class DenseDataLoader(torch.utils.data.DataLoader):
18 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
19 | def dense_collate(data_list):
20 | batch = Batch()
21 | for key in data_list[0].keys:
22 | batch[key] = default_collate([d[key] for d in data_list])
23 | return batch
24 |
25 | super(DenseDataLoader, self).__init__(
26 | dataset, batch_size, shuffle, collate_fn=dense_collate, **kwargs)
27 |
--------------------------------------------------------------------------------
/torch_geometric/utils/degree.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .num_nodes import maybe_num_nodes
4 |
5 |
6 | def degree(index, num_nodes=None, dtype=None):
7 | """Computes the degree of a given index tensor.
8 |
9 | Args:
10 | index (LongTensor): Source or target indices of edges.
11 | num_nodes (int, optional): The number of nodes in :attr:`index`.
12 | (default: :obj:`None`)
13 | dtype (:obj:`torch.dtype`, optional). The desired data type of returned
14 | tensor.
15 |
16 | :rtype: :class:`Tensor`
17 |
18 | .. testsetup::
19 |
20 | import torch
21 |
22 | .. testcode::
23 |
24 | from torch_geometric.utils import degree
25 | index = torch.tensor([0, 1, 0, 2, 0])
26 | output = degree(index)
27 | print(output)
28 |
29 | .. testoutput::
30 |
31 | tensor([3., 1., 1.])
32 | """
33 |
34 | num_nodes = maybe_num_nodes(index, num_nodes)
35 | out = torch.zeros((num_nodes), dtype=dtype, device=index.device)
36 | return out.scatter_add_(0, index, out.new_ones((index.size(0))))
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018 Matthias Fey
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/sample_points.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class SamplePoints(object):
5 | def __init__(self, num):
6 | self.num = num
7 |
8 | def __call__(self, data):
9 | pos, face = data.pos, data.face
10 | assert not pos.is_cuda and not face.is_cuda
11 | assert pos.size(1) == 3 and face.size(0) == 3
12 |
13 | area = (pos[face[1]] - pos[face[0]]).cross(pos[face[2]] - pos[face[0]])
14 | area = torch.sqrt((area**2).sum(dim=-1)) / 2
15 |
16 | prob = area / area.sum()
17 | sample = torch.multinomial(prob, self.num, replacement=True)
18 | face = face[:, sample]
19 |
20 | frac = torch.rand(self.num, 2)
21 | mask = frac.sum(dim=-1) > 1
22 | frac[mask] = 1 - frac[mask]
23 |
24 | pos_sampled = pos[face[0]]
25 | pos_sampled += frac[:, :1] * (pos[face[1]] - pos[face[0]])
26 | pos_sampled += frac[:, 1:] * (pos[face[2]] - pos[face[0]])
27 |
28 | data.pos = pos_sampled
29 | data.face = None
30 | return data
31 |
32 | def __repr__(self):
33 | return '{}({})'.format(self.__class__.__name__, self.num)
34 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import sphinx_rtd_theme
3 | import doctest
4 | from torch_geometric import __version__
5 |
6 | extensions = [
7 | 'sphinx.ext.autodoc',
8 | 'sphinx.ext.doctest',
9 | 'sphinx.ext.intersphinx',
10 | 'sphinx.ext.mathjax',
11 | 'sphinx.ext.napoleon',
12 | 'sphinx.ext.viewcode',
13 | 'sphinx.ext.githubpages',
14 | ]
15 |
16 | source_suffix = '.rst'
17 | master_doc = 'index'
18 |
19 | author = 'Matthias Fey'
20 | project = 'pytorch_geometric'
21 | copyright = '{}, {}'.format(datetime.datetime.now().year, author)
22 |
23 | version = 'master ({})'.format(__version__)
24 | release = 'master'
25 |
26 | html_theme = 'sphinx_rtd_theme'
27 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
28 |
29 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE
30 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)}
31 |
32 | html_theme_options = {
33 | 'collapse_navigation': False,
34 | 'display_version': True,
35 | 'logo_only': True,
36 | }
37 |
38 | html_logo = '_static/img/logo.svg'
39 | html_static_path = ['_static']
40 | html_context = {'css_files': ['_static/css/custom.css']}
41 |
42 | add_module_names = False
43 |
--------------------------------------------------------------------------------
/test/data/test_split.py:
--------------------------------------------------------------------------------
1 | # x1 = torch.tensor([1, 2, 3], dtype=torch.float)
2 | # edge_index1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
3 | # x2 = torch.tensor([1, 2], dtype=torch.float)
4 | # edge_index2 = torch.tensor([[0, 1], [1, 0]])
5 | # data1, data2 = Data(x1, edge_index1), Data(x2, edge_index2)
6 | # dataset, slices = collate_to_set([data1, data2])
7 |
8 | # def test_data_from_set():
9 | # data = data_from_set(dataset, slices, 0)
10 | # assert len(data) == 2
11 | # assert data.x.tolist() == x1.tolist()
12 | # assert data.edge_index.tolist() == edge_index1.tolist()
13 |
14 | # data = data_from_set(dataset, slices, 1)
15 | # assert len(data) == 2
16 | # assert data.x.tolist() == x2.tolist()
17 | # assert data.edge_index.tolist() == edge_index2.tolist()
18 |
19 | # def test_split_set():
20 | # output, output_slices = split_set(dataset, slices, torch.tensor([0]))
21 |
22 | # assert len(output) == 2
23 | # assert output.x.tolist() == x1.tolist()
24 | # assert output.edge_index.tolist() == edge_index1.tolist()
25 |
26 | # assert len(output_slices.keys()) == 2
27 | # assert output_slices['x'].tolist() == [0, 3]
28 | # assert output_slices['edge_index'].tolist() == [0, 4]
29 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/random_shear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import LinearTransformation
3 |
4 |
5 | class RandomShear(object):
6 | r"""Shears node positions by randomly sampled factors :math:`s` within a
7 | given interval, e.g., resulting in the transformation matrix
8 |
9 | .. math::
10 | \begin{bmatrix}
11 | 1 & s_{xy} & s_{xz} \\
12 | s_{yx} & 1 & s_{yz} \\
13 | s_{zx} & z_{zy} & 1 \\
14 | \end{bmatrix}
15 |
16 | for three-dimensional positions.
17 |
18 | Args:
19 | shear (float or int): maximum shearing factor defining the range
20 | :math:`(-\mathrm{shear}, +\mathrm{shear})` to sample from.
21 | """
22 |
23 | def __init__(self, shear):
24 | self.shear = abs(shear)
25 |
26 | def __call__(self, data):
27 | dim = data.pos.size(1)
28 |
29 | matrix = data.pos.new_empty(dim, dim).uniform_(-self.shear, self.shear)
30 | eye = torch.arange(dim, dtype=torch.long)
31 | matrix[eye, eye] = 1
32 |
33 | return LinearTransformation(matrix)(data)
34 |
35 | def __repr__(self):
36 | return '{}({})'.format(self.__class__.__name__, self.shear)
37 |
--------------------------------------------------------------------------------
/torch_geometric/nn/dense/diff_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def dense_diff_pool(x, adj, s, mask=None):
5 | r"""Differentiable pooling operator from the `"Hierarchical Graph
6 | Representation Learning with Differentiable Pooling"
7 | `_ paper
8 |
9 | .. math::
10 | \mathbf{X}^{\prime} &= \mathbf{S} \cdot \mathbf{X}
11 |
12 | \mathbf{A}^{\prime} &= \mathbf{S}^{\top} \cdot \mathbf{A} \cdot
13 | \mathbf{S}
14 |
15 | based on dense learned assignments :math:`\mathbf{S}`.
16 | """
17 |
18 | x = x.unsqueeze(0) if x.dim() == 2 else x
19 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
20 | s = s.unsqueeze(0) if s.dim() == 2 else s
21 |
22 | batch_size, num_nodes, _ = x.size()
23 |
24 | s = torch.softmax(s, dim=-1)
25 |
26 | if mask is not None:
27 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
28 | x, s = x * mask, s * mask
29 |
30 | out = torch.matmul(s.transpose(1, 2), x)
31 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
32 |
33 | reg = adj - torch.matmul(s, s.transpose(1, 2))
34 | reg = torch.norm(reg, p=2)
35 | reg = reg / adj.numel()
36 |
37 | return out, out_adj, reg
38 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/rusty1s/pytorch_geometric
2 |
3 | PyTorch Geometric documentation
4 | ===============================
5 |
6 | `PyTorch Geometric `_ is a geometric deep learning extension library for `PyTorch `_.
7 |
8 | It consists of various methods for deep learning on graphs and other irregular structures, also known as `geometric deep learning `_, from a variety of published papers.
9 | In addition, it consists of an easy-to-use mini-batch loader, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.
10 |
11 | .. toctree::
12 | :glob:
13 | :maxdepth: 1
14 | :caption: Notes
15 |
16 | notes/installation
17 | notes/introduction
18 | notes/create_dataset
19 |
20 | .. toctree::
21 | :glob:
22 | :maxdepth: 1
23 | :caption: Package Reference
24 |
25 | modules/nn
26 | modules/data
27 | modules/datasets
28 | modules/transforms
29 | modules/utils
30 |
31 | Indices and Tables
32 | ==================
33 |
34 | * :ref:`genindex`
35 | * :ref:`modindex`
36 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | sudo: required
3 | dist: trusty
4 | matrix:
5 | include:
6 | - python: 2.7
7 | - python: 3.5
8 | - python: 3.6
9 | addons:
10 | apt:
11 | sources:
12 | - ubuntu-toolchain-r-test
13 | packages:
14 | - gcc-4.9
15 | - g++-4.9
16 | before_install:
17 | - export CC="gcc-4.9"
18 | - export CXX="g++-4.9"
19 | install:
20 | - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl; fi
21 | - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl; fi
22 | - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl; fi
23 | - pip install pycodestyle
24 | - pip install flake8
25 | - pip install torch-scatter torch-sparse torch-cluster torch-spline-conv
26 | - pip install codecov
27 | script:
28 | - pycodestyle .
29 | - flake8 .
30 | - python setup.py install
31 | - python setup.py test
32 | - cd docs && pip install -r requirements.txt && make clean && make html && make doctest && cd ..
33 | after_success:
34 | - codecov
35 | notifications:
36 | email: false
37 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/random_rotate.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 | import math
4 |
5 | import torch
6 | from torch_geometric.transforms import LinearTransformation
7 |
8 |
9 | class RandomRotate(object):
10 | def __init__(self, degrees, axis=0):
11 | if isinstance(degrees, numbers.Number):
12 | degrees = (-abs(degrees), abs(degrees))
13 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
14 | self.degrees = degrees
15 | self.axis = axis
16 |
17 | def __call__(self, data):
18 | degree = math.pi * random.uniform(*self.degrees) / 180.0
19 | sin, cos = math.sin(degree), math.cos(degree)
20 |
21 | if data.pos.size(1) == 2:
22 | matrix = [[cos, sin], [-sin, cos]]
23 | else:
24 | if self.axis == 0:
25 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
26 | elif self.axis == 1:
27 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
28 | else:
29 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
30 | return LinearTransformation(torch.tensor(matrix))(data)
31 |
32 | def __repr__(self):
33 | return '{}({})'.format(self.__class__.__name__, self.degrees)
34 |
--------------------------------------------------------------------------------
/docs/build/html/_sources/index.rst.txt:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/rusty1s/pytorch_geometric
2 |
3 | PyTorch Geometric documentation
4 | ===============================
5 |
6 | `PyTorch Geometric `_ is a geometric deep learning extension library for `PyTorch `_.
7 |
8 | It consists of various methods for deep learning on graphs and other irregular structures, also known as `geometric deep learning `_, from a variety of published papers.
9 | In addition, it consists of an easy-to-use mini-batch loader, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.
10 |
11 | .. toctree::
12 | :glob:
13 | :maxdepth: 1
14 | :caption: Notes
15 |
16 | notes/installation
17 | notes/introduction
18 | notes/create_dataset
19 |
20 | .. toctree::
21 | :glob:
22 | :maxdepth: 1
23 | :caption: Package Reference
24 |
25 | modules/nn
26 | modules/data
27 | modules/datasets
28 | modules/transforms
29 | modules/utils
30 |
31 | Indices and Tables
32 | ==================
33 |
34 | * :ref:`genindex`
35 | * :ref:`modindex`
36 |
--------------------------------------------------------------------------------
/test/utils/test_loop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import (contains_self_loops, remove_self_loops,
3 | add_self_loops)
4 |
5 |
6 | def test_contains_self_loops():
7 | row = torch.tensor([0, 1, 0])
8 | col = torch.tensor([1, 0, 0])
9 |
10 | assert contains_self_loops(torch.stack([row, col], dim=0))
11 |
12 | row = torch.tensor([0, 1, 1])
13 | col = torch.tensor([1, 0, 2])
14 |
15 | assert not contains_self_loops(torch.stack([row, col], dim=0))
16 |
17 |
18 | def test_remove_self_loops():
19 | row = torch.tensor([1, 0, 1, 0, 2, 1])
20 | col = torch.tensor([0, 1, 1, 1, 2, 0])
21 | edge_index = torch.stack([row, col], dim=0)
22 | edge_attr = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]
23 | edge_attr = torch.tensor(edge_attr)
24 |
25 | out = remove_self_loops(edge_index, edge_attr)
26 | assert out[0].tolist() == [[1, 0, 0, 1], [0, 1, 1, 0]]
27 | assert out[1].tolist() == [[1, 2], [3, 4], [7, 8], [11, 12]]
28 |
29 |
30 | def test_add_self_loops():
31 | row = torch.tensor([0, 1, 0])
32 | col = torch.tensor([1, 0, 0])
33 | edge_index = torch.stack([row, col], dim=0)
34 |
35 | expected = [[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]
36 | assert add_self_loops(edge_index).tolist() == expected
37 |
--------------------------------------------------------------------------------
/torch_geometric/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .degree import degree
2 | from .scatter import scatter_
3 | from .softmax import softmax
4 | from .undirected import is_undirected, to_undirected
5 | from .isolated import contains_isolated_nodes
6 | from .loop import contains_self_loops, remove_self_loops, add_self_loops
7 | from .one_hot import one_hot
8 | from .grid import grid
9 | from .normalized_cut import normalized_cut
10 | from .sparse import dense_to_sparse
11 | from .to_batch import to_batch
12 | from .convert import to_scipy_sparse_matrix, to_networkx
13 | from .metric import (accuracy, true_positive, true_negative, false_positive,
14 | false_negative, precision, recall, f1_score)
15 |
16 | __all__ = [
17 | 'degree',
18 | 'scatter_',
19 | 'softmax',
20 | 'is_undirected',
21 | 'to_undirected',
22 | 'contains_self_loops',
23 | 'remove_self_loops',
24 | 'add_self_loops',
25 | 'contains_isolated_nodes',
26 | 'one_hot',
27 | 'grid',
28 | 'normalized_cut',
29 | 'dense_to_sparse',
30 | 'to_batch',
31 | 'to_scipy_sparse_matrix',
32 | 'to_networkx',
33 | 'accuracy',
34 | 'true_positive',
35 | 'true_negative',
36 | 'false_positive',
37 | 'false_negative',
38 | 'precision',
39 | 'recall',
40 | 'f1_score',
41 | ]
42 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/compose.py:
--------------------------------------------------------------------------------
1 | class Compose(object):
2 | """Composes several transforms together.
3 |
4 | Args:
5 | transforms (list of :obj:`transform` objects): List of transforms to
6 | compose.
7 |
8 | .. testsetup::
9 |
10 | import torch
11 | from torch_geometric.data import Data
12 |
13 | .. testcode::
14 |
15 | import torch_geometric.transforms as T
16 |
17 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
18 | edge_index = torch.tensor([[0, 1], [1, 2]])
19 | data = Data(edge_index=edge_index, pos=pos)
20 |
21 | transform = T.Compose([T.Cartesian(), T.TargetIndegree()])
22 | data = transform(data)
23 |
24 | print(data.edge_attr)
25 |
26 | .. testoutput::
27 |
28 | tensor([[0.7500, 0.5000, 1.0000],
29 | [1.0000, 0.5000, 1.0000]])
30 | """
31 |
32 | def __init__(self, transforms):
33 | self.transforms = transforms
34 |
35 | def __call__(self, data):
36 | for t in self.transforms:
37 | data = t(data)
38 | return data
39 |
40 | def __repr__(self):
41 | args = [' {},'.format(t) for t in self.transforms]
42 | return '{}([\n{}\n])'.format(self.__class__.__name__, '\n'.join(args))
43 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/max_pool.py:
--------------------------------------------------------------------------------
1 | from torch_scatter import scatter_max
2 | from torch_geometric.data import Batch
3 |
4 | from .consecutive import consecutive_cluster
5 | from .pool import pool_edge, pool_batch, pool_pos
6 |
7 |
8 | def _max_pool_x(cluster, x, size=None):
9 | fill = -9999999
10 | x, _ = scatter_max(x, cluster, dim=0, dim_size=size, fill_value=fill)
11 | x[x == fill] = 0
12 | return x
13 |
14 |
15 | def max_pool_x(cluster, x, batch, size=None):
16 | if size is not None:
17 | return _max_pool_x(cluster, x, (batch.max().item() + 1) * size)
18 |
19 | cluster, perm = consecutive_cluster(cluster)
20 | x = _max_pool_x(cluster, x)
21 | batch = pool_batch(perm, batch)
22 |
23 | return x, batch
24 |
25 |
26 | def max_pool(cluster, data, transform=None):
27 | cluster, perm = consecutive_cluster(cluster)
28 |
29 | x = _max_pool_x(cluster, data.x)
30 | index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)
31 | batch = None if data.batch is None else pool_batch(perm, data.batch)
32 | pos = None if data.pos is None else pool_pos(cluster, data.pos)
33 |
34 | data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)
35 |
36 | if transform is not None:
37 | data = transform(data)
38 |
39 | return data
40 |
--------------------------------------------------------------------------------
/test/datasets/test_planetoid.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import random
3 | import os.path as osp
4 | import shutil
5 |
6 | from torch_geometric.datasets import Planetoid
7 | from torch_geometric.data import DataLoader
8 |
9 |
10 | def test_citeseer():
11 | root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)), 'test')
12 | dataset = Planetoid(root, 'Citeseer')
13 | loader = DataLoader(dataset, batch_size=len(dataset))
14 |
15 | assert len(dataset) == 1
16 | assert dataset.__repr__() == 'Citeseer()'
17 |
18 | for data in loader:
19 | assert data.num_graphs == 1
20 | assert data.num_nodes == 3327
21 | assert data.num_edges / 2 == 4552
22 |
23 | assert len(data) == 7
24 | assert list(data.x.size()) == [data.num_nodes, 3703]
25 | assert list(data.y.size()) == [data.num_nodes]
26 | assert data.y.max() + 1 == 6
27 | assert data.train_mask.sum() == 6 * 20
28 | assert data.val_mask.sum() == 500
29 | assert data.test_mask.sum() == 1000
30 | assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0
31 | assert list(data.batch.size()) == [data.num_nodes]
32 |
33 | assert data.contains_isolated_nodes()
34 | assert not data.contains_self_loops()
35 | assert data.is_undirected()
36 |
37 | shutil.rmtree(root)
38 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/avg_pool.py:
--------------------------------------------------------------------------------
1 | from torch_scatter import scatter_mean
2 | from torch_geometric.data import Batch
3 |
4 | from .consecutive import consecutive_cluster
5 | from .pool import pool_edge, pool_batch, pool_pos
6 |
7 |
8 | def _avg_pool_x(cluster, x, size=None):
9 | fill = -9999999
10 | x, _ = scatter_mean(x, cluster, dim=0, dim_size=size, fill_value=fill)
11 | x[x == fill] = 0
12 | return x
13 |
14 |
15 | def avg_pool_x(cluster, x, batch=None, size=None):
16 | assert batch is None or size is None
17 |
18 | if size is not None:
19 | return _avg_pool_x(cluster, x, size)
20 |
21 | cluster, perm = consecutive_cluster(cluster)
22 | x = _avg_pool_x(cluster, x)
23 | batch = pool_batch(perm, batch)
24 |
25 | return x, batch
26 |
27 |
28 | def avg_pool(cluster, data, transform=None):
29 | cluster, perm = consecutive_cluster(cluster)
30 |
31 | x = _avg_pool_x(cluster, data.x)
32 | index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)
33 | batch = None if data.batch is None else pool_batch(perm, data.batch)
34 | pos = None if data.pos is None else pool_pos(cluster, data.pos)
35 |
36 | data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)
37 |
38 | if transform is not None:
39 | data = transform(data)
40 |
41 | return data
42 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/random_translate.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | from itertools import repeat
3 |
4 | import torch
5 |
6 |
7 | class RandomTranslate(object):
8 | r"""Translates node positions by randomly sampled translation values
9 | within a given interval. In contrast to other random transformations,
10 | translation is applied randomly at each position.
11 |
12 | Args:
13 | translate (sequence or float or int): maximum translation in each
14 | dimension, defining the range
15 | :math:`(-\mathrm{translate}, +\mathrm{translate})` to sample from.
16 | If :obj:`translate` is a number instead of a sequence, the same
17 | range is used for each dimension.
18 | """
19 |
20 | def __init__(self, translate):
21 | self.translate = translate
22 |
23 | def __call__(self, data):
24 | (n, dim), t = data.pos.size(), self.translate
25 | if isinstance(t, numbers.Number):
26 | t = repeat(t, dim)
27 |
28 | ts = []
29 | for d in range(dim):
30 | ts.append(data.pos.new_empty(n).uniform_(-abs(t[d]), abs(t[d])))
31 | t = torch.stack(ts, dim=-1)
32 |
33 | data.pos = data.pos + t
34 | return data
35 |
36 | def __repr__(self):
37 | return '{}({})'.format(self.__class__.__name__, *self.scale)
38 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/two_hop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_sparse import spspmm, coalesce
3 |
4 | from torch_geometric.utils import remove_self_loops
5 |
6 |
7 | class TwoHop(object):
8 | def __call__(self, data):
9 | edge_index, edge_attr = data.edge_index, data.edge_attr
10 | n = data.num_nodes
11 |
12 | fill = 1e16
13 | value = edge_index.new_full(
14 | (edge_index.size(1), ), fill, dtype=torch.float)
15 |
16 | index, value = spspmm(edge_index, value, edge_index, value, n, n, n)
17 | index, value = remove_self_loops(index, value)
18 |
19 | edge_index = torch.cat([edge_index, index], dim=1)
20 | if edge_attr is None:
21 | data.edge_index, _ = coalesce(edge_index, None, n, n)
22 | else:
23 | value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
24 | value = value.expand(-1, *list(edge_attr.size())[1:])
25 | edge_attr = torch.cat([edge_attr, value], dim=0)
26 | data.edge_index, edge_attr = coalesce(
27 | edge_index, edge_attr, n, n, op='min', fill_value=fill)
28 | edge_attr[edge_attr >= fill] = 0
29 | data.edge_attr = edge_attr
30 |
31 | return data
32 |
33 | def __repr__(self):
34 | return '{}()'.format(self.__class__.__name__)
35 |
--------------------------------------------------------------------------------
/torch_geometric/read/sdf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_sparse import coalesce
3 | from torch_geometric.read import parse_txt_array
4 | from torch_geometric.utils import one_hot
5 | from torch_geometric.data import Data
6 |
7 | elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
8 |
9 |
10 | def parse_sdf(src):
11 | src = src.split('\n')[3:]
12 | num_atoms, num_bonds = [int(item) for item in src[0].split()[:2]]
13 |
14 | atom_block = src[1:num_atoms + 1]
15 | pos = parse_txt_array(atom_block, end=3)
16 | x = torch.tensor([elems[item.split()[3]] for item in atom_block])
17 | x = one_hot(x, len(elems))
18 |
19 | bond_block = src[1 + num_atoms:1 + num_atoms + num_bonds]
20 | row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1
21 | row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
22 | edge_index = torch.stack([row, col], dim=0)
23 | edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1
24 | edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
25 | edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms,
26 | num_atoms)
27 |
28 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)
29 | return data
30 |
31 |
32 | def read_sdf(path):
33 | with open(path, 'r') as f:
34 | return parse_sdf(f.read())
35 |
--------------------------------------------------------------------------------
/test/datasets/test_tu_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import sys
4 | import random
5 | import os.path as osp
6 | import shutil
7 |
8 | import pytest
9 | from torch_geometric.datasets import TUDataset
10 | from torch_geometric.data import DataLoader
11 |
12 |
13 | def test_enzymes():
14 | root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)), 'test')
15 | dataset = TUDataset(root, 'ENZYMES')
16 | loader = DataLoader(dataset, batch_size=len(dataset))
17 |
18 | assert len(dataset) == 600
19 | assert dataset.__repr__() == 'ENZYMES(600)'
20 |
21 | for data in loader:
22 | assert data.num_graphs == 600
23 |
24 | avg_num_nodes = data.num_nodes / data.num_graphs
25 | assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63
26 |
27 | avg_num_edges = data.num_edges / (2 * data.num_graphs)
28 | assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14
29 |
30 | assert len(data) == 4
31 | assert list(data.x.size()) == [data.num_nodes, 21]
32 | assert list(data.y.size()) == [data.num_graphs]
33 | assert data.y.max() + 1 == 6
34 | assert list(data.batch.size()) == [data.num_nodes]
35 |
36 | assert data.contains_isolated_nodes()
37 | assert not data.contains_self_loops()
38 | assert data.is_undirected()
39 |
40 | shutil.rmtree(root)
41 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/planetoid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import InMemoryDataset, download_url
3 | from torch_geometric.read import read_planetoid_data
4 |
5 |
6 | class Planetoid(InMemoryDataset):
7 | url = 'https://github.com/kimiyoung/planetoid/raw/master/data'
8 |
9 | def __init__(self, root, name, transform=None, pre_transform=None):
10 | self.name = name
11 | super(Planetoid, self).__init__(root, transform, pre_transform)
12 | self.data, self.slices = torch.load(self.processed_paths[0])
13 |
14 | @property
15 | def raw_file_names(self):
16 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
17 | return ['ind.{}.{}'.format(self.name.lower(), name) for name in names]
18 |
19 | @property
20 | def processed_file_names(self):
21 | return 'data.pt'
22 |
23 | def download(self):
24 | for name in self.raw_file_names:
25 | download_url('{}/{}'.format(self.url, name), self.raw_dir)
26 |
27 | def process(self):
28 | data = read_planetoid_data(self.raw_dir, self.name)
29 | data = data if self.pre_transform is None else self.pre_transform(data)
30 | data, slices = self.collate([data])
31 | torch.save((data, slices), self.processed_paths[0])
32 |
33 | def __repr__(self):
34 | return '{}()'.format(self.name)
35 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/random_flip.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 |
4 | class RandomFlip(object):
5 | """Flips node positions along a given axis randomly with a given
6 | probability.
7 |
8 | Args:
9 | axis (int): The axis along the position of nodes being flipped.
10 | p (float, optional): Probability of the position of nodes being
11 | flipped. (default: :obj:`0.5`)
12 |
13 | .. testsetup::
14 |
15 | import torch
16 | from torch_geometric.data import Data
17 |
18 | .. testcode::
19 |
20 | from torch_geometric.transforms import RandomFlip
21 |
22 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float)
23 | data = Data(pos=pos)
24 |
25 | data = RandomFlip(axis=0, p=1)(data)
26 |
27 | print(data.pos)
28 |
29 | .. testoutput::
30 |
31 | tensor([[ 1., 1.],
32 | [ 3., 0.],
33 | [-2., -1.]])
34 | """
35 |
36 | def __init__(self, axis, p=0.5):
37 | self.axis = axis
38 | self.p = p
39 |
40 | def __call__(self, data):
41 | if random.random() < self.p:
42 | data.pos[:, self.axis] = -data.pos[:, self.axis]
43 | return data
44 |
45 | def __repr__(self):
46 | return '{}(axis={}, p={})'.format(self.__class__.__name__, self.axis,
47 | self.p)
48 |
--------------------------------------------------------------------------------
/torch_geometric/nn/prop/gcn_prop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_sparse import spmm
3 | from torch_scatter import scatter_add
4 | from torch_geometric.utils import remove_self_loops, add_self_loops
5 |
6 |
7 | class GCNProp(torch.nn.Module):
8 | def __init__(self, improved=False):
9 | super(GCNProp, self).__init__()
10 | self.improved = improved
11 |
12 | def forward(self, x, edge_index, edge_attr=None):
13 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
14 |
15 | if edge_attr is None:
16 | edge_attr = x.new_ones((edge_index.size(1), ))
17 | assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1)
18 |
19 | # Add self-loops to adjacency matrix.
20 | edge_index = add_self_loops(edge_index, x.size(0))
21 | loop_value = x.new_full((x.size(0), ), 1 if not self.improved else 2)
22 | edge_attr = torch.cat([edge_attr, loop_value], dim=0)
23 |
24 | # Normalize adjacency matrix.
25 | row, col = edge_index
26 | deg = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0))
27 | deg = deg.pow(-0.5)
28 | deg[deg == float('inf')] = 0
29 | edge_attr = deg[row] * edge_attr * deg[col]
30 |
31 | # Perform the propagation.
32 | out = spmm(edge_index, edge_attr, x.size(0), x)
33 |
34 | return out
35 |
36 | def __repr__(self):
37 | return '{}()'.format(self.__class__.__name__)
38 |
--------------------------------------------------------------------------------
/torch_geometric/data/batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data
3 |
4 |
5 | class Batch(Data):
6 | def __init__(self, batch=None, **kwargs):
7 | super(Batch, self).__init__(**kwargs)
8 | self.batch = batch
9 |
10 | @staticmethod
11 | def from_data_list(data_list):
12 | keys = [set(data.keys) for data in data_list]
13 | keys = list(set.union(*keys))
14 | assert 'batch' not in keys
15 |
16 | batch = Batch()
17 |
18 | for key in keys:
19 | batch[key] = []
20 | batch.batch = []
21 |
22 | cumsum = 0
23 | for i, data in enumerate(data_list):
24 | num_nodes = data.num_nodes
25 | batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
26 | for key in data.keys:
27 | item = data[key]
28 | item = item + cumsum if batch.cumsum(key, item) else item
29 | batch[key].append(item)
30 | cumsum += num_nodes
31 |
32 | for key in keys:
33 | batch[key] = torch.cat(
34 | batch[key], dim=data_list[0].cat_dim(key, batch[key][0]))
35 | batch.batch = torch.cat(batch.batch, dim=-1)
36 | return batch.contiguous()
37 |
38 | def cumsum(self, key, item):
39 | return item.dim() > 1 and item.dtype == torch.long
40 |
41 | @property
42 | def num_graphs(self):
43 | return self.batch[-1].item() + 1
44 |
--------------------------------------------------------------------------------
/torch_geometric/utils/grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_sparse import coalesce
3 |
4 |
5 | def grid(height, width, dtype=None, device=None):
6 | edge_index = grid_index(height, width, device)
7 | pos = grid_pos(height, width, dtype, device)
8 | return edge_index, pos
9 |
10 |
11 | def grid_index(height, width, device=None):
12 | w = width
13 | kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]
14 | kernel = torch.tensor(kernel, device=device)
15 |
16 | row = torch.arange(height * width, dtype=torch.long, device=device)
17 | row = row.view(-1, 1).repeat(1, kernel.size(0))
18 | col = row + kernel.view(1, -1)
19 | row, col = row.view(height, -1), col.view(height, -1)
20 | index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device)
21 | row, col = row[:, index].view(-1), col[:, index].view(-1)
22 |
23 | mask = (col >= 0) & (col < height * width)
24 | row, col = row[mask], col[mask]
25 |
26 | edge_index = torch.stack([row, col], dim=0)
27 | edge_index, _ = coalesce(edge_index, None, height * width, height * width)
28 |
29 | return edge_index
30 |
31 |
32 | def grid_pos(height, width, dtype=None, device=None):
33 | x = torch.arange(width, dtype=dtype, device=device)
34 | y = (height - 1) - torch.arange(height, dtype=dtype, device=device)
35 |
36 | x = x.repeat(height)
37 | y = y.unsqueeze(-1).repeat(1, width).view(-1)
38 |
39 | return torch.stack([x, y], dim=-1)
40 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/constant.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Constant(object):
5 | r"""Adds a constant value to each node feature.
6 |
7 | Args:
8 | value (int): The value to add.
9 | cat (bool, optional): Concat value to node features instead
10 | of replacing them. (default: :obj:`True`)
11 |
12 | .. testsetup::
13 |
14 | import torch
15 | from torch_geometric.data import Data
16 |
17 | .. testcode::
18 |
19 | from torch_geometric.transforms import Constant
20 |
21 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
22 | data = Data(edge_index=edge_index)
23 |
24 | data = Constant(value=1)(data)
25 |
26 | print(data.x)
27 |
28 | .. testoutput::
29 |
30 | tensor([[1.],
31 | [1.],
32 | [1.]])
33 | """
34 |
35 | def __init__(self, value, cat=True):
36 | self.value = value
37 | self.cat = cat
38 |
39 | def __call__(self, data):
40 | x = data.x
41 |
42 | c = torch.full((data.num_nodes, 1), self.value)
43 |
44 | if x is not None and self.cat:
45 | x = x.view(-1, 1) if x.dim() == 1 else x
46 | data.x = torch.cat([x, c.to(x.dtype).to(x.device)], dim=-1)
47 | else:
48 | data.x = c
49 |
50 | return data
51 |
52 | def __repr__(self):
53 | return '{}({}, cat={})'.format(self.__class__.__name__, self.value,
54 | self.cat)
55 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/edge_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import scatter_
3 |
4 | from ..inits import reset
5 |
6 |
7 | class EdgeConv(torch.nn.Module):
8 | r"""The edge convolutional operator from the `"Dynamic Graph CNN for
9 | Learning on Point Clouds" `_ paper
10 |
11 | .. math::
12 | \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)}
13 | h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \,
14 | \mathbf{x}_j - \mathbf{x}_i),
15 |
16 | where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* a MLP.
17 |
18 | Args:
19 | nn (nn.Sequential): Neural network.
20 | aggr (string): The aggregation operator to use (one of :obj:`"add"`,
21 | :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`)
22 | """
23 |
24 | def __init__(self, nn, aggr='add'):
25 | super(EdgeConv, self).__init__()
26 | self.nn = nn
27 | self.aggr = aggr
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | reset(self.nn)
32 |
33 | def forward(self, x, edge_index):
34 | """"""
35 | row, col = edge_index
36 | x = x.unsqueeze(-1) if x.dim() == 1 else x
37 |
38 | out = torch.cat([x[row], x[col] - x[row]], dim=1)
39 | out = self.nn(out)
40 | out = scatter_(self.aggr, out, row, dim_size=x.size(0))
41 |
42 | return out
43 |
44 | def __repr__(self):
45 | return '{}({})'.format(self.__class__.__name__, self.nn)
46 |
--------------------------------------------------------------------------------
/test/data/test_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data
3 |
4 |
5 | def test_data():
6 | x = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float).t()
7 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
8 | data = Data(x=x, edge_index=edge_index)
9 |
10 | assert data.x.tolist() == x.tolist()
11 | assert data['x'].tolist() == x.tolist()
12 |
13 | assert sorted(data.keys) == ['edge_index', 'x']
14 | assert len(data) == 2
15 | assert 'x' in data and 'edge_index' in data and 'pos' not in data
16 |
17 | assert data.cat_dim('x', data.x) == 0
18 | assert data.cat_dim('edge_index', data.edge_index) == -1
19 |
20 | data.to(torch.device('cpu'))
21 | assert not data.x.is_contiguous()
22 | data.contiguous()
23 | assert data.x.is_contiguous()
24 |
25 | data['x'] = x + 1
26 | assert data.x.tolist() == (x + 1).tolist()
27 |
28 | assert data.__repr__() == 'Data(edge_index=[2, 4], x=[3, 2])'
29 |
30 | dictionary = {'x': x, 'edge_index': edge_index}
31 | data = Data.from_dict(dictionary)
32 | assert sorted(data.keys) == ['edge_index', 'x']
33 |
34 | assert not data.contains_isolated_nodes()
35 | assert not data.contains_self_loops()
36 | assert data.is_undirected()
37 | assert not data.is_directed()
38 |
39 | assert data.num_nodes == 3
40 | assert data.num_edges == 4
41 |
42 | data.x = None
43 | assert data.num_nodes == 3
44 |
45 | data.edge_index = None
46 | assert data.num_nodes is None
47 | assert data.num_edges is None
48 |
--------------------------------------------------------------------------------
/torch_geometric/utils/convert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.sparse
3 | import networkx as nx
4 |
5 | from .num_nodes import maybe_num_nodes
6 |
7 |
8 | def to_scipy_sparse_matrix(edge_index, edge_attr=None, num_nodes=None):
9 | row, col = edge_index.cpu()
10 |
11 | if edge_attr is None:
12 | edge_attr = torch.ones(row.size(0))
13 | else:
14 | edge_attr = edge_attr.view(-1).cpu()
15 | assert edge_attr.size(0) == row.size(0)
16 |
17 | N = maybe_num_nodes(edge_index, num_nodes)
18 | out = scipy.sparse.coo_matrix((edge_attr, (row, col)), (N, N))
19 | return out
20 |
21 |
22 | def to_networkx(edge_index, x=None, edge_attr=None, pos=None, num_nodes=None):
23 | num_nodes = num_nodes if x is None else x.size(0)
24 | num_nodes = num_nodes if pos is None else pos.size(0)
25 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
26 |
27 | G = nx.Graph()
28 |
29 | for i in range(num_nodes):
30 | G.add_node(i)
31 | if x is not None:
32 | G.nodes[i]['x'] = x[i].cpu().numpy()
33 | if pos is not None:
34 | G.nodes[i]['pos'] = pos[i].cpu().numpy()
35 |
36 | for i in range(edge_index.size(1)):
37 | source, target = edge_index[0][i].item(), edge_index[1][i].item()
38 | G.add_edge(source, target)
39 | if edge_attr is not None:
40 | if edge_attr.numel() == edge_attr.size(0):
41 | G[source][target]['weight'] = edge_attr[i].item()
42 | else:
43 | G[source][target]['weight'] = edge_attr[i].cpu().numpy()
44 |
45 | return G
46 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Distance(object):
5 | r"""Saves the Euclidean distance of linked nodes in its edge attributes.
6 |
7 | Args:
8 | cat (bool, optional): Concat pseudo-coordinates to edge attributes
9 | instead of replacing them. (default: :obj:`True`)
10 |
11 | .. testsetup::
12 |
13 | import torch
14 | from torch_geometric.data import Data
15 |
16 | .. testcode::
17 |
18 | from torch_geometric.transforms import Distance
19 |
20 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
21 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
22 | data = Data(edge_index=edge_index, pos=pos)
23 |
24 | data = Distance()(data)
25 |
26 | print(data.edge_attr)
27 |
28 | .. testoutput::
29 |
30 | tensor([[1.],
31 | [1.],
32 | [2.],
33 | [2.]])
34 | """
35 |
36 | def __init__(self, cat=True):
37 | self.cat = cat
38 |
39 | def __call__(self, data):
40 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
41 |
42 | dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
43 |
44 | if pseudo is not None and self.cat:
45 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
46 | data.edge_attr = torch.cat([pseudo, dist.type_as(pseudo)], dim=-1)
47 | else:
48 | data.edge_attr = dist
49 |
50 | return data
51 |
52 | def __repr__(self):
53 | return '{}(cat={})'.format(self.__class__.__name__, self.cat)
54 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .compose import Compose
2 | from .constant import Constant
3 | from .distance import Distance
4 | from .cartesian import Cartesian
5 | from .local_cartesian import LocalCartesian
6 | from .polar import Polar
7 | from .spherical import Spherical
8 | from .one_hot_degree import OneHotDegree
9 | from .target_indegree import TargetIndegree
10 | from .center import Center
11 | from .normalize_scale import NormalizeScale
12 | from .random_translate import RandomTranslate
13 | from .random_flip import RandomFlip
14 | from .linear_transformation import LinearTransformation
15 | from .random_scale import RandomScale
16 | from .random_rotate import RandomRotate
17 | from .random_shear import RandomShear
18 | from .normalize_features import NormalizeFeatures
19 | from .add_self_loops import AddSelfLoops
20 | from .nn_graph import NNGraph
21 | from .radius_graph import RadiusGraph
22 | from .face_to_edge import FaceToEdge
23 | from .sample_points import SamplePoints
24 | from .to_dense import ToDense
25 | from .two_hop import TwoHop
26 |
27 | __all__ = [
28 | 'Compose',
29 | 'Constant',
30 | 'Distance',
31 | 'Cartesian',
32 | 'LocalCartesian',
33 | 'Polar',
34 | 'Spherical',
35 | 'OneHotDegree',
36 | 'TargetIndegree',
37 | 'Center',
38 | 'NormalizeScale',
39 | 'RandomTranslate',
40 | 'RandomFlip',
41 | 'LinearTransformation',
42 | 'RandomScale',
43 | 'RandomRotate',
44 | 'RandomShear',
45 | 'NormalizeFeatures',
46 | 'AddSelfLoops',
47 | 'NNGraph',
48 | 'RadiusGraph',
49 | 'FaceToEdge',
50 | 'SamplePoints',
51 | 'ToDense',
52 | 'TwoHop',
53 | ]
54 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/qm9.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch_geometric.data import (InMemoryDataset, download_url, extract_tar,
5 | Data)
6 |
7 |
8 | class QM9(InMemoryDataset):
9 | url = 'http://www.roemisch-drei.de/qm9.tar.gz'
10 |
11 | def __init__(self,
12 | root,
13 | transform=None,
14 | pre_transform=None,
15 | pre_filter=None):
16 | super(QM9, self).__init__(root, transform, pre_transform, pre_filter)
17 | self.data, self.slices = torch.load(self.processed_paths[0])
18 |
19 | @property
20 | def raw_file_names(self):
21 | return 'qm9.pt'
22 |
23 | @property
24 | def processed_file_names(self):
25 | return 'data.pt'
26 |
27 | def download(self):
28 | file_path = download_url(self.url, self.raw_dir)
29 | extract_tar(file_path, self.raw_dir, mode='r')
30 | os.unlink(file_path)
31 |
32 | def process(self):
33 | raw_data_list = torch.load(self.raw_paths[0])
34 | data_list = [
35 | Data(
36 | x=d['x'],
37 | edge_index=d['edge_index'],
38 | edge_attr=d['edge_attr'],
39 | y=d['y'],
40 | pos=d['pos']) for d in raw_data_list
41 | ]
42 |
43 | if self.pre_filter is not None:
44 | data_list = [data for data in data_list if self.pre_filter(data)]
45 |
46 | if self.pre_transform is not None:
47 | data_list = [self.pre_transform(data) for data in data_list]
48 |
49 | data, slices = self.collate(data_list)
50 | torch.save((data, slices), self.processed_paths[0])
51 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/one_hot_degree.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import degree, one_hot
3 |
4 |
5 | class OneHotDegree(object):
6 | r"""Adds the node degree as one hot encodings to the node features.
7 |
8 | Args:
9 | max_degree (int): Maximum degree.
10 | cat (bool, optional): Concat node degrees to node features instead
11 | of replacing them. (default: :obj:`True`)
12 |
13 | .. testsetup::
14 |
15 | import torch
16 | from torch_geometric.data import Data
17 |
18 | .. testcode::
19 |
20 | from torch_geometric.transforms import OneHotDegree
21 |
22 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
23 | data = Data(edge_index=edge_index)
24 |
25 | data = OneHotDegree(max_degree=2)(data)
26 |
27 | print(data.x)
28 |
29 | .. testoutput::
30 |
31 | tensor([[0., 1., 0.],
32 | [0., 0., 1.],
33 | [0., 1., 0.]])
34 | """
35 |
36 | def __init__(self, max_degree, cat=True):
37 | self.max_degree = max_degree
38 | self.cat = cat
39 |
40 | def __call__(self, data):
41 | row, x = data.edge_index[0], data.x
42 | deg = degree(row, data.num_nodes)
43 | deg = one_hot(deg, num_classes=self.max_degree + 1)
44 |
45 | if x is not None and self.cat:
46 | x = x.view(-1, 1) if x.dim() == 1 else x
47 | data.x = torch.cat([x, deg.to(x.dtype)], dim=-1)
48 | else:
49 | data.x = deg
50 |
51 | return data
52 |
53 | def __repr__(self):
54 | return '{}({}, cat={})'.format(self.__class__.__name__,
55 | self.max_degree, self.cat)
56 |
--------------------------------------------------------------------------------
/torch_geometric/utils/scatter.py:
--------------------------------------------------------------------------------
1 | import torch_scatter
2 |
3 |
4 | def scatter_(name, src, index, dim_size=None):
5 | r"""Aggregates all values from the :attr:`src` tensor at the indices
6 | specified in the :attr:`index` tensor along the first dimension.
7 | If multiple indices reference the same location, their contributions
8 | are aggregated according to :attr:`name` (:obj:`"add"`, :obj:`"mean"`,
9 | :obj:`"max"`).
10 |
11 | Args:
12 | name (string): The aggregation to use (one of :obj:`"add"`,
13 | :obj:`"mean"`, :obj:`"max"`).
14 | src (Tensor): The source tensor.
15 | index (LongTensor): The indices of elements to scatter.
16 | dim_size (int, optional): Automatically create output tensor with size
17 | :attr:`dim_size` in the first dimension. If :attr:`None`, a minimal
18 | sized output tensor is returned. (default: :obj:`None`)
19 |
20 | :rtype: :class:`Tensor`
21 |
22 | .. testsetup::
23 |
24 | import torch
25 |
26 | .. testcode::
27 |
28 | from torch_geometric.utils import scatter_
29 | src = torch.Tensor([2, 3, -2, 1, 1])
30 | index = torch.tensor([0, 1, 0, 1, 2])
31 | output = scatter_("add", src, index)
32 | print(output)
33 |
34 | .. testoutput::
35 |
36 | tensor([0., 4., 1.])
37 | """
38 |
39 | assert name in ['add', 'mean', 'max']
40 |
41 | op = getattr(torch_scatter, 'scatter_{}'.format(name))
42 | fill_value = -1e38 if name is 'max' else 0
43 |
44 | out = op(src, index, 0, None, dim_size, fill_value)
45 | if isinstance(out, tuple):
46 | out = out[0]
47 |
48 | if name is 'max':
49 | out[out == fill_value] = 0
50 |
51 | return out
52 |
--------------------------------------------------------------------------------
/torch_geometric/nn/dense/sage_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import Parameter
4 |
5 | from ..inits import uniform
6 |
7 |
8 | class DenseSAGEConv(torch.nn.Module):
9 | def __init__(self,
10 | in_channels,
11 | out_channels,
12 | norm=True,
13 | norm_embed=True,
14 | bias=True):
15 | super(DenseSAGEConv, self).__init__()
16 |
17 | self.in_channels = in_channels
18 | self.out_channels = out_channels
19 | self.norm = norm
20 | self.norm_embed = norm_embed
21 | self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
22 |
23 | if bias:
24 | self.bias = Parameter(torch.Tensor(out_channels))
25 | else:
26 | self.register_parameter('bias', None)
27 |
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | uniform(self.in_channels, self.weight)
32 | uniform(self.in_channels, self.bias)
33 |
34 | def forward(self, x, adj):
35 | x = x.unsqueeze(0) if x.dim() == 2 else x
36 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
37 |
38 | out = torch.matmul(adj, x)
39 |
40 | if self.norm:
41 | out = out / adj.sum(dim=-1, keepdim=True)
42 |
43 | out = torch.matmul(out, self.weight)
44 |
45 | if self.bias is not None:
46 | out = out + self.bias
47 |
48 | if self.norm_embed:
49 | out = F.normalize(out, p=2, dim=-1)
50 |
51 | return out
52 |
53 | def __repr__(self):
54 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
55 | self.out_channels)
56 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/qm7.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.io
3 | from torch_geometric.data import InMemoryDataset, download_url, Data
4 |
5 |
6 | class QM7(InMemoryDataset):
7 | url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
8 | 'datasets/qm7b.mat'
9 |
10 | def __init__(self,
11 | root,
12 | transform=None,
13 | pre_transform=None,
14 | pre_filter=None):
15 | super(QM7, self).__init__(root, transform, pre_transform, pre_filter)
16 | self.data, self.slices = torch.load(self.processed_paths[0])
17 |
18 | @property
19 | def raw_file_names(self):
20 | return 'qm7b.mat'
21 |
22 | @property
23 | def processed_file_names(self):
24 | return 'data.pt'
25 |
26 | def download(self):
27 | download_url(self.url, self.raw_dir)
28 |
29 | def process(self):
30 | data = scipy.io.loadmat(self.raw_paths[0])
31 | coulomb_matrix = torch.from_numpy(data['X'])
32 | target = torch.from_numpy(data['T']).to(torch.float)
33 |
34 | data_list = []
35 | for i in range(target.shape[0]):
36 | edge_index = coulomb_matrix[i].nonzero().t().contiguous()
37 | edge_attr = coulomb_matrix[i, edge_index[0], edge_index[1]]
38 | y = target[i].view(1, -1)
39 | data = Data(edge_index=edge_index, edge_attr=edge_attr, y=y)
40 | data_list.append(data)
41 |
42 | if self.pre_filter is not None:
43 | data_list = [d for d in data_list if self.pre_filter(d)]
44 |
45 | if self.pre_transform is not None:
46 | data_list = [self.pre_transform(d) for d in data_list]
47 |
48 | data, slices = self.collate(data_list)
49 | torch.save((data, slices), self.processed_paths[0])
50 |
--------------------------------------------------------------------------------
/examples/gat.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import Planetoid
6 | import torch_geometric.transforms as T
7 | from torch_geometric.nn import GATConv
8 |
9 | dataset = 'Cora'
10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
11 | data = Planetoid(path, dataset, T.NormalizeFeatures())[0]
12 |
13 |
14 | class Net(torch.nn.Module):
15 | def __init__(self):
16 | super(Net, self).__init__()
17 | self.att1 = GATConv(data.num_features, 8, heads=8, dropout=0.6)
18 | self.att2 = GATConv(8 * 8, data.num_classes, dropout=0.6)
19 |
20 | def forward(self):
21 | x = F.dropout(data.x, p=0.6, training=self.training)
22 | x = F.elu(self.att1(x, data.edge_index))
23 | x = F.dropout(x, p=0.6, training=self.training)
24 | x = self.att2(x, data.edge_index)
25 | return F.log_softmax(x, dim=1)
26 |
27 |
28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29 | model, data = Net().to(device), data.to(device)
30 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
31 |
32 |
33 | def train():
34 | model.train()
35 | optimizer.zero_grad()
36 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
37 | optimizer.step()
38 |
39 |
40 | def test():
41 | model.eval()
42 | logits, accs = model(), []
43 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
44 | pred = logits[mask].max(1)[1]
45 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
46 | accs.append(acc)
47 | return accs
48 |
49 |
50 | for epoch in range(1, 201):
51 | train()
52 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
53 | print(log.format(epoch, *test()))
54 |
--------------------------------------------------------------------------------
/torch_geometric/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(pred, target):
5 | return (pred == target).sum().item() / target.numel()
6 |
7 |
8 | def true_positive(pred, target, num_classes):
9 | out = []
10 | for i in range(num_classes):
11 | out.append(((pred == i) & (target == i)).sum())
12 |
13 | return torch.tensor(out)
14 |
15 |
16 | def true_negative(pred, target, num_classes):
17 | out = []
18 | for i in range(num_classes):
19 | out.append(((pred != i) & (target != i)).sum())
20 |
21 | return torch.tensor(out)
22 |
23 |
24 | def false_positive(pred, target, num_classes):
25 | out = []
26 | for i in range(num_classes):
27 | out.append(((pred == i) & (target != i)).sum())
28 |
29 | return torch.tensor(out)
30 |
31 |
32 | def false_negative(pred, target, num_classes):
33 | out = []
34 | for i in range(num_classes):
35 | out.append(((pred != i) & (target == i)).sum())
36 |
37 | return torch.tensor(out)
38 |
39 |
40 | def precision(pred, target, num_classes):
41 | tp = true_positive(pred, target, num_classes).to(torch.float)
42 | fp = false_positive(pred, target, num_classes).to(torch.float)
43 |
44 | out = tp / (tp + fp)
45 | out[torch.isnan(out)] = 0
46 |
47 | return out
48 |
49 |
50 | def recall(pred, target, num_classes):
51 | tp = true_positive(pred, target, num_classes).to(torch.float)
52 | fn = false_negative(pred, target, num_classes).to(torch.float)
53 |
54 | out = tp / (tp + fn)
55 | out[torch.isnan(out)] = 0
56 |
57 | return out
58 |
59 |
60 | def f1_score(pred, target, num_classes):
61 | prec = precision(pred, target, num_classes)
62 | rec = recall(pred, target, num_classes)
63 |
64 | score = 2 * (prec * rec) / (prec + rec)
65 | score[torch.isnan(score)] = 0
66 |
67 | return score
68 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/target_indegree.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import degree
3 |
4 |
5 | class TargetIndegree(object):
6 | r"""Saves the globally normalized degree of target nodes (mapped to the
7 | fixed interval :math:`[0, 1]`)
8 |
9 | .. math::
10 |
11 | \mathbf{u}(i,j) = \frac{\deg(j)}{\max_{v \in \mathcal{V}} \deg(v)}
12 |
13 | in its edge attributes.
14 |
15 | Args:
16 | cat (bool, optional): Concat pseudo-coordinates to edge attributes
17 | instead of replacing them. (default: :obj:`True`)
18 |
19 | .. testsetup::
20 |
21 | import torch
22 | from torch_geometric.data import Data
23 |
24 | .. testcode::
25 |
26 | from torch_geometric.transforms import TargetIndegree
27 |
28 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
29 | data = Data(edge_index=edge_index)
30 |
31 | data = TargetIndegree()(data)
32 |
33 | print(data.edge_attr)
34 |
35 | .. testoutput::
36 |
37 | tensor([[1.0000],
38 | [0.5000],
39 | [0.5000],
40 | [1.0000]])
41 | """
42 |
43 | def __init__(self, cat=True):
44 | self.cat = cat
45 |
46 | def __call__(self, data):
47 | col, pseudo = data.edge_index[1], data.edge_attr
48 |
49 | deg = degree(col, data.num_nodes)
50 | deg = deg / deg.max()
51 | deg = deg[col]
52 | deg = deg.view(-1, 1)
53 |
54 | if pseudo is not None and self.cat:
55 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
56 | data.edge_attr = torch.cat([pseudo, deg.type_as(pseudo)], dim=-1)
57 | else:
58 | data.edge_attr = deg
59 |
60 | return data
61 |
62 | def __repr__(self):
63 | return '{}(cat={})'.format(self.__class__.__name__, self.cat)
64 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/linear_transformation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LinearTransformation(object):
5 | r"""Transforms node positions with a square transformation matrix computed
6 | offline.
7 |
8 | Args:
9 | matrix (Tensor): tensor with shape :math:`[D, D]` where :math:`D`
10 | corresponds to the dimensionality of node positions.
11 |
12 | .. testsetup::
13 |
14 | import torch
15 | from torch_geometric.data import Data
16 |
17 | .. testcode::
18 |
19 | from torch_geometric.transforms import LinearTransformation
20 |
21 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float)
22 | data = Data(pos=pos)
23 |
24 | matrix = torch.tensor([[2, 0], [0, 2]], dtype=torch.float)
25 | data = LinearTransformation(matrix)(data)
26 |
27 | print(data.pos)
28 |
29 | .. testoutput::
30 |
31 | tensor([[-2., 2.],
32 | [-6., 0.],
33 | [ 4., -2.]])
34 | """
35 |
36 | def __init__(self, matrix):
37 | assert matrix.dim() == 2, (
38 | 'Transformation matrix should be two-dimensional.')
39 | assert matrix.size(0) == matrix.size(1), (
40 | 'Transformation matrix should be square. Got [{} x {}] rectangular'
41 | 'matrix.'.format(*matrix.size()))
42 |
43 | self.matrix = matrix
44 |
45 | def __call__(self, data):
46 | pos = data.pos.view(-1, 1) if data.pos.dim() == 1 else data.pos
47 |
48 | assert pos.size(1) == self.matrix.size(0), (
49 | 'Node position matrix and transformation matrix have incompatible '
50 | 'shape.')
51 |
52 | data.pos = torch.mm(pos, self.matrix.to(pos.dtype).to(pos.device))
53 |
54 | return data
55 |
56 | def __repr__(self):
57 | return '{}({})'.format(self.__class__.__name__, self.matrix.tolist())
58 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/gin_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter_add
3 | from torch_geometric.utils import remove_self_loops
4 |
5 | from ..inits import reset
6 |
7 |
8 | class GINConv(torch.nn.Module):
9 | r"""The graph isomorphism operator from the `"How Powerful are
10 | Graph Neural Networks?" `_ paper
11 |
12 | .. math::
13 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot
14 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right),
15 |
16 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* a MLP.
17 |
18 | Args:
19 | nn (nn.Sequential): Neural network.
20 | eps (float, optional): (Initial) :math:`\epsilon` value.
21 | (default: :obj:`0`)
22 | train_eps (bool optional): If set to :obj:`True`, :math:`\epsilon` will
23 | be a trainable parameter. (default: :obj:`False`)
24 | """
25 |
26 | def __init__(self, nn, eps=0, train_eps=False):
27 | super(GINConv, self).__init__()
28 | self.nn = nn
29 | self.initial_eps = eps
30 | if train_eps:
31 | self.eps = torch.nn.Parameter(torch.Tensor([eps]))
32 | else:
33 | self.register_buffer('eps', torch.Tensor([eps]))
34 | self.reset_parameters()
35 |
36 | def reset_parameters(self):
37 | reset(self.nn)
38 | self.eps.data.fill_(self.initial_eps)
39 |
40 | def forward(self, x, edge_index):
41 | """"""
42 | x = x.unsqueeze(-1) if x.dim() == 1 else x
43 | edge_index, _ = remove_self_loops(edge_index)
44 | row, col = edge_index
45 |
46 | out = scatter_add(x[col], row, dim=0, dim_size=x.size(0))
47 | out = (1 + self.eps) * x + out
48 | out = self.nn(out)
49 | return out
50 |
51 | def __repr__(self):
52 | return '{}({})'.format(self.__class__.__name__, self.nn)
53 |
--------------------------------------------------------------------------------
/examples/agnn.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import Planetoid
6 | import torch_geometric.transforms as T
7 | from torch_geometric.nn import AGNNProp
8 |
9 | dataset = 'Cora'
10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
11 | data = Planetoid(path, dataset, T.NormalizeFeatures())[0]
12 |
13 |
14 | class Net(torch.nn.Module):
15 | def __init__(self):
16 | super(Net, self).__init__()
17 | self.fc1 = torch.nn.Linear(data.num_features, 16)
18 | self.prop1 = AGNNProp(requires_grad=False)
19 | self.prop2 = AGNNProp(requires_grad=True)
20 | self.fc2 = torch.nn.Linear(16, data.num_classes)
21 |
22 | def forward(self):
23 | x = F.dropout(data.x, training=self.training)
24 | x = F.relu(self.fc1(x))
25 | x = self.prop1(x, data.edge_index)
26 | x = self.prop2(x, data.edge_index)
27 | x = F.dropout(x, training=self.training)
28 | x = self.fc2(x)
29 | return F.log_softmax(x, dim=1)
30 |
31 |
32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33 | model, data = Net().to(device), data.to(device)
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
35 |
36 |
37 | def train():
38 | model.train()
39 | optimizer.zero_grad()
40 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
41 | optimizer.step()
42 |
43 |
44 | def test():
45 | model.eval()
46 | logits, accs = model(), []
47 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
48 | pred = logits[mask].max(1)[1]
49 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
50 | accs.append(acc)
51 | return accs
52 |
53 |
54 | for epoch in range(1, 101):
55 | train()
56 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
57 | print(log.format(epoch, *test()))
58 |
--------------------------------------------------------------------------------
/torch_geometric/nn/prop/agnn_prop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_sparse import spmm
4 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
5 |
6 |
7 | class AGNNProp(torch.nn.Module):
8 | """Graph Attentional Propagation Layer from the
9 | `"Attention-based Graph Neural Network for Semi-Supervised Learning (AGNN)"
10 | `_ paper.
11 |
12 | Args:
13 | requires_grad (bool, optional): If set to :obj:`False`, the propagation
14 | layer will not be trainable. (default: :obj:`True`)
15 | """
16 |
17 | def __init__(self, requires_grad=True):
18 | super(AGNNProp, self).__init__()
19 |
20 | if requires_grad:
21 | self.beta = Parameter(torch.Tensor(1))
22 | else:
23 | self.register_buffer('beta', torch.ones(1))
24 |
25 | self.requires_grad = requires_grad
26 | self.reset_parameters()
27 |
28 | def reset_parameters(self):
29 | if self.requires_grad:
30 | self.beta.data.uniform_(0, 1)
31 |
32 | def forward(self, x, edge_index):
33 | num_nodes = x.size(0)
34 |
35 | x = x.unsqueeze(-1) if x.dim() == 1 else x
36 | beta = self.beta if self.requires_grad else self._buffers['beta']
37 |
38 | # Add self-loops to adjacency matrix.
39 | edge_index, edge_attr = remove_self_loops(edge_index)
40 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
41 | row, col = edge_index
42 |
43 | # Compute attention coefficients.
44 | norm = torch.norm(x, p=2, dim=1)
45 | alpha = (x[row] * x[col]).sum(dim=1) / (norm[row] * norm[col])
46 | alpha = softmax(alpha * beta, row, num_nodes=x.size(0))
47 |
48 | # Perform the propagation.
49 | out = spmm(edge_index, alpha, num_nodes, x)
50 |
51 | return out
52 |
53 | def __repr__(self):
54 | return '{}()'.format(self.__class__.__name__)
55 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/faust.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import shutil
3 |
4 | import torch
5 | from torch_geometric.data import InMemoryDataset, extract_zip
6 | from torch_geometric.read import read_ply
7 |
8 |
9 | class FAUST(InMemoryDataset):
10 | url = 'http://faust.is.tue.mpg.de/'
11 |
12 | def __init__(self,
13 | root,
14 | train=True,
15 | transform=None,
16 | pre_transform=None,
17 | pre_filter=None):
18 | super(FAUST, self).__init__(root, transform, pre_transform, pre_filter)
19 | path = self.processed_paths[0] if train else self.processed_paths[1]
20 | self.data, self.slices = torch.load(path)
21 |
22 | @property
23 | def raw_file_names(self):
24 | return 'MPI-FAUST.zip'
25 |
26 | @property
27 | def processed_file_names(self):
28 | return ['training.pt', 'test.pt']
29 |
30 | def download(self):
31 | raise RuntimeError(
32 | 'Dataset not found. Please download MPI-FAUST.zip from {} and '
33 | 'move it to {}'.format(self.url, self.raw_dir))
34 |
35 | def process(self):
36 | extract_zip(self.raw_paths[0], self.raw_dir, log=False)
37 |
38 | path = osp.join(self.raw_dir, 'MPI-FAUST', 'training', 'registrations')
39 | path = osp.join(path, 'tr_reg_{0:03d}.ply')
40 | data_list = []
41 | for i in range(100):
42 | data = read_ply(path.format(i))
43 | data.y = torch.tensor([i % 10], dtype=torch.long)
44 | if self.pre_filter is not None and not self.pre_filter(data):
45 | continue
46 | if self.pre_transform is not None:
47 | data = self.pre_transform(data)
48 | data_list.append(data)
49 |
50 | torch.save(self.collate(data_list[:80]), self.processed_paths[0])
51 | torch.save(self.collate(data_list[80:]), self.processed_paths[1])
52 |
53 | shutil.rmtree(osp.join(self.raw_dir, 'MPI-FAUST'))
54 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/to_dense.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ToDense(object):
5 | def __init__(self, num_nodes=None):
6 | self.num_nodes = num_nodes
7 |
8 | def __call__(self, data):
9 | assert data.edge_index is not None
10 |
11 | orig_num_nodes = data.num_nodes
12 | if self.num_nodes is None:
13 | num_nodes = orig_num_nodes
14 | else:
15 | assert orig_num_nodes <= self.num_nodes
16 | num_nodes = self.num_nodes
17 |
18 | if data.edge_attr is None:
19 | edge_attr = torch.ones(data.edge_index.size(1), dtype=torch.float)
20 | else:
21 | edge_attr = data.edge_attr
22 |
23 | size = torch.Size([num_nodes, num_nodes] + list(edge_attr.size())[1:])
24 | adj = torch.sparse_coo_tensor(data.edge_index, edge_attr, size)
25 | data.adj = adj.to_dense()
26 | data.edge_index = None
27 | data.edge_attr = None
28 |
29 | data.mask = torch.zeros(num_nodes, dtype=torch.uint8)
30 | data.mask[:orig_num_nodes] = 1
31 |
32 | if data.x is not None:
33 | size = [num_nodes - data.x.size(0)] + list(data.x.size())[1:]
34 | data.x = torch.cat([data.x, data.x.new_zeros(size)], dim=0)
35 |
36 | if data.pos is not None:
37 | size = [num_nodes - data.pos.size(0)] + list(data.pos.size())[1:]
38 | data.pos = torch.cat([data.pos, data.pos.new_zeros(size)], dim=0)
39 |
40 | if data.y is not None and (data.y.size(0) == orig_num_nodes):
41 | size = [num_nodes - data.y.size(0)] + list(data.y.size())[1:]
42 | data.y = torch.cat([data.y, data.y.new_zeros(size)], dim=0)
43 |
44 | return data
45 |
46 | def __repr__(self):
47 | if self.num_nodes is None:
48 | return '{}()'.format(self.__class__.__name__)
49 | else:
50 | return '{}(num_nodes={})'.format(self.__class__.__name__,
51 | self.num_nodes)
52 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/cartesian.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Cartesian(object):
5 | r"""Saves the globally normalized spatial relation of linked nodes as
6 | Cartesian coordinates (mapped to the fixed interval :math:`[0, 1]`)
7 |
8 | .. math::
9 | \mathbf{u}(i,j) = 0.5 + \frac{\mathbf{pos}_j - \mathbf{pos}_i}{2 \cdot
10 | \max_{(v, w) \in \mathcal{E}} | \mathbf{pos}_w - \mathbf{pos}_v|}
11 |
12 | in its edge attributes.
13 |
14 | Args:
15 | cat (bool, optional): Concat pseudo-coordinates to edge attributes
16 | instead of replacing them. (default: :obj:`True`)
17 |
18 | .. testsetup::
19 |
20 | import torch
21 | from torch_geometric.data import Data
22 |
23 | .. testcode::
24 |
25 | from torch_geometric.transforms import Cartesian
26 |
27 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
28 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
29 | data = Data(edge_index=edge_index, pos=pos)
30 |
31 | data = Cartesian()(data)
32 |
33 | print(data.edge_attr)
34 |
35 | .. testoutput::
36 |
37 | tensor([[0.7500, 0.5000],
38 | [0.2500, 0.5000],
39 | [1.0000, 0.5000],
40 | [0.0000, 0.5000]])
41 | """
42 |
43 | def __init__(self, cat=True):
44 | self.cat = cat
45 |
46 | def __call__(self, data):
47 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
48 |
49 | cart = pos[col] - pos[row]
50 | cart = cart / (2 * cart.abs().max())
51 | cart += 0.5
52 | cart = cart.view(-1, 1) if cart.dim() == 1 else cart
53 |
54 | if pseudo is not None and self.cat:
55 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
56 | data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)
57 | else:
58 | data.edge_attr = cart
59 |
60 | return data
61 |
62 | def __repr__(self):
63 | return '{}(cat={})'.format(self.__class__.__name__, self.cat)
64 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/polar.py:
--------------------------------------------------------------------------------
1 | from math import pi as PI
2 |
3 | import torch
4 |
5 |
6 | class Polar(object):
7 | r"""Saves the globally normalized two-dimensional spatial relation of
8 | linked nodes as polar coordinates (mapped to the fixed interval
9 | :math:`[0, 1]`) in its edge attributes.
10 |
11 | Args:
12 | cat (bool, optional): Concat pseudo-coordinates to edge attributes
13 | instead of replacing them. (default: :obj:`True`)
14 |
15 | .. testsetup::
16 |
17 | import torch
18 | from torch_geometric.data import Data
19 |
20 | .. testcode::
21 |
22 | from torch_geometric.transforms import Polar
23 |
24 | pos = torch.tensor([[-1, 0], [0, 0], [0, 2]], dtype=torch.float)
25 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
26 | data = Data(edge_index=edge_index, pos=pos)
27 |
28 | data = Polar()(data)
29 |
30 | print(data.edge_attr)
31 |
32 | .. testoutput::
33 |
34 | tensor([[0.5000, 0.0000],
35 | [0.5000, 0.5000],
36 | [1.0000, 0.2500],
37 | [1.0000, 0.7500]])
38 | """
39 |
40 | def __init__(self, cat=True):
41 | self.cat = cat
42 |
43 | def __call__(self, data):
44 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
45 | assert pos.dim() == 2 and pos.size(1) == 2
46 |
47 | cart = pos[col] - pos[row]
48 | rho = torch.norm(cart, p=2, dim=-1)
49 | rho = rho / rho.max()
50 | theta = torch.atan2(cart[..., 1], cart[..., 0]) / (2 * PI)
51 | theta += (theta < 0).type_as(theta)
52 | polar = torch.stack([rho, theta], dim=1)
53 |
54 | if pseudo is not None and self.cat:
55 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
56 | data.edge_attr = torch.cat([pseudo, polar.type_as(pos)], dim=-1)
57 | else:
58 | data.edge_attr = polar
59 |
60 | return data
61 |
62 | def __repr__(self):
63 | return '{}(cat={})'.format(self.__class__.__name__, self.cat)
64 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/spherical.py:
--------------------------------------------------------------------------------
1 | from math import pi as PI
2 |
3 | import torch
4 |
5 |
6 | class Spherical(object):
7 | r"""Saves the globally normalized three-dimensional spatial relation of
8 | linked nodes as spherical coordinates (mapped to the fixed interval
9 | :math:`[0, 1]`) in its edge attributes.
10 |
11 | Args:
12 | cat (bool, optional): Concat pseudo-coordinates to edge attributes
13 | instead of replacing them. (default: :obj:`True`)
14 |
15 | .. testsetup::
16 |
17 | import torch
18 | from torch_geometric.data import Data
19 |
20 | .. testcode::
21 |
22 | from torch_geometric.transforms import Spherical
23 |
24 | pos = torch.tensor([[0, 0, 0], [0, 1, 1]], dtype=torch.float)
25 | edge_index = torch.tensor([[0, 1], [1, 0]])
26 | data = Data(edge_index=edge_index, pos=pos)
27 |
28 | data = Spherical()(data)
29 |
30 | print(data.edge_attr)
31 |
32 | .. testoutput::
33 |
34 | tensor([[1.0000, 0.2500, 0.0000],
35 | [1.0000, 0.7500, 1.0000]])
36 | """
37 |
38 | def __init__(self, cat=True):
39 | self.cat = cat
40 |
41 | def __call__(self, data):
42 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
43 | assert pos.dim() == 2 and pos.size(1) == 3
44 |
45 | cart = pos[col] - pos[row]
46 | rho = torch.norm(cart, p=2, dim=-1)
47 | rho = rho / rho.max()
48 | theta = torch.atan2(cart[..., 1], cart[..., 0]) / (2 * PI)
49 | theta += (theta < 0).type_as(theta)
50 | phi = torch.acos(cart[..., 2] / rho) / PI
51 | spher = torch.stack([rho, theta, phi], dim=1)
52 |
53 | if pseudo is not None and self.cat:
54 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
55 | data.edge_attr = torch.cat([pseudo, spher.type_as(pos)], dim=-1)
56 | else:
57 | data.edge_attr = spher
58 |
59 | return data
60 |
61 | def __repr__(self):
62 | return '{}(cat={})'.format(self.__class__.__name__, self.cat)
63 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/tu_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import torch
5 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip
6 | from torch_geometric.read import read_tu_data
7 |
8 |
9 | class TUDataset(InMemoryDataset):
10 | url = 'https://ls11-www.cs.uni-dortmund.de/people/morris/' \
11 | 'graphkerneldatasets'
12 |
13 | def __init__(self,
14 | root,
15 | name,
16 | transform=None,
17 | pre_transform=None,
18 | pre_filter=None):
19 | self.name = name
20 | super(TUDataset, self).__init__(root, transform, pre_transform,
21 | pre_filter)
22 | self.data, self.slices = torch.load(self.processed_paths[0])
23 |
24 | @property
25 | def raw_file_names(self):
26 | names = ['A', 'graph_indicator']
27 | return ['{}_{}.txt'.format(self.name, name) for name in names]
28 |
29 | @property
30 | def processed_file_names(self):
31 | return 'data.pt'
32 |
33 | def download(self):
34 | path = download_url('{}/{}.zip'.format(self.url, self.name), self.root)
35 | extract_zip(path, self.root)
36 | os.unlink(path)
37 | os.rename(osp.join(self.root, self.name), self.raw_dir)
38 |
39 | def process(self):
40 | self.data, self.slices = read_tu_data(self.raw_dir, self.name)
41 |
42 | if self.pre_filter is not None:
43 | data_list = [self.get(idx) for idx in range(len(self))]
44 | data_list = [data for data in data_list if self.pre_filter(data)]
45 | self.data, self.slices = self.collate(data_list)
46 |
47 | if self.pre_transform is not None:
48 | data_list = [self.get(idx) for idx in range(len(self))]
49 | data_list = [self.pre_transform(data) for data in data_list]
50 | self.data, self.slices = self.collate(data_list)
51 |
52 | torch.save((self.data, self.slices), self.processed_paths[0])
53 |
54 | def __repr__(self):
55 | return '{}({})'.format(self.name, len(self))
56 |
--------------------------------------------------------------------------------
/examples/cora.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import Planetoid
6 | import torch_geometric.transforms as T
7 | from torch_geometric.nn import SplineConv
8 |
9 | dataset = 'Cora'
10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
11 | data = Planetoid(path, dataset, T.TargetIndegree())[0]
12 |
13 | data.train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
14 | data.train_mask[:data.num_nodes - 1000] = 1
15 | data.val_mask = None
16 | data.test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
17 | data.test_mask[data.num_nodes - 500:] = 1
18 |
19 |
20 | class Net(torch.nn.Module):
21 | def __init__(self):
22 | super(Net, self).__init__()
23 | self.conv1 = SplineConv(data.num_features, 16, dim=1, kernel_size=2)
24 | self.conv2 = SplineConv(16, data.num_classes, dim=1, kernel_size=2)
25 |
26 | def forward(self):
27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
28 | x = F.dropout(x, training=self.training)
29 | x = F.elu(self.conv1(x, edge_index, edge_attr))
30 | x = F.dropout(x, training=self.training)
31 | x = self.conv2(x, edge_index, edge_attr)
32 | return F.log_softmax(x, dim=1)
33 |
34 |
35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36 | model, data = Net().to(device), data.to(device)
37 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)
38 |
39 |
40 | def train():
41 | model.train()
42 | optimizer.zero_grad()
43 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
44 | optimizer.step()
45 |
46 |
47 | def test():
48 | model.eval()
49 | logits, accs = model(), []
50 | for _, mask in data('train_mask', 'test_mask'):
51 | pred = logits[mask].max(1)[1]
52 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
53 | accs.append(acc)
54 | return accs
55 |
56 |
57 | for epoch in range(1, 201):
58 | train()
59 | log = 'Epoch: {:03d}, Train: {:.4f}, Test: {:.4f}'
60 | print(log.format(epoch, *test()))
61 |
--------------------------------------------------------------------------------
/examples/gcn.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import Planetoid
6 | import torch_geometric.transforms as T
7 | from torch_geometric.nn import GCNConv, ChebConv # noqa
8 |
9 | dataset = 'Cora'
10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
11 | data = Planetoid(path, dataset, T.NormalizeFeatures())[0]
12 |
13 |
14 | class Net(torch.nn.Module):
15 | def __init__(self):
16 | super(Net, self).__init__()
17 | self.conv1 = GCNConv(data.num_features, 16, improved=False)
18 | self.conv2 = GCNConv(16, data.num_classes, improved=False)
19 | # self.conv1 = ChebConv(data.num_features, 16, K=2)
20 | # self.conv2 = ChebConv(16, data.num_features, K=2)
21 |
22 | def forward(self):
23 | x, edge_index = data.x, data.edge_index
24 | x = F.relu(self.conv1(x, edge_index))
25 | x = F.dropout(x, training=self.training)
26 | x = self.conv2(x, edge_index)
27 | return F.log_softmax(x, dim=1)
28 |
29 |
30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31 | model, data = Net().to(device), data.to(device)
32 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
33 |
34 |
35 | def train():
36 | model.train()
37 | optimizer.zero_grad()
38 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
39 | optimizer.step()
40 |
41 |
42 | def test():
43 | model.eval()
44 | logits, accs = model(), []
45 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
46 | pred = logits[mask].max(1)[1]
47 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
48 | accs.append(acc)
49 | return accs
50 |
51 |
52 | best_val_acc = test_acc = 0
53 | for epoch in range(1, 101):
54 | train()
55 | train_acc, val_acc, tmp_test_acc = test()
56 | if val_acc > best_val_acc:
57 | best_val_acc = val_acc
58 | test_acc = tmp_test_acc
59 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
60 | print(log.format(epoch, train_acc, best_val_acc, test_acc))
61 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/set2set.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Set2Set(torch.nn.Module):
5 | def __init__(self,
6 | in_channels,
7 | hidden_channels,
8 | processing_steps,
9 | num_layers=1):
10 | super(Set2Set, self).__init__()
11 |
12 | self.in_channels = in_channels
13 | self.out_channels = 2 * in_channels
14 | self.hidden_channels = in_channels
15 | self.processing_steps = processing_steps
16 | self.num_layers = num_layers
17 |
18 | self.lstm = torch.nn.LSTM(self.out_channels, hidden_channels,
19 | num_layers)
20 |
21 | self.reset_parameters()
22 |
23 | def reset_parameters(self):
24 | self.lstm.reset_parameters()
25 |
26 | def forward(self, x, batch):
27 | """"""
28 | batch_size = batch.max().item() + 1
29 |
30 | # Bring x into shape [batch_size, max_nodes, in_channels].
31 | xs = x.split(torch.bincount(batch).tolist())
32 | max_nodes = max([t.size(0) for t in xs])
33 | xs = [[t, t.new_zeros(max_nodes - t.size(0), t.size(1))] for t in xs]
34 | xs = [torch.cat(t, dim=0) for t in xs]
35 | x = torch.stack(xs, dim=0)
36 |
37 | h = (x.new_zeros((self.num_layers, batch_size, self.hidden_channels)),
38 | x.new_zeros((self.num_layers, batch_size, self.hidden_channels)))
39 | q_star = x.new_zeros(1, batch_size, self.out_channels)
40 |
41 | for i in range(self.processing_steps):
42 | q, h = self.lstm(q_star, h)
43 | q = q.view(batch_size, 1, self.in_channels)
44 | e = (x * q).sum(dim=-1) # Dot product.
45 | a = torch.softmax(e, dim=-1)
46 | a = a.view(batch_size, max_nodes, 1)
47 | r = (a * x).sum(dim=1, keepdim=True)
48 | q_star = torch.cat([q, r], dim=-1)
49 | q_star = q_star.view(1, batch_size, self.out_channels)
50 |
51 | q_star = q_star.view(batch_size, self.out_channels)
52 | return q_star
53 |
54 | def __repr__(self):
55 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
56 | self.out_channels)
57 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/local_cartesian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter_max
3 |
4 |
5 | class LocalCartesian(object):
6 | r"""Saves the locally normalized spatial relation of linked nodes as
7 | Cartesian coordinates (mapped to the fixed interval :math:`[0, 1]`)
8 |
9 | .. math::
10 | \mathbf{u}(i,j) = 0.5 + \frac{\mathbf{pos}_j - \mathbf{pos}_i}{2 \cdot
11 | \max_{v \in \mathcal{N}(i)} | \mathbf{pos}_v - \mathbf{pos}_i|}
12 |
13 | in its edge attributes.
14 |
15 | Args:
16 | cat (bool, optional): Concat pseudo-coordinates to edge attributes
17 | instead of replacing them. (default: :obj:`True`)
18 |
19 | .. testsetup::
20 |
21 | import torch
22 | from torch_geometric.data import Data
23 |
24 | .. testcode::
25 |
26 | from torch_geometric.transforms import LocalCartesian
27 |
28 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)
29 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
30 | data = Data(edge_index=edge_index, pos=pos)
31 |
32 | data = LocalCartesian()(data)
33 |
34 | print(data.edge_attr)
35 |
36 | .. testoutput::
37 |
38 | tensor([[1.0000, 0.5000],
39 | [0.2500, 0.5000],
40 | [1.0000, 0.5000],
41 | [0.0000, 0.5000]])
42 | """
43 |
44 | def __init__(self, cat=True):
45 | self.cat = cat
46 |
47 | def __call__(self, data):
48 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
49 |
50 | cart = pos[col] - pos[row]
51 | max_cart, _ = scatter_max(cart.abs(), row, 0, dim_size=pos.size(0))
52 | cart = cart / (2 * max_cart.max(dim=1, keepdim=True)[0][row])
53 | cart += 0.5
54 | cart = cart.view(-1, 1) if cart.dim() == 1 else cart
55 |
56 | if pseudo is not None and self.cat:
57 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
58 | data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)
59 | else:
60 | data.edge_attr = cart
61 |
62 | return data
63 |
64 | def __repr__(self):
65 | return '{}(cat={})'.format(self.__class__.__name__, self.cat)
66 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/gcn_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 |
4 | from ..inits import uniform
5 | from ..prop import GCNProp
6 |
7 |
8 | class GCNConv(torch.nn.Module):
9 | r"""The graph convolutional operator from the `"Semi-supervised
10 | Classfication with Graph Convolutional Networks"
11 | `_ paper
12 |
13 | .. math::
14 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
15 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
16 |
17 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
18 | adjacency matrix with inserted self-loops and
19 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
20 |
21 | Args:
22 | in_channels (int): Size of each input sample.
23 | out_channels (int): Size of each output sample.
24 | improved (bool, optional): If set to :obj:`True`, the layer computes
25 | :math:`\hat{A}` as :math:`A + 2I`. (default: :obj:`False`)
26 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
27 | an additive bias. (default: :obj:`True`)
28 | """
29 |
30 | def __init__(self, in_channels, out_channels, improved=False, bias=True):
31 | super(GCNConv, self).__init__()
32 |
33 | self.in_channels = in_channels
34 | self.out_channels = out_channels
35 | self.prop = GCNProp(improved)
36 | self.weight = Parameter(torch.Tensor(in_channels, out_channels))
37 |
38 | if bias:
39 | self.bias = Parameter(torch.Tensor(out_channels))
40 | else:
41 | self.register_parameter('bias', None)
42 |
43 | self.reset_parameters()
44 |
45 | def reset_parameters(self):
46 | size = self.in_channels
47 | uniform(size, self.weight)
48 | uniform(size, self.att_weight)
49 |
50 | def forward(self, x, edge_index, edge_attr=None):
51 | """"""
52 | out = torch.mm(x, self.weight)
53 | out = self.prop(out, edge_index, edge_attr)
54 |
55 | if self.bias is not None:
56 | out = out + self.bias
57 |
58 | return out
59 |
60 | def __repr__(self):
61 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
62 | self.out_channels)
63 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/graph_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_geometric.utils import remove_self_loops, scatter_
4 |
5 | from ..inits import uniform
6 |
7 |
8 | class GraphConv(torch.nn.Module):
9 | r"""The graph neural network operator from the `"Weisfeiler and Leman Go
10 | Neural: Higher-order Graph Neural Networks"
11 | `_ paper
12 |
13 | .. math::
14 | \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{(1)} \mathbf{x}_i +
15 | \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}^{(2)} \mathbf{x}_j.
16 |
17 | Args:
18 | in_channels (int): Size of each input sample.
19 | out_channels (int): Size of each output sample.
20 | aggr (string): The aggregation operator to use (one of :obj:`"add"`,
21 | :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`)
22 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
23 | an additive bias. (default: :obj:`True`)
24 | """
25 |
26 | def __init__(self, in_channels, out_channels, aggr='add', bias=True):
27 | super(GraphConv, self).__init__()
28 |
29 | self.in_channels = in_channels
30 | self.out_channels = out_channels
31 | self.aggr = aggr
32 | self.weight = Parameter(torch.Tensor(in_channels, out_channels))
33 | self.root = Parameter(torch.Tensor(in_channels, out_channels))
34 |
35 | if bias:
36 | self.bias = Parameter(torch.Tensor(out_channels))
37 | else:
38 | self.register_parameter('bias', None)
39 |
40 | self.reset_parameters()
41 |
42 | def reset_parameters(self):
43 | size = self.in_channels
44 | uniform(size, self.weight)
45 | uniform(size, self.root)
46 | uniform(size, self.bias)
47 |
48 | def forward(self, x, edge_index):
49 | """"""
50 | x = x.unsqueeze(-1) if x.dim() == 1 else x
51 | edge_index, _ = remove_self_loops(edge_index)
52 | row, col = edge_index
53 |
54 | out = torch.mm(x, self.weight)
55 | out = scatter_(self.aggr, out[col], row, dim_size=x.size(0))
56 | out = out + torch.mm(x, self.root)
57 |
58 | if self.bias is not None:
59 | out = out + self.bias
60 |
61 | return out
62 |
63 | def __repr__(self):
64 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
65 | self.out_channels)
66 |
--------------------------------------------------------------------------------
/torch_geometric/data/dataset.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os.path as osp
3 |
4 | import torch.utils.data
5 |
6 | from .makedirs import makedirs
7 |
8 |
9 | def to_list(x):
10 | if not isinstance(x, collections.Iterable) or isinstance(x, str):
11 | x = [x]
12 | return x
13 |
14 |
15 | def files_exist(files):
16 | return all([osp.exists(f) for f in files])
17 |
18 |
19 | class Dataset(torch.utils.data.Dataset):
20 | @property
21 | def raw_file_names(self):
22 | raise NotImplementedError
23 |
24 | @property
25 | def processed_file_names(self):
26 | raise NotImplementedError
27 |
28 | def download(self):
29 | raise NotImplementedError
30 |
31 | def process(self):
32 | raise NotImplementedError
33 |
34 | def __len__(self):
35 | raise NotImplementedError
36 |
37 | def get(self, idx):
38 | raise NotImplementedError
39 |
40 | def __init__(self,
41 | root,
42 | transform=None,
43 | pre_transform=None,
44 | pre_filter=None):
45 | super(Dataset, self).__init__()
46 |
47 | self.root = osp.expanduser(osp.normpath(root))
48 | self.raw_dir = osp.join(self.root, 'raw')
49 | self.processed_dir = osp.join(self.root, 'processed')
50 | self.transform = transform
51 | self.pre_transform = pre_transform
52 | self.pre_filter = pre_filter
53 |
54 | self._download()
55 | self._process()
56 |
57 | @property
58 | def raw_paths(self):
59 | files = to_list(self.raw_file_names)
60 | return [osp.join(self.raw_dir, f) for f in files]
61 |
62 | @property
63 | def processed_paths(self):
64 | files = to_list(self.processed_file_names)
65 | return [osp.join(self.processed_dir, f) for f in files]
66 |
67 | def _download(self):
68 | if files_exist(self.raw_paths):
69 | return
70 |
71 | makedirs(self.raw_dir)
72 | self.download()
73 |
74 | def _process(self):
75 | if files_exist(self.processed_paths):
76 | return
77 |
78 | print('Processing...')
79 |
80 | makedirs(self.processed_dir)
81 | self.process()
82 |
83 | print('Done!')
84 |
85 | def __getitem__(self, idx):
86 | data = self.get(idx)
87 | data = data if self.transform is None else self.transform(data)
88 | return data
89 |
90 | def __repr__(self):
91 | return '{}({})'.format(self.__class__.__name__, len(self))
92 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/mnist_superpixels.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch_geometric.data import (InMemoryDataset, Data, download_url,
5 | extract_tar)
6 |
7 |
8 | class MNISTSuperpixels(InMemoryDataset):
9 | url = 'http://ls7-www.cs.uni-dortmund.de/cvpr_geometric_dl/' \
10 | 'mnist_superpixels.tar.gz'
11 |
12 | def __init__(self,
13 | root,
14 | train=True,
15 | transform=None,
16 | pre_transform=None,
17 | pre_filter=None):
18 | super(MNISTSuperpixels, self).__init__(root, transform, pre_transform,
19 | pre_filter)
20 | path = self.processed_paths[0] if train else self.processed_paths[1]
21 | self.data, self.slices = torch.load(path)
22 |
23 | @property
24 | def raw_file_names(self):
25 | return ['training.pt', 'test.pt']
26 |
27 | @property
28 | def processed_file_names(self):
29 | return ['training.pt', 'test.pt']
30 |
31 | def download(self):
32 | path = download_url(self.url, self.raw_dir)
33 | extract_tar(path, self.raw_dir, mode='r')
34 | os.unlink(path)
35 |
36 | def process(self):
37 | for raw_path, path in zip(self.raw_paths, self.processed_paths):
38 | x, edge_index, edge_slice, pos, y = torch.load(raw_path)
39 | edge_index, y = edge_index.to(torch.long), y.to(torch.long)
40 | m, n = y.size(0), 75
41 | x, pos = x.view(m * n, 1), pos.view(m * n, 2)
42 | node_slice = torch.arange(0, (m + 1) * n, step=n, dtype=torch.long)
43 | graph_slice = torch.arange(m + 1, dtype=torch.long)
44 | self.data = Data(x=x, edge_index=edge_index, y=y, pos=pos)
45 | self.slices = {
46 | 'x': node_slice,
47 | 'edge_index': edge_slice,
48 | 'y': graph_slice,
49 | 'pos': node_slice
50 | }
51 |
52 | if self.pre_filter is not None:
53 | data_list = [self.get(idx) for idx in range(len(self))]
54 | data_list = [d for d in data_list if self.pre_filter(d)]
55 | self.data, self.slices = self.collate(data_list)
56 |
57 | if self.pre_transform is not None:
58 | data_list = [self.get(idx) for idx in range(len(self))]
59 | data_list = [self.pre_transform(data) for data in data_list]
60 | self.data, self.slices = self.collate(data_list)
61 |
62 | torch.save((self.data, self.slices), path)
63 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/sage_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import Parameter
4 | from torch_scatter import scatter_mean
5 | from torch_geometric.utils import remove_self_loops, add_self_loops
6 |
7 | from ..inits import uniform
8 |
9 |
10 | class SAGEConv(torch.nn.Module):
11 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on
12 | Large Graphs" `_ paper
13 |
14 | .. math::
15 | \mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot
16 | \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j)
17 |
18 | \mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i}
19 | {\| \mathbf{\hat{x}}_i \|_2}.
20 |
21 | Args:
22 | in_channels (int): Size of each input sample.
23 | out_channels (int): Size of each output sample.
24 | normalize (bool, optional): If set to :obj:`False`, output features
25 | will not be :math:`\ell^2`-normalized.
26 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
27 | an additive bias. (default: :obj:`True`)
28 | """
29 |
30 | def __init__(self, in_channels, out_channels, normalize=True, bias=True):
31 | super(SAGEConv, self).__init__()
32 |
33 | self.in_channels = in_channels
34 | self.out_channels = out_channels
35 | self.normalize = normalize
36 | self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
37 |
38 | if bias:
39 | self.bias = Parameter(torch.Tensor(out_channels))
40 | else:
41 | self.register_parameter('bias', None)
42 |
43 | self.reset_parameters()
44 |
45 | def reset_parameters(self):
46 | size = self.weight.size(0)
47 | uniform(size, self.weight)
48 | uniform(size, self.bias)
49 |
50 | def forward(self, x, edge_index):
51 | """"""
52 | edge_index, _ = remove_self_loops(edge_index)
53 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
54 |
55 | x = x.unsqueeze(-1) if x.dim() == 1 else x
56 | row, col = edge_index
57 |
58 | out = scatter_mean(x[col], row, dim=0, dim_size=x.size(0))
59 | out = torch.matmul(out, self.weight)
60 |
61 | if self.bias is not None:
62 | out = out + self.bias
63 |
64 | if self.normalize:
65 | out = F.normalize(out, p=2, dim=-1)
66 |
67 | return out
68 |
69 | def __repr__(self):
70 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
71 | self.out_channels)
72 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/ppi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import json
4 |
5 | import torch
6 | import numpy as np
7 | from torch_sparse import coalesce
8 | from torch_geometric.data import (InMemoryDataset, Data, download_url,
9 | extract_zip)
10 |
11 |
12 | class PPI(InMemoryDataset):
13 | url = 'http://snap.stanford.edu/graphsage/ppi.zip'
14 |
15 | def __init__(self, root, transform=None, pre_transform=None):
16 | super(PPI, self).__init__(root, transform, pre_transform)
17 | self.data, self.slices = torch.load(self.processed_paths[0])
18 |
19 | @property
20 | def raw_file_names(self):
21 | prefix = self.__class__.__name__.lower()
22 | suffix = ['G.json', 'feats.npy', 'class_map.json']
23 | return ['{}-{}'.format(prefix, s) for s in suffix]
24 |
25 | @property
26 | def processed_file_names(self):
27 | return 'data.pt'
28 |
29 | def download(self):
30 | path = download_url(self.url, self.root)
31 | extract_zip(path, self.root)
32 | os.unlink(path)
33 | name = self.__class__.__name__.lower()
34 | os.rename(osp.join(self.root, name), self.raw_dir)
35 |
36 | def process(self):
37 | with open(self.raw_paths[0], 'r') as f:
38 | graph_data = json.load(f)
39 |
40 | mask = torch.zeros(len(graph_data['nodes']), dtype=torch.uint8)
41 | for i in graph_data['nodes']:
42 | mask[i['id']] = 1 if i['val'] else (2 if i['test'] else 0)
43 | train_mask, val_mask, test_mask = mask == 0, mask == 1, mask == 2
44 |
45 | row, col = [], []
46 | for i in graph_data['links']:
47 | row.append(i['source'])
48 | col.append(i['target'])
49 | edge_index = torch.stack([torch.tensor(row), torch.tensor(col)], dim=0)
50 | edge_index, _ = coalesce(edge_index, None, mask.size(0), mask.size(0))
51 |
52 | x = torch.from_numpy(np.load(self.raw_paths[1])).float()
53 |
54 | with open(self.raw_paths[2], 'r') as f:
55 | y_data = json.load(f)
56 |
57 | y = []
58 | for i in range(len(y_data)):
59 | y.append(y_data[str(i)])
60 | y = torch.tensor(y, dtype=torch.float)
61 |
62 | data = Data(x=x, edge_index=edge_index, y=y)
63 | data.train_mask = train_mask
64 | data.val_mask = val_mask
65 | data.test_mask = test_mask
66 |
67 | data = data if self.pre_transform is None else self.pre_transform(data)
68 | data, slices = self.collate([data])
69 | torch.save((data, slices), self.processed_paths[0])
70 |
71 | def __repr__(self):
72 | return '{}()'.format(self.__class__.__name__)
73 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/coma.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from glob import glob
3 |
4 | import torch
5 | from torch_geometric.data import InMemoryDataset, extract_zip
6 | from torch_geometric.read import read_ply
7 |
8 |
9 | class CoMA(InMemoryDataset):
10 | url = 'https://coma.is.tue.mpg.de/'
11 |
12 | categories = [
13 | 'bareteeth',
14 | 'cheeks_in',
15 | 'eyebrow',
16 | 'high_smile',
17 | 'lips_back',
18 | 'lips_up',
19 | 'mouth_down',
20 | 'mouth_extreme',
21 | 'mouth_middle',
22 | 'mouth_open',
23 | 'mouth_side',
24 | 'mouth_up',
25 | ]
26 |
27 | def __init__(self,
28 | root,
29 | train=True,
30 | transform=None,
31 | pre_transform=None,
32 | pre_filter=None):
33 | super(CoMA, self).__init__(root, transform, pre_transform, pre_filter)
34 | path = self.processed_paths[0] if train else self.processed_paths[1]
35 | self.data, self.slices = torch.load(path)
36 |
37 | @property
38 | def raw_file_names(self):
39 | return 'COMA_data.zip'
40 |
41 | @property
42 | def processed_file_names(self):
43 | return ['training.pt', 'test.pt']
44 |
45 | def download(self):
46 | raise RuntimeError(
47 | 'Dataset not found. Please download COMA_data.zip from {} and '
48 | 'move it to {}'.format(self.url, self.raw_dir))
49 |
50 | def process(self):
51 | folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*')))
52 | if len(folders) == 0:
53 | extract_zip(self.raw_paths[0], self.raw_dir, log=False)
54 | folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*')))
55 |
56 | train_data_list, test_data_list = [], []
57 | for folder in folders:
58 | for i, category in enumerate(self.categories):
59 | files = sorted(glob(osp.join(folder, category, '*.ply')))
60 | for j, f in enumerate(files):
61 | data = read_ply(f)
62 | data.y = torch.tensor([i], dtype=torch.long)
63 | if self.pre_filter is not None and\
64 | not self.pre_filter(data):
65 | continue
66 | if self.pre_transform is not None:
67 | data = self.pre_transform(data)
68 |
69 | if (j % 100) < 90:
70 | train_data_list.append(data)
71 | else:
72 | test_data_list.append(data)
73 |
74 | torch.save(self.collate(train_data_list), self.processed_paths[0])
75 | torch.save(self.collate(test_data_list), self.processed_paths[1])
76 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/modelnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import glob
4 |
5 | import torch
6 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip
7 | from torch_geometric.read import read_off
8 |
9 |
10 | class ModelNet(InMemoryDataset):
11 | urls = {
12 | '10':
13 | 'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip',
14 | '40':
15 | 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
16 | }
17 |
18 | def __init__(self,
19 | root,
20 | name='10',
21 | train=True,
22 | transform=None,
23 | pre_transform=None,
24 | pre_filter=None):
25 | assert name in ['10', '40']
26 | self.name = name
27 | super(ModelNet, self).__init__(root, transform, pre_transform,
28 | pre_filter)
29 | path = self.processed_paths[0] if train else self.processed_paths[1]
30 | self.data, self.slices = torch.load(path)
31 |
32 | @property
33 | def raw_file_names(self):
34 | return [
35 | 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
36 | 'night_stand', 'sofa', 'table', 'toilet'
37 | ]
38 |
39 | @property
40 | def processed_file_names(self):
41 | return ['training.pt', 'test.pt']
42 |
43 | def download(self):
44 | path = download_url(self.urls[self.name], self.root)
45 | extract_zip(path, self.root)
46 | os.unlink(path)
47 | folder = osp.join(self.root, 'ModelNet{}'.format(self.name))
48 | os.rename(folder, self.raw_dir)
49 |
50 | def process(self):
51 | torch.save(self.process_set('train'), self.processed_paths[0])
52 | torch.save(self.process_set('test'), self.processed_paths[1])
53 |
54 | def process_set(self, dataset):
55 | categories = glob.glob(osp.join(self.raw_dir, '*', ''))
56 | categories = sorted([x.split('/')[-2] for x in categories])
57 |
58 | data_list = []
59 | for target, category in enumerate(categories):
60 | folder = osp.join(self.raw_dir, category, dataset)
61 | paths = glob.glob('{}/{}_*.off'.format(folder, category))
62 | for path in paths:
63 | data = read_off(path)
64 | data.y = torch.tensor([target])
65 | data_list.append(data)
66 |
67 | if self.pre_filter is not None:
68 | data_list = [d for d in data_list if self.pre_filter(d)]
69 |
70 | if self.pre_transform is not None:
71 | data_list = [self.pre_transform(d) for d in data_list]
72 |
73 | return self.collate(data_list)
74 |
75 | def __repr__(self):
76 | return '{}{}({})'.format(self.__class__.__name__, self.name, len(self))
77 |
--------------------------------------------------------------------------------
/examples/mnist_voxel_grid.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import MNISTSuperpixels
6 | import torch_geometric.transforms as T
7 | from torch_geometric.data import DataLoader
8 | from torch_geometric.nn import SplineConv, voxel_grid, max_pool, max_pool_x
9 |
10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')
11 | transform = T.Cartesian(cat=False)
12 | train_dataset = MNISTSuperpixels(path, True, transform=transform)
13 | test_dataset = MNISTSuperpixels(path, False, transform=transform)
14 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
15 | test_loader = DataLoader(test_dataset, batch_size=64)
16 |
17 |
18 | class Net(torch.nn.Module):
19 | def __init__(self):
20 | super(Net, self).__init__()
21 | self.conv1 = SplineConv(1, 32, dim=2, kernel_size=5)
22 | self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5)
23 | self.conv3 = SplineConv(64, 64, dim=2, kernel_size=5)
24 | self.fc1 = torch.nn.Linear(4 * 64, 128)
25 | self.fc2 = torch.nn.Linear(128, 10)
26 |
27 | def forward(self, data):
28 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
29 | cluster = voxel_grid(data.pos, data.batch, size=5, start=0, end=28)
30 | data = max_pool(cluster, data, transform=transform)
31 |
32 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
33 | cluster = voxel_grid(data.pos, data.batch, size=7, start=0, end=28)
34 | data = max_pool(cluster, data, transform=transform)
35 |
36 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
37 | cluster = voxel_grid(data.pos, data.batch, size=14, start=0, end=27.99)
38 | x = max_pool_x(cluster, data.x, data.batch, size=4)
39 |
40 | x = x.view(-1, self.fc1.weight.size(1))
41 | x = F.elu(self.fc1(x))
42 | x = F.dropout(x, training=self.training)
43 | x = self.fc2(x)
44 | return F.log_softmax(x, dim=1)
45 |
46 |
47 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48 | model = Net().to(device)
49 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
50 |
51 |
52 | def train(epoch):
53 | model.train()
54 |
55 | if epoch == 6:
56 | for param_group in optimizer.param_groups:
57 | param_group['lr'] = 0.001
58 |
59 | if epoch == 16:
60 | for param_group in optimizer.param_groups:
61 | param_group['lr'] = 0.0001
62 |
63 | for data in train_loader:
64 | data = data.to(device)
65 | optimizer.zero_grad()
66 | F.nll_loss(model(data), data.y).backward()
67 | optimizer.step()
68 |
69 |
70 | def test():
71 | model.eval()
72 | correct = 0
73 |
74 | for data in test_loader:
75 | data = data.to(device)
76 | pred = model(data).max(1)[1]
77 | correct += pred.eq(data.y).sum().item()
78 | return correct / len(test_dataset)
79 |
80 |
81 | for epoch in range(1, 21):
82 | train(epoch)
83 | test_acc = test()
84 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
85 |
--------------------------------------------------------------------------------
/docs/source/notes/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | We have outsourced a lot of functionality of `PyTorch Geometric `_ to other packages, which needs to be installed in advance.
5 | These packages come with their own CPU and GPU kernel implementations based on the newly introduced `C++/CUDA extensions `_ in PyTorch 0.4.0.
6 |
7 | .. note::
8 | We do not recommend installation as root user on your system python.
9 | Please setup an `Anaconda/Miniconda `_ environment or create a `Docker image `_.
10 |
11 | Please follow the steps below for a successful installation:
12 |
13 | #. Ensure that at least PyTorch 0.4.1 is installed:
14 |
15 | .. code-block:: none
16 |
17 | $ python -c "import torch; print(torch.__version__)"
18 | >>> 0.4.1
19 |
20 | #. Ensure CUDA is setup correctly (optional):
21 |
22 | #. Check if PyTorch is installed with CUDA support:
23 |
24 | .. code-block:: none
25 |
26 | $ python -c "import torch; print(torch.cuda.is_available())"
27 | >>> True
28 |
29 | #. Add CUDA to ``$PATH`` and ``$CPATH`` (note that your actual CUDA path may vary from ``/usr/local/cuda``):
30 |
31 | .. code-block:: none
32 |
33 | $ PATH=/usr/local/cuda/bin:$PATH
34 | $ echo $PATH
35 | >>> /usr/local/cuda/bin:...
36 |
37 | $ CPATH=/usr/local/cuda/include:$CPATH
38 | $ echo $CPATH
39 | >>> /usr/local/cuda/include:...
40 |
41 | #. Verify that ``nvcc`` is accessible from terminal:
42 |
43 | .. code-block:: none
44 |
45 | $ nvcc --version
46 |
47 | #. Install all needed packages:
48 |
49 | .. code-block:: none
50 |
51 | $ pip install --upgrade torch-scatter
52 | $ pip install --upgrade torch-sparse
53 | $ pip install --upgrade torch-cluster
54 | $ pip install --upgrade torch-spline-conv
55 | $ pip install torch-geometric
56 |
57 | In rare cases, CUDA or python path issues can prevent a succesful installation.
58 | Unfortunately, the error messages of ``pip`` are not very meaningful.
59 | You should therefore clone the respective package and check where the error occurs, e.g.:
60 |
61 | .. code-block:: none
62 |
63 | $ git clone https://github.com/rusty1s/pytorch_scatter
64 | $ cd pytorch_scatter
65 | $ python setup.py install # Check for CUDA compilation or link error.
66 | $ python setup.py test # Verify installation by running test suite.
67 |
68 | C++/CUDA extensions on macOS
69 | ----------------------------
70 |
71 | .. note::
72 | As reported by some users, the use of miniconda is absolutely necessary on macOS.
73 |
74 | In order to compile CUDA extensions on macOS, you need to replace the call
75 |
76 | .. code-block:: python
77 |
78 | def spawn(self, cmd):
79 | spawn(cmd, dry_run=self.dry_run)
80 |
81 | with
82 |
83 | .. code-block:: python
84 |
85 | import subprocess
86 |
87 | def spawn(self, cmd):
88 | subprocess.call(cmd)
89 |
90 | in ``lib/python{xxx}/distutils/ccompiler.py``.
91 |
--------------------------------------------------------------------------------
/docs/build/html/_sources/notes/installation.rst.txt:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | We have outsourced a lot of functionality of `PyTorch Geometric `_ to other packages, which needs to be installed in advance.
5 | These packages come with their own CPU and GPU kernel implementations based on `C FFI `_ and the newly introduced `C++/CUDA extensions `_ in PyTorch 0.4.0.
6 |
7 | .. note::
8 | We do not recommend installation as root user on your system python.
9 | Please setup an `Anaconda/Miniconda `_ environment or create a `Docker image `_.
10 |
11 | Please follow the steps below for a successful installation:
12 |
13 | #. Ensure that at least PyTorch 0.4.1 is installed:
14 |
15 | .. code-block:: none
16 |
17 | $ python -c "import torch; print(torch.__version__)"
18 | >>> 0.4.1
19 |
20 | #. Ensure CUDA is setup correctly (optional):
21 |
22 | #. Check if PyTorch is installed with CUDA support:
23 |
24 | .. code-block:: none
25 |
26 | $ python -c "import torch; print(torch.cuda.is_available())"
27 | >>> True
28 |
29 | #. Add CUDA to ``$PATH`` and ``$CPATH`` (note that your actual CUDA path may vary from ``/usr/local/cuda``):
30 |
31 | .. code-block:: none
32 |
33 | $ PATH=/usr/local/cuda/bin:$PATH
34 | $ echo $PATH
35 | >>> /usr/local/cuda/bin:...
36 |
37 | $ CPATH=/usr/local/cuda/include:$CPATH
38 | $ echo $CPATH
39 | >>> /usr/local/cuda/include:...
40 |
41 | #. Verify that ``nvcc`` is accessible from terminal:
42 |
43 | .. code-block:: none
44 |
45 | $ nvcc --version
46 |
47 | #. Install all needed packages:
48 |
49 | .. code-block:: none
50 |
51 | $ pip install cffi
52 | $ pip install --upgrade torch-scatter
53 | $ pip install --upgrade torch-sparse
54 | $ pip install --upgrade torch-cluster
55 | $ pip install --upgrade torch-spline-conv
56 | $ pip install torch-geometric
57 |
58 | In rare cases, CUDA or python path issues can prevent a succesful installation.
59 | Unfortunately, the error messages of ``pip`` are not very meaningful.
60 | You should therefore clone the respective package and check where the error occurs, e.g.:
61 |
62 | .. code-block:: none
63 |
64 | $ git clone https://github.com/rusty1s/pytorch_scatter
65 | $ cd pytorch_scatter
66 | $ python setup.py install # Check for CUDA compilation or link error.
67 | $ python setup.py test # Verify installation by running test suite.
68 |
69 | C++/CUDA extensions on macOS
70 | ----------------------------
71 |
72 | In order to compile C++/CUDA extensions on macOS, you need to replace the call
73 |
74 | .. code-block:: python
75 |
76 | def spawn(self, cmd):
77 | spawn(cmd, dry_run=self.dry_run)
78 |
79 | with
80 |
81 | .. code-block:: python
82 |
83 | def spawn(self, cmd):
84 | subprocess.call(cmd)
85 |
86 | in ``distutils/ccompiler.py``.
87 | Do not forget to ``import subprocess`` at the top of the file.
88 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/nn_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_geometric.utils import scatter_
4 |
5 | from ..inits import reset, uniform
6 |
7 |
8 | class NNConv(torch.nn.Module):
9 | r"""The continuous kernel-based convolutional operator adapted from the
10 | `"Neural Message Passing for Quantum Chemistry"
11 | `_ paper
12 |
13 | .. math::
14 | \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{(1)} \mathbf{x}_i +
15 | \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
16 | h_{\mathbf{\Theta^{(2)}}}(\mathbf{e}_{i,j}),
17 |
18 | where :math:`h_{\mathbf{\Theta^{(2)}}}` denotes a neural network, *.i.e.*
19 | a MLP.
20 |
21 | Args:
22 | in_channels (int): Size of each input sample.
23 | out_channels (int): Size of each output sample.
24 | nn (nn.Sequential): Neural network.
25 | aggr (string): The aggregation operator to use (one of :obj:`"add"`,
26 | :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`)
27 | root_weight (bool, optional): If set to :obj:`False`, the layer will
28 | not add the transformed root node features to the output.
29 | (default: :obj:`True`)
30 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
31 | an additive bias. (default: :obj:`True`)
32 | """
33 |
34 | def __init__(self,
35 | in_channels,
36 | out_channels,
37 | nn,
38 | aggr="add",
39 | root_weight=True,
40 | bias=True):
41 | super(NNConv, self).__init__()
42 |
43 | self.in_channels = in_channels
44 | self.out_channels = out_channels
45 | self.nn = nn
46 | self.aggr = aggr
47 |
48 | if root_weight:
49 | self.root = Parameter(torch.Tensor(in_channels, out_channels))
50 | else:
51 | self.register_parameter('root', None)
52 |
53 | if bias:
54 | self.bias = Parameter(torch.Tensor(out_channels))
55 | else:
56 | self.register_parameter('bias', None)
57 |
58 | self.reset_parameters()
59 |
60 | def reset_parameters(self):
61 | reset(self.nn)
62 | size = self.in_channels
63 | uniform(size, self.root)
64 | uniform(size, self.bias)
65 |
66 | def forward(self, x, edge_index, pseudo):
67 | """"""
68 | x = x.unsqueeze(-1) if x.dim() == 1 else x
69 | pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
70 | row, col = edge_index
71 |
72 | out = self.nn(pseudo)
73 | out = out.view(-1, self.in_channels, self.out_channels)
74 | out = torch.matmul(x[col].unsqueeze(1), out).squeeze(1)
75 | out = scatter_(self.aggr, out, row, dim_size=x.size(0))
76 |
77 | if self.root is not None:
78 | out = out + torch.mm(x, self.root)
79 |
80 | if self.bias is not None:
81 | out = out + self.bias
82 |
83 | return out
84 |
85 | def __repr__(self):
86 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
87 | self.out_channels)
88 |
--------------------------------------------------------------------------------
/docs/build/html/_static/css/badge_only.css:
--------------------------------------------------------------------------------
1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../font/fontawesome_webfont.eot");src:url("../font/fontawesome_webfont.eot?#iefix") format("embedded-opentype"),url("../font/fontawesome_webfont.woff") format("woff"),url("../font/fontawesome_webfont.ttf") format("truetype"),url("../font/fontawesome_webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:0.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;border-top:solid 10px #343131;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}
2 | /*# sourceMappingURL=badge_only.css.map */
3 |
--------------------------------------------------------------------------------
/examples/mnist_graclus.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import MNISTSuperpixels
6 | import torch_geometric.transforms as T
7 | from torch_geometric.data import DataLoader
8 | from torch_geometric.utils import normalized_cut
9 | from torch_geometric.nn import (SplineConv, graclus, max_pool, max_pool_x,
10 | global_mean_pool)
11 |
12 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')
13 | train_dataset = MNISTSuperpixels(path, True, transform=T.Cartesian())
14 | test_dataset = MNISTSuperpixels(path, False, transform=T.Cartesian())
15 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
16 | test_loader = DataLoader(test_dataset, batch_size=64)
17 | d = train_dataset.data
18 |
19 |
20 | def normalized_cut_2d(edge_index, pos):
21 | row, col = edge_index
22 | edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
23 | return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
24 |
25 |
26 | class Net(torch.nn.Module):
27 | def __init__(self):
28 | super(Net, self).__init__()
29 | self.conv1 = SplineConv(d.num_features, 32, dim=2, kernel_size=5)
30 | self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5)
31 | self.fc1 = torch.nn.Linear(64, 128)
32 | self.fc2 = torch.nn.Linear(128, d.num_classes)
33 |
34 | def forward(self, data):
35 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
36 | weight = normalized_cut_2d(data.edge_index, data.pos)
37 | cluster = graclus(data.edge_index, weight, data.x.size(0))
38 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False))
39 |
40 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
41 | weight = normalized_cut_2d(data.edge_index, data.pos)
42 | cluster = graclus(data.edge_index, weight, data.x.size(0))
43 | x, batch = max_pool_x(cluster, data.x, data.batch)
44 |
45 | x = global_mean_pool(x, batch)
46 | x = F.elu(self.fc1(x))
47 | x = F.dropout(x, training=self.training)
48 | return F.log_softmax(self.fc2(x), dim=1)
49 |
50 |
51 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52 | model = Net().to(device)
53 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
54 |
55 |
56 | def train(epoch):
57 | model.train()
58 |
59 | if epoch == 16:
60 | for param_group in optimizer.param_groups:
61 | param_group['lr'] = 0.001
62 |
63 | if epoch == 26:
64 | for param_group in optimizer.param_groups:
65 | param_group['lr'] = 0.0001
66 |
67 | for data in train_loader:
68 | data = data.to(device)
69 | optimizer.zero_grad()
70 | F.nll_loss(model(data), data.y).backward()
71 | optimizer.step()
72 |
73 |
74 | def test():
75 | model.eval()
76 | correct = 0
77 |
78 | for data in test_loader:
79 | data = data.to(device)
80 | pred = model(data).max(1)[1]
81 | correct += pred.eq(data.y).sum().item()
82 | return correct / len(test_dataset)
83 |
84 |
85 | for epoch in range(1, 31):
86 | train(epoch)
87 | test_acc = test()
88 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
89 |
--------------------------------------------------------------------------------
/torch_geometric/data/in_memory_dataset.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat, product
2 |
3 | import torch
4 | from torch_geometric.data import Dataset, Data
5 |
6 |
7 | class InMemoryDataset(Dataset):
8 | @property
9 | def raw_file_names(self):
10 | raise NotImplementedError
11 |
12 | @property
13 | def processed_file_names(self):
14 | raise NotImplementedError
15 |
16 | def download(self):
17 | raise NotImplementedError
18 |
19 | def process(self):
20 | raise NotImplementedError
21 |
22 | def __init__(self,
23 | root,
24 | transform=None,
25 | pre_transform=None,
26 | pre_filter=None):
27 | super(InMemoryDataset, self).__init__(root, transform, pre_transform,
28 | pre_filter)
29 | self.data, self.slices = None, None
30 |
31 | @property
32 | def num_features(self):
33 | return self[0].num_features
34 |
35 | @property
36 | def num_classes(self):
37 | data = self.data
38 | return data.y.max().item() + 1 if data.y.dim() == 1 else data.y.size(1)
39 |
40 | def __len__(self):
41 | return self.slices[list(self.slices.keys())[0]].size(0) - 1
42 |
43 | def __getitem__(self, idx):
44 | if isinstance(idx, int):
45 | data = self.get(idx)
46 | data = data if self.transform is None else self.transform(data)
47 | return data
48 | elif isinstance(idx, slice):
49 | return self.split(range(*idx.indices(len(self))))
50 | elif isinstance(idx, torch.LongTensor):
51 | return self.split(idx)
52 | elif isinstance(idx, torch.ByteTensor):
53 | return self.split(idx.nonzero())
54 |
55 | raise IndexError(
56 | 'Only integers, slices (`:`) and long or byte tensors are valid '
57 | 'indices (got {}).'.format(type(idx).__name__))
58 |
59 | def shuffle(self):
60 | return self.split(torch.randperm(len(self)))
61 |
62 | def get(self, idx):
63 | data = Data()
64 | for key in self.data.keys:
65 | item, slices = self.data[key], self.slices[key]
66 | s = list(repeat(slice(None), item.dim()))
67 | s[self.data.cat_dim(key, item)] = slice(slices[idx],
68 | slices[idx + 1])
69 | data[key] = item[s]
70 | return data
71 |
72 | def split(self, index):
73 | copy = self.__class__.__new__(self.__class__)
74 | copy.__dict__ = self.__dict__.copy()
75 | copy.data, copy.slices = self.collate([self.get(i) for i in index])
76 | return copy
77 |
78 | def collate(self, data_list):
79 | keys = data_list[0].keys
80 | data = Data()
81 |
82 | for key in keys:
83 | data[key] = []
84 | slices = {key: [0] for key in keys}
85 |
86 | for item, key in product(data_list, keys):
87 | data[key].append(item[key])
88 | s = slices[key][-1] + item[key].size(item.cat_dim(key, item[key]))
89 | slices[key].append(s)
90 |
91 | for key in keys:
92 | data[key] = torch.cat(
93 | data[key], dim=data_list[0].cat_dim(key, data_list[0][key]))
94 | slices[key] = torch.LongTensor(slices[key])
95 |
96 | return data, slices
97 |
--------------------------------------------------------------------------------
/examples/enzymes_topk_pool.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import TUDataset
6 | from torch_geometric.data import DataLoader
7 | from torch_geometric.nn import GraphConv, TopKPooling
8 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
9 |
10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ENZYMES')
11 | dataset = TUDataset(path, name='ENZYMES')
12 | dataset = dataset.shuffle()
13 | n = len(dataset) // 10
14 | test_dataset = dataset[:n]
15 | train_dataset = dataset[n:]
16 | test_loader = DataLoader(test_dataset, batch_size=60)
17 | train_loader = DataLoader(train_dataset, batch_size=60)
18 |
19 |
20 | class Net(torch.nn.Module):
21 | def __init__(self):
22 | super(Net, self).__init__()
23 |
24 | self.conv1 = GraphConv(dataset.num_features, 128)
25 | self.pool1 = TopKPooling(128, ratio=0.8)
26 | self.conv2 = GraphConv(128, 128)
27 | self.pool2 = TopKPooling(128, ratio=0.8)
28 | self.conv3 = GraphConv(128, 128)
29 | self.pool3 = TopKPooling(128, ratio=0.8)
30 |
31 | self.lin1 = torch.nn.Linear(256, 128)
32 | self.lin2 = torch.nn.Linear(128, 64)
33 | self.lin3 = torch.nn.Linear(64, dataset.num_classes)
34 |
35 | def forward(self, data):
36 | x, edge_index, batch = data.x, data.edge_index, data.batch
37 |
38 | x = F.relu(self.conv1(x, edge_index))
39 | x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
40 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
41 |
42 | x = F.relu(self.conv2(x, edge_index))
43 | x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
44 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
45 |
46 | x = F.relu(self.conv3(x, edge_index))
47 | x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
48 | x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
49 |
50 | x = x1 + x2 + x3
51 |
52 | x = F.relu(self.lin1(x))
53 | x = F.dropout(x, p=0.5, training=self.training)
54 | x = F.relu(self.lin2(x))
55 | x = F.log_softmax(self.lin3(x), dim=-1)
56 |
57 | return x
58 |
59 |
60 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61 | model = Net().to(device)
62 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
63 |
64 |
65 | def train(epoch):
66 | model.train()
67 |
68 | loss_all = 0
69 | for data in train_loader:
70 | data = data.to(device)
71 | optimizer.zero_grad()
72 | output = model(data)
73 | loss = F.nll_loss(output, data.y)
74 | loss.backward()
75 | loss_all += data.num_graphs * loss.item()
76 | optimizer.step()
77 | return loss_all / len(train_dataset)
78 |
79 |
80 | def test(loader):
81 | model.eval()
82 |
83 | correct = 0
84 | for data in loader:
85 | data = data.to(device)
86 | pred = model(data).max(dim=1)[1]
87 | correct += pred.eq(data.y).sum().item()
88 | return correct / len(loader.dataset)
89 |
90 |
91 | for epoch in range(1, 201):
92 | loss = train(epoch)
93 | train_acc = test(train_loader)
94 | test_acc = test(test_loader)
95 | print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.5f}, Test Acc: {:.5f}'.
96 | format(epoch, loss, train_acc, test_acc))
97 |
--------------------------------------------------------------------------------
/examples/mnist_mpnn.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch_geometric.datasets import MNISTSuperpixels
7 | import torch_geometric.transforms as T
8 | from torch_geometric.data import DataLoader
9 | from torch_geometric.utils import normalized_cut
10 | from torch_geometric.nn import (NNConv, graclus, max_pool, max_pool_x,
11 | global_mean_pool)
12 |
13 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')
14 | train_dataset = MNISTSuperpixels(path, True, transform=T.Cartesian())
15 | test_dataset = MNISTSuperpixels(path, False, transform=T.Cartesian())
16 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
17 | test_loader = DataLoader(test_dataset, batch_size=64)
18 | d = train_dataset.data
19 |
20 |
21 | def normalized_cut_2d(edge_index, pos):
22 | row, col = edge_index
23 | edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
24 | return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
25 |
26 |
27 | class Net(nn.Module):
28 | def __init__(self):
29 | super(Net, self).__init__()
30 | n1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), nn.Linear(25, 32))
31 | self.conv1 = NNConv(d.num_features, 32, n1)
32 |
33 | n2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), nn.Linear(25, 2048))
34 | self.conv2 = NNConv(32, 64, n2)
35 |
36 | self.fc1 = torch.nn.Linear(64, 128)
37 | self.fc2 = torch.nn.Linear(128, d.num_classes)
38 |
39 | def forward(self, data):
40 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
41 | weight = normalized_cut_2d(data.edge_index, data.pos)
42 | cluster = graclus(data.edge_index, weight, data.x.size(0))
43 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False))
44 |
45 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
46 | weight = normalized_cut_2d(data.edge_index, data.pos)
47 | cluster = graclus(data.edge_index, weight, data.x.size(0))
48 | x, batch = max_pool_x(cluster, data.x, data.batch)
49 |
50 | x = global_mean_pool(x, batch)
51 | x = F.elu(self.fc1(x))
52 | x = F.dropout(x, training=self.training)
53 | return F.log_softmax(self.fc2(x), dim=1)
54 |
55 |
56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57 | model = Net().to(device)
58 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
59 |
60 |
61 | def train(epoch):
62 | model.train()
63 |
64 | if epoch == 16:
65 | for param_group in optimizer.param_groups:
66 | param_group['lr'] = 0.001
67 |
68 | if epoch == 26:
69 | for param_group in optimizer.param_groups:
70 | param_group['lr'] = 0.0001
71 |
72 | for data in train_loader:
73 | data = data.to(device)
74 | optimizer.zero_grad()
75 | F.nll_loss(model(data), data.y).backward()
76 | optimizer.step()
77 |
78 |
79 | def test():
80 | model.eval()
81 | correct = 0
82 |
83 | for data in test_loader:
84 | data = data.to(device)
85 | pred = model(data).max(1)[1]
86 | correct += pred.eq(data.y).sum().item()
87 | return correct / len(test_dataset)
88 |
89 |
90 | for epoch in range(1, 31):
91 | train(epoch)
92 | test_acc = test()
93 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
94 |
--------------------------------------------------------------------------------
/examples/faust.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch_geometric.datasets import FAUST
6 | import torch_geometric.transforms as T
7 | from torch_geometric.data import DataLoader
8 | from torch_geometric.nn import SplineConv
9 | from torch_geometric.utils import degree
10 |
11 |
12 | class MyTransform(object):
13 | def __call__(self, data):
14 | data.face, data.x = None, torch.ones(data.num_nodes, 1)
15 | return data
16 |
17 |
18 | def norm(x, edge_index):
19 | deg = degree(edge_index[0], x.size(0), x.dtype, x.device) + 1
20 | return x / deg.unsqueeze(-1)
21 |
22 |
23 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FAUST')
24 | pre_transform = T.Compose([T.FaceToEdge(), MyTransform()])
25 | train_dataset = FAUST(path, True, T.Cartesian(), pre_transform)
26 | test_dataset = FAUST(path, False, T.Cartesian(), pre_transform)
27 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
28 | test_loader = DataLoader(test_dataset, batch_size=1)
29 | d = train_dataset[0]
30 |
31 |
32 | class Net(torch.nn.Module):
33 | def __init__(self):
34 | super(Net, self).__init__()
35 | self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, norm=False)
36 | self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, norm=False)
37 | self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False)
38 | self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False)
39 | self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False)
40 | self.conv6 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False)
41 | self.fc1 = torch.nn.Linear(64, 256)
42 | self.fc2 = torch.nn.Linear(256, d.num_nodes)
43 |
44 | def forward(self, data):
45 | x, edge_index, pseudo = data.x, data.edge_index, data.edge_attr
46 | x = F.elu(norm(self.conv1(x, edge_index, pseudo), edge_index))
47 | x = F.elu(norm(self.conv2(x, edge_index, pseudo), edge_index))
48 | x = F.elu(norm(self.conv3(x, edge_index, pseudo), edge_index))
49 | x = F.elu(norm(self.conv4(x, edge_index, pseudo), edge_index))
50 | x = F.elu(norm(self.conv5(x, edge_index, pseudo), edge_index))
51 | x = F.elu(norm(self.conv6(x, edge_index, pseudo), edge_index))
52 | x = F.elu(self.fc1(x))
53 | x = F.dropout(x, training=self.training)
54 | x = self.fc2(x)
55 | return F.log_softmax(x, dim=1)
56 |
57 |
58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59 | model = Net().to(device)
60 | target = torch.arange(d.num_nodes, dtype=torch.long, device=device)
61 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
62 |
63 |
64 | def train(epoch):
65 | model.train()
66 |
67 | if epoch == 61:
68 | for param_group in optimizer.param_groups:
69 | param_group['lr'] = 0.001
70 |
71 | for data in train_loader:
72 | optimizer.zero_grad()
73 | F.nll_loss(model(data.to(device)), target).backward()
74 | optimizer.step()
75 |
76 |
77 | def test():
78 | model.eval()
79 | correct = 0
80 |
81 | for data in test_loader:
82 | pred = model(data.to(device)).max(1)[1]
83 | correct += pred.eq(target).sum().item()
84 | return correct / (len(test_dataset) * d.num_nodes)
85 |
86 |
87 | for epoch in range(1, 101):
88 | train(epoch)
89 | test_acc = test()
90 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
91 |
--------------------------------------------------------------------------------
/torch_geometric/read/planetoid.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os.path as osp
3 | from itertools import repeat
4 |
5 | import torch
6 | from torch_sparse import coalesce
7 | from torch_geometric.data import Data
8 | from torch_geometric.read import read_txt_array
9 | from torch_geometric.utils import remove_self_loops
10 |
11 | try:
12 | import cPickle as pickle
13 | except ImportError:
14 | import pickle
15 |
16 |
17 | def read_planetoid_data(folder, prefix):
18 | """Reads the planetoid data format.
19 | ind.{}.x
20 | """
21 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
22 | items = [read_file(folder, prefix, name) for name in names]
23 | x, tx, allx, y, ty, ally, graph, test_index = items
24 | train_index = torch.arange(y.size(0), dtype=torch.long)
25 | val_index = torch.arange(y.size(0), y.size(0) + 500, dtype=torch.long)
26 | sorted_test_index = test_index.sort()[0]
27 |
28 | if prefix.lower() == 'citeseer':
29 | # There are some isolated nodes in the Citeseer graph, resulting in
30 | # none consecutive test indices. We need to identify them and add them
31 | # as zero vectors to `tx` and `ty`.
32 | len_test_indices = (test_index.max() - test_index.min()).item() + 1
33 |
34 | tx_ext = torch.zeros(len_test_indices, tx.size(1))
35 | tx_ext[sorted_test_index - test_index.min(), :] = tx
36 | ty_ext = torch.zeros(len_test_indices, ty.size(1))
37 | ty_ext[sorted_test_index - test_index.min(), :] = ty
38 |
39 | tx, ty = tx_ext, ty_ext
40 |
41 | x = torch.cat([allx, tx], dim=0)
42 | y = torch.cat([ally, ty], dim=0).max(dim=1)[1]
43 |
44 | x[test_index] = x[sorted_test_index]
45 | y[test_index] = y[sorted_test_index]
46 |
47 | train_mask = sample_mask(train_index, num_nodes=y.size(0))
48 | val_mask = sample_mask(val_index, num_nodes=y.size(0))
49 | test_mask = sample_mask(test_index, num_nodes=y.size(0))
50 |
51 | edge_index = edge_index_from_dict(graph, num_nodes=y.size(0))
52 |
53 | data = Data(x=x, edge_index=edge_index, y=y)
54 | data.train_mask = train_mask
55 | data.val_mask = val_mask
56 | data.test_mask = test_mask
57 |
58 | return data
59 |
60 |
61 | def read_file(folder, prefix, name):
62 | path = osp.join(folder, 'ind.{}.{}'.format(prefix.lower(), name))
63 |
64 | if name == 'test.index':
65 | return read_txt_array(path, dtype=torch.long)
66 |
67 | with open(path, 'rb') as f:
68 | if sys.version_info > (3, 0):
69 | out = pickle.load(f, encoding='latin1')
70 | else:
71 | out = pickle.load(f)
72 |
73 | if name == 'graph':
74 | return out
75 |
76 | out = out.todense() if hasattr(out, 'todense') else out
77 | out = torch.Tensor(out)
78 | return out
79 |
80 |
81 | def edge_index_from_dict(graph_dict, num_nodes=None):
82 | row, col = [], []
83 | for key, value in graph_dict.items():
84 | row += repeat(key, len(value))
85 | col += value
86 | edge_index = torch.stack([torch.tensor(row), torch.tensor(col)], dim=0)
87 | # NOTE: There are duplicated edges and self loops in the datasets. Other
88 | # implementations do not remove them!
89 | edge_index, _ = remove_self_loops(edge_index)
90 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
91 | return edge_index
92 |
93 |
94 | def sample_mask(index, num_nodes):
95 | mask = torch.zeros((num_nodes, ), dtype=torch.uint8)
96 | mask[index] = 1
97 | return mask
98 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/topk_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_scatter import scatter_add
4 |
5 | from ..inits import uniform
6 | from ...utils.num_nodes import maybe_num_nodes
7 |
8 |
9 | def topk(x, ratio, batch):
10 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
11 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
12 | k = (ratio * num_nodes.to(torch.float)).round().to(torch.long)
13 |
14 | cum_num_nodes = torch.cat(
15 | [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
16 |
17 | index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
18 | index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
19 |
20 | x_min_value = x.min().item()
21 | dense_x = x.new_full((batch_size * max_num_nodes, ), x_min_value - 1)
22 | dense_x[index] = x
23 | dense_x = dense_x.view(batch_size, max_num_nodes)
24 | _, perm = dense_x.sort(dim=-1, descending=True)
25 |
26 | perm = perm + cum_num_nodes.view(-1, 1)
27 | perm = perm.view(-1)
28 |
29 | mask = [
30 | torch.arange(k[i], dtype=torch.long, device=x.device) + i *
31 | max_num_nodes for i in range(len(num_nodes))
32 | ]
33 | mask = torch.cat(mask, dim=0)
34 |
35 | perm = perm[mask]
36 |
37 | return perm
38 |
39 |
40 | def filter_adj(edge_index, edge_attr, perm, num_nodes=None):
41 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
42 |
43 | mask = perm.new_full((num_nodes, ), -1)
44 | i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
45 | mask[perm] = i
46 |
47 | row, col = edge_index
48 | row, col = mask[row], mask[col]
49 | mask = (row >= 0) & (col >= 0)
50 | row, col = row[mask], col[mask]
51 |
52 | if edge_attr is not None:
53 | edge_attr = edge_attr[mask]
54 |
55 | return torch.stack([row, col], dim=0), edge_attr
56 |
57 |
58 | class TopKPooling(torch.nn.Module):
59 | r""":math:`\mathrm{top}_k` pooling from the `"Graph U-Net"
60 | `_ paper.
61 |
62 | Args:
63 | in_channels (int): Size of each input sample.
64 | ratio (float): Graph pooling ratio. (default: :obj:`0.5`)
65 | """
66 |
67 | def __init__(self, in_channels, ratio=0.5):
68 | super(TopKPooling, self).__init__()
69 |
70 | self.in_channels = in_channels
71 | self.ratio = ratio
72 |
73 | self.weight = Parameter(torch.Tensor(1, in_channels))
74 |
75 | self.reset_parameters()
76 |
77 | def reset_parameters(self):
78 | size = self.in_channels
79 | uniform(size, self.weight)
80 |
81 | def forward(self, x, edge_index, edge_attr=None, batch=None):
82 | """"""
83 | if batch is None:
84 | batch = edge_index.new_zeros(x.size(0))
85 |
86 | x = x.unsqueeze(-1) if x.dim() == 1 else x
87 |
88 | score = (x * self.weight).sum(dim=-1)
89 | score = score / self.weight.norm(p=2, dim=-1)
90 | perm = topk(score, self.ratio, batch)
91 | x = x[perm] * torch.tanh(score[perm]).view(-1, 1)
92 | batch = batch[perm]
93 | edge_index, edge_attr = filter_adj(
94 | edge_index, edge_attr, perm, num_nodes=score.size(0))
95 |
96 | return x, edge_index, edge_attr, batch, perm
97 |
98 | def __repr__(self):
99 | return '{}({})'.format(self.__class__.__name__, self.ratio)
100 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/cheb_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_sparse import spmm
4 | from torch_geometric.utils import degree, remove_self_loops
5 |
6 | from ..inits import uniform
7 |
8 |
9 | class ChebConv(torch.nn.Module):
10 | r"""The chebyshev spectral graph convolutional operator from the
11 | `"Convolutional Neural Networks on Graphs with Fast Localized Spectral
12 | Filtering" `_ paper
13 |
14 | .. math::
15 | \mathbf{X}^{\prime} = \sum_{k=0}^{K-1} \mathbf{\hat{X}}_k \cdot
16 | \mathbf{\Theta}_k
17 |
18 | where :math:`\mathbf{\hat{X}}_k` is computed recursively by
19 |
20 | .. math::
21 | \mathbf{\hat{X}}_0 &= \mathbf{X}
22 |
23 | \mathbf{\hat{X}}_1 &= \mathbf{\hat{L}} \cdot \mathbf{X}
24 |
25 | \mathbf{\hat{X}}_k &= 2 \cdot \mathbf{\hat{L}} \cdot
26 | \mathbf{\hat{X}}_{k-1} - \mathbf{\hat{X}}_{k-2}
27 |
28 | and :math:`\mathbf{\hat{L}}` denotes the scaled and normalized Lalplacian.
29 |
30 | Args:
31 | in_channels (int): Size of each input sample.
32 | out_channels (int): Size of each output sample.
33 | K (int): Chebyshev filter size, *i.e.* number of hops.
34 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
35 | an additive bias. (default: :obj:`True`)
36 | """
37 |
38 | def __init__(self, in_channels, out_channels, K, bias=True):
39 | super(ChebConv, self).__init__()
40 |
41 | self.in_channels = in_channels
42 | self.out_channels = out_channels
43 | self.weight = Parameter(torch.Tensor(K, in_channels, out_channels))
44 |
45 | if bias:
46 | self.bias = Parameter(torch.Tensor(out_channels))
47 | else:
48 | self.register_parameter('bias', None)
49 |
50 | self.reset_parameters()
51 |
52 | def reset_parameters(self):
53 | size = self.in_channels * self.weight.size(0)
54 | uniform(size, self.weight)
55 | uniform(size, self.bias)
56 |
57 | def forward(self, x, edge_index, edge_attr=None):
58 | """"""
59 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
60 |
61 | row, col = edge_index
62 | num_nodes, num_edges, K = x.size(0), row.size(0), self.weight.size(0)
63 |
64 | if edge_attr is None:
65 | edge_attr = x.new_ones((num_edges, ))
66 | assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1)
67 |
68 | deg = degree(row, num_nodes, dtype=x.dtype)
69 |
70 | # Compute normalized and rescaled Laplacian.
71 | deg = deg.pow(-0.5)
72 | deg[deg == float('inf')] = 0
73 | lap = -deg[row] * edge_attr * deg[col]
74 |
75 | # Perform filter operation recurrently.
76 | Tx_0 = x
77 | out = torch.mm(Tx_0, self.weight[0])
78 |
79 | if K > 1:
80 | Tx_1 = spmm(edge_index, lap, num_nodes, x)
81 | out = out + torch.mm(Tx_1, self.weight[1])
82 |
83 | for k in range(2, K):
84 | Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0
85 | out = out + torch.mm(Tx_2, self.weight[k])
86 | Tx_0, Tx_1 = Tx_1, Tx_2
87 |
88 | if self.bias is not None:
89 | out = out + self.bias
90 |
91 | return out
92 |
93 | def __repr__(self):
94 | return '{}({}, {}, K={})'.format(self.__class__.__name__,
95 | self.in_channels, self.out_channels,
96 | self.weight.size(0) - 1)
97 |
--------------------------------------------------------------------------------
/torch_geometric/data/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import (contains_isolated_nodes,
3 | contains_self_loops, is_undirected)
4 |
5 | from ..utils.num_nodes import maybe_num_nodes
6 |
7 |
8 | class Data(object):
9 | def __init__(self,
10 | x=None,
11 | edge_index=None,
12 | edge_attr=None,
13 | y=None,
14 | pos=None):
15 | self.x = x
16 | self.edge_index = edge_index
17 | self.edge_attr = edge_attr
18 | self.y = y
19 | self.pos = pos
20 |
21 | @staticmethod
22 | def from_dict(dictionary):
23 | data = Data()
24 | for key, item in dictionary.items():
25 | data[key] = item
26 | return data
27 |
28 | def __getitem__(self, key):
29 | return getattr(self, key)
30 |
31 | def __setitem__(self, key, item):
32 | setattr(self, key, item)
33 |
34 | @property
35 | def keys(self):
36 | return [key for key in self.__dict__.keys() if self[key] is not None]
37 |
38 | def __len__(self):
39 | return len(self.keys)
40 |
41 | def __contains__(self, key):
42 | return key in self.keys
43 |
44 | def __iter__(self):
45 | for key in sorted(self.keys):
46 | yield key, self[key]
47 |
48 | def __call__(self, *keys):
49 | for key in sorted(self.keys) if not keys else keys:
50 | if self[key] is not None:
51 | yield key, self[key]
52 |
53 | def cat_dim(self, key, item):
54 | return -1 if item.dtype == torch.long else 0
55 |
56 | @property
57 | def num_nodes(self):
58 | for key, item in self('x', 'pos'):
59 | return item.size(self.cat_dim(key, item))
60 | if self.edge_index is not None:
61 | return maybe_num_nodes(self.edge_index)
62 | return None
63 |
64 | @property
65 | def num_edges(self):
66 | for key, item in self('edge_index', 'edge_attr'):
67 | return item.size(self.cat_dim(key, item))
68 | return None
69 |
70 | @property
71 | def num_features(self):
72 | return 1 if self.x.dim() == 1 else self.x.size(1)
73 |
74 | @property
75 | def num_classes(self):
76 | return self.y.max().item() + 1 if self.y.dim() == 1 else self.y.size(1)
77 |
78 | def is_coalesced(self):
79 | row, col = self.edge_index
80 | index = self.num_nodes * row + col
81 | return self.row.size(0) == torch.unique(index).size(0)
82 |
83 | def contains_isolated_nodes(self):
84 | return contains_isolated_nodes(self.edge_index, self.num_nodes)
85 |
86 | def contains_self_loops(self):
87 | return contains_self_loops(self.edge_index)
88 |
89 | def is_undirected(self):
90 | return is_undirected(self.edge_index, self.num_nodes)
91 |
92 | def is_directed(self):
93 | return not self.is_undirected()
94 |
95 | def apply(self, func, *keys):
96 | for key, item in self(*keys):
97 | self[key] = func(item)
98 | return self
99 |
100 | def contiguous(self, *keys):
101 | return self.apply(lambda x: x.contiguous(), *keys)
102 |
103 | def to(self, device, *keys):
104 | return self.apply(lambda x: x.to(device), *keys)
105 |
106 | def __repr__(self):
107 | info = ['{}={}'.format(key, list(item.size())) for key, item in self]
108 | return '{}({})'.format(self.__class__.__name__, ', '.join(info))
109 |
--------------------------------------------------------------------------------