├── ReadME.md ├── [public]_Hypergraph_Lab_(solved).ipynb └── [public]_Hypergraph_Lab_(unsolved).ipynb /ReadME.md: -------------------------------------------------------------------------------- 1 | ## Tutorial on Hypergraph Networks 2 | 3 | 4 | This lab is designed to showcase and familiarise you with the basics of learning on hypergraphs. It will encompass both the theory behind hypergraph networks and putting them into practice. 5 | 6 |
7 | 8 |
9 | 10 | 11 | The lab is split into 4 main parts. 12 | 13 | **Part 1:** 14 | 15 | * Creating the hypergraph object 16 | * Understand how to work with hypergraph-structure data 17 | * Building a basic HNN for node-level prediction 18 | 19 | **Part 2:** 20 | 21 | * Comparing hypergraph methods to graph based methods 22 | 23 | **Part 3:** 24 | 25 | * Creating mini-batches for hypergraphs 26 | * Hypergraphs for graph-level prediction 27 | 28 | **Part 4:** 29 | 30 | * Hypergraph attention networks 31 | 32 | **If you find any mistakes please raise it as an issue or get in touch directly at hypergraph.practical@gmail.com !** 33 | 34 | Likewise, please get in touch with any feedback from the tutorial via email. 35 | 36 | Daniel 37 | -------------------------------------------------------------------------------- /[public]_Hypergraph_Lab_(unsolved).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "31RTznj3UwEX" 17 | }, 18 | "source": [ 19 | "## Exploring Higher Order Processing Using Hypergraph Neural Networks" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "mxWzBVTzB5ed" 26 | }, 27 | "source": [ 28 | "To start this lab, save a copy of this colab file and edit that 🙂." 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "6ekuceu8G-Td" 35 | }, 36 | "source": [ 37 | "This lab is designed to showcase and familiarise you with the basics of learning on hypergraphs. It will encompass both the theory behind hypergraph networks and putting them into practice.\n", 38 | "\n", 39 | "The lab is split into 4 main parts.\n", 40 | "\n", 41 | "**Part 1**:\n", 42 | "* Creating the hypergraph object\n", 43 | "* Understand how to work with hypergraph-structure data\n", 44 | "* Building a basic HNN for node-level prediction\n", 45 | "\n", 46 | "**Part 2**:\n", 47 | "* Comparing hypergraph methods to graph based methods\n", 48 | "\n", 49 | "**Part 3**:\n", 50 | "* Creating mini-batches for hypergraphs\n", 51 | "* Hypergraphs for graph-level prediction\n", 52 | "\n", 53 | "**Part 4**:\n", 54 | "* Hypergraph attention networks\n", 55 | "\n", 56 | "Hopefully these will giving natural breaks in the lab so you may come and go as you need or easily find pieces of code you may need elsewhere.\n", 57 | "\n", 58 | "\n", 59 | "\n", 60 | "\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "id": "IOlIWW6zkLLH" 67 | }, 68 | "source": [ 69 | "# 🤷 **Why hypergraphs?**\n", 70 | "\n", 71 | "Hypergraphs allow us to study relationships that go beyond pairwise, relationships that graph structured data fail to capture concisely.\n", 72 | "\n", 73 | "Consider a group of 5 researchers: Ann, Beth, Chris, Dave, and Erica. Ann and Chris write a paper together: ***Theory is all you need (2022)***. Ann was more interested in the theory presented in the paper while Chris was more interested in the applications. Ann collaborated with Beth and Dave developing the theory in the paper: ***Everything goes neural: a general recipe for success (2023)***. Chris collabarates with Dave and Erica to develop the applications of the theory: ***Neural networks for forecasting crocodile attacks (2023)***.\n", 74 | "\n", 75 | "The structure of this network can be easily captured using a hypergraph.\n", 76 | "\n", 77 | "
\n", 78 | "\n", 79 | "
\n", 80 | "\n", 81 | "\n", 82 | "\n", 83 | "\n", 84 | "However, how should we display this in a graph. We can simply connect two authors who have we written a paper together.\n", 85 | "\n", 86 | "
\n", 87 | "\n", 88 | "
\n", 89 | "\n", 90 | "\n", 91 | "However, this fails to capture the whole structure of the network and it is unclear who the authors were for a given paper. For example, if Beth and Dave wrote a paper together: ***Neural Networks: where did it all go wrong? (2024)***, the graph would remain the same as before.\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | "Another way we could try to visualise this data using a graph, is to introduce another type of nodes representing each of the papers. Not only does this increase the computational cost of a graph neural network (since there are more nodes to process and store in memory), but this leads to a heterogeneous graph, where nodes have different meaning, things that are slightly unusual for the way we are interpreting graphs in general and processing them is harder than processing usual graphs (there are some papers dedicated to heterogeneous graphs but without great results).\n", 96 | "\n", 97 | "
\n", 98 | "\n", 99 | "
\n", 100 | "\n", 101 | "\n", 102 | "\n", 103 | "\n", 104 | "\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": { 110 | "id": "W5DtRGF2Z7mJ" 111 | }, 112 | "source": [ 113 | "# **😫 Preliminaries**: Install and import modules\n", 114 | "\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "id": "cNA-6kpE4wL1" 121 | }, 122 | "source": [ 123 | "❗**Note 🤕:** Due to updated versions, there is a high chance you will receive a warrning after running each one of the next 3 cells (the ones containing `pip install` commands). For each cell, when the warning appear, please press \"restart the session\" and proceed to the next cell." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "id": "PW-uyKVLaOFd" 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "#@title [RUN] install\n", 135 | "!pip install --force-reinstall dhg" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "source": [ 141 | "!pip install hypernetx" 142 | ], 143 | "metadata": { 144 | "id": "uOuozVTkoQjI" 145 | }, 146 | "execution_count": null, 147 | "outputs": [] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "source": [ 152 | "!pip install torch==2.1.0\n", 153 | "import torch\n", 154 | "!pip install --force-reinstall torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html" 155 | ], 156 | "metadata": { 157 | "id": "-tfFipvyoS8v" 158 | }, 159 | "execution_count": null, 160 | "outputs": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "source": [ 165 | "!pip install -U --force-reinstall numpy==1.24.0" 166 | ], 167 | "metadata": { 168 | "id": "CL7BJofgavFE" 169 | }, 170 | "execution_count": null, 171 | "outputs": [] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "cellView": "form", 178 | "id": "dA5QXrU_LWPj" 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "#@title [RUN] Import modules\n", 183 | "import torch\n", 184 | "import torch.nn as nn\n", 185 | "import torch.nn.functional as F\n", 186 | "import numpy as np\n", 187 | "import dhg\n", 188 | "import pdb\n", 189 | "import torch.optim as optim\n", 190 | "import matplotlib.pyplot as plt\n", 191 | "import networkx as nx\n", 192 | "import hypernetx as hnx\n", 193 | "import random\n", 194 | "import torch_geometric\n", 195 | "from torch_scatter import scatter_mean, scatter_max, scatter_sum\n", 196 | "from datetime import datetime" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": { 203 | "cellView": "form", 204 | "id": "GjrxJ6_B6wiJ" 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "#@title [RUN] Helper functions\n", 209 | "def draw_one_graph(ax, edges, label=None, node_emb=None, layout=None, special_color=False):\n", 210 | " \"\"\"draw a graph with networkx based on adjacency matrix (edges)\n", 211 | " graph labels could be displayed as a title for each graph\n", 212 | " node_emb could be displayed in colors\n", 213 | " \"\"\"\n", 214 | " graph = nx.Graph()\n", 215 | " edges = zip(edges[0], edges[1])\n", 216 | " graph.add_edges_from(edges)\n", 217 | " node_pos = layout(graph)\n", 218 | " #add colors according to node embeding\n", 219 | " if (node_emb is not None) or special_color:\n", 220 | " color_map = []\n", 221 | " node_list = [node[0] for node in graph.nodes(data = True)]\n", 222 | " for i,node in enumerate(node_list):\n", 223 | " #just ignore this branch\n", 224 | " if special_color:\n", 225 | " if len(node_list) == 3:\n", 226 | " crt_color = (1,0,0)\n", 227 | " elif len(node_list) == 5:\n", 228 | " crt_color = (0,1,0)\n", 229 | " elif len(node_list) == 4:\n", 230 | " crt_color = (1,1,0)\n", 231 | " else:\n", 232 | " special_list = [(1,0,0)] * 3 + [(0,1,0)] * 5 + [(1,1,0)] * 4\n", 233 | " crt_color = special_list[i]\n", 234 | " else:\n", 235 | " crt_node_emb = node_emb[node]\n", 236 | " #map float number (node embeding) to a color\n", 237 | " crt_color = cm.gist_rainbow(crt_node_emb, bytes=True)\n", 238 | " crt_color = (crt_color[0]/255.0, crt_color[1]/255.0, crt_color[2]/255.0, crt_color[3]/255.0)\n", 239 | " color_map.append(crt_color)\n", 240 | "\n", 241 | " nx.draw_networkx_nodes(graph,node_pos, node_color=color_map,\n", 242 | " nodelist = node_list, ax=ax)\n", 243 | " nx.draw_networkx_edges(graph, node_pos, ax=ax)\n", 244 | " nx.draw_networkx_labels(graph,node_pos, ax=ax)\n", 245 | " else:\n", 246 | " nx.draw_networkx(graph, node_pos, ax=ax)\n", 247 | "\n", 248 | "def gallery(graphs, labels=None, node_emb=None, special_color=False, max_graphs=4, max_fig_size=(40, 10), layout=nx.layout.kamada_kawai_layout):\n", 249 | " ''' Draw multiple graphs as a gallery\n", 250 | " Args:\n", 251 | " graphs: torch_geometrics.dataset object/ List of Graph objects\n", 252 | " labels: num_graphs\n", 253 | " node_emb: num_graphs* [num_nodes x num_ch]\n", 254 | " max_graphs: maximum graphs display\n", 255 | " '''\n", 256 | " num_graphs = min(len(graphs), max_graphs)\n", 257 | " ff, axes = plt.subplots(1, num_graphs,\n", 258 | " figsize=max_fig_size,\n", 259 | " subplot_kw={'xticks': [], 'yticks': []})\n", 260 | " if num_graphs == 1:\n", 261 | " axes = [axes]\n", 262 | " if node_emb is None:\n", 263 | " node_emb = num_graphs*[None]\n", 264 | " if labels is None:\n", 265 | " labels = num_graphs * [\" \"]\n", 266 | "\n", 267 | "\n", 268 | " for i in range(num_graphs):\n", 269 | " draw_one_graph(axes[i], graphs[i].edge_index.numpy(), labels[i], node_emb[i], layout, special_color)\n", 270 | " if labels[i] != \" \":\n", 271 | " axes[i].set_title(f\"Target: {labels[i]}\", fontsize=28)\n", 272 | " axes[i].set_axis_off()\n", 273 | " plt.show()\n", 274 | "\n", 275 | "def update_stats(training_stats, epoch_stats):\n", 276 | " \"\"\" Store metrics along the training\n", 277 | " Args:\n", 278 | " epoch_stats: dict containg metrics about one epoch\n", 279 | " training_stats: dict containing lists of metrics along training\n", 280 | " Returns:\n", 281 | " updated training_stats\n", 282 | " \"\"\"\n", 283 | " if training_stats is None:\n", 284 | " training_stats = {}\n", 285 | " for key in epoch_stats.keys():\n", 286 | " training_stats[key] = []\n", 287 | " for key,val in epoch_stats.items():\n", 288 | " training_stats[key].append(val)\n", 289 | " return training_stats\n", 290 | "# epoch_stats = {'train_acc': ep_train_ac, 'val_acc', ep_val_acc, 'train_loss':}\n", 291 | "# training_stats = None\n", 292 | "\n", 293 | "def plot_stats(training_stats, figsize=(5, 5), name=\"\"):\n", 294 | " \"\"\" Create one plot for each metric stored in training_stats\n", 295 | " \"\"\"\n", 296 | " stats_names = [key[6:] for key in training_stats.keys() if key.startswith('train_')]\n", 297 | " f, ax = plt.subplots(len(stats_names), 1, figsize=figsize)\n", 298 | " if len(stats_names)==1:\n", 299 | " ax = np.array([ax])\n", 300 | " for key, axx in zip(stats_names, ax.reshape(-1,)):\n", 301 | " axx.plot(\n", 302 | " training_stats['epoch'],\n", 303 | " training_stats[f'train_{key}'],\n", 304 | " label=f\"Training {key}\")\n", 305 | " axx.plot(\n", 306 | " training_stats['epoch'],\n", 307 | " training_stats[f'val_{key}'],\n", 308 | " label=f\"Validation {key}\")\n", 309 | " axx.set_xlabel(\"Training epoch\")\n", 310 | " axx.set_ylabel(key)\n", 311 | " axx.legend()\n", 312 | " plt.title(name)\n", 313 | "\n", 314 | "def get_list_of_edges(hypergraph):\n", 315 | " incidence_matrix = hypergraph.incidence_matrix()\n", 316 | " list_of_edges = []\n", 317 | " for i in range(incidence_matrix.shape[-1]):\n", 318 | " current_list = []\n", 319 | " for j in range(incidence_matrix.shape[0]):\n", 320 | " if incidence_matrix[j,i] != 0:\n", 321 | " current_list.append(j)\n", 322 | " list_of_edges.append(current_list)\n", 323 | " return list_of_edges\n", 324 | "\n", 325 | "\n", 326 | "def visualise(hypergraph):\n", 327 | " list_of_edges = get_list_of_edges(hypergraph)\n", 328 | " list_of_nodes = list(range(hypergraph.num_nodes))\n", 329 | " empty_dic = {}\n", 330 | " i=0\n", 331 | " for current_list in list_of_edges:\n", 332 | " empty_dic[str(i)] = current_list\n", 333 | " i+=1\n", 334 | " for current_node in list_of_nodes:\n", 335 | " empty_dic[str(i)] = [current_node]\n", 336 | " i+=1\n", 337 | " final_graph = hnx.Hypergraph(empty_dic)\n", 338 | " hnx.draw(final_graph)\n", 339 | " plt.title(f'Value = {hypergraph.y}')\n", 340 | " plt.show()\n", 341 | "\n", 342 | "def hnxHyperGraph(hypergraph):\n", 343 | " list_of_edges = get_list_of_edges(hypergraph)\n", 344 | " list_of_nodes = list(range(hypergraph.num_nodes))\n", 345 | " empty_dic = {}\n", 346 | " i=0\n", 347 | " for current_list in list_of_edges:\n", 348 | " empty_dic[str(i)] = current_list\n", 349 | " i+=1\n", 350 | " for current_node in list_of_nodes:\n", 351 | " empty_dic[str(i)] = [current_node]\n", 352 | " i+=1\n", 353 | " final_graph = hnx.Hypergraph(empty_dic)\n", 354 | " return final_graph\n", 355 | "\n", 356 | "def incidence_to_edgeindex(incidence_matrix):\n", 357 | " incidence_matrix = incidence_matrix.to_sparse()\n", 358 | " values = incidence_matrix.values()\n", 359 | " indices = incidence_matrix.indices()\n", 360 | "\n", 361 | " edge_list = []\n", 362 | " vertex_list = []\n", 363 | " for t in range(int(values.shape[0])):\n", 364 | " for j in range(int(values[t])):\n", 365 | " edge_list.append(indices[1,t])\n", 366 | " vertex_list.append(indices[0,t])\n", 367 | " edge_index = torch.tensor([vertex_list, edge_list])\n", 368 | " return edge_index\n", 369 | "\n", 370 | "\n", 371 | "##graph object\n", 372 | "\n", 373 | "class Graph(object):\n", 374 | " def __init__(self, edge_index, x, y, weighted = None):\n", 375 | " \"\"\" Graph structure\n", 376 | " for a mini-batch it will store a big (sparse) graph\n", 377 | " representing the entire batch\n", 378 | " Args:\n", 379 | " x: node features [num_nodes x num_feats]\n", 380 | " y: graph labels [num_graphs]\n", 381 | " edge_index: list of edges [2 x num_edges]\n", 382 | " \"\"\"\n", 383 | " self.edge_index = edge_index\n", 384 | " self.x = x.to(torch.float32)\n", 385 | " self.y = y\n", 386 | " self.num_nodes = self.x.shape[0]\n", 387 | " if weighted is None:\n", 388 | " self.values = torch.ones(self.edge_index.shape[1])\n", 389 | " else:\n", 390 | " self.values = weighted\n", 391 | "\n", 392 | " #ignore this for now, it will be useful for batching\n", 393 | " def set_batch(self, batch):\n", 394 | " \"\"\" list of ints that maps each node to the graph it belongs to\n", 395 | " e.g. for batch = [0,0,0,1,1,1,1]: the first 3 nodes belong to graph_0 while\n", 396 | " the last 4 belong to graph_1\n", 397 | " \"\"\"\n", 398 | " self.batch = batch\n", 399 | "\n", 400 | " # this function returns a sparse tensor\n", 401 | " def get_adjacency_matrix(self):\n", 402 | " \"\"\" from the list of edges create\n", 403 | " a num_nodes x num_nodes sparse adjacency matrix\n", 404 | " \"\"\"\n", 405 | " return torch.sparse.LongTensor(self.edge_index,\n", 406 | " # we work with a binary adj containing 1 if an edge exist\n", 407 | " self.values,\n", 408 | " torch.Size((self.num_nodes, self.num_nodes))\n", 409 | " ).to_dense()\n", 410 | "\n", 411 | "\n" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "DVQmvJn0Ofdk" 418 | }, 419 | "source": [ 420 | "# 🐤 **Part 1: The basics of learning on hypergraphs** " 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": { 426 | "id": "-8n9Y4wVP7SL" 427 | }, 428 | "source": [ 429 | "# 👶 **What is a hypergraph?**\n", 430 | "\n", 431 | "A **hypergraph** can be viewed as a generalisation of a graph. An edge of a graph is a pair ($2$-tuple) of vertices in that graph. A hyperedge is a unordered $k$-tuple of vertices.\n", 432 | "\n", 433 | "More formally, a **hypergraph** is an ordered pair $\\mathcal{H} = (\\mathcal{V}, \\mathcal{E})$ where $\\mathcal{V}$ is a set of vertices/nodes and $\\mathcal{E}$ is a set of unordered $k$-tuples of $\\mathcal{V}$ for $k\\in \\{1, 2, …, |\\mathcal{V} |\\}$, known as **hyperedges**.\n", 434 | "\n", 435 | "**Example**: Consider the following hypergraph which will become our running example.\n", 436 | "\n", 437 | "
\n", 438 | "\n", 439 | "
\n", 440 | "\n", 441 | "\n", 442 | "We have that $\\mathcal{V} = \\{v_1, v_2, v_3,v_4,v_5\\}$ and $\\mathcal{E} = \\{ (v_1, v_2), (v_1, v_3, v_4), (v_2, v_4, v_5)\\}$.\n", 443 | "\n", 444 | "\n", 445 | "\n", 446 | "We can easily represent a hypergraph using matrices. For a hypergraph with $n$ vertices and $m$ hyperedges, $H \\in \\mathbb{R}^{n \\times m}$ is known as the **incidence matrix**. It is defined as\n", 447 | "$$ h_{i,j} = \\left\\{\n", 448 | "\\begin{array}{ll}\n", 449 | " 1 & \\text{if }v_i\\text{ is in }e_j \\\\\n", 450 | " 0 & \\text{otherwise}\n", 451 | "\\end{array}\n", 452 | "\\right. $$\n", 453 | "\n", 454 | "Occasionally, one might wish to weight this incidence matrix, this can be done by replacing the value of 1 with a weight, more will be said about this later.\n", 455 | "\n", 456 | "So for our running example, the corresponding incidence matrix is $$ H =\\begin{pmatrix}\n", 457 | "1 & 1 & 0 \\\\\n", 458 | "1 & 0 & 1 \\\\\n", 459 | "0 & 1 & 0 \\\\\n", 460 | "0 & 1 & 1 \\\\\n", 461 | "0 & 0 & 1\n", 462 | "\\end{pmatrix}$$\n", 463 | "\n", 464 | "Unlike in graphs, where you only have a degree for nodes, in hypergraph we have a notion of a degree for an edge as well. The degree of a node in a hypergraph is the number of hyperedges that contain that node. The **node degree matrix** $D_v \\in \\mathbb{R}^{n\\times n}$ is defined as the matrix whose diagonal entry is the column sum of the incidence matrix. Similarly, the degree of a hyperedge is the number of vertices contained in that edge. The **hyperedge degree matrix** $D_e \\in \\mathbb{R}^{m \\times m}$, is the matrix whose diagonal entries are the row sums of the incidence matrix.\n", 465 | "\n", 466 | "\n", 467 | "For hypergraphs with node features, each node $v_i \\in \\mathcal{V}$ has an associated $d$-dimensional feature vector $\\mathbf{x_i} \\in \\mathbb{R}^{d}$. Then the **feature matrix** $\\mathbf{X} \\in \\mathbb{R}^{n \\times d}$ can be used to represent the feature vectors for every node in the graph. Can you think of how hyperedge feature matrix could be defined?" 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": { 473 | "id": "QWLMlZMYaBsK" 474 | }, 475 | "source": [ 476 | "# 🎒 Learning on hypergraph structured-data\n", 477 | "\n", 478 | "Machine learning tasks on hypergraph structured data can be categorised based on the nuanced differences of the dataset and the task being studied. Most generally we are interested in the following tasks\n", 479 | "\n", 480 | "- **Node/vertex-level prediction:** a data observation is a node within a hypergraph, we are interested in doing node-level classification/regression. Ex. Cora dataset where we are interested in categorising papers that are nodes of a larger citation network\n", 481 | "- **Hyperdge-level prediction:** we are interested in predicting hyperedges between samples in the dataset/ properties of a hyperedge.\n", 482 | "- **Hypergraph-level prediction:** a data observation is a hypergraph, i.e. our dataset consists of several hypergraphs and we are interested in hypergraph-level classification/regression, ie a single output for the entire hypergraph.\n", 483 | "\n", 484 | "In the first part of this lab we will start with a node prediction task while the second part tackles a hypergraph level prediction.\n", 485 | "\n", 486 | "
\n", 487 | "\n", 488 | "
\n" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": { 494 | "id": "27QLUsf5w_Gl" 495 | }, 496 | "source": [ 497 | "# 🤖 **Creating a hypergraph object**\n", 498 | "\n", 499 | "To store data about the hypergraph we are going to create the following `HyperGraph` class.\n", 500 | "\n", 501 | "\n", 502 | "The `HyperGraph` class stores the hypergraph according to its 'hyper edge index', which is a tensor of size `[2, k]` where $$ k = \\sum_{e \\in \\mathcal{E}}\\sum_{v \\in e} 1 .$$\n", 503 | "\n", 504 | "Where the $i$'th column, $ x = [v_r, e_j]$ represent that $v_r \\in e_j$.\n", 505 | "\n", 506 | "**Example** Going back to our running example. We would represent our hypergraph as\n", 507 | "\n", 508 | "$$\\begin{bmatrix}\n", 509 | "0 & 1 & 0 & 2 & 3 & 1 & 3 & 4 \\\\\n", 510 | "0 & 0 & 1 & 1 & 1 & 2 & 2 & 2\n", 511 | "\\end{bmatrix} $$\n", 512 | "\n", 513 | "The `HyperGraph` class has a method `incidence_matrix()` which will return the incidence matrix when it is required." 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": { 520 | "id": "vfyoW5hcV85u" 521 | }, 522 | "outputs": [], 523 | "source": [ 524 | "class HyperGraph(object):\n", 525 | " def __init__(self, hyper_edge_index, x, y):\n", 526 | " \"\"\" HyperGraph structure\n", 527 | " for a mini-batch it will store a big hypergraph\n", 528 | " representing the entire batch\n", 529 | " Args:\n", 530 | " x: node features [num_nodes x num_feats]\n", 531 | " y: hypergraph labels [num_hyper_graphs]\n", 532 | " hyper_edge_index: list of hyperedges [2, k] where k is defined in\n", 533 | " the above cell.\n", 534 | " \"\"\"\n", 535 | " self.hyper_edge_index = hyper_edge_index\n", 536 | " self.x = x.to(torch.float32)\n", 537 | " self.y = y\n", 538 | " self.num_nodes = self.x.shape[0] ##number of nodes\n", 539 | " self.num_hyper_edges = self.hyper_edge_index[1,:].max().item()+1 #number of hyper edges\n", 540 | "\n", 541 | " #this will be useful later, but we can ignore it for now\n", 542 | " def set_batch(self, batch):\n", 543 | " self.batch = batch\n", 544 | "\n", 545 | "\n", 546 | " #returns the incidence matrix H\n", 547 | " def incidence_matrix(self):\n", 548 | " return torch.sparse.LongTensor(self.hyper_edge_index,\n", 549 | " # we work with a binary incidence containing 1 if an edge exist\n", 550 | " torch.ones((self.hyper_edge_index.shape[1])),\n", 551 | " torch.Size((self.num_nodes, self.num_hyper_edges))\n", 552 | " ).to_dense()\n", 553 | "\n" 554 | ] 555 | }, 556 | { 557 | "cell_type": "markdown", 558 | "metadata": { 559 | "id": "B_ayNwU0vDFf" 560 | }, 561 | "source": [ 562 | "As an aside, we are also able to draw hypergraphs in python. It took us sometime to find a good drawing package but we finally decided this one is quite nice, with the minor annoyance that you need to have every vertex in at least one hyperedge. Documentation can be found [here](https://pnnl.github.io/HyperNetX/index.html). Let us visualise a hypergraph now and look at it's incidence matrix." 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "metadata": { 569 | "id": "5krdgQLf9Nv0" 570 | }, 571 | "outputs": [], 572 | "source": [ 573 | "num_vertices = 6\n", 574 | "num_edges = 3\n", 575 | "num_features = 16\n", 576 | "n=8\n", 577 | "x = torch.rand((6,3))\n", 578 | "y = 1\n", 579 | "incidence_matrix = torch.concat((torch.randint(0, num_vertices, (1,n)), torch.randint(0, num_edges, (1,n))), dim = 0)\n", 580 | "random_hypergraph = HyperGraph(incidence_matrix, x, y)\n", 581 | "\n", 582 | "visualise(random_hypergraph)\n", 583 | "print(f'The incidence matrix of this hypergraph is \\n {random_hypergraph.incidence_matrix()}')" 584 | ] 585 | }, 586 | { 587 | "cell_type": "markdown", 588 | "metadata": { 589 | "id": "CwxEULih5KV2" 590 | }, 591 | "source": [ 592 | "Let us now look the corresponding `hyper_edge_index`. Is it as you would expect?" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": null, 598 | "metadata": { 599 | "id": "3dbkXEEU-but" 600 | }, 601 | "outputs": [], 602 | "source": [ 603 | "random_hypergraph.hyper_edge_index" 604 | ] 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "metadata": { 609 | "id": "37ZWIdiacCpW" 610 | }, 611 | "source": [ 612 | "# 📖 **CoAuthorshipCora Dataset**\n", 613 | "\n", 614 | "For the first part of this lab, we will be working with the CoAuthorship cora dataset. CoauthroshipCora is a citation network used for node classification.\n", 615 | "\n", 616 | "Similarly for the Cora dataset for graphs, CoAuthorshipCora is a citation network of 2708 papers and our goal is to classify them into 7 categories. Each paper is characterised by a 1433 dimensional bag-of-words. All documents co-authored by a specific author form a hyperedge.\n", 617 | "\n", 618 | "We will be working with the CoAuthroshipCora (the name CoCora will be used interchangably) through a `CoCoraHyperDataset` which, currently, has methods to turn the data into a hypergraph." 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "metadata": { 625 | "id": "h_fmluDSbiWC" 626 | }, 627 | "outputs": [], 628 | "source": [ 629 | "dataset = dhg.data.CoauthorshipCora(data_root='')\n", 630 | "\n", 631 | "class CoCoraHyperDataset(object):\n", 632 | " def __init__(self, Dataset):\n", 633 | " \"\"\"\n", 634 | " Dataset is a dhg object storing CoCora data. We will convert it to a\n", 635 | " hypergraph, graph, and weighted graph. Using clique and weighted clique\n", 636 | " expansion respectively (more on this later).\n", 637 | " \"\"\"\n", 638 | " self.train_mask = Dataset['train_mask']\n", 639 | " self.val_mask = Dataset['val_mask']\n", 640 | " self.test_mask = Dataset['test_mask']\n", 641 | " self.features = Dataset['features']\n", 642 | " self.num_vertices = Dataset['features'].shape[0]\n", 643 | " self.labels = Dataset['labels']\n", 644 | " self.edge_list = Dataset['edge_list']\n", 645 | "\n", 646 | " ## Creating the hypergraph from the data\n", 647 | " self.edge_index = self.edge_list_2_index(self.edge_list)\n", 648 | " self.hyper_graph = HyperGraph(self.edge_index, self.features, self.labels)\n", 649 | "\n", 650 | "\n", 651 | " #creating a graph for the data\n", 652 | " self.graph_edge_index = self.edge_list_2_graph_index(self.edge_list)\n", 653 | " self.graph = Graph(self.graph_edge_index, self.features, self.labels)\n", 654 | "\n", 655 | " #creating a graph using weighted clique expansion\n", 656 | " self.wgraph_edge_index, self.wgraph_weights = self.edge_list_2_weighted_edge_index(self.edge_list, self.features.shape[0])\n", 657 | " self.wgraph = Graph(self.wgraph_edge_index, self.features, self.labels, self.wgraph_weights)\n", 658 | "\n", 659 | " def edge_list_2_index(self, edge_list):\n", 660 | " \"\"\"\n", 661 | " Args\n", 662 | " edge_list : a list of lists each representing nodes forming hyperedge\n", 663 | " returns :\n", 664 | " hyper_edge_index : tensor containing indices of incidence matrix [2, k]\n", 665 | " where k is number of non-zero values in incidence matrix\n", 666 | " \"\"\"\n", 667 | " index = 0\n", 668 | " hyper_vertices_list = []\n", 669 | " hyper_edges_list = []\n", 670 | " for i in edge_list:\n", 671 | " for j in range(len(i)):\n", 672 | " hyper_edges_list+=[index]\n", 673 | " hyper_vertices_list+=i\n", 674 | " index+=1\n", 675 | " hyper_edge_index = torch.tensor([hyper_vertices_list, hyper_edges_list])\n", 676 | " return hyper_edge_index\n", 677 | "\n", 678 | " def edge_list_2_graph_index(self, edge_list):\n", 679 | " \"\"\"\n", 680 | " function that creates adjacency matrix associated to clique expansion of\n", 681 | " hypergraph\n", 682 | " Args\n", 683 | " edge_list : a list of lists each representing nodes forming hyperedge\n", 684 | " returns :\n", 685 | " edge_index: tensor containing indices of adjacency matrix [2, k] where k\n", 686 | " is number of non-zero values in adjacency matrix\n", 687 | " \"\"\"\n", 688 | " source_list = []\n", 689 | " dest_list = []\n", 690 | " for current_edges in edge_list:\n", 691 | " # for every hyperedge draw an edge between every distinct pair of nodes\n", 692 | " n = len(current_edges)\n", 693 | " for i in range(n):\n", 694 | " for j in range(n):\n", 695 | " if i==j:\n", 696 | " continue\n", 697 | " source_list.append(current_edges[i])\n", 698 | " dest_list.append(current_edges[j])\n", 699 | " edge_index = torch.tensor([source_list, dest_list])\n", 700 | " return edge_index\n", 701 | "\n", 702 | " def edge_list_2_weighted_edge_index(self, edge_list, num_vertices):\n", 703 | " \"\"\"\n", 704 | " function that creates adjacency matrix associated to weighted clique\n", 705 | " expansion of hypergraph\n", 706 | " Args:\n", 707 | " edge_list : a list of lists each representing nodes forming hyperedge\n", 708 | " returns :\n", 709 | " index: tensor containing indices of adjacency matrix [2, k] where k is\n", 710 | " number of non-zero values in adjacency matrix\n", 711 | " vals: weight associated to each edge\n", 712 | " \"\"\"\n", 713 | " adjacency_matrix = torch.zeros((num_vertices, num_vertices))\n", 714 | " for current_edges in edge_list:\n", 715 | " # for every hyperedge draw an edge between every pair of nodes\n", 716 | " # for each pair of nodes the weight corresponds to number hyperedges they\n", 717 | " # have in common\n", 718 | " for i in current_edges:\n", 719 | " for j in current_edges:\n", 720 | " # this can be done with HH^T and subtracting degree matrix\n", 721 | " adjacency_matrix[i,j] += 1\n", 722 | " adjacency_matrix = adjacency_matrix.to_sparse()\n", 723 | " index = adjacency_matrix.indices()\n", 724 | " vals = adjacency_matrix.values()\n", 725 | "\n", 726 | " return index,vals\n", 727 | "\n", 728 | "\n", 729 | "Data = CoCoraHyperDataset(dataset)" 730 | ] 731 | }, 732 | { 733 | "cell_type": "markdown", 734 | "metadata": { 735 | "id": "fJfjrJgE2vH8" 736 | }, 737 | "source": [ 738 | "# 🚀 **Our first HGNN**\n", 739 | "\n", 740 | "\n", 741 | "\n", 742 | "We construct a layer of our [hypergraph neural network](https://arxiv.org/abs/1809.09401) by first creating a representation for each hyperedge. This is typically done by summing or averaging the node representations for all nodes contained in the hyperedge. Then, for each node we typically sum or average the representations of all hyperedges that contain this node, before combinging that with our original node representation.\n", 743 | "\n", 744 | "Mathematically, this can be written as $$HH^TXW,$$ where $W$ is our learnable linear projection.\n", 745 | "\n", 746 | "❓**Can you convince yourself of this?**\n", 747 | "\n", 748 | "\n", 749 | "In practice we normalise the projections to avoid exploding numerics so a layer of HGNN is $$ D_v^{-\\frac{1}{2}} H D_e^{-1} H^T D_v^{-\\frac{1}{2}}XW.$$\n", 750 | "\n", 751 | "Once again $D_v$ is the node degree matrix and $D_e$ is the hyperedge degree matrix.\n", 752 | "\n", 753 | "
\n", 754 | "\n", 755 | "
\n", 756 | "\n" 757 | ] 758 | }, 759 | { 760 | "cell_type": "markdown", 761 | "source": [ 762 | "# ❓**Task 1**: Implement the hyper graph neural network described above" 763 | ], 764 | "metadata": { 765 | "id": "fyTm_81QqV_8" 766 | } 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": null, 771 | "metadata": { 772 | "id": "SprBcItvYD8D" 773 | }, 774 | "outputs": [], 775 | "source": [ 776 | "class HyperNNLayer(nn.Module):\n", 777 | " def __init__(self, input_dim, output_dim):\n", 778 | " \"\"\"\n", 779 | " One layer of hypergraph neural network\n", 780 | "\n", 781 | " Args:\n", 782 | " input_dim : number of features of each node in hyergraph\n", 783 | " output_dim : number of output features\n", 784 | " \"\"\"\n", 785 | " super(HyperNNLayer, self).__init__()\n", 786 | " self.input_dim = input_dim\n", 787 | " self.output_dim = output_dim\n", 788 | " self.Linear = nn.Linear(input_dim,output_dim)\n", 789 | "\n", 790 | " def forward(self, x,H):\n", 791 | " \"\"\"\n", 792 | " Args:\n", 793 | " x : feature matrix [num_nodes, input_dim]\n", 794 | " H : incidence matrix [num_nodes, num_hyper_edges]\n", 795 | " returns:\n", 796 | " x : output of one layer of hypergraph neural network [num_nodes, output_dim]\n", 797 | " \"\"\"\n", 798 | " # compute degree of nodes (D_v)^-0.5\n", 799 | " degree_of_nodes = torch.nan_to_num(torch.pow(torch.diag(torch.sum(H, dim=-1)), -0.5), nan=0, posinf=0, neginf=0).to(torch.float32)\n", 800 | " # compute degree of hyper edges (D_e)^-1\n", 801 | " degree_of_edges = torch.nan_to_num(torch.pow(torch.diag(torch.sum(H, dim=0)), -1.0), nan=0, posinf=0, neginf=0).to(torch.float32)\n", 802 | "\n", 803 | " # ============ YOUR CODE HERE =============\n", 804 | " # Compute D_v^{-0.5} H D_e^{-1} H^T D_v^{-0.5} x\n", 805 | " #\n", 806 | " # x = ...\n", 807 | " # =========================================\n", 808 | "\n", 809 | " #apply linear layer\n", 810 | " x = self.Linear(x)\n", 811 | " return x\n", 812 | "\n", 813 | "class HyperNN(nn.Module):\n", 814 | " def __init__(self, input_dim, output_dim, hidden_dim, num_layers):\n", 815 | " \"\"\"\n", 816 | " Hypergraph neural network containing num_layers HyperNNLayer\n", 817 | "\n", 818 | " Args:\n", 819 | " input_dim : number of features of each node in hyergraph\n", 820 | " output_dim : number of output features\n", 821 | " hidden_dim : hidden dimension\n", 822 | " num_layers : number of layers\n", 823 | " \"\"\"\n", 824 | " super(HyperNN, self).__init__()\n", 825 | " self.input_dim = input_dim\n", 826 | " self.output_dim = output_dim\n", 827 | " self.hidden_dim=hidden_dim\n", 828 | "\n", 829 | " if num_layers > 1:\n", 830 | " self.hnn_layers = [HyperNNLayer(input_dim, hidden_dim)]\n", 831 | " self.hnn_layers+= [HyperNNLayer(hidden_dim, hidden_dim) for i in range(num_layers-2)]\n", 832 | " self.hnn_layers+= [HyperNNLayer(hidden_dim, output_dim)]\n", 833 | " else:\n", 834 | " self.hnn_layers = [HyperNNLayer(input_dim, output_dim)]\n", 835 | "\n", 836 | " self.hnn_layers = nn.ModuleList(self.hnn_layers)\n", 837 | " self.num_layers = num_layers\n", 838 | "\n", 839 | " def forward(self,hgraph):\n", 840 | " \"\"\"\n", 841 | " Args:\n", 842 | " hgraph : input hypergraph stored as HyperGraph class\n", 843 | " returns:\n", 844 | " y_hat : logits for each node [num_nodes, output_dim]\n", 845 | " \"\"\"\n", 846 | " H = hgraph.incidence_matrix()\n", 847 | " x = hgraph.x.to(torch.float32)\n", 848 | "\n", 849 | " # ============ YOUR CODE HERE =============\n", 850 | " # Apply (self.num_layers) HyperNNLayer(s), with ReLU nonlinearity\n", 851 | " #\n", 852 | " # ...\n", 853 | " # y_hat = ...\n", 854 | " # =========================================\n", 855 | " return y_hat" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": null, 861 | "metadata": { 862 | "cellView": "form", 863 | "id": "3jfYpPFjzZYl" 864 | }, 865 | "outputs": [], 866 | "source": [ 867 | "# @title ✅ [RUN] **Please run this unit test to validate your code. You might still have bugs but this is a good sanity check.**\n", 868 | "def testing_hnn():\n", 869 | " torch.random.manual_seed(0)\n", 870 | " np.random.seed(0)\n", 871 | "\n", 872 | " input_dim = 64\n", 873 | " output_dim = 128\n", 874 | " #A = torch.tensor([[0, 1, 1, 0], [1, 0, 0, 0], [1, 1, 0, 1], [1, 1, 1, 0]])\n", 875 | " H = random_hypergraph.incidence_matrix()\n", 876 | " model = HyperNNLayer(input_dim, output_dim)\n", 877 | "\n", 878 | " x = torch.rand(H.shape[0], input_dim)\n", 879 | " out = model(x, H)\n", 880 | "\n", 881 | " assert(out.shape == (H.shape[0], output_dim)), \"Oups! 🤭 Output shape is wrong\"\n", 882 | "\n", 883 | " np.random.seed(0)\n", 884 | " perm_x = torch.tensor(np.random.permutation(x.numpy()))\n", 885 | " np.random.seed(0)\n", 886 | " perm_out = torch.tensor(np.random.permutation(out.detach().numpy()))\n", 887 | "\n", 888 | " np.random.seed(0)\n", 889 | " A_perm = np.random.permutation(H.detach().numpy().transpose()).transpose()\n", 890 | " np.random.seed(0)\n", 891 | " A_perm = torch.tensor(np.random.permutation(A_perm))\n", 892 | "\n", 893 | " torch.random.manual_seed(0)\n", 894 | " model_perm = HyperNNLayer(input_dim, output_dim)\n", 895 | "\n", 896 | " out_model_perm = model_perm(perm_x, A_perm)\n", 897 | "\n", 898 | " assert (torch.allclose(perm_out, out_model_perm, atol=1e-6)), \"🤔 Something is wrong in the model! You are not permuation equivariant anymore 🥺\"\n", 899 | " print(\"All good!\")\n", 900 | "\n", 901 | "testing_hnn()\n", 902 | "np.random.seed(None)\n", 903 | "torch.random.manual_seed(datetime.now().timestamp())" 904 | ] 905 | }, 906 | { 907 | "cell_type": "markdown", 908 | "metadata": { 909 | "id": "rNS6ZwSIzzm3" 910 | }, 911 | "source": [ 912 | "Now that you have written your hypergraph neural network, it is time to train it." 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": null, 918 | "metadata": { 919 | "id": "sea01kWnuum9" 920 | }, 921 | "outputs": [], 922 | "source": [ 923 | "NUM_EPOCHS = 50 #@param {type:\"integer\"}\n", 924 | "LR = 0.001 #@param {type:\"number\"}\n", 925 | "num_runs = 3\n", 926 | "\n", 927 | "\n" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": null, 933 | "metadata": { 934 | "id": "9zBTcR6jYGcT" 935 | }, 936 | "outputs": [], 937 | "source": [ 938 | "def quick_accuracy(y_hat, y):\n", 939 | " \"\"\"\n", 940 | " Args :\n", 941 | " y_hat : logits predicted by model [n, num_classes]\n", 942 | " y : ground trutch labels [n]\n", 943 | " returns :\n", 944 | " average accuracy\n", 945 | " \"\"\"\n", 946 | " n = y.shape[0]\n", 947 | " y_hat = torch.argmax(y_hat, dim=-1)\n", 948 | " accuracy = (y_hat==y).sum().data.item()\n", 949 | " return accuracy/n\n", 950 | "\n" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": null, 956 | "metadata": { 957 | "id": "pOMspsxhpQlb" 958 | }, 959 | "outputs": [], 960 | "source": [ 961 | "def trainCoCora(hypergraph, model, mask, optimiser):\n", 962 | " model.train()\n", 963 | " y = hypergraph.y[mask]\n", 964 | " optimiser.zero_grad()\n", 965 | " y_hat = model(hypergraph)[mask] #only make predicitions for the ones we know the labels of\n", 966 | " loss = F.cross_entropy(y_hat, y)\n", 967 | " loss.backward()\n", 968 | " optimiser.step()\n", 969 | " return loss.data\n", 970 | "\n", 971 | "def evalCoCora(hypergraph, model, mask):\n", 972 | " model.eval()\n", 973 | " y = hypergraph.y[mask]\n", 974 | " y_hat = model(hypergraph)[mask]\n", 975 | " accuracy = quick_accuracy(y_hat, y)\n", 976 | " return accuracy\n", 977 | "\n", 978 | "def train_eval_loop_CoCora(model, hypergraph, train_mask,\n", 979 | " valid_mask, test_mask):\n", 980 | " optimiser = optim.Adam(model.parameters(), lr=LR)\n", 981 | " training_stats = None\n", 982 | " # Training loop\n", 983 | " for epoch in range(NUM_EPOCHS):\n", 984 | " train_loss = trainCoCora(hypergraph,model, train_mask, optimiser)\n", 985 | " train_acc = evalCoCora(hypergraph, model,train_mask)\n", 986 | " valid_acc = evalCoCora(hypergraph, model, valid_mask)\n", 987 | " if epoch % 10 == 0:\n", 988 | " print(f\"Epoch {epoch} with train loss: {train_loss:.3f} train accuracy: {train_acc:.3f} validation accuracy: {valid_acc:.3f}\")\n", 989 | " # store the loss and the accuracy for the final plot\n", 990 | " epoch_stats = {'train_acc': train_acc, 'val_acc': valid_acc, 'epoch':epoch}\n", 991 | " training_stats = update_stats(training_stats, epoch_stats)\n", 992 | " # Lets look at our final test performance\n", 993 | " test_acc = evalCoCora(hypergraph, model, test_mask)\n", 994 | " print(f\"Our final test accuracy for this model is: {test_acc:.3f}\")\n", 995 | " return training_stats\n" 996 | ] 997 | }, 998 | { 999 | "cell_type": "markdown", 1000 | "metadata": { 1001 | "id": "Z6MqjGHV1DaV" 1002 | }, 1003 | "source": [ 1004 | "🍵 **Tea break**: This will take a while to train, grab a cup of tea." 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "code", 1009 | "execution_count": null, 1010 | "metadata": { 1011 | "id": "EUasHKNA4dAx" 1012 | }, 1013 | "outputs": [], 1014 | "source": [ 1015 | "HyperNNModel = HyperNN(1433, 7, 128, 2)\n", 1016 | "# Data.hyper_graph stores our hypergraph structure\n", 1017 | "hyperNNModelOut = train_eval_loop_CoCora(HyperNNModel, Data.hyper_graph,\n", 1018 | " Data.train_mask, Data.val_mask,\n", 1019 | " Data.test_mask)\n", 1020 | "plot_stats(hyperNNModelOut)" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "markdown", 1025 | "metadata": { 1026 | "id": "o2LbqKjwEnY9" 1027 | }, 1028 | "source": [ 1029 | "# ↩ **Part 2**: From hypergraphs back to graphs\n", 1030 | "\n", 1031 | "We want to compare our new hypergraph model to the standard graph processing. But now comes the problem of converting a hypergraph to a graph.\n", 1032 | "\n", 1033 | "We will present two ways of doing it here, known as **clique expansion** and **weighted clique expansion**.\n", 1034 | "\n", 1035 | "In clique expansion, we turn each hyperedge into a fully connected graph.\n", 1036 | "\n", 1037 | "In weighted clique expansion, we turn each hyperedge into a fully connected graph. However, different from clique expansion, the weight on the edge between vertex i and vertex j, is the number of hyperedges they have in common. This reduces the loss of information. Consider the adjacency matrix for the graph created by weighted clique expansion $A$. Then notice that $$ A = HH^T - D_v.$$ Therefore, HGNN is equivalent to [GCN](https://arxiv.org/abs/1609.02907) if we transform the hypergraph to a graph using weighted clique expansion, provided we add weighted self-loops.\n", 1038 | "\n", 1039 | "❓Convince yourself that the above statement is indeed true.\n", 1040 | "
\n", 1041 | "\n", 1042 | "
\n", 1043 | "\n" 1044 | ] 1045 | }, 1046 | { 1047 | "cell_type": "markdown", 1048 | "metadata": { 1049 | "id": "gxt-OjGK2I2B" 1050 | }, 1051 | "source": [ 1052 | "To store information about each graph, we create the following ` Graph` class. Note that, instead of storing an entire adjacency matrix to describe the graph structure, we will store it more efficiently as a list of edges of shape `[2, num_edges]`, were for each edge we store the indices of the source and destination node." 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "code", 1057 | "execution_count": null, 1058 | "metadata": { 1059 | "id": "ohJNBXkmkU0F" 1060 | }, 1061 | "outputs": [], 1062 | "source": [ 1063 | "##graph object\n", 1064 | "\n", 1065 | "class Graph(object):\n", 1066 | " def __init__(self, edge_index, x, y, weighted = None):\n", 1067 | " \"\"\" Graph structure\n", 1068 | " for a mini-batch it will store a big (sparse) graph\n", 1069 | " representing the entire batch\n", 1070 | " Args:\n", 1071 | " x: node features [num_nodes x num_feats]\n", 1072 | " y: graph labels [num_graphs]\n", 1073 | " edge_index: list of edges [2 x num_edges]\n", 1074 | " \"\"\"\n", 1075 | " self.edge_index = edge_index\n", 1076 | " self.x = x.to(torch.float32)\n", 1077 | " self.y = y\n", 1078 | " self.num_nodes = self.x.shape[0]\n", 1079 | " if weighted is None:\n", 1080 | " self.values = torch.ones(self.edge_index.shape[1])\n", 1081 | " else:\n", 1082 | " self.values = weighted\n", 1083 | "\n", 1084 | " #ignore this for now, it will be useful for batching\n", 1085 | " def set_batch(self, batch):\n", 1086 | " \"\"\" list of ints that maps each node to the graph it belongs to\n", 1087 | " e.g. for batch = [0,0,0,1,1,1,1]: the first 3 nodes belong to graph_0 while\n", 1088 | " the last 4 belong to graph_1\n", 1089 | " \"\"\"\n", 1090 | " self.batch = batch\n", 1091 | "\n", 1092 | " # this function returns a sparse tensor\n", 1093 | " def get_adjacency_matrix(self):\n", 1094 | " \"\"\" from the list of edges create\n", 1095 | " a num_nodes x num_nodes sparse adjacency matrix\n", 1096 | " \"\"\"\n", 1097 | " return torch.sparse.LongTensor(self.edge_index,\n", 1098 | " # we work with a binary adj containing 1 if an edge exist\n", 1099 | " self.values,\n", 1100 | " torch.Size((self.num_nodes, self.num_nodes))\n", 1101 | " ).to_dense()" 1102 | ] 1103 | }, 1104 | { 1105 | "cell_type": "code", 1106 | "execution_count": null, 1107 | "metadata": { 1108 | "id": "oiTUsWRlmo_l" 1109 | }, 1110 | "outputs": [], 1111 | "source": [ 1112 | "class GCNLayer(nn.Module):\n", 1113 | " def __init__(self,input_dim, output_dim):\n", 1114 | " super(GCNLayer, self).__init__()\n", 1115 | " self.input_dim =input_dim\n", 1116 | " self.output_dim = output_dim\n", 1117 | " \"\"\"GCN layer to be implemented by students of practical\n", 1118 | "\n", 1119 | " Args:\n", 1120 | " input_dim (int): Dimensionality of the input feature vectors\n", 1121 | " output_dim (int): Dimensionality of the output softmax distribution\n", 1122 | " A (torch.Tensor): 2-D adjacency matrix\n", 1123 | " \"\"\"\n", 1124 | "\n", 1125 | "\n", 1126 | " self.linear = nn.Linear(input_dim, output_dim)\n", 1127 | "\n", 1128 | " def forward(self, x, A):\n", 1129 | "\n", 1130 | " D = torch.nan_to_num(torch.pow(torch.diag(torch.sum(A, dim=0)), -0.5), nan=0, posinf=0, neginf=0)\n", 1131 | " self.adj_norm = D@A@D\n", 1132 | " x = self.adj_norm@x\n", 1133 | " x=self.linear(x)\n", 1134 | " return x\n", 1135 | "\n", 1136 | "class GNN(nn.Module):\n", 1137 | " def __init__(self, input_dim, hidden_dim, output_dim, num_layers):\n", 1138 | " \"\"\"\n", 1139 | " Graph convolutional network containing num_layers GCNLayer\n", 1140 | "\n", 1141 | " Args:\n", 1142 | " input_dim : number of features of each node in graph\n", 1143 | " output_dim : number of output features\n", 1144 | " hidden_dim : hidden dimension\n", 1145 | " num_layers : number of layers\n", 1146 | " \"\"\"\n", 1147 | " super(GNN, self).__init__()\n", 1148 | " self.input_dim = input_dim\n", 1149 | " self.output_dim = output_dim\n", 1150 | " if num_layers > 1:\n", 1151 | " self.gcn_layers = [GCNLayer(input_dim, hidden_dim)]\n", 1152 | " self.gcn_layers += [GCNLayer(hidden_dim, hidden_dim) for i in range(num_layers-2)]\n", 1153 | " self.gcn_layers += [GCNLayer(hidden_dim, output_dim)]\n", 1154 | " else:\n", 1155 | " self.gcn_layers = [GCNLayer(input_dim, output_dim)]\n", 1156 | "\n", 1157 | " self.gcn_layers = nn.ModuleList(self.gcn_layers)\n", 1158 | " self.num_gcn_layers = num_layers\n", 1159 | "\n", 1160 | " def forward(self, graph):\n", 1161 | " \"\"\"\n", 1162 | " Args:\n", 1163 | " graph : input graph stored as Graph class\n", 1164 | " returns:\n", 1165 | " y_hat : logits for each node [num_nodes, output_dim]\n", 1166 | " \"\"\"\n", 1167 | " x = graph.x.to(torch.float32)\n", 1168 | " A = graph.get_adjacency_matrix()\n", 1169 | " for j in range(self.num_gcn_layers-1):\n", 1170 | " x = self.gcn_layers[j](x, A)\n", 1171 | " x = F.relu(x)\n", 1172 | " x = self.gcn_layers[-1](x,A)\n", 1173 | " y_hat = x\n", 1174 | " return y_hat" 1175 | ] 1176 | }, 1177 | { 1178 | "cell_type": "markdown", 1179 | "source": [ 1180 | "# ❓**Task 2**: Instantiate the model" 1181 | ], 1182 | "metadata": { 1183 | "id": "ayVfaI2Qqsum" 1184 | } 1185 | }, 1186 | { 1187 | "cell_type": "markdown", 1188 | "metadata": { 1189 | "id": "6urVwmal1asz" 1190 | }, 1191 | "source": [ 1192 | "🍵 **Tea break**: This will take a while to train, grab a cup of tea." 1193 | ] 1194 | }, 1195 | { 1196 | "cell_type": "code", 1197 | "execution_count": null, 1198 | "metadata": { 1199 | "id": "kfeZGLCHvaEd" 1200 | }, 1201 | "outputs": [], 1202 | "source": [ 1203 | "\n", 1204 | "graph_model = GNN(1433, 128, 7, 2)\n", 1205 | "\n", 1206 | "# ============ YOUR CODE HERE =============\n", 1207 | "# apply the train-eval pipeline for the graph structure\n", 1208 | "# Data.graph stores clique expansion associated to hypergraph\n", 1209 | "#\n", 1210 | "# graph_model_out = train_eval_loop_CoCora(...)\n", 1211 | "# =========================================\n", 1212 | "\n", 1213 | "plot_stats(graph_model_out)" 1214 | ] 1215 | }, 1216 | { 1217 | "cell_type": "markdown", 1218 | "source": [ 1219 | "# ❓ **Task 3**: Instantiate the model" 1220 | ], 1221 | "metadata": { 1222 | "id": "ifbUCV3sq2d3" 1223 | } 1224 | }, 1225 | { 1226 | "cell_type": "markdown", 1227 | "metadata": { 1228 | "id": "--qqQe851cBd" 1229 | }, 1230 | "source": [ 1231 | "🍵 **Tea break**: This will take a while to train, grab a cup of tea." 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "code", 1236 | "execution_count": null, 1237 | "metadata": { 1238 | "id": "jiVKeKGHYQRg" 1239 | }, 1240 | "outputs": [], 1241 | "source": [ 1242 | "w_graph_model = GNN(1433, 128, 7, 2)\n", 1243 | "\n", 1244 | "# ============ YOUR CODE HERE =============\n", 1245 | "# apply the train-eval pipeline for the graph structure\n", 1246 | "# Data.wgraph stores weighted clique expansion associated to hypergraph\n", 1247 | "#\n", 1248 | "# graph_model_out = train_eval_loop_CoCora(...)\n", 1249 | "# =========================================\n", 1250 | "\n", 1251 | "plot_stats(graph_model_out)" 1252 | ] 1253 | }, 1254 | { 1255 | "cell_type": "markdown", 1256 | "metadata": { 1257 | "id": "0HjXsf621dww" 1258 | }, 1259 | "source": [ 1260 | "❓**Task**: Can you think of any other ways to convert a hypergraph to a graph?" 1261 | ] 1262 | }, 1263 | { 1264 | "cell_type": "markdown", 1265 | "source": [ 1266 | "What do your results tell you? The hyperparameters you selected will determine the final accuracy you recieved, sometimes the graph models will be better and other times the hypergraph model will be. As shown in [Hypergraph Neural Networks](https://arxiv.org/abs/1809.09401) the hypergraph approach is better." 1267 | ], 1268 | "metadata": { 1269 | "id": "WYnGtrwKFsM6" 1270 | } 1271 | }, 1272 | { 1273 | "cell_type": "markdown", 1274 | "metadata": { 1275 | "id": "DEQqjFOGkLJL" 1276 | }, 1277 | "source": [ 1278 | "# 🧬 **Part 3**: Hypergraph level prediction and batching\n", 1279 | "\n", 1280 | "**Our Dataset**\n", 1281 | "\n", 1282 | "In this part, we are going to show that hypergraph models can also be a powerful tool for hypergraph level prediction.\n", 1283 | "\n", 1284 | "Unfortunately, as we are writing this, we are not aware of a benchmarking dataset for hypergraph level prediction. To ameliorate this, we are going to design and building our own dataset.\n", 1285 | "\n", 1286 | "The dataset `hypergraph_dataset` is a collection of 1000 hypergraphs. Each node of the hypergraph is characterised by 6 features, which have been generated randomly. It is our models task to predict if the hypergraph is connected or not. We say that a hypergraph $\\mathcal{H} = (\\mathcal{V}, \\mathcal{E})$ is **connected** if for any distinct $v_0, v_n \\in \\mathcal{V}$, there is a sequence $\\{v_0, v_1,\\dots, v_{n-1}, v_n\\}$ such that $v_{i}$ and $v_{i+1}$ belong to the same hyperedge.\n", 1287 | "\n", 1288 | "\n", 1289 | "Each hypergraph has between 5 and 11 nodes and between 1 and 10 edges. To determine is the hypergraph is connected or not, we are using the `is_connected()` function from the `hypernetx` package, an extremely useful package for all things hypergraphs. Documentation can be found [here](https://pnnl.github.io/HyperNetX/index.html).\n", 1290 | "\n", 1291 | "\n" 1292 | ] 1293 | }, 1294 | { 1295 | "cell_type": "code", 1296 | "execution_count": null, 1297 | "metadata": { 1298 | "id": "LELL7Lt8IUsx", 1299 | "cellView": "form" 1300 | }, 1301 | "outputs": [], 1302 | "source": [ 1303 | "#@title [RUN] Generate dataset\n", 1304 | "torch.random.manual_seed(0)\n", 1305 | "\n", 1306 | "Hypergraph_Dataset = []\n", 1307 | "number_connected = 0\n", 1308 | "number_not_connected = 0\n", 1309 | "number_of_hgraphs = 1000\n", 1310 | "for j in range(number_of_hgraphs):\n", 1311 | " num_vertices = torch.randint(5, 11, (1,1)).item() # select number of nodes uniformly between 5 and 11\n", 1312 | " num_edges = torch.randint(1,10, (1,1)).item() # select number of hyper edges uniformly between 5 and 11\n", 1313 | " nnz = torch.randint(num_edges, num_vertices*num_edges, (1,1)).item() # number of nnz in incidence matrix\n", 1314 | "\n", 1315 | " X = torch.rand((num_vertices, 6)) # randomly generate features\n", 1316 | " y = torch.rand(1) # dummy label\n", 1317 | "\n", 1318 | " #randomly generate incidence matrix\n", 1319 | " hyper_edge_index = torch.concat((torch.randint(0, num_vertices, (1,nnz)),\n", 1320 | " torch.randint(0, num_edges, (1,nnz))), dim = 0)\n", 1321 | "\n", 1322 | " #create hypergraph as hnxHyperGraph to check connectivity\n", 1323 | " hypergraph = hnxHyperGraph(HyperGraph(hyper_edge_index, X, y))\n", 1324 | "\n", 1325 | " #check if hypergraph is connected or not\n", 1326 | " if hypergraph.is_connected() == True:\n", 1327 | " number_connected +=1\n", 1328 | " # assign label of 1 if connected\n", 1329 | " Hypergraph_Dataset.append(HyperGraph(hyper_edge_index, X, torch.tensor(1)))\n", 1330 | " else:\n", 1331 | " number_not_connected +=1\n", 1332 | " # assign label of 0 if not connected\n", 1333 | " Hypergraph_Dataset.append(HyperGraph(hyper_edge_index, X, torch.tensor(0)))\n", 1334 | "\n", 1335 | "\n", 1336 | "print(f'Number of connected hypergraphs in our dataset = {number_connected} out of {1000} hypergraphs')\n" 1337 | ] 1338 | }, 1339 | { 1340 | "cell_type": "markdown", 1341 | "metadata": { 1342 | "id": "0MADSDtRW8h0" 1343 | }, 1344 | "source": [ 1345 | "Let us visualise the 11'th hypergraph in our dataset together with our label." 1346 | ] 1347 | }, 1348 | { 1349 | "cell_type": "code", 1350 | "execution_count": null, 1351 | "metadata": { 1352 | "id": "bo7CVGckFSCm" 1353 | }, 1354 | "outputs": [], 1355 | "source": [ 1356 | "visualise(Hypergraph_Dataset[10])\n", 1357 | "print(Hypergraph_Dataset[10].incidence_matrix().shape)" 1358 | ] 1359 | }, 1360 | { 1361 | "cell_type": "code", 1362 | "execution_count": null, 1363 | "metadata": { 1364 | "id": "OAUT0cRsuhtp" 1365 | }, 1366 | "outputs": [], 1367 | "source": [ 1368 | "# Hypergraph_Dataset\n", 1369 | "\n", 1370 | "## We now need to split our dataset into testing, training and validation subset\n", 1371 | "\n", 1372 | "train_data = Hypergraph_Dataset[0:700]\n", 1373 | "validation_data = Hypergraph_Dataset[700:850]\n", 1374 | "test_data = Hypergraph_Dataset[850:]" 1375 | ] 1376 | }, 1377 | { 1378 | "cell_type": "markdown", 1379 | "metadata": { 1380 | "id": "PTH6xfvCs3Mn" 1381 | }, 1382 | "source": [ 1383 | "⛵ **Batching**\n", 1384 | "\n", 1385 | "Similar to any other deep learning field, when doing hypergraph level prediction, it is good practice to process these in batches. But how do we build these batches?\n", 1386 | "\n", 1387 | "One solution for this is to create a single *sparse* hypergraph as the union of all the hypergraphs in the batch as follow:\n", 1388 | "\n", 1389 | "1. stack the features $x$ for all the nodes in all the hypergraphs\n", 1390 | "2. stack the labels $y$ for all the nodes in all the hypergraphs\n", 1391 | "3. stack all the incidence matrices $H_i$ as diagonal blocks in the new incidence matrix\n", 1392 | "\n", 1393 | "This way, we will obtain a new hypergraph containing $\\sum_{i=1}^{B}|V_i|$ nodes, where $B$ is the batch_size and by $|V_i|$ we denote the number of nodes in hypergraph $i$. Note that since **no** hyperedges connect nodes from different hypergraphs, the information propagation will not be affected by the way we store it.\n", 1394 | "\n", 1395 | "
\n", 1396 | "\n", 1397 | "
\n", 1398 | "\n", 1399 | "\n" 1400 | ] 1401 | }, 1402 | { 1403 | "cell_type": "markdown", 1404 | "source": [ 1405 | "# ❓**Task 4**: Complete the `create_mini_batch` function" 1406 | ], 1407 | "metadata": { 1408 | "id": "2xmy_ry9q4-K" 1409 | } 1410 | }, 1411 | { 1412 | "cell_type": "code", 1413 | "execution_count": null, 1414 | "metadata": { 1415 | "id": "SQYi9MF9UW0w" 1416 | }, 1417 | "outputs": [], 1418 | "source": [ 1419 | "def create_mini_batch(hgraph_list) -> Graph:\n", 1420 | " \"\"\" Built a sparse graph from a batch of graphs\n", 1421 | " Args:\n", 1422 | " graph_list: list of Graph objects in a batch\n", 1423 | " Returns:\n", 1424 | " a big (sparse) Graph representing the entire batch\n", 1425 | " \"\"\"\n", 1426 | " #insert first graph into the structure\n", 1427 | " batch_edge_index = hgraph_list[0].hyper_edge_index\n", 1428 | "\n", 1429 | " # ============ YOUR CODE HERE =============\n", 1430 | " # Insert the features and the labels for the first graph\n", 1431 | " #\n", 1432 | " # batch_x = ...\n", 1433 | " # batch_y = ...\n", 1434 | " #\n", 1435 | " # You might need additional variables here\n", 1436 | " # ...\n", 1437 | " # =========================================\n", 1438 | "\n", 1439 | " # ============ YOUR CODE HERE =============\n", 1440 | " # Insert the incidence matrix corresponding to the first graph\n", 1441 | " # Don't forget that the incidence is stored using the indexes!\n", 1442 | " #\n", 1443 | " # batch_edge_index = ...\n", 1444 | " # =========================================\n", 1445 | "\n", 1446 | " # the batch indexes are computed for you\n", 1447 | " batch_batch = torch.zeros((hgraph_list[0].num_nodes), dtype=torch.int64)\n", 1448 | "\n", 1449 | " # append the rest of the graphs to the structure\n", 1450 | " for idx, graph in enumerate(hgraph_list[1:]):\n", 1451 | "\n", 1452 | " # ============ YOUR CODE HERE =============\n", 1453 | " # Append the features and the labels for the current graph\n", 1454 | " #\n", 1455 | " # batch_x = ...\n", 1456 | " # batch_y = ...\n", 1457 | " # =========================================\n", 1458 | "\n", 1459 | " # ============ YOUR CODE HERE =============\n", 1460 | " # Concat the incidence matrix as a block matrix\n", 1461 | " # Keep in mind that we only store the indexes for the incidence matrix!\n", 1462 | " #\n", 1463 | " # ...\n", 1464 | " # batch_edge_index = ...\n", 1465 | " # =========================================\n", 1466 | "\n", 1467 | " # create the array of indexes mapping nodes in the batch-graph\n", 1468 | " # to the graph they belong to\n", 1469 | " # specify the mapping between the new nodes and the graph they belong to (idx+1)\n", 1470 | " batch_batch = torch.concat([batch_batch, (idx+1)*torch.ones([graph.num_nodes]).to(torch.int64)])\n", 1471 | "\n", 1472 | " pass\n", 1473 | "\n", 1474 | " # [SANITY CHECK]\n", 1475 | " # If everything is correct, we expect the following shapes\n", 1476 | " # for a batch with N=n1+..+nk nodes and E=e1+..+ek edges\n", 1477 | " #\n", 1478 | " # batch_x -> [N, num_features]\n", 1479 | " # batch_y -> [N]\n", 1480 | " # batch_edge_index -> [2, K]\n", 1481 | "\n", 1482 | " # create the big sparse graph\n", 1483 | " batch_graph = HyperGraph(batch_edge_index, batch_x, batch_y)\n", 1484 | " # attach the index array to the Graph structure\n", 1485 | " batch_graph.set_batch(batch_batch)\n", 1486 | " return batch_graph" 1487 | ] 1488 | }, 1489 | { 1490 | "cell_type": "markdown", 1491 | "metadata": { 1492 | "id": "LAJIBSQcvHln" 1493 | }, 1494 | "source": [ 1495 | "Let's visualise a few of the graphs turned into a larger batch of graphs." 1496 | ] 1497 | }, 1498 | { 1499 | "cell_type": "code", 1500 | "execution_count": null, 1501 | "metadata": { 1502 | "id": "ihKRhmAFVnAB" 1503 | }, 1504 | "outputs": [], 1505 | "source": [ 1506 | "smaller_list = Hypergraph_Dataset[0:3]\n", 1507 | "print(smaller_list[0].incidence_matrix().shape)\n", 1508 | "print(smaller_list[1].incidence_matrix().shape)\n", 1509 | "print(smaller_list[2].incidence_matrix().shape)\n", 1510 | "\n", 1511 | "batch_graph = create_mini_batch(smaller_list)\n", 1512 | "visualise(batch_graph)\n", 1513 | "print(batch_graph.incidence_matrix().shape)\n", 1514 | "\n", 1515 | "\n" 1516 | ] 1517 | }, 1518 | { 1519 | "cell_type": "markdown", 1520 | "metadata": { 1521 | "id": "U_y20A0BvQX8" 1522 | }, 1523 | "source": [ 1524 | "A simple way of aggregating information from node-level representation to obtain hypergraph-level predictions is by (max/mean/sum) pooling. This can be efficiently obtained using the [`torch_scatter`](https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/mean.html) library containing operations such as `scatter_mean`, `scatter_max`, `scatter_sum`.\n", 1525 | "\n", 1526 | " `scatter_*` receives as input a tensor and an array of indices and pools the information in the tensor stored at the indices specified in the array.\n", 1527 | "\n", 1528 | " In order to do this, we need to know for every node, what hypergraph it belongs to. This is stored in `batch_batch` variable of our `create_mini_batch` function." 1529 | ] 1530 | }, 1531 | { 1532 | "cell_type": "markdown", 1533 | "metadata": { 1534 | "id": "PDAKzl3Ew1th" 1535 | }, 1536 | "source": [ 1537 | "Visualisation for `scatter_sum(array, index)`:\n", 1538 | "\n", 1539 | "\\\\\n", 1540 | "\n", 1541 | "\n", 1542 | "\n", 1543 | "" 1544 | ] 1545 | }, 1546 | { 1547 | "cell_type": "markdown", 1548 | "source": [ 1549 | "# ❓ **Task 5**: Modify the hyper graph neural network to do hypergraph level prediction" 1550 | ], 1551 | "metadata": { 1552 | "id": "D1x1KeijrEYw" 1553 | } 1554 | }, 1555 | { 1556 | "cell_type": "code", 1557 | "execution_count": null, 1558 | "metadata": { 1559 | "id": "ReTdjkZXxtUl" 1560 | }, 1561 | "outputs": [], 1562 | "source": [ 1563 | "class GHyperNNLayer(nn.Module):\n", 1564 | " def __init__(self, input_dim, output_dim):\n", 1565 | " \"\"\"\n", 1566 | " One layer of hypergraph neural network\n", 1567 | "\n", 1568 | " Args:\n", 1569 | " input_dim : number of features of each node in hyergraph\n", 1570 | " output_dim : number of output features\n", 1571 | " \"\"\"\n", 1572 | " super(GHyperNNLayer, self).__init__()\n", 1573 | " self.input_dim = input_dim\n", 1574 | " self.output_dim = output_dim\n", 1575 | " self.Linear = nn.Linear(input_dim,output_dim)\n", 1576 | "\n", 1577 | " def forward(self, x,H):\n", 1578 | " \"\"\"\n", 1579 | " Args:\n", 1580 | " x : feature matrix [num_nodes, input_dim]\n", 1581 | " H : incidence matrix [num_nodes, num_hyper_edges]\n", 1582 | " returns:\n", 1583 | " x : output of one layer of hypergraph neural network [num_nodes, output_dim]\n", 1584 | " \"\"\"\n", 1585 | "\n", 1586 | " # compute degree of nodes (D_v)^-0.5\n", 1587 | " degree_of_nodes = torch.nan_to_num(torch.pow(torch.diag(torch.sum(H, dim=-1)), -0.5), nan=0, posinf=0, neginf=0).to(torch.float32)\n", 1588 | " # compute degree of hyper edges (D_e)^-1\n", 1589 | " degree_of_edges = torch.nan_to_num(torch.pow(torch.diag(torch.sum(H, dim=0)), -1.0), nan=0, posinf=0, neginf=0).to(torch.float32)\n", 1590 | "\n", 1591 | " # ============ YOUR CODE HERE =============\n", 1592 | " # Compute D_v^{-0.5} H D_e^{-1} H^T D_v^{-0.5} x\n", 1593 | " #\n", 1594 | " # x = ...\n", 1595 | " # =========================================\n", 1596 | "\n", 1597 | " # apply a linear layer\n", 1598 | " x = self.Linear(x)\n", 1599 | " return x\n", 1600 | "\n", 1601 | "class GHyperNN(nn.Module):\n", 1602 | " def __init__(self, input_dim, output_dim, hidden_dim, num_layers):\n", 1603 | " \"\"\"\n", 1604 | " Hypergraph neural network containing num_layers GHyperNNLayer for hypergraph\n", 1605 | " level prediction\n", 1606 | "\n", 1607 | " Args:\n", 1608 | " input_dim : number of features of each node in hyergraph\n", 1609 | " output_dim : number of output features\n", 1610 | " hidden_dim : hidden dimension\n", 1611 | " num_layers : number of layers\n", 1612 | " \"\"\"\n", 1613 | " super(GHyperNN, self).__init__()\n", 1614 | " self.input_dim = input_dim\n", 1615 | " self.output_dim = output_dim\n", 1616 | " self.hidden_dim=hidden_dim\n", 1617 | "\n", 1618 | " if num_layers > 1:\n", 1619 | " self.hnn_layers = [GHyperNNLayer(input_dim, hidden_dim)]\n", 1620 | " self.hnn_layers+= [GHyperNNLayer(hidden_dim, hidden_dim) for i in range(num_layers-2)]\n", 1621 | " self.hnn_layers+= [GHyperNNLayer(hidden_dim, output_dim)]\n", 1622 | " else:\n", 1623 | " self.hnn_layers = [GHyperNNLayer(input_dim, output_dim)]\n", 1624 | "\n", 1625 | " self.hnn_layers = nn.ModuleList(self.hnn_layers)\n", 1626 | " self.num_layers = num_layers\n", 1627 | "\n", 1628 | " def forward(self,hgraph):\n", 1629 | " \"\"\"\n", 1630 | " Args:\n", 1631 | " hgraph : input hypergraph stored as HyperGraph class formed as a batch\n", 1632 | " returns:\n", 1633 | " y_hat : logits for each hypergraph in batch [batch_size, output_dim]\n", 1634 | " \"\"\"\n", 1635 | "\n", 1636 | " H = hgraph.incidence_matrix()\n", 1637 | " x = hgraph.x.to(torch.float32)\n", 1638 | "\n", 1639 | " # ============ YOUR CODE HERE =============\n", 1640 | " # Apply (self.num_layers) HyperNNLayer(s), with ReLU nonlinearity\n", 1641 | " #\n", 1642 | " # ...\n", 1643 | " # x = ...\n", 1644 | " # =========================================\n", 1645 | "\n", 1646 | " # ============ YOUR CODE HERE =============\n", 1647 | " # Aggregate the node information to obtain a hypergraph-level prediction -> [batch_size, output_dim]\n", 1648 | " #\n", 1649 | " # remember that information about which hypergraph each node belongs to\n", 1650 | " # is stored in hgraph.batch\n", 1651 | " # y_hat = ...\n", 1652 | " # =========================================\n", 1653 | " return y_hat" 1654 | ] 1655 | }, 1656 | { 1657 | "cell_type": "code", 1658 | "execution_count": null, 1659 | "metadata": { 1660 | "cellView": "form", 1661 | "id": "YE3w_hIk4IlQ" 1662 | }, 1663 | "outputs": [], 1664 | "source": [ 1665 | "# @title ✅ [RUN] **Please run this unit test to validate your code. You might still have bugs but this is a good sanity check.**\n", 1666 | "def testing_hnn():\n", 1667 | " torch.random.manual_seed(0)\n", 1668 | " np.random.seed(0)\n", 1669 | "\n", 1670 | " input_dim = 6\n", 1671 | " output_dim = 2\n", 1672 | " hidden_dim = 128\n", 1673 | " hypergraph = Hypergraph_Dataset[0]\n", 1674 | " model = GHyperNN(input_dim, output_dim, hidden_dim, 3)\n", 1675 | " # visualise(hypergraph)\n", 1676 | " out = model(create_mini_batch([hypergraph]))\n", 1677 | "\n", 1678 | " assert(out.shape[-1] == output_dim), \"Oups! 🤭 Output shape is wrong\"\n", 1679 | "\n", 1680 | " np.random.seed(0)\n", 1681 | " perm_x = torch.tensor(np.random.permutation(hypergraph.x.numpy()))\n", 1682 | "\n", 1683 | "\n", 1684 | " H = hypergraph.incidence_matrix()\n", 1685 | "\n", 1686 | "\n", 1687 | " np.random.seed(0)\n", 1688 | " A_perm = torch.tensor(np.random.permutation(H.numpy()))\n", 1689 | " A_perm = incidence_to_edgeindex(A_perm)\n", 1690 | " perm_hypergraph = HyperGraph(A_perm, perm_x, hypergraph.y)\n", 1691 | " torch.random.manual_seed(0)\n", 1692 | "\n", 1693 | " # visualise(perm_hypergraph)\n", 1694 | " # visualise(hypergraph)\n", 1695 | " out_model_perm = model(create_mini_batch([perm_hypergraph]))\n", 1696 | "\n", 1697 | "\n", 1698 | " assert (torch.allclose(out, out_model_perm, atol=1e-4)), \"🤔 Something is wrong in the model! You are not permuation invariant anymore 🥺\"\n", 1699 | " print(\"All good!\")\n", 1700 | "\n", 1701 | "testing_hnn()\n", 1702 | "np.random.seed(None)\n", 1703 | "torch.random.manual_seed(datetime.now().timestamp())" 1704 | ] 1705 | }, 1706 | { 1707 | "cell_type": "code", 1708 | "execution_count": null, 1709 | "metadata": { 1710 | "id": "9KmDtzO4dg-1" 1711 | }, 1712 | "outputs": [], 1713 | "source": [ 1714 | "BATCH_SIZE = 50 #@param {type:\"integer\"}\n", 1715 | "learning_rate = 0.001 #@param {type: \"number\"}\n", 1716 | "num_epochs = 10 #@param {type: \"integer\"}\n" 1717 | ] 1718 | }, 1719 | { 1720 | "cell_type": "code", 1721 | "execution_count": null, 1722 | "metadata": { 1723 | "id": "u-Q6KjVAc8ly" 1724 | }, 1725 | "outputs": [], 1726 | "source": [ 1727 | "def train(dataset, model, optimiser, epoch, loss_fct, metric_fct, print_every = 50):\n", 1728 | " \"\"\" Train model for one epoch\n", 1729 | " \"\"\"\n", 1730 | " model.train()\n", 1731 | " num_iter = int(len(dataset)/BATCH_SIZE)\n", 1732 | " for i in range(num_iter):\n", 1733 | " batch_list = dataset[i*BATCH_SIZE:(i+1)*BATCH_SIZE]\n", 1734 | " batch = create_mini_batch(batch_list)\n", 1735 | " optimiser.zero_grad()\n", 1736 | " y_hat= model(batch)\n", 1737 | " loss = loss_fct(y_hat, batch.y)\n", 1738 | " metric = metric_fct(y_hat, batch.y)\n", 1739 | " loss.backward()\n", 1740 | " optimiser.step()\n", 1741 | " if (i+1) % print_every == 0:\n", 1742 | " print(f\"Epoch {epoch} Iter {i}/{num_iter}\",\n", 1743 | " f\"Loss train {loss}; Metric train {metric}\")\n", 1744 | " return loss, metric\n", 1745 | "\n", 1746 | "def evaluate(dataset, model, loss_fct, metrics_fct):\n", 1747 | " \"\"\" Evaluate model on dataset\n", 1748 | " \"\"\"\n", 1749 | " model.eval()\n", 1750 | " # be careful in practice, as doing this way we will lose some\n", 1751 | " # examples from the validation split, when len(dataset)%BATCH_SIZE != 0\n", 1752 | " # think about how can you fix this!\n", 1753 | " num_iter = int(len(dataset)/BATCH_SIZE)\n", 1754 | " metrics_eval = 0\n", 1755 | " loss_eval = 0\n", 1756 | " for i in range(num_iter):\n", 1757 | " batch_list = dataset[i*BATCH_SIZE:(i+1)*BATCH_SIZE]\n", 1758 | " batch = create_mini_batch(batch_list)\n", 1759 | "\n", 1760 | "\n", 1761 | " y_hat = model(batch).to(torch.float32)\n", 1762 | "\n", 1763 | " metrics = metrics_fct(y_hat, batch.y)\n", 1764 | " loss = loss_fct(y_hat, batch.y)\n", 1765 | "\n", 1766 | " metrics_eval += metrics\n", 1767 | " loss_eval += loss.detach()\n", 1768 | " metrics_eval /= num_iter\n", 1769 | " loss_eval /= num_iter\n", 1770 | " return loss_eval, metrics_eval" 1771 | ] 1772 | }, 1773 | { 1774 | "cell_type": "code", 1775 | "execution_count": null, 1776 | "metadata": { 1777 | "id": "7UeZlb8md8hA" 1778 | }, 1779 | "outputs": [], 1780 | "source": [ 1781 | "def train_eval(model, train_dataset, val_dataset, test_dataset,\n", 1782 | " loss_fct, metric_fct, print_every=1):\n", 1783 | " \"\"\" Train the model for NUM_EPOCHS epochs\n", 1784 | " \"\"\"\n", 1785 | " #Instantiatie our optimiser\n", 1786 | " optimiser = optim.Adam(model.parameters(), lr=learning_rate)\n", 1787 | " training_stats = None\n", 1788 | "\n", 1789 | " #initial evaluation (before training)\n", 1790 | " val_loss, val_metric = evaluate(val_dataset, model, loss_fct, metric_fct)\n", 1791 | " train_loss, train_metric = evaluate(train_dataset[:BATCH_SIZE], model,\n", 1792 | " loss_fct, metric_fct)\n", 1793 | " epoch_stats = {'train_loss': train_loss.detach(), 'val_loss': val_loss.detach(),\n", 1794 | " 'train_metric': train_metric, 'val_metric': val_metric,\n", 1795 | " 'epoch':0}\n", 1796 | " training_stats = update_stats(training_stats, epoch_stats)\n", 1797 | "\n", 1798 | " for epoch in range(num_epochs):\n", 1799 | " if isinstance(train_dataset, list):\n", 1800 | " random.shuffle(train_dataset)\n", 1801 | " else:\n", 1802 | " train_dataset.shuffle()\n", 1803 | " train_loss, train_metric = train(train_dataset, model, optimiser, epoch,\n", 1804 | " loss_fct, metric_fct, print_every)\n", 1805 | " val_loss, val_metric = evaluate(val_dataset, model, loss_fct, metric_fct)\n", 1806 | " print(f\"[Epoch {epoch+1}]\",\n", 1807 | " f\"train loss: {train_loss:.3f} val loss: {val_loss:.3f}\",\n", 1808 | " f\"train metric: {train_metric:.3f} val metric: {val_metric:.3f}\"\n", 1809 | " )\n", 1810 | " # store the loss and the computed metric for the final plot\n", 1811 | " epoch_stats = {'train_loss': train_loss.detach(), 'val_loss': val_loss.detach(),\n", 1812 | " 'train_metric': train_metric, 'val_metric': val_metric,\n", 1813 | " 'epoch':epoch+1}\n", 1814 | " training_stats = update_stats(training_stats, epoch_stats)\n", 1815 | "\n", 1816 | " test_loss, test_metric = evaluate(test_dataset, model, loss_fct, metric_fct)\n", 1817 | " print(f\"Test metric: {test_metric:.3f}\")\n", 1818 | " return training_stats" 1819 | ] 1820 | }, 1821 | { 1822 | "cell_type": "code", 1823 | "execution_count": null, 1824 | "metadata": { 1825 | "id": "Q5u_wmlKz60x" 1826 | }, 1827 | "outputs": [], 1828 | "source": [ 1829 | "model = GHyperNN(6, 2, 10, 3)\n", 1830 | "\n", 1831 | "graph_level_out = train_eval(model, train_data, validation_data, test_data,loss_fct=F.cross_entropy, metric_fct=quick_accuracy, print_every=140)\n", 1832 | "plot_stats(graph_level_out)" 1833 | ] 1834 | }, 1835 | { 1836 | "cell_type": "markdown", 1837 | "metadata": { 1838 | "id": "6Pr0VN9dvkID" 1839 | }, 1840 | "source": [ 1841 | "# 🛑 **Part 4**: Adding attention\n", 1842 | "\n", 1843 | "Attention proved to be beneficial for several architectures such as CNNs, RNNs and GNNs. In this final section, we will demonstrate a technique that allows us to integrate attention to our [hypergraph model](https://arxiv.org/abs/1901.08150).\n", 1844 | "\n", 1845 | "So far in our HGNN, we have been weighting the contribution from each of the hyperedges that contain a given node equally. The goal of adding attention is to allow the model to learn the weighting for each of them, allowing larger contribution from more 'important' hyperedges and ignoring irrelevant ones.\n", 1846 | "\n", 1847 | "To predict this new incidence matrix, for each pair containing the representation of a node and the representation of an incident hyperedge we need to predict a score. To do this we use a technique known broadcasting to create a matrix $B$ of size number_of_nodes $\\times$ number_of_hyperedges $\\times$(number_of_node_features + number_of_hyperedge_features) with $$ B_{i,j} = x_i || e_j $$ where $x_i$ is features of node $i$ and $e_j$ is hyperedge features of hyperedge $j$.\n", 1848 | "\n", 1849 | "\n", 1850 | "\n", 1851 | "
\n", 1852 | "\n", 1853 | "
\n", 1854 | "\n", 1855 | "\n", 1856 | "\n", 1857 | "Let's see how it works.\n", 1858 | "\n", 1859 | "\n", 1860 | "\n" 1861 | ] 1862 | }, 1863 | { 1864 | "cell_type": "code", 1865 | "source": [ 1866 | "# Given 2 tensors a and b, we want to generate (a_i || b_j) for all (i,j) pairs\n", 1867 | "# a: [6, 5]\n", 1868 | "# b: [10, 7]\n", 1869 | "# broadcast_a_b -> [6, 10, 5+7]\n", 1870 | "\n", 1871 | "a = torch.rand((6,5))\n", 1872 | "b = torch.rand((10,7))\n", 1873 | "\n", 1874 | "# expand first tensor on the 2nd dimension and second tensor on the 1st dimension\n", 1875 | "a = a.unsqueeze(1) # a: [6,1,5]\n", 1876 | "b = b.unsqueeze(0) # b: [1,10,7]\n", 1877 | "\n", 1878 | "# repeat the expanded dimensions to create tensors having the same dimension everywhere\n", 1879 | "# except from the dim where the concatenation will happen (last one in our case)\n", 1880 | "a = a.repeat(1, b.shape[1], 1) # a: [6,10,5]\n", 1881 | "b = b.repeat(a.shape[0], 1, 1) # b: [6,10,7]\n", 1882 | "\n", 1883 | "# concatenate the 2 volumes on the last dimension\n", 1884 | "broadcast_a_b = torch.concat((a, b), -1) # broadcast_a_b: [6,10,12]\n", 1885 | "\n", 1886 | "print('Output shape is', broadcast_a_b.shape)" 1887 | ], 1888 | "metadata": { 1889 | "id": "mSE8A-T6y5tV" 1890 | }, 1891 | "execution_count": null, 1892 | "outputs": [] 1893 | }, 1894 | { 1895 | "cell_type": "markdown", 1896 | "source": [ 1897 | "We then project these, through a learnable linear layer, into a scalar and used that as a weighted incidence matrix $H_{ij}$. Note that the resulting broadcast operation will associate non-zero score also to the non-incident (node, hedge) pairs. To avoid that, we will multiply elementwise the resulting weighted incidence matrix with the original one.\n", 1898 | "\n", 1899 | "Often datasets do not come attached with hyperedge features, but this does not mean you cannot apply attention to your model. You will need to infer hyperedge features from the node features. You can do this with any permutation invariant operation, such as sum, max, min, mean." 1900 | ], 1901 | "metadata": { 1902 | "id": "XYO3gM_ny03H" 1903 | } 1904 | }, 1905 | { 1906 | "cell_type": "markdown", 1907 | "source": [ 1908 | "# ❓ **Task 6**: Implement the hypergraph attention model" 1909 | ], 1910 | "metadata": { 1911 | "id": "6QROCMsBrTcP" 1912 | } 1913 | }, 1914 | { 1915 | "cell_type": "code", 1916 | "execution_count": null, 1917 | "metadata": { 1918 | "id": "vbngXRuPbEno" 1919 | }, 1920 | "outputs": [], 1921 | "source": [ 1922 | "class HyperGatLayer(nn.Module):\n", 1923 | " def __init__(self, input_dim, output_dim):\n", 1924 | " \"\"\"\n", 1925 | " One layer of hypergraph attention neural network\n", 1926 | "\n", 1927 | " Args:\n", 1928 | " input_dim : number of features of each node in hyergraph\n", 1929 | " output_dim : number of output features\n", 1930 | " \"\"\"\n", 1931 | " super(HyperGatLayer, self).__init__()\n", 1932 | " self.input_dim = input_dim\n", 1933 | " self.output_dim = output_dim\n", 1934 | " self.final_Linear = nn.Linear(input_dim, output_dim)\n", 1935 | " self.mlp1 = nn.Linear(2*input_dim, 2*input_dim)\n", 1936 | " self.mlp2 = nn.Linear(2*input_dim, 1)\n", 1937 | "\n", 1938 | " def forward(self, H, x):\n", 1939 | " \"\"\"\n", 1940 | " Args:\n", 1941 | " x : feature matrix [num_nodes, input_dim]\n", 1942 | " H : incidence matrix [num_nodes, num_hyper_edges]\n", 1943 | " returns:\n", 1944 | " x : output of one layer of hypergraph neural network [num_nodes, output_dim]\n", 1945 | " \"\"\"\n", 1946 | " num_nodes = x.shape[0]\n", 1947 | " num_edges = H.shape[-1]\n", 1948 | "\n", 1949 | " # create features for each hyperedges\n", 1950 | " edge_features = torch.transpose(H, 0, 1) @ x\n", 1951 | "\n", 1952 | " # ============ YOUR CODE HERE =============\n", 1953 | " # Use the broadcasting technique explained above to create all possible (x_i||e_j) pairs\n", 1954 | " #\n", 1955 | " # from x: [num_nodes, input_dim] and edge_features: [num_edges, input_dim]\n", 1956 | " # creates concat_features: [num_nodes, num_edges, 2*input_dim]\n", 1957 | " # you might need some preprocessing to achieve that\n", 1958 | " #\n", 1959 | " # ...\n", 1960 | " # concat_features = ...\n", 1961 | " # =========================================\n", 1962 | "\n", 1963 | "\n", 1964 | " # ============ YOUR CODE HERE =============\n", 1965 | " # Step 1: Project the concat_features into scalars by using\n", 1966 | " # self.mlp1->relu->self.mlp2->sigmoid\n", 1967 | " # The expected shape should be [num_nodes, num_edges]\n", 1968 | " #\n", 1969 | " # H_tilda = ...\n", 1970 | " #\n", 1971 | " # Step 2: Apply the masking to set the NON-incidence (node, hedge) pairs to 0\n", 1972 | " #\n", 1973 | " # H_tilda = ...\n", 1974 | " # =========================================\n", 1975 | "\n", 1976 | " # compute degree of nodes (D_v)^-0.5\n", 1977 | " degree_of_nodes = torch.nan_to_num(torch.pow(torch.diag(torch.sum(H, dim=-1)), -0.5), nan=0, posinf=0, neginf=0).to(torch.float32)\n", 1978 | " # compute degree of hyper edges (D_e)^-1\n", 1979 | " degree_of_edges = torch.nan_to_num(torch.pow(torch.diag(torch.sum(H, dim=0)), -1.0), nan=0, posinf=0, neginf=0).to(torch.float32)\n", 1980 | "\n", 1981 | " # ============ YOUR CODE HERE =============\n", 1982 | " # Compute D_v^{-0.5} H_tilda D_e^{-1} H_tilda^T D_v^{-0.5} x\n", 1983 | " #\n", 1984 | " # x = ...\n", 1985 | " # =========================================\n", 1986 | "\n", 1987 | " # apply MLP\n", 1988 | " x = self.final_Linear(x)\n", 1989 | " return x\n", 1990 | "\n", 1991 | "\n", 1992 | "class HyperGat(nn.Module):\n", 1993 | " def __init__(self, input_dim, output_dim, hidden_dim, num_layers):\n", 1994 | " \"\"\"\n", 1995 | " Hypergraph neural network containing num_layers GHyperNNLayer for hypergraph\n", 1996 | " level prediction\n", 1997 | "\n", 1998 | " Args:\n", 1999 | " input_dim : number of features of each node in hyergraph\n", 2000 | " output_dim : number of output features\n", 2001 | " hidden_dim : hidden dimension\n", 2002 | " num_layers : number of layers\n", 2003 | " \"\"\"\n", 2004 | " super(HyperGat, self).__init__()\n", 2005 | " self.input_dim = input_dim\n", 2006 | " self.output_dim = output_dim\n", 2007 | " self.hidden_dim=hidden_dim\n", 2008 | "\n", 2009 | "\n", 2010 | " # ============ YOUR CODE HERE =============\n", 2011 | " # Create a list of (num_layers) HyperGatLayer(s) with hidden dimension hidden_dim\n", 2012 | " #\n", 2013 | " # if num_layers > 1:\n", 2014 | " # ...\n", 2015 | " # ...\n", 2016 | " # self.hnn_layers = ...\n", 2017 | " # else:\n", 2018 | " # self.hnn_layers = ...\n", 2019 | " # =========================================\n", 2020 | "\n", 2021 | " self.hnn_layers = nn.ModuleList(self.hnn_layers)\n", 2022 | " self.num_layers = num_layers\n", 2023 | "\n", 2024 | " def forward(self,hgraph):\n", 2025 | " \"\"\"\n", 2026 | " Args:\n", 2027 | " hgraph : input hypergraph stored as HyperGraph class formed as a batch\n", 2028 | " returns:\n", 2029 | " y_hat : logits for each hypergraph in batch [batch_size, output_dim]\n", 2030 | " \"\"\"\n", 2031 | " x = hgraph.x\n", 2032 | " y = hgraph.y\n", 2033 | " H = hgraph.incidence_matrix()\n", 2034 | " batch = hgraph.batch\n", 2035 | " # ============ YOUR CODE HERE =============\n", 2036 | " # Apply (self.num_layers) HyperGatLayer(s), with ReLU nonlinearity\n", 2037 | " #\n", 2038 | " # ...\n", 2039 | " # x = ...\n", 2040 | " # =========================================\n", 2041 | "\n", 2042 | " # ============ YOUR CODE HERE =============\n", 2043 | " # Aggregate the node information to obtain a hypergraph-level prediction -> [batch_size, output_dim]\n", 2044 | " #\n", 2045 | " # remember that information about which hypergraph each node belongs to\n", 2046 | " # is stored in hgraph.batch\n", 2047 | " # y_hat = ...\n", 2048 | " # =========================================\n", 2049 | " return y_hat" 2050 | ] 2051 | }, 2052 | { 2053 | "cell_type": "code", 2054 | "execution_count": null, 2055 | "metadata": { 2056 | "cellView": "form", 2057 | "id": "dNNW2qU53b_B" 2058 | }, 2059 | "outputs": [], 2060 | "source": [ 2061 | "# @title ✅ [RUN] **Please run this unit test to validate your code. You might still have bugs but this is a good sanity check.**\n", 2062 | "def testing_hnn():\n", 2063 | " torch.random.manual_seed(0)\n", 2064 | " np.random.seed(0)\n", 2065 | "\n", 2066 | " input_dim = 6\n", 2067 | " output_dim = 2\n", 2068 | " hidden_dim = 128\n", 2069 | " hypergraph = Hypergraph_Dataset[0]\n", 2070 | " model = HyperGat(input_dim, output_dim, hidden_dim, 3)\n", 2071 | " # visualise(hypergraph)\n", 2072 | " out = model(create_mini_batch([hypergraph]))\n", 2073 | "\n", 2074 | " assert(out.shape[-1] == output_dim), \"Oups! 🤭 Output shape is wrong\"\n", 2075 | "\n", 2076 | " np.random.seed(0)\n", 2077 | " perm_x = torch.tensor(np.random.permutation(hypergraph.x.numpy()))\n", 2078 | "\n", 2079 | "\n", 2080 | " H = hypergraph.incidence_matrix()\n", 2081 | "\n", 2082 | "\n", 2083 | " np.random.seed(0)\n", 2084 | " A_perm = torch.tensor(np.random.permutation(H.numpy()))\n", 2085 | " A_perm = incidence_to_edgeindex(A_perm)\n", 2086 | " perm_hypergraph = HyperGraph(A_perm, perm_x, hypergraph.y)\n", 2087 | " torch.random.manual_seed(0)\n", 2088 | "\n", 2089 | " # visualise(perm_hypergraph)\n", 2090 | " # visualise(hypergraph)\n", 2091 | " out_model_perm = model(create_mini_batch([perm_hypergraph]))\n", 2092 | "\n", 2093 | "\n", 2094 | " assert (torch.allclose(out, out_model_perm, atol=1e-4)), \"🤔 Something is wrong in the model! You are not permuation invariant anymore 🥺\"\n", 2095 | " print(\"All good!\")\n", 2096 | "\n", 2097 | "testing_hnn()\n", 2098 | "np.random.seed(None)\n", 2099 | "torch.random.manual_seed(datetime.now().timestamp())" 2100 | ] 2101 | }, 2102 | { 2103 | "cell_type": "code", 2104 | "execution_count": null, 2105 | "metadata": { 2106 | "id": "9hmFOUnxU4Tl" 2107 | }, 2108 | "outputs": [], 2109 | "source": [ 2110 | "attention_model = HyperGat(6,2, 8, 3)\n", 2111 | "attention_model_out = train_eval(attention_model, train_data, validation_data,\n", 2112 | " test_data,loss_fct=F.cross_entropy,\n", 2113 | " metric_fct=quick_accuracy, print_every=140)\n", 2114 | "plot_stats(attention_model_out)" 2115 | ] 2116 | }, 2117 | { 2118 | "cell_type": "markdown", 2119 | "metadata": { 2120 | "id": "snqwnrsIz29i" 2121 | }, 2122 | "source": [ 2123 | "# ⭐️Further reading\n", 2124 | "\n", 2125 | "Hypergraph processing represent a new, fast-growing field of deep learning. This lab aims to guide you through understanding and familiarise yourself with the basics concepts in hypergraph learning. However, researchers works towards developing new, more powerful tools for creating hypergraph representations.\n", 2126 | "\n", 2127 | "If you are interested in finding more about hypergraphs, here are a selection of recent papers that you might find useful:\n", 2128 | "\n", 2129 | "[1] *HyperGCN: A New Method of Training Graph Convolutional Networks on Hypergraphs*, Yadati et al, NeurIPS 2019 \\\\\n", 2130 | "[2] *HNHN: Hypergraph Networks with Hyperedge Neurons*, Dong et al GRLW 2020 \\\\\n", 2131 | "[3] *Unignn: a unified framework for graph and hypergraph neural networks*, Huang et al, IJCAI 2021 \\\\\n", 2132 | "[4] *Nonlinear higher-order label spreading*, Tudisco et al, WWW 2021 \\\\\n", 2133 | "[5] *You are AllSet: A Multiset Function Framework for Hypergraph Neural Networks*, Chien et al, ICLR 2022 \\\\\n", 2134 | "[6] *EvolveHypergraph: Group-Aware Dynamic Relational Reasoning for Trajectory Prediction*, Li et al, CVPR 2022 \\\\\n", 2135 | "[7] *Equivariant Hypergraph Diffusion Neural Operators*, Wang et al, ICLR 2023 \\\\\n", 2136 | "\n", 2137 | "\n" 2138 | ] 2139 | } 2140 | ], 2141 | "metadata": { 2142 | "colab": { 2143 | "provenance": [], 2144 | "include_colab_link": true 2145 | }, 2146 | "gpuClass": "standard", 2147 | "kernelspec": { 2148 | "display_name": "Python 3", 2149 | "name": "python3" 2150 | }, 2151 | "language_info": { 2152 | "name": "python" 2153 | } 2154 | }, 2155 | "nbformat": 4, 2156 | "nbformat_minor": 0 2157 | } --------------------------------------------------------------------------------