├── images └── rgcn.png ├── LICENSE ├── README.md ├── helper └── utils.py ├── .gitignore └── environment.yml /images/rgcn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dobraczka/GNNTutorial/HEAD/images/rgcn.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Daniel Obraczka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An intro to Graph Neural Networks 🕸️🧠 2 | 3 | Graph Neural Networks have seen a rise in popularity. This is no surprise since various forms of information can be understood in the context of graphs from social networks to molecules. 4 | This notebook intends to illuminate the inner workings of [Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) and give an intuition into some other types of Networks which extend this idea. 5 | 6 | # Environment and Kernel Setup 7 | 8 | If you work on the clara/paula cluster [load](https://www.sc.uni-leipzig.de/user-doc/quickstart/hpc/#use-preinstalled-software) python: 9 | ``` 10 | ml Python/3.9.5-GCCcore-10.3.0 11 | ``` 12 | or Anaconda 13 | 14 | ``` 15 | ml Anaconda3/2021.11 16 | ``` 17 | 18 | and then create the environment with the respective dependencies: 19 | 20 | ``` 21 | conda env create n "PyG" -f environment.yml 22 | ``` 23 | 24 | Activate the environment 25 | 26 | ``` 27 | conda activate PyG 28 | ``` 29 | 30 | Now create a kernel to use in the Jupyter Notebook 31 | 32 | ``` 33 | ipython kernel install --user --name "PyG" --display-name "PyG" 34 | ``` 35 | 36 | Now you can go to the JupyterLab, select the kernel `PyG` and run the notebook. 37 | -------------------------------------------------------------------------------- /helper/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union, List 2 | 3 | import matplotlib as mpl 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | import numpy as np 7 | 8 | 9 | def nudge(pos: Dict[int, np.ndarray], x_shift: float, y_shift: float): 10 | """Nudge position to position attributes""" 11 | return {n: (x + x_shift, y + y_shift) for n, (x, y) in pos.items()} 12 | 13 | 14 | def draw_graph_with_attributes( 15 | G: nx.Graph, 16 | props: Dict[int, Any] = None, 17 | pos: Dict[int, np.ndarray] = None, 18 | figsize: Tuple[int, int] = (8, 8), 19 | x_nudge: float = 0.0, 20 | y_nudge: float = 0.07, 21 | ax: mpl.axes.Axes = None, 22 | font_color: str = "green", 23 | edge_color: str = "white", 24 | node_color: Union[str, List] = "blue", 25 | ): 26 | """Draw a graph with node labels and attributes""" 27 | if ax is None: 28 | fig, ax = plt.subplots(1, 1, figsize=figsize) 29 | 30 | if pos is None: 31 | pos = nx.spring_layout(G) 32 | 33 | nx.draw_networkx(G, pos=pos, with_labels=True, ax=ax, edge_color=edge_color, node_color=node_color) 34 | pos_nudged = nudge(pos, x_nudge, y_nudge) 35 | if props is None: 36 | props = nx.get_node_attributes(G, "x") 37 | props = { 38 | node_id: np.array2string(np.array(x), precision=2, separator=",") 39 | for node_id, x in props.items() 40 | } 41 | nx.draw_networkx_labels( 42 | G, pos=pos_nudged, labels=props, ax=ax, font_color=font_color 43 | ) 44 | ax.set_ylim(tuple(i * 1.1 for i in ax.get_ylim())) 45 | ax.spines["top"].set_visible(False) 46 | ax.spines["right"].set_visible(False) 47 | ax.spines["bottom"].set_visible(False) 48 | ax.spines["left"].set_visible(False) 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | !blocking_gcn/data 3 | !tests/data 4 | # Created by https://www.toptal.com/developers/gitignore/api/vim,python 5 | # Edit at https://www.toptal.com/developers/gitignore?templates=vim,python 6 | 7 | ### Python ### 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | pytestdebug.log 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | doc/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | # .env 116 | .env/ 117 | .venv/ 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | pythonenv* 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # operating system-related files 147 | # file properties cache/storage on macOS 148 | *.DS_Store 149 | # thumbnail cache on Windows 150 | Thumbs.db 151 | 152 | # profiling data 153 | .prof 154 | 155 | 156 | ### Vim ### 157 | # Swap 158 | [._]*.s[a-v][a-z] 159 | !*.svg # comment out if you don't need vector files 160 | [._]*.sw[a-p] 161 | [._]s[a-rt-v][a-z] 162 | [._]ss[a-gi-z] 163 | [._]sw[a-p] 164 | 165 | # Session 166 | Session.vim 167 | Sessionx.vim 168 | 169 | # Temporary 170 | .netrwhist 171 | *~ 172 | # Auto-generated tag files 173 | tags 174 | # Persistent undo 175 | [._]*.un~ 176 | 177 | # End of https://www.toptal.com/developers/gitignore/api/vim,python 178 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: PyG 2 | channels: 3 | - pyg 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - asttokens=2.0.5=pyhd3eb1b0_0 9 | - backcall=0.2.0=pyhd3eb1b0_0 10 | - blas=1.0=mkl 11 | - bottleneck=1.3.5=py39h7deecbd_0 12 | - brotli=1.0.9=h5eee18b_7 13 | - brotli-bin=1.0.9=h5eee18b_7 14 | - brotlipy=0.7.0=py39h27cfd23_1003 15 | - ca-certificates=2022.10.11=h06a4308_0 16 | - certifi=2022.9.24=py39h06a4308_0 17 | - cffi=1.15.1=py39h74dc2b5_0 18 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 19 | - cryptography=38.0.1=py39h9ce1e76_0 20 | - cudatoolkit=11.3.1=h2bc3f7f_2 21 | - cycler=0.11.0=pyhd3eb1b0_0 22 | - dbus=1.13.18=hb2f20db_0 23 | - debugpy=1.5.1=py39h295c915_0 24 | - decorator=5.1.1=pyhd3eb1b0_0 25 | - entrypoints=0.4=py39h06a4308_0 26 | - executing=0.8.3=pyhd3eb1b0_0 27 | - expat=2.4.9=h6a678d5_0 28 | - fftw=3.3.9=h27cfd23_1 29 | - fontconfig=2.13.1=hef1e5e3_1 30 | - fonttools=4.25.0=pyhd3eb1b0_0 31 | - freetype=2.12.1=h4a9f257_0 32 | - future=0.18.2=py39h06a4308_1 33 | - giflib=5.2.1=h7b6447c_0 34 | - glib=2.69.1=h4ff587b_1 35 | - gst-plugins-base=1.14.0=h8213a91_2 36 | - gstreamer=1.14.0=h28cd5cc_2 37 | - icu=58.2=he6710b0_3 38 | - idna=3.4=py39h06a4308_0 39 | - intel-openmp=2021.4.0=h06a4308_3561 40 | - ipykernel=6.15.2=py39h06a4308_0 41 | - ipython=8.6.0=py39h06a4308_0 42 | - jedi=0.18.1=py39h06a4308_1 43 | - jinja2=3.1.2=py39h06a4308_0 44 | - joblib=1.1.1=py39h06a4308_0 45 | - jpeg=9e=h7f8727e_0 46 | - jupyter_client=7.3.5=py39h06a4308_0 47 | - jupyter_core=4.11.2=py39h06a4308_0 48 | - kiwisolver=1.4.2=py39h295c915_0 49 | - krb5=1.19.2=hac12032_0 50 | - lcms2=2.12=h3be6417_0 51 | - ld_impl_linux-64=2.38=h1181459_1 52 | - lerc=3.0=h295c915_0 53 | - libbrotlicommon=1.0.9=h5eee18b_7 54 | - libbrotlidec=1.0.9=h5eee18b_7 55 | - libbrotlienc=1.0.9=h5eee18b_7 56 | - libclang=10.0.1=default_hb85057a_2 57 | - libdeflate=1.8=h7f8727e_5 58 | - libedit=3.1.20210910=h7f8727e_0 59 | - libevent=2.1.12=h8f2d780_0 60 | - libffi=3.3=he6710b0_2 61 | - libgcc-ng=11.2.0=h1234567_1 62 | - libgfortran-ng=11.2.0=h00389a5_1 63 | - libgfortran5=11.2.0=h1234567_1 64 | - libgomp=11.2.0=h1234567_1 65 | - libllvm10=10.0.1=hbcb73fb_5 66 | - libpng=1.6.37=hbc83047_0 67 | - libpq=12.9=h16c4e8d_3 68 | - libsodium=1.0.18=h7b6447c_0 69 | - libstdcxx-ng=11.2.0=h1234567_1 70 | - libtiff=4.4.0=hecacb30_2 71 | - libuuid=1.41.5=h5eee18b_0 72 | - libwebp=1.2.4=h11a3e52_0 73 | - libwebp-base=1.2.4=h5eee18b_0 74 | - libxcb=1.15=h7f8727e_0 75 | - libxkbcommon=1.0.1=hfa300c1_0 76 | - libxml2=2.9.14=h74e7548_0 77 | - libxslt=1.1.35=h4e12654_0 78 | - lz4-c=1.9.3=h295c915_1 79 | - markupsafe=2.1.1=py39h7f8727e_0 80 | - matplotlib=3.5.3=py39h06a4308_0 81 | - matplotlib-base=3.5.3=py39hf590b9c_0 82 | - matplotlib-inline=0.1.6=py39h06a4308_0 83 | - mkl=2021.4.0=h06a4308_640 84 | - mkl-service=2.4.0=py39h7f8727e_0 85 | - mkl_fft=1.3.1=py39hd3c417c_0 86 | - mkl_random=1.2.2=py39h51133e4_0 87 | - munkres=1.1.4=py_0 88 | - ncurses=6.3=h5eee18b_3 89 | - nest-asyncio=1.5.5=py39h06a4308_0 90 | - networkx=2.8.4=py39h06a4308_0 91 | - ninja=1.10.2=h06a4308_5 92 | - ninja-base=1.10.2=hd09550d_5 93 | - nspr=4.33=h295c915_0 94 | - nss=3.74=h0370c37_0 95 | - numexpr=2.8.3=py39h807cd23_0 96 | - numpy=1.23.4=py39h14f4228_0 97 | - numpy-base=1.23.4=py39h31eccc5_0 98 | - openssl=1.1.1s=h7f8727e_0 99 | - packaging=21.3=pyhd3eb1b0_0 100 | - pandas=1.5.1=py39h417a72b_0 101 | - parso=0.8.3=pyhd3eb1b0_0 102 | - pcre=8.45=h295c915_0 103 | - pexpect=4.8.0=pyhd3eb1b0_3 104 | - pickleshare=0.7.5=pyhd3eb1b0_1003 105 | - pillow=9.2.0=py39hace64e9_1 106 | - pip=22.2.2=py39h06a4308_0 107 | - ply=3.11=py39h06a4308_0 108 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 109 | - psutil=5.9.0=py39h5eee18b_0 110 | - ptyprocess=0.7.0=pyhd3eb1b0_2 111 | - pure_eval=0.2.2=pyhd3eb1b0_0 112 | - pycparser=2.21=pyhd3eb1b0_0 113 | - pyg=2.0.4=py39_torch_1.10.0_cu113 114 | - pygments=2.11.2=pyhd3eb1b0_0 115 | - pyopenssl=22.0.0=pyhd3eb1b0_0 116 | - pyparsing=3.0.9=py39h06a4308_0 117 | - pyqt=5.15.7=py39h6a678d5_1 118 | - pyqt5-sip=12.11.0=py39h6a678d5_1 119 | - pysocks=1.7.1=py39h06a4308_0 120 | - python=3.9.15=haa1d7c7_0 121 | - python-dateutil=2.8.2=pyhd3eb1b0_0 122 | - python-louvain=0.15=pyhd3eb1b0_0 123 | - pytorch=1.10.2=cpu_py39hfa7516b_0 124 | - pytorch-cluster=1.6.0=py39_torch_1.10.0_cu113 125 | - pytorch-scatter=2.0.9=py39_torch_1.10.0_cu113 126 | - pytorch-sparse=0.6.13=py39_torch_1.10.0_cu113 127 | - pytorch-spline-conv=1.2.1=py39_torch_1.10.0_cu113 128 | - pytz=2022.1=py39h06a4308_0 129 | - pyyaml=6.0=py39h7f8727e_1 130 | - pyzmq=23.2.0=py39h6a678d5_0 131 | - qt-main=5.15.2=h327a75a_7 132 | - qt-webengine=5.15.9=hd2b0992_4 133 | - qtwebkit=5.212=h4eab89a_4 134 | - readline=8.2=h5eee18b_0 135 | - requests=2.28.1=py39h06a4308_0 136 | - scikit-learn=1.1.3=py39h6a678d5_0 137 | - scipy=1.9.3=py39h14f4228_0 138 | - setuptools=65.5.0=py39h06a4308_0 139 | - sip=6.6.2=py39h6a678d5_0 140 | - six=1.16.0=pyhd3eb1b0_1 141 | - sqlite=3.39.3=h5082296_0 142 | - stack_data=0.2.0=pyhd3eb1b0_0 143 | - threadpoolctl=2.2.0=pyh0d69192_0 144 | - tk=8.6.12=h1ccaba5_0 145 | - toml=0.10.2=pyhd3eb1b0_0 146 | - tornado=6.2=py39h5eee18b_0 147 | - tqdm=4.64.1=py39h06a4308_0 148 | - traitlets=5.1.1=pyhd3eb1b0_0 149 | - typing-extensions=4.3.0=py39h06a4308_0 150 | - typing_extensions=4.3.0=py39h06a4308_0 151 | - tzdata=2022f=h04d1e81_0 152 | - urllib3=1.26.12=py39h06a4308_0 153 | - wcwidth=0.2.5=pyhd3eb1b0_0 154 | - wheel=0.37.1=pyhd3eb1b0_0 155 | - xz=5.2.6=h5eee18b_0 156 | - yacs=0.1.6=pyhd3eb1b0_1 157 | - yaml=0.2.5=h7b6447c_0 158 | - zeromq=4.3.4=h2531618_0 159 | - zlib=1.2.13=h5eee18b_0 160 | - zstd=1.5.2=ha4553b6_0 161 | prefix: /home/sc.uni-leipzig.de/se302qele/miniconda3/envs/PyG 162 | --------------------------------------------------------------------------------