├── 0_intro
└── README.md
├── 1_reasoning
├── README.md
├── Reasoning_tutorial.ipynb
├── Reasoning_tutorial_solution.ipynb
├── ins_sort.png
└── mul.png
├── 2_reinforcement_learning
├── README.md
├── RL_tutorial.ipynb
└── RL_tutorial_solution.ipynb
├── 3_vision_language_models
├── README.md
├── VLM_tutorial.ipynb
└── VLM_tutorial_solution.ipynb
├── 4_geometric_deep_learning
├── GDL_tutorial.ipynb
├── GDL_tutorial_solution.ipynb
└── README.md
└── README.md
/0_intro/README.md:
--------------------------------------------------------------------------------
1 | # [[EEML2024](https://www.eeml.eu)] Tutorial 0: Introduction to Colab and PyTorch
2 |
3 | **Authors:** Mandana Samiei and Teodor Szente
4 |
5 | ---
6 |
7 | In this tutorial, we will learn how to use Colab notebook and train a simple model in PyTorch.
8 |
9 | ### Outline
10 |
11 | - [Tutorial 1] Intro to colab
12 | - [Tutorial 2] Intro to PyTorch (MLP)
13 | - [Tutorial 3] Advanced PyTorch (CNN)
14 |
15 |
16 | ### Notebooks
17 |
18 | Tutorial 1: [](https://colab.research.google.com/drive/1ocFPgLnE1X7YypI35x835AVhAj-gZulB?usp=sharing)
19 |
20 | Tutorial 2: [](https://colab.research.google.com/drive/1B8JQGsoXJcZQ9n4fLivyOQTo27LY1ALd?usp=sharing)
21 |
22 | Tutorial 3: [](https://colab.research.google.com/drive/1ekLHmSVpuHaFHscZL_CrvnJgNaQpKTbi?usp=sharing)
23 |
24 | ---
25 |
--------------------------------------------------------------------------------
/1_reasoning/README.md:
--------------------------------------------------------------------------------
1 | # [[EEML2024](https://www.eeml.eu)] Tutorial 1: Reasoning
2 |
3 | **Authors:** Petar Veličković and Matko Bošnjak
4 |
5 | ---
6 |
7 | In this tutorial, we will explore the wonderful and challenging domain of algorithmic reasoning 🔢 with deep neural networks 🤖.
8 |
9 |
10 | [Introduction video](https://www.youtube.com/watch?v=CyIuM5eQZ5A)
11 |
12 |
13 | ### Outline
14 |
15 | - What even _is_ reasoning?
16 | - Setup and installation of necessary Python libraries.
17 | - Practical 1: How can we train a model to robustly execute computation?
18 | - Practical 2: Exploration of model / algorithm variations
19 | - Practical 3: Algorithmically tuning a large language model
20 |
21 |
22 | ### Notebooks
23 |
24 | Tutorial: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/1_reasoning/Reasoning_tutorial.ipynb)
26 |
27 |
28 | Solved: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/1_reasoning/Reasoning_tutorial_solution.ipynb)
30 |
31 | ---
32 |
--------------------------------------------------------------------------------
/1_reasoning/ins_sort.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eemlcommunity/PracticalSessions2024/ec7d5b1aa9869060e4cd222525f54bec22a24024/1_reasoning/ins_sort.png
--------------------------------------------------------------------------------
/1_reasoning/mul.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eemlcommunity/PracticalSessions2024/ec7d5b1aa9869060e4cd222525f54bec22a24024/1_reasoning/mul.png
--------------------------------------------------------------------------------
/2_reinforcement_learning/README.md:
--------------------------------------------------------------------------------
1 | # [[EEML2024](https://www.eeml.eu)] Tutorial 2: Reinforcement Learning
2 |
3 | **Authors:** Andreea Deac and Ognjen Milinković
4 |
5 | ---
6 |
7 | This tutorial will explore policy-based reinforcement learning agents, starting from a random agent in a grid-like environment and finishing with fine-tuning a pre-trained language model in the style of RLHF.
8 |
9 |
10 | [Introduction video](https://www.youtube.com/watch?v=nCKXzXrdWYo)
11 |
12 |
13 | ### Outline
14 |
15 | - Intro to policy-based RL
16 | - Setup and installation of necessary Python libraries.
17 | - Part I: Training a simple RL agent using REINFORCE.
18 | - Part II: Building more complex agents: values, advantages, entropy -- A2C and PPO.
19 | - Part III: Fine-tune GPT2 to give more positive reviews.
20 |
21 |
22 | ### Notebooks
23 |
24 | Tutorial: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/2_reinforcement_learning/RL_tutorial.ipynb)
26 |
27 |
28 | Solution: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/2_reinforcement_learning/RL_tutorial_solution.ipynb)
30 |
31 | ---
32 |
--------------------------------------------------------------------------------
/3_vision_language_models/README.md:
--------------------------------------------------------------------------------
1 | # [[EEML2024](https://www.eeml.eu)] Tutorial 3: Vision-Language Models
2 |
3 | **Authors:** Aishwarya Kamath, Anastasia Ilić and Ioana Bica
4 |
5 | ---
6 |
7 | In this tutorial we'll explore how we can use image-text data to build Vision Language Models 🚀. We'll start with an introduction to multimodal understanding that describes the main components of a Vision Lanugage Model and provides a brief history of how these have evolved in recent years. Then, we'll dive deep into Contrastive Language-Image Pre-training (CLIP), a model for learning general representation from image-text pairs that can be used for a wide range of downstream tasks. We'll then explore how CLIP can be used for semantic image search followed by a showcase of its failure cases. Finally, we'll finetune together PaliGemma, a powerful 3B vision language model.
8 |
9 |
10 | [Introduction video](https://www.youtube.com/watch?v=zdejKiH06CU)
11 |
12 |
13 | ### Outline
14 |
15 | - Part I: Introduction to multimodal understanding.
16 | - Part II: Contrastive Language-Image Pre-training (CLIP).
17 | - Part III: PaliGemma.
18 |
19 |
20 | ### Notebooks
21 |
22 | Tutorial: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/3_vision_language_models/VLM_tutorial.ipynb)
24 |
25 | Solution: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/3_vision_language_models/VLM_tutorial_solution.ipynb)
27 |
28 | ---
29 |
--------------------------------------------------------------------------------
/4_geometric_deep_learning/GDL_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "\n",
7 | "
\n"
8 | ],
9 | "metadata": {
10 | "id": "a-KY9hPOj7Jz"
11 | }
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "source": [
16 | "**Here are the authors.** Do reach out to us for any queries or feedback!\n",
17 | "\n",
18 | "* Viktor Mirjanic (vvm22@cam.ac.uk) \\\\\n",
19 | "* Iulia Duta (id366@cam.ac.uk)\n",
20 | "\n",
21 | "\n",
22 | "**Abstract:** This tutorial is designed to be a stand-alone introduction into the Graph Neural Network world. You will learn about the basics of working with graph data, implementing a standard Graph Network architecture and understand more about current challenges and open research problems in the field such as rewiring and positional encoding.\n",
23 | "\n",
24 | "\n",
25 | "**Outline:**\n",
26 | "\n",
27 | "The tutorial is structured as follow:\n",
28 | "1. Implement both a sparse and a dense version of **Graph Convolutional Network in Pytorch**.\n",
29 | "2. Write a training pipeline for graph inputs including **graph-level representation** and **custom mini-batching**.\n",
30 | "3. Improve the Graph Convolutional Network using attention mechanisms - **Graph Attention Network**.\n",
31 | "4. Make our first steps into **Pytorch Geometric**, a library dedicated to geometric deep learning.\n",
32 | "5. Re-implement **Graph Attention Network in Pytorch Geometric**.\n",
33 | "6. Understand the **over-squashing challenge** and experiment with **two graph rewiring techniques**: Graph Transformer and Expander Graph Propagation.\n",
34 | "7. Explore various **positional encodings** for graph data.\n"
35 | ],
36 | "metadata": {
37 | "id": "5KgbGnPp-gkM"
38 | }
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "source": [
43 | " ❗ Note: While a GPU is not mandatory for this tutorial, we recommend using it to speed up the training. You can do this by clicking `Runtime -> Change runtime type`, and set the hardware accelerator to GPU."
44 | ],
45 | "metadata": {
46 | "id": "G8LYCIB0RJtc"
47 | }
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "source": [
52 | "## 😴 Preliminaries: Install, import and other modules"
53 | ],
54 | "metadata": {
55 | "id": "lKxFuNjpAeVz"
56 | }
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "source": [
62 | "# @title [RUN] Install required python libraries\n",
63 | "!pip install -q networkx\n",
64 | "!pip install -q mycolorpy\n",
65 | "!pip install -q colorama\n",
66 | "!pip install torch==2.1.0\n",
67 | "!pip install -q torch-geometric\n",
68 | "\n",
69 | "import os\n",
70 | "import torch\n",
71 | "os.environ['TORCH'] = torch.__version__\n",
72 | "!pip install -q pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
73 | "\n",
74 | "!sudo apt-get install -qq graphviz graphviz-dev\n",
75 | "!pip install -q pygraphviz\n"
76 | ],
77 | "outputs": [],
78 | "metadata": {
79 | "cellView": "form",
80 | "id": "u-lbFF3r08Um"
81 | }
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "source": [
87 | "# @title [RUN] Import modules\n",
88 | "import sys\n",
89 | "import time\n",
90 | "import math\n",
91 | "import random\n",
92 | "import itertools\n",
93 | "from typing import Mapping, Tuple, Sequence, List\n",
94 | "\n",
95 | "import pandas as pd\n",
96 | "import networkx as nx\n",
97 | "import numpy as np\n",
98 | "import scipy as sp\n",
99 | "from scipy.stats import ortho_group\n",
100 | "from scipy.linalg import block_diag\n",
101 | "\n",
102 | "import pdb\n",
103 | "\n",
104 | "import torch\n",
105 | "import torch.nn.functional as F\n",
106 | "from torch.optim import Adam\n",
107 | "from torch.nn import Embedding, Linear, ReLU, BatchNorm1d, Module, ModuleList, Sequential, Parameter, LayerNorm\n",
108 | "from torch import Tensor\n",
109 | "from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum\n",
110 | "\n",
111 | "import torch_geometric\n",
112 | "from torch_geometric.data import Data, Batch\n",
113 | "from torch_geometric.loader import DataLoader\n",
114 | "import torch_geometric.transforms as T\n",
115 | "from torch_geometric.transforms import BaseTransform\n",
116 | "from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse, softmax, get_laplacian, cumsum, add_self_loops\n",
117 | "from torch_geometric.utils.convert import to_scipy_sparse_matrix\n",
118 | "from torch_geometric.nn import MessagePassing, global_mean_pool, global_add_pool\n",
119 | "# from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum\n",
120 | "\n",
121 | "from torch_geometric.nn.inits import glorot\n",
122 | "\n",
123 | "from sklearn.metrics import accuracy_score, roc_auc_score\n",
124 | "\n",
125 | "import matplotlib.pyplot as plt\n",
126 | "import seaborn as sns\n",
127 | "import matplotlib.cm as cm\n",
128 | "\n",
129 | "from google.colab import files\n",
130 | "from IPython.display import HTML\n",
131 | "\n",
132 | "from numpy.linalg import eig, eigh\n",
133 | "\n",
134 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
135 | ],
136 | "outputs": [],
137 | "metadata": {
138 | "cellView": "form",
139 | "id": "R7Xw6Me_NRQ9"
140 | }
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "source": [
146 | "# @title [RUN] Helper functions for plots and visualisations\n",
147 | "\n",
148 | "\n",
149 | "def draw_one_graph(ax, edges, label=None, node_emb=None, layout=None,\n",
150 | " special_color=False, pos=None):\n",
151 | " \"\"\"draw a graph with networkx based on adjacency matrix (edges)\n",
152 | " graph labels could be displayed as a title for each graph\n",
153 | " node_emb could be displayed in colors\n",
154 | " \"\"\"\n",
155 | " graph = nx.Graph()\n",
156 | " edges = zip(edges[0], edges[1])\n",
157 | " graph.add_edges_from(edges)\n",
158 | " if layout == 'custom':\n",
159 | " node_pos = pos\n",
160 | " elif layout == 'tree':\n",
161 | " node_pos=nx.nx_agraph.graphviz_layout(graph, prog='dot')\n",
162 | " else:\n",
163 | " node_pos = layout(graph)\n",
164 | " #add colors according to node embeding\n",
165 | " if (node_emb is not None) or special_color:\n",
166 | " color_map = []\n",
167 | " node_list = [node[0] for node in graph.nodes(data = True)]\n",
168 | " for i,node in enumerate(node_list):\n",
169 | " #just ignore this branch\n",
170 | " if special_color:\n",
171 | " if len(node_list) == 3:\n",
172 | " crt_color = (1,0,0)\n",
173 | " elif len(node_list) == 5:\n",
174 | " crt_color = (0,1,0)\n",
175 | " elif len(node_list) == 4:\n",
176 | " crt_color = (1,1,0)\n",
177 | " else:\n",
178 | " special_list = [(1,0,0)] * 3 + [(0,1,0)] * 5 + [(1,1,0)] * 4\n",
179 | " crt_color = special_list[i]\n",
180 | " else:\n",
181 | " crt_node_emb = node_emb[node]\n",
182 | " #map float number (node embeding) to a color\n",
183 | " crt_color = cm.gist_rainbow(crt_node_emb, bytes=True)\n",
184 | " crt_color = (crt_color[0]/255.0, crt_color[1]/255.0, crt_color[2]/255.0, crt_color[3]/255.0)\n",
185 | " color_map.append(crt_color)\n",
186 | "\n",
187 | " nx.draw_networkx_nodes(graph,node_pos, node_color=color_map,\n",
188 | " nodelist = node_list, ax=ax)\n",
189 | " nx.draw_networkx_edges(graph, node_pos, ax=ax)\n",
190 | " nx.draw_networkx_labels(graph,node_pos, ax=ax)\n",
191 | " else:\n",
192 | " nx.draw_networkx(graph, node_pos, ax=ax)\n",
193 | "\n",
194 | "def gallery(graphs, labels=None, node_emb=None, special_color=False, max_graphs=4, max_fig_size=(40, 10), layout=nx.layout.kamada_kawai_layout):\n",
195 | " ''' Draw multiple graphs as a gallery\n",
196 | " Args:\n",
197 | " graphs: torch_geometrics.dataset object/ List of Graph objects\n",
198 | " labels: num_graphs\n",
199 | " node_emb: num_graphs* [num_nodes x num_ch]\n",
200 | " max_graphs: maximum graphs display\n",
201 | " '''\n",
202 | " num_graphs = min(len(graphs), max_graphs)\n",
203 | " ff, axes = plt.subplots(1, num_graphs,\n",
204 | " figsize=max_fig_size,\n",
205 | " subplot_kw={'xticks': [], 'yticks': []})\n",
206 | " if num_graphs == 1:\n",
207 | " axes = [axes]\n",
208 | " if node_emb is None:\n",
209 | " node_emb = num_graphs*[None]\n",
210 | " if labels is None:\n",
211 | " labels = num_graphs * [\" \"]\n",
212 | "\n",
213 | "\n",
214 | " for i in range(num_graphs):\n",
215 | " draw_one_graph(axes[i], graphs[i].edge_index.numpy(), labels[i], node_emb[i], layout, special_color,\n",
216 | " pos=graphs[i].pos)\n",
217 | " if labels[i] != \" \":\n",
218 | " axes[i].set_title(f\"Target: {labels[i]}\", fontsize=28)\n",
219 | " axes[i].set_axis_off()\n",
220 | " plt.show()\n",
221 | "\n",
222 | "def draw_one_tree(data):\n",
223 | " \"\"\"draw a tree with networkx\n",
224 | " this function is only used to display the tree\n",
225 | " objects from the synthetic datasets\n",
226 | " \"\"\"\n",
227 | " edges = data.edge_index.detach().cpu().numpy()\n",
228 | " feats = data.x.detach().cpu().numpy()\n",
229 | "\n",
230 | " graph = nx.DiGraph()\n",
231 | " edges = zip(edges[0], edges[1])\n",
232 | " graph.add_edges_from(edges)\n",
233 | " graph = graph.reverse()\n",
234 | "\n",
235 | " node_pos=nx.nx_agraph.graphviz_layout(graph, prog='dot')\n",
236 | "\n",
237 | " labels = dict()\n",
238 | " for i,node in enumerate(graph.nodes):\n",
239 | " labels[node] = f\"{feats[node,0]}, {feats[node,1]}\"\n",
240 | "\n",
241 | "\n",
242 | " colors = [0.5 for node in graph.nodes()]\n",
243 | " for i,node in enumerate(graph.nodes):\n",
244 | " if feats[node,0] != 0 and feats[node,1] == 0:\n",
245 | " colors[i] = 0.25\n",
246 | " #add colors according to node embeding\n",
247 | " nx.draw_networkx(graph, node_pos, node_size=500, labels=labels,\n",
248 | " cmap=plt.get_cmap('Pastel1'), node_color=colors)\n",
249 | "\n",
250 | "\n",
251 | "\n",
252 | "def plot_results(perf_per_epoch):\n",
253 | " df_results = pd.DataFrame(perf_per_epoch,\n",
254 | " columns=[\"Test AUC\", \"Val AUC\", \"Test loss\",\n",
255 | " \"Val loss\", \"Train loss\", \"Epoch\", \"Model\"])\n",
256 | " p = sns.lineplot(x=\"Epoch\", y=\"Val AUC\", hue=\"Model\", data=df_results)\n",
257 | " plt.show()\n",
258 | " p = sns.lineplot(x=\"Epoch\", y=\"Test AUC\", hue=\"Model\", data=df_results)\n",
259 | " plt.show()\n",
260 | "\n",
261 | "def get_color_coded_str(i, color):\n",
262 | " return \"\\033[3{}m{}\\033[0m\".format(int(color), int(i))\n",
263 | "\n",
264 | "\n",
265 | "def print_color_numpy(map, list_graphs):\n",
266 | " \"\"\" print matrix map in color according to list_graphs\n",
267 | " \"\"\"\n",
268 | " list_blocks = []\n",
269 | " total_num_nodes = 0\n",
270 | " for i,graph in enumerate(list_graphs):\n",
271 | " block_i = (i+1)*np.ones((graph.num_nodes,graph.num_nodes))\n",
272 | " list_blocks += [block_i]\n",
273 | " total_num_nodes += graph.num_nodes\n",
274 | " block_color = block_diag(*list_blocks)\n",
275 | "\n",
276 | " map_modified = np.vectorize(get_color_coded_str)(map, block_color)\n",
277 | "\n",
278 | " colored_matrix = \"\\n\".join([\" \".join([\"{}\"]*total_num_nodes)]*total_num_nodes).format(*[x for y in map_modified.tolist() for x in y])\n",
279 | " print(colored_matrix)\n",
280 | "\n",
281 | "def update_stats(training_stats, epoch_stats):\n",
282 | " \"\"\" Store metrics along the training\n",
283 | " Args:\n",
284 | " epoch_stats: dict containg metrics about one epoch\n",
285 | " training_stats: dict containing lists of metrics along training\n",
286 | " Returns:\n",
287 | " updated training_stats\n",
288 | " \"\"\"\n",
289 | " if training_stats is None:\n",
290 | " training_stats = {}\n",
291 | " for key in epoch_stats.keys():\n",
292 | " training_stats[key] = []\n",
293 | " for key,val in epoch_stats.items():\n",
294 | " training_stats[key].append(val)\n",
295 | " return training_stats\n",
296 | "\n",
297 | "def plot_stats(training_stats, figsize=(5, 5), name=\"\"):\n",
298 | " \"\"\" Create one plot for each metric stored in training_stats\n",
299 | " \"\"\"\n",
300 | " stats_names = [key[6:] for key in training_stats.keys() if key.startswith('train_')]\n",
301 | " f, ax = plt.subplots(len(stats_names), 1, figsize=figsize)\n",
302 | " if len(stats_names)==1:\n",
303 | " ax = np.array([ax])\n",
304 | " for key, axx in zip(stats_names, ax.reshape(-1,)):\n",
305 | " axx.plot(\n",
306 | " training_stats['epoch'],\n",
307 | " training_stats[f'train_{key}'],\n",
308 | " label=f\"Training {key}\")\n",
309 | " axx.plot(\n",
310 | " training_stats['epoch'],\n",
311 | " training_stats[f'val_{key}'],\n",
312 | " label=f\"Validation {key}\")\n",
313 | " axx.set_xlabel(\"Training epoch\")\n",
314 | " axx.set_ylabel(key)\n",
315 | " axx.legend()\n",
316 | " plt.title(name)\n"
317 | ],
318 | "outputs": [],
319 | "metadata": {
320 | "cellView": "form",
321 | "id": "yzE7HwUvoaUv"
322 | }
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": null,
327 | "source": [
328 | "# @title [RUN] Helper functions for generating synthetic data\n",
329 | "\n",
330 | "# code based on: https://jhuow.fun/posts/barbell_graph/\n",
331 | "def one_barbell_graph(n_clique, n_path):\n",
332 | " clique1 = nx.complete_graph(n_clique)\n",
333 | " clique1_pos = nx.circular_layout(clique1)\n",
334 | " clique2 = nx.complete_graph(n_clique)\n",
335 | " clique2_mapping = {node: node + n_clique for node in clique2}\n",
336 | " nx.relabel_nodes(clique2, clique2_mapping, copy=False) # avoids repeated nodes\n",
337 | " x_diff, y_diff = 8, -1\n",
338 | " clique2_pos = {node: clique1_pos[node-n_clique] + (x_diff, y_diff) for node in clique2}\n",
339 | " path = nx.path_graph(n_path)\n",
340 | " path_mapping = {node: node + 2 * n_clique for node in path}\n",
341 | " nx.relabel_nodes(path, path_mapping, copy=False) # avoids repeated nodes\n",
342 | " path_nodes = list(path.nodes)\n",
343 | " path_half1_nodes = path_nodes[:n_path//2]\n",
344 | " path_half2_nodes = path_nodes[n_path//2:]\n",
345 | " path_dist = 0.9\n",
346 | " clique2_entry = n_clique + n_clique // 2\n",
347 | " path_half1_pos = {node: clique1_pos[0] + (path_dist + i * path_dist, 0) for i, node in enumerate(path_half1_nodes)}\n",
348 | " path_half2_pos = {node: clique2_pos[clique2_entry] - (path_dist + i * path_dist, 0) for i, node in enumerate(path_half2_nodes[::-1])}\n",
349 | " path_pos = {**path_half1_pos, **path_half2_pos}\n",
350 | " barbell = nx.Graph()\n",
351 | " barbell.add_edges_from(clique1.edges)\n",
352 | " barbell.add_edges_from(clique2.edges)\n",
353 | " barbell.add_edges_from(path.edges)\n",
354 | " barbell.add_edges_from([(path_half1_nodes[0], 0), (path_half2_nodes[-1], clique2_entry)])\n",
355 | " clique_pos = {**clique1_pos, **clique2_pos}\n",
356 | " barbell_pos = {**clique_pos, **path_pos}\n",
357 | " barbell_dict = {}\n",
358 | " for k,v in barbell_pos.items():\n",
359 | " barbell_dict[k] = {\"pos\":v}\n",
360 | " nx.set_node_attributes(barbell, barbell_dict)\n",
361 | " return barbell\n",
362 | "\n",
363 | "# code inspired by https://github.com/lrnzgiusti/on-oversquashing\n",
364 | "def generate_barbell_graph(m1, m2, target_label):\n",
365 | " \"\"\"\n",
366 | " Generate a barbell graph.\n",
367 | "\n",
368 | " Args:\n",
369 | " - m1:\n",
370 | " - m2:\n",
371 | " - target_label\n",
372 | "\n",
373 | " Returns:\n",
374 | " - Data: Torch geometric data structure containing graph details.\n",
375 | " \"\"\"\n",
376 | " barbell_graph = one_barbell_graph(m1, m2)\n",
377 | " nodes = 2 * m1 + m2\n",
378 | " # Initialize node features. The first node gets 0s, while the last gets the target label\n",
379 | " x = np.ones((nodes, len(target_label)))\n",
380 | " x[0, :] = 0.0\n",
381 | " x[nodes - 1, :] = target_label\n",
382 | " x = torch.tensor(x, dtype=torch.float32)\n",
383 | "\n",
384 | " edge_index = torch.tensor(list(barbell_graph.edges())).T\n",
385 | "\n",
386 | " # Create a mask to indicate the target node (in this case, the first node)\n",
387 | " mask = torch.zeros(nodes, dtype=torch.bool)\n",
388 | " mask[0] = 1\n",
389 | "\n",
390 | " # Convert the one-hot encoded target label to its corresponding class index\n",
391 | " y = torch.tensor([np.argmax(target_label)], dtype=torch.long)\n",
392 | "\n",
393 | " return Data(x=x, edge_index=edge_index, mask=mask, y=y, pos=nx.get_node_attributes(barbell_graph, 'pos'))\n",
394 | "\n",
395 | "\n",
396 | "# The code for generating the tree dataset is based on:\n",
397 | "# https://github.com/tech-srl/bottleneck\n",
398 | "import math\n",
399 | "class TreeDataset(object):\n",
400 | " def __init__(self, depth):\n",
401 | " super(TreeDataset, self).__init__()\n",
402 | " self.depth = depth\n",
403 | " self.num_nodes, self.edges, self.leaf_indices = self._create_blank_tree()\n",
404 | " # self.criterion = F.cross_entropy\n",
405 | "\n",
406 | " def add_child_edges(self, cur_node, max_node):\n",
407 | " edges = []\n",
408 | " leaf_indices = []\n",
409 | " stack = [(cur_node, max_node)]\n",
410 | " while len(stack) > 0:\n",
411 | " cur_node, max_node = stack.pop()\n",
412 | " if cur_node == max_node:\n",
413 | " leaf_indices.append(cur_node)\n",
414 | " continue\n",
415 | " left_child = cur_node + 1\n",
416 | " right_child = cur_node + 1 + ((max_node - cur_node) // 2)\n",
417 | " edges.append([left_child, cur_node])\n",
418 | " edges.append([right_child, cur_node])\n",
419 | " stack.append((right_child, max_node))\n",
420 | " stack.append((left_child, right_child - 1))\n",
421 | " return edges, leaf_indices\n",
422 | "\n",
423 | " def _create_blank_tree(self):\n",
424 | " max_node_id = 2 ** (self.depth + 1) - 2\n",
425 | " edges, leaf_indices = self.add_child_edges(cur_node=0, max_node=max_node_id)\n",
426 | " return max_node_id + 1, edges, leaf_indices\n",
427 | "\n",
428 | " def create_blank_tree(self, add_self_loops=True):\n",
429 | " edge_index = torch.tensor(self.edges).t()\n",
430 | " if add_self_loops:\n",
431 | " edge_index, _ = torch_geometric.utils.add_remaining_self_loops(edge_index=edge_index, )\n",
432 | " return edge_index\n",
433 | "\n",
434 | " def generate_data(self, transform=None, add_self_loops=True):\n",
435 | " data_list = []\n",
436 | "\n",
437 | " for comb in self.get_combinations():\n",
438 | " edge_index = self.create_blank_tree(add_self_loops=add_self_loops)\n",
439 | " nodes = torch.tensor(self.get_nodes_features(comb), dtype=torch.long)\n",
440 | " root_mask = torch.tensor([True] + [False] * (len(nodes) - 1))\n",
441 | " label = self.label(comb)\n",
442 | " graph = Data(x=nodes, edge_index=edge_index, root_mask=root_mask, y=label)\n",
443 | " if transform:\n",
444 | " graph = transform(graph)\n",
445 | " data_list.append(graph)\n",
446 | "\n",
447 | " dim0, out_dim = self.get_dims()\n",
448 | " return data_list, dim0, out_dim\n",
449 | " # X_train, X_test = train_test_split(\n",
450 | " # data_list, train_size=train_fraction, shuffle=True, stratify=[data.y for data in data_list])\n",
451 | "\n",
452 | "\n",
453 | " # return X_train, X_test, dim0, out_dim, self.criterion\n",
454 | "\n",
455 | " # Every sub-class should implement the following methods:\n",
456 | " def get_combinations(self):\n",
457 | " raise NotImplementedError\n",
458 | "\n",
459 | " def get_nodes_features(self, combination):\n",
460 | " raise NotImplementedError\n",
461 | "\n",
462 | " def label(self, combination):\n",
463 | " raise NotImplementedError\n",
464 | "\n",
465 | " def get_dims(self):\n",
466 | " raise NotImplementedError\n",
467 | "\n",
468 | "class DictionaryLookupDataset(TreeDataset):\n",
469 | " def __init__(self, depth):\n",
470 | " super(DictionaryLookupDataset, self).__init__(depth)\n",
471 | "\n",
472 | " def get_combinations(self):\n",
473 | " # returns: an iterable of [key, permutation(leaves)]\n",
474 | " # number of combinations: (num_leaves!)*num_choices\n",
475 | " num_leaves = len(self.leaf_indices)\n",
476 | " num_permutations = 1000\n",
477 | " max_examples = 8000\n",
478 | "\n",
479 | " if self.depth > 3:\n",
480 | " per_depth_num_permutations = min(num_permutations, math.factorial(num_leaves), max_examples // num_leaves)\n",
481 | " permutations = [np.random.permutation(range(1, num_leaves + 1)) for _ in\n",
482 | " range(per_depth_num_permutations)]\n",
483 | " else:\n",
484 | " permutations = random.sample(list(itertools.permutations(range(1, num_leaves + 1))),\n",
485 | " min(num_permutations, math.factorial(num_leaves)))\n",
486 | "\n",
487 | " return itertools.chain.from_iterable(\n",
488 | "\n",
489 | " zip(range(1, num_leaves + 1), itertools.repeat(perm))\n",
490 | " for perm in permutations)\n",
491 | "\n",
492 | " def get_nodes_features(self, combination):\n",
493 | " # combination: a list of indices\n",
494 | " # Each leaf contains a one-hot encoding of a key, and a one-hot encoding of the value\n",
495 | " # Every other node is empty, for now\n",
496 | " selected_key, values = combination\n",
497 | "\n",
498 | " # The root is [one-hot selected key] + [0 ... 0]\n",
499 | " nodes = [ (selected_key, 0) ]\n",
500 | "\n",
501 | " for i in range(1, self.num_nodes):\n",
502 | " if i in self.leaf_indices:\n",
503 | " leaf_num = self.leaf_indices.index(i)\n",
504 | " node = (leaf_num+1, values[leaf_num])\n",
505 | " else:\n",
506 | " node = (0, 0)\n",
507 | " nodes.append(node)\n",
508 | " return nodes\n",
509 | "\n",
510 | " def label(self, combination):\n",
511 | " selected_key, values = combination\n",
512 | " return int(values[selected_key - 1])\n",
513 | "\n",
514 | " def get_dims(self):\n",
515 | " # get input and output dims\n",
516 | " in_dim = len(self.leaf_indices)\n",
517 | " out_dim = len(self.leaf_indices)\n",
518 | " return in_dim, out_dim"
519 | ],
520 | "outputs": [],
521 | "metadata": {
522 | "cellView": "form",
523 | "id": "5Q6RyEaZ1IoB"
524 | }
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": null,
529 | "source": [
530 | "#@title [RUN] Set random seed for deterministic results\n",
531 | "def seed(seed=0):\n",
532 | " random.seed(seed)\n",
533 | " np.random.seed(seed)\n",
534 | " torch.manual_seed(seed)\n",
535 | " torch.cuda.manual_seed(seed)\n",
536 | " torch.cuda.manual_seed_all(seed)\n",
537 | " torch.backends.cudnn.deterministic = True\n",
538 | " torch.backends.cudnn.benchmark = False\n",
539 | "\n",
540 | "seed(0)"
541 | ],
542 | "outputs": [],
543 | "metadata": {
544 | "cellView": "form",
545 | "id": "KM4ynsEfVTZM"
546 | }
547 | },
548 | {
549 | "cell_type": "markdown",
550 | "source": [
551 | "# 🔨 [Basic] **Towards Implementing our own Graph Neural Network** "
552 | ],
553 | "metadata": {
554 | "id": "97XbufFzYVQO"
555 | }
556 | },
557 | {
558 | "cell_type": "markdown",
559 | "source": [
560 | "\n",
561 | "Graphs are very general data structures, which can represent a wide variety of natural phenomena. Mathematically, a graph is a tuple $G = (V, E)$ where\n",
562 | "- $V$ is a set of vertices or **nodes** representing some entities\n",
563 | "- $E$ is a set of **edges** between nodes representing pairwise relationship between entities.\n",
564 | " \n",
565 | "\n",
566 | "Graphs can be **directed** or **undirected**. A graph is undirected if edge from $u$ to $v$ exists if and only if a \"backwards\" edge from $v$ to $u$ also exists.\n",
567 | "\n",
568 | "A **neighborhood** of a node $u$ is the set ${N}_u=\\lbrace v\\ | \\left(u,v\\right)\\in E\\rbrace$ of all nodes $v$ directly connected to $u$.\n",
569 | "\n",
570 | "There exist different ways to store the edges of a graph:\n",
571 | "- As an **adjacency matrix** $\\mathbf{A}$, a $|V|\\times|V|$ matrix such that $\\mathbf{A}_{ij}=1$ if there is an edge between $i$th and $j$th node, and $0$ otherwise.\n",
572 | "- As an **edge list** $\\mathbf{E}$, a $2\\times |E|$ matrix storing start and end indices of each edge.\n",
573 | "\n",
574 | "Note that the order of vertices and edges is in principle arbitrary, but we have to decide on one to store the graph in memory.\n",
575 | "\n"
576 | ],
577 | "metadata": {
578 | "id": "Fu7fs6IZYVQg"
579 | }
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "source": [
584 | "🔎 Let's look at an example graph:\n",
585 | "\n",
586 | "\n",
587 | "\n",
588 | "_(Image source: Wikipedia)_\n",
589 | "\n",
590 | "It has 6 nodes and 7 edges. However, because it is _undirected_, we will be storing 14 entries!\n",
591 | "\n",
592 | "\n",
593 | "Its edge list is:\n",
594 | "$$\\mathbf{E}=\\begin{pmatrix}\n",
595 | "1\\ 2\\ 1\\ 5\\ 2\\ 5\\ 2\\ 3\\ 3\\ 4\\ 4\\ 5\\ 4\\ 6 \\\\\n",
596 | "2\\ 1\\ 5\\ 1\\ 5\\ 2\\ 3\\ 2\\ 4\\ 3\\ 5\\ 4\\ 6\\ 4\n",
597 | "\\end{pmatrix}$$\n",
598 | "\n",
599 | "and its adjacency matrix is:\n",
600 | "\n",
601 | "$$\\mathbf{A}=\\begin{pmatrix}\n",
602 | "0\\ {\\color{red}{1}}\\ 0\\ 0\\ {\\color{red}{1}}\\ 0 \\\\\n",
603 | "{\\color{red}{1}}\\ 0\\ {\\color{red}{1}}\\ 0\\ {\\color{red}{1}}\\ 0 \\\\\n",
604 | "0\\ {\\color{red}{1}}\\ 0\\ {\\color{red}{1}}\\ 0\\ 0 \\\\\n",
605 | "0\\ 0\\ {\\color{red}{1}}\\ 0\\ {\\color{red}{1}}\\ {\\color{red}{1}} \\\\\n",
606 | "{\\color{red}{1}}\\ {\\color{red}{1}}\\ 0\\ {\\color{red}{1}}\\ 0\\ 0 \\\\\n",
607 | "0\\ 0\\ 0\\ {\\color{red}{1}}\\ 0\\ 0\n",
608 | "\\end{pmatrix}$$\n",
609 | "\n",
610 | "Note that as the graph is undirected, $\\mathbf{A}$ is symmetrical along the main diagonal. Also, note that the main diagonal of $\\mathbf{A}$ is zero because the graph does not have *self-loops* (i.e. edges connecting a node to itself $u\\to u$).\n",
611 | "\n",
612 | "The neighborhood of node $2$ is $N_2 =\\{1,3,5\\}$."
613 | ],
614 | "metadata": {
615 | "id": "LspRQOdgYVQg"
616 | }
617 | },
618 | {
619 | "cell_type": "markdown",
620 | "source": [
621 | "## Problem Setup"
622 | ],
623 | "metadata": {
624 | "id": "J-OWdKGpYVQh"
625 | }
626 | },
627 | {
628 | "cell_type": "markdown",
629 | "source": [
630 | "The diversity of tasks involving graph data led to the development of three major type of problems:\n",
631 | "\n",
632 | "- **Node prediction**: we require on target for each node. *e.g. what is the topic of a paper given a citation network of papers?*\n",
633 | "- **Link (edge) prediction**: we require on target for each pair of nodes. *e.g. are two people in a social network friends?*\n",
634 | "- **Graph prediction**: we require a single target for the entire graph. *e.g. is this protein molecule (represented as a graph) likely going to be effective?*\n",
635 | "\n",
636 | "\n",
637 | "\n",
638 | "_(Image source: Petar Veličković)_\n",
639 | "\n"
640 | ],
641 | "metadata": {
642 | "id": "jB-cMp4bYVQh"
643 | }
644 | },
645 | {
646 | "cell_type": "markdown",
647 | "source": [
648 | "In this part of the tutorial, we will work with the ZINC12K benchmark - a dataset containing about 12k molecules. Our goal is to predict, for each molecule, its \"constrained solubility\". This is a graph regression task as we need to predict a single value per graph.\n",
649 | "\n",
650 | "Let's load the data. It might take a while 🫖 :\n"
651 | ],
652 | "metadata": {
653 | "id": "ZZWIp1IxYVQh"
654 | }
655 | },
656 | {
657 | "cell_type": "code",
658 | "execution_count": null,
659 | "source": [
660 | "from torch_geometric.datasets import ZINC\n",
661 | "\n",
662 | "train_ds = ZINC(root='data/ZINC', split='train', subset=True)\n",
663 | "val_ds = ZINC(root='data/ZINC', split='val', subset=True)\n",
664 | "test_ds = ZINC(root='data/ZINC', split='test', subset=True)"
665 | ],
666 | "outputs": [],
667 | "metadata": {
668 | "id": "Gwt4MTnGYVQh"
669 | }
670 | },
671 | {
672 | "cell_type": "markdown",
673 | "source": [
674 | "Now let's look at one molecule from the dataset:"
675 | ],
676 | "metadata": {
677 | "id": "PJOljVmSYVQi"
678 | }
679 | },
680 | {
681 | "cell_type": "code",
682 | "execution_count": null,
683 | "source": [
684 | "sample = train_ds[0]\n",
685 | "gallery([sample], max_fig_size=(4, 6))"
686 | ],
687 | "outputs": [],
688 | "metadata": {
689 | "id": "y5rlsmcgYVQi"
690 | }
691 | },
692 | {
693 | "cell_type": "markdown",
694 | "source": [
695 | "Graphs in ZINC have only one input feature - the atom's id, which is stored in `graph.x`. This molecule has 29 atoms (we can inspect this by looking at the `graph.num_nodes` attribute) and 64 edges (information stored in `graph.num_edges`).\n",
696 | "\n",
697 | "The graph connectivity is represented as an edge list `graph.edge_index`, a $2\\times |E|$ tensor containing starting nodes in the first row, and ending nodes in the second. This sparse representation has benefits that we will explore in a later section. For now, let's quickly convert this list into an adjacency matrix. We will store this dense representation in `graph.adj`.\n",
698 | "\n",
699 | "❗️Note: the molecule also has _edge features_ stored in `graph.edge_attr`, which, for simplicity, we will ignore in this part of the tutorial."
700 | ],
701 | "metadata": {
702 | "id": "zVQC9NN5YVQi"
703 | }
704 | },
705 | {
706 | "cell_type": "code",
707 | "execution_count": null,
708 | "source": [
709 | "from torch_geometric.utils import to_torch_coo_tensor\n",
710 | "\n",
711 | "sample_adj = Data(x=sample.x, adj=to_torch_coo_tensor(sample.edge_index).to_dense())\n",
712 | "print(\"The matrix containing the node features has shape: \", sample_adj.x.shape)\n",
713 | "print(\"The adjacency matrix has shape: \", sample_adj.adj.shape)"
714 | ],
715 | "outputs": [],
716 | "metadata": {
717 | "id": "JFcyRzroYVQi"
718 | }
719 | },
720 | {
721 | "cell_type": "markdown",
722 | "source": [
723 | "## Graph Convolutional Network"
724 | ],
725 | "metadata": {
726 | "id": "cq9dvgwuYVQi"
727 | }
728 | },
729 | {
730 | "cell_type": "markdown",
731 | "source": [
732 | "As a first step in modeling our graph-shaped molecule, we will construct a [Graph Convolutional Network](https://arxiv.org/abs/1609.02907) (GCN) layer."
733 | ],
734 | "metadata": {
735 | "id": "zzvEG8QEYVQi"
736 | }
737 | },
738 | {
739 | "cell_type": "markdown",
740 | "source": [
741 | " The GCN is comparable to convolutional layers on image data, with one significant difference. Images are uniformly structured - its pixels are always ordered into a grid - and that allows convolutional layers to apply kernels on local patches. Following the assumption that pixels that are close in space should be correlated, a classical Convolutional Neural Network aggregates information from a local $D \\times H$ patch.\n",
742 | "\n",
743 | "Unlike pixels in images, nodes doesn't have a position in space. However, we can develop a similar intuition: nodes that are connected to each other in the graph should be more correlated than distant nodes. So, just like the kernels of Convolutional Neural Network are _local_ operations, we want our graph layers to respect the graph topology in a similar way. **Nodes that are close to each other according to the graph topology, should influence the representation more than the far away nodes.**\n",
744 | "\n",
745 | "\n",
746 | "
\n",
747 | "\n",
748 | "\n"
749 | ],
750 | "metadata": {
751 | "id": "6ysADiXDYVQi"
752 | }
753 | },
754 | {
755 | "cell_type": "markdown",
756 | "source": [
757 | "For a graph $\\mathbf{G}$ with $n$ nodes and $e$ edges, suppose we start with input node features $\\mathbf{X}_{n\\times d}$ and the adjacency matrix $\\mathbf{A}_{n\\times n}$. Let's consider the following operation:\n",
758 | "\n",
759 | "$$\n",
760 | "\\mathbf{H} = \\sigma \\big( \\mathbf{A} \\mathbf{X} \\mathbf{W} \\big)\n",
761 | "$$\n",
762 | "\n",
763 | "where $\\mathbf{W}_{d\\times d'}$ are trainable weights, $\\mathbf{H}_{n\\times d'}$ are output node features, and $\\sigma$ is a pointwise non-linearity.\n",
764 | "\n",
765 | "\n",
766 | "On a per-node level, we can re-write this layer as\n",
767 | "\n",
768 | "$$\n",
769 | "\\mathbf{h_i} = \\sigma \\big( \\sum_{j \\in \\mathbf{N}_i} \\mathbf{x_j}\\mathbf{W}\\big)\n",
770 | "$$\n",
771 | "\n",
772 | "\n",
773 | "where $\\mathbf{N}_i$ denotes the neighbourhood of node $i$.\n",
774 | "\n",
775 | " ✍ **Exercise:** Convince yourself that these two equations really mean the same thing!.\n",
776 | "\n",
777 | "\n"
778 | ],
779 | "metadata": {
780 | "id": "A7ZJpjzjYVQi"
781 | }
782 | },
783 | {
784 | "cell_type": "markdown",
785 | "source": [
786 | "\n",
787 | "\n",
789 | "\n",
790 | "
\n"
791 | ],
792 | "metadata": {
793 | "id": "y6jfoSnoYVQi"
794 | }
795 | },
796 | {
797 | "cell_type": "markdown",
798 | "source": [
799 | "Observe how this simple layer **respects the graph topology** by only exchanging information between neighboring nodes via the adjacency matrix. This is perfectly in line with our requirements: a model in which **neighbours influence the representation more than the far away nodes**. Moreover, it nicely mirrors the Convolutional Network, where instead of the local image patches we are now aggregating neighbourhoods.\n",
800 | "\n",
801 | "This layer has some problems, though. Graph nodes may have different degrees, possibly leading to numerical instabilities during training. To alleviate it, the [Graph Convolutional Network (GCN)](https://arxiv.org/abs/1609.02907) model introduces _symmetric normalization_:\n",
802 | "\n",
803 | "$$\n",
804 | "\\mathbf{H}_\\text{GCN} = \\sigma \\big( \\mathbf{\\tilde{D}}^{-\\frac{1}{2}} \\mathbf{\\tilde{A}} \\mathbf{\\tilde{D}}^{-\\frac{1}{2}} \\mathbf{X} \\mathbf{W} \\big)\n",
805 | "$$\n",
806 | "\n",
807 | "In this equation, $\\mathbf{\\tilde{A}}=\\mathbf{A}+\\mathbf{I}$ represent the adjacency matrix enriched with self-loops (edges that connect a node with itself), and $\\mathbf{\\tilde{D}}$ is the degree matrix, a diagonal matrix containing degrees of $\\mathbf{\\tilde{A}}$, $\\mathbf{\\tilde{D}}_{ii} = \\sum_j \\mathbf{\\tilde{A}}_{ij}$ .\n",
808 | "\n",
809 | " ✍ **Exercise:** Think about how can you write down the node-level action of this normalised GCN. What does adding $\\mathbf{D}$ change in [this equation](#node-level-gcn)?\n",
810 | "\n"
811 | ],
812 | "metadata": {
813 | "id": "Ppof2XK8YVQi"
814 | }
815 | },
816 | {
817 | "cell_type": "markdown",
818 | "source": [
819 | "## 🖋 **Task** Implement the symmetric normalization for the GCN layer using PyTorch.\n",
820 | "\n",
821 | "> We expect you to implement $\\mathbf{\\tilde{D}}^{-\\frac{1}{2}} \\mathbf{\\tilde{A}} \\mathbf{\\tilde{D}}^{-\\frac{1}{2}}$ as used in the GCN Layer\n",
822 | ">\n",
823 | "> 🆘 **Hint:** you may assume that the initial adjacency A does not have self-edges\n"
824 | ],
825 | "metadata": {
826 | "id": "G6GawNjEYVQi"
827 | }
828 | },
829 | {
830 | "cell_type": "code",
831 | "execution_count": null,
832 | "source": [
833 | "class GCNLayer(Module):\n",
834 | "\n",
835 | " def __init__(self, in_channels, out_channels):\n",
836 | " \"\"\"\n",
837 | " One layer of Graph Convolutional Network (GCN)\n",
838 | " using the dense adjacency matrix\n",
839 | "\n",
840 | " Args:\n",
841 | " in_channels: (int) - input dimension\n",
842 | " out_channels: (int) - output dimension\n",
843 | " \"\"\"\n",
844 | " super(GCNLayer, self).__init__()\n",
845 | " self.linear = Linear(in_channels, out_channels)\n",
846 | "\n",
847 | " def forward(self, x, A):\n",
848 | " \"\"\"\n",
849 | " Args:\n",
850 | " x: (n, in_dim) - initial node features\n",
851 | " A: (n, n) - adjacency matrix\n",
852 | "\n",
853 | " Returns:\n",
854 | " out: (n, out_dim) - updated node features\n",
855 | " \"\"\"\n",
856 | " # ============ YOUR CODE HERE ==============\n",
857 | " # Compute the normalised adjacency matrix\n",
858 | " # as denoted by the equation above\n",
859 | " #\n",
860 | " # Atilde = ...\n",
861 | " # Dtilde = ...\n",
862 | " # adj_norm = ...\n",
863 | " #\n",
864 | " # ===========================================\n",
865 | "\n",
866 | " x = self.linear(x)\n",
867 | " out = adj_norm @ x\n",
868 | " return out"
869 | ],
870 | "outputs": [],
871 | "metadata": {
872 | "id": "PDg41GtAYVQj"
873 | }
874 | },
875 | {
876 | "cell_type": "markdown",
877 | "source": [
878 | "We will also create a simple model that uses our GCNLayer:"
879 | ],
880 | "metadata": {
881 | "id": "pKnglfMDYVQj"
882 | }
883 | },
884 | {
885 | "cell_type": "code",
886 | "execution_count": null,
887 | "source": [
888 | "class SimpleGCN(Module):\n",
889 | " def __init__(self):\n",
890 | " \"\"\"\n",
891 | " A GNN model applying one GCN layers, to create graph-level representation\n",
892 | "\n",
893 | " \"\"\"\n",
894 | " super().__init__()\n",
895 | " hidden = 64\n",
896 | " self.embed = Embedding(28, hidden) # There are 28 different atoms in ZINC\n",
897 | " self.gcn = GCNLayer(hidden, 1)\n",
898 | "\n",
899 | " def forward(self, data):\n",
900 | " \"\"\"\n",
901 | " Args:\n",
902 | " data: (PyG.Data) - one graph from the dataset\n",
903 | "\n",
904 | " Returns:\n",
905 | " out: (float) - a scalar representing the output for the entire graph\n",
906 | " \"\"\"\n",
907 | " x = data.x\n",
908 | " A = data.adj\n",
909 | "\n",
910 | " x = self.embed(x).squeeze(1)\n",
911 | " x = F.relu(x)\n",
912 | " x = self.gcn(x, A)\n",
913 | "\n",
914 | " # ============ YOUR CODE HERE ==============\n",
915 | " # (For Task 3)\n",
916 | " # return ...\n",
917 | " # ==================================\n",
918 | "\n",
919 | " return torch.sum(x)\n",
920 | "\n",
921 | "\n",
922 | "simple_gcn_model = SimpleGCN()"
923 | ],
924 | "outputs": [],
925 | "metadata": {
926 | "id": "Ej_-Vv0lYVQj"
927 | }
928 | },
929 | {
930 | "cell_type": "markdown",
931 | "source": [
932 | "Let's apply our GCN network to our sample graph (the result should be a single number):"
933 | ],
934 | "metadata": {
935 | "id": "AtiIkv72YVQj"
936 | }
937 | },
938 | {
939 | "cell_type": "code",
940 | "execution_count": null,
941 | "source": [
942 | "simple_gcn_model(sample_adj).detach().numpy()"
943 | ],
944 | "outputs": [],
945 | "metadata": {
946 | "id": "RR3a_HIZYVQj"
947 | }
948 | },
949 | {
950 | "cell_type": "markdown",
951 | "source": [
952 | "## Permutation Invarance and Equivariance"
953 | ],
954 | "metadata": {
955 | "id": "Wo_zKXGBYVQj"
956 | }
957 | },
958 | {
959 | "cell_type": "markdown",
960 | "source": [
961 | "In the previous section we introduced the GCN layer in a hands-on way, by comparing it to convolutional layers used in image processing. Now, let's take a step back and think about the high level effect these layers have and why we would want them.\n",
962 | "\n",
963 | "The two main concepts we will use to analyse the behavior of these layers are **invariance** and **equivariance**.\n",
964 | "\n",
965 | "Suppose we want to detect whether a cat is in an image. The classifier should behave the same whether the cat is in the top-left, center, or right. Thus, it should be **invariant** to the **translation** of the image.\n",
966 | "On the other hand, suppose we want to identify the _coordinates_ of a cat in the image. Then, the network should follow along with the image transformations - it should be **equivariant** to **translation**.\n",
967 | "CNNs, by having weight sharing through use of kernels, enforce **translation equivariance** on the network. If needed for our task, we can then convert equivariance into invariance by pooling the entire image into a single output.\n",
968 | "\n",
969 | "\n"
970 | ],
971 | "metadata": {
972 | "id": "-00TOI-KYVQj"
973 | }
974 | },
975 | {
976 | "cell_type": "markdown",
977 | "source": [
978 | "Now, lets turn to graphs.\n",
979 | "\n",
980 | "On graphs we have no left, right, or up, so we cannot talk about translation symmetries of GNNs. But.. do you remember in the begining of this tutorial when we had to randomly pick an order of the nodes in order to store them in memory? Since this is an arbitrary choice, we want our model to be invariant/equivariant to this decision.\n",
981 | "\n",
982 | "We call this permutation invariance / permutation equivariance. A GNN is permutation invariant if, no mather how we permute the nodes the outoput doesn't change. On the other hand, a GNN is permutation equivariant if, by permuting the input, the output will permute in the same way. Generally, if our model outputs a graph-level representation we expect it to be permutation invariant. On the other hand, if the task prediction is at the node-level, the model should be permutation equivariant.\n",
983 | "\n",
984 | "\n",
985 | "\n",
986 | "Let's make this more formal.\n",
987 | "\n",
988 | "\n",
999 | "\n",
1000 | "A GNN $\\mathbf{f}$ is **permutation invariant** with respect to permutation matrix $\\mathbf{P}$ if\n",
1001 | "\n",
1002 | "$$\n",
1003 | "\\mathbf{f}\\left(\\mathbf{X}, \\mathbf{A}\\right)=\\mathbf{f}\\left(\\mathbf{P}\\mathbf{X}, \\mathbf{P}\\mathbf{A}\\mathbf{P}^\\intercal\\right)\n",
1004 | "$$\n",
1005 | "\n",
1006 | "A GNN $\\mathbf{f}$ is **permutation equivariant** with respect to permutation matrix $\\mathbf{P}$ if\n",
1007 | "\n",
1008 | "$$\n",
1009 | "\\mathbf{P}\\mathbf{f}\\left(\\mathbf{X}, \\mathbf{A}\\right)=\\mathbf{f}\\left(\\mathbf{P}\\mathbf{X}, \\mathbf{P}\\mathbf{A}\\mathbf{P}^\\intercal\\right)\n",
1010 | "$$\n",
1011 | "\n",
1012 | "\n",
1013 | "\n",
1014 | "Note how we have to transform the adjacency matrix as well as node features. Why?\n",
1015 | "\n",
1016 | "As we mentione before, tasks on graphs typically don't depend on node ordering. That's why, when designing GNNs, **it is crucial to ensure that our layers fulfill these properties**.\n",
1017 | "\n",
1018 | ""
1019 | ],
1020 | "metadata": {
1021 | "id": "ec5xUFFgYVQj"
1022 | }
1023 | },
1024 | {
1025 | "cell_type": "markdown",
1026 | "source": [
1027 | "## 🖋 **Task** Write a unit test to check that the SimpleGCN is permutation invariant.\n",
1028 | "\n",
1029 | "> Since on ZINC we predict one value per entire graph, the output should be permutation invariant. We will now test the permutation invariance of our simple model. All you need to do is to permute our sample graph. If everything is correct, SimpleGCN should output the same result for both the original and permuted molecule.\n",
1030 | "\n",
1031 | "\n",
1032 | "\n",
1033 | "\n"
1034 | ],
1035 | "metadata": {
1036 | "id": "vDVqbLHNYVQj"
1037 | }
1038 | },
1039 | {
1040 | "cell_type": "code",
1041 | "execution_count": null,
1042 | "source": [
1043 | "def test_permutation(graph):\n",
1044 | " # torch.random.manual_seed(42)\n",
1045 | " # np.random.seed(42)\n",
1046 | "\n",
1047 | " perm = np.random.permutation(graph.x.shape[0])\n",
1048 | "\n",
1049 | " permuted_graph = Data(\n",
1050 | " # ============ YOUR CODE HERE ==============\n",
1051 | " # Permute the features and the adjacency matrix according to perm\n",
1052 | " #\n",
1053 | " # x = ...\n",
1054 | " # adj = ...\n",
1055 | " #\n",
1056 | " # ===========================================\n",
1057 | " )\n",
1058 | "\n",
1059 | "\n",
1060 | " assert torch.allclose(\n",
1061 | "\n",
1062 | " # ============ YOUR CODE HERE ==============\n",
1063 | " # (For Task 3)\n",
1064 | " # Change the following lines to\n",
1065 | " # accomodate permutation equivariance\n",
1066 | " # ==================================\n",
1067 | "\n",
1068 | " simple_gcn_model(graph),\n",
1069 | " simple_gcn_model(permuted_graph)\n",
1070 | " ), \"Your model is not permutation invariant :(\"\n",
1071 | " print(\"Success: Your GCN model is permutation invariant by construction!\")\n",
1072 | "\n",
1073 | "test_permutation(sample_adj)"
1074 | ],
1075 | "outputs": [],
1076 | "metadata": {
1077 | "id": "3BZY_tnYYVQj"
1078 | }
1079 | },
1080 | {
1081 | "cell_type": "markdown",
1082 | "source": [
1083 | "## 🖋 **Task** Modify the `SimpleGCN` to produce node-level as opposed to aggregated graph-level features. Check that it breaks permutation invariance test and edit the test to check for permutation equivariance instead.\n",
1084 | "\n",
1085 | "\n",
1086 | "> Feel free to experiment with `SimpleGCN` and `GCNLayer` as we won't be using them beyond this point.\n",
1087 | ">\n",
1088 | "> 🆘 **Hint:** You need to edit only one line in both SimpleGCN and test_permutation for this task!\n"
1089 | ],
1090 | "metadata": {
1091 | "id": "-8RsyC4vYVQj"
1092 | }
1093 | },
1094 | {
1095 | "cell_type": "code",
1096 | "execution_count": null,
1097 | "source": [
1098 | "# There is no extra code for this task - go back and edit SimpleGCN and test_permutation!"
1099 | ],
1100 | "outputs": [],
1101 | "metadata": {
1102 | "id": "SOdFWF7PYVQj"
1103 | }
1104 | },
1105 | {
1106 | "cell_type": "markdown",
1107 | "source": [
1108 | " ✍ **Exercise:** Based on the formal definition, prove that GCN layer is permutation equivariant. (**Hint:** Which matrices commute?)"
1109 | ],
1110 | "metadata": {
1111 | "id": "0fYoNbEuYVQj"
1112 | }
1113 | },
1114 | {
1115 | "cell_type": "markdown",
1116 | "source": [
1117 | "## From Dense to Sparse Adjacency Matrix "
1118 | ],
1119 | "metadata": {
1120 | "id": "uwzL5bwBYVQj"
1121 | }
1122 | },
1123 | {
1124 | "cell_type": "markdown",
1125 | "source": [
1126 | "The GCN layer we implemented above works correctly, but it is _inefficient_. Remember how we decided to use an adjacency matrix? This means that we created a $29\\times 29$ matrix with only $64$ nonzero entries! Real-world graphs are often sparse, and we need to utilize their sparsity to save memory. (As we will see in the next section, with batching this effect becomes evein more pronounced.)\n",
1127 | "\n",
1128 | "To accomplish this, we will now work directly with the provided `edge_index` matrix, without converting it into an adjacency matrix. Luckily, there are tools that will make this process easy!\n",
1129 | "\n",
1130 | "Our main tool is the `scatter_sum` operator from [`torch_scatter`](https://pytorch-scatter.readthedocs.io/en/latest/index.html). A classical `torch.sum` takes two arguments and contracts a tensor by computing $\\text{torch.sum}\\left(\\mathbf{A}, \\text{dim}\\right)=\\sum_i \\mathbf{A}\\lbrack\\dots, \\underbrace{i}_{\\text{dim}}, \\dots\\rbrack$. On the other hand, `scatter_sum` takes a third argument, `index`. The operation will first group the elements based on the indices represented in `index`, then sum-aggregate the elements in each group.\n",
1131 | "\n",
1132 | "$$\n",
1133 | "\\text{scatter_sum}\\left(\\mathbf{A},\\text{index},\\text{dim}\\right)\n",
1134 | "\\lbrack\\dots, \\underbrace{j}_{\\text{dim}}, \\dots\\rbrack = \\sum_i \\mathbf{A}\\lbrack\\dots, \\underbrace{i}_{\\text{dim}}, \\dots\\rbrack \\text{, where } \\text{index}\\lbrack i\\rbrack = j\n",
1135 | "$$\n",
1136 | "\n",
1137 | "Let's look at an example:\n",
1138 | "\n",
1139 | "
\n",
1140 | "\n",
1141 | "\n",
1144 | "\n",
1145 | "The purpose of `scatter_sum` therefore is to aggregate groups of values in parallel, such as when summing over node neighborhoods in GCN!\n",
1146 | "Other `scatter_*` operations exist, including `scatter_mean` or `scatter_max`, which acts on a similar fashion, but perform a different operations on the selected groups.\n",
1147 | "\n",
1148 | "The scatter sum has, loosely speaking, a one-sided inverse[[1]](#cite_note-1) called `torch.index_select`, which is also useful for creating GNN layers. Where `scatter_sum` _reduced_ tensor size by _aggregating_ entries based on index, `index_select` _increases_ size by _copying_ them. The operation it performs is:\n",
1149 | "\n",
1150 | "$$\n",
1151 | "\\text{index_select}\\left(\\mathbf{A},\\text{index},\\text{dim}\\right)\n",
1152 | "\\lbrack\\dots, \\underbrace{i}_{\\text{dim}}, \\dots\\rbrack = \\mathbf{A}\\lbrack\\dots, \\underbrace{\\text{index}\\lbrack i \\rbrack}_{\\text{dim}}, \\dots\\rbrack\n",
1153 | "$$\n",
1154 | "\n",
1155 | "where $i$ now ranges over $\\text{index}$ and not over dimension $\\text{dim}$ of\n",
1156 | "$\\mathbf{A}$.\n",
1157 | "\n",
1158 | "
\n",
1159 | "\n",
1160 | "1. [^](#cite_ref-1) This is not entirely true. What `scatter_*` operation is `torch.index_select` actually an inverse of, and under which conditions?"
1161 | ],
1162 | "metadata": {
1163 | "id": "pu5V9RoqYVQj"
1164 | }
1165 | },
1166 | {
1167 | "cell_type": "markdown",
1168 | "source": [
1169 | "Lets see it in practice:"
1170 | ],
1171 | "metadata": {
1172 | "id": "IO9zzYvzYVQk"
1173 | }
1174 | },
1175 | {
1176 | "cell_type": "code",
1177 | "execution_count": null,
1178 | "source": [
1179 | "# group and aggregate elements based on group index\n",
1180 | "array = torch.tensor([88, 15, 10, 7, 24, 30, 9])\n",
1181 | "index = torch.tensor([0,1,1,1,2,0,1])\n",
1182 | "agg_sum = scatter_sum(array, index=index, dim=0)\n",
1183 | "print(f\"The results of scatter sum is {agg_sum}\")\n",
1184 | "\n",
1185 | "# select elements based on array index\n",
1186 | "array = torch.tensor([88, 15, 10, 7, 24, 30, 9])\n",
1187 | "index = torch.tensor([1, 1, 0, 2, 5, 4])\n",
1188 | "selected_array = torch.index_select(array, index=index, dim=0)\n",
1189 | "print(f\"The selected elements are {selected_array}\")\n"
1190 | ],
1191 | "outputs": [],
1192 | "metadata": {
1193 | "id": "xpv9Oz7hYVQk"
1194 | }
1195 | },
1196 | {
1197 | "cell_type": "markdown",
1198 | "source": [
1199 | "## 🖋 **Task** Implement the same GCN as before, but using the sparse representation of the edges.\n",
1200 | "\n",
1201 | "> This is not an easy task, but we will guide you step by step.\n",
1202 | ">\n",
1203 | "> The equation of a symmetrically-normalised GCN, represented at the node-level is:\n",
1204 | ">\n",
1205 | "> $$\n",
1206 | "\\mathbf{h_i} = \\sum_{j \\in \\mathbf{N}_i} \\frac{1}{\\sqrt{d_id_j}}\\mathbf{x_j}\\mathbf{W}\n",
1207 | "$$\n",
1208 | "> where $d_i$ represents the degree of node $i$ and $N_i$ is the neighbourhood of node $i$ including self-loops.\n",
1209 | ">\n",
1210 | "> Our function will receive:\n",
1211 | "- `x`: a $n \\times f$ matrix\n",
1212 | "- `edge_index`: a $2 \\times e$ matrix.\n",
1213 | "\n"
1214 | ],
1215 | "metadata": {
1216 | "id": "4PhDiGzZYVQk"
1217 | }
1218 | },
1219 | {
1220 | "cell_type": "markdown",
1221 | "source": [
1222 | "> **Step 1:** Compute the degree vector $\\mathbf{d}$ using the `edge_index` representation. To achieve this we will start with a vector full of ones $\\mathbf{1}_e$ and group-sum them based on the list of source nodes stored in `edge_index[0]`. This way we will add 1 to $\\mathbf{d}[i]$ whenever the node is identify as the source node of an edge\n",
1223 | "\n"
1224 | ],
1225 | "metadata": {
1226 | "id": "Ax3wSkElYVQk"
1227 | }
1228 | },
1229 | {
1230 | "cell_type": "code",
1231 | "execution_count": null,
1232 | "source": [
1233 | "# Step 1: compute the degree vector\n",
1234 | "\n",
1235 | "# edge_index[0][k] represent the source node of edge k\n",
1236 | "# edge_index[1][k] represent the destination node of edge k\n",
1237 | "edge_index = torch.tensor([[0,1,1,2,2,3], [1,0,2,1,3,2]]) # 2xe tensor\n",
1238 | "d = torch.ones_like(edge_index[0]) # 2xe tensor\n",
1239 | "\n",
1240 | "# ============ YOUR CODE HERE ==============\n",
1241 | "# d = .. # 1xn tensor\n",
1242 | "# ===========================================\n",
1243 | "\n",
1244 | "assert(torch.allclose(d, torch.tensor([1,2,2,1]), atol=1e-6))\n"
1245 | ],
1246 | "outputs": [],
1247 | "metadata": {
1248 | "id": "ei5BdkgZYVQk"
1249 | }
1250 | },
1251 | {
1252 | "cell_type": "markdown",
1253 | "source": [
1254 | " > **Step 2:** For each edge `(edge_index[0][k], edge_index[1][k])` compute the normalising coefficient $\\frac{1}{\\sqrt{\\textbf{d}\\text{[edge_index[0][k]]}\\textbf{d}\\text{[edge_index[1][k]]}}}$. To achieve that we form a vector $d_i$ of dimension $e$ in which $d_i[k]$ containing the degree of node `edge_index[0][k]`. Similarly vector $d_j[k]$ containing the degree of node `edge_index[1][k]`. We can easily achieve this by index-selecting from $\\mathbf{d}$ based on `edge_index[0]` and `edge_index[1]` respectively. The final normalising coefficient should be obtain as an element-wise operation between the vectors $d_i$ and $d_j$: $c = 1/\\sqrt{d_i d_j}$.\n"
1255 | ],
1256 | "metadata": {
1257 | "id": "wwbKByolYVQk"
1258 | }
1259 | },
1260 | {
1261 | "cell_type": "code",
1262 | "execution_count": null,
1263 | "source": [
1264 | "# Step 2: compute the normalising coefficients\n",
1265 | "\n",
1266 | "edge_index = torch.tensor([[0,1,1,2,2,3], [1,0,2,1,3,2]]) # 2xe tensor\n",
1267 | "# d[i] represent the degree of node i\n",
1268 | "d = torch.tensor([1,2,2,1]) # 1xn tensor\n",
1269 | "\n",
1270 | "# ============ YOUR CODE HERE ==============\n",
1271 | "# for each edge k extract from d the degree of the the source node edge_index[0][k]\n",
1272 | "# Hint: index_select might come handy\n",
1273 | "# d_j = .. # ex1 tensor\n",
1274 | "\n",
1275 | "# for each edge k extract from d the degree of the the destination node edge_index[1][k]\n",
1276 | "# d_i = .. # ex1 tensor\n",
1277 | "\n",
1278 | "\n",
1279 | "# c_ij = ... # ex1 tensor\n",
1280 | "# ===========================================\n",
1281 | "\n",
1282 | "assert(torch.allclose(c_ij, torch.tensor([[0.7071],[0.7071],[0.5000],[0.5000],[0.7071],[0.7071]]), atol=1e-6))"
1283 | ],
1284 | "outputs": [],
1285 | "metadata": {
1286 | "id": "2Af2THtCYVQk"
1287 | }
1288 | },
1289 | {
1290 | "cell_type": "markdown",
1291 | "source": [
1292 | "> **Step 3:** For each edge `(edge_index[0][k], edge_index[1][k])` we compute the \"message\" between $j$=`edge_index[0][k]` and $i$=`edge_index[1][k]` as $\\frac{1}{\\sqrt{d_i d_j}}\\mathbf{x_j}\\mathbf{W}$. To do this we create the vector $x_j$ containing for each edge the representation of the source node (we can easily do that by index selecting from x using the list of source nodes `edge_index[0]`, linearly projecting it than rescale it using our precomputed normalising scalars `c` from Step 2.\n",
1293 | "\n",
1294 | ""
1301 | ],
1302 | "metadata": {
1303 | "id": "oCtBMzCMYVQk"
1304 | }
1305 | },
1306 | {
1307 | "cell_type": "code",
1308 | "execution_count": null,
1309 | "source": [
1310 | "# Step 3: compute the messages from i to j\n",
1311 | "\n",
1312 | "edge_index = torch.tensor([[0,1,1,2,2,3], [1,0,2,1,3,2]]) # 2xe tensor\n",
1313 | "# c[k] represent the normalisation coefficient 1/sqrt(di[k]dj[k])\n",
1314 | "c = torch.tensor([[0.7071],[0.7071],[0.5000],[0.5000],[0.7071],[0.7071]]) # ex1 tensor\n",
1315 | "# x represent the input features\n",
1316 | "x = torch.ones(4,2) # nxf tensor\n",
1317 | "\n",
1318 | "linear_W = Linear(2, 3)\n",
1319 | "\n",
1320 | "# ============ YOUR CODE HERE ==============\n",
1321 | "# linearly project the node features\n",
1322 | "# x = .. # nxf' tensor\n",
1323 | "#\n",
1324 | "# for each edge k select as a message x_j\n",
1325 | "# this is equivalent to selecting from x based on the source nodes edge_index[0]\n",
1326 | "# x_j = .. # exf' tensor\n",
1327 | "#\n",
1328 | "# normalise the representation x_j using coefficients c_ij\n",
1329 | "# x_j = ..\n",
1330 | "#\n",
1331 | "# ==========================================="
1332 | ],
1333 | "outputs": [],
1334 | "metadata": {
1335 | "id": "LDg82u1wYVQk"
1336 | }
1337 | },
1338 | {
1339 | "cell_type": "markdown",
1340 | "source": [
1341 | "> **Step 4:** For each node, aggregate the received messages using sum aggregator."
1342 | ],
1343 | "metadata": {
1344 | "id": "z71gYuEbYVQk"
1345 | }
1346 | },
1347 | {
1348 | "cell_type": "code",
1349 | "execution_count": null,
1350 | "source": [
1351 | "edge_index = torch.tensor([[0,1,1,2,2,3], [1,0,2,1,3,2]]) # 2xe tensor\n",
1352 | "# m_ij represent the messages between i and j\n",
1353 | "m_ij = torch.ones(6,3) # exf' tensor\n",
1354 | "\n",
1355 | "# ============ YOUR CODE HERE ==============\n",
1356 | "# for each node aggregate the messages coming from its neighbour\n",
1357 | "# we ca have access the neighbourhood information in edge_index[1]\n",
1358 | "# Hint: scatter_sum might be useful\n",
1359 | "#\n",
1360 | "# out = .. # nxf' tensor\n",
1361 | "# ==========================================="
1362 | ],
1363 | "outputs": [],
1364 | "metadata": {
1365 | "id": "4AEDqJQFYVQk"
1366 | }
1367 | },
1368 | {
1369 | "cell_type": "markdown",
1370 | "source": [
1371 | "> Now we are ready to put everything together:"
1372 | ],
1373 | "metadata": {
1374 | "id": "HXK6uhFjYVQk"
1375 | }
1376 | },
1377 | {
1378 | "cell_type": "code",
1379 | "execution_count": null,
1380 | "source": [
1381 | "class SparseGCNLayer(Module):\n",
1382 | " def __init__(self, in_channels, out_channels):\n",
1383 | " \"\"\"\n",
1384 | " One layer of Graph Convolutional Network (GCN)\n",
1385 | " similar to the implemented before, but using the sparse\n",
1386 | " representation of edges as opposed to the dense\n",
1387 | " adjacency matrix.\n",
1388 | "\n",
1389 | " Args:\n",
1390 | " in_channels: (int) - input dimension\n",
1391 | " out_channels: (int) - output dimension\n",
1392 | " \"\"\"\n",
1393 | " super().__init__()\n",
1394 | " self.linear = Linear(in_channels, out_channels)\n",
1395 | "\n",
1396 | " def forward(self, x, edge_index):\n",
1397 | " \"\"\"\n",
1398 | " Args:\n",
1399 | " x: (n, in_dim) - initial node features\n",
1400 | " edge_index: (2, e) - list of edge indices\n",
1401 | "\n",
1402 | " Returns:\n",
1403 | " out: (n, out_dim) - updated node features\n",
1404 | " \"\"\"\n",
1405 | " x = self.linear(x)\n",
1406 | "\n",
1407 | " edge_index = add_self_loops(edge_index)[0] # Equivalent of Atilde = A + I\n",
1408 | " x_j = torch.index_select(x, index=edge_index[0], dim=0)\n",
1409 | " # x_i = torch.index_select(x, index=edge_index[1], dim=0) -- not used here\n",
1410 | "\n",
1411 | " # ============ YOUR CODE HERE ==============\n",
1412 | " #\n",
1413 | " # Step 1: compute degree d\n",
1414 | " # d = ...\n",
1415 | " # Step 2: compute coefficients c_ij\n",
1416 | " # d_j = ...\n",
1417 | " # d_i = ...\n",
1418 | " # c_ij = ...\n",
1419 | " # Step 3: Compute messages m_ij\n",
1420 | " # m_ij = ..\n",
1421 | " # Step 4: Aggregate message into node representation\n",
1422 | " # out = ...\n",
1423 | " #\n",
1424 | " # ===========================================\n",
1425 | "\n",
1426 | " return out"
1427 | ],
1428 | "outputs": [],
1429 | "metadata": {
1430 | "id": "_SD6ZI3nYVQl"
1431 | }
1432 | },
1433 | {
1434 | "cell_type": "markdown",
1435 | "source": [
1436 | "Let's test the layer by passing our `sample` through it:\n",
1437 | "\n",
1438 | "(note that we have to embed the node features first)"
1439 | ],
1440 | "metadata": {
1441 | "id": "ml7uyGRUYVQn"
1442 | }
1443 | },
1444 | {
1445 | "cell_type": "code",
1446 | "execution_count": null,
1447 | "source": [
1448 | "out = SparseGCNLayer(64, 1)(Embedding(28, 64)(sample.x).squeeze(1), sample.edge_index)\n",
1449 | "print(out.shape)"
1450 | ],
1451 | "outputs": [],
1452 | "metadata": {
1453 | "id": "tEFSyqbMYVQn"
1454 | }
1455 | },
1456 | {
1457 | "cell_type": "markdown",
1458 | "source": [
1459 | "The output should be a vector with $29$ entries - one for each node."
1460 | ],
1461 | "metadata": {
1462 | "id": "_AMDCZi_YVQn"
1463 | }
1464 | },
1465 | {
1466 | "cell_type": "markdown",
1467 | "source": [
1468 | "## Mini-batching for graph data"
1469 | ],
1470 | "metadata": {
1471 | "id": "9_KwYsGoYVQn"
1472 | }
1473 | },
1474 | {
1475 | "cell_type": "markdown",
1476 | "source": [
1477 | "There is one final thing we need to add to create a training pipeline, and that is batching. \n",
1478 | "\n",
1479 | "When we think of batching we typically think of adding a whole new dimension to our input. However, this won't work here because graphs come in different shapes and sizes.\n",
1480 | "\n",
1481 | "Instead, we perform mini-batching by **merging all graphs into one large batch graph **.\n",
1482 | "\n",
1483 | "Lets's look at an example. Suppose we want to batch 3 graphs $\\{G_1=(V_1, E_1)$, $G_2=(V_2, E_2)$, $G_3=(V_3, E_3)\\}$. Our batch graph will be a large graph $G=(V_1 \\cup V_2 \\cup V_3, E_1 \\cup E_2 \\cup E_3)$. In order to be able to keeping track of which nodes came from which graph, we will use a list of indices stored in the `batch` attribute.\n",
1484 | "\n",
1485 | " ✍ **Exercise:** Why does graph sparsity become an even bigger issue with batching?"
1486 | ],
1487 | "metadata": {
1488 | "id": "lT5mEKczYVQn"
1489 | }
1490 | },
1491 | {
1492 | "cell_type": "markdown",
1493 | "source": [
1494 | "
"
1495 | ],
1496 | "metadata": {
1497 | "id": "gsDvepFlYVQn"
1498 | }
1499 | },
1500 | {
1501 | "cell_type": "markdown",
1502 | "source": [
1503 | "## 🖋 **Task** Implement your own mini-batch by concatenating all graphs in a given list.\n",
1504 | "\n",
1505 | "> ❗Be careful: when creating the new `edge_index`, you will need to increment node indeces to match new node positions in the large graph! You can use our already implemented `collate` function for this.\n",
1506 | ">\n",
1507 | "> You also need to create a `batch` array assigning a graph index to each node, so 0s for nodes in the first graph, 1s for nodes in the second graph and so on."
1508 | ],
1509 | "metadata": {
1510 | "id": "kNzhJPFXYVQn"
1511 | }
1512 | },
1513 | {
1514 | "cell_type": "code",
1515 | "execution_count": null,
1516 | "source": [
1517 | "def collate(values: List[Tensor], incs: List[int], dim: int) -> Tensor:\n",
1518 | " \"\"\" Concatenate values along dim while incrementing based on incs\"\"\"\n",
1519 | " incs = cumsum(torch.tensor(incs))\n",
1520 | " return torch.cat([v + i for v, i in zip(values, incs)], dim=dim)\n",
1521 | "\n",
1522 | "\n",
1523 | "def create_mini_batch(graph_list: List[Data]) -> Data:\n",
1524 | " \"\"\" Build a sparse graph from a batch of graphs \"\"\"\n",
1525 | "\n",
1526 | " # ============ YOUR CODE HERE ==============\n",
1527 | " # To compute a batch graph you need to 1) concatenate the features\n",
1528 | " # 2) concatenate and update the edge indexes and 3) create the batching\n",
1529 | " # vector assigning each node to a graph\n",
1530 | " #\n",
1531 | " # batch_x = torch.cat(...)\n",
1532 | " # batch_y = ...\n",
1533 | " # batch_edge_index = collate(...)\n",
1534 | " # batch_batch = ...\n",
1535 | " #\n",
1536 | " # ===========================================\n",
1537 | "\n",
1538 | " batch_batch = batch_batch.to(torch.int64)\n",
1539 | "\n",
1540 | " #create the big sparse graph\n",
1541 | " batch_graph = Data(x=batch_x,\n",
1542 | " y=batch_y,\n",
1543 | " batch=batch_batch,\n",
1544 | " edge_index=batch_edge_index)\n",
1545 | " return batch_graph"
1546 | ],
1547 | "outputs": [],
1548 | "metadata": {
1549 | "id": "hF4FOFsFYVQo"
1550 | }
1551 | },
1552 | {
1553 | "cell_type": "markdown",
1554 | "source": [
1555 | "To test the batching implementation, we will visualise 2 graphs from the test dataset. If all is well, you should see two distinct molecules on top of each other:"
1556 | ],
1557 | "metadata": {
1558 | "id": "07q61xWGYVQo"
1559 | }
1560 | },
1561 | {
1562 | "cell_type": "code",
1563 | "execution_count": null,
1564 | "source": [
1565 | "test_minibatch = create_mini_batch(test_ds[:2])\n",
1566 | "gallery([test_minibatch], max_fig_size=(8,6), node_emb=[test_minibatch.batch/2])"
1567 | ],
1568 | "outputs": [],
1569 | "metadata": {
1570 | "id": "B8CUGfbIYVQo"
1571 | }
1572 | },
1573 | {
1574 | "cell_type": "markdown",
1575 | "source": [
1576 | "## 🖋 **Task** Aggregate the node representations into graph-level representation for a batch of graphs.\n",
1577 | "\n",
1578 | "> The GNN layers we implemented so far output a representation for each node. In order to obtain graph-level representations, we need to aggregate them according to the graph they are part of.\n",
1579 | ">\n",
1580 | "> Use the tools you learned about so far to complete the model below, which takes batched graphs and outputs a single prediction per graph. You need to implement the final step where node-level embeddings are converted to graph-level ones.\n",
1581 | ">\n",
1582 | "> 🆘 **Hint:** Compare this code to `SimpleGCN`. What needs to be done differently here? Where in the code do we need to use `batch`?"
1583 | ],
1584 | "metadata": {
1585 | "id": "KCrUw1teYVQt"
1586 | }
1587 | },
1588 | {
1589 | "cell_type": "code",
1590 | "execution_count": null,
1591 | "source": [
1592 | "class ZINCModel(Module):\n",
1593 | " def __init__(self, model, hidden_dim, num_layers,\n",
1594 | " **kwargs):\n",
1595 | " \"\"\"\n",
1596 | " A GNN model applying a series of `model` layers\n",
1597 | " to create graph-level representation\n",
1598 | "\n",
1599 | " Args:\n",
1600 | " model: (Module) - the class of layers applied\n",
1601 | " inside the model\n",
1602 | " hidden_dim: (int) - hidden dimension\n",
1603 | " num_layers: (int) - number of layers\n",
1604 | " \"\"\"\n",
1605 | " super().__init__()\n",
1606 | "\n",
1607 | " self.embed = Embedding(28, hidden_dim)\n",
1608 | " self.layers = torch.nn.ModuleList(\n",
1609 | " [model(in_channels=hidden_dim, out_channels=d, **kwargs)\n",
1610 | " for d in (num_layers - 1) * [hidden_dim] + [1]])\n",
1611 | "\n",
1612 | "\n",
1613 | " def forward(self, graph):\n",
1614 | " \"\"\"\n",
1615 | " Args:\n",
1616 | " graph: (PyG.Data) - a batch of graphs\n",
1617 | " Returns:\n",
1618 | " out: (float) - a scalar representing the output for the entire graph\n",
1619 | " \"\"\"\n",
1620 | " x = graph.x\n",
1621 | " batch = graph.batch\n",
1622 | " edge_index = graph.edge_index\n",
1623 | "\n",
1624 | " x = self.embed(x).squeeze(1)\n",
1625 | " for i in range(len(self.layers)):\n",
1626 | " x = self.layers[i](x, edge_index=edge_index)\n",
1627 | " if i < len(self.layers) - 1:\n",
1628 | " x = F.relu(x)\n",
1629 | "\n",
1630 | " # ============ YOUR CODE HERE ==============\n",
1631 | " # Aggregate the node information into graph information\n",
1632 | " # for each graph in the batch\n",
1633 | " #\n",
1634 | " # x = ...\n",
1635 | " #\n",
1636 | " # ===========================================\n",
1637 | "\n",
1638 | " out = x.squeeze(-1)\n",
1639 | " return out"
1640 | ],
1641 | "outputs": [],
1642 | "metadata": {
1643 | "id": "dFZTesUAYVQt"
1644 | }
1645 | },
1646 | {
1647 | "cell_type": "markdown",
1648 | "source": [
1649 | "Let's test the model:"
1650 | ],
1651 | "metadata": {
1652 | "id": "lhH7PQ3BYVQt"
1653 | }
1654 | },
1655 | {
1656 | "cell_type": "code",
1657 | "execution_count": null,
1658 | "source": [
1659 | "ZINCModel(model=SparseGCNLayer, hidden_dim=32, num_layers=2)(create_mini_batch(train_ds[:2]))"
1660 | ],
1661 | "outputs": [],
1662 | "metadata": {
1663 | "id": "qoVPjQhRYVQt"
1664 | }
1665 | },
1666 | {
1667 | "cell_type": "markdown",
1668 | "source": [
1669 | "We can finally train our model! We just need to load some boilerplate code for the training loop..."
1670 | ],
1671 | "metadata": {
1672 | "id": "U7NfbljRYVQt"
1673 | }
1674 | },
1675 | {
1676 | "cell_type": "code",
1677 | "execution_count": null,
1678 | "source": [
1679 | "# @title [RUN] Hyperparameters GCN\n",
1680 | "BATCH_SIZE = 64 #@param {type:\"integer\"}\n",
1681 | "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n",
1682 | "LR = 0.005 #@param {type:\"number\"}\n",
1683 | "HIDDEN_DIM = 64 #@param {type:\"integer\"}\n",
1684 | "NUM_LAYERS = 4 #@param {type:\"integer\"}\n",
1685 | "\n",
1686 | "\n",
1687 | "#you can add more here if you need"
1688 | ],
1689 | "outputs": [],
1690 | "metadata": {
1691 | "cellView": "form",
1692 | "id": "4C5geEmPYVQt"
1693 | }
1694 | },
1695 | {
1696 | "cell_type": "code",
1697 | "execution_count": null,
1698 | "source": [
1699 | "def train(dataset, model, optimiser, epoch, loss_fct):\n",
1700 | " \"\"\" Train model for one epoch\n",
1701 | " \"\"\"\n",
1702 | " model.train()\n",
1703 | " num_iter = int(len(dataset)/BATCH_SIZE)\n",
1704 | " for i in range(num_iter):\n",
1705 | " batch_list = dataset[i*BATCH_SIZE:(i+1)*BATCH_SIZE]\n",
1706 | " batch = create_mini_batch(batch_list)\n",
1707 | " optimiser.zero_grad()\n",
1708 | "\n",
1709 | " batch = batch.to(DEVICE)\n",
1710 | " y_hat = model(batch)\n",
1711 | " loss = loss_fct(y_hat, batch.y)\n",
1712 | "\n",
1713 | " loss.backward()\n",
1714 | " optimiser.step()\n",
1715 | " return loss.data\n",
1716 | "\n",
1717 | "def evaluate(dataset, model, loss_fct):\n",
1718 | " \"\"\" Evaluate model on dataset\n",
1719 | " \"\"\"\n",
1720 | " model.eval()\n",
1721 | " # be careful in practice, as doing this way we will lose some\n",
1722 | " # examples from the validation split, when len(dataset)%BATCH_SIZE != 0\n",
1723 | " # think about how can you fix this!\n",
1724 | " num_iter = math.ceil(len(dataset)/BATCH_SIZE)\n",
1725 | " loss_eval = 0\n",
1726 | " for i in range(num_iter):\n",
1727 | " batch_list = dataset[i*BATCH_SIZE:min((i+1)*BATCH_SIZE, len(dataset))] # Last batch is shorter\n",
1728 | " batch = create_mini_batch(batch_list)\n",
1729 | " batch = batch.to(DEVICE)\n",
1730 | "\n",
1731 | " y_hat = model(batch)\n",
1732 | " loss = loss_fct(y_hat, batch.y)\n",
1733 | "\n",
1734 | " loss_eval += loss.data\n",
1735 | "\n",
1736 | " loss_eval /= num_iter\n",
1737 | " return loss_eval"
1738 | ],
1739 | "outputs": [],
1740 | "metadata": {
1741 | "id": "2FiFtZwmYVQt"
1742 | }
1743 | },
1744 | {
1745 | "cell_type": "code",
1746 | "execution_count": null,
1747 | "source": [
1748 | "def run_exp(model, train_dataset, val_dataset, test_dataset,\n",
1749 | " loss_fct, lr, num_epochs):\n",
1750 | " \"\"\" Train the model for NUM_EPOCHS epochs\n",
1751 | " \"\"\"\n",
1752 | " print(\"\\nModel architecture:\")\n",
1753 | " print(model)\n",
1754 | "\n",
1755 | " model = model.to(DEVICE)\n",
1756 | " #Instantiatie our optimiser\n",
1757 | " optimiser = torch.optim.Adam(model.parameters(), lr=lr)\n",
1758 | " training_stats = None\n",
1759 | "\n",
1760 | " #initial evaluation (before training)\n",
1761 | " val_loss = evaluate(val_dataset, model, loss_fct)\n",
1762 | " train_loss = evaluate(train_dataset[:BATCH_SIZE], model,\n",
1763 | " loss_fct)\n",
1764 | " epoch_stats = {'train_loss': train_loss.cpu(), 'val_loss': val_loss.cpu(), 'epoch':0}\n",
1765 | " training_stats = update_stats(training_stats, epoch_stats)\n",
1766 | "\n",
1767 | " print(\"\\nStart training:\")\n",
1768 | " for epoch in range(num_epochs):\n",
1769 | " if isinstance(train_dataset, list):\n",
1770 | " random.shuffle(train_dataset)\n",
1771 | " else:\n",
1772 | " train_dataset.shuffle()\n",
1773 | " train_loss = train(train_dataset, model, optimiser, epoch,\n",
1774 | " loss_fct)\n",
1775 | " val_loss = evaluate(val_dataset, model, loss_fct)\n",
1776 | " print(f\"[Epoch {epoch+1}]\",\n",
1777 | " f\"train loss: {train_loss:.3f} val loss: {val_loss:.3f}\",\n",
1778 | " )\n",
1779 | " # store the loss and the computed metric for the final plot\n",
1780 | " epoch_stats = {'train_loss': train_loss.cpu(), 'val_loss': val_loss.cpu(),\n",
1781 | " 'epoch':epoch+1}\n",
1782 | " training_stats = update_stats(training_stats, epoch_stats)\n",
1783 | "\n",
1784 | " test_loss = evaluate(test_dataset, model, loss_fct)\n",
1785 | " print(f\"Done! Test loss: {test_loss:.3f}\")\n",
1786 | " return training_stats"
1787 | ],
1788 | "outputs": [],
1789 | "metadata": {
1790 | "id": "eE7E-c3FYVQt"
1791 | }
1792 | },
1793 | {
1794 | "cell_type": "markdown",
1795 | "source": [
1796 | "... and let the magic happen:"
1797 | ],
1798 | "metadata": {
1799 | "id": "QmYf2fkLYVQu"
1800 | }
1801 | },
1802 | {
1803 | "cell_type": "code",
1804 | "execution_count": null,
1805 | "source": [
1806 | "gcn_model = ZINCModel(model=SparseGCNLayer, hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS)\n",
1807 | "\n",
1808 | "stats = run_exp(gcn_model, train_ds, val_ds,\n",
1809 | " test_ds, loss_fct=F.mse_loss, lr=LR, num_epochs=NUM_EPOCHS)"
1810 | ],
1811 | "outputs": [],
1812 | "metadata": {
1813 | "id": "f6_rCWG2YVQu"
1814 | }
1815 | },
1816 | {
1817 | "cell_type": "code",
1818 | "execution_count": null,
1819 | "source": [
1820 | "plot_stats(stats)"
1821 | ],
1822 | "outputs": [],
1823 | "metadata": {
1824 | "id": "_B7w_uedYVQu"
1825 | }
1826 | },
1827 | {
1828 | "cell_type": "markdown",
1829 | "source": [
1830 | "## Graph Attention Network\n",
1831 | "\n",
1832 | "We will wrap up this section by implementing another popular GNN layer called [Graph Attention Network (GAT)](https://arxiv.org/abs/1710.10903).\n"
1833 | ],
1834 | "metadata": {
1835 | "id": "3IiXH2JoYVQu"
1836 | }
1837 | },
1838 | {
1839 | "cell_type": "markdown",
1840 | "source": [
1841 | "Compared to Graph Convolutional Network, which statically aggregates the information from the node's neigbourhood, GAT Layers adopts a more powerful aggregation technique, by using an attention mechanism.\n",
1842 | "\n",
1843 | "\n",
1844 | "We define the GAT layer as\n",
1845 | "\n",
1846 | "\n",
1847 | "$$\n",
1848 | "\\mathbf{h_i} = \\sigma \\big( \\sum_{j \\in N_i} \\alpha\\left(\\mathbf{x_i}, \\mathbf{x_j}\\right) \\mathbf{x_j} \\mathbf{W}\\big)\n",
1849 | "$$\n",
1850 | "\n",
1851 | "where\n",
1852 | "\n",
1853 | "$$\n",
1854 | "\\alpha\\left(\\mathbf{x_i}, \\mathbf{x_j}\\right) = \\sigma(\\text{LeakyReLU}\\left( \\left(\\mathbf{x_i}\\mathbf{q}_1^T + \\mathbf{x_j}\\mathbf{q}_2^T\\right)\\right)) \\in \\mathbb{R}\n",
1855 | "$$\n",
1856 | "\n",
1857 | "with $\\mathbf{q}_1^T, \\mathbf{q}_2^T \\in \\mathbb{R}^{d \\times 1}$ and $\\sigma$ is the softmax nonlinearity applied across each node's neighbourhood.\n",
1858 | "\n",
1859 | "Concretly, in GCN, summation coefficients were *statically* determined by $\\mathbf{A}$, while in GAT the aggregation is *dynamic*, with the weighted aggregation depending both on adjacency matrix $\\mathbf{A}$ but also on the features $\\mathbf{X}$. The attention coefficient predicts the importance of each neighbour, permitting the model to block the information from some of the nodes if they are not relevant for the task.\n",
1860 | "\n",
1861 | "\n",
1862 | "\n",
1863 | "\n"
1864 | ],
1865 | "metadata": {
1866 | "id": "XUjnms5cYVQu"
1867 | }
1868 | },
1869 | {
1870 | "cell_type": "markdown",
1871 | "source": [
1872 | "\n",
1873 | "\n",
1874 | "
\n"
1875 | ],
1876 | "metadata": {
1877 | "id": "wuRycoKFYVQu"
1878 | }
1879 | },
1880 | {
1881 | "cell_type": "markdown",
1882 | "source": [
1883 | "## 🖋 **Task** Implement the Graph Attention Layer using Pytorch.\n",
1884 | "\n",
1885 | "> Implement the attention coefficients $\\alpha$ for GAT in sparse form using `scatter_sum`.\n",
1886 | ">\n",
1887 | "> 🆘 **Hint:** we already imported `softmax` from `torch_geometric.utils`. It takes `index` as a second parameter, which determines groups that will be softmaxed independently of each other. Why is this useful here?"
1888 | ],
1889 | "metadata": {
1890 | "id": "RJvSG9TGYVQu"
1891 | }
1892 | },
1893 | {
1894 | "cell_type": "code",
1895 | "execution_count": null,
1896 | "source": [
1897 | "class GATLayer(Module):\n",
1898 | " def __init__(self, in_channels, out_channels):\n",
1899 | " \"\"\"\n",
1900 | " One layer of Graph Attention Network (GAT)\n",
1901 | "\n",
1902 | " Args:\n",
1903 | " in_channels: (int) - input dimension\n",
1904 | " out_channels: (int) - output dimension\n",
1905 | " \"\"\"\n",
1906 | " super().__init__()\n",
1907 | " self.a_i = Linear(out_channels,1, bias=False)\n",
1908 | " self.a_j = Linear(out_channels,1, bias=False)\n",
1909 | "\n",
1910 | " self.linear = Linear(in_channels, out_channels)\n",
1911 | " self.negative_slope = 0.2\n",
1912 | "\n",
1913 | " def forward(self, x, edge_index):\n",
1914 | " \"\"\"\n",
1915 | " Args:\n",
1916 | " x: (n, in_dim) - initial node features\n",
1917 | " edge_index: (2, e) - list of edge indices\n",
1918 | "\n",
1919 | " Returns:\n",
1920 | " out: (n, out_dim) - updated node features\n",
1921 | " \"\"\"\n",
1922 | " x = self.linear(x)\n",
1923 | " # select all the source nodes for each edge\n",
1924 | " x_j = torch.index_select(x, index=edge_index[0], dim=0)\n",
1925 | " # select all the destination nodes for each edge\n",
1926 | " x_i = torch.index_select(x, index=edge_index[1], dim=0)\n",
1927 | "\n",
1928 | " # ============ YOUR CODE HERE ==============\n",
1929 | " # Implement the equation above to compute the\n",
1930 | " # node update corresonding to GAT\n",
1931 | " #\n",
1932 | " # e = ...\n",
1933 | " # alpha = ...\n",
1934 | " # out = ...\n",
1935 | " #\n",
1936 | " # ===========================================\n",
1937 | "\n",
1938 | " return out"
1939 | ],
1940 | "outputs": [],
1941 | "metadata": {
1942 | "id": "E8etWmG8YVQu"
1943 | }
1944 | },
1945 | {
1946 | "cell_type": "code",
1947 | "execution_count": null,
1948 | "source": [
1949 | "# @title [RUN] Hyperparameters GAT\n",
1950 | "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n",
1951 | "LR = 0.005 #@param {type:\"number\"}\n",
1952 | "HIDDEN_DIM = 64 #@param {type:\"integer\"}\n",
1953 | "NUM_LAYERS = 4 #@param {type:\"integer\"}\n",
1954 | "\n",
1955 | "\n",
1956 | "#you can add more here if you need"
1957 | ],
1958 | "outputs": [],
1959 | "metadata": {
1960 | "cellView": "form",
1961 | "id": "M7XF7PKSYVQu"
1962 | }
1963 | },
1964 | {
1965 | "cell_type": "code",
1966 | "execution_count": null,
1967 | "source": [
1968 | "gat_model = ZINCModel(model=GATLayer, hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS)\n",
1969 | "\n",
1970 | "stats = run_exp(gat_model, train_ds, val_ds, test_ds, loss_fct=F.mse_loss, lr=LR, num_epochs=NUM_EPOCHS)"
1971 | ],
1972 | "outputs": [],
1973 | "metadata": {
1974 | "id": "UA6w5BI9YVQv"
1975 | }
1976 | },
1977 | {
1978 | "cell_type": "code",
1979 | "execution_count": null,
1980 | "source": [
1981 | "plot_stats(stats)"
1982 | ],
1983 | "outputs": [],
1984 | "metadata": {
1985 | "id": "yfEKl8MDYVQv"
1986 | }
1987 | },
1988 | {
1989 | "cell_type": "markdown",
1990 | "source": [
1991 | "If all goes well, you should see GAT slightly outperform GCN."
1992 | ],
1993 | "metadata": {
1994 | "id": "ug5nzSjsYVQv"
1995 | }
1996 | },
1997 | {
1998 | "cell_type": "markdown",
1999 | "source": [
2000 | "🎉 Congratulations, you have finished the first half of the tutorial!\n",
2001 | "\n",
2002 | "In the process, you reinvented many primitives that already exist in off-the-shelf GNN libraries. For example, lifting node representations to edges with `index_select` is already done more efficiently in `torch_geometric.nn.conv.MessagePassing.lift`, batching is implemented in `torch_geometric.loader.DataLoader`, and our `collate` is an oversimplified version of `torch_geometric.data.collate`. Feel free to check out these functions and compare.\n",
2003 | "\n",
2004 | "In the upcoming half of this tutorial you will learn to work with this high-level API, but now you hopefully understand that there is nothing mysterious going on under the hood!"
2005 | ],
2006 | "metadata": {
2007 | "id": "EKwHHfDHYVQv"
2008 | }
2009 | },
2010 | {
2011 | "cell_type": "markdown",
2012 | "source": [
2013 | "# 🚼 [Intermediate] **First steps in Pytorch-Geometric** \n"
2014 | ],
2015 | "metadata": {
2016 | "id": "783cmcM-dhG-"
2017 | }
2018 | },
2019 | {
2020 | "cell_type": "markdown",
2021 | "source": [
2022 | "So far, we learned how to implement some GNN architectures using the standard [Pytorch library](https://pytorch.org/docs/stable/index.html). While this gives us a better understanding of what operations are happening inside a GNN, it can become quite cumbersome for more complex projects. Just think about how painful it was to implement your own graph batching (well done! 🥂) or how much memory you wasted storing a big adjacency matrix full of zeros.\n",
2023 | "\n",
2024 | "To avoid these, for the rest of this practical we will use **Pytorch-Geometric library** (PyG). Pytorch Geometric is a library built on top of Pytorch, especially designed for a smoother handeling of geometric deep learning, providing easy-to-use mini-batching operation, several build-in GNN architectures, datasets and much more. If you never encounter PyG before don't worry, we will teach you step-by-step everything you need to know for this tutorial. However, if you want to explore more, there are some very useful tutorials on the [official platform](https://pytorch-geometric.readthedocs.io/en/latest/get_started/colabs.html).\n",
2025 | "\n",
2026 | "\n",
2027 | "## Data object and mini-batching\n",
2028 | "\n",
2029 | "In PyG, a graph is stored using the `Data` object, containg all the information about a graph $G(V,E)$:\n",
2030 | "\n",
2031 | "\n",
2032 | "> **`data.x`**: a tensor of dimension $|V| \\times C$ tensor containing the node features \\\n",
2033 | " **`data.edge_index`**: a tensor of dimension $2 \\times |E|$ describing the graph connectivity \\\n",
2034 | " **`data.edge_attr`**: if applicable, a tensor of dimension $|E| \\times K$ containing edge features \\\n",
2035 | " **`data.y`**: depending on the task, the target can contain labels either for each node, each edge or each graph\n",
2036 | "\n",
2037 | "If needed, you can add your own attributes to the `data` object (we will see this later). Moreover, pyg provides a series of special attributes and functions that you can directly access such as `data.num_nodes`, `data.num_edges`, `data.has_isolated_nodes()` and [much more](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data).\n",
2038 | "\n",
2039 | "🔍 Lets look at an example:"
2040 | ],
2041 | "metadata": {
2042 | "id": "9iOnSZ4MegDo"
2043 | }
2044 | },
2045 | {
2046 | "cell_type": "code",
2047 | "execution_count": null,
2048 | "source": [
2049 | "x = torch.rand(3, 5)\n",
2050 | "edge_index = torch.tensor(\n",
2051 | " [[0,1,2,0],[1,0,0,2]],\n",
2052 | " dtype=torch.long\n",
2053 | ")\n",
2054 | "graph = Data(x=x, edge_index=edge_index)\n",
2055 | "\n",
2056 | "print(f\"Graph contains {graph.num_edges} edges, {graph.num_nodes} nodes, each characterised by {graph.x.shape[0]} features.\")\n"
2057 | ],
2058 | "outputs": [],
2059 | "metadata": {
2060 | "id": "96nOk-FJnjpF"
2061 | }
2062 | },
2063 | {
2064 | "cell_type": "markdown",
2065 | "source": [
2066 | "This is similar to the object returned by our ZINC datasets. However PyG allows us to batch and unbatch the graphs much quicker, using the `Batch` object. Lets try to do it:"
2067 | ],
2068 | "metadata": {
2069 | "id": "OVtwiGAupJ1K"
2070 | }
2071 | },
2072 | {
2073 | "cell_type": "code",
2074 | "execution_count": null,
2075 | "source": [
2076 | "# Graph 1\n",
2077 | "x_1 = torch.tensor([[10.0], [0.5], [1.7]], dtype=torch.float)\n",
2078 | "edge_index_1 = torch.tensor(\n",
2079 | " [[0, 1, 1, 2], [1, 0, 2, 1]],\n",
2080 | " dtype=torch.long\n",
2081 | ")\n",
2082 | "data_1 = Data(x=x_1, edge_index=edge_index_1)\n",
2083 | "\n",
2084 | "# Graph 2\n",
2085 | "x_2 = torch.tensor([[11.5], [120.2], [-100], [40.5]], dtype=torch.float)\n",
2086 | "edge_index_2 = torch.tensor(\n",
2087 | " [[0, 2, 1, 0, 1], [2, 0, 0, 1, 3]],\n",
2088 | " dtype=torch.long\n",
2089 | ")\n",
2090 | "data_2 = Data(x=x_2, edge_index=edge_index_2)\n",
2091 | "\n",
2092 | "# Create a batch from the 2 graphs using Batch.from_data_list\n",
2093 | "data_list = [data_1, data_2]\n",
2094 | "batch = Batch.from_data_list(data_list)\n",
2095 | "\n",
2096 | "dense_adjacency = to_dense_adj(batch.edge_index).numpy().squeeze()\n",
2097 | "print(\"Let's see if it does what we expect:\")\n",
2098 | "print_color_numpy(dense_adjacency, data_list)\n",
2099 | "print(f\"\\nWe also have access to what graph each node it is part of via: {batch.batch}:\")\n",
2100 | "\n",
2101 | "# Unbatch the graphs via simple indexing\n",
2102 | "print(f\"And we can still access the individual graphs in the batch: \\n Graph 0: {batch[0]}, \\n Graph 1: {batch[1]}\")"
2103 | ],
2104 | "outputs": [],
2105 | "metadata": {
2106 | "id": "zquabOeAqDQ_"
2107 | }
2108 | },
2109 | {
2110 | "cell_type": "markdown",
2111 | "source": [
2112 | "`Batch.from_data_list()` is useful when we have our own list of graphs to batch. On the other hand, the PyG `Dataloader` object handles all of the batching under the hood."
2113 | ],
2114 | "metadata": {
2115 | "id": "oyEUXoZZyAHl"
2116 | }
2117 | },
2118 | {
2119 | "cell_type": "code",
2120 | "execution_count": null,
2121 | "source": [
2122 | "train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)\n",
2123 | "\n",
2124 | "# extract first batch from the ZINC dataset\n",
2125 | "first_batch = next(iter(train_loader))\n",
2126 | "print(\"This is the first batch in our dataset:\", first_batch)"
2127 | ],
2128 | "outputs": [],
2129 | "metadata": {
2130 | "id": "HUXdqHrqdnVq"
2131 | }
2132 | },
2133 | {
2134 | "cell_type": "markdown",
2135 | "source": [
2136 | "## Message-Passing framework\n",
2137 | "\n",
2138 | "Now that we are familiar to how graphs are stored in PyG, lets see how we create models using it. To understand how to implement a GNN model in PyG, we need to recap what the general **Message-Passing framework** is.\n",
2139 | "\n",
2140 | "Lets consider a graph $G = (V,E)$, with $n$ nodes, $e$ edges, where each node $i$ is characterised by a feature vector $x_i \\in \\mathbb{R}^d$. If edge features are available, we denote them by $e_{ij} \\in \\mathbb{R}^f$. We represent by ${N}_i$ the set of neighbours for node $i$.\n",
2141 | "\n",
2142 | "Most of the Graph Neural Network architectures are based on a $3$-step framework:\n",
2143 | "\n",
2144 | "\n",
2145 | "\n",
2146 | "\n",
2147 | "* **Message step.** For each pair of connected nodes $(i,j)$, the function $f_{msg}(x_i, x_j, e_{ij})$ will compute a message $m_{ij}$ representing the information sent between the $2$ nodes. Usually, this is implemented as an MLP which takes as input the concated input from the source node $x_j$, the destination node $x_i$ and the edge features $e_{ij}$. *Note that all these $3$ representations are optional and they can be omitted depending on the architecture.*\n",
2148 | "\n",
2149 | "* **Aggregate step.** For each node $i$, the incoming messages received from all the neighbouring nodes $m_i = ⊕ (\\{m_{ij, j \\in N_i}\\})$ are aggregated using a permutation-invariant operator $⨁$ such as sumation, average, maximum etc.\n",
2150 | "\n",
2151 | "* **Update step.** For each node $i$, the aggregated information $m_i$ is combined with the previous representation $x_i$ using a function $f_{upd}(x_i, m_i)$ which can be implemented as an MLP over the concatenation of the input vectors.\n",
2152 | "\n",
2153 | "🤔 Note that, most of the existing architectures including [GCN](https://arxiv.org/abs/1609.02907), [GAT](https://arxiv.org/abs/1710.10903) and [MPNN](https://arxiv.org/abs/1704.01212) are instantiations of this general framework. Take some time to think about what functions each of them are using.\n",
2154 | "\n",
2155 | "\n",
2156 | "\n"
2157 | ],
2158 | "metadata": {
2159 | "id": "Woj_lLFFzcJL"
2160 | }
2161 | },
2162 | {
2163 | "cell_type": "markdown",
2164 | "source": [
2165 | "We are now ready to implement such models in Pytorch Geometric. To create a message-passing based architecture, we inherit from the `MessagePassing` class, which[ take care of most of the propagation pipeline](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html#the-messagepassing-base-class). All we need to do is to implement the three functions described above: `message`, `aggregate` and `update`.\n",
2166 | "\n",
2167 | "\n",
2168 | "\n",
2169 | "```\n",
2170 | "class MPNNLayer(MessagePassing):\n",
2171 | " def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):\n",
2172 | " ...\n",
2173 | "\n",
2174 | " def forward(self, x, edge_index, edge_attr):\n",
2175 | " ...\n",
2176 | "\n",
2177 | " def message(self, x_i, x_j, edge_attr):\n",
2178 | " ...\n",
2179 | "\n",
2180 | " def aggregate(self, inputs, index):\n",
2181 | " ...\n",
2182 | "\n",
2183 | " def update(self, aggr_out, x):\n",
2184 | " ...\n",
2185 | "```\n",
2186 | "\n",
2187 | "> \\\n",
2188 | "**`forward(x, edge_index, edge_attr)`**: in the forward function, we simply need to call the `propagate()` function which automatically triggers the message passing procedure: `message()` -> `aggregate()` -> `update()`. \\\n",
2189 | "\\\n",
2190 | " **`message(x_i, x_j, edge_attr)`**: this constructs messages from node j to node i for each pair of nodes (j,i) in the `edge_index`. The function will received all the arguments passed to the `propagate` function. Additionally, for all the node-related arguments $h$ received by propagate, the function will distinguish between the source node, represented as ` _j` and the destination node represented as `_i`. So, if we have two node-related arguments in the propagate call, $x$ and $h$, message function will receive 4 arguments `x_j`, `h_j` corresponding to the source node, and `x_i`, `h_i` corresponding to the destination. \\\n",
2191 | " \\\n",
2192 | " **`aggregate(inputs, index)`**: it receives the set of all messages computed above (`inputs`) together with a set of indices representing the destination node for all these messages. Functions like `scatter_mean`, `scatter_max` can be useful to implement various aggregators.\\\n",
2193 | "\\\n",
2194 | " **`update(aggr_out, x)`**: receives the aggregated information from the above function (one for each node) together with any extra arguments received in the propagate, in order to compute the final node representation.\\\n",
2195 | " \\\n",
2196 | "\n"
2197 | ],
2198 | "metadata": {
2199 | "id": "LkZq0xTu6laH"
2200 | }
2201 | },
2202 | {
2203 | "cell_type": "markdown",
2204 | "source": [
2205 | "## Graph Attention Network in Pytorch-Geometric\n",
2206 | "\n",
2207 | "As we mention before, most of the GNN architecture can be seen as a particular instantiation of the above $3$-steps framework.\n",
2208 | "\n",
2209 | "The **Graph Attention Network** that you implemented in the begining of this tutorial is also an instance of the Message Passing framework.\n",
2210 | "\n",
2211 | "\n",
2212 | "\n",
2213 | "\n",
2214 | "\n",
2215 | "To implement the attention coefficient, various functions can be used. In [the original paper](https://arxiv.org/abs/1710.10903):\n",
2216 | "\n",
2217 | "\n",
2222 | "\n",
2223 | "$$\n",
2224 | "\\alpha\\left(\\mathbf{x_i}, \\mathbf{x_j}\\right) = \\sigma(\\text{LeakyReLU}\\left( \\left(\\mathbf{x_i}\\mathbf{q}_1^T + \\mathbf{x_j}\\mathbf{q}_2^T\\right)\\right)) \\in \\mathbb{R}\n",
2225 | "$$\n",
2226 | "\n",
2227 | "with $\\mathbf{q}_1^T, \\mathbf{q}_2^T \\in \\mathbb{R}^{d \\times 1}$ and $\\sigma$ is the softmax nonlinearity applied across each node's neighbourhood.\n",
2228 | "\n"
2229 | ],
2230 | "metadata": {
2231 | "id": "UUeaRBiuNnTP"
2232 | }
2233 | },
2234 | {
2235 | "cell_type": "markdown",
2236 | "source": [
2237 | "Lets use what we learned so far to implement the GAT layer in PyG."
2238 | ],
2239 | "metadata": {
2240 | "id": "21JIJbdCAM1U"
2241 | }
2242 | },
2243 | {
2244 | "cell_type": "markdown",
2245 | "source": [
2246 | "\n",
2247 | "\n",
2248 | "## 🖋 **Task** Implement a layer of Graph Attention Network using the Pytorch Geometric tools.\n",
2249 | "\n",
2250 | "> Most of the code is provided for you. You just need to fill in the code corresponding to the message and update function:\n",
2251 | ">\n",
2252 | "> \\begin{align}\n",
2253 | "f_{msg}(x_i, x_j)&=\\alpha(x_i, x_j)x_j \\\\\n",
2254 | "f_{upd}(x, m)&=f(m)\n",
2255 | "\\end{align}"
2256 | ],
2257 | "metadata": {
2258 | "id": "W2Bt6NiiITHB"
2259 | }
2260 | },
2261 | {
2262 | "cell_type": "code",
2263 | "execution_count": null,
2264 | "source": [
2265 | "class GATpygLayer(MessagePassing):\n",
2266 | " def __init__(self, in_dim, out_dim):\n",
2267 | " \"\"\"GAT Layer implemented using Pytorch Geometric\n",
2268 | "\n",
2269 | " Args:\n",
2270 | " in_dim: (int) - input dimension for nodes `d`\n",
2271 | " out_dim: (int) - output dimension `d_e`\n",
2272 | " \"\"\"\n",
2273 | " super().__init__(node_dim=0, aggr='add')\n",
2274 | " self.in_dim = in_dim\n",
2275 | " self.out_dim = out_dim\n",
2276 | "\n",
2277 | " # ============ YOUR CODE HERE ==============\n",
2278 | " # Add new layers to compute the two linear projections\n",
2279 | " # x_i - > x_i q_i^T and x_j -> x_j q_j^T\n",
2280 | " #\n",
2281 | " # self.lin_q = ...\n",
2282 | " # self.lin_k = ...\n",
2283 | " # ===========================================\n",
2284 | " self.lin_upd = Linear(in_dim, out_dim, bias=False)\n",
2285 | "\n",
2286 | "\n",
2287 | " def forward(self, x, edge_index):\n",
2288 | " \"\"\"\n",
2289 | " The forward pass perform one round of message passing.\n",
2290 | "\n",
2291 | " By calling the `propagate()` function it automatically starts the\n",
2292 | " `message()` -> `aggregate()` -> `update()` pipeline.\n",
2293 | "\n",
2294 | " Args:\n",
2295 | " x: (n, in_dim) - initial node features\n",
2296 | " edge_index: (2, e) - list of edges as a tuple\n",
2297 | "\n",
2298 | " Returns:\n",
2299 | " out: (n, out_dim) - updated node features\n",
2300 | " \"\"\"\n",
2301 | " out = self.propagate(edge_index, x=x)\n",
2302 | " return out\n",
2303 | "\n",
2304 | "\n",
2305 | " def message(self, x_i, x_j, edge_index_j):\n",
2306 | " \"\"\" The `message()` function constructs the messages from nodes j\n",
2307 | " to nodes i for each edge (j, i) in `edge_index`.\n",
2308 | "\n",
2309 | " This function receives all the arguments passed to the `propagate` function.\n",
2310 | " For node-related features, it distinguish between the source and\n",
2311 | " the destination node by appending `_i` or `_j` to the name.\n",
2312 | " E.g. if `x` represents the node features, message will receive both `x_j`\n",
2313 | " the features for source node and `x_i` the features of the destination.\n",
2314 | "\n",
2315 | " Args:\n",
2316 | " x_j: (e, d) - source node features, essentially x[edge_index[0]]\n",
2317 | " x_i: (e, d) - destination node features: x[edge_index[1]]\n",
2318 | " edge_index_j: (e, d_e) - node index used to guide the softmax computation\n",
2319 | " Returns:\n",
2320 | " out: (e, d) - messages `m_ji`\n",
2321 | " \"\"\"\n",
2322 | "\n",
2323 | " # ============ YOUR CODE HERE ==============\n",
2324 | " # Compute alpha(x_i, x_j) as a tensor of dimension (e,)\n",
2325 | " # storing the messages for each edge in edge_index\n",
2326 | " #\n",
2327 | " # alpha = ...\n",
2328 | " # alpha = softmax(alpha, edge_index_j)\n",
2329 | " # out = ...\n",
2330 | " # ===========================================\n",
2331 | "\n",
2332 | " return out\n",
2333 | "\n",
2334 | " def update(self, aggr_out, x):\n",
2335 | " \"\"\" update()` combines the aggregated messages with the initial nodes.\n",
2336 | "\n",
2337 | " `aggr_out` represent the result of the `aggregate()` step, whiel the\n",
2338 | " rest of the arguments are all the arguments initially passed to\n",
2339 | " `propagate()`.\n",
2340 | "\n",
2341 | " Args:\n",
2342 | " aggr_out: (n, d) - aggregated messages `m_i`\n",
2343 | " x: (n, d) - initial node features\n",
2344 | "\n",
2345 | " Returns:\n",
2346 | " out: (n, d') - updated node features\n",
2347 | " \"\"\"\n",
2348 | " out = self.lin_upd(aggr_out)\n",
2349 | " return out\n"
2350 | ],
2351 | "outputs": [],
2352 | "metadata": {
2353 | "id": "kFa4Yne1N79A"
2354 | }
2355 | },
2356 | {
2357 | "cell_type": "code",
2358 | "execution_count": null,
2359 | "source": [
2360 | "layer_gat = GATpygLayer(5,10)\n",
2361 | "x = torch.rand(3, 5)\n",
2362 | "edge_index = torch.tensor([[0,1,2,0],[1,0,0,2]], dtype=torch.long)\n",
2363 | "out = layer_gat(x, edge_index)\n",
2364 | "print(f\"Layer output has shape: {out.shape}\")"
2365 | ],
2366 | "outputs": [],
2367 | "metadata": {
2368 | "id": "v-wFAB8YRlQk"
2369 | }
2370 | },
2371 | {
2372 | "cell_type": "markdown",
2373 | "source": [
2374 | "That's awesome!! You implemented your first model in PyG!! Lets use this layer to create a full GAT model for graph-level prediction.\n",
2375 | "\n",
2376 | "You might have noticed that the code is vey similar to what you've implemented in the Pytorch section. However, thanks to the *magic* behind PyG , you don't need to `index_select` and `scatter_` all the time. 🪄 "
2377 | ],
2378 | "metadata": {
2379 | "id": "DHLWTLKvJ_HA"
2380 | }
2381 | },
2382 | {
2383 | "cell_type": "code",
2384 | "execution_count": null,
2385 | "source": [
2386 | "class GATpygModel(Module):\n",
2387 | " def __init__(self, hidden_dim, num_layers=2):\n",
2388 | " \"\"\"GAT Neural Network model for graph-level regression\n",
2389 | "\n",
2390 | " Args:\n",
2391 | " hidden_dim: (int) - hidden dimension\n",
2392 | " num_layers: (int) - number of GAT layers used in the model\n",
2393 | " \"\"\"\n",
2394 | " super(GATpygModel, self).__init__()\n",
2395 | " self.num_layers = num_layers\n",
2396 | "\n",
2397 | " # the ZINC node feature are integers representing the atom type\n",
2398 | " # we convert them in vectorial representation using a learnable embedding\n",
2399 | " self.embed_x = Embedding(28, hidden_dim)\n",
2400 | "\n",
2401 | " self.layers = [GATpygLayer(hidden_dim, hidden_dim) for _ in range(num_layers-1)]\n",
2402 | " self.layers += [GATpygLayer(hidden_dim, 1)]\n",
2403 | " self.layers = ModuleList(self.layers)\n",
2404 | "\n",
2405 | " def forward(self, graph):\n",
2406 | " \"\"\"\n",
2407 | " Args:\n",
2408 | " graph: (PyG.Data) - batch of PyG graphs\n",
2409 | "\n",
2410 | " Returns:\n",
2411 | " out: (batch_size,) - scalar prediction for each graph\n",
2412 | " \"\"\"\n",
2413 | " x = self.embed_x(graph.x).squeeze(1)\n",
2414 | "\n",
2415 | " for i in range(self.num_layers-1):\n",
2416 | " x = self.layers[i](x, graph.edge_index)\n",
2417 | " x = F.relu(x)\n",
2418 | " x = self.layers[-1](x, graph.edge_index)\n",
2419 | " out = global_add_pool(x, graph.batch).squeeze(-1)\n",
2420 | "\n",
2421 | " return out"
2422 | ],
2423 | "outputs": [],
2424 | "metadata": {
2425 | "id": "ZgKoEFGNmifv"
2426 | }
2427 | },
2428 | {
2429 | "cell_type": "code",
2430 | "execution_count": null,
2431 | "source": [
2432 | "#@title [RUN] Helper functions for train-eval models with PyG\n",
2433 | "# these are very similar to the ones used before, but now we are\n",
2434 | "# taking advantage of the PyG tools for loading and batching data\n",
2435 | "\n",
2436 | "def train_pyg(loader, model, optimiser, epoch, loss_fct):\n",
2437 | " \"\"\" Train model for one epoch\n",
2438 | " \"\"\"\n",
2439 | " model.train()\n",
2440 | " for batch in loader:\n",
2441 | " batch = batch.to(DEVICE)\n",
2442 | " optimiser.zero_grad()\n",
2443 | "\n",
2444 | " y_hat = model(batch)\n",
2445 | " loss = loss_fct(y_hat, batch.y)\n",
2446 | "\n",
2447 | " loss.backward()\n",
2448 | " optimiser.step()\n",
2449 | " return loss.item()\n",
2450 | "\n",
2451 | "def evaluate_pyg(loader, model, loss_fct):\n",
2452 | " \"\"\" Evaluate model on dataset\n",
2453 | " \"\"\"\n",
2454 | " model.eval()\n",
2455 | " loss_eval = 0\n",
2456 | " for batch in loader:\n",
2457 | " batch = batch.to(DEVICE)\n",
2458 | " with torch.no_grad():\n",
2459 | " y_hat = model(batch)\n",
2460 | " loss = loss_fct(y_hat, batch.y)\n",
2461 | "\n",
2462 | " loss_eval += loss.item() * batch.num_graphs\n",
2463 | " loss_eval /= len(loader.dataset)\n",
2464 | " return loss_eval\n",
2465 | "\n",
2466 | "def run_exp_pyg(model, train_loader, val_loader, test_loader, loss_fct,\n",
2467 | " lr=0.001, num_epochs=100):\n",
2468 | " \"\"\" Train the model for NUM_EPOCHS epochs\n",
2469 | " \"\"\"\n",
2470 | " print(\"\\nModel architecture:\")\n",
2471 | " print(model)\n",
2472 | "\n",
2473 | " model = model.to(DEVICE)\n",
2474 | "\n",
2475 | " #Instantiatie our optimiser\n",
2476 | " optimiser = torch.optim.Adam(model.parameters(), lr=lr)\n",
2477 | " training_stats = None\n",
2478 | "\n",
2479 | " #initial evaluation (before training)\n",
2480 | " val_loss = evaluate_pyg(val_loader, model, loss_fct)\n",
2481 | " train_loss = evaluate_pyg(train_loader, model, loss_fct)\n",
2482 | " epoch_stats = {'train_loss': train_loss, 'val_loss': val_loss, 'epoch':0}\n",
2483 | " training_stats = update_stats(training_stats, epoch_stats)\n",
2484 | "\n",
2485 | " print(\"\\nStart training:\")\n",
2486 | " for epoch in range(num_epochs):\n",
2487 | " train_loss = train_pyg(train_loader, model, optimiser, epoch,\n",
2488 | " loss_fct)\n",
2489 | " val_loss = evaluate_pyg(val_loader, model, loss_fct)\n",
2490 | " print(f\"[Epoch {epoch+1}]\",\n",
2491 | " f\"train loss: {train_loss:.3f} val loss: {val_loss:.3f}\",\n",
2492 | " )\n",
2493 | " # store the loss and the computed metric for the final plot\n",
2494 | " epoch_stats = {'train_loss': train_loss, 'val_loss': val_loss,\n",
2495 | " 'epoch':epoch+1}\n",
2496 | " training_stats = update_stats(training_stats, epoch_stats)\n",
2497 | "\n",
2498 | " test_loss = evaluate_pyg(test_loader, model, loss_fct)\n",
2499 | " print(f\"Done! Test loss: {test_loss:.3f}\")\n",
2500 | " return training_stats"
2501 | ],
2502 | "outputs": [],
2503 | "metadata": {
2504 | "id": "nPi8mIXmGv5H",
2505 | "cellView": "form"
2506 | }
2507 | },
2508 | {
2509 | "cell_type": "markdown",
2510 | "source": [
2511 | "We are now ready to train our model on the molecular prediction problem."
2512 | ],
2513 | "metadata": {
2514 | "id": "7ERtqTufKrtk"
2515 | }
2516 | },
2517 | {
2518 | "cell_type": "code",
2519 | "execution_count": null,
2520 | "source": [
2521 | "batch_size = 64\n",
2522 | "\n",
2523 | "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)\n",
2524 | "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)\n",
2525 | "test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)"
2526 | ],
2527 | "outputs": [],
2528 | "metadata": {
2529 | "id": "9VntJG8UIKPu"
2530 | }
2531 | },
2532 | {
2533 | "cell_type": "code",
2534 | "execution_count": null,
2535 | "source": [
2536 | "# @title [RUN] Hyperparameters GAT\n",
2537 | "\n",
2538 | "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n",
2539 | "LR = 0.005 #@param {type:\"number\"}\n",
2540 | "HIDDEN_DIM = 64 #@param {type:\"integer\"}\n",
2541 | "NUM_LAYERS = 4 #@param {type:\"integer\"}\n",
2542 | "\n",
2543 | "\n",
2544 | "#you can add more here if you need"
2545 | ],
2546 | "outputs": [],
2547 | "metadata": {
2548 | "id": "71ckTAfAL0X8"
2549 | }
2550 | },
2551 | {
2552 | "cell_type": "code",
2553 | "execution_count": null,
2554 | "source": [
2555 | "model_gat_pyg = GATpygModel(HIDDEN_DIM, num_layers=NUM_LAYERS)\n",
2556 | "stats = run_exp_pyg(model_gat_pyg, train_loader, val_loader, test_loader, loss_fct=F.mse_loss,\n",
2557 | " lr=LR, num_epochs=NUM_EPOCHS)"
2558 | ],
2559 | "outputs": [],
2560 | "metadata": {
2561 | "id": "QKheqXR0Ii0z"
2562 | }
2563 | },
2564 | {
2565 | "cell_type": "code",
2566 | "execution_count": null,
2567 | "source": [
2568 | "plot_stats(stats)"
2569 | ],
2570 | "outputs": [],
2571 | "metadata": {
2572 | "id": "eYcR0m5Wb11r"
2573 | }
2574 | },
2575 | {
2576 | "cell_type": "markdown",
2577 | "source": [
2578 | "❗ Note that, since this is the exact same architecture as the one you implemented in Pytorch, we don't expect an improvement in performance. This exercise was just such that you get experience playing with PyG."
2579 | ],
2580 | "metadata": {
2581 | "id": "O_25cGkFLM72"
2582 | }
2583 | },
2584 | {
2585 | "cell_type": "markdown",
2586 | "source": [
2587 | "## From Graph Attention Network to Graph Transformer\n",
2588 | "\n",
2589 | "\n",
2590 | "\n",
2591 | "\n",
2592 | "As you've already seen several times during the summer school, one of the most popular neural network these days, across several domains, is the [Transformer](https://arxiv.org/abs/1706.03762) architecture. The key element of the Transformer layer is a particular type of attention known as key-query similarity.\n",
2593 | "\n",
2594 | "To get a step closer to Transformers, we will modify the previous Graph Attention Layer, by replacing the original function used to infer the attention coefficient -- $\\alpha(x_i, x_j)$ with the key-query similarity used in Transformers. This is the idea behind the first [**Graph Transformer** architecture](https://arxiv.org/abs/2312.11109) and the core observation that leads to popular idea that [*Transformers are Graph Neural Networks*](https://thegradient.pub/transformers-are-graph-neural-networks/).\n",
2595 | "\n",
2596 | "Concretly, we will implement the following message function:\n",
2597 | "\n",
2598 | "\\begin{align}\n",
2599 | "f_{msg}(x_i, x_j)=\\sigma(q(x_i)^Tk(x_j))v(x_j)\n",
2600 | "\\end{align}\n",
2601 | "\n",
2602 | "where $k,q: \\mathbb{R}^d \\rightarrow \\mathbb{R}^f$ and $v: \\mathbb{R}^d \\rightarrow \\mathbb{R}^o$ are 3 different linear layers and $\\sigma$ is the softmax non-linearity.\n",
2603 | "\n",
2604 | "❗️ *Note that, since we are already processing $v(x_j)$ in the message function, we can omit the update function entirely.*\n"
2605 | ],
2606 | "metadata": {
2607 | "id": "VMOleceFNqEZ"
2608 | }
2609 | },
2610 | {
2611 | "cell_type": "markdown",
2612 | "source": [
2613 | "## 🖋 **Task** Implement the modified Graph Attention Layer, using dot-product attention."
2614 | ],
2615 | "metadata": {
2616 | "id": "swraStnlDk7Q"
2617 | }
2618 | },
2619 | {
2620 | "cell_type": "markdown",
2621 | "source": [
2622 | "> For our implementation:\n",
2623 | ">\n",
2624 | "> \\begin{align}\n",
2625 | "f_{msg}(x_i, x_j)&=\\sigma(q(x_i)^Tk(x_j))v(x_j)\\\\\n",
2626 | "f_{upd}(x,m)&=m \\hspace{40mm} \\small\\text{ # this is the default update() behaviour}\n",
2627 | "\\end{align}\n",
2628 | ">\n",
2629 | "> This is just one option to align the model with the message passing framework. Several equivalent alignments exist, and you are free to choose the one you prefer.\n",
2630 | ">\n",
2631 | "> ❗ Note that, in our code template, we postpone the application of the MLPs $f_*$ to be done in the message function. While applying them in the forward would be more efficient, we prefered this approach for educational purpose."
2632 | ],
2633 | "metadata": {
2634 | "id": "P12JLWRyDyhB"
2635 | }
2636 | },
2637 | {
2638 | "cell_type": "code",
2639 | "execution_count": null,
2640 | "source": [
2641 | "class GTpygLayer(MessagePassing):\n",
2642 | " def __init__(self, in_dim, out_dim, hid_dim):\n",
2643 | " \"\"\"Sparse Graph Transformer Layer implemented using Pytorch Geometric\n",
2644 | "\n",
2645 | " Args:\n",
2646 | " in_dim: (int) - input dimension for node features\n",
2647 | " out_dim: (int) - output dimension\n",
2648 | " hid_dim: (int) - hidden dimension\n",
2649 | " aggr: (int) - the type of aggregation used in the message passing\n",
2650 | " \"\"\"\n",
2651 | " super().__init__(node_dim=0, aggr='add')\n",
2652 | " self.in_dim = in_dim\n",
2653 | " self.out_dim = out_dim\n",
2654 | "\n",
2655 | " # ============ YOUR CODE HERE ==============\n",
2656 | " # Add new layers to compute the three linear projections\n",
2657 | " # x_i -> q(x_i); x_j -> k(x_i); x_j -> v(x_i)\n",
2658 | " #\n",
2659 | " # self.lin_q = ...\n",
2660 | " # self.lin_k = ...\n",
2661 | " # self.lin_v = ...\n",
2662 | " # ===========================================\n",
2663 | "\n",
2664 | " def forward(self, x, edge_index):\n",
2665 | " \"\"\"\n",
2666 | " Args:\n",
2667 | " x: (n, in_dim) - initial node features\n",
2668 | " edge_index: (2, e) - list of edges as a tuple\n",
2669 | "\n",
2670 | " Returns:\n",
2671 | " out: (n, out_dim) - updated node features\n",
2672 | " \"\"\"\n",
2673 | " out = self.propagate(edge_index, x=x)\n",
2674 | " return out\n",
2675 | "\n",
2676 | "\n",
2677 | " def message(self, x_i, x_j, edge_index_j):\n",
2678 | " \"\"\"\n",
2679 | " Args:\n",
2680 | " x_i: (e, in_dim) - features corresponding to destination nodes\n",
2681 | " x_j: (e, in_dim) - features corresponding to source nodes\n",
2682 | " edge_index_j: (e, d_e) - node index used to guide the softmax computation\n",
2683 | "\n",
2684 | " Returns:\n",
2685 | " out: (n, out_dim) - updated node features\n",
2686 | " \"\"\"\n",
2687 | "\n",
2688 | " # ============ YOUR CODE HERE ==============\n",
2689 | " # Compute alpha(x_i, x_j) as a tensor of dimension (e,)\n",
2690 | " # storing the messages for each edge in edge_index\n",
2691 | " #\n",
2692 | " # alpha = ...\n",
2693 | " # alpha = softmax(alpha, edge_index_j)\n",
2694 | " # out = ...\n",
2695 | " # ===========================================\n",
2696 | "\n",
2697 | " return out"
2698 | ],
2699 | "outputs": [],
2700 | "metadata": {
2701 | "id": "LW-te-URY3ti"
2702 | }
2703 | },
2704 | {
2705 | "cell_type": "code",
2706 | "execution_count": null,
2707 | "source": [
2708 | "model = GTpygLayer(5,10, 16)\n",
2709 | "x = torch.rand(3, 5)\n",
2710 | "edge_index = torch.tensor([[0,1,2,0],[1,0,0,2]], dtype=torch.long)\n",
2711 | "out = model(x, edge_index)\n",
2712 | "print(out.shape)"
2713 | ],
2714 | "outputs": [],
2715 | "metadata": {
2716 | "id": "KsozvluxZ9M-"
2717 | }
2718 | },
2719 | {
2720 | "cell_type": "code",
2721 | "execution_count": null,
2722 | "source": [
2723 | "class GTpygModel(Module):\n",
2724 | " def __init__(self, hidden_dim, num_layers=2):\n",
2725 | " \"\"\"\n",
2726 | " Sparse Graph Transformer Neural Network model for graph-level regression\n",
2727 | "\n",
2728 | " Args:\n",
2729 | " hidden_dim: (int) - hidden dimension\n",
2730 | " num_layers: (int) - number of GT layers used in the model\n",
2731 | " \"\"\"\n",
2732 | " super(GTpygModel, self).__init__()\n",
2733 | " self.num_layers = num_layers # please select num_layers>=2\n",
2734 | "\n",
2735 | " # the ZINC node feature are integers representing the atom type\n",
2736 | " # we convert them in vectorial representation using a learnable embedding\n",
2737 | " self.embed_x = Embedding(28,hidden_dim)\n",
2738 | "\n",
2739 | " self.layers = [GTpygLayer(hidden_dim, hidden_dim, hidden_dim) for _ in range(num_layers-1)]\n",
2740 | " self.layers += [GTpygLayer(hidden_dim, 1, hidden_dim)]\n",
2741 | " self.layers = ModuleList(self.layers)\n",
2742 | "\n",
2743 | " def forward(self, graph):\n",
2744 | " \"\"\"\n",
2745 | " Args:\n",
2746 | " graph: (PyG.Data) - batch of PyG graphs\n",
2747 | " Returns:\n",
2748 | " out: (batch_size,) - scalar prediction for each graph in the batch\n",
2749 | " \"\"\"\n",
2750 | " x = self.embed_x(graph.x).squeeze(1)\n",
2751 | "\n",
2752 | " for i in range(self.num_layers-1):\n",
2753 | " x = self.layers[i](x, graph.edge_index)\n",
2754 | " x = F.relu(x)\n",
2755 | " x = self.layers[-1](x, graph.edge_index)\n",
2756 | " out = global_add_pool(x, graph.batch)\n",
2757 | "\n",
2758 | " out = out.squeeze(-1)\n",
2759 | " return out"
2760 | ],
2761 | "outputs": [],
2762 | "metadata": {
2763 | "id": "D89d9oY4-37y"
2764 | }
2765 | },
2766 | {
2767 | "cell_type": "markdown",
2768 | "source": [
2769 | "Lets train and evaluate this new model on the molecular prediction task we used before."
2770 | ],
2771 | "metadata": {
2772 | "id": "WphZUhBxW8ns"
2773 | }
2774 | },
2775 | {
2776 | "cell_type": "code",
2777 | "execution_count": null,
2778 | "source": [
2779 | "# @title [RUN] Hyperparameters GT\n",
2780 | "\n",
2781 | "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n",
2782 | "LR = 0.001 #@param {type:\"number\"}\n",
2783 | "HIDDEN_DIM = 64 #@param {type:\"integer\"}\n",
2784 | "NUM_LAYERS = 4 #@param {type:\"integer\"}\n",
2785 | "\n",
2786 | "\n",
2787 | "#you can add more here if you need"
2788 | ],
2789 | "outputs": [],
2790 | "metadata": {
2791 | "id": "Sz0glF9-CRUN"
2792 | }
2793 | },
2794 | {
2795 | "cell_type": "code",
2796 | "execution_count": null,
2797 | "source": [
2798 | "model_gt_pyg = GTpygModel(HIDDEN_DIM, num_layers=NUM_LAYERS)\n",
2799 | "stats = run_exp_pyg(model_gt_pyg, train_loader, val_loader, test_loader, loss_fct=F.mse_loss,\n",
2800 | " lr=LR, num_epochs=NUM_EPOCHS)"
2801 | ],
2802 | "outputs": [],
2803 | "metadata": {
2804 | "id": "EWFFGo5r-l0b"
2805 | }
2806 | },
2807 | {
2808 | "cell_type": "code",
2809 | "execution_count": null,
2810 | "source": [
2811 | "plot_stats(stats)"
2812 | ],
2813 | "outputs": [],
2814 | "metadata": {
2815 | "id": "4KgEm4EL_uDe"
2816 | }
2817 | },
2818 | {
2819 | "cell_type": "markdown",
2820 | "source": [
2821 | "If everythings go well, the results should be on par with the previous GAT model. Multiple heads cand be added to improve the performance, but this is beyond the goal of this tutorial."
2822 | ],
2823 | "metadata": {
2824 | "id": "ALDPIgLCXj_k"
2825 | }
2826 | },
2827 | {
2828 | "cell_type": "markdown",
2829 | "source": [
2830 | "# ⏳ [Advanced] **Over-squashing: a bottleneck problem in Graph Networks** "
2831 | ],
2832 | "metadata": {
2833 | "id": "6gCvStRGNs0O"
2834 | }
2835 | },
2836 | {
2837 | "cell_type": "markdown",
2838 | "source": [
2839 | "Graph Neural Networks are powerful models that, [under certain conditions](https://arxiv.org/abs/1810.00826), can distinguish a large class of non-isomorphic graphs. However, there are still scenarios where standard graph neural networks fall short.\n",
2840 | "\n",
2841 | "To get an intuition about challenges in Graph Neural Networks, lets look at the following ❚█══█❚ graph."
2842 | ],
2843 | "metadata": {
2844 | "id": "OOpElCqyX9vN"
2845 | }
2846 | },
2847 | {
2848 | "cell_type": "code",
2849 | "execution_count": null,
2850 | "source": [
2851 | "graph = generate_barbell_graph(m1=6, m2=7, target_label=[0,0,1])\n",
2852 | "gallery([graph], layout='custom')"
2853 | ],
2854 | "outputs": [],
2855 | "metadata": {
2856 | "id": "tYRFm_AbnfSj"
2857 | }
2858 | },
2859 | {
2860 | "cell_type": "markdown",
2861 | "source": [
2862 | "If our task requires sending messages between one node from the left clique (lets call it A) and another node from the right clique (lets call it B), this message needs to travers the thin *bridge* connecting A and B. However, besides the information we are interested in, there is a high number of messages (all the possible messages from a node in A to a node in B, a total of $|A| \\times |B|$) traversing that bridge. This means that, the pairwise messages exchanged by the nodes forming the bridge $\\{0, 9, 12, 13 \\dots 18\\}$ need to be capable of preserving all the information from this increasing number of messages.\n",
2863 | "\n",
2864 | "This problem is called **[over-squashing](https://arxiv.org/pdf/2006.05205)**. In simpler words, the information from an exponentially-growing set of messages is compressed into fixed-length node vectors.\n",
2865 | "\n",
2866 | "\n",
2867 | "\n",
2868 | "
\n",
2869 | "Image from the paper \n",
2870 | "\n"
2871 | ],
2872 | "metadata": {
2873 | "id": "s0w2EZARZMbX"
2874 | }
2875 | },
2876 | {
2877 | "cell_type": "markdown",
2878 | "source": [
2879 | "## Long-range Node Transfer in a Tree 🌴\n",
2880 | "\n",
2881 | "To understand the over-squashing problem better, we are gonna use a [synthetic dataset ](https://arxiv.org/pdf/2006.05205) especially designed for that. Lets consider a tree (a connected graph without cycles) of depth $r$. All the nodes have $(0,0)$ features except from the leaves wich are labeled (id, value) and the root node which is labeled (id, $0$). For the (id, value) pairs, each value is uniquely associated to an id. The goal is, from the features of the root node, to predict the value corresponding to the node's id. In order to solve this task, all the leaves needs to send their (id, value) message to the root node, and the root needs to select the one with the same id as itself in order to find its corresponding value. "
2882 | ],
2883 | "metadata": {
2884 | "id": "vFtiWxljvCKA"
2885 | }
2886 | },
2887 | {
2888 | "cell_type": "markdown",
2889 | "source": [
2890 | "Lets look at an example:"
2891 | ],
2892 | "metadata": {
2893 | "id": "yVKJgMTgTlUO"
2894 | }
2895 | },
2896 | {
2897 | "cell_type": "code",
2898 | "execution_count": null,
2899 | "source": [
2900 | "tree_dataset, dim0, num_classes = DictionaryLookupDataset(3).generate_data(add_self_loops=False)\n",
2901 | "draw_one_tree(tree_dataset[1])"
2902 | ],
2903 | "outputs": [],
2904 | "metadata": {
2905 | "id": "mm_RX0aVXMv7"
2906 | }
2907 | },
2908 | {
2909 | "cell_type": "markdown",
2910 | "source": [
2911 | "In the example above, the root node has id $2$. Among the leaf nodes, the id $2$ corresponds to the value 1. Thus, the root node should predict 1.\n",
2912 | "\n",
2913 | "As the tree becomes deeper, both the distance and the number of messages from the leafs to the root increases. This is a particularly challenging problem for a message passing GNN, since it will require the summarisation of an *increasingly number of messages* on a finite vectorial representation.\n"
2914 | ],
2915 | "metadata": {
2916 | "id": "fXhQvks8aYq6"
2917 | }
2918 | },
2919 | {
2920 | "cell_type": "markdown",
2921 | "source": [
2922 | "To understand to what extent over-squashing is a problem in practice, we will train and evaluate our current models on this synthetic dataset."
2923 | ],
2924 | "metadata": {
2925 | "id": "rQYdTC8hFDqQ"
2926 | }
2927 | },
2928 | {
2929 | "cell_type": "code",
2930 | "execution_count": null,
2931 | "source": [
2932 | "tree_dataset, dim0, num_classes = DictionaryLookupDataset(4).generate_data()\n",
2933 | "\n",
2934 | "train_syn_data = tree_dataset[:1000]\n",
2935 | "val_syn_data = tree_dataset[1000:2000]\n",
2936 | "test_syn_data = tree_dataset[2000:3000]\n",
2937 | "\n",
2938 | "batch_size = 128\n",
2939 | "train_syn_loader = DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)\n",
2940 | "val_syn_loader = DataLoader(tree_dataset, batch_size=batch_size, shuffle=False)\n",
2941 | "test_syn_loader = DataLoader(tree_dataset, batch_size=batch_size, shuffle=False)"
2942 | ],
2943 | "outputs": [],
2944 | "metadata": {
2945 | "id": "zd4Rzt71zFib"
2946 | }
2947 | },
2948 | {
2949 | "cell_type": "code",
2950 | "execution_count": null,
2951 | "source": [
2952 | "print(f\"Number of classes: {num_classes}\")\n",
2953 | "graph = tree_dataset[0]\n",
2954 | "print(f\"Node features for graph 0:\\n {tree_dataset[0].x.T}\")\n",
2955 | "print(f\"Target label for graph 0: {tree_dataset[0].y}\") # first graph has class 0\n",
2956 | "print(f\"Source node for graph 0: {tree_dataset[0].root_mask.int()}\") # the node denoted by 1 is the source node"
2957 | ],
2958 | "outputs": [],
2959 | "metadata": {
2960 | "id": "aWu_RYSu4Ywo"
2961 | }
2962 | },
2963 | {
2964 | "cell_type": "markdown",
2965 | "source": [
2966 | "## 🖋 **Task** Modify the below graph model class to output the prediction *only* from the source node information (as denoted by the `data.root_mask` attribute) instead of computing the average global pooling as before."
2967 | ],
2968 | "metadata": {
2969 | "id": "IgDQtyXZEwck"
2970 | }
2971 | },
2972 | {
2973 | "cell_type": "markdown",
2974 | "source": [
2975 | "\n",
2976 | "> ❗ Note that since we are interested in analysing the capacity of the GNN architectures to model this task, our primary metrics will be the train loss / train accuracy. A model uncapable of sending long-distance message not only will struggle to generalized on validation set, but will also perform poorly on the training set."
2977 | ],
2978 | "metadata": {
2979 | "id": "M_9YFXUZcuZ7"
2980 | }
2981 | },
2982 | {
2983 | "cell_type": "code",
2984 | "execution_count": null,
2985 | "source": [
2986 | "class GTpygSynModel(Module):\n",
2987 | " def __init__(self, output_dim, hidden_dim, num_layers=2, dim_emb=0):\n",
2988 | " \"\"\"\n",
2989 | " Sparse Graph Transformer Neural Network model for root-node-level classification\n",
2990 | "\n",
2991 | " In our synthetic task the prediction needs to be done only\n",
2992 | " from the features of the root node stored in root_mask attribute\n",
2993 | "\n",
2994 | " Args:\n",
2995 | " output_dim: (int) - output dimension (number of classes)\n",
2996 | " hidden_dim: (int) - hidden dimension\n",
2997 | " num_layers: (int) - number of GT layers used in the model\n",
2998 | " dim_emb: (int) - number of potential ids/vals in the TREE\n",
2999 | " (characterstic to the dataset)\n",
3000 | " \"\"\"\n",
3001 | " super(GTpygSynModel, self).__init__()\n",
3002 | " self.num_layers = num_layers # please select num_layers>=2\n",
3003 | "\n",
3004 | " self.layer0_keys = Embedding(num_embeddings=dim_emb + 1, embedding_dim=hidden_dim)\n",
3005 | " self.layer0_values = Embedding(num_embeddings=dim_emb + 1, embedding_dim=hidden_dim)\n",
3006 | "\n",
3007 | " self.layers = [GTpygLayer(hidden_dim, hidden_dim, hidden_dim) for _ in range(num_layers-1)]\n",
3008 | " self.layers += [GTpygLayer(hidden_dim, output_dim, hidden_dim)]\n",
3009 | " self.layer_norms = [torch_geometric.nn.LayerNorm(hidden_dim) for _ in range(num_layers-1)]\n",
3010 | "\n",
3011 | " self.layers = ModuleList(self.layers)\n",
3012 | " self.layer_norms = ModuleList(self.layer_norms)\n",
3013 | "\n",
3014 | " def forward(self, graph):\n",
3015 | " \"\"\"\n",
3016 | " Args:\n",
3017 | " graph (PyG.Data): batch of PyG graphs\n",
3018 | " Returns:\n",
3019 | " out (batch_size, output_dim): updated representation for the root nodes\n",
3020 | " \"\"\"\n",
3021 | " # we changed the atom embedding with an id-embedding and a value-embedding\n",
3022 | " # the final node input features will be the sum of the id and value embedding\n",
3023 | " x_key, x_val = graph.x[:, 0], graph.x[:, 1]\n",
3024 | " x_key_embed = self.layer0_keys(x_key)\n",
3025 | " x_val_embed = self.layer0_values(x_val)\n",
3026 | " new_x = x_key_embed + x_val_embed\n",
3027 | "\n",
3028 | " x = new_x\n",
3029 | "\n",
3030 | " for i in range(self.num_layers-1):\n",
3031 | " x = self.layers[i](x, graph.edge_index)\n",
3032 | " x = F.relu(x)\n",
3033 | " x = x + new_x\n",
3034 | " x = self.layer_norms[i](x)\n",
3035 | " new_x = x\n",
3036 | " x = self.layers[-1](x, graph.edge_index)\n",
3037 | "\n",
3038 | " # ============ YOUR CODE HERE ==============\n",
3039 | " # From x extract only the information corresponging to\n",
3040 | " # the root node as indicated by graph.root_mask\n",
3041 | " #\n",
3042 | " # out = ...\n",
3043 | " # ===========================================\n",
3044 | "\n",
3045 | " return out"
3046 | ],
3047 | "outputs": [],
3048 | "metadata": {
3049 | "id": "l0SYGrH46hQk"
3050 | }
3051 | },
3052 | {
3053 | "cell_type": "code",
3054 | "execution_count": null,
3055 | "source": [
3056 | "#@title [RUN] Helper functions for train-eval models.\n",
3057 | "\n",
3058 | "def train_syn(model, train_loader, optimizer, device):\n",
3059 | " model.train()\n",
3060 | " loss_all = 0\n",
3061 | " pred_all = []\n",
3062 | " gt_all = []\n",
3063 | "\n",
3064 | " for data in train_loader:\n",
3065 | " data = data.to(device)\n",
3066 | " optimizer.zero_grad()\n",
3067 | "\n",
3068 | " y_pred = model(data)\n",
3069 | " pred_all.extend(y_pred.detach().cpu())\n",
3070 | " gt_all.extend(data.y)\n",
3071 | "\n",
3072 | " loss = F.cross_entropy(y_pred, data.y)\n",
3073 | " loss.backward()\n",
3074 | " loss_all += loss.item() * data.num_graphs\n",
3075 | " optimizer.step()\n",
3076 | "\n",
3077 | " pred_all = torch.stack(pred_all).cpu().numpy()\n",
3078 | " gt_all = torch.stack(gt_all).cpu().numpy()\n",
3079 | " return loss_all / len(train_loader.dataset), pred_all, gt_all\n",
3080 | "\n",
3081 | "\n",
3082 | "def eval_syn(model, loader, device):\n",
3083 | " model.eval()\n",
3084 | " loss = 0\n",
3085 | "\n",
3086 | " pred_all = []\n",
3087 | " gt_all = []\n",
3088 | " for data in loader:\n",
3089 | " data = data.to(device)\n",
3090 | " with torch.no_grad():\n",
3091 | " y_pred = model(data)\n",
3092 | " pred_all.extend(y_pred.detach().cpu())\n",
3093 | " gt_all.extend(data.y.detach().cpu())\n",
3094 | "\n",
3095 | " loss = F.cross_entropy(y_pred, data.y)\n",
3096 | " loss += loss.item() * data.num_graphs\n",
3097 | "\n",
3098 | " pred_all = torch.stack(pred_all).cpu().numpy()\n",
3099 | " gt_all = torch.stack(gt_all).cpu().numpy()\n",
3100 | " return loss / len(loader.dataset), pred_all, gt_all\n",
3101 | "\n",
3102 | "\n",
3103 | "def run_exp_syn(model, train_loader, val_loader, test_loader,\n",
3104 | " lr=0.001, n_epochs=100):\n",
3105 | "\n",
3106 | " print(\"\\nModel architecture:\")\n",
3107 | " print(model)\n",
3108 | " device = DEVICE\n",
3109 | " model = model.to(device)\n",
3110 | "\n",
3111 | " # Adam optimizer\n",
3112 | " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
3113 | "\n",
3114 | " # LR scheduler which decays LR when validation metric doesn't improve\n",
3115 | " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
3116 | " optimizer, mode='min', factor=0.9, patience=5, min_lr=0.00001)\n",
3117 | "\n",
3118 | " print(\"\\nStart training:\")\n",
3119 | " best_val_acc = None\n",
3120 | " training_stats = None # Track Test/Val MAE vs. epoch (for plotting)\n",
3121 | "\n",
3122 | " for epoch in range(1, n_epochs+1):\n",
3123 | " # Call LR scheduler at start of each epoch\n",
3124 | " lr = scheduler.optimizer.param_groups[0]['lr']\n",
3125 | "\n",
3126 | " # Train model for one epoch, return avg. training loss\n",
3127 | " train_loss, train_pred, train_gt = train_syn(model, train_loader, optimizer, device)\n",
3128 | " # print(np.argmax(train_pred, -1), train_gt)\n",
3129 | " train_acc = accuracy_score(train_gt, np.argmax(train_pred, -1))\n",
3130 | "\n",
3131 | " # Evaluate model on validation set\n",
3132 | " val_loss, val_pred, val_gt = eval_syn(model, val_loader, device)\n",
3133 | " val_acc = accuracy_score(val_gt, np.argmax(val_pred, -1))\n",
3134 | "\n",
3135 | " if best_val_acc is None or val_acc >= best_val_acc:\n",
3136 | " # Evaluate model on test set if validation metric improves\n",
3137 | " test_loss, test_pred, test_gt = eval_syn(model, test_loader, device)\n",
3138 | " test_acc = accuracy_score(test_gt, np.argmax(test_pred, -1))\n",
3139 | " best_val_loss = val_loss\n",
3140 | " best_val_acc = val_acc\n",
3141 | "\n",
3142 | "\n",
3143 | " if epoch % 10 == 0:\n",
3144 | " # Print and track stats every 10 epochs\n",
3145 | " print(f'Epoch: {epoch:03d}, LR: {lr:5f}, Train Loss: {train_loss:.7f},\\n '\n",
3146 | " f'Val Loss: {val_loss:.7f}, Test Loss: {test_loss:.7f}\\n',\n",
3147 | " f'Train Acc: {train_acc:.7f}\\n')\n",
3148 | " # f'Val Acc: {val_acc:.7f}, Test Acc: {test_acc:.7f}\\n')\n",
3149 | "\n",
3150 | " epoch_stats = {'train_loss': train_loss, 'val_loss': val_loss, 'train_acc':train_acc, 'epoch':epoch}\n",
3151 | " training_stats = update_stats(training_stats, epoch_stats)\n",
3152 | "\n",
3153 | " print(\"Done\")\n",
3154 | "\n",
3155 | " return training_stats"
3156 | ],
3157 | "outputs": [],
3158 | "metadata": {
3159 | "cellView": "form",
3160 | "id": "TkyIezwB7On-"
3161 | }
3162 | },
3163 | {
3164 | "cell_type": "code",
3165 | "execution_count": null,
3166 | "source": [
3167 | "# @title [RUN] Hyperparameters GT\n",
3168 | "\n",
3169 | "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n",
3170 | "LR = 0.001 #@param {type:\"number\"}\n",
3171 | "HIDDEN_DIM = 32 #@param {type:\"integer\"}\n",
3172 | "NUM_LAYERS = 5 #@param {type:\"integer\"}\n",
3173 | "\n",
3174 | "\n",
3175 | "#you can add more here if you need"
3176 | ],
3177 | "outputs": [],
3178 | "metadata": {
3179 | "id": "WCzo1sAD-Kyd"
3180 | }
3181 | },
3182 | {
3183 | "cell_type": "code",
3184 | "execution_count": null,
3185 | "source": [
3186 | "model_gt_syn = GTpygSynModel(num_classes+1, HIDDEN_DIM, num_layers=NUM_LAYERS, dim_emb=dim0)\n",
3187 | "gt_perf_per_epoch = run_exp_syn(model_gt_syn, train_syn_loader, val_syn_loader, test_syn_loader,\n",
3188 | " lr=LR, n_epochs=NUM_EPOCHS)"
3189 | ],
3190 | "outputs": [],
3191 | "metadata": {
3192 | "id": "2hJq8TxU9MTT"
3193 | }
3194 | },
3195 | {
3196 | "cell_type": "markdown",
3197 | "source": [
3198 | "The performance of our model seems to be far from satisfactory, due to the over-squashing phenomena we discussed above.\n",
3199 | "\n",
3200 | "Some of the real-world problems might not require long-distance messages, thus a normal Message Passing GNN would be enough to solve it. However, if sending messages between distance nodes is a characteristic of our task, we need better solutions."
3201 | ],
3202 | "metadata": {
3203 | "id": "fQqH1MnGgFZ3"
3204 | }
3205 | },
3206 | {
3207 | "cell_type": "markdown",
3208 | "source": [
3209 | "## Graph Transformer: message passing on fully connected graphs \n",
3210 | "\n"
3211 | ],
3212 | "metadata": {
3213 | "id": "Kr9wHG1wNy38"
3214 | }
3215 | },
3216 | {
3217 | "cell_type": "markdown",
3218 | "source": [
3219 | "Hopefully we managed to convince you that sending messages between nodes situated far away in the graph is not an easy job. But how can we fix that?\n",
3220 | "\n",
3221 | "One simple solution is to ensure that all the nodes in our graph stay reasonably close to each other. This is the key principle behind a new and exciting research direction known as **graph rewiring**: if sending messages on the original graph is too hard for our current tools, lets help them a bit by modifying the graph topology, i.e. adding/removing edges or adding/removing nodes.\n",
3222 | "\n",
3223 | "Several papers analyse the over-squashing problem, by [characterising the ingredients that leads to it](https://proceedings.mlr.press/v202/di-giovanni23a.html) and [proposing methods to alleviate it](https://arxiv.org/abs/2305.08018). In this tutorial, we will focus on two simple but efficient strategies. If you are curious to find out more, you can find a list of alternative approaches at the end of the tutorial."
3224 | ],
3225 | "metadata": {
3226 | "id": "ATl811kkhrt7"
3227 | }
3228 | },
3229 | {
3230 | "cell_type": "markdown",
3231 | "source": [
3232 | "Arguably the easient solution to get rid of the over-squashing problem is to perform the message-passing algorithm on a fully connected graph, where each pair of two nodes are connected. This way, a single iteration of Message Passing GNN will exchange message between any pair of nodes in the graph. This can be seen as a form of rewiring the graph topology, by adding all the possible missing edges in the graph."
3233 | ],
3234 | "metadata": {
3235 | "id": "3bmYyvB6lJ2J"
3236 | }
3237 | },
3238 | {
3239 | "cell_type": "markdown",
3240 | "source": [
3241 | "## 🖋 **Task** Fill in the code to generate the edge_indexes corresponding to a fully connected graph"
3242 | ],
3243 | "metadata": {
3244 | "id": "b-JVHNrEFUTF"
3245 | }
3246 | },
3247 | {
3248 | "cell_type": "markdown",
3249 | "source": [
3250 | "> Since this transformation is independent of the GNN processing, we will perform it as a pre-processing stage, applied when we first generate the data. This is a common approach in Pytorch Geometric, and is achieved using a special class called [`BaseTransform`](https://colab.research.google.com/drive/1VKXNhdtKS3piWqAE5h1OHdZlBcJkpx4m#scrollTo=x3VFuQbamUIf). Each `call` of the transform function will receive an instance of a graph and return the transformed version (in our case, the same graph with an extra attribute `rewire_edge_index`)."
3251 | ],
3252 | "metadata": {
3253 | "id": "x3VFuQbamUIf"
3254 | }
3255 | },
3256 | {
3257 | "cell_type": "code",
3258 | "execution_count": null,
3259 | "source": [
3260 | "class FullyConnectedTransform(BaseTransform):\n",
3261 | " def __init__(self):\n",
3262 | " super(FullyConnectedTransform).__init__()\n",
3263 | "\n",
3264 | " def __call__(self, data):\n",
3265 | " \"\"\"\n",
3266 | " Args:\n",
3267 | " data (PyG.Data): one instance of PyG graph\n",
3268 | " Returns:\n",
3269 | " data (PyG.Data): original graph with an additional attribute\n",
3270 | " rewire_edge_index containing the edge_index\n",
3271 | " for a fully connected graph\n",
3272 | " \"\"\"\n",
3273 | " num_nodes = data.x.shape[0]\n",
3274 | "\n",
3275 | " # ============ YOUR CODE HERE ==============\n",
3276 | " # Create a list of indices corresponding to a\n",
3277 | " # fully connected graph with num_nodes\n",
3278 | " #\n",
3279 | " # fully_edge_index = ...\n",
3280 | " # ===========================================\n",
3281 | "\n",
3282 | " # This will work fine in the batching as long as the attribute name contains edge_index\n",
3283 | " data.rewire_edge_index = fully_edge_index\n",
3284 | " return data"
3285 | ],
3286 | "outputs": [],
3287 | "metadata": {
3288 | "id": "nNZ6uMBDnAsn"
3289 | }
3290 | },
3291 | {
3292 | "cell_type": "code",
3293 | "execution_count": null,
3294 | "source": [
3295 | "# One way to apply the transformation is to directly apply it to our graphs\n",
3296 | "fc_transform = FullyConnectedTransform()\n",
3297 | "modified_graph = fc_transform(tree_dataset[0])\n",
3298 | "\n",
3299 | "# Another way is to send it as argument when we first create the dataset\n",
3300 | "tree_dataset, _, _ = DictionaryLookupDataset(3).generate_data(transform=FullyConnectedTransform())\n",
3301 | "draw_one_graph(plt.axes(), tree_dataset[0].rewire_edge_index.numpy(), layout='tree')\n"
3302 | ],
3303 | "outputs": [],
3304 | "metadata": {
3305 | "id": "e3vSTSyrmmp_"
3306 | }
3307 | },
3308 | {
3309 | "cell_type": "markdown",
3310 | "source": [
3311 | "Things looks as expected. And we definitely don't have far away nodes to be worried about."
3312 | ],
3313 | "metadata": {
3314 | "id": "vp64bCgUoFfY"
3315 | }
3316 | },
3317 | {
3318 | "cell_type": "markdown",
3319 | "source": [
3320 | "## 🖋 **Task** Modify the graph neural network model to perform the message passing propagation based on the fully connected adjacency matrix\n",
3321 | "\n"
3322 | ],
3323 | "metadata": {
3324 | "id": "I_-nIAmXj3I_"
3325 | }
3326 | },
3327 | {
3328 | "cell_type": "code",
3329 | "execution_count": null,
3330 | "source": [
3331 | "class GTpygRewireModel(Module):\n",
3332 | " def __init__(self, output_dim, hidden_dim, num_layers=2, dim_emb=0):\n",
3333 | " \"\"\"\n",
3334 | " Graph Transformer Neural Network using the rewired adjacency matrix\n",
3335 | " as indicated by graph.rewire_edge_index\n",
3336 | "\n",
3337 | " The model perform root-node-level classification\n",
3338 | " In our synthetic task the prediction needs to be done only\n",
3339 | " from the features of the root node stored in root_mask attribute\n",
3340 | "\n",
3341 | " Args:\n",
3342 | " output_dim: (int) - output dimension (number of classes)\n",
3343 | " hidden_dim: (int) - hidden dimension\n",
3344 | " num_layers: (int) - number of GT layers used in the model\n",
3345 | " dim_emb: (int) - number of potential ids/vals in the TREE\n",
3346 | " (characterstic to the dataset)\n",
3347 | " \"\"\"\n",
3348 | " super(GTpygRewireModel, self).__init__()\n",
3349 | " self.num_layers = num_layers # please select num_layers>=2\n",
3350 | "\n",
3351 | " self.layer0_keys = Embedding(num_embeddings=dim_emb + 1, embedding_dim=hidden_dim)\n",
3352 | " self.layer0_values = Embedding(num_embeddings=dim_emb + 1, embedding_dim=hidden_dim)\n",
3353 | "\n",
3354 | " self.layers = [GTpygLayer(hidden_dim, hidden_dim, hidden_dim) for _ in range(num_layers-1)]\n",
3355 | " self.layers += [GTpygLayer(hidden_dim, output_dim, hidden_dim)]\n",
3356 | " self.layer_norms = [torch_geometric.nn.LayerNorm(hidden_dim) for _ in range(num_layers-1)]\n",
3357 | "\n",
3358 | " self.layers = ModuleList(self.layers)\n",
3359 | " self.layer_norms = ModuleList(self.layer_norms)\n",
3360 | "\n",
3361 | " def forward(self, graph):\n",
3362 | " \"\"\"\n",
3363 | " Args:\n",
3364 | " graph (PyG.Data): batch of PyG graphs\n",
3365 | " Returns:\n",
3366 | " out (batch_size, output_dim): updated representation for the root nodes\n",
3367 | " \"\"\"\n",
3368 | "\n",
3369 | " # we changed the atom embedding with an id-embedding and a value-embedding\n",
3370 | " # the final node input features will be the sum of the id and value embedding\n",
3371 | " x_key, x_val = graph.x[:, 0], graph.x[:, 1]\n",
3372 | " x_key_embed = self.layer0_keys(x_key)\n",
3373 | " x_val_embed = self.layer0_values(x_val)\n",
3374 | " new_x = x_key_embed + x_val_embed\n",
3375 | " x = new_x\n",
3376 | "\n",
3377 | " # ============ YOUR CODE HERE ==============\n",
3378 | " # This is the ORIGINAL code. Your task is to MODIFY\n",
3379 | " # it such that instead of the original graph topology, it uses\n",
3380 | " # the rewires version as denoted by `graph.rewire_edge_index`\n",
3381 | " #\n",
3382 | " # for i in range(self.num_layers-1):\n",
3383 | " # x = self.layers[i](x, graph.edge_index)\n",
3384 | " # x = F.relu(x)\n",
3385 | " # x = x + new_x\n",
3386 | " # x = self.layer_norms[i](x)\n",
3387 | " # new_x = x\n",
3388 | " # x = self.layers[-1](x, graph.edge_index)\n",
3389 | " # ===========================================\n",
3390 | "\n",
3391 | " out = x[graph.root_mask]\n",
3392 | " return out"
3393 | ],
3394 | "outputs": [],
3395 | "metadata": {
3396 | "id": "wp01bNm7J1rV"
3397 | }
3398 | },
3399 | {
3400 | "cell_type": "markdown",
3401 | "source": [
3402 | "It's time to train our rewired model on the synthetic dataset. We will utilise the same train-test utilitaries, but using the preprocessed dataset and the new model."
3403 | ],
3404 | "metadata": {
3405 | "id": "HxJ8QzS-kUut"
3406 | }
3407 | },
3408 | {
3409 | "cell_type": "code",
3410 | "execution_count": null,
3411 | "source": [
3412 | "# @title [RUN] Hyperparameters GT\n",
3413 | "\n",
3414 | "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n",
3415 | "LR = 0.001 #@param {type:\"number\"}\n",
3416 | "HIDDEN_DIM = 32 #@param {type:\"integer\"}\n",
3417 | "NUM_LAYERS = 2 #@param {type:\"integer\"}\n",
3418 | "\n",
3419 | "\n",
3420 | "#you can add more here if you need"
3421 | ],
3422 | "outputs": [],
3423 | "metadata": {
3424 | "id": "307ze-xBiMmT"
3425 | }
3426 | },
3427 | {
3428 | "cell_type": "code",
3429 | "execution_count": null,
3430 | "source": [
3431 | "tree_dataset, dim0, num_classes = DictionaryLookupDataset(4).generate_data(transform=FullyConnectedTransform())\n",
3432 | "\n",
3433 | "train_syn_data = tree_dataset[:1000]\n",
3434 | "val_syn_data = tree_dataset[1000:2000]\n",
3435 | "test_syn_data = tree_dataset[2000:3000]\n",
3436 | "\n",
3437 | "batch_size = 128\n",
3438 | "train_syn_loader = DataLoader(train_syn_data, batch_size=batch_size, shuffle=True)\n",
3439 | "val_syn_loader = DataLoader(val_syn_data, batch_size=batch_size, shuffle=False)\n",
3440 | "test_syn_loader = DataLoader(test_syn_data, batch_size=batch_size, shuffle=False)"
3441 | ],
3442 | "outputs": [],
3443 | "metadata": {
3444 | "id": "ysOyGpnowSVU"
3445 | }
3446 | },
3447 | {
3448 | "cell_type": "code",
3449 | "execution_count": null,
3450 | "source": [
3451 | "seed(0)\n",
3452 | "model_gt_syn_full = GTpygRewireModel(num_classes+1, HIDDEN_DIM, num_layers=NUM_LAYERS, dim_emb=dim0)\n",
3453 | "gt_perf_per_epoch = run_exp_syn(model_gt_syn_full, train_syn_loader, val_syn_loader, test_syn_loader,\n",
3454 | " lr=LR, n_epochs=NUM_EPOCHS)"
3455 | ],
3456 | "outputs": [],
3457 | "metadata": {
3458 | "id": "OUkmMKrjRAMB"
3459 | }
3460 | },
3461 | {
3462 | "cell_type": "markdown",
3463 | "source": [
3464 | "While running the GNN on the original graph topology couldn't learn anything, running the same model on the fully connected graph is now capable of learning the task. This validates that this rewiring approach is effective for our synthetic problem.\n",
3465 | "\n",
3466 | "🤔 However, this technique brings some issues. Before moving forwards, take some times to think about potential problems."
3467 | ],
3468 | "metadata": {
3469 | "id": "6NKXaR0yjk8l"
3470 | }
3471 | },
3472 | {
3473 | "cell_type": "markdown",
3474 | "source": [
3475 | "## Expander Graphs: aiming for a sparser rewiring\n",
3476 | "\n",
3477 | "Rewiring the graph into a fully connected one helps us avoiding the over-squashing issue. However, the fully-connected rewiring suffers from a strong limitation. If the graph contains a high number of nodes, connecting each pair of nodes result in a very high number of edges. This is both memory inefficient and can exhibit numerical and optimisation issues.\n",
3478 | "\n",
3479 | "An ellegant solution is to [rewire the graph based on **expander graphs**](https://arxiv.org/abs/2210.02997). Expander graphs represent a class of graphs that are both sparse and highly connected. The approach remains the same: Instead of the original topology or the fully connected graph, we perform the message passing on the expander graph.\n",
3480 | "\n",
3481 | "We provide the code to generate the expander graph below.\n"
3482 | ],
3483 | "metadata": {
3484 | "id": "etPjeuQiDe2S"
3485 | }
3486 | },
3487 | {
3488 | "cell_type": "code",
3489 | "execution_count": null,
3490 | "source": [
3491 | "# @title [RUN] Generate a set of graph expanders\n",
3492 | "# code based on: https://github.com/kpetrovicc/TGR/blob/main/modules/cayley_construction.py\n",
3493 | "\"\"\"# Cayley Graph Generation\"\"\"\n",
3494 | "\n",
3495 | "import math\n",
3496 | "\n",
3497 | "from collections import deque\n",
3498 | "import re\n",
3499 | "import numpy as np\n",
3500 | "import torch\n",
3501 | "from torch_geometric.utils import subgraph\n",
3502 | "import networkx as nx\n",
3503 | "\n",
3504 | "_CAYLEY_BOUNDS = [\n",
3505 | " (6, 2),\n",
3506 | " (24, 3),\n",
3507 | " (120, 5),\n",
3508 | " (336, 7),\n",
3509 | " (1320, 11),\n",
3510 | "]\n",
3511 | "\n",
3512 | "def build_cayley_bank():\n",
3513 | "\n",
3514 | " ret_edges = []\n",
3515 | "\n",
3516 | " for _, p in _CAYLEY_BOUNDS:\n",
3517 | " generators = np.array([\n",
3518 | " [[1, 1], [0, 1]],\n",
3519 | " [[1, p-1], [0, 1]],\n",
3520 | " [[1, 0], [1, 1]],\n",
3521 | " [[1, 0], [p-1, 1]]])\n",
3522 | " ind = 1\n",
3523 | "\n",
3524 | " queue = deque([np.array([[1, 0], [0, 1]])])\n",
3525 | " nodes = {(1, 0, 0, 1): 0}\n",
3526 | "\n",
3527 | " senders = []\n",
3528 | " receivers = []\n",
3529 | "\n",
3530 | " while queue:\n",
3531 | " x = queue.pop()\n",
3532 | " x_flat = (x[0][0], x[0][1], x[1][0], x[1][1])\n",
3533 | " assert x_flat in nodes\n",
3534 | " ind_x = nodes[x_flat]\n",
3535 | " for i in range(4):\n",
3536 | " tx = np.matmul(x, generators[i])\n",
3537 | " tx = np.mod(tx, p)\n",
3538 | " tx_flat = (tx[0][0], tx[0][1], tx[1][0], tx[1][1])\n",
3539 | " if tx_flat not in nodes:\n",
3540 | " nodes[tx_flat] = ind\n",
3541 | " ind += 1\n",
3542 | " queue.append(tx)\n",
3543 | " ind_tx = nodes[tx_flat]\n",
3544 | "\n",
3545 | " senders.append(ind_x)\n",
3546 | " receivers.append(ind_tx)\n",
3547 | "\n",
3548 | " ret_edges.append((p, [senders, receivers]))\n",
3549 | "\n",
3550 | " return ret_edges\n",
3551 | "\n",
3552 | "def batched_augment_cayley(num_nodes, cayley_bank):\n",
3553 | "\n",
3554 | " # Find the appropriate cayley graph\n",
3555 | " p = 2\n",
3556 | " chosen_i = -1\n",
3557 | "\n",
3558 | " senders=[]\n",
3559 | " receivers=[]\n",
3560 | "\n",
3561 | " for i in range(len(_CAYLEY_BOUNDS)):\n",
3562 | " sz, p = _CAYLEY_BOUNDS[i]\n",
3563 | " if sz >= num_nodes:\n",
3564 | " chosen_i = i\n",
3565 | " break\n",
3566 | " assert chosen_i >= 0\n",
3567 | "\n",
3568 | " _p, edge_pack = cayley_bank[chosen_i]\n",
3569 | " assert p == _p\n",
3570 | "\n",
3571 | " for v, w in zip(*edge_pack):\n",
3572 | " if v < num_nodes and w < num_nodes:\n",
3573 | " senders.append(v)\n",
3574 | " receivers.append(w)\n",
3575 | "\n",
3576 | " # Create edge attributes\n",
3577 | " edge_attr = [[0]*272 for _ in range(len(senders))]\n",
3578 | " edge_index = [senders, receivers]\n",
3579 | " return edge_index, edge_attr"
3580 | ],
3581 | "outputs": [],
3582 | "metadata": {
3583 | "id": "GJiLHs8Excog"
3584 | }
3585 | },
3586 | {
3587 | "cell_type": "code",
3588 | "execution_count": null,
3589 | "source": [
3590 | "class ExpanderTransform(BaseTransform):\n",
3591 | " def __init__(self):\n",
3592 | " super(ExpanderTransform).__init__()\n",
3593 | " self.cayley_bank = build_cayley_bank()\n",
3594 | "\n",
3595 | " def __call__(self, data):\n",
3596 | " \"\"\"\n",
3597 | " Args:\n",
3598 | " data (PyG.Data): one instance of PyG graph\n",
3599 | " Returns:\n",
3600 | " data (PyG.Data): original graph with an additional attribute\n",
3601 | " rewire_edge_index containing the edge_index for\n",
3602 | " an expander graph\n",
3603 | " \"\"\"\n",
3604 | " num_nodes = data.x.shape[0]\n",
3605 | " cayley_edge_index, cayley_edge_attr = batched_augment_cayley(num_nodes, self.cayley_bank)\n",
3606 | " data.rewire_edge_index = torch.tensor(cayley_edge_index)\n",
3607 | " return data"
3608 | ],
3609 | "outputs": [],
3610 | "metadata": {
3611 | "id": "FFxzd9i909DU"
3612 | }
3613 | },
3614 | {
3615 | "cell_type": "code",
3616 | "execution_count": null,
3617 | "source": [
3618 | "expander = ExpanderTransform()"
3619 | ],
3620 | "outputs": [],
3621 | "metadata": {
3622 | "id": "c88PbzeqDCbZ"
3623 | }
3624 | },
3625 | {
3626 | "cell_type": "markdown",
3627 | "source": [
3628 | "Lets have a look at how the two rewiring techniques looks like for one graph in our dataset."
3629 | ],
3630 | "metadata": {
3631 | "id": "5PfHYkwl7B97"
3632 | }
3633 | },
3634 | {
3635 | "cell_type": "code",
3636 | "execution_count": null,
3637 | "source": [
3638 | "tree_dataset, dim0, num_classes = DictionaryLookupDataset(5).generate_data(add_self_loops=False)\n",
3639 | "# original graph\n",
3640 | "orig_data = tree_dataset[0].clone()\n",
3641 | "\n",
3642 | "# expander graph\n",
3643 | "cayley_data = orig_data.clone()\n",
3644 | "cayley_data = expander(cayley_data)\n",
3645 | "cayley_data.edge_index = cayley_data.rewire_edge_index\n",
3646 | "\n",
3647 | "# fully connected graph\n",
3648 | "fully_data = orig_data.clone()\n",
3649 | "fully_data = FullyConnectedTransform()(fully_data)\n",
3650 | "fully_data.edge_index = fully_data.rewire_edge_index\n",
3651 | "\n"
3652 | ],
3653 | "outputs": [],
3654 | "metadata": {
3655 | "id": "zH6ZxqbExnjk"
3656 | }
3657 | },
3658 | {
3659 | "cell_type": "code",
3660 | "execution_count": null,
3661 | "source": [
3662 | "# plot graphs\n",
3663 | "gallery([orig_data, cayley_data, fully_data], labels=[\"original\", \"expander\", \"fully connect\"])\n"
3664 | ],
3665 | "outputs": [],
3666 | "metadata": {
3667 | "id": "HoShW9H0_RcA"
3668 | }
3669 | },
3670 | {
3671 | "cell_type": "markdown",
3672 | "source": [
3673 | "We can see that the expander graph is more connected than the original graph (thus helping in our long-distance messages problem), but far less dense compared to the fully connected approach."
3674 | ],
3675 | "metadata": {
3676 | "id": "qllw-CRw8Afz"
3677 | }
3678 | },
3679 | {
3680 | "cell_type": "code",
3681 | "execution_count": null,
3682 | "source": [
3683 | "# @title [RUN] Hyperparameters GT\n",
3684 | "\n",
3685 | "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n",
3686 | "LR = 0.001 #@param {type:\"number\"}\n",
3687 | "HIDDEN_DIM = 32 #@param {type:\"integer\"}\n",
3688 | "NUM_LAYERS = 5 #@param {type:\"integer\"}\n",
3689 | "\n",
3690 | "\n",
3691 | "#you can add more here if you need"
3692 | ],
3693 | "outputs": [],
3694 | "metadata": {
3695 | "id": "qFvi2IOO7EpX"
3696 | }
3697 | },
3698 | {
3699 | "cell_type": "code",
3700 | "execution_count": null,
3701 | "source": [
3702 | "tree_dataset, dim0, num_classes = DictionaryLookupDataset(4).generate_data(transform=ExpanderTransform())\n",
3703 | "\n",
3704 | "train_syn_data = tree_dataset[:1000]\n",
3705 | "val_syn_data = tree_dataset[1000:2000]\n",
3706 | "test_syn_data = tree_dataset[2000:3000]\n",
3707 | "\n",
3708 | "batch_size = 128\n",
3709 | "train_syn_loader = DataLoader(train_syn_data, batch_size=batch_size, shuffle=True)\n",
3710 | "val_syn_loader = DataLoader(val_syn_data, batch_size=batch_size, shuffle=False)\n",
3711 | "test_syn_loader = DataLoader(test_syn_data, batch_size=batch_size, shuffle=False)"
3712 | ],
3713 | "outputs": [],
3714 | "metadata": {
3715 | "id": "vasBgdWmEksy"
3716 | }
3717 | },
3718 | {
3719 | "cell_type": "markdown",
3720 | "source": [
3721 | "Since the only modification are in the pre-processing stage (by rewiring based on the expander graph rather than fully connected), the graph model remains the same as above. It's time to train the model."
3722 | ],
3723 | "metadata": {
3724 | "id": "PcNivxwGN5hq"
3725 | }
3726 | },
3727 | {
3728 | "cell_type": "code",
3729 | "execution_count": null,
3730 | "source": [
3731 | "model_gt_syn_expander = GTpygRewireModel(num_classes+1, HIDDEN_DIM, num_layers=NUM_LAYERS, dim_emb=dim0)\n",
3732 | "gt_perf_per_epoch = run_exp_syn(model_gt_syn_expander, train_syn_loader, val_syn_loader, test_syn_loader,\n",
3733 | " lr=LR, n_epochs=NUM_EPOCHS)"
3734 | ],
3735 | "outputs": [],
3736 | "metadata": {
3737 | "id": "JBYe26_2euNn"
3738 | }
3739 | },
3740 | {
3741 | "cell_type": "markdown",
3742 | "source": [
3743 | "❗Note that, while for the fully connected approach a single layer is enough to send the messages between any pair of nodes, the expander one might need more than a single layer since, while being more connected than our original graph, it still exhibits some sparsity.\n",
3744 | "\n",
3745 | "For our synthetic task, the model might converge slower than the previous version, however it is still capable of learning the task. Moreover, the operations inside the model are much more efficient due to the sparser topology."
3746 | ],
3747 | "metadata": {
3748 | "id": "1xilrmzD6SQL"
3749 | }
3750 | },
3751 | {
3752 | "cell_type": "markdown",
3753 | "source": [
3754 | "🍸 It was an intense journey so far!! We learned how to implement and train a GNN both from scratch and using the PyG library, and we even got in contact with some of the hot open-problems in the GNN research field.\n",
3755 | "\n",
3756 | "⏩ The main purpose of this tutorial is to make you familiar with graph deep learning and what you can do with it. However, besides the over-squashing solutions presented here, there is much more out there that we didn't have time to cover. These includes techniques such as virtual nodes or dynamic rewiring. If you are interested to find out more about all of these, we included a list of papers [in the end of this tutorial](#references)."
3757 | ],
3758 | "metadata": {
3759 | "id": "rnoEGAe-69cT"
3760 | }
3761 | },
3762 | {
3763 | "cell_type": "markdown",
3764 | "source": [
3765 | "# 🧘 [Bonus] **Positional Encoding for Graph Data** "
3766 | ],
3767 | "metadata": {
3768 | "id": "Dn3HjUK0gN-2"
3769 | }
3770 | },
3771 | {
3772 | "cell_type": "markdown",
3773 | "source": [
3774 | "The major issue with the rewiring approaches presented so far is that we **completly ignore the original graph topology**. While for our synthetic problem the graph doesn't influence the output, this is unlikely to be the case in real-world problems, such as the solubility prediction in ZINC dataset. Ideally, we would like to be able to overcome over-squashing, while still being aware of the graph structure we are working on.\n",
3775 | "\n",
3776 | "One solution to recover the graph topology is by incorporating extra features, known as **positional encodings**. These are additional input features summarizing various graph statistics. Some popular choices include [laplacian eigenvectors](https://arxiv.org/abs/2003.00982), features extracted using [random walks](https://arxiv.org/pdf/2110.07875) or [learnable positional encoding](https://arxiv.org/pdf/2307.07107)."
3777 | ],
3778 | "metadata": {
3779 | "id": "4fSDnNWkotLh"
3780 | }
3781 | },
3782 | {
3783 | "cell_type": "markdown",
3784 | "source": [
3785 | "In the following, we implemented a data transformation to pre-compute the laplacian eigenvector necessariy to encode graph structure information.\n",
3786 | "\n",
3787 | "For each node, the positional encoding corresponding to that node will be represented by the eigenvector associated with the smallest `emb_dim` eigenvalues and will be store in `data.pos`.\n"
3788 | ],
3789 | "metadata": {
3790 | "id": "np_L1FJDq1ed"
3791 | }
3792 | },
3793 | {
3794 | "cell_type": "code",
3795 | "execution_count": null,
3796 | "source": [
3797 | "# this is the dimension of the positional encoding\n",
3798 | "POSEMB_DIM = 16 #@param {type:\"integer\"}"
3799 | ],
3800 | "outputs": [],
3801 | "metadata": {
3802 | "cellView": "form",
3803 | "id": "zKUM_-Lz_XHe"
3804 | }
3805 | },
3806 | {
3807 | "cell_type": "code",
3808 | "execution_count": null,
3809 | "source": [
3810 | "class AddLaplacianPETransform(BaseTransform):\n",
3811 | " def __init__(self, emb_dim):\n",
3812 | " \"\"\"\n",
3813 | " Args:\n",
3814 | " emb_dim (int): dimension of the positional encoding\n",
3815 | " \"\"\"\n",
3816 | " super(AddLaplacianPETransform).__init__()\n",
3817 | " self.emb_dim = emb_dim\n",
3818 | "\n",
3819 | " def __call__(self, data):\n",
3820 | " \"\"\"\n",
3821 | " Args:\n",
3822 | " data (PyG.Data): one instance of PyG graph\n",
3823 | " Returns:\n",
3824 | " data (PyG.Data): original graph with an additional attribute\n",
3825 | " pos (n, emb_dim) containing the top emb_dim eigenvectors\n",
3826 | " \"\"\"\n",
3827 | " num_nodes = data.x.shape[0]\n",
3828 | "\n",
3829 | " edge_index, edge_weight = get_laplacian(\n",
3830 | " data.edge_index,\n",
3831 | " normalization='sym',\n",
3832 | " num_nodes=num_nodes,\n",
3833 | " )\n",
3834 | "\n",
3835 | " L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)\n",
3836 | " eig_vals, eig_vecs = eig(L.todense())\n",
3837 | "\n",
3838 | "\n",
3839 | " eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])\n",
3840 | "\n",
3841 | " pe = torch.from_numpy(eig_vecs[:, 1:self.emb_dim + 1])\n",
3842 | " if pe.shape[1] < self.emb_dim:\n",
3843 | " pe = torch.cat((pe, torch.zeros(pe.shape[0], self.emb_dim-pe.shape[1])), axis=-1)\n",
3844 | " sign = -1 + 2 * torch.randint(0, 2, (self.emb_dim, ))\n",
3845 | " pe *= sign\n",
3846 | "\n",
3847 | " data.pos = pe\n",
3848 | " return data"
3849 | ],
3850 | "outputs": [],
3851 | "metadata": {
3852 | "id": "5RMIic2GrARM"
3853 | }
3854 | },
3855 | {
3856 | "cell_type": "code",
3857 | "execution_count": null,
3858 | "source": [
3859 | "tree_dataset, _, _ = DictionaryLookupDataset(3).generate_data(transform=AddLaplacianPETransform(POSEMB_DIM))\n",
3860 | "print(f\"The positional encoding for node 5 is: {tree_dataset[0].pos[5]}\")"
3861 | ],
3862 | "outputs": [],
3863 | "metadata": {
3864 | "id": "OnIC67FOsrjY"
3865 | }
3866 | },
3867 | {
3868 | "cell_type": "markdown",
3869 | "source": [
3870 | "## 🖋 **Task** Modify the graph neural network model to incorporate the positional encoding, while still using fully connected adjacency matrix \n",
3871 | "\n",
3872 | "> We expect two distinct modification for this task:\n",
3873 | " 1. same as before, use the rewired connected adjacency matrix as opposed to the original graph topology\n",
3874 | " 2. before any processing, concatenate the additional positional encoding computed above to the initial features."
3875 | ],
3876 | "metadata": {
3877 | "id": "B1ArQV2BG0lS"
3878 | }
3879 | },
3880 | {
3881 | "cell_type": "code",
3882 | "execution_count": null,
3883 | "source": [
3884 | "class GTpygPEModel(Module):\n",
3885 | " def __init__(self, hidden_dim, num_layers=2):\n",
3886 | " \"\"\"Graph Transformer Neural Network model using node-level PE\n",
3887 | "\n",
3888 | " Args:\n",
3889 | " hidden_dim: (int) - hidden dimension\n",
3890 | " num_layers: (int) - number of layers\n",
3891 | " \"\"\"\n",
3892 | " super(GTpygPEModel, self).__init__()\n",
3893 | " self.num_layers = num_layers # please select num_layers>=2\n",
3894 | "\n",
3895 | " self.embed_x = Embedding(28, hidden_dim)\n",
3896 | "\n",
3897 | " # ============ YOUR CODE HERE ==============\n",
3898 | " # This is the ORIGINAL code used before. Your task is to MODIFY\n",
3899 | " # it such that it takes into account the newly added features.\n",
3900 | " # Hint: What is happening with the input dimension in the first layer?\n",
3901 | " #\n",
3902 | " # self.layers = [GTpygLayer(hidden_dim, hidden_dim, hidden_dim)]\n",
3903 | " # ===========================================\n",
3904 | "\n",
3905 | " self.layers += [GTpygLayer(hidden_dim, hidden_dim, hidden_dim) for _ in range(num_layers-2)]\n",
3906 | " self.layers += [GTpygLayer(hidden_dim, 1, hidden_dim)]\n",
3907 | "\n",
3908 | " self.layers = ModuleList(self.layers)\n",
3909 | "\n",
3910 | " def forward(self, graph):\n",
3911 | " \"\"\"\n",
3912 | " Args:\n",
3913 | " graph: (PyG.Data) - batch of PyG graphs\n",
3914 | " Returns:\n",
3915 | " out: (batch_size,) - scalar prediction for each graph in the batch\n",
3916 | " \"\"\"\n",
3917 | " new_x = self.embed_x(graph.x).squeeze(1)\n",
3918 | "\n",
3919 | " # ============ YOUR CODE HERE ==============\n",
3920 | " # Concatenate the positional encoding computed in thge transformation\n",
3921 | " # wit the input features new_x\n",
3922 | " #\n",
3923 | " # x = torch.cat((new_x, graph.pos), axis=-1)\n",
3924 | " # ===========================================\n",
3925 | "\n",
3926 | " for i in range(self.num_layers-1):\n",
3927 | " x = self.layers[i](x, graph.rewire_edge_index)\n",
3928 | " x = F.relu(x)\n",
3929 | "\n",
3930 | " x = self.layers[-1](x, graph.rewire_edge_index)\n",
3931 | "\n",
3932 | " out = global_add_pool(x, graph.batch)\n",
3933 | " out = out.squeeze(-1)\n",
3934 | " return out"
3935 | ],
3936 | "outputs": [],
3937 | "metadata": {
3938 | "id": "fSjKhPBskD-O"
3939 | }
3940 | },
3941 | {
3942 | "cell_type": "code",
3943 | "execution_count": null,
3944 | "source": [
3945 | "# @title [RUN] Hyperparameters GT\n",
3946 | "\n",
3947 | "NUM_EPOCHS = 50 #@param {type:\"integer\"}\n",
3948 | "LR = 0.001 #@param {type:\"number\"}\n",
3949 | "HIDDEN_DIM = 64 #@param {type:\"integer\"}\n",
3950 | "NUM_LAYERS = 2 #@param {type:\"integer\"}\n",
3951 | "\n",
3952 | "\n",
3953 | "#you can add more here if you need"
3954 | ],
3955 | "outputs": [],
3956 | "metadata": {
3957 | "id": "i2M4uEtiwkZG"
3958 | }
3959 | },
3960 | {
3961 | "cell_type": "code",
3962 | "execution_count": null,
3963 | "source": [
3964 | "batch_size = 64\n",
3965 | "\n",
3966 | "train_ds = ZINC(root='data/ZINC', split='train', subset=True, transform=T.Compose([FullyConnectedTransform(), AddLaplacianPETransform(POSEMB_DIM)]))\n",
3967 | "val_ds = ZINC(root='data/ZINC', split='val', subset=True, transform=T.Compose([FullyConnectedTransform(), AddLaplacianPETransform(POSEMB_DIM)]))\n",
3968 | "test_ds = ZINC(root='data/ZINC', split='test', subset=True, transform=T.Compose([FullyConnectedTransform(), AddLaplacianPETransform(POSEMB_DIM)]))\n",
3969 | "\n",
3970 | "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)\n",
3971 | "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)\n",
3972 | "test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)"
3973 | ],
3974 | "outputs": [],
3975 | "metadata": {
3976 | "id": "3opHPfUyUg1B"
3977 | }
3978 | },
3979 | {
3980 | "cell_type": "code",
3981 | "execution_count": null,
3982 | "source": [
3983 | "model_gt_pe = GTpygPEModel(HIDDEN_DIM, num_layers=NUM_LAYERS)\n",
3984 | "stats = run_exp_pyg(model_gt_pe, train_loader, val_loader, test_loader, loss_fct=F.mse_loss,\n",
3985 | " lr=LR, num_epochs=NUM_EPOCHS)"
3986 | ],
3987 | "outputs": [],
3988 | "metadata": {
3989 | "id": "1XV3lDhAUg1C"
3990 | }
3991 | },
3992 | {
3993 | "cell_type": "code",
3994 | "execution_count": null,
3995 | "source": [
3996 | "plot_stats(stats)"
3997 | ],
3998 | "outputs": [],
3999 | "metadata": {
4000 | "id": "hHlUoTyAZ2uK"
4001 | }
4002 | },
4003 | {
4004 | "cell_type": "markdown",
4005 | "source": [
4006 | "If you are familiar with the NLP research, this entire model might look very familiar: we have a set of nodes, that we enriched with positional encoding, and than we exchange information between each pair of two nodes based on a key-value attention mechanism. It is basically a **Transformer architecture** applyed on a sequence of nodes instead of a sequence of word tokens."
4007 | ],
4008 | "metadata": {
4009 | "id": "oMjCto-MuIhA"
4010 | }
4011 | },
4012 | {
4013 | "cell_type": "markdown",
4014 | "source": [
4015 | "## Asolute vs Relative Positional Encoding"
4016 | ],
4017 | "metadata": {
4018 | "id": "f679FnVE7OHR"
4019 | }
4020 | },
4021 | {
4022 | "cell_type": "markdown",
4023 | "source": [
4024 | "Inspired from the NLP community, [two types of positional encoding recently emerges](https://arxiv.org/abs/2402.14202): **absolute positional encoding** and **relative positional encoding**. The *absolute positional encoding* summarizes the graph structure as node features (as we did above with the laplacian eigenvectors). On the other hand, the *relative positional encoding* represents the graph structure as an embedding of pairs of nodes (i.e. edges). Examples of relative positional encoding include boolean identifiers ($0$ for additional edges and a $1$ for the original topology), [resistance distance](https://arxiv.org/abs/2301.09505) or [shortest-path](https://arxiv.org/pdf/2106.05234)."
4025 | ],
4026 | "metadata": {
4027 | "id": "Vj4AJ7v2wH1g"
4028 | }
4029 | },
4030 | {
4031 | "cell_type": "markdown",
4032 | "source": [
4033 | "## 🖋 **Task** Implement a transformation to compute the relative positional encoding based on the shortest-path distance.\n",
4034 | "\n",
4035 | "> For each pair of two nodes, the extra feature will be an integer denoting the shortest path between the two nodes. Note that this can only make sense in conjuction with a rewiring transformation, otherwise all the shortest-path distances would be 1.\n",
4036 | "\n"
4037 | ],
4038 | "metadata": {
4039 | "id": "ACopExRjxqjt"
4040 | }
4041 | },
4042 | {
4043 | "cell_type": "code",
4044 | "execution_count": null,
4045 | "source": [
4046 | "class AddRelativePETransform(BaseTransform):\n",
4047 | " def __init__(self):\n",
4048 | " super(AddRelativePETransform).__init__()\n",
4049 | " def __call__(self, data):\n",
4050 | " \"\"\"\n",
4051 | " Args:\n",
4052 | " data (PyG.Data): one instance of PyG graph\n",
4053 | " Returns:\n",
4054 | " data (PyG.Data): original graph with an additional attribute\n",
4055 | " edge_attr (e,) containing the shortest-path between\n",
4056 | " each pair of two nodes connected in the rewired graph\n",
4057 | " \"\"\"\n",
4058 | " orig_data = Data(x=data.x, edge_index=data.edge_index)\n",
4059 | " G = torch_geometric.utils.to_networkx(orig_data).to_undirected()\n",
4060 | "\n",
4061 | " # ============ YOUR CODE HERE ==============\n",
4062 | " # For each pair of 2 nodes in data.rewire_edge_index (2, e),\n",
4063 | " # compute their shortest path and store it in `edge_attr`: (e,)\n",
4064 | " # Hint: nx.shortest_path_length can be used to compute the shortest paths\n",
4065 | " #\n",
4066 | " # edge_attr = []\n",
4067 | " # num_edges = data.rewire_edge_index.shape[1]\n",
4068 | " # for i in range(num_edges):\n",
4069 | " #\n",
4070 | " #\n",
4071 | " # edge_attr.append(length)\n",
4072 | " # ===========================================\n",
4073 | "\n",
4074 | " data.edge_attr = torch.tensor(edge_attr)\n",
4075 | " return data"
4076 | ],
4077 | "outputs": [],
4078 | "metadata": {
4079 | "id": "bAZuMU54OB5E"
4080 | }
4081 | },
4082 | {
4083 | "cell_type": "code",
4084 | "execution_count": null,
4085 | "source": [
4086 | "tree_dataset, _, _ = DictionaryLookupDataset(3).generate_data(add_self_loops=False, transform=T.Compose([FullyConnectedTransform(), AddRelativePETransform()]))\n",
4087 | "draw_one_graph(plt.axes(), tree_dataset[0].edge_index.numpy(), layout='tree')\n",
4088 | "print(f\"The relative positional encoding for the edge between {tree_dataset[0].rewire_edge_index[0][5]} and {tree_dataset[0].rewire_edge_index[1][5]} is : {tree_dataset[0].edge_attr[5]}\")"
4089 | ],
4090 | "outputs": [],
4091 | "metadata": {
4092 | "id": "Q4EKdEFgyPIa"
4093 | }
4094 | },
4095 | {
4096 | "cell_type": "markdown",
4097 | "source": [
4098 | "## 🖋 **Task** Modify the GNN Layer below to incorporate the relative positional encoding when computing the attention cofficient.\n",
4099 | "\n",
4100 | "> While there are several ways in which we can do that, for this practical we will implement the following function:\n",
4101 | ">\n",
4102 | "> \\begin{align}\n",
4103 | "f_{msg}(x_i, x_j)=\\sigma(k(x_i)^Tq(x_j) \\color{red}+\\color{red}r_\\color{red}{ij})v(x_j)\n",
4104 | "\\end{align}\n",
4105 | ">\n",
4106 | "> where $\\sigma$ is the softmax non-linearity and $r_{ij}$ is the relative positional encoding (the shortest-path between $i$ and $j$ in our case)."
4107 | ],
4108 | "metadata": {
4109 | "id": "jpnSoLa00Xrz"
4110 | }
4111 | },
4112 | {
4113 | "cell_type": "code",
4114 | "execution_count": null,
4115 | "source": [
4116 | "class GTpygRPeLayer(MessagePassing):\n",
4117 | " def __init__(self, in_dim, out_dim, hid_dim):\n",
4118 | " \"\"\"Sparse Graph Transformer Layer implemented using Pytorch Geometric\n",
4119 | "\n",
4120 | " Args:\n",
4121 | " in_dim: (int) - input dimension for node features\n",
4122 | " out_dim: (int) - output dimension\n",
4123 | " hid_dim: (int) - hidden dimension\n",
4124 | " aggr: (int) - the type of aggregation used in the message passing\n",
4125 | " \"\"\"\n",
4126 | " super().__init__(node_dim=0, aggr='add')\n",
4127 | " self.in_dim = in_dim\n",
4128 | " self.out_dim = out_dim\n",
4129 | "\n",
4130 | " # ============ YOUR CODE HERE ==============\n",
4131 | " # Add new layers to compute the three linear projections\n",
4132 | " # these should be the same as before\n",
4133 | " # x_i -> q(x_i); x_j -> k(x_i); x_j -> v(x_i)\n",
4134 | " #\n",
4135 | " # self.lin_q = ...\n",
4136 | " # self.lin_k = ...\n",
4137 | " # self.lin_v = ...\n",
4138 | " # ===========================================\n",
4139 | "\n",
4140 | " # convert the scalar denoting shortest-path into an embedding\n",
4141 | " # maximum shortest path is set to 23.\n",
4142 | " self.lin_attr = Embedding(23, 1)\n",
4143 | "\n",
4144 | " def forward(self, x, edge_index, edge_attr):\n",
4145 | " \"\"\"\n",
4146 | " Args:\n",
4147 | " x: (n, in_dim) - initial node features\n",
4148 | " edge_index: (2, e) - list of edges as a tuple\n",
4149 | " edge_attr: (e, ) - relative positional encoding\n",
4150 | "\n",
4151 | " Returns:\n",
4152 | " out: (n, out_dim) - updated node features\n",
4153 | " \"\"\"\n",
4154 | " out = self.propagate(edge_index, x=x, edge_attr=edge_attr)\n",
4155 | " return out\n",
4156 | "\n",
4157 | "\n",
4158 | " def message(self, x_i, x_j, edge_index_j, edge_attr):\n",
4159 | " \"\"\"\n",
4160 | " Args:\n",
4161 | " x_i: (e, in_dim) - features corresponding to destination nodes\n",
4162 | " x_j: (e, in_dim) - features corresponding to source nodes\n",
4163 | " index: (e, d_e) - node index used to guide the softmax computation\n",
4164 | " size_i - number of destination nodes\n",
4165 | " edge_attr: (e,)\n",
4166 | "\n",
4167 | " Returns:\n",
4168 | " out: (n, out_dim) - updated node features\n",
4169 | " \"\"\"\n",
4170 | "\n",
4171 | " # ============ YOUR CODE HERE ==============\n",
4172 | " # Compute the message function from the equation above\n",
4173 | " # Hint: some of the code should be similar to the one that\n",
4174 | " # your wrote for the GTpygLayer.\n",
4175 | " #\n",
4176 | " # alpha = ...\n",
4177 | " # alpha = softmax(alpha, edge_index_j)\n",
4178 | " # out = ...\n",
4179 | " # ===========================================\n",
4180 | "\n",
4181 | " return out"
4182 | ],
4183 | "outputs": [],
4184 | "metadata": {
4185 | "id": "YuQfwbGYaN1v"
4186 | }
4187 | },
4188 | {
4189 | "cell_type": "markdown",
4190 | "source": [
4191 | "The model class is the same as before, but using the extra argument for edge_attribute."
4192 | ],
4193 | "metadata": {
4194 | "id": "tLAGlQ9p1nGJ"
4195 | }
4196 | },
4197 | {
4198 | "cell_type": "code",
4199 | "execution_count": null,
4200 | "source": [
4201 | "class GTpygRPeModel(Module):\n",
4202 | " def __init__(self, hidden_dim, num_layers=2):\n",
4203 | " \"\"\"\n",
4204 | " Graph Transformer Neural Network model using rewiring and\n",
4205 | " relative positional encoding\n",
4206 | "\n",
4207 | " Args:\n",
4208 | " hidden_dim: (int) - hidden dimension\n",
4209 | " num_layers: (int) - number of GT layers used in the model\n",
4210 | " \"\"\"\n",
4211 | " super(GTpygRPeModel, self).__init__()\n",
4212 | " self.num_layers = num_layers\n",
4213 | "\n",
4214 | " self.embed_x = Embedding(28, hidden_dim)\n",
4215 | "\n",
4216 | " self.layers = [GTpygRPeLayer(hidden_dim, hidden_dim, hidden_dim) for _ in range(num_layers-1)]\n",
4217 | " self.layers += [GTpygRPeLayer(hidden_dim, 1, hidden_dim)]\n",
4218 | "\n",
4219 | " self.layers = ModuleList(self.layers)\n",
4220 | "\n",
4221 | " def forward(self, graph):\n",
4222 | " \"\"\"\n",
4223 | " Args:\n",
4224 | " graph: (PyG.Data) - batch of PyG graphs\n",
4225 | " Returns:\n",
4226 | " out: (batch_size,) - scalar prediction for each graph in the batch\n",
4227 | " \"\"\"\n",
4228 | "\n",
4229 | " new_x = self.embed_x(graph.x).squeeze(1)\n",
4230 | " x = new_x\n",
4231 | " for i in range(self.num_layers-1):\n",
4232 | " x = self.layers[i](x, graph.rewire_edge_index, graph.edge_attr)\n",
4233 | " x = F.relu(x)\n",
4234 | "\n",
4235 | " x = self.layers[-1](x, graph.rewire_edge_index, graph.edge_attr)\n",
4236 | "\n",
4237 | " out = global_add_pool(x, graph.batch)\n",
4238 | " out = out.squeeze(-1)\n",
4239 | " return out"
4240 | ],
4241 | "outputs": [],
4242 | "metadata": {
4243 | "id": "3vX3k0-RdXaA"
4244 | }
4245 | },
4246 | {
4247 | "cell_type": "markdown",
4248 | "source": [
4249 | "Lets train the model once again, this time using the relative positional encoding.\n",
4250 | "\n",
4251 | "❗️Note that the training of the `GTpygRPeModel` is quite slow 🐢. For the purpose of this tutorial, it is enough to test that it runs for a couple of iterations. If you want to see results and colab's credit is not enough we suggest moving the training to a local machine."
4252 | ],
4253 | "metadata": {
4254 | "id": "rToIjzKu1y0Q"
4255 | }
4256 | },
4257 | {
4258 | "cell_type": "code",
4259 | "execution_count": null,
4260 | "source": [
4261 | "# @title [RUN] Hyperparameters GT with relative PE\n",
4262 | "\n",
4263 | "NUM_EPOCHS = 50 #@param {type:\"integer\"}\n",
4264 | "LR = 0.001 #@param {type:\"number\"}\n",
4265 | "HIDDEN_DIM = 64 #@param {type:\"integer\"}\n",
4266 | "NUM_LAYERS = 2 #@param {type:\"integer\"}\n",
4267 | "\n",
4268 | "\n",
4269 | "#you can add more here if you need"
4270 | ],
4271 | "outputs": [],
4272 | "metadata": {
4273 | "cellView": "form",
4274 | "id": "3VC3mh3uaKe_"
4275 | }
4276 | },
4277 | {
4278 | "cell_type": "code",
4279 | "execution_count": null,
4280 | "source": [
4281 | "batch_size = 64\n",
4282 | "\n",
4283 | "train_ds = ZINC(root='data/ZINC', split='train', subset=True, transform=T.Compose([FullyConnectedTransform(), AddRelativePETransform()]))\n",
4284 | "val_ds = ZINC(root='data/ZINC', split='val', subset=True, transform=T.Compose([FullyConnectedTransform(), AddRelativePETransform()]))\n",
4285 | "test_ds = ZINC(root='data/ZINC', split='test', subset=True, transform=T.Compose([FullyConnectedTransform(), AddRelativePETransform()]))\n",
4286 | "\n",
4287 | "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)\n",
4288 | "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)\n",
4289 | "test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)"
4290 | ],
4291 | "outputs": [],
4292 | "metadata": {
4293 | "id": "uVPGsXqGyE5v"
4294 | }
4295 | },
4296 | {
4297 | "cell_type": "code",
4298 | "execution_count": null,
4299 | "source": [
4300 | "model_gt_rpe = GTpygRPeModel(HIDDEN_DIM, num_layers=NUM_LAYERS)\n",
4301 | "stats = run_exp_pyg(model_gt_rpe, train_loader, val_loader, test_loader, loss_fct=F.mse_loss,\n",
4302 | " lr=LR, num_epochs=NUM_EPOCHS)"
4303 | ],
4304 | "outputs": [],
4305 | "metadata": {
4306 | "id": "w28zbovgyE5w"
4307 | }
4308 | },
4309 | {
4310 | "cell_type": "code",
4311 | "execution_count": null,
4312 | "source": [
4313 | "plot_stats(stats)"
4314 | ],
4315 | "outputs": [],
4316 | "metadata": {
4317 | "id": "bYrr3-IbQc12"
4318 | }
4319 | },
4320 | {
4321 | "cell_type": "markdown",
4322 | "source": [
4323 | "💡 Positional encoding is not the only way to preserve the original topology when using rewiring techniques. One possible solutions is to alternate between layers of message passing on the original graph and layers of message passing on the rewired graph. Feel free to experiment more with this."
4324 | ],
4325 | "metadata": {
4326 | "id": "8Pvu2WXbsTWz"
4327 | }
4328 | },
4329 | {
4330 | "cell_type": "markdown",
4331 | "source": [
4332 | "🏆 Well done!! You arrived at the end of this tutorial. We hope that you are now familiar with working with graph data and you were developing a good intuition about the cause and solutions for some of the hard problems in graph community. If you have any queries or feedback don't hesitate to contact us!"
4333 | ],
4334 | "metadata": {
4335 | "id": "nIjsnIDKp1-W"
4336 | }
4337 | },
4338 | {
4339 | "cell_type": "markdown",
4340 | "source": [
4341 | "\n",
4342 | "# 🍫 Want to learn more?"
4343 | ],
4344 | "metadata": {
4345 | "id": "kGs9slSTThj8"
4346 | }
4347 | },
4348 | {
4349 | "cell_type": "markdown",
4350 | "source": [
4351 | "If you are interested in finding more about these topics, here is a selection of papers that you might find useful:\n",
4352 | "\n",
4353 | "⚫ [On the Bottleneck of Graph Neural Networks and its Practical Implications](https://arxiv.org/abs/2006.05205) \\\\\n",
4354 | "🟢 [Expander graph propagation](https://arxiv.org/abs/2210.02997) \\\\\n",
4355 | "🟣 [Drew: Dynamically rewired message passing with delay](https://arxiv.org/abs/2305.08018) \\\\\n",
4356 | "🔵 [On Over-Squashing in Message Passing Neural Networks: The Impact of Width, Depth, and Topology](https://arxiv.org/abs/2302.02941) \\\\\n",
4357 | "🟡 [Understanding over-squashing and bottlenecks on graphs via curvature](https://arxiv.org/abs/2111.14522) \\\\\n",
4358 | "🟠 [Probabilistically Rewired Message-Passing Neural Networks](https://arxiv.org/abs/2310.02156) \\\\\n",
4359 | "🔴 [Graph Neural Networks with Learnable Structural and Positional Representations](https://arxiv.org/abs/2110.07875) \\\\\n",
4360 | "🟤 [A Generalization of Transformer Networks to Graph](https://arxiv.org/abs/2012.09699)"
4361 | ],
4362 | "metadata": {
4363 | "id": "sngoJa0QF16i"
4364 | }
4365 | }
4366 | ],
4367 | "metadata": {
4368 | "accelerator": "GPU",
4369 | "colab": {
4370 | "gpuType": "T4",
4371 | "provenance": [],
4372 | "collapsed_sections": [
4373 | "lKxFuNjpAeVz",
4374 | "VMOleceFNqEZ"
4375 | ]
4376 | },
4377 | "kernelspec": {
4378 | "display_name": "Python 3",
4379 | "name": "python3"
4380 | },
4381 | "language_info": {
4382 | "name": "python"
4383 | }
4384 | },
4385 | "nbformat": 4,
4386 | "nbformat_minor": 2
4387 | }
--------------------------------------------------------------------------------
/4_geometric_deep_learning/README.md:
--------------------------------------------------------------------------------
1 | # [[EEML2024](https://www.eeml.eu)] Tutorial 4: Geometric Deep Learning
2 |
3 | **Authors:** Iulia Duta and Vladimir Mirjanić
4 |
5 | ---
6 |
7 | This tutorial is meant to be a stand-alone introduction into the Graph Neural Network world. You will learn about the basics of working with graph data, implementing a standard Graph Network architecture and understand more about current challenges and open research problems in the field such as rewiring and positional encoding.
8 |
9 |
10 | [Introduction video](https://www.youtube.com/watch?v=ZvpxLfDz_mk)
11 |
12 |
13 | ### Outline
14 |
15 | 1. Implement both a sparse and a dense version of **Graph Convolutional Network in Pytorch**.
16 | 2. Write a training pipeline for graph inputs including **graph-level representation** and **custom mini-batching**.
17 | 3. Improve the Graph Convolutional Network using attention mechanisms - **Graph Attention Network**.
18 | 4. Make our first steps into **Pytorch Geometric**, a library dedicated to geometric deep learning.
19 | 5. Re-implement **Graph Attention Network in Pytorch Geometric**.
20 | 6. Understand the **over-squashing challenge** and experiment with **two graph rewiring techniques**: Graph Transformer and Expander Graph Propagation.
21 | 7. Explore various **positional encodings** for graph data.
22 |
23 |
24 | ### Notebooks
25 |
26 | Tutorial: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/4_geometric_deep_learning/GDL_tutorial.ipynb)
28 |
29 |
30 | Solution: [](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2024/blob/main/4_geometric_deep_learning/GDL_tutorial_solution.ipynb)
32 |
33 | ---
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PracticalSessions2024
2 |
3 | Repository for tutorial sessions at EEML2024
4 |
5 | Designed for education purposes. Please do not distribute without permission. Write at contact@eeml.eu if you have any question.
6 |
7 | You are welcome to reuse this material in other courses or schools, but please reach out to contact@eeml.eu if you plan to do so. We would appreciate it if you could acknowledge that the materials come from EEML2024 and give credits to the authors. Also please keep a link in your materials to the original repo, in case updates occur.
8 |
9 | MIT License
10 |
11 | Copyright (c) 2024 EEML
12 |
13 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18 |
--------------------------------------------------------------------------------