├── environment.yml
├── LICENSE
├── .gitignore
├── README.md
└── notebooks
├── 06-graph-neural-networks-1-gnn-model.ipynb
└── 08-applications-of-graph-neural-networks.ipynb
/environment.yml:
--------------------------------------------------------------------------------
1 | # This is Conda environment file
2 | # Usage: `conda env update -f environment.yml`
3 |
4 | name:
5 | cs224w
6 |
7 | channels:
8 | - pyg
9 | - pytorch
10 | - conda-forge
11 |
12 | dependencies:
13 | - grakel==0.1.8
14 | - ipywidgets
15 | - karateclub==1.2.1
16 | - notebook
17 | - ogb
18 | - pip
19 | - pyg
20 | - python==3.7.12
21 | - pytorch>=1.8.0
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Mario Namtao Shianti Larcher
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CS224W: Machine Learning with Graphs - Slides to Code (Unofficial)
2 |
3 | This repository is an attempt to convert the slides from Stanford's "CS224W: Machine Learning with Graphs" course into code. The notebooks presented here include code to implement techniques hinted at in the lectures but not shown in the official labs.
4 |
5 | My initial plan was to cover all the lessons but already by the eighth the computation becomes challenging for Colab and I think these first eight are already a great introduction to the subject, I'll stop here.
6 |
7 | > **Disclaimer**: I am not a Stanford student and this material has not been reviewed by the course instructors, it is possible that it contains errors, if you find any please open an issue.
8 |
9 | A few useful links:
10 | * [CS224W | Home](http://web.stanford.edu/class/cs224w/index.html#schedule)
11 | * [CS224W | YouTube](https://youtu.be/JAB_plj2rbA)
12 |
13 | # Notebooks
14 | 1. [Introduction; Machine Learning for Graphs](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/01-introduction-machine-learning-for-graphs.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/01-introduction-machine-learning-for-graphs.ipynb)
15 | 2. [Traditional Methods for ML on Graphs](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/02-traditional-methods-for-ml-on-graphs.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/02-traditional-methods-for-ml-on-graphs.ipynb)
16 | 3. [Node Embeddings](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/03-node-embeddings.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/03-node-embeddings.ipynb)
17 | 4. [Link Analysis: PageRank](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/04-link-analysis-pagerank.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/04-link-analysis-pagerank.ipynb)
18 | 5. [Label Propagation for Node Classification](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/05-label-propagation-for-node-classification.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/05-label-propagation-for-node-classification.ipynb)
19 | 6. [Graph Neural Networks 1: GNN Model](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/06-graph-neural-networks-1-gnn-model.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/06-graph-neural-networks-1-gnn-model.ipynb) [](https://blog.devgenius.io/how-to-train-a-graph-convolutional-network-on-the-cora-dataset-with-pytorch-geometric-847ed5fab9cb)
20 | 7. [Graph Neural Networks 2: Design Space](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/07-graph-neural-networks-2-design-space.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/07-graph-neural-networks-2-design-space.ipynb)
21 | 8. [Applications of Graph Neural Networks](https://github.com/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/08-applications-of-graph-neural-networks.ipynb) [](https://colab.research.google.com/github/mnslarcher/cs224w-slides-to-code/blob/main/notebooks/08-applications-of-graph-neural-networks.ipynb)
22 |
23 | # Changelog
24 |
25 | ## 2022-02-13
26 |
27 | - Add the link to the Medium article related to the 6th lesson "Graph Neural Networks 1: GNN Model"
28 |
29 | ## 2021-12-12
30 |
31 | - Complete the notebook related to the 8th lesson "Applications of Graph Neural Networks"
32 |
33 | ## 2021-12-08
34 |
35 | - Complete the notebook related to the 7th lesson "Graph Neural Networks 2: Design Space"
36 |
37 | ## 2021-12-04
38 |
39 | - Complete the notebook related to the 6th lesson "Graph Neural Networks 1: GNN Model"
40 | - Add Anaconda env. file
41 |
42 | ## 2021-11-28
43 |
44 | - Complete the notebook related to the 5th lesson "Label Propagation for Node Classification"
45 |
46 | ## 2021-11-14
47 |
48 | - Complete the notebook related to the 4th lesson "Link Analysis: PageRank"
49 |
50 | ## 2021-11-13
51 |
52 | - Complete the notebook related to the 3rd lesson "Node Embeddings"
53 |
54 | ## 2021-11-06
55 |
56 | - Complete the notebook related to the 2nd lesson "Traditional Methods for ML on Graphs"
57 |
58 | ## 2021-11-04
59 |
60 | - Complete the notebook related to the 1st lesson "Introduction; Machine Learning for Graphs"
61 |
--------------------------------------------------------------------------------
/notebooks/06-graph-neural-networks-1-gnn-model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "id": "PJSONe-045tN"
18 | },
19 | "outputs": [],
20 | "source": [
21 | "try:\n",
22 | " # Check if PyTorch Geometric is installed:\n",
23 | " import torch_geometric\n",
24 | "except ImportError:\n",
25 | " # If PyTorch Geometric is not installed, install it.\n",
26 | " %pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n",
27 | " %pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n",
28 | " %pip install -q torch-geometric"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "oV56fjpNYjCR"
35 | },
36 | "source": [
37 | "# Graph Neural Networks 1: GNN Model"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "metadata": {
44 | "colab": {
45 | "base_uri": "https://localhost:8080/"
46 | },
47 | "id": "2OtSkdGl48E0",
48 | "outputId": "7efa478a-75df-4efb-991a-a3ac877f6469"
49 | },
50 | "outputs": [],
51 | "source": [
52 | "from typing import Callable, List, Optional, Tuple\n",
53 | "\n",
54 | "import matplotlib.pyplot as plt\n",
55 | "import numpy as np\n",
56 | "import torch\n",
57 | "import torch.nn.functional as F\n",
58 | "import torch_geometric.transforms as T\n",
59 | "from torch import Tensor\n",
60 | "from torch.optim import Optimizer\n",
61 | "from torch_geometric.data import Data\n",
62 | "from torch_geometric.datasets import Planetoid\n",
63 | "from torch_geometric.nn import GCNConv\n",
64 | "from torch_geometric.utils import accuracy\n",
65 | "from typing_extensions import Literal, TypedDict"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {
71 | "id": "KYGSasdA5ziu"
72 | },
73 | "source": [
74 | "## Cora Dataset"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {
80 | "id": "_iseKiaT5GZ7"
81 | },
82 | "source": [
83 | "> From The [Papers With Code page of the Cora Dataset](https://paperswithcode.com/dataset/cora): \"The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.\""
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {
89 | "id": "z518Aul25JZv"
90 | },
91 | "source": [
92 | "> From [Kipf & Welling (ICLR 2017)](https://arxiv.org/abs/1609.02907): \"[...] evaluate prediction accuracy on a test set of 1,000 labeled examples. [...] validation set of 500 labeled examples for hyperparameter optimization (dropout rate for all layers, L2 regularization factor for the first GCN layer and number of hidden units). We do not use the validation set labels for training.\""
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 3,
98 | "metadata": {
99 | "colab": {
100 | "base_uri": "https://localhost:8080/"
101 | },
102 | "id": "6Gg7oGlS5NRX",
103 | "outputId": "b8d10e40-275d-4b51-9907-b66d23465c6b"
104 | },
105 | "outputs": [
106 | {
107 | "name": "stdout",
108 | "output_type": "stream",
109 | "text": [
110 | "Dataset: Cora\n",
111 | "Num. nodes: 2708 (train=140, val=500, test=1000, other=1068)\n",
112 | "Num. edges: 5278\n",
113 | "Num. node features: 1433\n",
114 | "Num. classes: 7\n",
115 | "Dataset len.: 1\n"
116 | ]
117 | }
118 | ],
119 | "source": [
120 | "dataset = Planetoid(\"/tmp/Cora\", name=\"Cora\")\n",
121 | "num_nodes = dataset.data.num_nodes\n",
122 | "# For num. edges see:\n",
123 | "# - https://github.com/pyg-team/pytorch_geometric/issues/343\n",
124 | "# - https://github.com/pyg-team/pytorch_geometric/issues/852\n",
125 | "num_edges = dataset.data.num_edges // 2\n",
126 | "train_len = dataset[0].train_mask.sum()\n",
127 | "val_len = dataset[0].val_mask.sum()\n",
128 | "test_len = dataset[0].test_mask.sum()\n",
129 | "other_len = num_nodes - train_len - val_len - test_len\n",
130 | "print(f\"Dataset: {dataset.name}\")\n",
131 | "print(f\"Num. nodes: {num_nodes} (train={train_len}, val={val_len}, test={test_len}, other={other_len})\")\n",
132 | "print(f\"Num. edges: {num_edges}\")\n",
133 | "print(f\"Num. node features: {dataset.num_node_features}\")\n",
134 | "print(f\"Num. classes: {dataset.num_classes}\")\n",
135 | "print(f\"Dataset len.: {dataset.len()}\")"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "metadata": {
141 | "id": "hBzvCryI5Uba"
142 | },
143 | "source": [
144 | "> From [Kipf & Welling (ICLR 2017)](https://arxiv.org/abs/1609.02907): \"We initialize weights using the initialization described in Glorot & Bengio (2010) and accordingly (row-)normalize input feature vectors.\""
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 4,
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/"
153 | },
154 | "id": "nljKuAA85XDE",
155 | "outputId": "48bc3cd5-4d3b-4c1f-90ff-e435c28962da"
156 | },
157 | "outputs": [
158 | {
159 | "name": "stdout",
160 | "output_type": "stream",
161 | "text": [
162 | "Sum of row values without normalization: tensor([ 9., 23., 19., ..., 18., 14., 13.])\n",
163 | "Sum of row values with normalization: tensor([1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000])\n"
164 | ]
165 | }
166 | ],
167 | "source": [
168 | "dataset = Planetoid(\"/tmp/Cora\", name=\"Cora\")\n",
169 | "print(f\"Sum of row values without normalization: {dataset[0].x.sum(dim=-1)}\")\n",
170 | "\n",
171 | "dataset = Planetoid(\"/tmp/Cora\", name=\"Cora\", transform=T.NormalizeFeatures())\n",
172 | "print(f\"Sum of row values with normalization: {dataset[0].x.sum(dim=-1)}\")"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {
178 | "id": "c0Jw_TT35AiO"
179 | },
180 | "source": [
181 | "## Graph Convolutional Networks"
182 | ]
183 | },
184 | {
185 | "cell_type": "markdown",
186 | "metadata": {
187 | "id": "b19trEwa5Y-W"
188 | },
189 | "source": [
190 | "> From [Kipf & Welling (ICLR 2017)](https://arxiv.org/abs/1609.02907): \"We used the following sets of hyperparameters for Citeseer, Cora and Pubmed: 0.5 (dropout rate), $5\\cdot10^{-4}$ (L2 regularization) and 16 (number of hidden units);\""
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": 5,
196 | "metadata": {
197 | "id": "-m_i0aCa5eM4"
198 | },
199 | "outputs": [],
200 | "source": [
201 | "class GCN(torch.nn.Module):\n",
202 | " def __init__(\n",
203 | " self,\n",
204 | " num_node_features: int,\n",
205 | " num_classes: int,\n",
206 | " hidden_dim: int = 16,\n",
207 | " dropout_rate: float = 0.5,\n",
208 | " ) -> None:\n",
209 | " super().__init__()\n",
210 | " self.dropout1 = torch.nn.Dropout(dropout_rate)\n",
211 | " self.conv1 = GCNConv(num_node_features, hidden_dim)\n",
212 | " self.relu = torch.nn.ReLU(inplace=True)\n",
213 | " self.dropout2 = torch.nn.Dropout(dropout_rate)\n",
214 | " self.conv2 = GCNConv(hidden_dim, num_classes)\n",
215 | "\n",
216 | " def forward(self, x: Tensor, edge_index: Tensor) -> torch.Tensor:\n",
217 | " x = self.dropout1(x)\n",
218 | " x = self.conv1(x, edge_index)\n",
219 | " x = self.relu(x)\n",
220 | " x = self.dropout2(x)\n",
221 | " x = self.conv2(x, edge_index)\n",
222 | " return x"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 6,
228 | "metadata": {
229 | "colab": {
230 | "base_uri": "https://localhost:8080/"
231 | },
232 | "id": "jKQH-HKS5hVt",
233 | "outputId": "e98433b7-97c9-4469-e409-6aa8b07dea4e"
234 | },
235 | "outputs": [
236 | {
237 | "name": "stdout",
238 | "output_type": "stream",
239 | "text": [
240 | "Graph Convolutional Network (GCN):\n"
241 | ]
242 | },
243 | {
244 | "data": {
245 | "text/plain": [
246 | "GCN(\n",
247 | " (dropout1): Dropout(p=0.5, inplace=False)\n",
248 | " (conv1): GCNConv(1433, 16)\n",
249 | " (relu): ReLU(inplace=True)\n",
250 | " (dropout2): Dropout(p=0.5, inplace=False)\n",
251 | " (conv2): GCNConv(16, 7)\n",
252 | ")"
253 | ]
254 | },
255 | "execution_count": 6,
256 | "metadata": {},
257 | "output_type": "execute_result"
258 | }
259 | ],
260 | "source": [
261 | "print(\"Graph Convolutional Network (GCN):\")\n",
262 | "GCN(dataset.num_node_features, dataset.num_classes)"
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {
268 | "id": "mN5NPwag6GPT"
269 | },
270 | "source": [
271 | "## Training and Evaluation"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 7,
277 | "metadata": {},
278 | "outputs": [],
279 | "source": [
280 | "LossFn = Callable[[Tensor, Tensor], Tensor]\n",
281 | "Stage = Literal[\"train\", \"val\", \"test\"]\n",
282 | "\n",
283 | "\n",
284 | "def train_step(\n",
285 | " model: torch.nn.Module, data: Data, optimizer: torch.optim.Optimizer, loss_fn: LossFn\n",
286 | ") -> Tuple[float, float]:\n",
287 | " model.train()\n",
288 | " optimizer.zero_grad()\n",
289 | " mask = data.train_mask\n",
290 | " logits = model(data.x, data.edge_index)[mask]\n",
291 | " preds = logits.argmax(dim=1)\n",
292 | " y = data.y[mask]\n",
293 | " loss = loss_fn(logits, y)\n",
294 | " # + L2 regularization to the first layer only\n",
295 | " # for name, params in model.state_dict().items():\n",
296 | " # if name.startswith(\"conv1\"):\n",
297 | " # loss += 5e-4 * params.square().sum() / 2.0\n",
298 | "\n",
299 | " acc = accuracy(preds, y)\n",
300 | " loss.backward()\n",
301 | " optimizer.step()\n",
302 | " return loss.item(), acc\n",
303 | "\n",
304 | "\n",
305 | "@torch.no_grad()\n",
306 | "def eval_step(model: torch.nn.Module, data: Data, loss_fn: LossFn, stage: Stage) -> Tuple[float, float]:\n",
307 | " model.eval()\n",
308 | " mask = getattr(data, f\"{stage}_mask\")\n",
309 | " logits = model(data.x, data.edge_index)[mask]\n",
310 | " preds = logits.argmax(dim=1)\n",
311 | " y = data.y[mask]\n",
312 | " loss = loss_fn(logits, y)\n",
313 | " # + L2 regularization to the first layer only\n",
314 | " # for name, params in model.state_dict().items():\n",
315 | " # if name.startswith(\"conv1\"):\n",
316 | " # loss += 5e-4 * params.square().sum() / 2.0\n",
317 | "\n",
318 | " acc = accuracy(preds, y)\n",
319 | " return loss.item(), acc"
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "metadata": {
325 | "id": "oMqTMjZD5mTy"
326 | },
327 | "source": [
328 | "> From [Kipf & Welling (ICLR 2017)](https://arxiv.org/abs/1609.02907): \"We train all models for a maximum of 200 epochs (training iterations) using Adam (Kingma & Ba, 2015) with a learning rate of 0.01 and early stopping with a window size of 10, i.e. we stop training if the validation loss does not decrease for 10 consecutive epochs.\""
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": 8,
334 | "metadata": {},
335 | "outputs": [],
336 | "source": [
337 | "class HistoryDict(TypedDict):\n",
338 | " loss: List[float]\n",
339 | " acc: List[float]\n",
340 | " val_loss: List[float]\n",
341 | " val_acc: List[float]\n",
342 | "\n",
343 | "\n",
344 | "def train(\n",
345 | " model: torch.nn.Module,\n",
346 | " data: Data,\n",
347 | " optimizer: torch.optim.Optimizer,\n",
348 | " loss_fn: LossFn = torch.nn.CrossEntropyLoss(),\n",
349 | " max_epochs: int = 200,\n",
350 | " early_stopping: int = 10,\n",
351 | " print_interval: int = 20,\n",
352 | " verbose: bool = True,\n",
353 | ") -> HistoryDict:\n",
354 | " history = {\"loss\": [], \"val_loss\": [], \"acc\": [], \"val_acc\": []}\n",
355 | " for epoch in range(max_epochs):\n",
356 | " loss, acc = train_step(model, data, optimizer, loss_fn)\n",
357 | " val_loss, val_acc = eval_step(model, data, loss_fn, \"val\")\n",
358 | " history[\"loss\"].append(loss)\n",
359 | " history[\"acc\"].append(acc)\n",
360 | " history[\"val_loss\"].append(val_loss)\n",
361 | " history[\"val_acc\"].append(val_acc)\n",
362 | " # The official implementation in TensorFlow is a little different from what is described in the paper...\n",
363 | " if epoch > early_stopping and val_loss > np.mean(history[\"val_loss\"][-(early_stopping + 1) : -1]):\n",
364 | " if verbose:\n",
365 | " print(\"\\nEarly stopping...\")\n",
366 | "\n",
367 | " break\n",
368 | "\n",
369 | " if verbose and epoch % print_interval == 0:\n",
370 | " print(f\"\\nEpoch: {epoch}\\n----------\")\n",
371 | " print(f\"Train loss: {loss:.4f} | Train acc: {acc:.4f}\")\n",
372 | " print(f\" Val loss: {val_loss:.4f} | Val acc: {val_acc:.4f}\")\n",
373 | "\n",
374 | " test_loss, test_acc = eval_step(model, data, loss_fn, \"test\")\n",
375 | " if verbose:\n",
376 | " print(f\"\\nEpoch: {epoch}\\n----------\")\n",
377 | " print(f\"Train loss: {loss:.4f} | Train acc: {acc:.4f}\")\n",
378 | " print(f\" Val loss: {val_loss:.4f} | Val acc: {val_acc:.4f}\")\n",
379 | " print(f\" Test loss: {test_loss:.4f} | Test acc: {test_acc:.4f}\")\n",
380 | "\n",
381 | " return history"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": 9,
387 | "metadata": {},
388 | "outputs": [],
389 | "source": [
390 | "def plot_history(history: HistoryDict, title: str, font_size: Optional[int] = 14) -> None:\n",
391 | " plt.suptitle(title, fontsize=font_size)\n",
392 | " ax1 = plt.subplot(121)\n",
393 | " ax1.set_title(\"Loss\")\n",
394 | " ax1.plot(history[\"loss\"], label=\"train\")\n",
395 | " ax1.plot(history[\"val_loss\"], label=\"val\")\n",
396 | " plt.xlabel(\"Epoch\")\n",
397 | " ax1.legend()\n",
398 | "\n",
399 | " ax2 = plt.subplot(122)\n",
400 | " ax2.set_title(\"Accuracy\")\n",
401 | " ax2.plot(history[\"acc\"], label=\"train\")\n",
402 | " ax2.plot(history[\"val_acc\"], label=\"val\")\n",
403 | " plt.xlabel(\"Epoch\")\n",
404 | " ax2.legend()"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 10,
410 | "metadata": {},
411 | "outputs": [
412 | {
413 | "name": "stdout",
414 | "output_type": "stream",
415 | "text": [
416 | "\n",
417 | "Epoch: 0\n",
418 | "----------\n",
419 | "Train loss: 1.9460 | Train acc: 0.1500\n",
420 | " Val loss: 1.9480 | Val acc: 0.0820\n",
421 | "\n",
422 | "Epoch: 20\n",
423 | "----------\n",
424 | "Train loss: 1.7267 | Train acc: 0.6500\n",
425 | " Val loss: 1.8245 | Val acc: 0.4860\n",
426 | "\n",
427 | "Epoch: 40\n",
428 | "----------\n",
429 | "Train loss: 1.4002 | Train acc: 0.7857\n",
430 | " Val loss: 1.6003 | Val acc: 0.6980\n",
431 | "\n",
432 | "Epoch: 60\n",
433 | "----------\n",
434 | "Train loss: 0.9969 | Train acc: 0.8714\n",
435 | " Val loss: 1.3578 | Val acc: 0.7540\n",
436 | "\n",
437 | "Epoch: 80\n",
438 | "----------\n",
439 | "Train loss: 0.7222 | Train acc: 0.9286\n",
440 | " Val loss: 1.1672 | Val acc: 0.7720\n",
441 | "\n",
442 | "Epoch: 100\n",
443 | "----------\n",
444 | "Train loss: 0.6089 | Train acc: 0.9214\n",
445 | " Val loss: 1.0402 | Val acc: 0.7800\n",
446 | "\n",
447 | "Epoch: 120\n",
448 | "----------\n",
449 | "Train loss: 0.5035 | Train acc: 0.9071\n",
450 | " Val loss: 0.9487 | Val acc: 0.7860\n",
451 | "\n",
452 | "Early stopping...\n",
453 | "\n",
454 | "Epoch: 127\n",
455 | "----------\n",
456 | "Train loss: 0.5273 | Train acc: 0.9500\n",
457 | " Val loss: 0.9461 | Val acc: 0.7760\n",
458 | " Test loss: 0.9202 | Test acc: 0.8080\n"
459 | ]
460 | },
461 | {
462 | "data": {
463 | "image/png": "\n",
464 | "text/plain": [
465 | ""
466 | ]
467 | },
468 | "metadata": {
469 | "needs_background": "light"
470 | },
471 | "output_type": "display_data"
472 | }
473 | ],
474 | "source": [
475 | "SEED = 42\n",
476 | "MAX_EPOCHS = 200\n",
477 | "LEARNING_RATE = 0.01\n",
478 | "WEIGHT_DECAY = 5e-4\n",
479 | "EARLY_STOPPING = 10\n",
480 | "\n",
481 | "\n",
482 | "torch.manual_seed(SEED)\n",
483 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
484 | "\n",
485 | "model = GCN(dataset.num_node_features, dataset.num_classes).to(device)\n",
486 | "data = dataset[0].to(device)\n",
487 | "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n",
488 | "history = train(model, data, optimizer, max_epochs=MAX_EPOCHS, early_stopping=EARLY_STOPPING)\n",
489 | "\n",
490 | "plt.figure(figsize=(12, 4))\n",
491 | "plot_history(history, \"GCN\")"
492 | ]
493 | }
494 | ],
495 | "metadata": {
496 | "colab": {
497 | "authorship_tag": "ABX9TyM4BJ9SfFG7mXOuMcdhOFtF",
498 | "include_colab_link": true,
499 | "name": "06-graph-neural-networks-1-gnn-model.ipynb",
500 | "provenance": []
501 | },
502 | "kernelspec": {
503 | "display_name": "Python 3",
504 | "language": "python",
505 | "name": "python3"
506 | },
507 | "language_info": {
508 | "codemirror_mode": {
509 | "name": "ipython",
510 | "version": 3
511 | },
512 | "file_extension": ".py",
513 | "mimetype": "text/x-python",
514 | "name": "python",
515 | "nbconvert_exporter": "python",
516 | "pygments_lexer": "ipython3",
517 | "version": "3.7.12"
518 | }
519 | },
520 | "nbformat": 4,
521 | "nbformat_minor": 1
522 | }
523 |
--------------------------------------------------------------------------------
/notebooks/08-applications-of-graph-neural-networks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e578f14e",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "8b54d1d9",
14 | "metadata": {},
15 | "source": [
16 | "> **Warning**: this notebook takes forever on the CPU and goes OOM on the standard Colab GPU, if possible use a more powerful GPU"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 1,
22 | "id": "c2d8a203",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "try:\n",
27 | " # Check if PyTorch Geometric is installed:\n",
28 | " import torch_geometric\n",
29 | "except ImportError:\n",
30 | " # If PyTorch Geometric is not installed, install it.\n",
31 | " %pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n",
32 | " %pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n",
33 | " %pip install -q torch-geometric"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "id": "bc6e944c",
39 | "metadata": {},
40 | "source": [
41 | "# Applications of Graph Neural Networks"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 2,
47 | "id": "76f2ecef",
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "import os\n",
52 | "from copy import deepcopy\n",
53 | "from typing import List, Optional, Tuple\n",
54 | "\n",
55 | "import matplotlib.pyplot as plt\n",
56 | "import networkx as nx\n",
57 | "import numpy as np\n",
58 | "import torch\n",
59 | "import torch.nn.functional as F\n",
60 | "import torch_geometric.transforms as T\n",
61 | "from sklearn.metrics import roc_auc_score\n",
62 | "from torch_geometric.data import Data\n",
63 | "from torch_geometric.loader.neighbor_sampler import EdgeIndex\n",
64 | "from torch_geometric.utils import to_networkx\n",
65 | "from tqdm import tqdm"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "id": "bc71bb37",
71 | "metadata": {},
72 | "source": [
73 | "# Feature Augmentation on Graphs"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 3,
79 | "id": "7d4c1589",
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "def draw_graph_from_data(data: Data, node_size: Optional[int] = 1500, seed: Optional[int] = None) -> None:\n",
84 | " G = to_networkx(data, to_undirected=data.is_undirected())\n",
85 | " labels = {node_id: f\"ID: {node_id}\\n$x_{node_id}$: {node_x.tolist()}\" for node_id, node_x in enumerate(data.x)}\n",
86 | " pos = nx.spring_layout(G, seed=seed)\n",
87 | " nx.draw(G, pos=pos, labels=labels, node_size=node_size)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "5f1c355a",
93 | "metadata": {},
94 | "source": [
95 | "## Asign constant values to nodes"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 4,
101 | "id": "dd3c9d6e",
102 | "metadata": {},
103 | "outputs": [
104 | {
105 | "name": "stdout",
106 | "output_type": "stream",
107 | "text": [
108 | "Num. nodes: 6\n",
109 | "Num. node features: 1\n",
110 | "Num. edges: 14\n",
111 | "Is undirected? True\n"
112 | ]
113 | },
114 | {
115 | "data": {
116 | "image/png": "\n",
117 | "text/plain": [
118 | ""
119 | ]
120 | },
121 | "metadata": {},
122 | "output_type": "display_data"
123 | }
124 | ],
125 | "source": [
126 | "x = torch.ones((6, 1))\n",
127 | "edge_index = torch.tensor(\n",
128 | " [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5], [1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]], dtype=torch.long\n",
129 | ")\n",
130 | "data = Data(x=x, edge_index=edge_index)\n",
131 | "print(f\"Num. nodes: {data.num_nodes}\")\n",
132 | "print(f\"Num. node features: {data.num_node_features}\")\n",
133 | "print(f\"Num. edges: {data.num_edges}\")\n",
134 | "print(f\"Is undirected? {data.is_undirected()}\")\n",
135 | "\n",
136 | "draw_graph_from_data(data, seed=42)"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "id": "844252c4",
142 | "metadata": {},
143 | "source": [
144 | "## Asign unique IDs to nodes"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 5,
150 | "id": "3ca0acae",
151 | "metadata": {},
152 | "outputs": [
153 | {
154 | "name": "stdout",
155 | "output_type": "stream",
156 | "text": [
157 | "Num. nodes: 6\n",
158 | "Num. node features: 6\n",
159 | "Num. edges: 14\n",
160 | "Is undirected? True\n"
161 | ]
162 | },
163 | {
164 | "data": {
165 | "image/png": "\n",
166 | "text/plain": [
167 | ""
168 | ]
169 | },
170 | "metadata": {},
171 | "output_type": "display_data"
172 | }
173 | ],
174 | "source": [
175 | "x = torch.eye(6)\n",
176 | "edge_index = torch.tensor(\n",
177 | " [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5], [1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]], dtype=torch.long\n",
178 | ")\n",
179 | "data = Data(x=x, edge_index=edge_index)\n",
180 | "print(f\"Num. nodes: {data.num_nodes}\")\n",
181 | "print(f\"Num. node features: {data.num_node_features}\")\n",
182 | "print(f\"Num. edges: {data.num_edges}\")\n",
183 | "print(f\"Is undirected? {data.is_undirected()}\")\n",
184 | "\n",
185 | "draw_graph_from_data(data, seed=42)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "id": "3611f325",
191 | "metadata": {},
192 | "source": [
193 | "# Add Virtual Nodes/Edges"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "id": "e0d02317",
199 | "metadata": {},
200 | "source": [
201 | "## Add virtual edges"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 6,
207 | "id": "46ed6e1a",
208 | "metadata": {},
209 | "outputs": [],
210 | "source": [
211 | "from torch_geometric.utils import dense_to_sparse, to_dense_adj"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 7,
217 | "id": "c399a5f8",
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "def draw_bipartite(\n",
222 | " data: Data,\n",
223 | " node1_color: str = \"tab:orange\",\n",
224 | " node2_color: str = \"tab:blue\",\n",
225 | " node1_shape: str = \"o\",\n",
226 | " node2_shape: str = \"s\",\n",
227 | " set1_name: str = \"Authors\",\n",
228 | " set2_name: str = \"Papers\",\n",
229 | " text_size: int = 16,\n",
230 | " font_weight: str = \"bold\",\n",
231 | ") -> None:\n",
232 | " group0_len = (data.x == 0).sum().item()\n",
233 | " group1_len = (data.x == 1).sum().item()\n",
234 | " len_max = max(group0_len, group1_len)\n",
235 | " offset = (group0_len - group1_len) / 2\n",
236 | " bipartite = data.x.flatten().tolist()\n",
237 | " labels = {\n",
238 | " node_id: str(node_id - group0_len + 1) if node_group else chr(node_id + 65)\n",
239 | " for node_id, node_group in enumerate(bipartite)\n",
240 | " }\n",
241 | " pos = {\n",
242 | " node_id: (1 + node_group, group0_len - node_id + node_group * (offset + group0_len - 1))\n",
243 | " for node_id, node_group in enumerate(bipartite)\n",
244 | " }\n",
245 | " node_colors = [node1_color if gr else node2_color for gr in bipartite]\n",
246 | " node_shapes = [node1_shape if gr else node2_shape for gr in bipartite]\n",
247 | " G = to_networkx(data, to_undirected=data.is_undirected())\n",
248 | "\n",
249 | " nx.draw_networkx_nodes(range(group0_len), pos=pos, node_color=node1_color, node_shape=node1_shape)\n",
250 | " nx.draw_networkx_nodes(\n",
251 | " range(group0_len, group0_len + group1_len), pos=pos, node_color=node2_color, node_shape=node2_shape\n",
252 | " )\n",
253 | " nx.draw_networkx_labels(G, pos=pos, labels=labels)\n",
254 | " nx.draw_networkx_edges(G, pos=pos, width=2)\n",
255 | "\n",
256 | " plt.text(0.95, len_max + 0.75, set1_name, size=text_size, fontweight=font_weight)\n",
257 | " plt.text(1.95, len_max + 0.75, set2_name, size=text_size, fontweight=font_weight)\n",
258 | "\n",
259 | " plt.axis(\"off\")"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 8,
265 | "id": "3ab7a9b4",
266 | "metadata": {},
267 | "outputs": [],
268 | "source": [
269 | "def connect_two_hop_neighbors(data: Data) -> Data:\n",
270 | " data = deepcopy(data)\n",
271 | " A = to_dense_adj(data.edge_index)[0]\n",
272 | " idx = range(len(A))\n",
273 | " A2 = torch.matrix_power(A, 2)\n",
274 | " A2[idx, idx] = 0.0\n",
275 | " data.edge_index = dense_to_sparse(A + A2)[0]\n",
276 | " return data"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 9,
282 | "id": "e7224566",
283 | "metadata": {
284 | "scrolled": true
285 | },
286 | "outputs": [
287 | {
288 | "data": {
289 | "image/png": "\n",
290 | "text/plain": [
291 | ""
292 | ]
293 | },
294 | "metadata": {},
295 | "output_type": "display_data"
296 | }
297 | ],
298 | "source": [
299 | "x = torch.vstack([torch.zeros((5, 1)), torch.ones((4, 1))])\n",
300 | "edge_index = torch.tensor(\n",
301 | " [[0, 0, 1, 1, 2, 3, 3, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8], [5, 6, 5, 7, 7, 6, 8, 5, 8, 0, 1, 4, 0, 3, 1, 2, 3, 4]],\n",
302 | " dtype=torch.long,\n",
303 | ")\n",
304 | "data = Data(x=x, edge_index=edge_index)\n",
305 | "data_with_virtual_edges = connect_two_hop_neighbors(data)\n",
306 | "\n",
307 | "plt.figure(figsize=(8, 4))\n",
308 | "\n",
309 | "plt.subplot(121)\n",
310 | "draw_bipartite(data)\n",
311 | "\n",
312 | "plt.subplot(122)\n",
313 | "draw_bipartite(data_with_virtual_edges)"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "id": "7e5ff688",
319 | "metadata": {},
320 | "source": [
321 | "## Add virtual nodes"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 10,
327 | "id": "a7d43d06",
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "def add_virtual_node(data: Data) -> Data:\n",
332 | " data = deepcopy(data)\n",
333 | " node_id = len(data.x)\n",
334 | " data.x = torch.vstack([data.x, torch.ones(1)])\n",
335 | " list1 = [node_id] * node_id\n",
336 | " list2 = list(range(node_id))\n",
337 | " data.edge_index = torch.hstack([data.edge_index, torch.tensor([list1 + list2, list2 + list1])])\n",
338 | " return data"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": 11,
344 | "id": "d0a7a182",
345 | "metadata": {},
346 | "outputs": [],
347 | "source": [
348 | "def average_shortest_path_length(data: Data) -> float:\n",
349 | " G = to_networkx(data, to_undirected=data.is_undirected())\n",
350 | " return nx.average_shortest_path_length(G)"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": 12,
356 | "id": "3bb20f1c",
357 | "metadata": {},
358 | "outputs": [
359 | {
360 | "data": {
361 | "image/png": "\n",
362 | "text/plain": [
363 | ""
364 | ]
365 | },
366 | "metadata": {},
367 | "output_type": "display_data"
368 | }
369 | ],
370 | "source": [
371 | "x = torch.ones((6, 1))\n",
372 | "edge_index = torch.tensor(\n",
373 | " [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5], [1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]], dtype=torch.long\n",
374 | ")\n",
375 | "data = Data(x=x, edge_index=edge_index)\n",
376 | "data_with_virtual_node = add_virtual_node(data)\n",
377 | "\n",
378 | "plt.figure(figsize=(12, 4))\n",
379 | "\n",
380 | "ax1 = plt.subplot(121)\n",
381 | "ax1.set_title(f\"Average shortest path length: {average_shortest_path_length(data):.2f}\")\n",
382 | "draw_graph_from_data(data)\n",
383 | "\n",
384 | "ax2 = plt.subplot(122)\n",
385 | "ax2.set_title(f\"Average shortest path length: {average_shortest_path_length(data_with_virtual_node):.2f}\")\n",
386 | "draw_graph_from_data(data_with_virtual_node)"
387 | ]
388 | },
389 | {
390 | "cell_type": "markdown",
391 | "id": "4e0ef4c7",
392 | "metadata": {},
393 | "source": [
394 | "# Node Neighborhood Sampling"
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "id": "35e2d8c6",
400 | "metadata": {},
401 | "source": [
402 | "> The following part is an adaptation of a [example](https://github.com/pyg-team/pytorch_geometric/blob/bee6ca2e78890e57c97f71b6110dc86cbdbf5efb/examples/reddit.py)."
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": 13,
408 | "id": "294405db",
409 | "metadata": {},
410 | "outputs": [],
411 | "source": [
412 | "from torch_geometric.datasets import Reddit\n",
413 | "from torch_geometric.loader import NeighborSampler\n",
414 | "from torch_geometric.nn import SAGEConv"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "execution_count": 14,
420 | "id": "75147c14",
421 | "metadata": {},
422 | "outputs": [],
423 | "source": [
424 | "class SAGE(torch.nn.Module):\n",
425 | " def __init__(self, in_channels: int, hidden_channels: int, out_channels: int) -> None:\n",
426 | " super().__init__()\n",
427 | "\n",
428 | " self.num_layers = 2\n",
429 | "\n",
430 | " self.convs = torch.nn.ModuleList()\n",
431 | " self.convs.append(SAGEConv(in_channels, hidden_channels))\n",
432 | " self.convs.append(SAGEConv(hidden_channels, out_channels))\n",
433 | "\n",
434 | " def forward(self, x: torch.Tensor, adjs: List[EdgeIndex]) -> torch.Tensor:\n",
435 | " # `train_loader` computes the k-hop neighborhood of a batch of nodes,\n",
436 | " # and returns, for each layer, a bipartite graph object, holding the\n",
437 | " # bipartite edges `edge_index`, the index `e_id` of the original edges,\n",
438 | " # and the size/shape `size` of the bipartite graph.\n",
439 | " # Target nodes are also included in the source nodes so that one can\n",
440 | " # easily apply skip-connections or add self-loops.\n",
441 | " for i, (edge_index, _, size) in enumerate(adjs):\n",
442 | " x_target = x[: size[1]] # Target nodes are always placed first.\n",
443 | " x = self.convs[i]((x, x_target), edge_index)\n",
444 | " if i != self.num_layers - 1:\n",
445 | " x = F.relu(x)\n",
446 | " x = F.dropout(x, p=0.5, training=self.training)\n",
447 | " return x.log_softmax(dim=-1)\n",
448 | "\n",
449 | " def inference(self, x_all: torch.Tensor, subgraph_loader: NeighborSampler) -> torch.Tensor:\n",
450 | " pbar = tqdm(total=x_all.size(0) * self.num_layers)\n",
451 | " pbar.set_description(\"Evaluating\")\n",
452 | "\n",
453 | " # Compute representations of nodes layer by layer, using *all*\n",
454 | " # available edges. This leads to faster computation in contrast to\n",
455 | " # immediately computing the final representations of each batch.\n",
456 | " for i in range(self.num_layers):\n",
457 | " xs = []\n",
458 | " for batch_size, n_id, adj in subgraph_loader:\n",
459 | " edge_index, _, size = adj.to(device)\n",
460 | " x = x_all[n_id].to(device)\n",
461 | " x_target = x[: size[1]]\n",
462 | " x = self.convs[i]((x, x_target), edge_index)\n",
463 | " if i != self.num_layers - 1:\n",
464 | " x = F.relu(x)\n",
465 | " xs.append(x.cpu())\n",
466 | "\n",
467 | " pbar.update(batch_size)\n",
468 | "\n",
469 | " x_all = torch.cat(xs, dim=0)\n",
470 | "\n",
471 | " pbar.close()\n",
472 | "\n",
473 | " return x_all"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": 15,
479 | "id": "905b6b3a",
480 | "metadata": {},
481 | "outputs": [],
482 | "source": [
483 | "class Trainer:\n",
484 | " def __init__(\n",
485 | " self,\n",
486 | " model: torch.nn.Module,\n",
487 | " x: torch.Tensor,\n",
488 | " y: torch.Tensor,\n",
489 | " train_mask: torch.Tensor,\n",
490 | " val_mask: torch.Tensor,\n",
491 | " test_mask: torch.Tensor,\n",
492 | " optimizer: torch.optim.Optimizer,\n",
493 | " sizes: List[int] = [25, 10],\n",
494 | " batch_size: int = 512,\n",
495 | " num_workers: int = 2,\n",
496 | " ) -> None:\n",
497 | " self.model = model\n",
498 | " self.data = data\n",
499 | " self.x = x\n",
500 | " self.y = y\n",
501 | " self.train_mask = train_mask\n",
502 | " self.val_mask = val_mask\n",
503 | " self.test_mask = test_mask\n",
504 | " self.optimizer = optimizer\n",
505 | " self.train_loader = NeighborSampler(\n",
506 | " data.edge_index,\n",
507 | " node_idx=data.train_mask,\n",
508 | " sizes=sizes,\n",
509 | " batch_size=batch_size,\n",
510 | " shuffle=True,\n",
511 | " num_workers=num_workers,\n",
512 | " )\n",
513 | " self.subgraph_loader = NeighborSampler(\n",
514 | " data.edge_index, node_idx=None, sizes=[-1], batch_size=batch_size, shuffle=False, num_workers=num_workers\n",
515 | " )\n",
516 | "\n",
517 | " def training_epoch(self, epoch: int) -> float:\n",
518 | " self.model.train()\n",
519 | " pbar = tqdm(total=int(self.train_mask.sum()))\n",
520 | " pbar.set_description(f\"Epoch {epoch:02d}\")\n",
521 | " total_loss = 0\n",
522 | " for batch_size, n_id, adjs in self.train_loader:\n",
523 | " # `adjs` holds a list of `(edge_index, e_id, size)` tuples.\n",
524 | " adjs = [adj.to(device) for adj in adjs]\n",
525 | " self.optimizer.zero_grad()\n",
526 | " out = self.model(self.x[n_id], adjs)\n",
527 | " loss = F.nll_loss(out, self.y[n_id[:batch_size]])\n",
528 | " loss.backward()\n",
529 | " self.optimizer.step()\n",
530 | " total_loss += float(loss)\n",
531 | " pbar.update(batch_size)\n",
532 | "\n",
533 | " pbar.close()\n",
534 | " loss = total_loss / len(self.train_loader)\n",
535 | " return loss\n",
536 | "\n",
537 | " @torch.no_grad()\n",
538 | " def evaluate(self) -> List[float]:\n",
539 | " self.model.eval()\n",
540 | "\n",
541 | " out = self.model.inference(self.x, self.subgraph_loader)\n",
542 | " y_true = self.y.cpu().unsqueeze(-1)\n",
543 | " y_pred = out.argmax(dim=-1, keepdim=True)\n",
544 | "\n",
545 | " results = []\n",
546 | " for mask in [self.train_mask, self.val_mask, self.test_mask]:\n",
547 | " results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]\n",
548 | "\n",
549 | " return results\n",
550 | "\n",
551 | " def fit(self, num_epochs: int = 10) -> None:\n",
552 | " for epoch in range(1, num_epochs + 1):\n",
553 | " loss = self.training_epoch(epoch)\n",
554 | " print(f\"Epoch {epoch:02d}, Loss: {loss:.4f}\")\n",
555 | " train_acc, val_acc, test_acc = self.evaluate()\n",
556 | " print(f\"Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}\")"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": 16,
562 | "id": "cfdd899a",
563 | "metadata": {
564 | "scrolled": true
565 | },
566 | "outputs": [
567 | {
568 | "name": "stderr",
569 | "output_type": "stream",
570 | "text": [
571 | "Epoch 01: 100%|██████████| 153431/153431 [00:06<00:00, 23528.50it/s]\n"
572 | ]
573 | },
574 | {
575 | "name": "stdout",
576 | "output_type": "stream",
577 | "text": [
578 | "Epoch 01, Loss: 0.6978\n"
579 | ]
580 | },
581 | {
582 | "name": "stderr",
583 | "output_type": "stream",
584 | "text": [
585 | "Evaluating: 100%|██████████| 465930/465930 [00:21<00:00, 21265.90it/s]\n"
586 | ]
587 | },
588 | {
589 | "name": "stdout",
590 | "output_type": "stream",
591 | "text": [
592 | "Train: 0.9416, Val: 0.9423, Test: 0.9385\n"
593 | ]
594 | },
595 | {
596 | "name": "stderr",
597 | "output_type": "stream",
598 | "text": [
599 | "Epoch 02: 100%|██████████| 153431/153431 [00:06<00:00, 24566.45it/s]\n"
600 | ]
601 | },
602 | {
603 | "name": "stdout",
604 | "output_type": "stream",
605 | "text": [
606 | "Epoch 02, Loss: 0.8775\n"
607 | ]
608 | },
609 | {
610 | "name": "stderr",
611 | "output_type": "stream",
612 | "text": [
613 | "Evaluating: 100%|██████████| 465930/465930 [00:21<00:00, 21307.97it/s]\n"
614 | ]
615 | },
616 | {
617 | "name": "stdout",
618 | "output_type": "stream",
619 | "text": [
620 | "Train: 0.9482, Val: 0.9441, Test: 0.9448\n"
621 | ]
622 | },
623 | {
624 | "name": "stderr",
625 | "output_type": "stream",
626 | "text": [
627 | "Epoch 03: 100%|██████████| 153431/153431 [00:06<00:00, 24169.32it/s]\n"
628 | ]
629 | },
630 | {
631 | "name": "stdout",
632 | "output_type": "stream",
633 | "text": [
634 | "Epoch 03, Loss: 0.9626\n"
635 | ]
636 | },
637 | {
638 | "name": "stderr",
639 | "output_type": "stream",
640 | "text": [
641 | "Evaluating: 100%|██████████| 465930/465930 [00:21<00:00, 21296.08it/s]\n"
642 | ]
643 | },
644 | {
645 | "name": "stdout",
646 | "output_type": "stream",
647 | "text": [
648 | "Train: 0.9485, Val: 0.9454, Test: 0.9434\n"
649 | ]
650 | },
651 | {
652 | "name": "stderr",
653 | "output_type": "stream",
654 | "text": [
655 | "Epoch 04: 100%|██████████| 153431/153431 [00:06<00:00, 23724.96it/s]\n"
656 | ]
657 | },
658 | {
659 | "name": "stdout",
660 | "output_type": "stream",
661 | "text": [
662 | "Epoch 04, Loss: 0.9915\n"
663 | ]
664 | },
665 | {
666 | "name": "stderr",
667 | "output_type": "stream",
668 | "text": [
669 | "Evaluating: 100%|██████████| 465930/465930 [00:22<00:00, 21157.18it/s]\n"
670 | ]
671 | },
672 | {
673 | "name": "stdout",
674 | "output_type": "stream",
675 | "text": [
676 | "Train: 0.9507, Val: 0.9464, Test: 0.9452\n"
677 | ]
678 | },
679 | {
680 | "name": "stderr",
681 | "output_type": "stream",
682 | "text": [
683 | "Epoch 05: 100%|██████████| 153431/153431 [00:06<00:00, 23812.61it/s]\n"
684 | ]
685 | },
686 | {
687 | "name": "stdout",
688 | "output_type": "stream",
689 | "text": [
690 | "Epoch 05, Loss: 0.9560\n"
691 | ]
692 | },
693 | {
694 | "name": "stderr",
695 | "output_type": "stream",
696 | "text": [
697 | "Evaluating: 100%|██████████| 465930/465930 [00:22<00:00, 21124.83it/s]\n"
698 | ]
699 | },
700 | {
701 | "name": "stdout",
702 | "output_type": "stream",
703 | "text": [
704 | "Train: 0.9538, Val: 0.9472, Test: 0.9475\n"
705 | ]
706 | }
707 | ],
708 | "source": [
709 | "path = os.path.join(\"..\", \"tmp\", \"data\", \"Reddit\")\n",
710 | "dataset = Reddit(path)\n",
711 | "data = dataset[0]\n",
712 | "\n",
713 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
714 | "model = SAGE(dataset.num_features, 256, dataset.num_classes)\n",
715 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
716 | "\n",
717 | "trainer = Trainer(\n",
718 | " model=model.to(device),\n",
719 | " x=data.x.to(device),\n",
720 | " y=data.y.squeeze().to(device),\n",
721 | " train_mask=data.train_mask,\n",
722 | " val_mask=data.val_mask,\n",
723 | " test_mask=data.test_mask,\n",
724 | " optimizer=optimizer,\n",
725 | ")\n",
726 | "trainer.fit(num_epochs=5)\n",
727 | "\n",
728 | "torch.cuda.empty_cache()"
729 | ]
730 | },
731 | {
732 | "cell_type": "markdown",
733 | "id": "575bf948",
734 | "metadata": {},
735 | "source": [
736 | "# Prediction Heads: Node-level"
737 | ]
738 | },
739 | {
740 | "cell_type": "markdown",
741 | "id": "6cc46b18",
742 | "metadata": {},
743 | "source": [
744 | "> See previous example."
745 | ]
746 | },
747 | {
748 | "cell_type": "markdown",
749 | "id": "a8a36ed8",
750 | "metadata": {},
751 | "source": [
752 | "# Prediction Heads: Edge-level"
753 | ]
754 | },
755 | {
756 | "cell_type": "markdown",
757 | "id": "e4d3422f",
758 | "metadata": {},
759 | "source": [
760 | "> The following part is an adaptation of a PyTorch Geometric [example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py)."
761 | ]
762 | },
763 | {
764 | "cell_type": "code",
765 | "execution_count": 17,
766 | "id": "490f4a4b",
767 | "metadata": {},
768 | "outputs": [],
769 | "source": [
770 | "from torch_geometric.datasets import Planetoid\n",
771 | "from torch_geometric.nn import GCNConv\n",
772 | "from torch_geometric.utils import negative_sampling"
773 | ]
774 | },
775 | {
776 | "cell_type": "code",
777 | "execution_count": 18,
778 | "id": "4302ba04",
779 | "metadata": {},
780 | "outputs": [],
781 | "source": [
782 | "class GCN(torch.nn.Module):\n",
783 | " def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, use_dot: bool = True) -> None:\n",
784 | " super().__init__()\n",
785 | " self.conv1 = GCNConv(in_channels, hidden_channels)\n",
786 | " self.conv2 = GCNConv(hidden_channels, out_channels)\n",
787 | " self.use_dot = use_dot\n",
788 | " if not use_dot:\n",
789 | " self.linear = torch.nn.Linear(out_channels * 2, 1)\n",
790 | "\n",
791 | " def encode(self, x: torch.Tensor, edge_index) -> torch.Tensor:\n",
792 | " x = self.conv1(x, edge_index).relu()\n",
793 | " return self.conv2(x, edge_index)\n",
794 | "\n",
795 | " def decode(self, z: torch.Tensor, edge_label_index) -> torch.Tensor:\n",
796 | " if self.use_dot:\n",
797 | " return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)\n",
798 | " else:\n",
799 | " z = torch.hstack([z[edge_label_index[0]], z[edge_label_index[1]]])\n",
800 | " return self.linear(z)"
801 | ]
802 | },
803 | {
804 | "cell_type": "code",
805 | "execution_count": 19,
806 | "id": "527bdd2d",
807 | "metadata": {},
808 | "outputs": [],
809 | "source": [
810 | "class Trainer:\n",
811 | " def __init__(\n",
812 | " self,\n",
813 | " model: torch.nn.Module,\n",
814 | " train_data: Data,\n",
815 | " val_data: Data,\n",
816 | " test_data: Data,\n",
817 | " optimizer: torch.optim.Optimizer,\n",
818 | " ) -> None:\n",
819 | " self.model = model\n",
820 | " self.train_data = train_data\n",
821 | " self.val_data = val_data\n",
822 | " self.test_data = test_data\n",
823 | " self.optimizer = optimizer\n",
824 | "\n",
825 | " def training_epoch(self) -> torch.Tensor:\n",
826 | " self.model.train()\n",
827 | " self.optimizer.zero_grad()\n",
828 | " z = model.encode(self.train_data.x, self.train_data.edge_index)\n",
829 | "\n",
830 | " # We perform a new round of negative sampling for every training epoch:\n",
831 | " neg_edge_index = negative_sampling(\n",
832 | " edge_index=self.train_data.edge_index,\n",
833 | " num_nodes=self.train_data.num_nodes,\n",
834 | " num_neg_samples=self.train_data.edge_label_index.size(1),\n",
835 | " method=\"sparse\",\n",
836 | " )\n",
837 | "\n",
838 | " edge_label_index = torch.cat(\n",
839 | " [train_data.edge_label_index, neg_edge_index],\n",
840 | " dim=-1,\n",
841 | " )\n",
842 | " edge_label = torch.cat([train_data.edge_label, train_data.edge_label.new_zeros(neg_edge_index.size(1))], dim=0)\n",
843 | "\n",
844 | " out = model.decode(z, edge_label_index).view(-1)\n",
845 | " loss = F.binary_cross_entropy_with_logits(out, edge_label)\n",
846 | " loss.backward()\n",
847 | " optimizer.step()\n",
848 | " return loss\n",
849 | "\n",
850 | " @torch.no_grad()\n",
851 | " def evaluate(self) -> List[float]:\n",
852 | " model.eval()\n",
853 | " results = []\n",
854 | " for data in [self.val_data, self.test_data]:\n",
855 | " z = model.encode(data.x, data.edge_index)\n",
856 | " out = model.decode(z, data.edge_label_index).view(-1).sigmoid()\n",
857 | " auc = roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())\n",
858 | " results.append(auc)\n",
859 | "\n",
860 | " return results\n",
861 | "\n",
862 | " def fit(self, num_epochs: int = 100, print_interval: int = 1) -> None:\n",
863 | " for epoch in range(1, num_epochs + 1):\n",
864 | " loss = self.training_epoch()\n",
865 | " val_auc, test_auc = self.evaluate()\n",
866 | " if epoch % print_interval == 0:\n",
867 | " print(f\"Epoch {epoch:02d}, Loss: {loss:.4f}\")\n",
868 | " print(f\"Val: {val_auc:.4f}, Test: {test_auc:.4f}\")"
869 | ]
870 | },
871 | {
872 | "cell_type": "markdown",
873 | "id": "84995c9b",
874 | "metadata": {},
875 | "source": [
876 | "## (1) Concatenation + Linear"
877 | ]
878 | },
879 | {
880 | "cell_type": "code",
881 | "execution_count": 20,
882 | "id": "1abeaa84",
883 | "metadata": {},
884 | "outputs": [
885 | {
886 | "name": "stdout",
887 | "output_type": "stream",
888 | "text": [
889 | "Epoch 20, Loss: 0.5618\n",
890 | "Val: 0.6947, Test: 0.7304\n",
891 | "Epoch 40, Loss: 0.5206\n",
892 | "Val: 0.7169, Test: 0.7433\n",
893 | "Epoch 60, Loss: 0.4923\n",
894 | "Val: 0.7146, Test: 0.7367\n",
895 | "Epoch 80, Loss: 0.5081\n",
896 | "Val: 0.7067, Test: 0.7336\n",
897 | "Epoch 100, Loss: 0.5037\n",
898 | "Val: 0.7050, Test: 0.7331\n"
899 | ]
900 | }
901 | ],
902 | "source": [
903 | "USE_DOT = False\n",
904 | "\n",
905 | "\n",
906 | "device = torch.device(\"cpu\")\n",
907 | "# See https://github.com/pyg-team/pytorch_geometric/issues/3641\n",
908 | "\n",
909 | "transform = T.Compose(\n",
910 | " [\n",
911 | " T.NormalizeFeatures(),\n",
912 | " T.ToDevice(device),\n",
913 | " T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False),\n",
914 | " ]\n",
915 | ")\n",
916 | "\n",
917 | "path = os.path.join(\"..\", \"tmp\", \"data\", \"Planetoid\")\n",
918 | "dataset = Planetoid(path, name=\"Cora\", transform=transform)\n",
919 | "# After applying the `RandomLinkSplit` transform, the data is transformed from\n",
920 | "# a data object to a list of tuples (train_data, val_data, test_data), with\n",
921 | "# each element representing the corresponding split.\n",
922 | "train_data, val_data, test_data = dataset[0]\n",
923 | "\n",
924 | "model = GCN(dataset.num_features, 128, 64, use_dot=USE_DOT).to(device)\n",
925 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
926 | "\n",
927 | "trainer = Trainer(\n",
928 | " model=model.to(device),\n",
929 | " train_data=train_data,\n",
930 | " val_data=val_data,\n",
931 | " test_data=test_data,\n",
932 | " optimizer=optimizer,\n",
933 | ")\n",
934 | "trainer.fit(print_interval=20)"
935 | ]
936 | },
937 | {
938 | "cell_type": "markdown",
939 | "id": "bdbaebc2",
940 | "metadata": {},
941 | "source": [
942 | "## (2) Dot product"
943 | ]
944 | },
945 | {
946 | "cell_type": "code",
947 | "execution_count": 21,
948 | "id": "3114da30",
949 | "metadata": {
950 | "scrolled": false
951 | },
952 | "outputs": [
953 | {
954 | "name": "stdout",
955 | "output_type": "stream",
956 | "text": [
957 | "Epoch 20, Loss: 0.6341\n",
958 | "Val: 0.7855, Test: 0.7926\n",
959 | "Epoch 40, Loss: 0.5270\n",
960 | "Val: 0.8369, Test: 0.8390\n",
961 | "Epoch 60, Loss: 0.4692\n",
962 | "Val: 0.8991, Test: 0.8986\n",
963 | "Epoch 80, Loss: 0.4549\n",
964 | "Val: 0.9197, Test: 0.9046\n",
965 | "Epoch 100, Loss: 0.4411\n",
966 | "Val: 0.9274, Test: 0.9039\n"
967 | ]
968 | }
969 | ],
970 | "source": [
971 | "USE_DOT = True\n",
972 | "\n",
973 | "\n",
974 | "device = torch.device(\"cpu\")\n",
975 | "# See https://github.com/pyg-team/pytorch_geometric/issues/3641\n",
976 | "\n",
977 | "transform = T.Compose(\n",
978 | " [\n",
979 | " T.NormalizeFeatures(),\n",
980 | " T.ToDevice(device),\n",
981 | " T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False),\n",
982 | " ]\n",
983 | ")\n",
984 | "\n",
985 | "path = os.path.join(\"..\", \"tmp\", \"data\", \"Planetoid\")\n",
986 | "dataset = Planetoid(path, name=\"Cora\", transform=transform)\n",
987 | "# After applying the `RandomLinkSplit` transform, the data is transformed from\n",
988 | "# a data object to a list of tuples (train_data, val_data, test_data), with\n",
989 | "# each element representing the corresponding split.\n",
990 | "train_data, val_data, test_data = dataset[0]\n",
991 | "\n",
992 | "model = GCN(dataset.num_features, 128, 64, use_dot=USE_DOT).to(device)\n",
993 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
994 | "\n",
995 | "trainer = Trainer(\n",
996 | " model=model.to(device),\n",
997 | " train_data=train_data,\n",
998 | " val_data=val_data,\n",
999 | " test_data=test_data,\n",
1000 | " optimizer=optimizer,\n",
1001 | ")\n",
1002 | "trainer.fit(print_interval=20)"
1003 | ]
1004 | },
1005 | {
1006 | "cell_type": "markdown",
1007 | "id": "d011895b",
1008 | "metadata": {},
1009 | "source": [
1010 | "# Prediction Heads: Graph-level"
1011 | ]
1012 | },
1013 | {
1014 | "cell_type": "markdown",
1015 | "id": "ba875e25",
1016 | "metadata": {},
1017 | "source": [
1018 | "## (1) Global mean pooling"
1019 | ]
1020 | },
1021 | {
1022 | "cell_type": "markdown",
1023 | "id": "6659379b",
1024 | "metadata": {},
1025 | "source": [
1026 | "TODO, see [global_mean_pool](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html?highlight=global%20pooling#torch_geometric.nn.glob.global_mean_pool)"
1027 | ]
1028 | },
1029 | {
1030 | "cell_type": "markdown",
1031 | "id": "5bc7769b",
1032 | "metadata": {},
1033 | "source": [
1034 | "## (2) Global max pooling"
1035 | ]
1036 | },
1037 | {
1038 | "cell_type": "markdown",
1039 | "id": "2e7c2dbe",
1040 | "metadata": {},
1041 | "source": [
1042 | "TODO, see [global_max_pool](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html?highlight=global%20pooling#torch_geometric.nn.glob.global_max_pool)"
1043 | ]
1044 | },
1045 | {
1046 | "cell_type": "markdown",
1047 | "id": "87b09c4e",
1048 | "metadata": {},
1049 | "source": [
1050 | "## (3) Global sum pooling"
1051 | ]
1052 | },
1053 | {
1054 | "cell_type": "markdown",
1055 | "id": "3e677b16",
1056 | "metadata": {},
1057 | "source": [
1058 | "TODO, see [global_add_pool](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html?highlight=global%20pooling#torch_geometric.nn.glob.global_add_pool)"
1059 | ]
1060 | },
1061 | {
1062 | "cell_type": "markdown",
1063 | "id": "fe4f9d37",
1064 | "metadata": {},
1065 | "source": [
1066 | "# Hierarchical Global Pooling"
1067 | ]
1068 | },
1069 | {
1070 | "cell_type": "markdown",
1071 | "id": "086d0ec4",
1072 | "metadata": {},
1073 | "source": [
1074 | "TODO, see [pooling-layers](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#pooling-layers)"
1075 | ]
1076 | }
1077 | ],
1078 | "metadata": {
1079 | "kernelspec": {
1080 | "display_name": "Python 3",
1081 | "language": "python",
1082 | "name": "python3"
1083 | },
1084 | "language_info": {
1085 | "codemirror_mode": {
1086 | "name": "ipython",
1087 | "version": 3
1088 | },
1089 | "file_extension": ".py",
1090 | "mimetype": "text/x-python",
1091 | "name": "python",
1092 | "nbconvert_exporter": "python",
1093 | "pygments_lexer": "ipython3",
1094 | "version": "3.7.12"
1095 | }
1096 | },
1097 | "nbformat": 4,
1098 | "nbformat_minor": 5
1099 | }
1100 |
--------------------------------------------------------------------------------