├── docs ├── .nojekyll ├── build │ └── html │ │ ├── .nojekyll │ │ ├── _static │ │ ├── pygments.css │ │ ├── down.png │ │ ├── file.png │ │ ├── plus.png │ │ ├── up.png │ │ ├── minus.png │ │ ├── comment.png │ │ ├── ajax-loader.gif │ │ ├── up-pressed.png │ │ ├── comment-bright.png │ │ ├── comment-close.png │ │ ├── down-pressed.png │ │ ├── fonts │ │ │ ├── Lato-Bold.ttf │ │ │ ├── Lato-Regular.ttf │ │ │ ├── RobotoSlab-Bold.ttf │ │ │ ├── Inconsolata-Bold.ttf │ │ │ ├── RobotoSlab-Regular.ttf │ │ │ ├── Inconsolata-Regular.ttf │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.ttf │ │ │ └── fontawesome-webfont.woff │ │ └── css │ │ │ ├── custom.css │ │ │ └── badge_only.css │ │ └── _sources │ │ ├── modules │ │ ├── data.rst.txt │ │ ├── utils.rst.txt │ │ ├── datasets.rst.txt │ │ ├── transforms.rst.txt │ │ └── nn.rst.txt │ │ ├── index.rst.txt │ │ └── notes │ │ └── installation.rst.txt ├── source │ ├── _figures │ │ ├── .gitignore │ │ ├── build.sh │ │ └── graph.tex │ ├── modules │ │ ├── data.rst │ │ ├── utils.rst │ │ ├── datasets.rst │ │ ├── transforms.rst │ │ └── nn.rst │ ├── _static │ │ └── css │ │ │ └── custom.css │ ├── conf.py │ ├── index.rst │ └── notes │ │ └── installation.rst ├── requirements.txt ├── .gitignore ├── index.html └── Makefile ├── test ├── nn │ ├── conv │ │ └── test_spline_conv.py │ └── pool │ │ ├── test_consecutive.py │ │ ├── test_voxel_grid.py │ │ ├── test_graclus.py │ │ └── test_topk_pool.py ├── utils │ ├── test_degree.py │ ├── test_normalized_cut.py │ ├── test_to_batch.py │ ├── test_isolated.py │ ├── test_one_hot.py │ ├── test_grid.py │ ├── test_undirected.py │ ├── test_softmax.py │ ├── test_convert.py │ └── test_loop.py ├── transforms │ ├── test_random_flip.py │ ├── test_random_rotate.py │ ├── test_face_to_edge.py │ ├── test_radius_graph.py │ ├── test_linear_transformation.py │ ├── test_nn_graph.py │ ├── test_target_indegree.py │ ├── test_distance.py │ ├── test_sample_points.py │ ├── test_compose.py │ ├── test_spherical.py │ ├── test_cartesian.py │ ├── test_two_hop.py │ ├── test_polar.py │ └── test_local_cartesian.py ├── data │ ├── test_collate.py │ ├── test_batch.py │ ├── test_split.py │ └── test_data.py └── datasets │ ├── test_planetoid.py │ └── test_tu_dataset.py ├── MANIFEST.in ├── .coveragerc ├── torch_geometric ├── __init__.py ├── utils │ ├── num_nodes.py │ ├── sparse.py │ ├── normalized_cut.py │ ├── isolated.py │ ├── softmax.py │ ├── one_hot.py │ ├── to_batch.py │ ├── undirected.py │ ├── loop.py │ ├── degree.py │ ├── __init__.py │ ├── grid.py │ ├── convert.py │ ├── scatter.py │ └── metric.py ├── nn │ ├── prop │ │ ├── __init__.py │ │ ├── gcn_prop.py │ │ └── agnn_prop.py │ ├── dense │ │ ├── __init__.py │ │ ├── diff_pool.py │ │ └── sage_conv.py │ ├── repeat.py │ ├── pool │ │ ├── graclus.py │ │ ├── consecutive.py │ │ ├── global_pool.py │ │ ├── __init__.py │ │ ├── pool.py │ │ ├── sort_pool.py │ │ ├── voxel_grid.py │ │ ├── max_pool.py │ │ ├── avg_pool.py │ │ ├── set2set.py │ │ └── topk_pool.py │ ├── __init__.py │ ├── conv │ │ ├── __init__.py │ │ ├── edge_conv.py │ │ ├── gin_conv.py │ │ ├── gcn_conv.py │ │ ├── graph_conv.py │ │ ├── sage_conv.py │ │ ├── nn_conv.py │ │ └── cheb_conv.py │ └── inits.py ├── transforms │ ├── normalize_features.py │ ├── center.py │ ├── normalize_scale.py │ ├── face_to_edge.py │ ├── add_self_loops.py │ ├── radius_graph.py │ ├── nn_graph.py │ ├── random_scale.py │ ├── sample_points.py │ ├── random_shear.py │ ├── random_rotate.py │ ├── compose.py │ ├── random_translate.py │ ├── two_hop.py │ ├── random_flip.py │ ├── constant.py │ ├── distance.py │ ├── __init__.py │ ├── one_hot_degree.py │ ├── target_indegree.py │ ├── linear_transformation.py │ ├── to_dense.py │ ├── cartesian.py │ ├── polar.py │ ├── spherical.py │ └── local_cartesian.py ├── data │ ├── makedirs.py │ ├── download.py │ ├── extract.py │ ├── __init__.py │ ├── dataloader.py │ ├── batch.py │ ├── dataset.py │ ├── in_memory_dataset.py │ └── data.py ├── read │ ├── __init__.py │ ├── txt_array.py │ ├── ply.py │ ├── off.py │ ├── sdf.py │ └── planetoid.py └── datasets │ ├── __init__.py │ ├── karate.py │ ├── planetoid.py │ ├── qm9.py │ ├── qm7.py │ ├── faust.py │ ├── tu_dataset.py │ ├── mnist_superpixels.py │ ├── ppi.py │ ├── coma.py │ └── modelnet.py ├── setup.cfg ├── .gitignore ├── setup.py ├── LICENSE ├── .travis.yml └── examples ├── gat.py ├── agnn.py ├── cora.py ├── gcn.py ├── mnist_voxel_grid.py ├── mnist_graclus.py ├── enzymes_topk_pool.py ├── mnist_mpnn.py └── faust.py /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/build/html/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/nn/conv/test_spline_conv.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_figures/.gitignore: -------------------------------------------------------------------------------- 1 | *.aux 2 | *.log 3 | *.pdf 4 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==1.6.6 2 | sphinx_rtd_theme==0.2.4 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | 3 | recursive-include examples * 4 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | pragma: no cover 4 | raise 5 | -------------------------------------------------------------------------------- /torch_geometric/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.3.1' 2 | 3 | __all__ = ['__version__'] 4 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/doctest 2 | build/doctrees 3 | build/html/.buildinfo 4 | build/html/objects.inv 5 | -------------------------------------------------------------------------------- /docs/build/html/_static/pygments.css: -------------------------------------------------------------------------------- 1 | .highlight .hll { background-color: #ffffcc } 2 | .highlight { background: #ffffff; } -------------------------------------------------------------------------------- /docs/build/html/_static/down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/down.png -------------------------------------------------------------------------------- /docs/build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/file.png -------------------------------------------------------------------------------- /docs/build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/build/html/_static/up.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/up.png -------------------------------------------------------------------------------- /docs/build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/build/html/_static/comment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/comment.png -------------------------------------------------------------------------------- /docs/build/html/_static/ajax-loader.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/ajax-loader.gif -------------------------------------------------------------------------------- /docs/build/html/_static/up-pressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/up-pressed.png -------------------------------------------------------------------------------- /docs/build/html/_static/comment-bright.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/comment-bright.png -------------------------------------------------------------------------------- /docs/build/html/_static/comment-close.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/comment-close.png -------------------------------------------------------------------------------- /docs/build/html/_static/down-pressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/down-pressed.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [tool:pytest] 8 | addopts = --capture=no --cov 9 | -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Lato-Bold.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Lato-Regular.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/RobotoSlab-Bold.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Inconsolata-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Inconsolata-Bold.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/RobotoSlab-Regular.ttf -------------------------------------------------------------------------------- /docs/source/modules/data.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.data 2 | ==================== 3 | 4 | .. automodule:: torch_geometric.data 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /torch_geometric/utils/num_nodes.py: -------------------------------------------------------------------------------- 1 | def maybe_num_nodes(edge_index, num_nodes=None): 2 | return edge_index.max().item() + 1 if num_nodes is None else num_nodes 3 | -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Inconsolata-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/Inconsolata-Regular.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/pytorch_geometric/HEAD/docs/build/html/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.utils 2 | ===================== 3 | 4 | .. automodule:: torch_geometric.utils 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /torch_geometric/nn/prop/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn_prop import GCNProp 2 | from .agnn_prop import AGNNProp 3 | 4 | __all__ = [ 5 | 'GCNProp', 6 | 'AGNNProp', 7 | ] 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/data.rst.txt: -------------------------------------------------------------------------------- 1 | torch_geometric.data 2 | ==================== 3 | 4 | .. automodule:: torch_geometric.data 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/modules/datasets.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.datasets 2 | ======================== 3 | 4 | .. automodule:: torch_geometric.datasets 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/utils.rst.txt: -------------------------------------------------------------------------------- 1 | torch_geometric.utils 2 | ===================== 3 | 4 | .. automodule:: torch_geometric.utils 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Redirect 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/source/modules/transforms.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.transforms 2 | ========================== 3 | 4 | .. automodule:: torch_geometric.transforms 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | build/ 4 | dist/ 5 | alpha/ 6 | .cache/ 7 | .eggs/ 8 | *.egg-info/ 9 | .coverage 10 | 11 | !docs/build/ 12 | !torch_geometric/data/ 13 | !test/data/ 14 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/datasets.rst.txt: -------------------------------------------------------------------------------- 1 | torch_geometric.datasets 2 | ======================== 3 | 4 | .. automodule:: torch_geometric.datasets 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /torch_geometric/nn/dense/__init__.py: -------------------------------------------------------------------------------- 1 | from .sage_conv import DenseSAGEConv 2 | from .diff_pool import dense_diff_pool 3 | 4 | __all__ = [ 5 | 'DenseSAGEConv', 6 | 'dense_diff_pool', 7 | ] 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/transforms.rst.txt: -------------------------------------------------------------------------------- 1 | torch_geometric.transforms 2 | ========================== 3 | 4 | .. automodule:: torch_geometric.transforms 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/source/_figures/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for filename in *.tex; do 4 | basename=$(basename $filename .tex) 5 | pdflatex "$basename.tex" 6 | pdf2svg "$basename.pdf" "$basename.svg" 7 | done 8 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Use white for logo background. */ 2 | .wy-side-nav-search { 3 | background-color: #fff; 4 | } 5 | 6 | .wy-side-nav-search > div.version { 7 | color: #000; 8 | } 9 | -------------------------------------------------------------------------------- /docs/build/html/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Use white for logo background. */ 2 | .wy-side-nav-search { 3 | background-color: #fff; 4 | } 5 | 6 | .wy-side-nav-search > div.version { 7 | color: #000; 8 | } 9 | -------------------------------------------------------------------------------- /test/utils/test_degree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree 3 | 4 | 5 | def test_degree(): 6 | row = torch.tensor([0, 1, 0, 2, 0]) 7 | assert degree(row).tolist() == [3, 1, 1] 8 | -------------------------------------------------------------------------------- /torch_geometric/nn/repeat.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import itertools 3 | 4 | 5 | def repeat(src, length): 6 | if isinstance(src, numbers.Number): 7 | src = list(itertools.repeat(src, length)) 8 | return src 9 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/graclus.py: -------------------------------------------------------------------------------- 1 | from torch_cluster import graclus_cluster 2 | 3 | 4 | def graclus(edge_index, weight=None, num_nodes=None): 5 | row, col = edge_index 6 | return graclus_cluster(row, col, weight, num_nodes) 7 | -------------------------------------------------------------------------------- /torch_geometric/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv import * # noqa 2 | from .prop import * # noqa 3 | from .pool import * # noqa 4 | from .dense import * # noqa 5 | from .meta import MetaLayer 6 | 7 | __all__ = [ 8 | 'MetaLayer', 9 | ] 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD = sphinx-build 2 | SPHINXPROJ = pytorch_geometric 3 | SOURCEDIR = source 4 | BUILDDIR = build 5 | 6 | .PHONY: help Makefile 7 | 8 | %: Makefile 9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" 10 | -------------------------------------------------------------------------------- /torch_geometric/transforms/normalize_features.py: -------------------------------------------------------------------------------- 1 | class NormalizeFeatures(object): 2 | def __call__(self, data): 3 | data.x = data.x / data.x.sum(1, keepdim=True).clamp(min=1) 4 | return data 5 | 6 | def __repr__(self): 7 | return '{}()'.format(self.__class__.__name__) 8 | -------------------------------------------------------------------------------- /torch_geometric/data/makedirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import errno 4 | 5 | 6 | def makedirs(path): 7 | try: 8 | os.makedirs(osp.expanduser(osp.normpath(path))) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST and osp.isdir(path): 11 | raise e 12 | -------------------------------------------------------------------------------- /torch_geometric/utils/sparse.py: -------------------------------------------------------------------------------- 1 | from torch_sparse import coalesce 2 | 3 | 4 | def dense_to_sparse(tensor): 5 | index = tensor.nonzero() 6 | value = tensor[index] 7 | index = index.t().contiguous() 8 | index, value = coalesce(index, value, tensor.size(0), tensor.size(1)) 9 | return index, value 10 | -------------------------------------------------------------------------------- /torch_geometric/utils/normalized_cut.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.utils import degree 2 | 3 | 4 | def normalized_cut(edge_index, edge_attr, num_nodes=None): 5 | row, col = edge_index 6 | deg = 1 / degree(row, num_nodes, edge_attr.dtype) 7 | deg = deg[row] + deg[col] 8 | cut = edge_attr * deg 9 | return cut 10 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/consecutive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def consecutive_cluster(src): 5 | unique, inv = torch.unique(src, sorted=True, return_inverse=True) 6 | perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device) 7 | perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm) 8 | return inv, perm 9 | -------------------------------------------------------------------------------- /torch_geometric/transforms/center.py: -------------------------------------------------------------------------------- 1 | class Center(object): 2 | def __call__(self, data): 3 | pos = data.pos 4 | 5 | mean = data.pos.mean(dim=0).view(1, -1) 6 | data.pos = pos - mean 7 | 8 | return data 9 | 10 | def __repr__(self): 11 | return '{}()'.format(self.__class__.__name__) 12 | -------------------------------------------------------------------------------- /test/nn/pool/test_consecutive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.pool.consecutive import consecutive_cluster 3 | 4 | 5 | def test_consecutive_cluster(): 6 | src = torch.tensor([8, 2, 10, 15, 100, 1, 100]) 7 | 8 | out, perm = consecutive_cluster(src) 9 | assert out.tolist() == [2, 1, 3, 4, 5, 0, 5] 10 | assert perm.tolist() == [5, 1, 0, 2, 3, 6] 11 | -------------------------------------------------------------------------------- /torch_geometric/utils/isolated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .num_nodes import maybe_num_nodes 4 | from .loop import remove_self_loops 5 | 6 | 7 | def contains_isolated_nodes(edge_index, num_nodes=None): 8 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 9 | (row, _), _ = remove_self_loops(edge_index) 10 | return torch.unique(row).size(0) < num_nodes 11 | -------------------------------------------------------------------------------- /docs/source/_figures/graph.tex: -------------------------------------------------------------------------------- 1 | \documentclass{standalone} 2 | 3 | \usepackage{tikz} 4 | 5 | \begin{document} 6 | 7 | \begin{tikzpicture} 8 | \node[draw,circle,label= left:{$x_1=-1$}] (0) at (0, 0) {0}; 9 | \node[draw,circle,label=above:{$x_1=0$}] (1) at (1, 1) {1}; 10 | \node[draw,circle,label=right:{$x_1=1$}] (2) at (2, 0) {2}; 11 | 12 | \path[draw] (0) -- (1); 13 | \path[draw] (1) -- (2); 14 | \end{tikzpicture} 15 | 16 | \end{document} 17 | -------------------------------------------------------------------------------- /torch_geometric/utils/softmax.py: -------------------------------------------------------------------------------- 1 | from torch_scatter import scatter_max, scatter_add 2 | 3 | from .num_nodes import maybe_num_nodes 4 | 5 | 6 | def softmax(src, index, num_nodes=None): 7 | num_nodes = maybe_num_nodes(index, num_nodes) 8 | 9 | out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] 10 | out = out.exp() 11 | out = out / scatter_add(out, index, dim=0, dim_size=num_nodes)[index] 12 | 13 | return out 14 | -------------------------------------------------------------------------------- /test/transforms/test_random_flip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import RandomFlip 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_cartesian(): 7 | assert RandomFlip(axis=0).__repr__() == 'RandomFlip(axis=0, p=0.5)' 8 | 9 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float) 10 | data = Data(pos=pos) 11 | 12 | out = RandomFlip(axis=0, p=1)(data).pos.tolist() 13 | assert out == [[1, 1], [3, 0], [-2, -1]] 14 | -------------------------------------------------------------------------------- /test/transforms/test_random_rotate.py: -------------------------------------------------------------------------------- 1 | from pytest import approx 2 | import torch 3 | from torch_geometric.transforms import RandomRotate 4 | from torch_geometric.data import Data 5 | 6 | 7 | def test_spherical(): 8 | assert RandomRotate(-180).__repr__() == 'RandomRotate((-180, 180))' 9 | 10 | pos = torch.tensor([[1, 0], [0, 1]], dtype=torch.float) 11 | data = Data(pos=pos) 12 | 13 | out = RandomRotate((90, 90))(data).pos.view(-1).tolist() 14 | assert approx(out) == [0, 1, -1, 0] 15 | -------------------------------------------------------------------------------- /torch_geometric/data/download.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from six.moves import urllib 3 | 4 | from .makedirs import makedirs 5 | 6 | 7 | def download_url(url, folder, log=True): 8 | if log: 9 | print('Downloading', url) 10 | 11 | makedirs(folder) 12 | 13 | data = urllib.request.urlopen(url) 14 | filename = url.rpartition('/')[2] 15 | path = osp.join(folder, filename) 16 | 17 | with open(path, 'wb') as f: 18 | f.write(data.read()) 19 | 20 | return path 21 | -------------------------------------------------------------------------------- /torch_geometric/transforms/normalize_scale.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.transforms import Center 2 | 3 | 4 | class NormalizeScale(object): 5 | def __init__(self): 6 | self.center = Center() 7 | 8 | def __call__(self, data): 9 | data = self.center(data) 10 | 11 | scale = (1 / data.pos.abs().max()) * 0.999999 12 | data.pos = data.pos * scale 13 | 14 | return data 15 | 16 | def __repr__(self): 17 | return '{}()'.format(self.__class__.__name__) 18 | -------------------------------------------------------------------------------- /torch_geometric/data/extract.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import zipfile 3 | 4 | 5 | def maybe_log(path, log=True): 6 | if log: 7 | print('Extracting', path) 8 | 9 | 10 | def extract_tar(path, folder, mode='r:gz', log=True): 11 | maybe_log(path, log) 12 | with tarfile.open(path, mode) as f: 13 | f.extractall(folder) 14 | 15 | 16 | def extract_zip(path, folder, log=True): 17 | maybe_log(path, log) 18 | with zipfile.ZipFile(path, 'r') as f: 19 | f.extractall(folder) 20 | -------------------------------------------------------------------------------- /test/utils/test_normalized_cut.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import normalized_cut 3 | 4 | 5 | def test_normalized_cut(): 6 | row = torch.LongTensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4]) 7 | col = torch.LongTensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3]) 8 | edge_attr = torch.Tensor([3, 3, 6, 3, 6, 1, 3, 2, 1, 2]) 9 | expected_output = [4, 4, 5, 2.5, 5, 1, 2.5, 2, 1, 2] 10 | 11 | output = normalized_cut(torch.stack([row, col], dim=0), edge_attr) 12 | assert output.tolist() == expected_output 13 | -------------------------------------------------------------------------------- /torch_geometric/read/__init__.py: -------------------------------------------------------------------------------- 1 | from .txt_array import parse_txt_array, read_txt_array 2 | from .tu import read_tu_data 3 | from .planetoid import read_planetoid_data 4 | from .ply import read_ply 5 | from .sdf import read_sdf, parse_sdf 6 | from .off import read_off, parse_off 7 | 8 | __all__ = [ 9 | 'parse_txt_array', 10 | 'read_txt_array', 11 | 'read_tu_data', 12 | 'read_planetoid_data', 13 | 'read_ply', 14 | 'read_sdf', 15 | 'parse_sdf', 16 | 'read_off', 17 | 'parse_off', 18 | ] 19 | -------------------------------------------------------------------------------- /test/transforms/test_face_to_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import FaceToEdge 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_face_to_edge(): 7 | assert FaceToEdge().__repr__() == 'FaceToEdge()' 8 | 9 | face = torch.tensor([[0, 0], [1, 1], [2, 3]]) 10 | data = Data() 11 | data.face = face 12 | 13 | row, col = FaceToEdge()(data).edge_index 14 | assert row.tolist() == [0, 0, 0, 1, 1, 1, 2, 2, 3, 3] 15 | assert col.tolist() == [1, 2, 3, 0, 2, 3, 0, 1, 0, 1] 16 | -------------------------------------------------------------------------------- /test/utils/test_to_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import to_batch 3 | 4 | 5 | def test_to_batch(): 6 | x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) 7 | batch = torch.tensor([0, 0, 1, 2, 2, 2]) 8 | 9 | x, num_nodes = to_batch(x, batch) 10 | expected = [ 11 | [[1, 2], [3, 4], [0, 0]], 12 | [[5, 6], [0, 0], [0, 0]], 13 | [[7, 8], [9, 10], [11, 12]], 14 | ] 15 | assert x.tolist() == expected 16 | assert num_nodes.tolist() == [2, 1, 3] 17 | -------------------------------------------------------------------------------- /torch_geometric/read/txt_array.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def parse_txt_array(src, sep=None, start=0, end=None, dtype=None, device=None): 5 | src = [[float(x) for x in line.split(sep)[start:end]] for line in src] 6 | src = torch.tensor(src, dtype=dtype).squeeze() 7 | return src 8 | 9 | 10 | def read_txt_array(path, sep=None, start=0, end=None, dtype=None, device=None): 11 | with open(path, 'r') as f: 12 | src = f.read().split('\n')[:-1] 13 | return parse_txt_array(src, sep, start, end, dtype, device) 14 | -------------------------------------------------------------------------------- /test/transforms/test_radius_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import RadiusGraph 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_radius_graph(): 7 | assert RadiusGraph(1).__repr__() == 'RadiusGraph(r=1)' 8 | 9 | pos = [[0, 0], [1, 0], [2, 0], [0, 1], [-2, 0], [0, -2]] 10 | pos = torch.tensor(pos, dtype=torch.float) 11 | data = Data(pos=pos) 12 | 13 | edge_index = RadiusGraph(1)(data).edge_index 14 | assert edge_index.tolist() == [[0, 0, 1, 1, 2, 3], [1, 3, 0, 2, 1, 0]] 15 | -------------------------------------------------------------------------------- /test/utils/test_isolated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import contains_isolated_nodes 3 | 4 | 5 | def test_contains_isolated_nodes(): 6 | row = torch.tensor([0, 1, 0]) 7 | col = torch.tensor([1, 0, 0]) 8 | 9 | assert not contains_isolated_nodes(torch.stack([row, col], dim=0)) 10 | assert contains_isolated_nodes(torch.stack([row, col], dim=0), 3) 11 | 12 | row = torch.tensor([0, 1, 2, 0]) 13 | col = torch.tensor([1, 0, 2, 0]) 14 | assert contains_isolated_nodes(torch.stack([row, col], dim=0)) 15 | -------------------------------------------------------------------------------- /test/nn/pool/test_voxel_grid.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # from torch_geometric.nn.pool.voxel_grid import voxel_grid 3 | 4 | # def test_voxel_grid(): 5 | # pos = torch.Tensor([[0.5, 0.5], [1.2, 1.5], [0.2, 1.7], [0.3, 0.4], 6 | # [0.5, 0.5], [1.2, 1.5], [0.2, 1.7], [0.3, 0.4]]) 7 | # batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1]) 8 | 9 | # cluster, batch = voxel_grid(pos, size=1, start=0, end=2, batch=batch) 10 | # assert cluster.tolist() == [0, 2, 1, 0, 3, 5, 4, 3] 11 | # assert batch.tolist() == [0, 0, 0, 1, 1, 1] 12 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn_conv import GCNConv 2 | from .cheb_conv import ChebConv 3 | from .sage_conv import SAGEConv 4 | from .graph_conv import GraphConv 5 | from .gat_conv import GATConv 6 | from .gin_conv import GINConv 7 | from .spline_conv import SplineConv 8 | from .nn_conv import NNConv 9 | from .edge_conv import EdgeConv 10 | 11 | __all__ = [ 12 | 'GCNConv', 13 | 'ChebConv', 14 | 'SAGEConv', 15 | 'GraphConv', 16 | 'GATConv', 17 | 'GINConv', 18 | 'SplineConv', 19 | 'NNConv', 20 | 'EdgeConv', 21 | ] 22 | -------------------------------------------------------------------------------- /torch_geometric/transforms/face_to_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import to_undirected 3 | 4 | 5 | class FaceToEdge(object): 6 | def __call__(self, data): 7 | face = data.face 8 | 9 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1) 10 | edge_index = to_undirected(edge_index, num_nodes=data.num_nodes) 11 | 12 | data.edge_index = edge_index 13 | data.face = None 14 | return data 15 | 16 | def __repr__(self): 17 | return '{}()'.format(self.__class__.__name__) 18 | -------------------------------------------------------------------------------- /torch_geometric/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .download import download_url 2 | from .extract import extract_tar, extract_zip 3 | from .data import Data 4 | from .batch import Batch 5 | from .dataset import Dataset 6 | from .in_memory_dataset import InMemoryDataset 7 | from .dataloader import DataLoader, DenseDataLoader 8 | 9 | __all__ = [ 10 | 'download_url', 11 | 'extract_tar', 12 | 'extract_zip', 13 | 'Data', 14 | 'Batch', 15 | 'Dataset', 16 | 'DataLoader', 17 | 'DenseDataLoader', 18 | 'InMemoryDataset', 19 | 'DataLoader', 20 | ] 21 | -------------------------------------------------------------------------------- /torch_geometric/transforms/add_self_loops.py: -------------------------------------------------------------------------------- 1 | from torch_sparse import coalesce 2 | from torch_geometric.utils import add_self_loops 3 | 4 | 5 | class AddSelfLoops(object): 6 | def __call__(self, data): 7 | edge_index = data.edge_index 8 | num_nodes = data.num_nodes 9 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 10 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes) 11 | data.edge_index = edge_index 12 | return data 13 | 14 | def __repr__(self): 15 | return '{}()'.format(self.__class__.__name__) 16 | -------------------------------------------------------------------------------- /torch_geometric/read/ply.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from plyfile import PlyData 3 | from torch_geometric.data import Data 4 | 5 | 6 | def read_ply(path): 7 | with open(path, 'rb') as f: 8 | data = PlyData.read(f) 9 | 10 | pos = ([torch.tensor(data['vertex'][axis]) for axis in ['x', 'y', 'z']]) 11 | pos = torch.stack(pos, dim=-1) 12 | 13 | faces = data['face']['vertex_indices'] 14 | faces = [torch.tensor(face, dtype=torch.long) for face in faces] 15 | face = torch.stack(faces, dim=-1) 16 | 17 | data = Data(pos=pos) 18 | data.face = face 19 | 20 | return data 21 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/global_pool.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.utils import scatter_ 2 | 3 | 4 | def global_add_pool(x, batch, size=None): 5 | size = batch[-1].item() + 1 if size is None else size 6 | return scatter_('add', x, batch, dim_size=size) 7 | 8 | 9 | def global_mean_pool(x, batch, size=None): 10 | size = batch[-1].item() + 1 if size is None else size 11 | return scatter_('mean', x, batch, dim_size=size) 12 | 13 | 14 | def global_max_pool(x, batch, size=None): 15 | size = batch[-1].item() + 1 if size is None else size 16 | return scatter_('max', x, batch, dim_size=size) 17 | -------------------------------------------------------------------------------- /test/utils/test_one_hot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import one_hot 3 | 4 | 5 | def test_one_hot(): 6 | src = torch.LongTensor([1, 0, 3]) 7 | expected_output = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]] 8 | 9 | assert one_hot(src).tolist() == expected_output 10 | assert one_hot(src, 4).tolist() == expected_output 11 | 12 | src = torch.LongTensor([[1, 0], [0, 1], [2, 0]]) 13 | expected_output = [[0, 1, 0, 1, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0]] 14 | assert one_hot(src).tolist() == expected_output 15 | assert one_hot(src, torch.tensor([3, 2])).tolist() == expected_output 16 | -------------------------------------------------------------------------------- /test/transforms/test_linear_transformation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import LinearTransformation 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_cartesian(): 7 | matrix = torch.tensor([[2, 0], [0, 2]], dtype=torch.float) 8 | transform = LinearTransformation(matrix) 9 | out = 'LinearTransformation([[2.0, 0.0], [0.0, 2.0]])' 10 | assert transform.__repr__() == out 11 | 12 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float) 13 | data = Data(pos=pos) 14 | 15 | out = transform(data).pos.tolist() 16 | assert out == [[-2, 2], [-6, 0], [4, -2]] 17 | -------------------------------------------------------------------------------- /test/utils/test_grid.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.utils import grid 2 | 3 | 4 | def test_grid(): 5 | (row, col), pos = grid(height=3, width=2) 6 | 7 | expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2] 8 | expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5] 9 | expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5] 10 | expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5] 11 | 12 | expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]] 13 | 14 | assert row.tolist() == expected_row 15 | assert col.tolist() == expected_col 16 | assert pos.tolist() == expected_pos 17 | -------------------------------------------------------------------------------- /torch_geometric/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .karate import KarateClub 2 | from .tu_dataset import TUDataset 3 | from .planetoid import Planetoid 4 | from .mnist_superpixels import MNISTSuperpixels 5 | from .faust import FAUST 6 | from .qm7 import QM7 7 | from .qm9 import QM9 8 | from .ppi import PPI 9 | from .shapenet import ShapeNet 10 | from .modelnet import ModelNet 11 | from .coma import CoMA 12 | 13 | __all__ = [ 14 | 'KarateClub', 15 | 'TUDataset', 16 | 'Planetoid', 17 | 'MNISTSuperpixels', 18 | 'FAUST', 19 | 'QM7', 20 | 'QM9', 21 | 'PPI', 22 | 'ShapeNet', 23 | 'ModelNet', 24 | 'CoMA', 25 | ] 26 | -------------------------------------------------------------------------------- /test/nn/pool/test_graclus.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # from torch_geometric.nn.pool import graclus 3 | # from torch_geometric.data import Batch 4 | 5 | # def test_graclus_pool(): 6 | # x = torch.Tensor([[1, 6], [2, 5], [3, 4], [4, 3], [5, 2], [6, 1]]) 7 | # row = torch.LongTensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]) 8 | # col = torch.LongTensor([1, 2, 0, 2, 0, 1, 4, 5, 3, 5, 3, 4]) 9 | # edge_index = torch.stack([row, col], dim=0) 10 | # batch = torch.LongTensor([0, 0, 0, 1, 1, 1]) 11 | # data = Batch(x=x, edge_index=edge_index, batch=batch) 12 | 13 | # data = graclus_pool(data) 14 | 15 | # assert data.num_nodes == 4 16 | -------------------------------------------------------------------------------- /test/data/test_collate.py: -------------------------------------------------------------------------------- 1 | # def test_collate_to_set(): 2 | # x1 = torch.tensor([1, 2, 3], dtype=torch.float) 3 | # e1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 4 | # x2 = torch.tensor([1, 2], dtype=torch.float) 5 | # e2 = torch.tensor([[0, 1], [1, 0]]) 6 | 7 | # data, slices = collate_to_set([Data(x1, e1), Data(x2, e2)]) 8 | 9 | # assert len(data) == 2 10 | # assert data.x.tolist() == [1, 2, 3, 1, 2] 11 | # data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1], [1, 0, 2, 1, 1, 0]] 12 | # assert len(slices.keys()) == 2 13 | # assert slices['x'].tolist() == [0, 3, 5] 14 | # assert slices['edge_index'].tolist() == [0, 4, 6] 15 | -------------------------------------------------------------------------------- /test/transforms/test_nn_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import NNGraph 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_nn_graph(): 7 | assert NNGraph().__repr__() == 'NNGraph(k=6)' 8 | 9 | pos = [[0, 0], [1, 0], [2, 0], [0, 1], [-2, 0], [0, -2]] 10 | pos = torch.tensor(pos, dtype=torch.float) 11 | data = Data(pos=pos) 12 | 13 | row, col = NNGraph(2)(data).edge_index 14 | expected_row = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5] 15 | expected_col = [1, 2, 3, 4, 5, 0, 2, 3, 5, 0, 1, 0, 1, 4, 0, 3, 0, 1] 16 | 17 | assert row.tolist() == expected_row 18 | assert col.tolist() == expected_col 19 | -------------------------------------------------------------------------------- /test/transforms/test_target_indegree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import TargetIndegree 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_target_indegree(): 7 | assert TargetIndegree().__repr__() == 'TargetIndegree(cat=True)' 8 | 9 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 10 | data = Data(edge_index=edge_index) 11 | 12 | out = TargetIndegree()(data).edge_attr.tolist() 13 | assert out == [[1], [0.5], [0.5], [1]] 14 | 15 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float) 16 | out = TargetIndegree()(data).edge_attr.tolist() 17 | assert out == [[1, 1], [1, 0.5], [1, 0.5], [1, 1]] 18 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/__init__.py: -------------------------------------------------------------------------------- 1 | from .global_pool import global_add_pool, global_mean_pool, global_max_pool 2 | from .set2set import Set2Set 3 | 4 | from .max_pool import max_pool, max_pool_x 5 | from .avg_pool import avg_pool, avg_pool_x 6 | from .graclus import graclus 7 | from .voxel_grid import voxel_grid 8 | from .topk_pool import TopKPooling 9 | from .sort_pool import sort_pool 10 | 11 | __all__ = [ 12 | 'global_add_pool', 13 | 'global_mean_pool', 14 | 'global_max_pool', 15 | 'Set2Set', 16 | 'max_pool', 17 | 'max_pool_x', 18 | 'avg_pool', 19 | 'avg_pool_x', 20 | 'graclus', 21 | 'voxel_grid', 22 | 'TopKPooling', 23 | 'sort_pool', 24 | ] 25 | -------------------------------------------------------------------------------- /test/utils/test_undirected.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import is_undirected, to_undirected 3 | 4 | 5 | def test_is_undirected(): 6 | row = torch.tensor([0, 1, 0]) 7 | col = torch.tensor([1, 0, 0]) 8 | 9 | assert is_undirected(torch.stack([row, col], dim=0)) 10 | 11 | row = torch.tensor([0, 1, 1]) 12 | col = torch.tensor([1, 0, 2]) 13 | 14 | assert not is_undirected(torch.stack([row, col], dim=0)) 15 | 16 | 17 | def test_to_undirected(): 18 | row = torch.tensor([0, 1, 1]) 19 | col = torch.tensor([1, 0, 2]) 20 | 21 | edge_index = to_undirected(torch.stack([row, col], dim=0)) 22 | assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] 23 | -------------------------------------------------------------------------------- /test/transforms/test_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import Distance 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_distance(): 7 | assert Distance().__repr__() == 'Distance(cat=True)' 8 | 9 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 11 | data = Data(edge_index=edge_index, pos=pos) 12 | 13 | out = Distance()(data).edge_attr.tolist() 14 | assert out == [[1], [1], [2], [2]] 15 | 16 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float) 17 | out = Distance()(data).edge_attr.tolist() 18 | assert out == [[1, 1], [1, 1], [1, 2], [1, 2]] 19 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/pool.py: -------------------------------------------------------------------------------- 1 | from torch_sparse import coalesce 2 | from torch_scatter import scatter_mean 3 | from torch_geometric.utils import remove_self_loops 4 | 5 | 6 | def pool_edge(cluster, edge_index, edge_attr=None): 7 | num_nodes = cluster.size(0) 8 | edge_index = cluster[edge_index.view(-1)].view(2, -1) 9 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 10 | edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, 11 | num_nodes) 12 | return edge_index, edge_attr 13 | 14 | 15 | def pool_batch(perm, batch): 16 | return batch[perm] 17 | 18 | 19 | def pool_pos(cluster, pos): 20 | return scatter_mean(pos, cluster, dim=0) 21 | -------------------------------------------------------------------------------- /test/data/test_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, Batch 3 | 4 | 5 | def test_batch(): 6 | x1 = torch.tensor([1, 2, 3], dtype=torch.float) 7 | e1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 8 | x2 = torch.tensor([1, 2], dtype=torch.float) 9 | e2 = torch.tensor([[0, 1], [1, 0]]) 10 | 11 | data = Batch.from_data_list([Data(x1, e1), Data(x2, e2)]) 12 | 13 | assert data.__repr__() == 'Batch(batch=[5], edge_index=[2, 6], x=[5])' 14 | assert len(data) == 3 15 | assert data.x.tolist() == [1, 2, 3, 1, 2] 16 | assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]] 17 | assert data.batch.tolist() == [0, 0, 0, 1, 1] 18 | assert data.num_graphs == 2 19 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/nn.rst.txt: -------------------------------------------------------------------------------- 1 | torch_geometric.nn 2 | ================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | Convolution Layers 8 | ------------------ 9 | 10 | .. automodule:: torch_geometric.nn.meta 11 | :members: 12 | :undoc-members: 13 | 14 | .. automodule:: torch_geometric.nn.conv 15 | :members: 16 | :undoc-members: 17 | 18 | Propagation Layers 19 | ------------------ 20 | 21 | .. automodule:: torch_geometric.nn.prop 22 | :members: 23 | :undoc-members: 24 | 25 | Pooling Layers 26 | -------------- 27 | 28 | .. automodule:: torch_geometric.nn.pool 29 | :members: 30 | :undoc-members: 31 | 32 | .. automodule:: torch_geometric.nn.dense.diff_pool 33 | :members: 34 | :undoc-members: 35 | -------------------------------------------------------------------------------- /test/transforms/test_sample_points.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import SamplePoints 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_sample_points(): 7 | assert SamplePoints(1024).__repr__() == 'SamplePoints(1024)' 8 | 9 | pos = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]] 10 | data = Data(pos=torch.tensor(pos, dtype=torch.float)) 11 | data.face = torch.tensor([[0, 1], [1, 2], [2, 3]], dtype=torch.long) 12 | 13 | data = SamplePoints(8)(data) 14 | pos = data.pos 15 | assert pos[:, 0].min().item() >= 0 and pos[:, 0].max().item() <= 1 16 | assert pos[:, 1].min().item() >= 0 and pos[:, 1].max().item() <= 1 17 | assert pos[:, 2].abs().sum().item() == 0 18 | assert 'face' not in data 19 | -------------------------------------------------------------------------------- /test/transforms/test_compose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.transforms as T 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_compose(): 7 | transform = T.Compose([T.Cartesian(), T.TargetIndegree()]) 8 | assert transform.__repr__() == ('Compose([\n' 9 | ' Cartesian(cat=True),\n' 10 | ' TargetIndegree(cat=True),\n' 11 | '])') 12 | 13 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 14 | edge_index = torch.tensor([[0, 1], [1, 2]]) 15 | data = Data(edge_index=edge_index, pos=pos) 16 | 17 | out = transform(data).edge_attr.tolist() 18 | assert out == [[0.75, 0.5, 1], [1, 0.5, 1]] 19 | -------------------------------------------------------------------------------- /test/transforms/test_spherical.py: -------------------------------------------------------------------------------- 1 | from pytest import approx 2 | import torch 3 | from torch_geometric.transforms import Spherical 4 | from torch_geometric.data import Data 5 | 6 | 7 | def test_spherical(): 8 | assert Spherical().__repr__() == 'Spherical(cat=True)' 9 | 10 | pos = torch.tensor([[0, 0, 0], [0, 1, 1]], dtype=torch.float) 11 | edge_index = torch.tensor([[0, 1], [1, 0]]) 12 | data = Data(edge_index=edge_index, pos=pos) 13 | 14 | out = Spherical()(data).edge_attr.view(-1).tolist() 15 | assert approx(out) == [1, 0.25, 0, 1, 0.75, 1] 16 | 17 | data.edge_attr = torch.tensor([1, 1], dtype=torch.float) 18 | out = Spherical()(data).edge_attr.view(-1).tolist() 19 | assert approx(out) == [1, 1, 0.25, 0, 1, 1, 0.75, 1] 20 | -------------------------------------------------------------------------------- /test/transforms/test_cartesian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import Cartesian 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_cartesian(): 7 | assert Cartesian().__repr__() == 'Cartesian(cat=True)' 8 | 9 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 11 | data = Data(edge_index=edge_index, pos=pos) 12 | 13 | out = Cartesian()(data).edge_attr.tolist() 14 | assert out == [[0.75, 0.5], [0.25, 0.5], [1, 0.5], [0, 0.5]] 15 | 16 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float) 17 | out = Cartesian()(data).edge_attr.tolist() 18 | assert out == [[1, 0.75, 0.5], [1, 0.25, 0.5], [1, 1, 0.5], [1, 0, 0.5]] 19 | -------------------------------------------------------------------------------- /test/transforms/test_two_hop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import TwoHop 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_two_hop(): 7 | assert TwoHop().__repr__() == 'TwoHop()' 8 | 9 | edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) 10 | edge_attr = torch.tensor([1, 2, 3, 1, 2, 3], dtype=torch.float) 11 | data = Data(edge_index=edge_index, edge_attr=edge_attr) 12 | 13 | data = TwoHop()(data) 14 | edge_index, edge_attr = data.edge_index, data.edge_attr 15 | 16 | assert edge_index.tolist() == [[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], 17 | [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]] 18 | assert edge_attr.tolist() == [1, 2, 3, 1, 0, 0, 2, 0, 0, 3, 0, 0] 19 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/sort_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import to_batch 3 | 4 | 5 | def sort_pool(x, batch, k): 6 | x, _ = x.sort(dim=-1) 7 | 8 | fill_value = x.min().item() - 1 9 | batch_x, num_nodes = to_batch(x, batch, fill_value) 10 | B, N, D = batch_x.size() 11 | 12 | _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True) 13 | arange = torch.arange(B, dtype=torch.long, device=perm.device) * N 14 | perm = perm + arange.view(-1, 1) 15 | 16 | batch_x = batch_x.view(B * N, D) 17 | batch_x = batch_x[perm] 18 | batch_x = batch_x.view(B, N, D) 19 | 20 | batch_x = batch_x[:, :k].contiguous() 21 | batch_x[batch_x == fill_value] = 0 22 | x = batch_x.view(B, k * D) 23 | 24 | return x 25 | -------------------------------------------------------------------------------- /test/transforms/test_polar.py: -------------------------------------------------------------------------------- 1 | from pytest import approx 2 | import torch 3 | from torch_geometric.transforms import Polar 4 | from torch_geometric.data import Data 5 | 6 | 7 | def test_polar(): 8 | assert Polar().__repr__() == 'Polar(cat=True)' 9 | 10 | pos = torch.tensor([[-1, 0], [0, 0], [0, 2]], dtype=torch.float) 11 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 12 | data = Data(edge_index=edge_index, pos=pos) 13 | 14 | out = Polar()(data).edge_attr.view(-1).tolist() 15 | assert approx(out) == [0.5, 0, 0.5, 0.5, 1, 0.25, 1, 0.75] 16 | 17 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float) 18 | out = Polar()(data).edge_attr.view(-1).tolist() 19 | assert approx(out) == [1, 0.5, 0, 1, 0.5, 0.5, 1, 1, 0.25, 1, 1, 0.75] 20 | -------------------------------------------------------------------------------- /torch_geometric/utils/one_hot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def one_hot(src, num_classes=None, dtype=None): 5 | src = src.to(torch.long) 6 | src = src.unsqueeze(-1) if src.dim() == 1 else src 7 | assert src.dim() == 2 8 | 9 | if num_classes is None: 10 | num_classes = src.max(dim=0)[0] + 1 11 | elif isinstance(num_classes, int) or isinstance(num_classes, float): 12 | num_classes = torch.tensor(int(num_classes)) 13 | 14 | if src.size(1) > 1: 15 | zero = torch.tensor([0], device=src.device) 16 | src = src + torch.cat([zero, torch.cumsum(num_classes, 0)[:-1]]) 17 | 18 | size = src.size(0), num_classes.sum() 19 | out = torch.zeros(size, dtype=dtype, device=src.device) 20 | out.scatter_(1, src, 1) 21 | return out 22 | -------------------------------------------------------------------------------- /test/transforms/test_local_cartesian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import LocalCartesian 3 | from torch_geometric.data import Data 4 | 5 | 6 | def test_local_cartesian(): 7 | assert LocalCartesian().__repr__() == 'LocalCartesian(cat=True)' 8 | 9 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 11 | data = Data(edge_index=edge_index, pos=pos) 12 | 13 | out = LocalCartesian()(data).edge_attr.tolist() 14 | assert out == [[1, 0.5], [0.25, 0.5], [1, 0.5], [0, 0.5]] 15 | 16 | data.edge_attr = torch.tensor([1, 1, 1, 1], dtype=torch.float) 17 | out = LocalCartesian()(data).edge_attr.tolist() 18 | assert out == [[1, 1, 0.5], [1, 0.25, 0.5], [1, 1, 0.5], [1, 0, 0.5]] 19 | -------------------------------------------------------------------------------- /torch_geometric/utils/to_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_add 3 | 4 | 5 | def to_batch(x, batch, fill_value=0): 6 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 7 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item() 8 | cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)], dim=0) 9 | 10 | index = torch.arange(batch.size(0), dtype=torch.long, device=x.device) 11 | index = (index - cum_nodes[batch]) + (batch * max_num_nodes) 12 | 13 | size = [batch_size * max_num_nodes] + list(x.size())[1:] 14 | batch_x = x.new_full(size, fill_value) 15 | batch_x[index] = x 16 | size = [batch_size, max_num_nodes] + list(x.size())[1:] 17 | batch_x = batch_x.view(size) 18 | 19 | return batch_x, num_nodes 20 | -------------------------------------------------------------------------------- /test/nn/pool/test_topk_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.pool.topk_pool import topk, filter_adj 3 | 4 | 5 | def test_topk(): 6 | x = torch.tensor([2, 4, 5, 6, 2, 9], dtype=torch.float) 7 | batch = torch.tensor([0, 0, 1, 1, 1, 1]) 8 | 9 | perm = topk(x, 0.5, batch) 10 | 11 | assert perm.tolist() == [1, 5, 3] 12 | assert x[perm].tolist() == [4, 9, 6] 13 | 14 | 15 | def test_filter_adj(): 16 | edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3], 17 | [1, 3, 0, 2, 1, 3, 0, 2]]) 18 | edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) 19 | perm = torch.tensor([2, 3]) 20 | 21 | edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm) 22 | assert edge_index.tolist() == [[0, 1], [1, 0]] 23 | assert edge_attr.tolist() == [6, 8] 24 | -------------------------------------------------------------------------------- /test/utils/test_softmax.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | from torch_geometric.utils import softmax 6 | 7 | 8 | def test_softmax(): 9 | row = torch.LongTensor([0, 0, 0, 1, 1, 1]) 10 | col = torch.LongTensor([1, 2, 3, 0, 2, 3]) 11 | edge_attr = torch.Tensor([0, 1, 2, 0, 1, 2]) 12 | e = [exp(1), exp(2), exp(3)] 13 | e_sum = e[0] + e[1] + e[2] 14 | e = [e[0] / e_sum, e[1] / e_sum, e[2] / e_sum] 15 | 16 | output = softmax(edge_attr, row) 17 | 18 | output = softmax(Variable(edge_attr), row) 19 | 20 | output = softmax(edge_attr, col) 21 | 22 | edge_attr = torch.Tensor([[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]]) 23 | output = softmax(edge_attr, row) 24 | 25 | output = softmax(Variable(edge_attr), row) 26 | output 27 | -------------------------------------------------------------------------------- /test/utils/test_convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.sparse 3 | import networkx 4 | from torch_geometric.utils import to_scipy_sparse_matrix, to_networkx 5 | 6 | 7 | def test_to_scipy_sparse_matrix(): 8 | row = torch.tensor([0, 1, 0]) 9 | col = torch.tensor([1, 0, 0]) 10 | 11 | adj = to_scipy_sparse_matrix(torch.stack([row, col], dim=0)) 12 | assert isinstance(adj, scipy.sparse.coo_matrix) is True 13 | assert adj.shape == (2, 2) 14 | assert adj.row.tolist() == row.tolist() 15 | assert adj.col.tolist() == col.tolist() 16 | assert adj.data.tolist() == [1, 1, 1] 17 | 18 | 19 | def test_to_networkx(): 20 | row = torch.tensor([0, 1, 0]) 21 | col = torch.tensor([1, 0, 0]) 22 | 23 | adj = to_networkx(torch.stack([row, col], dim=0)) 24 | assert networkx.to_numpy_matrix(adj).tolist() == [[1, 1], [1, 0]] 25 | -------------------------------------------------------------------------------- /docs/source/modules/nn.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.nn 2 | ================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | Sparse Convolutional Layers 8 | --------------------------- 9 | 10 | .. automodule:: torch_geometric.nn.conv 11 | :members: 12 | :undoc-members: 13 | 14 | .. automodule:: torch_geometric.nn.meta 15 | :members: 16 | :undoc-members: 17 | 18 | Dense Convolutional Layers 19 | -------------------------- 20 | 21 | Global Pooling Layers 22 | --------------------- 23 | 24 | Sparse Hierarchical Pooling Layers 25 | ---------------------------------- 26 | 27 | .. automodule:: torch_geometric.nn.pool 28 | :members: 29 | :undoc-members: 30 | 31 | Dense Hierarchical Pooling Layers 32 | --------------------------------- 33 | 34 | .. automodule:: torch_geometric.nn.dense.diff_pool 35 | :members: 36 | :undoc-members: 37 | -------------------------------------------------------------------------------- /torch_geometric/utils/undirected.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import coalesce 3 | 4 | from .num_nodes import maybe_num_nodes 5 | 6 | 7 | def is_undirected(edge_index, num_nodes=None): 8 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 9 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes) 10 | undirected_edge_index = to_undirected(edge_index, num_nodes=num_nodes) 11 | return edge_index.size(1) == undirected_edge_index.size(1) 12 | 13 | 14 | def to_undirected(edge_index, num_nodes=None): 15 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 16 | 17 | row, col = edge_index 18 | row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) 19 | edge_index = torch.stack([row, col], dim=0) 20 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes) 21 | 22 | return edge_index 23 | -------------------------------------------------------------------------------- /torch_geometric/nn/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | stdv = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-stdv, stdv) 8 | 9 | 10 | def glorot(tensor): 11 | stdv = math.sqrt(6.0 / (tensor.size(0) + tensor.size(1))) 12 | if tensor is not None: 13 | tensor.data.uniform_(-stdv, stdv) 14 | 15 | 16 | def zeros(tensor): 17 | if tensor is not None: 18 | tensor.data.fill_(0) 19 | 20 | 21 | def ones(tensor): 22 | if tensor is not None: 23 | tensor.data.fill_(1) 24 | 25 | 26 | def reset(nn): 27 | def _reset(item): 28 | if hasattr(item, 'reset_parameters'): 29 | item.reset_parameters() 30 | 31 | if hasattr(nn, 'children'): 32 | for item in nn.children(): 33 | _reset(item) 34 | else: 35 | _reset(nn) 36 | -------------------------------------------------------------------------------- /torch_geometric/read/off.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.read import parse_txt_array 3 | from torch_geometric.data import Data 4 | 5 | 6 | def parse_off(src): 7 | # Some files may contain a bug and do not have a carriage return after OFF. 8 | if src[0] == 'OFF': 9 | src = src[1:] 10 | else: 11 | src[0] = src[0][3:] 12 | 13 | num_nodes, num_faces = [int(item) for item in src[0].split()[:2]] 14 | 15 | pos = parse_txt_array(src[1:1 + num_nodes]) 16 | 17 | face = src[1 + num_nodes:1 + num_nodes + num_faces] 18 | face = parse_txt_array(face, start=1, dtype=torch.long).t().contiguous() 19 | 20 | data = Data(pos=pos) 21 | data.face = face 22 | 23 | return data 24 | 25 | 26 | def read_off(path): 27 | with open(path, 'r') as f: 28 | src = f.read().split('\n')[:-1] 29 | return parse_off(src) 30 | -------------------------------------------------------------------------------- /torch_geometric/transforms/radius_graph.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | 3 | import torch 4 | import scipy.spatial 5 | from torch_geometric.utils import remove_self_loops 6 | 7 | 8 | class RadiusGraph(object): 9 | def __init__(self, r): 10 | self.r = r 11 | 12 | def __call__(self, data): 13 | pos = data.pos 14 | assert not pos.is_cuda 15 | 16 | tree = scipy.spatial.cKDTree(pos) 17 | indices = tree.query_ball_tree(tree, self.r) 18 | 19 | row, col = [], [] 20 | for i, neighbors in enumerate(indices): 21 | row += repeat(i, len(neighbors)) 22 | col += neighbors 23 | edge_index = torch.tensor([row, col]) 24 | edge_index, _ = remove_self_loops(edge_index) 25 | 26 | data.edge_index = edge_index 27 | return data 28 | 29 | def __repr__(self): 30 | return '{}(r={})'.format(self.__class__.__name__, self.r) 31 | -------------------------------------------------------------------------------- /torch_geometric/datasets/karate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import networkx as nx 3 | from torch_geometric.data import InMemoryDataset, Data 4 | 5 | 6 | class KarateClub(InMemoryDataset): 7 | def __init__(self, transform=None): 8 | super(KarateClub, self).__init__('.', transform, None, None) 9 | 10 | G = nx.karate_club_graph() 11 | adj = nx.to_scipy_sparse_matrix(G).tocoo() 12 | row = torch.from_numpy(adj.row).to(torch.long) 13 | col = torch.from_numpy(adj.col).to(torch.long) 14 | edge_index = torch.stack([row, col], dim=0) 15 | data = Data(edge_index=edge_index) 16 | data.x = torch.eye(data.num_nodes, dtype=torch.float) 17 | self.data, self.slices = self.collate([data]) 18 | 19 | def _download(self): 20 | return 21 | 22 | def _process(self): 23 | return 24 | 25 | def __repr__(self): 26 | return '{}()'.format(self.__class__.__name__) 27 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/voxel_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_cluster import grid_cluster 3 | 4 | from ..repeat import repeat 5 | 6 | 7 | def voxel_grid(pos, batch, size, start=None, end=None): 8 | pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos 9 | num_nodes, dim = pos.size() 10 | 11 | size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim) 12 | 13 | pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1) 14 | size = size + [1] 15 | start = None if start is None else start + [0] 16 | end = None if end is None else end + [batch.max().item()] 17 | 18 | size = torch.tensor(size, dtype=pos.dtype, device=pos.device) 19 | if start is not None: 20 | start = torch.tensor(start, dtype=pos.dtype, device=pos.device) 21 | if end is not None: 22 | end = torch.tensor(end, dtype=pos.dtype, device=pos.device) 23 | 24 | return grid_cluster(pos, size, start, end) 25 | -------------------------------------------------------------------------------- /torch_geometric/transforms/nn_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.spatial 3 | from torch_geometric.utils import to_undirected 4 | 5 | 6 | class NNGraph(object): 7 | def __init__(self, k=6): 8 | self.k = k 9 | 10 | def __call__(self, data): 11 | pos = data.pos 12 | assert not pos.is_cuda 13 | 14 | row = torch.arange(pos.size(0), dtype=torch.long) 15 | row = row.view(-1, 1).repeat(1, self.k).view(-1) 16 | 17 | _, col = scipy.spatial.cKDTree(pos).query(pos, self.k + 1) 18 | col = torch.tensor(col)[:, 1:].contiguous().view(-1) 19 | mask = col < pos.size(0) 20 | edge_index = torch.stack([row[mask], col[mask]], dim=0) 21 | edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) 22 | 23 | data.edge_index = edge_index 24 | return data 25 | 26 | def __repr__(self): 27 | return '{}(k={})'.format(self.__class__.__name__, self.k) 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = '0.3.1' 4 | url = 'https://github.com/rusty1s/pytorch_geometric' 5 | 6 | install_requires = [ 7 | 'numpy', 8 | 'scipy', 9 | 'networkx', 10 | 'plyfile', 11 | ] 12 | setup_requires = ['pytest-runner'] 13 | tests_require = ['pytest', 'pytest-cov'] 14 | 15 | setup( 16 | name='torch_geometric', 17 | version=__version__, 18 | description='Geometric Deep Learning Extension Library for PyTorch', 19 | author='Matthias Fey', 20 | author_email='matthias.fey@tu-dortmund.de', 21 | url=url, 22 | download_url='{}/archive/{}.tar.gz'.format(url, __version__), 23 | keywords=[ 24 | 'pytorch', 'geometric-deep-learning', 'graph', 'mesh', 25 | 'neural-networks', 'spline-cnn' 26 | ], 27 | install_requires=install_requires, 28 | setup_requires=setup_requires, 29 | tests_require=tests_require, 30 | packages=find_packages()) 31 | -------------------------------------------------------------------------------- /torch_geometric/utils/loop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .num_nodes import maybe_num_nodes 4 | 5 | 6 | def contains_self_loops(edge_index): 7 | row, col = edge_index 8 | mask = row == col 9 | return mask.sum().item() > 0 10 | 11 | 12 | def remove_self_loops(edge_index, edge_attr=None): 13 | row, col = edge_index 14 | mask = row != col 15 | edge_attr = edge_attr if edge_attr is None else edge_attr[mask] 16 | mask = mask.unsqueeze(0).expand_as(edge_index) 17 | edge_index = edge_index[mask].view(2, -1) 18 | 19 | return edge_index, edge_attr 20 | 21 | 22 | def add_self_loops(edge_index, num_nodes=None): 23 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 24 | 25 | dtype, device = edge_index.dtype, edge_index.device 26 | loop = torch.arange(0, num_nodes, dtype=dtype, device=device) 27 | loop = loop.unsqueeze(0).repeat(2, 1) 28 | edge_index = torch.cat([edge_index, loop], dim=1) 29 | 30 | return edge_index 31 | -------------------------------------------------------------------------------- /torch_geometric/transforms/random_scale.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class RandomScale(object): 5 | r"""Scales node positions by a randomly sampled factor :math:`s` within a 6 | given interval, e.g., resulting in the transformation matrix 7 | 8 | .. math:: 9 | \begin{bmatrix} 10 | s & 0 & 0 \\ 11 | 0 & s & 0 \\ 12 | 0 & 0 & s \\ 13 | \end{bmatrix} 14 | 15 | for three-dimensional positions. 16 | 17 | Args: 18 | scale (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale 19 | is randomly sampled from the range 20 | :math:`a \leq \mathrm{scale} \leq b`. 21 | """ 22 | 23 | def __init__(self, scales): 24 | self.scales = scales 25 | 26 | def __call__(self, data): 27 | scale = random.uniform(*self.scales) 28 | data.pos = data.pos * scale 29 | return data 30 | 31 | def __repr__(self): 32 | return '{}({})'.format(self.__class__.__name__, self.scales) 33 | -------------------------------------------------------------------------------- /torch_geometric/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from torch_geometric.data import Batch 5 | 6 | 7 | class DataLoader(torch.utils.data.DataLoader): 8 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 9 | super(DataLoader, self).__init__( 10 | dataset, 11 | batch_size, 12 | shuffle, 13 | collate_fn=lambda batch: Batch.from_data_list(batch), 14 | **kwargs) 15 | 16 | 17 | class DenseDataLoader(torch.utils.data.DataLoader): 18 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 19 | def dense_collate(data_list): 20 | batch = Batch() 21 | for key in data_list[0].keys: 22 | batch[key] = default_collate([d[key] for d in data_list]) 23 | return batch 24 | 25 | super(DenseDataLoader, self).__init__( 26 | dataset, batch_size, shuffle, collate_fn=dense_collate, **kwargs) 27 | -------------------------------------------------------------------------------- /torch_geometric/utils/degree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .num_nodes import maybe_num_nodes 4 | 5 | 6 | def degree(index, num_nodes=None, dtype=None): 7 | """Computes the degree of a given index tensor. 8 | 9 | Args: 10 | index (LongTensor): Source or target indices of edges. 11 | num_nodes (int, optional): The number of nodes in :attr:`index`. 12 | (default: :obj:`None`) 13 | dtype (:obj:`torch.dtype`, optional). The desired data type of returned 14 | tensor. 15 | 16 | :rtype: :class:`Tensor` 17 | 18 | .. testsetup:: 19 | 20 | import torch 21 | 22 | .. testcode:: 23 | 24 | from torch_geometric.utils import degree 25 | index = torch.tensor([0, 1, 0, 2, 0]) 26 | output = degree(index) 27 | print(output) 28 | 29 | .. testoutput:: 30 | 31 | tensor([3., 1., 1.]) 32 | """ 33 | 34 | num_nodes = maybe_num_nodes(index, num_nodes) 35 | out = torch.zeros((num_nodes), dtype=dtype, device=index.device) 36 | return out.scatter_add_(0, index, out.new_ones((index.size(0)))) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 Matthias Fey 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /torch_geometric/transforms/sample_points.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SamplePoints(object): 5 | def __init__(self, num): 6 | self.num = num 7 | 8 | def __call__(self, data): 9 | pos, face = data.pos, data.face 10 | assert not pos.is_cuda and not face.is_cuda 11 | assert pos.size(1) == 3 and face.size(0) == 3 12 | 13 | area = (pos[face[1]] - pos[face[0]]).cross(pos[face[2]] - pos[face[0]]) 14 | area = torch.sqrt((area**2).sum(dim=-1)) / 2 15 | 16 | prob = area / area.sum() 17 | sample = torch.multinomial(prob, self.num, replacement=True) 18 | face = face[:, sample] 19 | 20 | frac = torch.rand(self.num, 2) 21 | mask = frac.sum(dim=-1) > 1 22 | frac[mask] = 1 - frac[mask] 23 | 24 | pos_sampled = pos[face[0]] 25 | pos_sampled += frac[:, :1] * (pos[face[1]] - pos[face[0]]) 26 | pos_sampled += frac[:, 1:] * (pos[face[2]] - pos[face[0]]) 27 | 28 | data.pos = pos_sampled 29 | data.face = None 30 | return data 31 | 32 | def __repr__(self): 33 | return '{}({})'.format(self.__class__.__name__, self.num) 34 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sphinx_rtd_theme 3 | import doctest 4 | from torch_geometric import __version__ 5 | 6 | extensions = [ 7 | 'sphinx.ext.autodoc', 8 | 'sphinx.ext.doctest', 9 | 'sphinx.ext.intersphinx', 10 | 'sphinx.ext.mathjax', 11 | 'sphinx.ext.napoleon', 12 | 'sphinx.ext.viewcode', 13 | 'sphinx.ext.githubpages', 14 | ] 15 | 16 | source_suffix = '.rst' 17 | master_doc = 'index' 18 | 19 | author = 'Matthias Fey' 20 | project = 'pytorch_geometric' 21 | copyright = '{}, {}'.format(datetime.datetime.now().year, author) 22 | 23 | version = 'master ({})'.format(__version__) 24 | release = 'master' 25 | 26 | html_theme = 'sphinx_rtd_theme' 27 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 28 | 29 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE 30 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)} 31 | 32 | html_theme_options = { 33 | 'collapse_navigation': False, 34 | 'display_version': True, 35 | 'logo_only': True, 36 | } 37 | 38 | html_logo = '_static/img/logo.svg' 39 | html_static_path = ['_static'] 40 | html_context = {'css_files': ['_static/css/custom.css']} 41 | 42 | add_module_names = False 43 | -------------------------------------------------------------------------------- /test/data/test_split.py: -------------------------------------------------------------------------------- 1 | # x1 = torch.tensor([1, 2, 3], dtype=torch.float) 2 | # edge_index1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 3 | # x2 = torch.tensor([1, 2], dtype=torch.float) 4 | # edge_index2 = torch.tensor([[0, 1], [1, 0]]) 5 | # data1, data2 = Data(x1, edge_index1), Data(x2, edge_index2) 6 | # dataset, slices = collate_to_set([data1, data2]) 7 | 8 | # def test_data_from_set(): 9 | # data = data_from_set(dataset, slices, 0) 10 | # assert len(data) == 2 11 | # assert data.x.tolist() == x1.tolist() 12 | # assert data.edge_index.tolist() == edge_index1.tolist() 13 | 14 | # data = data_from_set(dataset, slices, 1) 15 | # assert len(data) == 2 16 | # assert data.x.tolist() == x2.tolist() 17 | # assert data.edge_index.tolist() == edge_index2.tolist() 18 | 19 | # def test_split_set(): 20 | # output, output_slices = split_set(dataset, slices, torch.tensor([0])) 21 | 22 | # assert len(output) == 2 23 | # assert output.x.tolist() == x1.tolist() 24 | # assert output.edge_index.tolist() == edge_index1.tolist() 25 | 26 | # assert len(output_slices.keys()) == 2 27 | # assert output_slices['x'].tolist() == [0, 3] 28 | # assert output_slices['edge_index'].tolist() == [0, 4] 29 | -------------------------------------------------------------------------------- /torch_geometric/transforms/random_shear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import LinearTransformation 3 | 4 | 5 | class RandomShear(object): 6 | r"""Shears node positions by randomly sampled factors :math:`s` within a 7 | given interval, e.g., resulting in the transformation matrix 8 | 9 | .. math:: 10 | \begin{bmatrix} 11 | 1 & s_{xy} & s_{xz} \\ 12 | s_{yx} & 1 & s_{yz} \\ 13 | s_{zx} & z_{zy} & 1 \\ 14 | \end{bmatrix} 15 | 16 | for three-dimensional positions. 17 | 18 | Args: 19 | shear (float or int): maximum shearing factor defining the range 20 | :math:`(-\mathrm{shear}, +\mathrm{shear})` to sample from. 21 | """ 22 | 23 | def __init__(self, shear): 24 | self.shear = abs(shear) 25 | 26 | def __call__(self, data): 27 | dim = data.pos.size(1) 28 | 29 | matrix = data.pos.new_empty(dim, dim).uniform_(-self.shear, self.shear) 30 | eye = torch.arange(dim, dtype=torch.long) 31 | matrix[eye, eye] = 1 32 | 33 | return LinearTransformation(matrix)(data) 34 | 35 | def __repr__(self): 36 | return '{}({})'.format(self.__class__.__name__, self.shear) 37 | -------------------------------------------------------------------------------- /torch_geometric/nn/dense/diff_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def dense_diff_pool(x, adj, s, mask=None): 5 | r"""Differentiable pooling operator from the `"Hierarchical Graph 6 | Representation Learning with Differentiable Pooling" 7 | `_ paper 8 | 9 | .. math:: 10 | \mathbf{X}^{\prime} &= \mathbf{S} \cdot \mathbf{X} 11 | 12 | \mathbf{A}^{\prime} &= \mathbf{S}^{\top} \cdot \mathbf{A} \cdot 13 | \mathbf{S} 14 | 15 | based on dense learned assignments :math:`\mathbf{S}`. 16 | """ 17 | 18 | x = x.unsqueeze(0) if x.dim() == 2 else x 19 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 20 | s = s.unsqueeze(0) if s.dim() == 2 else s 21 | 22 | batch_size, num_nodes, _ = x.size() 23 | 24 | s = torch.softmax(s, dim=-1) 25 | 26 | if mask is not None: 27 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 28 | x, s = x * mask, s * mask 29 | 30 | out = torch.matmul(s.transpose(1, 2), x) 31 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 32 | 33 | reg = adj - torch.matmul(s, s.transpose(1, 2)) 34 | reg = torch.norm(reg, p=2) 35 | reg = reg / adj.numel() 36 | 37 | return out, out_adj, reg 38 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/rusty1s/pytorch_geometric 2 | 3 | PyTorch Geometric documentation 4 | =============================== 5 | 6 | `PyTorch Geometric `_ is a geometric deep learning extension library for `PyTorch `_. 7 | 8 | It consists of various methods for deep learning on graphs and other irregular structures, also known as `geometric deep learning `_, from a variety of published papers. 9 | In addition, it consists of an easy-to-use mini-batch loader, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds. 10 | 11 | .. toctree:: 12 | :glob: 13 | :maxdepth: 1 14 | :caption: Notes 15 | 16 | notes/installation 17 | notes/introduction 18 | notes/create_dataset 19 | 20 | .. toctree:: 21 | :glob: 22 | :maxdepth: 1 23 | :caption: Package Reference 24 | 25 | modules/nn 26 | modules/data 27 | modules/datasets 28 | modules/transforms 29 | modules/utils 30 | 31 | Indices and Tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` 36 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: required 3 | dist: trusty 4 | matrix: 5 | include: 6 | - python: 2.7 7 | - python: 3.5 8 | - python: 3.6 9 | addons: 10 | apt: 11 | sources: 12 | - ubuntu-toolchain-r-test 13 | packages: 14 | - gcc-4.9 15 | - g++-4.9 16 | before_install: 17 | - export CC="gcc-4.9" 18 | - export CXX="g++-4.9" 19 | install: 20 | - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl; fi 21 | - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl; fi 22 | - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl; fi 23 | - pip install pycodestyle 24 | - pip install flake8 25 | - pip install torch-scatter torch-sparse torch-cluster torch-spline-conv 26 | - pip install codecov 27 | script: 28 | - pycodestyle . 29 | - flake8 . 30 | - python setup.py install 31 | - python setup.py test 32 | - cd docs && pip install -r requirements.txt && make clean && make html && make doctest && cd .. 33 | after_success: 34 | - codecov 35 | notifications: 36 | email: false 37 | -------------------------------------------------------------------------------- /torch_geometric/transforms/random_rotate.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import math 4 | 5 | import torch 6 | from torch_geometric.transforms import LinearTransformation 7 | 8 | 9 | class RandomRotate(object): 10 | def __init__(self, degrees, axis=0): 11 | if isinstance(degrees, numbers.Number): 12 | degrees = (-abs(degrees), abs(degrees)) 13 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 14 | self.degrees = degrees 15 | self.axis = axis 16 | 17 | def __call__(self, data): 18 | degree = math.pi * random.uniform(*self.degrees) / 180.0 19 | sin, cos = math.sin(degree), math.cos(degree) 20 | 21 | if data.pos.size(1) == 2: 22 | matrix = [[cos, sin], [-sin, cos]] 23 | else: 24 | if self.axis == 0: 25 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 26 | elif self.axis == 1: 27 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 28 | else: 29 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 30 | return LinearTransformation(torch.tensor(matrix))(data) 31 | 32 | def __repr__(self): 33 | return '{}({})'.format(self.__class__.__name__, self.degrees) 34 | -------------------------------------------------------------------------------- /docs/build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/rusty1s/pytorch_geometric 2 | 3 | PyTorch Geometric documentation 4 | =============================== 5 | 6 | `PyTorch Geometric `_ is a geometric deep learning extension library for `PyTorch `_. 7 | 8 | It consists of various methods for deep learning on graphs and other irregular structures, also known as `geometric deep learning `_, from a variety of published papers. 9 | In addition, it consists of an easy-to-use mini-batch loader, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds. 10 | 11 | .. toctree:: 12 | :glob: 13 | :maxdepth: 1 14 | :caption: Notes 15 | 16 | notes/installation 17 | notes/introduction 18 | notes/create_dataset 19 | 20 | .. toctree:: 21 | :glob: 22 | :maxdepth: 1 23 | :caption: Package Reference 24 | 25 | modules/nn 26 | modules/data 27 | modules/datasets 28 | modules/transforms 29 | modules/utils 30 | 31 | Indices and Tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` 36 | -------------------------------------------------------------------------------- /test/utils/test_loop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import (contains_self_loops, remove_self_loops, 3 | add_self_loops) 4 | 5 | 6 | def test_contains_self_loops(): 7 | row = torch.tensor([0, 1, 0]) 8 | col = torch.tensor([1, 0, 0]) 9 | 10 | assert contains_self_loops(torch.stack([row, col], dim=0)) 11 | 12 | row = torch.tensor([0, 1, 1]) 13 | col = torch.tensor([1, 0, 2]) 14 | 15 | assert not contains_self_loops(torch.stack([row, col], dim=0)) 16 | 17 | 18 | def test_remove_self_loops(): 19 | row = torch.tensor([1, 0, 1, 0, 2, 1]) 20 | col = torch.tensor([0, 1, 1, 1, 2, 0]) 21 | edge_index = torch.stack([row, col], dim=0) 22 | edge_attr = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] 23 | edge_attr = torch.tensor(edge_attr) 24 | 25 | out = remove_self_loops(edge_index, edge_attr) 26 | assert out[0].tolist() == [[1, 0, 0, 1], [0, 1, 1, 0]] 27 | assert out[1].tolist() == [[1, 2], [3, 4], [7, 8], [11, 12]] 28 | 29 | 30 | def test_add_self_loops(): 31 | row = torch.tensor([0, 1, 0]) 32 | col = torch.tensor([1, 0, 0]) 33 | edge_index = torch.stack([row, col], dim=0) 34 | 35 | expected = [[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]] 36 | assert add_self_loops(edge_index).tolist() == expected 37 | -------------------------------------------------------------------------------- /torch_geometric/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .degree import degree 2 | from .scatter import scatter_ 3 | from .softmax import softmax 4 | from .undirected import is_undirected, to_undirected 5 | from .isolated import contains_isolated_nodes 6 | from .loop import contains_self_loops, remove_self_loops, add_self_loops 7 | from .one_hot import one_hot 8 | from .grid import grid 9 | from .normalized_cut import normalized_cut 10 | from .sparse import dense_to_sparse 11 | from .to_batch import to_batch 12 | from .convert import to_scipy_sparse_matrix, to_networkx 13 | from .metric import (accuracy, true_positive, true_negative, false_positive, 14 | false_negative, precision, recall, f1_score) 15 | 16 | __all__ = [ 17 | 'degree', 18 | 'scatter_', 19 | 'softmax', 20 | 'is_undirected', 21 | 'to_undirected', 22 | 'contains_self_loops', 23 | 'remove_self_loops', 24 | 'add_self_loops', 25 | 'contains_isolated_nodes', 26 | 'one_hot', 27 | 'grid', 28 | 'normalized_cut', 29 | 'dense_to_sparse', 30 | 'to_batch', 31 | 'to_scipy_sparse_matrix', 32 | 'to_networkx', 33 | 'accuracy', 34 | 'true_positive', 35 | 'true_negative', 36 | 'false_positive', 37 | 'false_negative', 38 | 'precision', 39 | 'recall', 40 | 'f1_score', 41 | ] 42 | -------------------------------------------------------------------------------- /torch_geometric/transforms/compose.py: -------------------------------------------------------------------------------- 1 | class Compose(object): 2 | """Composes several transforms together. 3 | 4 | Args: 5 | transforms (list of :obj:`transform` objects): List of transforms to 6 | compose. 7 | 8 | .. testsetup:: 9 | 10 | import torch 11 | from torch_geometric.data import Data 12 | 13 | .. testcode:: 14 | 15 | import torch_geometric.transforms as T 16 | 17 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 18 | edge_index = torch.tensor([[0, 1], [1, 2]]) 19 | data = Data(edge_index=edge_index, pos=pos) 20 | 21 | transform = T.Compose([T.Cartesian(), T.TargetIndegree()]) 22 | data = transform(data) 23 | 24 | print(data.edge_attr) 25 | 26 | .. testoutput:: 27 | 28 | tensor([[0.7500, 0.5000, 1.0000], 29 | [1.0000, 0.5000, 1.0000]]) 30 | """ 31 | 32 | def __init__(self, transforms): 33 | self.transforms = transforms 34 | 35 | def __call__(self, data): 36 | for t in self.transforms: 37 | data = t(data) 38 | return data 39 | 40 | def __repr__(self): 41 | args = [' {},'.format(t) for t in self.transforms] 42 | return '{}([\n{}\n])'.format(self.__class__.__name__, '\n'.join(args)) 43 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/max_pool.py: -------------------------------------------------------------------------------- 1 | from torch_scatter import scatter_max 2 | from torch_geometric.data import Batch 3 | 4 | from .consecutive import consecutive_cluster 5 | from .pool import pool_edge, pool_batch, pool_pos 6 | 7 | 8 | def _max_pool_x(cluster, x, size=None): 9 | fill = -9999999 10 | x, _ = scatter_max(x, cluster, dim=0, dim_size=size, fill_value=fill) 11 | x[x == fill] = 0 12 | return x 13 | 14 | 15 | def max_pool_x(cluster, x, batch, size=None): 16 | if size is not None: 17 | return _max_pool_x(cluster, x, (batch.max().item() + 1) * size) 18 | 19 | cluster, perm = consecutive_cluster(cluster) 20 | x = _max_pool_x(cluster, x) 21 | batch = pool_batch(perm, batch) 22 | 23 | return x, batch 24 | 25 | 26 | def max_pool(cluster, data, transform=None): 27 | cluster, perm = consecutive_cluster(cluster) 28 | 29 | x = _max_pool_x(cluster, data.x) 30 | index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) 31 | batch = None if data.batch is None else pool_batch(perm, data.batch) 32 | pos = None if data.pos is None else pool_pos(cluster, data.pos) 33 | 34 | data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) 35 | 36 | if transform is not None: 37 | data = transform(data) 38 | 39 | return data 40 | -------------------------------------------------------------------------------- /test/datasets/test_planetoid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import os.path as osp 4 | import shutil 5 | 6 | from torch_geometric.datasets import Planetoid 7 | from torch_geometric.data import DataLoader 8 | 9 | 10 | def test_citeseer(): 11 | root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)), 'test') 12 | dataset = Planetoid(root, 'Citeseer') 13 | loader = DataLoader(dataset, batch_size=len(dataset)) 14 | 15 | assert len(dataset) == 1 16 | assert dataset.__repr__() == 'Citeseer()' 17 | 18 | for data in loader: 19 | assert data.num_graphs == 1 20 | assert data.num_nodes == 3327 21 | assert data.num_edges / 2 == 4552 22 | 23 | assert len(data) == 7 24 | assert list(data.x.size()) == [data.num_nodes, 3703] 25 | assert list(data.y.size()) == [data.num_nodes] 26 | assert data.y.max() + 1 == 6 27 | assert data.train_mask.sum() == 6 * 20 28 | assert data.val_mask.sum() == 500 29 | assert data.test_mask.sum() == 1000 30 | assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0 31 | assert list(data.batch.size()) == [data.num_nodes] 32 | 33 | assert data.contains_isolated_nodes() 34 | assert not data.contains_self_loops() 35 | assert data.is_undirected() 36 | 37 | shutil.rmtree(root) 38 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/avg_pool.py: -------------------------------------------------------------------------------- 1 | from torch_scatter import scatter_mean 2 | from torch_geometric.data import Batch 3 | 4 | from .consecutive import consecutive_cluster 5 | from .pool import pool_edge, pool_batch, pool_pos 6 | 7 | 8 | def _avg_pool_x(cluster, x, size=None): 9 | fill = -9999999 10 | x, _ = scatter_mean(x, cluster, dim=0, dim_size=size, fill_value=fill) 11 | x[x == fill] = 0 12 | return x 13 | 14 | 15 | def avg_pool_x(cluster, x, batch=None, size=None): 16 | assert batch is None or size is None 17 | 18 | if size is not None: 19 | return _avg_pool_x(cluster, x, size) 20 | 21 | cluster, perm = consecutive_cluster(cluster) 22 | x = _avg_pool_x(cluster, x) 23 | batch = pool_batch(perm, batch) 24 | 25 | return x, batch 26 | 27 | 28 | def avg_pool(cluster, data, transform=None): 29 | cluster, perm = consecutive_cluster(cluster) 30 | 31 | x = _avg_pool_x(cluster, data.x) 32 | index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) 33 | batch = None if data.batch is None else pool_batch(perm, data.batch) 34 | pos = None if data.pos is None else pool_pos(cluster, data.pos) 35 | 36 | data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) 37 | 38 | if transform is not None: 39 | data = transform(data) 40 | 41 | return data 42 | -------------------------------------------------------------------------------- /torch_geometric/transforms/random_translate.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from itertools import repeat 3 | 4 | import torch 5 | 6 | 7 | class RandomTranslate(object): 8 | r"""Translates node positions by randomly sampled translation values 9 | within a given interval. In contrast to other random transformations, 10 | translation is applied randomly at each position. 11 | 12 | Args: 13 | translate (sequence or float or int): maximum translation in each 14 | dimension, defining the range 15 | :math:`(-\mathrm{translate}, +\mathrm{translate})` to sample from. 16 | If :obj:`translate` is a number instead of a sequence, the same 17 | range is used for each dimension. 18 | """ 19 | 20 | def __init__(self, translate): 21 | self.translate = translate 22 | 23 | def __call__(self, data): 24 | (n, dim), t = data.pos.size(), self.translate 25 | if isinstance(t, numbers.Number): 26 | t = repeat(t, dim) 27 | 28 | ts = [] 29 | for d in range(dim): 30 | ts.append(data.pos.new_empty(n).uniform_(-abs(t[d]), abs(t[d]))) 31 | t = torch.stack(ts, dim=-1) 32 | 33 | data.pos = data.pos + t 34 | return data 35 | 36 | def __repr__(self): 37 | return '{}({})'.format(self.__class__.__name__, *self.scale) 38 | -------------------------------------------------------------------------------- /torch_geometric/transforms/two_hop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import spspmm, coalesce 3 | 4 | from torch_geometric.utils import remove_self_loops 5 | 6 | 7 | class TwoHop(object): 8 | def __call__(self, data): 9 | edge_index, edge_attr = data.edge_index, data.edge_attr 10 | n = data.num_nodes 11 | 12 | fill = 1e16 13 | value = edge_index.new_full( 14 | (edge_index.size(1), ), fill, dtype=torch.float) 15 | 16 | index, value = spspmm(edge_index, value, edge_index, value, n, n, n) 17 | index, value = remove_self_loops(index, value) 18 | 19 | edge_index = torch.cat([edge_index, index], dim=1) 20 | if edge_attr is None: 21 | data.edge_index, _ = coalesce(edge_index, None, n, n) 22 | else: 23 | value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)]) 24 | value = value.expand(-1, *list(edge_attr.size())[1:]) 25 | edge_attr = torch.cat([edge_attr, value], dim=0) 26 | data.edge_index, edge_attr = coalesce( 27 | edge_index, edge_attr, n, n, op='min', fill_value=fill) 28 | edge_attr[edge_attr >= fill] = 0 29 | data.edge_attr = edge_attr 30 | 31 | return data 32 | 33 | def __repr__(self): 34 | return '{}()'.format(self.__class__.__name__) 35 | -------------------------------------------------------------------------------- /torch_geometric/read/sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import coalesce 3 | from torch_geometric.read import parse_txt_array 4 | from torch_geometric.utils import one_hot 5 | from torch_geometric.data import Data 6 | 7 | elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} 8 | 9 | 10 | def parse_sdf(src): 11 | src = src.split('\n')[3:] 12 | num_atoms, num_bonds = [int(item) for item in src[0].split()[:2]] 13 | 14 | atom_block = src[1:num_atoms + 1] 15 | pos = parse_txt_array(atom_block, end=3) 16 | x = torch.tensor([elems[item.split()[3]] for item in atom_block]) 17 | x = one_hot(x, len(elems)) 18 | 19 | bond_block = src[1 + num_atoms:1 + num_atoms + num_bonds] 20 | row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1 21 | row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) 22 | edge_index = torch.stack([row, col], dim=0) 23 | edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 24 | edge_attr = torch.cat([edge_attr, edge_attr], dim=0) 25 | edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms, 26 | num_atoms) 27 | 28 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) 29 | return data 30 | 31 | 32 | def read_sdf(path): 33 | with open(path, 'r') as f: 34 | return parse_sdf(f.read()) 35 | -------------------------------------------------------------------------------- /test/datasets/test_tu_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import sys 4 | import random 5 | import os.path as osp 6 | import shutil 7 | 8 | import pytest 9 | from torch_geometric.datasets import TUDataset 10 | from torch_geometric.data import DataLoader 11 | 12 | 13 | def test_enzymes(): 14 | root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)), 'test') 15 | dataset = TUDataset(root, 'ENZYMES') 16 | loader = DataLoader(dataset, batch_size=len(dataset)) 17 | 18 | assert len(dataset) == 600 19 | assert dataset.__repr__() == 'ENZYMES(600)' 20 | 21 | for data in loader: 22 | assert data.num_graphs == 600 23 | 24 | avg_num_nodes = data.num_nodes / data.num_graphs 25 | assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63 26 | 27 | avg_num_edges = data.num_edges / (2 * data.num_graphs) 28 | assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14 29 | 30 | assert len(data) == 4 31 | assert list(data.x.size()) == [data.num_nodes, 21] 32 | assert list(data.y.size()) == [data.num_graphs] 33 | assert data.y.max() + 1 == 6 34 | assert list(data.batch.size()) == [data.num_nodes] 35 | 36 | assert data.contains_isolated_nodes() 37 | assert not data.contains_self_loops() 38 | assert data.is_undirected() 39 | 40 | shutil.rmtree(root) 41 | -------------------------------------------------------------------------------- /torch_geometric/datasets/planetoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset, download_url 3 | from torch_geometric.read import read_planetoid_data 4 | 5 | 6 | class Planetoid(InMemoryDataset): 7 | url = 'https://github.com/kimiyoung/planetoid/raw/master/data' 8 | 9 | def __init__(self, root, name, transform=None, pre_transform=None): 10 | self.name = name 11 | super(Planetoid, self).__init__(root, transform, pre_transform) 12 | self.data, self.slices = torch.load(self.processed_paths[0]) 13 | 14 | @property 15 | def raw_file_names(self): 16 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] 17 | return ['ind.{}.{}'.format(self.name.lower(), name) for name in names] 18 | 19 | @property 20 | def processed_file_names(self): 21 | return 'data.pt' 22 | 23 | def download(self): 24 | for name in self.raw_file_names: 25 | download_url('{}/{}'.format(self.url, name), self.raw_dir) 26 | 27 | def process(self): 28 | data = read_planetoid_data(self.raw_dir, self.name) 29 | data = data if self.pre_transform is None else self.pre_transform(data) 30 | data, slices = self.collate([data]) 31 | torch.save((data, slices), self.processed_paths[0]) 32 | 33 | def __repr__(self): 34 | return '{}()'.format(self.name) 35 | -------------------------------------------------------------------------------- /torch_geometric/transforms/random_flip.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class RandomFlip(object): 5 | """Flips node positions along a given axis randomly with a given 6 | probability. 7 | 8 | Args: 9 | axis (int): The axis along the position of nodes being flipped. 10 | p (float, optional): Probability of the position of nodes being 11 | flipped. (default: :obj:`0.5`) 12 | 13 | .. testsetup:: 14 | 15 | import torch 16 | from torch_geometric.data import Data 17 | 18 | .. testcode:: 19 | 20 | from torch_geometric.transforms import RandomFlip 21 | 22 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float) 23 | data = Data(pos=pos) 24 | 25 | data = RandomFlip(axis=0, p=1)(data) 26 | 27 | print(data.pos) 28 | 29 | .. testoutput:: 30 | 31 | tensor([[ 1., 1.], 32 | [ 3., 0.], 33 | [-2., -1.]]) 34 | """ 35 | 36 | def __init__(self, axis, p=0.5): 37 | self.axis = axis 38 | self.p = p 39 | 40 | def __call__(self, data): 41 | if random.random() < self.p: 42 | data.pos[:, self.axis] = -data.pos[:, self.axis] 43 | return data 44 | 45 | def __repr__(self): 46 | return '{}(axis={}, p={})'.format(self.__class__.__name__, self.axis, 47 | self.p) 48 | -------------------------------------------------------------------------------- /torch_geometric/nn/prop/gcn_prop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import spmm 3 | from torch_scatter import scatter_add 4 | from torch_geometric.utils import remove_self_loops, add_self_loops 5 | 6 | 7 | class GCNProp(torch.nn.Module): 8 | def __init__(self, improved=False): 9 | super(GCNProp, self).__init__() 10 | self.improved = improved 11 | 12 | def forward(self, x, edge_index, edge_attr=None): 13 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 14 | 15 | if edge_attr is None: 16 | edge_attr = x.new_ones((edge_index.size(1), )) 17 | assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1) 18 | 19 | # Add self-loops to adjacency matrix. 20 | edge_index = add_self_loops(edge_index, x.size(0)) 21 | loop_value = x.new_full((x.size(0), ), 1 if not self.improved else 2) 22 | edge_attr = torch.cat([edge_attr, loop_value], dim=0) 23 | 24 | # Normalize adjacency matrix. 25 | row, col = edge_index 26 | deg = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0)) 27 | deg = deg.pow(-0.5) 28 | deg[deg == float('inf')] = 0 29 | edge_attr = deg[row] * edge_attr * deg[col] 30 | 31 | # Perform the propagation. 32 | out = spmm(edge_index, edge_attr, x.size(0), x) 33 | 34 | return out 35 | 36 | def __repr__(self): 37 | return '{}()'.format(self.__class__.__name__) 38 | -------------------------------------------------------------------------------- /torch_geometric/data/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | 4 | 5 | class Batch(Data): 6 | def __init__(self, batch=None, **kwargs): 7 | super(Batch, self).__init__(**kwargs) 8 | self.batch = batch 9 | 10 | @staticmethod 11 | def from_data_list(data_list): 12 | keys = [set(data.keys) for data in data_list] 13 | keys = list(set.union(*keys)) 14 | assert 'batch' not in keys 15 | 16 | batch = Batch() 17 | 18 | for key in keys: 19 | batch[key] = [] 20 | batch.batch = [] 21 | 22 | cumsum = 0 23 | for i, data in enumerate(data_list): 24 | num_nodes = data.num_nodes 25 | batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) 26 | for key in data.keys: 27 | item = data[key] 28 | item = item + cumsum if batch.cumsum(key, item) else item 29 | batch[key].append(item) 30 | cumsum += num_nodes 31 | 32 | for key in keys: 33 | batch[key] = torch.cat( 34 | batch[key], dim=data_list[0].cat_dim(key, batch[key][0])) 35 | batch.batch = torch.cat(batch.batch, dim=-1) 36 | return batch.contiguous() 37 | 38 | def cumsum(self, key, item): 39 | return item.dim() > 1 and item.dtype == torch.long 40 | 41 | @property 42 | def num_graphs(self): 43 | return self.batch[-1].item() + 1 44 | -------------------------------------------------------------------------------- /torch_geometric/utils/grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import coalesce 3 | 4 | 5 | def grid(height, width, dtype=None, device=None): 6 | edge_index = grid_index(height, width, device) 7 | pos = grid_pos(height, width, dtype, device) 8 | return edge_index, pos 9 | 10 | 11 | def grid_index(height, width, device=None): 12 | w = width 13 | kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1] 14 | kernel = torch.tensor(kernel, device=device) 15 | 16 | row = torch.arange(height * width, dtype=torch.long, device=device) 17 | row = row.view(-1, 1).repeat(1, kernel.size(0)) 18 | col = row + kernel.view(1, -1) 19 | row, col = row.view(height, -1), col.view(height, -1) 20 | index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device) 21 | row, col = row[:, index].view(-1), col[:, index].view(-1) 22 | 23 | mask = (col >= 0) & (col < height * width) 24 | row, col = row[mask], col[mask] 25 | 26 | edge_index = torch.stack([row, col], dim=0) 27 | edge_index, _ = coalesce(edge_index, None, height * width, height * width) 28 | 29 | return edge_index 30 | 31 | 32 | def grid_pos(height, width, dtype=None, device=None): 33 | x = torch.arange(width, dtype=dtype, device=device) 34 | y = (height - 1) - torch.arange(height, dtype=dtype, device=device) 35 | 36 | x = x.repeat(height) 37 | y = y.unsqueeze(-1).repeat(1, width).view(-1) 38 | 39 | return torch.stack([x, y], dim=-1) 40 | -------------------------------------------------------------------------------- /torch_geometric/transforms/constant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Constant(object): 5 | r"""Adds a constant value to each node feature. 6 | 7 | Args: 8 | value (int): The value to add. 9 | cat (bool, optional): Concat value to node features instead 10 | of replacing them. (default: :obj:`True`) 11 | 12 | .. testsetup:: 13 | 14 | import torch 15 | from torch_geometric.data import Data 16 | 17 | .. testcode:: 18 | 19 | from torch_geometric.transforms import Constant 20 | 21 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 22 | data = Data(edge_index=edge_index) 23 | 24 | data = Constant(value=1)(data) 25 | 26 | print(data.x) 27 | 28 | .. testoutput:: 29 | 30 | tensor([[1.], 31 | [1.], 32 | [1.]]) 33 | """ 34 | 35 | def __init__(self, value, cat=True): 36 | self.value = value 37 | self.cat = cat 38 | 39 | def __call__(self, data): 40 | x = data.x 41 | 42 | c = torch.full((data.num_nodes, 1), self.value) 43 | 44 | if x is not None and self.cat: 45 | x = x.view(-1, 1) if x.dim() == 1 else x 46 | data.x = torch.cat([x, c.to(x.dtype).to(x.device)], dim=-1) 47 | else: 48 | data.x = c 49 | 50 | return data 51 | 52 | def __repr__(self): 53 | return '{}({}, cat={})'.format(self.__class__.__name__, self.value, 54 | self.cat) 55 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/edge_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import scatter_ 3 | 4 | from ..inits import reset 5 | 6 | 7 | class EdgeConv(torch.nn.Module): 8 | r"""The edge convolutional operator from the `"Dynamic Graph CNN for 9 | Learning on Point Clouds" `_ paper 10 | 11 | .. math:: 12 | \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} 13 | h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, 14 | \mathbf{x}_j - \mathbf{x}_i), 15 | 16 | where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* a MLP. 17 | 18 | Args: 19 | nn (nn.Sequential): Neural network. 20 | aggr (string): The aggregation operator to use (one of :obj:`"add"`, 21 | :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) 22 | """ 23 | 24 | def __init__(self, nn, aggr='add'): 25 | super(EdgeConv, self).__init__() 26 | self.nn = nn 27 | self.aggr = aggr 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | reset(self.nn) 32 | 33 | def forward(self, x, edge_index): 34 | """""" 35 | row, col = edge_index 36 | x = x.unsqueeze(-1) if x.dim() == 1 else x 37 | 38 | out = torch.cat([x[row], x[col] - x[row]], dim=1) 39 | out = self.nn(out) 40 | out = scatter_(self.aggr, out, row, dim_size=x.size(0)) 41 | 42 | return out 43 | 44 | def __repr__(self): 45 | return '{}({})'.format(self.__class__.__name__, self.nn) 46 | -------------------------------------------------------------------------------- /test/data/test_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | 4 | 5 | def test_data(): 6 | x = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float).t() 7 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 8 | data = Data(x=x, edge_index=edge_index) 9 | 10 | assert data.x.tolist() == x.tolist() 11 | assert data['x'].tolist() == x.tolist() 12 | 13 | assert sorted(data.keys) == ['edge_index', 'x'] 14 | assert len(data) == 2 15 | assert 'x' in data and 'edge_index' in data and 'pos' not in data 16 | 17 | assert data.cat_dim('x', data.x) == 0 18 | assert data.cat_dim('edge_index', data.edge_index) == -1 19 | 20 | data.to(torch.device('cpu')) 21 | assert not data.x.is_contiguous() 22 | data.contiguous() 23 | assert data.x.is_contiguous() 24 | 25 | data['x'] = x + 1 26 | assert data.x.tolist() == (x + 1).tolist() 27 | 28 | assert data.__repr__() == 'Data(edge_index=[2, 4], x=[3, 2])' 29 | 30 | dictionary = {'x': x, 'edge_index': edge_index} 31 | data = Data.from_dict(dictionary) 32 | assert sorted(data.keys) == ['edge_index', 'x'] 33 | 34 | assert not data.contains_isolated_nodes() 35 | assert not data.contains_self_loops() 36 | assert data.is_undirected() 37 | assert not data.is_directed() 38 | 39 | assert data.num_nodes == 3 40 | assert data.num_edges == 4 41 | 42 | data.x = None 43 | assert data.num_nodes == 3 44 | 45 | data.edge_index = None 46 | assert data.num_nodes is None 47 | assert data.num_edges is None 48 | -------------------------------------------------------------------------------- /torch_geometric/utils/convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.sparse 3 | import networkx as nx 4 | 5 | from .num_nodes import maybe_num_nodes 6 | 7 | 8 | def to_scipy_sparse_matrix(edge_index, edge_attr=None, num_nodes=None): 9 | row, col = edge_index.cpu() 10 | 11 | if edge_attr is None: 12 | edge_attr = torch.ones(row.size(0)) 13 | else: 14 | edge_attr = edge_attr.view(-1).cpu() 15 | assert edge_attr.size(0) == row.size(0) 16 | 17 | N = maybe_num_nodes(edge_index, num_nodes) 18 | out = scipy.sparse.coo_matrix((edge_attr, (row, col)), (N, N)) 19 | return out 20 | 21 | 22 | def to_networkx(edge_index, x=None, edge_attr=None, pos=None, num_nodes=None): 23 | num_nodes = num_nodes if x is None else x.size(0) 24 | num_nodes = num_nodes if pos is None else pos.size(0) 25 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 26 | 27 | G = nx.Graph() 28 | 29 | for i in range(num_nodes): 30 | G.add_node(i) 31 | if x is not None: 32 | G.nodes[i]['x'] = x[i].cpu().numpy() 33 | if pos is not None: 34 | G.nodes[i]['pos'] = pos[i].cpu().numpy() 35 | 36 | for i in range(edge_index.size(1)): 37 | source, target = edge_index[0][i].item(), edge_index[1][i].item() 38 | G.add_edge(source, target) 39 | if edge_attr is not None: 40 | if edge_attr.numel() == edge_attr.size(0): 41 | G[source][target]['weight'] = edge_attr[i].item() 42 | else: 43 | G[source][target]['weight'] = edge_attr[i].cpu().numpy() 44 | 45 | return G 46 | -------------------------------------------------------------------------------- /torch_geometric/transforms/distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Distance(object): 5 | r"""Saves the Euclidean distance of linked nodes in its edge attributes. 6 | 7 | Args: 8 | cat (bool, optional): Concat pseudo-coordinates to edge attributes 9 | instead of replacing them. (default: :obj:`True`) 10 | 11 | .. testsetup:: 12 | 13 | import torch 14 | from torch_geometric.data import Data 15 | 16 | .. testcode:: 17 | 18 | from torch_geometric.transforms import Distance 19 | 20 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 21 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 22 | data = Data(edge_index=edge_index, pos=pos) 23 | 24 | data = Distance()(data) 25 | 26 | print(data.edge_attr) 27 | 28 | .. testoutput:: 29 | 30 | tensor([[1.], 31 | [1.], 32 | [2.], 33 | [2.]]) 34 | """ 35 | 36 | def __init__(self, cat=True): 37 | self.cat = cat 38 | 39 | def __call__(self, data): 40 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr 41 | 42 | dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1) 43 | 44 | if pseudo is not None and self.cat: 45 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 46 | data.edge_attr = torch.cat([pseudo, dist.type_as(pseudo)], dim=-1) 47 | else: 48 | data.edge_attr = dist 49 | 50 | return data 51 | 52 | def __repr__(self): 53 | return '{}(cat={})'.format(self.__class__.__name__, self.cat) 54 | -------------------------------------------------------------------------------- /torch_geometric/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .compose import Compose 2 | from .constant import Constant 3 | from .distance import Distance 4 | from .cartesian import Cartesian 5 | from .local_cartesian import LocalCartesian 6 | from .polar import Polar 7 | from .spherical import Spherical 8 | from .one_hot_degree import OneHotDegree 9 | from .target_indegree import TargetIndegree 10 | from .center import Center 11 | from .normalize_scale import NormalizeScale 12 | from .random_translate import RandomTranslate 13 | from .random_flip import RandomFlip 14 | from .linear_transformation import LinearTransformation 15 | from .random_scale import RandomScale 16 | from .random_rotate import RandomRotate 17 | from .random_shear import RandomShear 18 | from .normalize_features import NormalizeFeatures 19 | from .add_self_loops import AddSelfLoops 20 | from .nn_graph import NNGraph 21 | from .radius_graph import RadiusGraph 22 | from .face_to_edge import FaceToEdge 23 | from .sample_points import SamplePoints 24 | from .to_dense import ToDense 25 | from .two_hop import TwoHop 26 | 27 | __all__ = [ 28 | 'Compose', 29 | 'Constant', 30 | 'Distance', 31 | 'Cartesian', 32 | 'LocalCartesian', 33 | 'Polar', 34 | 'Spherical', 35 | 'OneHotDegree', 36 | 'TargetIndegree', 37 | 'Center', 38 | 'NormalizeScale', 39 | 'RandomTranslate', 40 | 'RandomFlip', 41 | 'LinearTransformation', 42 | 'RandomScale', 43 | 'RandomRotate', 44 | 'RandomShear', 45 | 'NormalizeFeatures', 46 | 'AddSelfLoops', 47 | 'NNGraph', 48 | 'RadiusGraph', 49 | 'FaceToEdge', 50 | 'SamplePoints', 51 | 'ToDense', 52 | 'TwoHop', 53 | ] 54 | -------------------------------------------------------------------------------- /torch_geometric/datasets/qm9.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch_geometric.data import (InMemoryDataset, download_url, extract_tar, 5 | Data) 6 | 7 | 8 | class QM9(InMemoryDataset): 9 | url = 'http://www.roemisch-drei.de/qm9.tar.gz' 10 | 11 | def __init__(self, 12 | root, 13 | transform=None, 14 | pre_transform=None, 15 | pre_filter=None): 16 | super(QM9, self).__init__(root, transform, pre_transform, pre_filter) 17 | self.data, self.slices = torch.load(self.processed_paths[0]) 18 | 19 | @property 20 | def raw_file_names(self): 21 | return 'qm9.pt' 22 | 23 | @property 24 | def processed_file_names(self): 25 | return 'data.pt' 26 | 27 | def download(self): 28 | file_path = download_url(self.url, self.raw_dir) 29 | extract_tar(file_path, self.raw_dir, mode='r') 30 | os.unlink(file_path) 31 | 32 | def process(self): 33 | raw_data_list = torch.load(self.raw_paths[0]) 34 | data_list = [ 35 | Data( 36 | x=d['x'], 37 | edge_index=d['edge_index'], 38 | edge_attr=d['edge_attr'], 39 | y=d['y'], 40 | pos=d['pos']) for d in raw_data_list 41 | ] 42 | 43 | if self.pre_filter is not None: 44 | data_list = [data for data in data_list if self.pre_filter(data)] 45 | 46 | if self.pre_transform is not None: 47 | data_list = [self.pre_transform(data) for data in data_list] 48 | 49 | data, slices = self.collate(data_list) 50 | torch.save((data, slices), self.processed_paths[0]) 51 | -------------------------------------------------------------------------------- /torch_geometric/transforms/one_hot_degree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree, one_hot 3 | 4 | 5 | class OneHotDegree(object): 6 | r"""Adds the node degree as one hot encodings to the node features. 7 | 8 | Args: 9 | max_degree (int): Maximum degree. 10 | cat (bool, optional): Concat node degrees to node features instead 11 | of replacing them. (default: :obj:`True`) 12 | 13 | .. testsetup:: 14 | 15 | import torch 16 | from torch_geometric.data import Data 17 | 18 | .. testcode:: 19 | 20 | from torch_geometric.transforms import OneHotDegree 21 | 22 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 23 | data = Data(edge_index=edge_index) 24 | 25 | data = OneHotDegree(max_degree=2)(data) 26 | 27 | print(data.x) 28 | 29 | .. testoutput:: 30 | 31 | tensor([[0., 1., 0.], 32 | [0., 0., 1.], 33 | [0., 1., 0.]]) 34 | """ 35 | 36 | def __init__(self, max_degree, cat=True): 37 | self.max_degree = max_degree 38 | self.cat = cat 39 | 40 | def __call__(self, data): 41 | row, x = data.edge_index[0], data.x 42 | deg = degree(row, data.num_nodes) 43 | deg = one_hot(deg, num_classes=self.max_degree + 1) 44 | 45 | if x is not None and self.cat: 46 | x = x.view(-1, 1) if x.dim() == 1 else x 47 | data.x = torch.cat([x, deg.to(x.dtype)], dim=-1) 48 | else: 49 | data.x = deg 50 | 51 | return data 52 | 53 | def __repr__(self): 54 | return '{}({}, cat={})'.format(self.__class__.__name__, 55 | self.max_degree, self.cat) 56 | -------------------------------------------------------------------------------- /torch_geometric/utils/scatter.py: -------------------------------------------------------------------------------- 1 | import torch_scatter 2 | 3 | 4 | def scatter_(name, src, index, dim_size=None): 5 | r"""Aggregates all values from the :attr:`src` tensor at the indices 6 | specified in the :attr:`index` tensor along the first dimension. 7 | If multiple indices reference the same location, their contributions 8 | are aggregated according to :attr:`name` (:obj:`"add"`, :obj:`"mean"`, 9 | :obj:`"max"`). 10 | 11 | Args: 12 | name (string): The aggregation to use (one of :obj:`"add"`, 13 | :obj:`"mean"`, :obj:`"max"`). 14 | src (Tensor): The source tensor. 15 | index (LongTensor): The indices of elements to scatter. 16 | dim_size (int, optional): Automatically create output tensor with size 17 | :attr:`dim_size` in the first dimension. If :attr:`None`, a minimal 18 | sized output tensor is returned. (default: :obj:`None`) 19 | 20 | :rtype: :class:`Tensor` 21 | 22 | .. testsetup:: 23 | 24 | import torch 25 | 26 | .. testcode:: 27 | 28 | from torch_geometric.utils import scatter_ 29 | src = torch.Tensor([2, 3, -2, 1, 1]) 30 | index = torch.tensor([0, 1, 0, 1, 2]) 31 | output = scatter_("add", src, index) 32 | print(output) 33 | 34 | .. testoutput:: 35 | 36 | tensor([0., 4., 1.]) 37 | """ 38 | 39 | assert name in ['add', 'mean', 'max'] 40 | 41 | op = getattr(torch_scatter, 'scatter_{}'.format(name)) 42 | fill_value = -1e38 if name is 'max' else 0 43 | 44 | out = op(src, index, 0, None, dim_size, fill_value) 45 | if isinstance(out, tuple): 46 | out = out[0] 47 | 48 | if name is 'max': 49 | out[out == fill_value] = 0 50 | 51 | return out 52 | -------------------------------------------------------------------------------- /torch_geometric/nn/dense/sage_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Parameter 4 | 5 | from ..inits import uniform 6 | 7 | 8 | class DenseSAGEConv(torch.nn.Module): 9 | def __init__(self, 10 | in_channels, 11 | out_channels, 12 | norm=True, 13 | norm_embed=True, 14 | bias=True): 15 | super(DenseSAGEConv, self).__init__() 16 | 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | self.norm = norm 20 | self.norm_embed = norm_embed 21 | self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) 22 | 23 | if bias: 24 | self.bias = Parameter(torch.Tensor(out_channels)) 25 | else: 26 | self.register_parameter('bias', None) 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | uniform(self.in_channels, self.weight) 32 | uniform(self.in_channels, self.bias) 33 | 34 | def forward(self, x, adj): 35 | x = x.unsqueeze(0) if x.dim() == 2 else x 36 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 37 | 38 | out = torch.matmul(adj, x) 39 | 40 | if self.norm: 41 | out = out / adj.sum(dim=-1, keepdim=True) 42 | 43 | out = torch.matmul(out, self.weight) 44 | 45 | if self.bias is not None: 46 | out = out + self.bias 47 | 48 | if self.norm_embed: 49 | out = F.normalize(out, p=2, dim=-1) 50 | 51 | return out 52 | 53 | def __repr__(self): 54 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 55 | self.out_channels) 56 | -------------------------------------------------------------------------------- /torch_geometric/datasets/qm7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.io 3 | from torch_geometric.data import InMemoryDataset, download_url, Data 4 | 5 | 6 | class QM7(InMemoryDataset): 7 | url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \ 8 | 'datasets/qm7b.mat' 9 | 10 | def __init__(self, 11 | root, 12 | transform=None, 13 | pre_transform=None, 14 | pre_filter=None): 15 | super(QM7, self).__init__(root, transform, pre_transform, pre_filter) 16 | self.data, self.slices = torch.load(self.processed_paths[0]) 17 | 18 | @property 19 | def raw_file_names(self): 20 | return 'qm7b.mat' 21 | 22 | @property 23 | def processed_file_names(self): 24 | return 'data.pt' 25 | 26 | def download(self): 27 | download_url(self.url, self.raw_dir) 28 | 29 | def process(self): 30 | data = scipy.io.loadmat(self.raw_paths[0]) 31 | coulomb_matrix = torch.from_numpy(data['X']) 32 | target = torch.from_numpy(data['T']).to(torch.float) 33 | 34 | data_list = [] 35 | for i in range(target.shape[0]): 36 | edge_index = coulomb_matrix[i].nonzero().t().contiguous() 37 | edge_attr = coulomb_matrix[i, edge_index[0], edge_index[1]] 38 | y = target[i].view(1, -1) 39 | data = Data(edge_index=edge_index, edge_attr=edge_attr, y=y) 40 | data_list.append(data) 41 | 42 | if self.pre_filter is not None: 43 | data_list = [d for d in data_list if self.pre_filter(d)] 44 | 45 | if self.pre_transform is not None: 46 | data_list = [self.pre_transform(d) for d in data_list] 47 | 48 | data, slices = self.collate(data_list) 49 | torch.save((data, slices), self.processed_paths[0]) 50 | -------------------------------------------------------------------------------- /examples/gat.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import Planetoid 6 | import torch_geometric.transforms as T 7 | from torch_geometric.nn import GATConv 8 | 9 | dataset = 'Cora' 10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) 11 | data = Planetoid(path, dataset, T.NormalizeFeatures())[0] 12 | 13 | 14 | class Net(torch.nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.att1 = GATConv(data.num_features, 8, heads=8, dropout=0.6) 18 | self.att2 = GATConv(8 * 8, data.num_classes, dropout=0.6) 19 | 20 | def forward(self): 21 | x = F.dropout(data.x, p=0.6, training=self.training) 22 | x = F.elu(self.att1(x, data.edge_index)) 23 | x = F.dropout(x, p=0.6, training=self.training) 24 | x = self.att2(x, data.edge_index) 25 | return F.log_softmax(x, dim=1) 26 | 27 | 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | model, data = Net().to(device), data.to(device) 30 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) 31 | 32 | 33 | def train(): 34 | model.train() 35 | optimizer.zero_grad() 36 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() 37 | optimizer.step() 38 | 39 | 40 | def test(): 41 | model.eval() 42 | logits, accs = model(), [] 43 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 44 | pred = logits[mask].max(1)[1] 45 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 46 | accs.append(acc) 47 | return accs 48 | 49 | 50 | for epoch in range(1, 201): 51 | train() 52 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 53 | print(log.format(epoch, *test())) 54 | -------------------------------------------------------------------------------- /torch_geometric/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(pred, target): 5 | return (pred == target).sum().item() / target.numel() 6 | 7 | 8 | def true_positive(pred, target, num_classes): 9 | out = [] 10 | for i in range(num_classes): 11 | out.append(((pred == i) & (target == i)).sum()) 12 | 13 | return torch.tensor(out) 14 | 15 | 16 | def true_negative(pred, target, num_classes): 17 | out = [] 18 | for i in range(num_classes): 19 | out.append(((pred != i) & (target != i)).sum()) 20 | 21 | return torch.tensor(out) 22 | 23 | 24 | def false_positive(pred, target, num_classes): 25 | out = [] 26 | for i in range(num_classes): 27 | out.append(((pred == i) & (target != i)).sum()) 28 | 29 | return torch.tensor(out) 30 | 31 | 32 | def false_negative(pred, target, num_classes): 33 | out = [] 34 | for i in range(num_classes): 35 | out.append(((pred != i) & (target == i)).sum()) 36 | 37 | return torch.tensor(out) 38 | 39 | 40 | def precision(pred, target, num_classes): 41 | tp = true_positive(pred, target, num_classes).to(torch.float) 42 | fp = false_positive(pred, target, num_classes).to(torch.float) 43 | 44 | out = tp / (tp + fp) 45 | out[torch.isnan(out)] = 0 46 | 47 | return out 48 | 49 | 50 | def recall(pred, target, num_classes): 51 | tp = true_positive(pred, target, num_classes).to(torch.float) 52 | fn = false_negative(pred, target, num_classes).to(torch.float) 53 | 54 | out = tp / (tp + fn) 55 | out[torch.isnan(out)] = 0 56 | 57 | return out 58 | 59 | 60 | def f1_score(pred, target, num_classes): 61 | prec = precision(pred, target, num_classes) 62 | rec = recall(pred, target, num_classes) 63 | 64 | score = 2 * (prec * rec) / (prec + rec) 65 | score[torch.isnan(score)] = 0 66 | 67 | return score 68 | -------------------------------------------------------------------------------- /torch_geometric/transforms/target_indegree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree 3 | 4 | 5 | class TargetIndegree(object): 6 | r"""Saves the globally normalized degree of target nodes (mapped to the 7 | fixed interval :math:`[0, 1]`) 8 | 9 | .. math:: 10 | 11 | \mathbf{u}(i,j) = \frac{\deg(j)}{\max_{v \in \mathcal{V}} \deg(v)} 12 | 13 | in its edge attributes. 14 | 15 | Args: 16 | cat (bool, optional): Concat pseudo-coordinates to edge attributes 17 | instead of replacing them. (default: :obj:`True`) 18 | 19 | .. testsetup:: 20 | 21 | import torch 22 | from torch_geometric.data import Data 23 | 24 | .. testcode:: 25 | 26 | from torch_geometric.transforms import TargetIndegree 27 | 28 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 29 | data = Data(edge_index=edge_index) 30 | 31 | data = TargetIndegree()(data) 32 | 33 | print(data.edge_attr) 34 | 35 | .. testoutput:: 36 | 37 | tensor([[1.0000], 38 | [0.5000], 39 | [0.5000], 40 | [1.0000]]) 41 | """ 42 | 43 | def __init__(self, cat=True): 44 | self.cat = cat 45 | 46 | def __call__(self, data): 47 | col, pseudo = data.edge_index[1], data.edge_attr 48 | 49 | deg = degree(col, data.num_nodes) 50 | deg = deg / deg.max() 51 | deg = deg[col] 52 | deg = deg.view(-1, 1) 53 | 54 | if pseudo is not None and self.cat: 55 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 56 | data.edge_attr = torch.cat([pseudo, deg.type_as(pseudo)], dim=-1) 57 | else: 58 | data.edge_attr = deg 59 | 60 | return data 61 | 62 | def __repr__(self): 63 | return '{}(cat={})'.format(self.__class__.__name__, self.cat) 64 | -------------------------------------------------------------------------------- /torch_geometric/transforms/linear_transformation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LinearTransformation(object): 5 | r"""Transforms node positions with a square transformation matrix computed 6 | offline. 7 | 8 | Args: 9 | matrix (Tensor): tensor with shape :math:`[D, D]` where :math:`D` 10 | corresponds to the dimensionality of node positions. 11 | 12 | .. testsetup:: 13 | 14 | import torch 15 | from torch_geometric.data import Data 16 | 17 | .. testcode:: 18 | 19 | from torch_geometric.transforms import LinearTransformation 20 | 21 | pos = torch.tensor([[-1, 1], [-3, 0], [2, -1]], dtype=torch.float) 22 | data = Data(pos=pos) 23 | 24 | matrix = torch.tensor([[2, 0], [0, 2]], dtype=torch.float) 25 | data = LinearTransformation(matrix)(data) 26 | 27 | print(data.pos) 28 | 29 | .. testoutput:: 30 | 31 | tensor([[-2., 2.], 32 | [-6., 0.], 33 | [ 4., -2.]]) 34 | """ 35 | 36 | def __init__(self, matrix): 37 | assert matrix.dim() == 2, ( 38 | 'Transformation matrix should be two-dimensional.') 39 | assert matrix.size(0) == matrix.size(1), ( 40 | 'Transformation matrix should be square. Got [{} x {}] rectangular' 41 | 'matrix.'.format(*matrix.size())) 42 | 43 | self.matrix = matrix 44 | 45 | def __call__(self, data): 46 | pos = data.pos.view(-1, 1) if data.pos.dim() == 1 else data.pos 47 | 48 | assert pos.size(1) == self.matrix.size(0), ( 49 | 'Node position matrix and transformation matrix have incompatible ' 50 | 'shape.') 51 | 52 | data.pos = torch.mm(pos, self.matrix.to(pos.dtype).to(pos.device)) 53 | 54 | return data 55 | 56 | def __repr__(self): 57 | return '{}({})'.format(self.__class__.__name__, self.matrix.tolist()) 58 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/gin_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_add 3 | from torch_geometric.utils import remove_self_loops 4 | 5 | from ..inits import reset 6 | 7 | 8 | class GINConv(torch.nn.Module): 9 | r"""The graph isomorphism operator from the `"How Powerful are 10 | Graph Neural Networks?" `_ paper 11 | 12 | .. math:: 13 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot 14 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right), 15 | 16 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* a MLP. 17 | 18 | Args: 19 | nn (nn.Sequential): Neural network. 20 | eps (float, optional): (Initial) :math:`\epsilon` value. 21 | (default: :obj:`0`) 22 | train_eps (bool optional): If set to :obj:`True`, :math:`\epsilon` will 23 | be a trainable parameter. (default: :obj:`False`) 24 | """ 25 | 26 | def __init__(self, nn, eps=0, train_eps=False): 27 | super(GINConv, self).__init__() 28 | self.nn = nn 29 | self.initial_eps = eps 30 | if train_eps: 31 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 32 | else: 33 | self.register_buffer('eps', torch.Tensor([eps])) 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self): 37 | reset(self.nn) 38 | self.eps.data.fill_(self.initial_eps) 39 | 40 | def forward(self, x, edge_index): 41 | """""" 42 | x = x.unsqueeze(-1) if x.dim() == 1 else x 43 | edge_index, _ = remove_self_loops(edge_index) 44 | row, col = edge_index 45 | 46 | out = scatter_add(x[col], row, dim=0, dim_size=x.size(0)) 47 | out = (1 + self.eps) * x + out 48 | out = self.nn(out) 49 | return out 50 | 51 | def __repr__(self): 52 | return '{}({})'.format(self.__class__.__name__, self.nn) 53 | -------------------------------------------------------------------------------- /examples/agnn.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import Planetoid 6 | import torch_geometric.transforms as T 7 | from torch_geometric.nn import AGNNProp 8 | 9 | dataset = 'Cora' 10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) 11 | data = Planetoid(path, dataset, T.NormalizeFeatures())[0] 12 | 13 | 14 | class Net(torch.nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.fc1 = torch.nn.Linear(data.num_features, 16) 18 | self.prop1 = AGNNProp(requires_grad=False) 19 | self.prop2 = AGNNProp(requires_grad=True) 20 | self.fc2 = torch.nn.Linear(16, data.num_classes) 21 | 22 | def forward(self): 23 | x = F.dropout(data.x, training=self.training) 24 | x = F.relu(self.fc1(x)) 25 | x = self.prop1(x, data.edge_index) 26 | x = self.prop2(x, data.edge_index) 27 | x = F.dropout(x, training=self.training) 28 | x = self.fc2(x) 29 | return F.log_softmax(x, dim=1) 30 | 31 | 32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | model, data = Net().to(device), data.to(device) 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) 35 | 36 | 37 | def train(): 38 | model.train() 39 | optimizer.zero_grad() 40 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() 41 | optimizer.step() 42 | 43 | 44 | def test(): 45 | model.eval() 46 | logits, accs = model(), [] 47 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 48 | pred = logits[mask].max(1)[1] 49 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 50 | accs.append(acc) 51 | return accs 52 | 53 | 54 | for epoch in range(1, 101): 55 | train() 56 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 57 | print(log.format(epoch, *test())) 58 | -------------------------------------------------------------------------------- /torch_geometric/nn/prop/agnn_prop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_sparse import spmm 4 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 5 | 6 | 7 | class AGNNProp(torch.nn.Module): 8 | """Graph Attentional Propagation Layer from the 9 | `"Attention-based Graph Neural Network for Semi-Supervised Learning (AGNN)" 10 | `_ paper. 11 | 12 | Args: 13 | requires_grad (bool, optional): If set to :obj:`False`, the propagation 14 | layer will not be trainable. (default: :obj:`True`) 15 | """ 16 | 17 | def __init__(self, requires_grad=True): 18 | super(AGNNProp, self).__init__() 19 | 20 | if requires_grad: 21 | self.beta = Parameter(torch.Tensor(1)) 22 | else: 23 | self.register_buffer('beta', torch.ones(1)) 24 | 25 | self.requires_grad = requires_grad 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | if self.requires_grad: 30 | self.beta.data.uniform_(0, 1) 31 | 32 | def forward(self, x, edge_index): 33 | num_nodes = x.size(0) 34 | 35 | x = x.unsqueeze(-1) if x.dim() == 1 else x 36 | beta = self.beta if self.requires_grad else self._buffers['beta'] 37 | 38 | # Add self-loops to adjacency matrix. 39 | edge_index, edge_attr = remove_self_loops(edge_index) 40 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) 41 | row, col = edge_index 42 | 43 | # Compute attention coefficients. 44 | norm = torch.norm(x, p=2, dim=1) 45 | alpha = (x[row] * x[col]).sum(dim=1) / (norm[row] * norm[col]) 46 | alpha = softmax(alpha * beta, row, num_nodes=x.size(0)) 47 | 48 | # Perform the propagation. 49 | out = spmm(edge_index, alpha, num_nodes, x) 50 | 51 | return out 52 | 53 | def __repr__(self): 54 | return '{}()'.format(self.__class__.__name__) 55 | -------------------------------------------------------------------------------- /torch_geometric/datasets/faust.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import shutil 3 | 4 | import torch 5 | from torch_geometric.data import InMemoryDataset, extract_zip 6 | from torch_geometric.read import read_ply 7 | 8 | 9 | class FAUST(InMemoryDataset): 10 | url = 'http://faust.is.tue.mpg.de/' 11 | 12 | def __init__(self, 13 | root, 14 | train=True, 15 | transform=None, 16 | pre_transform=None, 17 | pre_filter=None): 18 | super(FAUST, self).__init__(root, transform, pre_transform, pre_filter) 19 | path = self.processed_paths[0] if train else self.processed_paths[1] 20 | self.data, self.slices = torch.load(path) 21 | 22 | @property 23 | def raw_file_names(self): 24 | return 'MPI-FAUST.zip' 25 | 26 | @property 27 | def processed_file_names(self): 28 | return ['training.pt', 'test.pt'] 29 | 30 | def download(self): 31 | raise RuntimeError( 32 | 'Dataset not found. Please download MPI-FAUST.zip from {} and ' 33 | 'move it to {}'.format(self.url, self.raw_dir)) 34 | 35 | def process(self): 36 | extract_zip(self.raw_paths[0], self.raw_dir, log=False) 37 | 38 | path = osp.join(self.raw_dir, 'MPI-FAUST', 'training', 'registrations') 39 | path = osp.join(path, 'tr_reg_{0:03d}.ply') 40 | data_list = [] 41 | for i in range(100): 42 | data = read_ply(path.format(i)) 43 | data.y = torch.tensor([i % 10], dtype=torch.long) 44 | if self.pre_filter is not None and not self.pre_filter(data): 45 | continue 46 | if self.pre_transform is not None: 47 | data = self.pre_transform(data) 48 | data_list.append(data) 49 | 50 | torch.save(self.collate(data_list[:80]), self.processed_paths[0]) 51 | torch.save(self.collate(data_list[80:]), self.processed_paths[1]) 52 | 53 | shutil.rmtree(osp.join(self.raw_dir, 'MPI-FAUST')) 54 | -------------------------------------------------------------------------------- /torch_geometric/transforms/to_dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ToDense(object): 5 | def __init__(self, num_nodes=None): 6 | self.num_nodes = num_nodes 7 | 8 | def __call__(self, data): 9 | assert data.edge_index is not None 10 | 11 | orig_num_nodes = data.num_nodes 12 | if self.num_nodes is None: 13 | num_nodes = orig_num_nodes 14 | else: 15 | assert orig_num_nodes <= self.num_nodes 16 | num_nodes = self.num_nodes 17 | 18 | if data.edge_attr is None: 19 | edge_attr = torch.ones(data.edge_index.size(1), dtype=torch.float) 20 | else: 21 | edge_attr = data.edge_attr 22 | 23 | size = torch.Size([num_nodes, num_nodes] + list(edge_attr.size())[1:]) 24 | adj = torch.sparse_coo_tensor(data.edge_index, edge_attr, size) 25 | data.adj = adj.to_dense() 26 | data.edge_index = None 27 | data.edge_attr = None 28 | 29 | data.mask = torch.zeros(num_nodes, dtype=torch.uint8) 30 | data.mask[:orig_num_nodes] = 1 31 | 32 | if data.x is not None: 33 | size = [num_nodes - data.x.size(0)] + list(data.x.size())[1:] 34 | data.x = torch.cat([data.x, data.x.new_zeros(size)], dim=0) 35 | 36 | if data.pos is not None: 37 | size = [num_nodes - data.pos.size(0)] + list(data.pos.size())[1:] 38 | data.pos = torch.cat([data.pos, data.pos.new_zeros(size)], dim=0) 39 | 40 | if data.y is not None and (data.y.size(0) == orig_num_nodes): 41 | size = [num_nodes - data.y.size(0)] + list(data.y.size())[1:] 42 | data.y = torch.cat([data.y, data.y.new_zeros(size)], dim=0) 43 | 44 | return data 45 | 46 | def __repr__(self): 47 | if self.num_nodes is None: 48 | return '{}()'.format(self.__class__.__name__) 49 | else: 50 | return '{}(num_nodes={})'.format(self.__class__.__name__, 51 | self.num_nodes) 52 | -------------------------------------------------------------------------------- /torch_geometric/transforms/cartesian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Cartesian(object): 5 | r"""Saves the globally normalized spatial relation of linked nodes as 6 | Cartesian coordinates (mapped to the fixed interval :math:`[0, 1]`) 7 | 8 | .. math:: 9 | \mathbf{u}(i,j) = 0.5 + \frac{\mathbf{pos}_j - \mathbf{pos}_i}{2 \cdot 10 | \max_{(v, w) \in \mathcal{E}} | \mathbf{pos}_w - \mathbf{pos}_v|} 11 | 12 | in its edge attributes. 13 | 14 | Args: 15 | cat (bool, optional): Concat pseudo-coordinates to edge attributes 16 | instead of replacing them. (default: :obj:`True`) 17 | 18 | .. testsetup:: 19 | 20 | import torch 21 | from torch_geometric.data import Data 22 | 23 | .. testcode:: 24 | 25 | from torch_geometric.transforms import Cartesian 26 | 27 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 28 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 29 | data = Data(edge_index=edge_index, pos=pos) 30 | 31 | data = Cartesian()(data) 32 | 33 | print(data.edge_attr) 34 | 35 | .. testoutput:: 36 | 37 | tensor([[0.7500, 0.5000], 38 | [0.2500, 0.5000], 39 | [1.0000, 0.5000], 40 | [0.0000, 0.5000]]) 41 | """ 42 | 43 | def __init__(self, cat=True): 44 | self.cat = cat 45 | 46 | def __call__(self, data): 47 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr 48 | 49 | cart = pos[col] - pos[row] 50 | cart = cart / (2 * cart.abs().max()) 51 | cart += 0.5 52 | cart = cart.view(-1, 1) if cart.dim() == 1 else cart 53 | 54 | if pseudo is not None and self.cat: 55 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 56 | data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1) 57 | else: 58 | data.edge_attr = cart 59 | 60 | return data 61 | 62 | def __repr__(self): 63 | return '{}(cat={})'.format(self.__class__.__name__, self.cat) 64 | -------------------------------------------------------------------------------- /torch_geometric/transforms/polar.py: -------------------------------------------------------------------------------- 1 | from math import pi as PI 2 | 3 | import torch 4 | 5 | 6 | class Polar(object): 7 | r"""Saves the globally normalized two-dimensional spatial relation of 8 | linked nodes as polar coordinates (mapped to the fixed interval 9 | :math:`[0, 1]`) in its edge attributes. 10 | 11 | Args: 12 | cat (bool, optional): Concat pseudo-coordinates to edge attributes 13 | instead of replacing them. (default: :obj:`True`) 14 | 15 | .. testsetup:: 16 | 17 | import torch 18 | from torch_geometric.data import Data 19 | 20 | .. testcode:: 21 | 22 | from torch_geometric.transforms import Polar 23 | 24 | pos = torch.tensor([[-1, 0], [0, 0], [0, 2]], dtype=torch.float) 25 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 26 | data = Data(edge_index=edge_index, pos=pos) 27 | 28 | data = Polar()(data) 29 | 30 | print(data.edge_attr) 31 | 32 | .. testoutput:: 33 | 34 | tensor([[0.5000, 0.0000], 35 | [0.5000, 0.5000], 36 | [1.0000, 0.2500], 37 | [1.0000, 0.7500]]) 38 | """ 39 | 40 | def __init__(self, cat=True): 41 | self.cat = cat 42 | 43 | def __call__(self, data): 44 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr 45 | assert pos.dim() == 2 and pos.size(1) == 2 46 | 47 | cart = pos[col] - pos[row] 48 | rho = torch.norm(cart, p=2, dim=-1) 49 | rho = rho / rho.max() 50 | theta = torch.atan2(cart[..., 1], cart[..., 0]) / (2 * PI) 51 | theta += (theta < 0).type_as(theta) 52 | polar = torch.stack([rho, theta], dim=1) 53 | 54 | if pseudo is not None and self.cat: 55 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 56 | data.edge_attr = torch.cat([pseudo, polar.type_as(pos)], dim=-1) 57 | else: 58 | data.edge_attr = polar 59 | 60 | return data 61 | 62 | def __repr__(self): 63 | return '{}(cat={})'.format(self.__class__.__name__, self.cat) 64 | -------------------------------------------------------------------------------- /torch_geometric/transforms/spherical.py: -------------------------------------------------------------------------------- 1 | from math import pi as PI 2 | 3 | import torch 4 | 5 | 6 | class Spherical(object): 7 | r"""Saves the globally normalized three-dimensional spatial relation of 8 | linked nodes as spherical coordinates (mapped to the fixed interval 9 | :math:`[0, 1]`) in its edge attributes. 10 | 11 | Args: 12 | cat (bool, optional): Concat pseudo-coordinates to edge attributes 13 | instead of replacing them. (default: :obj:`True`) 14 | 15 | .. testsetup:: 16 | 17 | import torch 18 | from torch_geometric.data import Data 19 | 20 | .. testcode:: 21 | 22 | from torch_geometric.transforms import Spherical 23 | 24 | pos = torch.tensor([[0, 0, 0], [0, 1, 1]], dtype=torch.float) 25 | edge_index = torch.tensor([[0, 1], [1, 0]]) 26 | data = Data(edge_index=edge_index, pos=pos) 27 | 28 | data = Spherical()(data) 29 | 30 | print(data.edge_attr) 31 | 32 | .. testoutput:: 33 | 34 | tensor([[1.0000, 0.2500, 0.0000], 35 | [1.0000, 0.7500, 1.0000]]) 36 | """ 37 | 38 | def __init__(self, cat=True): 39 | self.cat = cat 40 | 41 | def __call__(self, data): 42 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr 43 | assert pos.dim() == 2 and pos.size(1) == 3 44 | 45 | cart = pos[col] - pos[row] 46 | rho = torch.norm(cart, p=2, dim=-1) 47 | rho = rho / rho.max() 48 | theta = torch.atan2(cart[..., 1], cart[..., 0]) / (2 * PI) 49 | theta += (theta < 0).type_as(theta) 50 | phi = torch.acos(cart[..., 2] / rho) / PI 51 | spher = torch.stack([rho, theta, phi], dim=1) 52 | 53 | if pseudo is not None and self.cat: 54 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 55 | data.edge_attr = torch.cat([pseudo, spher.type_as(pos)], dim=-1) 56 | else: 57 | data.edge_attr = spher 58 | 59 | return data 60 | 61 | def __repr__(self): 62 | return '{}(cat={})'.format(self.__class__.__name__, self.cat) 63 | -------------------------------------------------------------------------------- /torch_geometric/datasets/tu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import torch 5 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 6 | from torch_geometric.read import read_tu_data 7 | 8 | 9 | class TUDataset(InMemoryDataset): 10 | url = 'https://ls11-www.cs.uni-dortmund.de/people/morris/' \ 11 | 'graphkerneldatasets' 12 | 13 | def __init__(self, 14 | root, 15 | name, 16 | transform=None, 17 | pre_transform=None, 18 | pre_filter=None): 19 | self.name = name 20 | super(TUDataset, self).__init__(root, transform, pre_transform, 21 | pre_filter) 22 | self.data, self.slices = torch.load(self.processed_paths[0]) 23 | 24 | @property 25 | def raw_file_names(self): 26 | names = ['A', 'graph_indicator'] 27 | return ['{}_{}.txt'.format(self.name, name) for name in names] 28 | 29 | @property 30 | def processed_file_names(self): 31 | return 'data.pt' 32 | 33 | def download(self): 34 | path = download_url('{}/{}.zip'.format(self.url, self.name), self.root) 35 | extract_zip(path, self.root) 36 | os.unlink(path) 37 | os.rename(osp.join(self.root, self.name), self.raw_dir) 38 | 39 | def process(self): 40 | self.data, self.slices = read_tu_data(self.raw_dir, self.name) 41 | 42 | if self.pre_filter is not None: 43 | data_list = [self.get(idx) for idx in range(len(self))] 44 | data_list = [data for data in data_list if self.pre_filter(data)] 45 | self.data, self.slices = self.collate(data_list) 46 | 47 | if self.pre_transform is not None: 48 | data_list = [self.get(idx) for idx in range(len(self))] 49 | data_list = [self.pre_transform(data) for data in data_list] 50 | self.data, self.slices = self.collate(data_list) 51 | 52 | torch.save((self.data, self.slices), self.processed_paths[0]) 53 | 54 | def __repr__(self): 55 | return '{}({})'.format(self.name, len(self)) 56 | -------------------------------------------------------------------------------- /examples/cora.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import Planetoid 6 | import torch_geometric.transforms as T 7 | from torch_geometric.nn import SplineConv 8 | 9 | dataset = 'Cora' 10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) 11 | data = Planetoid(path, dataset, T.TargetIndegree())[0] 12 | 13 | data.train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8) 14 | data.train_mask[:data.num_nodes - 1000] = 1 15 | data.val_mask = None 16 | data.test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8) 17 | data.test_mask[data.num_nodes - 500:] = 1 18 | 19 | 20 | class Net(torch.nn.Module): 21 | def __init__(self): 22 | super(Net, self).__init__() 23 | self.conv1 = SplineConv(data.num_features, 16, dim=1, kernel_size=2) 24 | self.conv2 = SplineConv(16, data.num_classes, dim=1, kernel_size=2) 25 | 26 | def forward(self): 27 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 28 | x = F.dropout(x, training=self.training) 29 | x = F.elu(self.conv1(x, edge_index, edge_attr)) 30 | x = F.dropout(x, training=self.training) 31 | x = self.conv2(x, edge_index, edge_attr) 32 | return F.log_softmax(x, dim=1) 33 | 34 | 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | model, data = Net().to(device), data.to(device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3) 38 | 39 | 40 | def train(): 41 | model.train() 42 | optimizer.zero_grad() 43 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() 44 | optimizer.step() 45 | 46 | 47 | def test(): 48 | model.eval() 49 | logits, accs = model(), [] 50 | for _, mask in data('train_mask', 'test_mask'): 51 | pred = logits[mask].max(1)[1] 52 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 53 | accs.append(acc) 54 | return accs 55 | 56 | 57 | for epoch in range(1, 201): 58 | train() 59 | log = 'Epoch: {:03d}, Train: {:.4f}, Test: {:.4f}' 60 | print(log.format(epoch, *test())) 61 | -------------------------------------------------------------------------------- /examples/gcn.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import Planetoid 6 | import torch_geometric.transforms as T 7 | from torch_geometric.nn import GCNConv, ChebConv # noqa 8 | 9 | dataset = 'Cora' 10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) 11 | data = Planetoid(path, dataset, T.NormalizeFeatures())[0] 12 | 13 | 14 | class Net(torch.nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.conv1 = GCNConv(data.num_features, 16, improved=False) 18 | self.conv2 = GCNConv(16, data.num_classes, improved=False) 19 | # self.conv1 = ChebConv(data.num_features, 16, K=2) 20 | # self.conv2 = ChebConv(16, data.num_features, K=2) 21 | 22 | def forward(self): 23 | x, edge_index = data.x, data.edge_index 24 | x = F.relu(self.conv1(x, edge_index)) 25 | x = F.dropout(x, training=self.training) 26 | x = self.conv2(x, edge_index) 27 | return F.log_softmax(x, dim=1) 28 | 29 | 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | model, data = Net().to(device), data.to(device) 32 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) 33 | 34 | 35 | def train(): 36 | model.train() 37 | optimizer.zero_grad() 38 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() 39 | optimizer.step() 40 | 41 | 42 | def test(): 43 | model.eval() 44 | logits, accs = model(), [] 45 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 46 | pred = logits[mask].max(1)[1] 47 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 48 | accs.append(acc) 49 | return accs 50 | 51 | 52 | best_val_acc = test_acc = 0 53 | for epoch in range(1, 101): 54 | train() 55 | train_acc, val_acc, tmp_test_acc = test() 56 | if val_acc > best_val_acc: 57 | best_val_acc = val_acc 58 | test_acc = tmp_test_acc 59 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 60 | print(log.format(epoch, train_acc, best_val_acc, test_acc)) 61 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/set2set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Set2Set(torch.nn.Module): 5 | def __init__(self, 6 | in_channels, 7 | hidden_channels, 8 | processing_steps, 9 | num_layers=1): 10 | super(Set2Set, self).__init__() 11 | 12 | self.in_channels = in_channels 13 | self.out_channels = 2 * in_channels 14 | self.hidden_channels = in_channels 15 | self.processing_steps = processing_steps 16 | self.num_layers = num_layers 17 | 18 | self.lstm = torch.nn.LSTM(self.out_channels, hidden_channels, 19 | num_layers) 20 | 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | self.lstm.reset_parameters() 25 | 26 | def forward(self, x, batch): 27 | """""" 28 | batch_size = batch.max().item() + 1 29 | 30 | # Bring x into shape [batch_size, max_nodes, in_channels]. 31 | xs = x.split(torch.bincount(batch).tolist()) 32 | max_nodes = max([t.size(0) for t in xs]) 33 | xs = [[t, t.new_zeros(max_nodes - t.size(0), t.size(1))] for t in xs] 34 | xs = [torch.cat(t, dim=0) for t in xs] 35 | x = torch.stack(xs, dim=0) 36 | 37 | h = (x.new_zeros((self.num_layers, batch_size, self.hidden_channels)), 38 | x.new_zeros((self.num_layers, batch_size, self.hidden_channels))) 39 | q_star = x.new_zeros(1, batch_size, self.out_channels) 40 | 41 | for i in range(self.processing_steps): 42 | q, h = self.lstm(q_star, h) 43 | q = q.view(batch_size, 1, self.in_channels) 44 | e = (x * q).sum(dim=-1) # Dot product. 45 | a = torch.softmax(e, dim=-1) 46 | a = a.view(batch_size, max_nodes, 1) 47 | r = (a * x).sum(dim=1, keepdim=True) 48 | q_star = torch.cat([q, r], dim=-1) 49 | q_star = q_star.view(1, batch_size, self.out_channels) 50 | 51 | q_star = q_star.view(batch_size, self.out_channels) 52 | return q_star 53 | 54 | def __repr__(self): 55 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 56 | self.out_channels) 57 | -------------------------------------------------------------------------------- /torch_geometric/transforms/local_cartesian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_max 3 | 4 | 5 | class LocalCartesian(object): 6 | r"""Saves the locally normalized spatial relation of linked nodes as 7 | Cartesian coordinates (mapped to the fixed interval :math:`[0, 1]`) 8 | 9 | .. math:: 10 | \mathbf{u}(i,j) = 0.5 + \frac{\mathbf{pos}_j - \mathbf{pos}_i}{2 \cdot 11 | \max_{v \in \mathcal{N}(i)} | \mathbf{pos}_v - \mathbf{pos}_i|} 12 | 13 | in its edge attributes. 14 | 15 | Args: 16 | cat (bool, optional): Concat pseudo-coordinates to edge attributes 17 | instead of replacing them. (default: :obj:`True`) 18 | 19 | .. testsetup:: 20 | 21 | import torch 22 | from torch_geometric.data import Data 23 | 24 | .. testcode:: 25 | 26 | from torch_geometric.transforms import LocalCartesian 27 | 28 | pos = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float) 29 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 30 | data = Data(edge_index=edge_index, pos=pos) 31 | 32 | data = LocalCartesian()(data) 33 | 34 | print(data.edge_attr) 35 | 36 | .. testoutput:: 37 | 38 | tensor([[1.0000, 0.5000], 39 | [0.2500, 0.5000], 40 | [1.0000, 0.5000], 41 | [0.0000, 0.5000]]) 42 | """ 43 | 44 | def __init__(self, cat=True): 45 | self.cat = cat 46 | 47 | def __call__(self, data): 48 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr 49 | 50 | cart = pos[col] - pos[row] 51 | max_cart, _ = scatter_max(cart.abs(), row, 0, dim_size=pos.size(0)) 52 | cart = cart / (2 * max_cart.max(dim=1, keepdim=True)[0][row]) 53 | cart += 0.5 54 | cart = cart.view(-1, 1) if cart.dim() == 1 else cart 55 | 56 | if pseudo is not None and self.cat: 57 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 58 | data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1) 59 | else: 60 | data.edge_attr = cart 61 | 62 | return data 63 | 64 | def __repr__(self): 65 | return '{}(cat={})'.format(self.__class__.__name__, self.cat) 66 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | 4 | from ..inits import uniform 5 | from ..prop import GCNProp 6 | 7 | 8 | class GCNConv(torch.nn.Module): 9 | r"""The graph convolutional operator from the `"Semi-supervised 10 | Classfication with Graph Convolutional Networks" 11 | `_ paper 12 | 13 | .. math:: 14 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 15 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 16 | 17 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 18 | adjacency matrix with inserted self-loops and 19 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 20 | 21 | Args: 22 | in_channels (int): Size of each input sample. 23 | out_channels (int): Size of each output sample. 24 | improved (bool, optional): If set to :obj:`True`, the layer computes 25 | :math:`\hat{A}` as :math:`A + 2I`. (default: :obj:`False`) 26 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 27 | an additive bias. (default: :obj:`True`) 28 | """ 29 | 30 | def __init__(self, in_channels, out_channels, improved=False, bias=True): 31 | super(GCNConv, self).__init__() 32 | 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.prop = GCNProp(improved) 36 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 37 | 38 | if bias: 39 | self.bias = Parameter(torch.Tensor(out_channels)) 40 | else: 41 | self.register_parameter('bias', None) 42 | 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | size = self.in_channels 47 | uniform(size, self.weight) 48 | uniform(size, self.att_weight) 49 | 50 | def forward(self, x, edge_index, edge_attr=None): 51 | """""" 52 | out = torch.mm(x, self.weight) 53 | out = self.prop(out, edge_index, edge_attr) 54 | 55 | if self.bias is not None: 56 | out = out + self.bias 57 | 58 | return out 59 | 60 | def __repr__(self): 61 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 62 | self.out_channels) 63 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/graph_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.utils import remove_self_loops, scatter_ 4 | 5 | from ..inits import uniform 6 | 7 | 8 | class GraphConv(torch.nn.Module): 9 | r"""The graph neural network operator from the `"Weisfeiler and Leman Go 10 | Neural: Higher-order Graph Neural Networks" 11 | `_ paper 12 | 13 | .. math:: 14 | \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{(1)} \mathbf{x}_i + 15 | \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}^{(2)} \mathbf{x}_j. 16 | 17 | Args: 18 | in_channels (int): Size of each input sample. 19 | out_channels (int): Size of each output sample. 20 | aggr (string): The aggregation operator to use (one of :obj:`"add"`, 21 | :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) 22 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 23 | an additive bias. (default: :obj:`True`) 24 | """ 25 | 26 | def __init__(self, in_channels, out_channels, aggr='add', bias=True): 27 | super(GraphConv, self).__init__() 28 | 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.aggr = aggr 32 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 33 | self.root = Parameter(torch.Tensor(in_channels, out_channels)) 34 | 35 | if bias: 36 | self.bias = Parameter(torch.Tensor(out_channels)) 37 | else: 38 | self.register_parameter('bias', None) 39 | 40 | self.reset_parameters() 41 | 42 | def reset_parameters(self): 43 | size = self.in_channels 44 | uniform(size, self.weight) 45 | uniform(size, self.root) 46 | uniform(size, self.bias) 47 | 48 | def forward(self, x, edge_index): 49 | """""" 50 | x = x.unsqueeze(-1) if x.dim() == 1 else x 51 | edge_index, _ = remove_self_loops(edge_index) 52 | row, col = edge_index 53 | 54 | out = torch.mm(x, self.weight) 55 | out = scatter_(self.aggr, out[col], row, dim_size=x.size(0)) 56 | out = out + torch.mm(x, self.root) 57 | 58 | if self.bias is not None: 59 | out = out + self.bias 60 | 61 | return out 62 | 63 | def __repr__(self): 64 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 65 | self.out_channels) 66 | -------------------------------------------------------------------------------- /torch_geometric/data/dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os.path as osp 3 | 4 | import torch.utils.data 5 | 6 | from .makedirs import makedirs 7 | 8 | 9 | def to_list(x): 10 | if not isinstance(x, collections.Iterable) or isinstance(x, str): 11 | x = [x] 12 | return x 13 | 14 | 15 | def files_exist(files): 16 | return all([osp.exists(f) for f in files]) 17 | 18 | 19 | class Dataset(torch.utils.data.Dataset): 20 | @property 21 | def raw_file_names(self): 22 | raise NotImplementedError 23 | 24 | @property 25 | def processed_file_names(self): 26 | raise NotImplementedError 27 | 28 | def download(self): 29 | raise NotImplementedError 30 | 31 | def process(self): 32 | raise NotImplementedError 33 | 34 | def __len__(self): 35 | raise NotImplementedError 36 | 37 | def get(self, idx): 38 | raise NotImplementedError 39 | 40 | def __init__(self, 41 | root, 42 | transform=None, 43 | pre_transform=None, 44 | pre_filter=None): 45 | super(Dataset, self).__init__() 46 | 47 | self.root = osp.expanduser(osp.normpath(root)) 48 | self.raw_dir = osp.join(self.root, 'raw') 49 | self.processed_dir = osp.join(self.root, 'processed') 50 | self.transform = transform 51 | self.pre_transform = pre_transform 52 | self.pre_filter = pre_filter 53 | 54 | self._download() 55 | self._process() 56 | 57 | @property 58 | def raw_paths(self): 59 | files = to_list(self.raw_file_names) 60 | return [osp.join(self.raw_dir, f) for f in files] 61 | 62 | @property 63 | def processed_paths(self): 64 | files = to_list(self.processed_file_names) 65 | return [osp.join(self.processed_dir, f) for f in files] 66 | 67 | def _download(self): 68 | if files_exist(self.raw_paths): 69 | return 70 | 71 | makedirs(self.raw_dir) 72 | self.download() 73 | 74 | def _process(self): 75 | if files_exist(self.processed_paths): 76 | return 77 | 78 | print('Processing...') 79 | 80 | makedirs(self.processed_dir) 81 | self.process() 82 | 83 | print('Done!') 84 | 85 | def __getitem__(self, idx): 86 | data = self.get(idx) 87 | data = data if self.transform is None else self.transform(data) 88 | return data 89 | 90 | def __repr__(self): 91 | return '{}({})'.format(self.__class__.__name__, len(self)) 92 | -------------------------------------------------------------------------------- /torch_geometric/datasets/mnist_superpixels.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 5 | extract_tar) 6 | 7 | 8 | class MNISTSuperpixels(InMemoryDataset): 9 | url = 'http://ls7-www.cs.uni-dortmund.de/cvpr_geometric_dl/' \ 10 | 'mnist_superpixels.tar.gz' 11 | 12 | def __init__(self, 13 | root, 14 | train=True, 15 | transform=None, 16 | pre_transform=None, 17 | pre_filter=None): 18 | super(MNISTSuperpixels, self).__init__(root, transform, pre_transform, 19 | pre_filter) 20 | path = self.processed_paths[0] if train else self.processed_paths[1] 21 | self.data, self.slices = torch.load(path) 22 | 23 | @property 24 | def raw_file_names(self): 25 | return ['training.pt', 'test.pt'] 26 | 27 | @property 28 | def processed_file_names(self): 29 | return ['training.pt', 'test.pt'] 30 | 31 | def download(self): 32 | path = download_url(self.url, self.raw_dir) 33 | extract_tar(path, self.raw_dir, mode='r') 34 | os.unlink(path) 35 | 36 | def process(self): 37 | for raw_path, path in zip(self.raw_paths, self.processed_paths): 38 | x, edge_index, edge_slice, pos, y = torch.load(raw_path) 39 | edge_index, y = edge_index.to(torch.long), y.to(torch.long) 40 | m, n = y.size(0), 75 41 | x, pos = x.view(m * n, 1), pos.view(m * n, 2) 42 | node_slice = torch.arange(0, (m + 1) * n, step=n, dtype=torch.long) 43 | graph_slice = torch.arange(m + 1, dtype=torch.long) 44 | self.data = Data(x=x, edge_index=edge_index, y=y, pos=pos) 45 | self.slices = { 46 | 'x': node_slice, 47 | 'edge_index': edge_slice, 48 | 'y': graph_slice, 49 | 'pos': node_slice 50 | } 51 | 52 | if self.pre_filter is not None: 53 | data_list = [self.get(idx) for idx in range(len(self))] 54 | data_list = [d for d in data_list if self.pre_filter(d)] 55 | self.data, self.slices = self.collate(data_list) 56 | 57 | if self.pre_transform is not None: 58 | data_list = [self.get(idx) for idx in range(len(self))] 59 | data_list = [self.pre_transform(data) for data in data_list] 60 | self.data, self.slices = self.collate(data_list) 61 | 62 | torch.save((self.data, self.slices), path) 63 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/sage_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Parameter 4 | from torch_scatter import scatter_mean 5 | from torch_geometric.utils import remove_self_loops, add_self_loops 6 | 7 | from ..inits import uniform 8 | 9 | 10 | class SAGEConv(torch.nn.Module): 11 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on 12 | Large Graphs" `_ paper 13 | 14 | .. math:: 15 | \mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot 16 | \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j) 17 | 18 | \mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i} 19 | {\| \mathbf{\hat{x}}_i \|_2}. 20 | 21 | Args: 22 | in_channels (int): Size of each input sample. 23 | out_channels (int): Size of each output sample. 24 | normalize (bool, optional): If set to :obj:`False`, output features 25 | will not be :math:`\ell^2`-normalized. 26 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 27 | an additive bias. (default: :obj:`True`) 28 | """ 29 | 30 | def __init__(self, in_channels, out_channels, normalize=True, bias=True): 31 | super(SAGEConv, self).__init__() 32 | 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.normalize = normalize 36 | self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) 37 | 38 | if bias: 39 | self.bias = Parameter(torch.Tensor(out_channels)) 40 | else: 41 | self.register_parameter('bias', None) 42 | 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | size = self.weight.size(0) 47 | uniform(size, self.weight) 48 | uniform(size, self.bias) 49 | 50 | def forward(self, x, edge_index): 51 | """""" 52 | edge_index, _ = remove_self_loops(edge_index) 53 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) 54 | 55 | x = x.unsqueeze(-1) if x.dim() == 1 else x 56 | row, col = edge_index 57 | 58 | out = scatter_mean(x[col], row, dim=0, dim_size=x.size(0)) 59 | out = torch.matmul(out, self.weight) 60 | 61 | if self.bias is not None: 62 | out = out + self.bias 63 | 64 | if self.normalize: 65 | out = F.normalize(out, p=2, dim=-1) 66 | 67 | return out 68 | 69 | def __repr__(self): 70 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 71 | self.out_channels) 72 | -------------------------------------------------------------------------------- /torch_geometric/datasets/ppi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | 5 | import torch 6 | import numpy as np 7 | from torch_sparse import coalesce 8 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 9 | extract_zip) 10 | 11 | 12 | class PPI(InMemoryDataset): 13 | url = 'http://snap.stanford.edu/graphsage/ppi.zip' 14 | 15 | def __init__(self, root, transform=None, pre_transform=None): 16 | super(PPI, self).__init__(root, transform, pre_transform) 17 | self.data, self.slices = torch.load(self.processed_paths[0]) 18 | 19 | @property 20 | def raw_file_names(self): 21 | prefix = self.__class__.__name__.lower() 22 | suffix = ['G.json', 'feats.npy', 'class_map.json'] 23 | return ['{}-{}'.format(prefix, s) for s in suffix] 24 | 25 | @property 26 | def processed_file_names(self): 27 | return 'data.pt' 28 | 29 | def download(self): 30 | path = download_url(self.url, self.root) 31 | extract_zip(path, self.root) 32 | os.unlink(path) 33 | name = self.__class__.__name__.lower() 34 | os.rename(osp.join(self.root, name), self.raw_dir) 35 | 36 | def process(self): 37 | with open(self.raw_paths[0], 'r') as f: 38 | graph_data = json.load(f) 39 | 40 | mask = torch.zeros(len(graph_data['nodes']), dtype=torch.uint8) 41 | for i in graph_data['nodes']: 42 | mask[i['id']] = 1 if i['val'] else (2 if i['test'] else 0) 43 | train_mask, val_mask, test_mask = mask == 0, mask == 1, mask == 2 44 | 45 | row, col = [], [] 46 | for i in graph_data['links']: 47 | row.append(i['source']) 48 | col.append(i['target']) 49 | edge_index = torch.stack([torch.tensor(row), torch.tensor(col)], dim=0) 50 | edge_index, _ = coalesce(edge_index, None, mask.size(0), mask.size(0)) 51 | 52 | x = torch.from_numpy(np.load(self.raw_paths[1])).float() 53 | 54 | with open(self.raw_paths[2], 'r') as f: 55 | y_data = json.load(f) 56 | 57 | y = [] 58 | for i in range(len(y_data)): 59 | y.append(y_data[str(i)]) 60 | y = torch.tensor(y, dtype=torch.float) 61 | 62 | data = Data(x=x, edge_index=edge_index, y=y) 63 | data.train_mask = train_mask 64 | data.val_mask = val_mask 65 | data.test_mask = test_mask 66 | 67 | data = data if self.pre_transform is None else self.pre_transform(data) 68 | data, slices = self.collate([data]) 69 | torch.save((data, slices), self.processed_paths[0]) 70 | 71 | def __repr__(self): 72 | return '{}()'.format(self.__class__.__name__) 73 | -------------------------------------------------------------------------------- /torch_geometric/datasets/coma.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | import torch 5 | from torch_geometric.data import InMemoryDataset, extract_zip 6 | from torch_geometric.read import read_ply 7 | 8 | 9 | class CoMA(InMemoryDataset): 10 | url = 'https://coma.is.tue.mpg.de/' 11 | 12 | categories = [ 13 | 'bareteeth', 14 | 'cheeks_in', 15 | 'eyebrow', 16 | 'high_smile', 17 | 'lips_back', 18 | 'lips_up', 19 | 'mouth_down', 20 | 'mouth_extreme', 21 | 'mouth_middle', 22 | 'mouth_open', 23 | 'mouth_side', 24 | 'mouth_up', 25 | ] 26 | 27 | def __init__(self, 28 | root, 29 | train=True, 30 | transform=None, 31 | pre_transform=None, 32 | pre_filter=None): 33 | super(CoMA, self).__init__(root, transform, pre_transform, pre_filter) 34 | path = self.processed_paths[0] if train else self.processed_paths[1] 35 | self.data, self.slices = torch.load(path) 36 | 37 | @property 38 | def raw_file_names(self): 39 | return 'COMA_data.zip' 40 | 41 | @property 42 | def processed_file_names(self): 43 | return ['training.pt', 'test.pt'] 44 | 45 | def download(self): 46 | raise RuntimeError( 47 | 'Dataset not found. Please download COMA_data.zip from {} and ' 48 | 'move it to {}'.format(self.url, self.raw_dir)) 49 | 50 | def process(self): 51 | folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*'))) 52 | if len(folders) == 0: 53 | extract_zip(self.raw_paths[0], self.raw_dir, log=False) 54 | folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*'))) 55 | 56 | train_data_list, test_data_list = [], [] 57 | for folder in folders: 58 | for i, category in enumerate(self.categories): 59 | files = sorted(glob(osp.join(folder, category, '*.ply'))) 60 | for j, f in enumerate(files): 61 | data = read_ply(f) 62 | data.y = torch.tensor([i], dtype=torch.long) 63 | if self.pre_filter is not None and\ 64 | not self.pre_filter(data): 65 | continue 66 | if self.pre_transform is not None: 67 | data = self.pre_transform(data) 68 | 69 | if (j % 100) < 90: 70 | train_data_list.append(data) 71 | else: 72 | test_data_list.append(data) 73 | 74 | torch.save(self.collate(train_data_list), self.processed_paths[0]) 75 | torch.save(self.collate(test_data_list), self.processed_paths[1]) 76 | -------------------------------------------------------------------------------- /torch_geometric/datasets/modelnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | 5 | import torch 6 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 7 | from torch_geometric.read import read_off 8 | 9 | 10 | class ModelNet(InMemoryDataset): 11 | urls = { 12 | '10': 13 | 'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip', 14 | '40': 15 | 'http://modelnet.cs.princeton.edu/ModelNet40.zip' 16 | } 17 | 18 | def __init__(self, 19 | root, 20 | name='10', 21 | train=True, 22 | transform=None, 23 | pre_transform=None, 24 | pre_filter=None): 25 | assert name in ['10', '40'] 26 | self.name = name 27 | super(ModelNet, self).__init__(root, transform, pre_transform, 28 | pre_filter) 29 | path = self.processed_paths[0] if train else self.processed_paths[1] 30 | self.data, self.slices = torch.load(path) 31 | 32 | @property 33 | def raw_file_names(self): 34 | return [ 35 | 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 36 | 'night_stand', 'sofa', 'table', 'toilet' 37 | ] 38 | 39 | @property 40 | def processed_file_names(self): 41 | return ['training.pt', 'test.pt'] 42 | 43 | def download(self): 44 | path = download_url(self.urls[self.name], self.root) 45 | extract_zip(path, self.root) 46 | os.unlink(path) 47 | folder = osp.join(self.root, 'ModelNet{}'.format(self.name)) 48 | os.rename(folder, self.raw_dir) 49 | 50 | def process(self): 51 | torch.save(self.process_set('train'), self.processed_paths[0]) 52 | torch.save(self.process_set('test'), self.processed_paths[1]) 53 | 54 | def process_set(self, dataset): 55 | categories = glob.glob(osp.join(self.raw_dir, '*', '')) 56 | categories = sorted([x.split('/')[-2] for x in categories]) 57 | 58 | data_list = [] 59 | for target, category in enumerate(categories): 60 | folder = osp.join(self.raw_dir, category, dataset) 61 | paths = glob.glob('{}/{}_*.off'.format(folder, category)) 62 | for path in paths: 63 | data = read_off(path) 64 | data.y = torch.tensor([target]) 65 | data_list.append(data) 66 | 67 | if self.pre_filter is not None: 68 | data_list = [d for d in data_list if self.pre_filter(d)] 69 | 70 | if self.pre_transform is not None: 71 | data_list = [self.pre_transform(d) for d in data_list] 72 | 73 | return self.collate(data_list) 74 | 75 | def __repr__(self): 76 | return '{}{}({})'.format(self.__class__.__name__, self.name, len(self)) 77 | -------------------------------------------------------------------------------- /examples/mnist_voxel_grid.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | import torch_geometric.transforms as T 7 | from torch_geometric.data import DataLoader 8 | from torch_geometric.nn import SplineConv, voxel_grid, max_pool, max_pool_x 9 | 10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') 11 | transform = T.Cartesian(cat=False) 12 | train_dataset = MNISTSuperpixels(path, True, transform=transform) 13 | test_dataset = MNISTSuperpixels(path, False, transform=transform) 14 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 15 | test_loader = DataLoader(test_dataset, batch_size=64) 16 | 17 | 18 | class Net(torch.nn.Module): 19 | def __init__(self): 20 | super(Net, self).__init__() 21 | self.conv1 = SplineConv(1, 32, dim=2, kernel_size=5) 22 | self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5) 23 | self.conv3 = SplineConv(64, 64, dim=2, kernel_size=5) 24 | self.fc1 = torch.nn.Linear(4 * 64, 128) 25 | self.fc2 = torch.nn.Linear(128, 10) 26 | 27 | def forward(self, data): 28 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 29 | cluster = voxel_grid(data.pos, data.batch, size=5, start=0, end=28) 30 | data = max_pool(cluster, data, transform=transform) 31 | 32 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 33 | cluster = voxel_grid(data.pos, data.batch, size=7, start=0, end=28) 34 | data = max_pool(cluster, data, transform=transform) 35 | 36 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 37 | cluster = voxel_grid(data.pos, data.batch, size=14, start=0, end=27.99) 38 | x = max_pool_x(cluster, data.x, data.batch, size=4) 39 | 40 | x = x.view(-1, self.fc1.weight.size(1)) 41 | x = F.elu(self.fc1(x)) 42 | x = F.dropout(x, training=self.training) 43 | x = self.fc2(x) 44 | return F.log_softmax(x, dim=1) 45 | 46 | 47 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 48 | model = Net().to(device) 49 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 50 | 51 | 52 | def train(epoch): 53 | model.train() 54 | 55 | if epoch == 6: 56 | for param_group in optimizer.param_groups: 57 | param_group['lr'] = 0.001 58 | 59 | if epoch == 16: 60 | for param_group in optimizer.param_groups: 61 | param_group['lr'] = 0.0001 62 | 63 | for data in train_loader: 64 | data = data.to(device) 65 | optimizer.zero_grad() 66 | F.nll_loss(model(data), data.y).backward() 67 | optimizer.step() 68 | 69 | 70 | def test(): 71 | model.eval() 72 | correct = 0 73 | 74 | for data in test_loader: 75 | data = data.to(device) 76 | pred = model(data).max(1)[1] 77 | correct += pred.eq(data.y).sum().item() 78 | return correct / len(test_dataset) 79 | 80 | 81 | for epoch in range(1, 21): 82 | train(epoch) 83 | test_acc = test() 84 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc)) 85 | -------------------------------------------------------------------------------- /docs/source/notes/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | We have outsourced a lot of functionality of `PyTorch Geometric `_ to other packages, which needs to be installed in advance. 5 | These packages come with their own CPU and GPU kernel implementations based on the newly introduced `C++/CUDA extensions `_ in PyTorch 0.4.0. 6 | 7 | .. note:: 8 | We do not recommend installation as root user on your system python. 9 | Please setup an `Anaconda/Miniconda `_ environment or create a `Docker image `_. 10 | 11 | Please follow the steps below for a successful installation: 12 | 13 | #. Ensure that at least PyTorch 0.4.1 is installed: 14 | 15 | .. code-block:: none 16 | 17 | $ python -c "import torch; print(torch.__version__)" 18 | >>> 0.4.1 19 | 20 | #. Ensure CUDA is setup correctly (optional): 21 | 22 | #. Check if PyTorch is installed with CUDA support: 23 | 24 | .. code-block:: none 25 | 26 | $ python -c "import torch; print(torch.cuda.is_available())" 27 | >>> True 28 | 29 | #. Add CUDA to ``$PATH`` and ``$CPATH`` (note that your actual CUDA path may vary from ``/usr/local/cuda``): 30 | 31 | .. code-block:: none 32 | 33 | $ PATH=/usr/local/cuda/bin:$PATH 34 | $ echo $PATH 35 | >>> /usr/local/cuda/bin:... 36 | 37 | $ CPATH=/usr/local/cuda/include:$CPATH 38 | $ echo $CPATH 39 | >>> /usr/local/cuda/include:... 40 | 41 | #. Verify that ``nvcc`` is accessible from terminal: 42 | 43 | .. code-block:: none 44 | 45 | $ nvcc --version 46 | 47 | #. Install all needed packages: 48 | 49 | .. code-block:: none 50 | 51 | $ pip install --upgrade torch-scatter 52 | $ pip install --upgrade torch-sparse 53 | $ pip install --upgrade torch-cluster 54 | $ pip install --upgrade torch-spline-conv 55 | $ pip install torch-geometric 56 | 57 | In rare cases, CUDA or python path issues can prevent a succesful installation. 58 | Unfortunately, the error messages of ``pip`` are not very meaningful. 59 | You should therefore clone the respective package and check where the error occurs, e.g.: 60 | 61 | .. code-block:: none 62 | 63 | $ git clone https://github.com/rusty1s/pytorch_scatter 64 | $ cd pytorch_scatter 65 | $ python setup.py install # Check for CUDA compilation or link error. 66 | $ python setup.py test # Verify installation by running test suite. 67 | 68 | C++/CUDA extensions on macOS 69 | ---------------------------- 70 | 71 | .. note:: 72 | As reported by some users, the use of miniconda is absolutely necessary on macOS. 73 | 74 | In order to compile CUDA extensions on macOS, you need to replace the call 75 | 76 | .. code-block:: python 77 | 78 | def spawn(self, cmd): 79 | spawn(cmd, dry_run=self.dry_run) 80 | 81 | with 82 | 83 | .. code-block:: python 84 | 85 | import subprocess 86 | 87 | def spawn(self, cmd): 88 | subprocess.call(cmd) 89 | 90 | in ``lib/python{xxx}/distutils/ccompiler.py``. 91 | -------------------------------------------------------------------------------- /docs/build/html/_sources/notes/installation.rst.txt: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | We have outsourced a lot of functionality of `PyTorch Geometric `_ to other packages, which needs to be installed in advance. 5 | These packages come with their own CPU and GPU kernel implementations based on `C FFI `_ and the newly introduced `C++/CUDA extensions `_ in PyTorch 0.4.0. 6 | 7 | .. note:: 8 | We do not recommend installation as root user on your system python. 9 | Please setup an `Anaconda/Miniconda `_ environment or create a `Docker image `_. 10 | 11 | Please follow the steps below for a successful installation: 12 | 13 | #. Ensure that at least PyTorch 0.4.1 is installed: 14 | 15 | .. code-block:: none 16 | 17 | $ python -c "import torch; print(torch.__version__)" 18 | >>> 0.4.1 19 | 20 | #. Ensure CUDA is setup correctly (optional): 21 | 22 | #. Check if PyTorch is installed with CUDA support: 23 | 24 | .. code-block:: none 25 | 26 | $ python -c "import torch; print(torch.cuda.is_available())" 27 | >>> True 28 | 29 | #. Add CUDA to ``$PATH`` and ``$CPATH`` (note that your actual CUDA path may vary from ``/usr/local/cuda``): 30 | 31 | .. code-block:: none 32 | 33 | $ PATH=/usr/local/cuda/bin:$PATH 34 | $ echo $PATH 35 | >>> /usr/local/cuda/bin:... 36 | 37 | $ CPATH=/usr/local/cuda/include:$CPATH 38 | $ echo $CPATH 39 | >>> /usr/local/cuda/include:... 40 | 41 | #. Verify that ``nvcc`` is accessible from terminal: 42 | 43 | .. code-block:: none 44 | 45 | $ nvcc --version 46 | 47 | #. Install all needed packages: 48 | 49 | .. code-block:: none 50 | 51 | $ pip install cffi 52 | $ pip install --upgrade torch-scatter 53 | $ pip install --upgrade torch-sparse 54 | $ pip install --upgrade torch-cluster 55 | $ pip install --upgrade torch-spline-conv 56 | $ pip install torch-geometric 57 | 58 | In rare cases, CUDA or python path issues can prevent a succesful installation. 59 | Unfortunately, the error messages of ``pip`` are not very meaningful. 60 | You should therefore clone the respective package and check where the error occurs, e.g.: 61 | 62 | .. code-block:: none 63 | 64 | $ git clone https://github.com/rusty1s/pytorch_scatter 65 | $ cd pytorch_scatter 66 | $ python setup.py install # Check for CUDA compilation or link error. 67 | $ python setup.py test # Verify installation by running test suite. 68 | 69 | C++/CUDA extensions on macOS 70 | ---------------------------- 71 | 72 | In order to compile C++/CUDA extensions on macOS, you need to replace the call 73 | 74 | .. code-block:: python 75 | 76 | def spawn(self, cmd): 77 | spawn(cmd, dry_run=self.dry_run) 78 | 79 | with 80 | 81 | .. code-block:: python 82 | 83 | def spawn(self, cmd): 84 | subprocess.call(cmd) 85 | 86 | in ``distutils/ccompiler.py``. 87 | Do not forget to ``import subprocess`` at the top of the file. 88 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/nn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.utils import scatter_ 4 | 5 | from ..inits import reset, uniform 6 | 7 | 8 | class NNConv(torch.nn.Module): 9 | r"""The continuous kernel-based convolutional operator adapted from the 10 | `"Neural Message Passing for Quantum Chemistry" 11 | `_ paper 12 | 13 | .. math:: 14 | \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{(1)} \mathbf{x}_i + 15 | \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot 16 | h_{\mathbf{\Theta^{(2)}}}(\mathbf{e}_{i,j}), 17 | 18 | where :math:`h_{\mathbf{\Theta^{(2)}}}` denotes a neural network, *.i.e.* 19 | a MLP. 20 | 21 | Args: 22 | in_channels (int): Size of each input sample. 23 | out_channels (int): Size of each output sample. 24 | nn (nn.Sequential): Neural network. 25 | aggr (string): The aggregation operator to use (one of :obj:`"add"`, 26 | :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) 27 | root_weight (bool, optional): If set to :obj:`False`, the layer will 28 | not add the transformed root node features to the output. 29 | (default: :obj:`True`) 30 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 31 | an additive bias. (default: :obj:`True`) 32 | """ 33 | 34 | def __init__(self, 35 | in_channels, 36 | out_channels, 37 | nn, 38 | aggr="add", 39 | root_weight=True, 40 | bias=True): 41 | super(NNConv, self).__init__() 42 | 43 | self.in_channels = in_channels 44 | self.out_channels = out_channels 45 | self.nn = nn 46 | self.aggr = aggr 47 | 48 | if root_weight: 49 | self.root = Parameter(torch.Tensor(in_channels, out_channels)) 50 | else: 51 | self.register_parameter('root', None) 52 | 53 | if bias: 54 | self.bias = Parameter(torch.Tensor(out_channels)) 55 | else: 56 | self.register_parameter('bias', None) 57 | 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | reset(self.nn) 62 | size = self.in_channels 63 | uniform(size, self.root) 64 | uniform(size, self.bias) 65 | 66 | def forward(self, x, edge_index, pseudo): 67 | """""" 68 | x = x.unsqueeze(-1) if x.dim() == 1 else x 69 | pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo 70 | row, col = edge_index 71 | 72 | out = self.nn(pseudo) 73 | out = out.view(-1, self.in_channels, self.out_channels) 74 | out = torch.matmul(x[col].unsqueeze(1), out).squeeze(1) 75 | out = scatter_(self.aggr, out, row, dim_size=x.size(0)) 76 | 77 | if self.root is not None: 78 | out = out + torch.mm(x, self.root) 79 | 80 | if self.bias is not None: 81 | out = out + self.bias 82 | 83 | return out 84 | 85 | def __repr__(self): 86 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 87 | self.out_channels) 88 | -------------------------------------------------------------------------------- /docs/build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../font/fontawesome_webfont.eot");src:url("../font/fontawesome_webfont.eot?#iefix") format("embedded-opentype"),url("../font/fontawesome_webfont.woff") format("woff"),url("../font/fontawesome_webfont.ttf") format("truetype"),url("../font/fontawesome_webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:0.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;border-top:solid 10px #343131;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} 2 | /*# sourceMappingURL=badge_only.css.map */ 3 | -------------------------------------------------------------------------------- /examples/mnist_graclus.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | import torch_geometric.transforms as T 7 | from torch_geometric.data import DataLoader 8 | from torch_geometric.utils import normalized_cut 9 | from torch_geometric.nn import (SplineConv, graclus, max_pool, max_pool_x, 10 | global_mean_pool) 11 | 12 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') 13 | train_dataset = MNISTSuperpixels(path, True, transform=T.Cartesian()) 14 | test_dataset = MNISTSuperpixels(path, False, transform=T.Cartesian()) 15 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 16 | test_loader = DataLoader(test_dataset, batch_size=64) 17 | d = train_dataset.data 18 | 19 | 20 | def normalized_cut_2d(edge_index, pos): 21 | row, col = edge_index 22 | edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) 23 | return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) 24 | 25 | 26 | class Net(torch.nn.Module): 27 | def __init__(self): 28 | super(Net, self).__init__() 29 | self.conv1 = SplineConv(d.num_features, 32, dim=2, kernel_size=5) 30 | self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5) 31 | self.fc1 = torch.nn.Linear(64, 128) 32 | self.fc2 = torch.nn.Linear(128, d.num_classes) 33 | 34 | def forward(self, data): 35 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 36 | weight = normalized_cut_2d(data.edge_index, data.pos) 37 | cluster = graclus(data.edge_index, weight, data.x.size(0)) 38 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) 39 | 40 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 41 | weight = normalized_cut_2d(data.edge_index, data.pos) 42 | cluster = graclus(data.edge_index, weight, data.x.size(0)) 43 | x, batch = max_pool_x(cluster, data.x, data.batch) 44 | 45 | x = global_mean_pool(x, batch) 46 | x = F.elu(self.fc1(x)) 47 | x = F.dropout(x, training=self.training) 48 | return F.log_softmax(self.fc2(x), dim=1) 49 | 50 | 51 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 52 | model = Net().to(device) 53 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 54 | 55 | 56 | def train(epoch): 57 | model.train() 58 | 59 | if epoch == 16: 60 | for param_group in optimizer.param_groups: 61 | param_group['lr'] = 0.001 62 | 63 | if epoch == 26: 64 | for param_group in optimizer.param_groups: 65 | param_group['lr'] = 0.0001 66 | 67 | for data in train_loader: 68 | data = data.to(device) 69 | optimizer.zero_grad() 70 | F.nll_loss(model(data), data.y).backward() 71 | optimizer.step() 72 | 73 | 74 | def test(): 75 | model.eval() 76 | correct = 0 77 | 78 | for data in test_loader: 79 | data = data.to(device) 80 | pred = model(data).max(1)[1] 81 | correct += pred.eq(data.y).sum().item() 82 | return correct / len(test_dataset) 83 | 84 | 85 | for epoch in range(1, 31): 86 | train(epoch) 87 | test_acc = test() 88 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc)) 89 | -------------------------------------------------------------------------------- /torch_geometric/data/in_memory_dataset.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat, product 2 | 3 | import torch 4 | from torch_geometric.data import Dataset, Data 5 | 6 | 7 | class InMemoryDataset(Dataset): 8 | @property 9 | def raw_file_names(self): 10 | raise NotImplementedError 11 | 12 | @property 13 | def processed_file_names(self): 14 | raise NotImplementedError 15 | 16 | def download(self): 17 | raise NotImplementedError 18 | 19 | def process(self): 20 | raise NotImplementedError 21 | 22 | def __init__(self, 23 | root, 24 | transform=None, 25 | pre_transform=None, 26 | pre_filter=None): 27 | super(InMemoryDataset, self).__init__(root, transform, pre_transform, 28 | pre_filter) 29 | self.data, self.slices = None, None 30 | 31 | @property 32 | def num_features(self): 33 | return self[0].num_features 34 | 35 | @property 36 | def num_classes(self): 37 | data = self.data 38 | return data.y.max().item() + 1 if data.y.dim() == 1 else data.y.size(1) 39 | 40 | def __len__(self): 41 | return self.slices[list(self.slices.keys())[0]].size(0) - 1 42 | 43 | def __getitem__(self, idx): 44 | if isinstance(idx, int): 45 | data = self.get(idx) 46 | data = data if self.transform is None else self.transform(data) 47 | return data 48 | elif isinstance(idx, slice): 49 | return self.split(range(*idx.indices(len(self)))) 50 | elif isinstance(idx, torch.LongTensor): 51 | return self.split(idx) 52 | elif isinstance(idx, torch.ByteTensor): 53 | return self.split(idx.nonzero()) 54 | 55 | raise IndexError( 56 | 'Only integers, slices (`:`) and long or byte tensors are valid ' 57 | 'indices (got {}).'.format(type(idx).__name__)) 58 | 59 | def shuffle(self): 60 | return self.split(torch.randperm(len(self))) 61 | 62 | def get(self, idx): 63 | data = Data() 64 | for key in self.data.keys: 65 | item, slices = self.data[key], self.slices[key] 66 | s = list(repeat(slice(None), item.dim())) 67 | s[self.data.cat_dim(key, item)] = slice(slices[idx], 68 | slices[idx + 1]) 69 | data[key] = item[s] 70 | return data 71 | 72 | def split(self, index): 73 | copy = self.__class__.__new__(self.__class__) 74 | copy.__dict__ = self.__dict__.copy() 75 | copy.data, copy.slices = self.collate([self.get(i) for i in index]) 76 | return copy 77 | 78 | def collate(self, data_list): 79 | keys = data_list[0].keys 80 | data = Data() 81 | 82 | for key in keys: 83 | data[key] = [] 84 | slices = {key: [0] for key in keys} 85 | 86 | for item, key in product(data_list, keys): 87 | data[key].append(item[key]) 88 | s = slices[key][-1] + item[key].size(item.cat_dim(key, item[key])) 89 | slices[key].append(s) 90 | 91 | for key in keys: 92 | data[key] = torch.cat( 93 | data[key], dim=data_list[0].cat_dim(key, data_list[0][key])) 94 | slices[key] = torch.LongTensor(slices[key]) 95 | 96 | return data, slices 97 | -------------------------------------------------------------------------------- /examples/enzymes_topk_pool.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import TUDataset 6 | from torch_geometric.data import DataLoader 7 | from torch_geometric.nn import GraphConv, TopKPooling 8 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 9 | 10 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ENZYMES') 11 | dataset = TUDataset(path, name='ENZYMES') 12 | dataset = dataset.shuffle() 13 | n = len(dataset) // 10 14 | test_dataset = dataset[:n] 15 | train_dataset = dataset[n:] 16 | test_loader = DataLoader(test_dataset, batch_size=60) 17 | train_loader = DataLoader(train_dataset, batch_size=60) 18 | 19 | 20 | class Net(torch.nn.Module): 21 | def __init__(self): 22 | super(Net, self).__init__() 23 | 24 | self.conv1 = GraphConv(dataset.num_features, 128) 25 | self.pool1 = TopKPooling(128, ratio=0.8) 26 | self.conv2 = GraphConv(128, 128) 27 | self.pool2 = TopKPooling(128, ratio=0.8) 28 | self.conv3 = GraphConv(128, 128) 29 | self.pool3 = TopKPooling(128, ratio=0.8) 30 | 31 | self.lin1 = torch.nn.Linear(256, 128) 32 | self.lin2 = torch.nn.Linear(128, 64) 33 | self.lin3 = torch.nn.Linear(64, dataset.num_classes) 34 | 35 | def forward(self, data): 36 | x, edge_index, batch = data.x, data.edge_index, data.batch 37 | 38 | x = F.relu(self.conv1(x, edge_index)) 39 | x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch) 40 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 41 | 42 | x = F.relu(self.conv2(x, edge_index)) 43 | x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch) 44 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 45 | 46 | x = F.relu(self.conv3(x, edge_index)) 47 | x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch) 48 | x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 49 | 50 | x = x1 + x2 + x3 51 | 52 | x = F.relu(self.lin1(x)) 53 | x = F.dropout(x, p=0.5, training=self.training) 54 | x = F.relu(self.lin2(x)) 55 | x = F.log_softmax(self.lin3(x), dim=-1) 56 | 57 | return x 58 | 59 | 60 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 61 | model = Net().to(device) 62 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) 63 | 64 | 65 | def train(epoch): 66 | model.train() 67 | 68 | loss_all = 0 69 | for data in train_loader: 70 | data = data.to(device) 71 | optimizer.zero_grad() 72 | output = model(data) 73 | loss = F.nll_loss(output, data.y) 74 | loss.backward() 75 | loss_all += data.num_graphs * loss.item() 76 | optimizer.step() 77 | return loss_all / len(train_dataset) 78 | 79 | 80 | def test(loader): 81 | model.eval() 82 | 83 | correct = 0 84 | for data in loader: 85 | data = data.to(device) 86 | pred = model(data).max(dim=1)[1] 87 | correct += pred.eq(data.y).sum().item() 88 | return correct / len(loader.dataset) 89 | 90 | 91 | for epoch in range(1, 201): 92 | loss = train(epoch) 93 | train_acc = test(train_loader) 94 | test_acc = test(test_loader) 95 | print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.5f}, Test Acc: {:.5f}'. 96 | format(epoch, loss, train_acc, test_acc)) 97 | -------------------------------------------------------------------------------- /examples/mnist_mpnn.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_geometric.datasets import MNISTSuperpixels 7 | import torch_geometric.transforms as T 8 | from torch_geometric.data import DataLoader 9 | from torch_geometric.utils import normalized_cut 10 | from torch_geometric.nn import (NNConv, graclus, max_pool, max_pool_x, 11 | global_mean_pool) 12 | 13 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') 14 | train_dataset = MNISTSuperpixels(path, True, transform=T.Cartesian()) 15 | test_dataset = MNISTSuperpixels(path, False, transform=T.Cartesian()) 16 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 17 | test_loader = DataLoader(test_dataset, batch_size=64) 18 | d = train_dataset.data 19 | 20 | 21 | def normalized_cut_2d(edge_index, pos): 22 | row, col = edge_index 23 | edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) 24 | return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) 25 | 26 | 27 | class Net(nn.Module): 28 | def __init__(self): 29 | super(Net, self).__init__() 30 | n1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), nn.Linear(25, 32)) 31 | self.conv1 = NNConv(d.num_features, 32, n1) 32 | 33 | n2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), nn.Linear(25, 2048)) 34 | self.conv2 = NNConv(32, 64, n2) 35 | 36 | self.fc1 = torch.nn.Linear(64, 128) 37 | self.fc2 = torch.nn.Linear(128, d.num_classes) 38 | 39 | def forward(self, data): 40 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 41 | weight = normalized_cut_2d(data.edge_index, data.pos) 42 | cluster = graclus(data.edge_index, weight, data.x.size(0)) 43 | data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) 44 | 45 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 46 | weight = normalized_cut_2d(data.edge_index, data.pos) 47 | cluster = graclus(data.edge_index, weight, data.x.size(0)) 48 | x, batch = max_pool_x(cluster, data.x, data.batch) 49 | 50 | x = global_mean_pool(x, batch) 51 | x = F.elu(self.fc1(x)) 52 | x = F.dropout(x, training=self.training) 53 | return F.log_softmax(self.fc2(x), dim=1) 54 | 55 | 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | model = Net().to(device) 58 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 59 | 60 | 61 | def train(epoch): 62 | model.train() 63 | 64 | if epoch == 16: 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = 0.001 67 | 68 | if epoch == 26: 69 | for param_group in optimizer.param_groups: 70 | param_group['lr'] = 0.0001 71 | 72 | for data in train_loader: 73 | data = data.to(device) 74 | optimizer.zero_grad() 75 | F.nll_loss(model(data), data.y).backward() 76 | optimizer.step() 77 | 78 | 79 | def test(): 80 | model.eval() 81 | correct = 0 82 | 83 | for data in test_loader: 84 | data = data.to(device) 85 | pred = model(data).max(1)[1] 86 | correct += pred.eq(data.y).sum().item() 87 | return correct / len(test_dataset) 88 | 89 | 90 | for epoch in range(1, 31): 91 | train(epoch) 92 | test_acc = test() 93 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc)) 94 | -------------------------------------------------------------------------------- /examples/faust.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.datasets import FAUST 6 | import torch_geometric.transforms as T 7 | from torch_geometric.data import DataLoader 8 | from torch_geometric.nn import SplineConv 9 | from torch_geometric.utils import degree 10 | 11 | 12 | class MyTransform(object): 13 | def __call__(self, data): 14 | data.face, data.x = None, torch.ones(data.num_nodes, 1) 15 | return data 16 | 17 | 18 | def norm(x, edge_index): 19 | deg = degree(edge_index[0], x.size(0), x.dtype, x.device) + 1 20 | return x / deg.unsqueeze(-1) 21 | 22 | 23 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FAUST') 24 | pre_transform = T.Compose([T.FaceToEdge(), MyTransform()]) 25 | train_dataset = FAUST(path, True, T.Cartesian(), pre_transform) 26 | test_dataset = FAUST(path, False, T.Cartesian(), pre_transform) 27 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 28 | test_loader = DataLoader(test_dataset, batch_size=1) 29 | d = train_dataset[0] 30 | 31 | 32 | class Net(torch.nn.Module): 33 | def __init__(self): 34 | super(Net, self).__init__() 35 | self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, norm=False) 36 | self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, norm=False) 37 | self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False) 38 | self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False) 39 | self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False) 40 | self.conv6 = SplineConv(64, 64, dim=3, kernel_size=5, norm=False) 41 | self.fc1 = torch.nn.Linear(64, 256) 42 | self.fc2 = torch.nn.Linear(256, d.num_nodes) 43 | 44 | def forward(self, data): 45 | x, edge_index, pseudo = data.x, data.edge_index, data.edge_attr 46 | x = F.elu(norm(self.conv1(x, edge_index, pseudo), edge_index)) 47 | x = F.elu(norm(self.conv2(x, edge_index, pseudo), edge_index)) 48 | x = F.elu(norm(self.conv3(x, edge_index, pseudo), edge_index)) 49 | x = F.elu(norm(self.conv4(x, edge_index, pseudo), edge_index)) 50 | x = F.elu(norm(self.conv5(x, edge_index, pseudo), edge_index)) 51 | x = F.elu(norm(self.conv6(x, edge_index, pseudo), edge_index)) 52 | x = F.elu(self.fc1(x)) 53 | x = F.dropout(x, training=self.training) 54 | x = self.fc2(x) 55 | return F.log_softmax(x, dim=1) 56 | 57 | 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | model = Net().to(device) 60 | target = torch.arange(d.num_nodes, dtype=torch.long, device=device) 61 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 62 | 63 | 64 | def train(epoch): 65 | model.train() 66 | 67 | if epoch == 61: 68 | for param_group in optimizer.param_groups: 69 | param_group['lr'] = 0.001 70 | 71 | for data in train_loader: 72 | optimizer.zero_grad() 73 | F.nll_loss(model(data.to(device)), target).backward() 74 | optimizer.step() 75 | 76 | 77 | def test(): 78 | model.eval() 79 | correct = 0 80 | 81 | for data in test_loader: 82 | pred = model(data.to(device)).max(1)[1] 83 | correct += pred.eq(target).sum().item() 84 | return correct / (len(test_dataset) * d.num_nodes) 85 | 86 | 87 | for epoch in range(1, 101): 88 | train(epoch) 89 | test_acc = test() 90 | print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc)) 91 | -------------------------------------------------------------------------------- /torch_geometric/read/planetoid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from itertools import repeat 4 | 5 | import torch 6 | from torch_sparse import coalesce 7 | from torch_geometric.data import Data 8 | from torch_geometric.read import read_txt_array 9 | from torch_geometric.utils import remove_self_loops 10 | 11 | try: 12 | import cPickle as pickle 13 | except ImportError: 14 | import pickle 15 | 16 | 17 | def read_planetoid_data(folder, prefix): 18 | """Reads the planetoid data format. 19 | ind.{}.x 20 | """ 21 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] 22 | items = [read_file(folder, prefix, name) for name in names] 23 | x, tx, allx, y, ty, ally, graph, test_index = items 24 | train_index = torch.arange(y.size(0), dtype=torch.long) 25 | val_index = torch.arange(y.size(0), y.size(0) + 500, dtype=torch.long) 26 | sorted_test_index = test_index.sort()[0] 27 | 28 | if prefix.lower() == 'citeseer': 29 | # There are some isolated nodes in the Citeseer graph, resulting in 30 | # none consecutive test indices. We need to identify them and add them 31 | # as zero vectors to `tx` and `ty`. 32 | len_test_indices = (test_index.max() - test_index.min()).item() + 1 33 | 34 | tx_ext = torch.zeros(len_test_indices, tx.size(1)) 35 | tx_ext[sorted_test_index - test_index.min(), :] = tx 36 | ty_ext = torch.zeros(len_test_indices, ty.size(1)) 37 | ty_ext[sorted_test_index - test_index.min(), :] = ty 38 | 39 | tx, ty = tx_ext, ty_ext 40 | 41 | x = torch.cat([allx, tx], dim=0) 42 | y = torch.cat([ally, ty], dim=0).max(dim=1)[1] 43 | 44 | x[test_index] = x[sorted_test_index] 45 | y[test_index] = y[sorted_test_index] 46 | 47 | train_mask = sample_mask(train_index, num_nodes=y.size(0)) 48 | val_mask = sample_mask(val_index, num_nodes=y.size(0)) 49 | test_mask = sample_mask(test_index, num_nodes=y.size(0)) 50 | 51 | edge_index = edge_index_from_dict(graph, num_nodes=y.size(0)) 52 | 53 | data = Data(x=x, edge_index=edge_index, y=y) 54 | data.train_mask = train_mask 55 | data.val_mask = val_mask 56 | data.test_mask = test_mask 57 | 58 | return data 59 | 60 | 61 | def read_file(folder, prefix, name): 62 | path = osp.join(folder, 'ind.{}.{}'.format(prefix.lower(), name)) 63 | 64 | if name == 'test.index': 65 | return read_txt_array(path, dtype=torch.long) 66 | 67 | with open(path, 'rb') as f: 68 | if sys.version_info > (3, 0): 69 | out = pickle.load(f, encoding='latin1') 70 | else: 71 | out = pickle.load(f) 72 | 73 | if name == 'graph': 74 | return out 75 | 76 | out = out.todense() if hasattr(out, 'todense') else out 77 | out = torch.Tensor(out) 78 | return out 79 | 80 | 81 | def edge_index_from_dict(graph_dict, num_nodes=None): 82 | row, col = [], [] 83 | for key, value in graph_dict.items(): 84 | row += repeat(key, len(value)) 85 | col += value 86 | edge_index = torch.stack([torch.tensor(row), torch.tensor(col)], dim=0) 87 | # NOTE: There are duplicated edges and self loops in the datasets. Other 88 | # implementations do not remove them! 89 | edge_index, _ = remove_self_loops(edge_index) 90 | edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes) 91 | return edge_index 92 | 93 | 94 | def sample_mask(index, num_nodes): 95 | mask = torch.zeros((num_nodes, ), dtype=torch.uint8) 96 | mask[index] = 1 97 | return mask 98 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/topk_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | 5 | from ..inits import uniform 6 | from ...utils.num_nodes import maybe_num_nodes 7 | 8 | 9 | def topk(x, ratio, batch): 10 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 11 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item() 12 | k = (ratio * num_nodes.to(torch.float)).round().to(torch.long) 13 | 14 | cum_num_nodes = torch.cat( 15 | [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) 16 | 17 | index = torch.arange(batch.size(0), dtype=torch.long, device=x.device) 18 | index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes) 19 | 20 | x_min_value = x.min().item() 21 | dense_x = x.new_full((batch_size * max_num_nodes, ), x_min_value - 1) 22 | dense_x[index] = x 23 | dense_x = dense_x.view(batch_size, max_num_nodes) 24 | _, perm = dense_x.sort(dim=-1, descending=True) 25 | 26 | perm = perm + cum_num_nodes.view(-1, 1) 27 | perm = perm.view(-1) 28 | 29 | mask = [ 30 | torch.arange(k[i], dtype=torch.long, device=x.device) + i * 31 | max_num_nodes for i in range(len(num_nodes)) 32 | ] 33 | mask = torch.cat(mask, dim=0) 34 | 35 | perm = perm[mask] 36 | 37 | return perm 38 | 39 | 40 | def filter_adj(edge_index, edge_attr, perm, num_nodes=None): 41 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 42 | 43 | mask = perm.new_full((num_nodes, ), -1) 44 | i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device) 45 | mask[perm] = i 46 | 47 | row, col = edge_index 48 | row, col = mask[row], mask[col] 49 | mask = (row >= 0) & (col >= 0) 50 | row, col = row[mask], col[mask] 51 | 52 | if edge_attr is not None: 53 | edge_attr = edge_attr[mask] 54 | 55 | return torch.stack([row, col], dim=0), edge_attr 56 | 57 | 58 | class TopKPooling(torch.nn.Module): 59 | r""":math:`\mathrm{top}_k` pooling from the `"Graph U-Net" 60 | `_ paper. 61 | 62 | Args: 63 | in_channels (int): Size of each input sample. 64 | ratio (float): Graph pooling ratio. (default: :obj:`0.5`) 65 | """ 66 | 67 | def __init__(self, in_channels, ratio=0.5): 68 | super(TopKPooling, self).__init__() 69 | 70 | self.in_channels = in_channels 71 | self.ratio = ratio 72 | 73 | self.weight = Parameter(torch.Tensor(1, in_channels)) 74 | 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | size = self.in_channels 79 | uniform(size, self.weight) 80 | 81 | def forward(self, x, edge_index, edge_attr=None, batch=None): 82 | """""" 83 | if batch is None: 84 | batch = edge_index.new_zeros(x.size(0)) 85 | 86 | x = x.unsqueeze(-1) if x.dim() == 1 else x 87 | 88 | score = (x * self.weight).sum(dim=-1) 89 | score = score / self.weight.norm(p=2, dim=-1) 90 | perm = topk(score, self.ratio, batch) 91 | x = x[perm] * torch.tanh(score[perm]).view(-1, 1) 92 | batch = batch[perm] 93 | edge_index, edge_attr = filter_adj( 94 | edge_index, edge_attr, perm, num_nodes=score.size(0)) 95 | 96 | return x, edge_index, edge_attr, batch, perm 97 | 98 | def __repr__(self): 99 | return '{}({})'.format(self.__class__.__name__, self.ratio) 100 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/cheb_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_sparse import spmm 4 | from torch_geometric.utils import degree, remove_self_loops 5 | 6 | from ..inits import uniform 7 | 8 | 9 | class ChebConv(torch.nn.Module): 10 | r"""The chebyshev spectral graph convolutional operator from the 11 | `"Convolutional Neural Networks on Graphs with Fast Localized Spectral 12 | Filtering" `_ paper 13 | 14 | .. math:: 15 | \mathbf{X}^{\prime} = \sum_{k=0}^{K-1} \mathbf{\hat{X}}_k \cdot 16 | \mathbf{\Theta}_k 17 | 18 | where :math:`\mathbf{\hat{X}}_k` is computed recursively by 19 | 20 | .. math:: 21 | \mathbf{\hat{X}}_0 &= \mathbf{X} 22 | 23 | \mathbf{\hat{X}}_1 &= \mathbf{\hat{L}} \cdot \mathbf{X} 24 | 25 | \mathbf{\hat{X}}_k &= 2 \cdot \mathbf{\hat{L}} \cdot 26 | \mathbf{\hat{X}}_{k-1} - \mathbf{\hat{X}}_{k-2} 27 | 28 | and :math:`\mathbf{\hat{L}}` denotes the scaled and normalized Lalplacian. 29 | 30 | Args: 31 | in_channels (int): Size of each input sample. 32 | out_channels (int): Size of each output sample. 33 | K (int): Chebyshev filter size, *i.e.* number of hops. 34 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 35 | an additive bias. (default: :obj:`True`) 36 | """ 37 | 38 | def __init__(self, in_channels, out_channels, K, bias=True): 39 | super(ChebConv, self).__init__() 40 | 41 | self.in_channels = in_channels 42 | self.out_channels = out_channels 43 | self.weight = Parameter(torch.Tensor(K, in_channels, out_channels)) 44 | 45 | if bias: 46 | self.bias = Parameter(torch.Tensor(out_channels)) 47 | else: 48 | self.register_parameter('bias', None) 49 | 50 | self.reset_parameters() 51 | 52 | def reset_parameters(self): 53 | size = self.in_channels * self.weight.size(0) 54 | uniform(size, self.weight) 55 | uniform(size, self.bias) 56 | 57 | def forward(self, x, edge_index, edge_attr=None): 58 | """""" 59 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 60 | 61 | row, col = edge_index 62 | num_nodes, num_edges, K = x.size(0), row.size(0), self.weight.size(0) 63 | 64 | if edge_attr is None: 65 | edge_attr = x.new_ones((num_edges, )) 66 | assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1) 67 | 68 | deg = degree(row, num_nodes, dtype=x.dtype) 69 | 70 | # Compute normalized and rescaled Laplacian. 71 | deg = deg.pow(-0.5) 72 | deg[deg == float('inf')] = 0 73 | lap = -deg[row] * edge_attr * deg[col] 74 | 75 | # Perform filter operation recurrently. 76 | Tx_0 = x 77 | out = torch.mm(Tx_0, self.weight[0]) 78 | 79 | if K > 1: 80 | Tx_1 = spmm(edge_index, lap, num_nodes, x) 81 | out = out + torch.mm(Tx_1, self.weight[1]) 82 | 83 | for k in range(2, K): 84 | Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0 85 | out = out + torch.mm(Tx_2, self.weight[k]) 86 | Tx_0, Tx_1 = Tx_1, Tx_2 87 | 88 | if self.bias is not None: 89 | out = out + self.bias 90 | 91 | return out 92 | 93 | def __repr__(self): 94 | return '{}({}, {}, K={})'.format(self.__class__.__name__, 95 | self.in_channels, self.out_channels, 96 | self.weight.size(0) - 1) 97 | -------------------------------------------------------------------------------- /torch_geometric/data/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import (contains_isolated_nodes, 3 | contains_self_loops, is_undirected) 4 | 5 | from ..utils.num_nodes import maybe_num_nodes 6 | 7 | 8 | class Data(object): 9 | def __init__(self, 10 | x=None, 11 | edge_index=None, 12 | edge_attr=None, 13 | y=None, 14 | pos=None): 15 | self.x = x 16 | self.edge_index = edge_index 17 | self.edge_attr = edge_attr 18 | self.y = y 19 | self.pos = pos 20 | 21 | @staticmethod 22 | def from_dict(dictionary): 23 | data = Data() 24 | for key, item in dictionary.items(): 25 | data[key] = item 26 | return data 27 | 28 | def __getitem__(self, key): 29 | return getattr(self, key) 30 | 31 | def __setitem__(self, key, item): 32 | setattr(self, key, item) 33 | 34 | @property 35 | def keys(self): 36 | return [key for key in self.__dict__.keys() if self[key] is not None] 37 | 38 | def __len__(self): 39 | return len(self.keys) 40 | 41 | def __contains__(self, key): 42 | return key in self.keys 43 | 44 | def __iter__(self): 45 | for key in sorted(self.keys): 46 | yield key, self[key] 47 | 48 | def __call__(self, *keys): 49 | for key in sorted(self.keys) if not keys else keys: 50 | if self[key] is not None: 51 | yield key, self[key] 52 | 53 | def cat_dim(self, key, item): 54 | return -1 if item.dtype == torch.long else 0 55 | 56 | @property 57 | def num_nodes(self): 58 | for key, item in self('x', 'pos'): 59 | return item.size(self.cat_dim(key, item)) 60 | if self.edge_index is not None: 61 | return maybe_num_nodes(self.edge_index) 62 | return None 63 | 64 | @property 65 | def num_edges(self): 66 | for key, item in self('edge_index', 'edge_attr'): 67 | return item.size(self.cat_dim(key, item)) 68 | return None 69 | 70 | @property 71 | def num_features(self): 72 | return 1 if self.x.dim() == 1 else self.x.size(1) 73 | 74 | @property 75 | def num_classes(self): 76 | return self.y.max().item() + 1 if self.y.dim() == 1 else self.y.size(1) 77 | 78 | def is_coalesced(self): 79 | row, col = self.edge_index 80 | index = self.num_nodes * row + col 81 | return self.row.size(0) == torch.unique(index).size(0) 82 | 83 | def contains_isolated_nodes(self): 84 | return contains_isolated_nodes(self.edge_index, self.num_nodes) 85 | 86 | def contains_self_loops(self): 87 | return contains_self_loops(self.edge_index) 88 | 89 | def is_undirected(self): 90 | return is_undirected(self.edge_index, self.num_nodes) 91 | 92 | def is_directed(self): 93 | return not self.is_undirected() 94 | 95 | def apply(self, func, *keys): 96 | for key, item in self(*keys): 97 | self[key] = func(item) 98 | return self 99 | 100 | def contiguous(self, *keys): 101 | return self.apply(lambda x: x.contiguous(), *keys) 102 | 103 | def to(self, device, *keys): 104 | return self.apply(lambda x: x.to(device), *keys) 105 | 106 | def __repr__(self): 107 | info = ['{}={}'.format(key, list(item.size())) for key, item in self] 108 | return '{}({})'.format(self.__class__.__name__, ', '.join(info)) 109 | --------------------------------------------------------------------------------