├── .gitignore ├── LICENSE ├── README.md ├── notebooks ├── data_exploration.ipynb ├── dataset_testing.ipynb ├── gat.ipynb ├── graphsage.ipynb └── simple_baseline.ipynb ├── results.md ├── src ├── config.json ├── config_gat.json ├── config_graphsage.json ├── datasets │ └── link_prediction.py ├── layers.py ├── main.py ├── models.py └── utils.py └── visualizations └── .DS_Store /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Raunak Kumar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphSAGE and Graph Attention Networks for Link Prediction 2 | This is a PyTorch implementation of GraphSAGE from the paper [Inductive Representation Learning on Large Graphs](http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs) 3 | and of Graph Attention Networks from the paper [Graph Attention Networks](https://arxiv.org/pdf/1710.10903.pdf). The code in this repository focuses on the link prediction task. Although the models themselves do not make use of temporal information, the datasets that we use are temporal networks obtained from [SNAP](http://snap.stanford.edu/data/index.html#temporal) and [Network Repository](https://networkrepository.com/dynamic.php). 4 | 5 | This is just some code I have been playing around with - there may be issues or bugs. If you use this code and find them, let me know! 6 | 7 | ## Usage 8 | 9 | In the `src` directory, edit the `config_gat.json` or `config_graphsage` file to specify arguments and 10 | flags. Then run `python main.py --json config_{}.json`. 11 | 12 | ## Limitations 13 | 14 | Although a nearly identical implementation of Graph Attention Networks performs well on the node classification task, I had trouble training them for the link prediction in these temporal networks. The GraphSAGE model seems to work pretty well, so there might be a bug in my code, or inadequate hyperparameter search. 15 | 16 | ## References 17 | * [Inductive Representation Learning on Large Graphs](http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs), Hamilton et al., NeurIPS 2017. 18 | * [Graph Attention Networks](https://arxiv.org/pdf/1710.10903.pdf), Velickovic et al., ICLR 2018. 19 | -------------------------------------------------------------------------------- /notebooks/data_exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import importlib\n", 10 | "import math\n", 11 | "import os\n", 12 | "import sys\n", 13 | "\n", 14 | "src_path = os.path.join(os.path.dirname(os.path.abspath('')), 'src')\n", 15 | "sys.path.append(src_path)\n", 16 | "\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import numpy as np\n", 19 | "\n", 20 | "import datasets\n", 21 | "\n", 22 | "%matplotlib inline\n", 23 | "\n", 24 | "%load_ext autoreload\n", 25 | "%autoreload 2" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "# How many edges in the test split are duplicates and repeats?\n", 33 | "\n", 34 | "For different train / test splits, we will look at the following quantities:\n", 35 | "* **Duplicates** - Within the test split, what fraction of edges are duplicates?\n", 36 | "* **Repeats** - What fraction of test edges already exist in the train split?" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 5, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "source = 'networkrepository' # SNAP or networkrepository\n", 46 | "\n", 47 | "if source == 'SNAP':\n", 48 | " data_dir = os.path.expanduser('~/Documents/Datasets/temporal-networks-snap')\n", 49 | " name = 'CollegeMsg.txt'\n", 50 | " path = os.path.join(data_dir, name)\n", 51 | " with open(path, 'r') as f:\n", 52 | " lines = f.readlines()\n", 53 | "\n", 54 | " if name.endswith('tsv'):\n", 55 | " separator = '\\t'\n", 56 | " lines = np.array([line.split(separator)[:3] for line in lines[1:]], dtype=np.str)\n", 57 | " vertex_id = {j : i for (i, j) in enumerate(lines.flatten())}\n", 58 | " lines = np.array([vertex_id[v] for v in lines.flatten()]).reshape(lines.shape)\n", 59 | " lines = np.array([[int(x) for x in line] for line in lines])\n", 60 | " else:\n", 61 | " if name.endswith('txt'):\n", 62 | " separator = ' '\n", 63 | " elif name.endswith('csv'):\n", 64 | " separator = ','\n", 65 | " lines = np.array([[int(x) for x in line.split(separator)[:3]] for line in lines])\n", 66 | "else:\n", 67 | " data_dir = os.path.expanduser('~/Documents/Datasets/temporal-networks-network-repository')\n", 68 | " name = 'fb-forum'\n", 69 | " path = os.path.join(data_dir, name, name+'.edges')\n", 70 | " with open(path, 'r') as f:\n", 71 | " lines = f.readlines()\n", 72 | " \n", 73 | " if name == 'soc-sign-bitcoinalpha':\n", 74 | " lines = np.loadtxt(path, delimiter=',', dtype=np.int64)\n", 75 | " lines = np.concatenate((lines[:, :2], lines[:, 3:]), axis=1)\n", 76 | " elif name == 'fb-forum' or name == 'ia-contacts_hypertext2009':\n", 77 | " lines = np.loadtxt(path, delimiter=',', dtype=np.float64)\n", 78 | " lines = np.array(lines, dtype=np.int64)\n", 79 | " elif name == 'ia-contact':\n", 80 | " lines = [line.split('\\t') for line in lines]\n", 81 | " lines = [[*line[0].split(), line[1].split()[1]] for line in lines]\n", 82 | " lines = np.array(lines, dtype=np.int64)\n", 83 | " elif name == 'ia-enron-employees':\n", 84 | " lines = [line.split() for line in lines]\n", 85 | " lines = [[line[0], line[1], line[3]] for line in lines]\n", 86 | " lines = np.array(lines, dtype=np.int64)\n", 87 | " elif name == 'ia-radoslaw-email' or name == 'soc-wiki-elec':\n", 88 | " lines = [line.split() for line in lines[2:]]\n", 89 | " lines = [[line[0], line[1], line[3]] for line in lines]\n", 90 | " lines = np.array(lines, dtype=np.int64)\n", 91 | " \n", 92 | "lines = lines[lines[:, 2].argsort()]\n", 93 | "m = len(lines)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 6, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Split: 0.05, duplicate: 0.790, repeat: 0.083\n", 106 | "Split: 0.10, duplicate: 0.787, repeat: 0.141\n", 107 | "Split: 0.15, duplicate: 0.786, repeat: 0.210\n", 108 | "Split: 0.20, duplicate: 0.784, repeat: 0.278\n", 109 | "Split: 0.25, duplicate: 0.779, repeat: 0.316\n", 110 | "Split: 0.30, duplicate: 0.775, repeat: 0.349\n", 111 | "Split: 0.35, duplicate: 0.770, repeat: 0.381\n", 112 | "Split: 0.40, duplicate: 0.767, repeat: 0.424\n", 113 | "Split: 0.45, duplicate: 0.761, repeat: 0.455\n", 114 | "Split: 0.50, duplicate: 0.758, repeat: 0.490\n", 115 | "Split: 0.55, duplicate: 0.758, repeat: 0.520\n", 116 | "Split: 0.60, duplicate: 0.755, repeat: 0.552\n", 117 | "Split: 0.65, duplicate: 0.752, repeat: 0.588\n", 118 | "Split: 0.70, duplicate: 0.746, repeat: 0.615\n", 119 | "Split: 0.75, duplicate: 0.738, repeat: 0.658\n", 120 | "Split: 0.80, duplicate: 0.739, repeat: 0.668\n", 121 | "Split: 0.85, duplicate: 0.718, repeat: 0.728\n", 122 | "Split: 0.90, duplicate: 0.643, repeat: 0.773\n", 123 | "Split: 0.95, duplicate: 0.525, repeat: 0.839\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "duplicates, repeats = [], []\n", 129 | "splits = np.linspace(0.05, 0.95, 19)\n", 130 | "static_train = set()\n", 131 | "static_test = {(edge[0], edge[1]) for edge in lines}\n", 132 | "cur = 0\n", 133 | "for split in splits:\n", 134 | " idx = math.ceil(split*m)\n", 135 | " train, test = lines[:idx, :], lines[idx:, :]\n", 136 | " \n", 137 | " moving_edges = {(edge[0], edge[1]) for edge in lines[cur:idx]}\n", 138 | " staying_edges = {(edge[0], edge[1]) for edge in lines[idx:]}\n", 139 | " removals = moving_edges - staying_edges\n", 140 | " \n", 141 | " static_train.update(moving_edges)\n", 142 | " static_test = static_test - removals\n", 143 | " \n", 144 | " cur = idx\n", 145 | " \n", 146 | " # Within the test split, what fraction of edges are duplicates? \n", 147 | " duplicate = 1 - len(static_test) / len(test)\n", 148 | " duplicates.append(duplicate)\n", 149 | " \n", 150 | " # What fraction of edges in the test split already exist in the train split?\n", 151 | " repeat = np.sum([(edge in static_train) for edge in static_test]) / len(static_test)\n", 152 | " repeats.append(repeat)\n", 153 | " \n", 154 | " print('Split: {:.2f}, duplicate: {:.3f}, repeat: {:.3f}'.format(split, duplicate, repeat))" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 7, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "[]" 166 | ] 167 | }, 168 | "execution_count": 7, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | }, 172 | { 173 | "data": { 174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3deXydZZ338c8ve9LsbdIm3Vu6pCktSthHWYoKyjIvFwRlHhkVhBG3GZnBcV7Kg86jo6PiMw8uoKjgKIM4QllGFCj7YluWQlu6b2m6pEnTJE3bbL/nj/tOe3p6kpyWnJxzku/79Tqvc+/nd+5zcn65ruu+rtvcHRERkWgZyQ5ARERSkxKEiIjEpAQhIiIxKUGIiEhMShAiIhKTEoSIiMSkBDEKmdk1ZvZcxHy7mc1IZkzHy8zczE5KdhzRzOyfzexnA6w/6tzHWP8/ZvaJxEQn8TCzzWZ2YTg94Oc50ilBpKDwC3rAzNrMrMXMXjCz680sIZ+Xuxe6+8a3cwwz+6WZfXOoYkpX7v5/3P3TAGY2LUxkWcex/8Xu/qsTee3IH7a3Y7Ak9jaPPSyJ3cwuN7PXzKzVzPaY2RNmNu14j/N2P890N2reaBq61N0fN7MS4Fzgh8AZwN8mNyyR1BYmoLuBDwJPAoXAe4HeZMaVltxdjxR7AJuBC6OWnU7wBZ8fzj8FfDpi/TXAcxHzDnwe2AjsAb4LZAyw7UnhdD7wPWALsA94DsgP1/0O2BkufwaoDZdfB3QBnUA78FC4vBr4PdAIbAI+P8B7zgX+HdgK7AJ+0ve64fqbgB1AA/DJqJjHAg8BrcBS4JtR728u8GegGVgDXBGx7v3AKqAN2A58uZ/4tgCnhtNXh68/L5z/NPBAOH0L8Otwemu4XXv4OKvv3IfvdW94Xi6OeJ3Dn+tg20bFd0/4/TgQvtY/hsvPBF4AWoDXgfOivjMbw/e+Cfg4UAMcBHrC47T083rH7Bux7pPA6jDmx4Cp4fJnwvOxPzz2R2N8B1oIv+PhsorwPVUC44CHw22agWcJv9NRx/kw8NoA37VbgPuB/wrjfwVYGOvvb7DPM9m/FYl+JD0APWJ8KDESRLh8K3BDOH34hyScv4Zjf/SXAOXAFGBt9A9P1LZ9P7a3h8eeCGQCZwO54bpPAkXhH/JtkX+EwC+Bb0bMZwDLga8BOcCM8Aflff2859uAxWG8RQQ/+N8K111EkDTmA2OA30TFfG/4KADmAdv63l+4/TaCklcW8E6ChNmX3HYA7wqny4B39hPf3cA/hNN3ABsiPou7gS+F05E/KNPCOLOiPqcu4Nrw/N5AkPQs+nMdbNvBvjfhZ9hEkAQzgPeE8xXheWkF5oTbVkWck6O+HzFeZ6B9/xpYT5BosoB/AV6I9V3r59h3Af8aMf9Z4I/h9LcI/nHIDh/vinUuCL5rB4EfAOcDhVHrbwnP64fD43yZIMllR5/HwT7Pkf5QG0R6aSD4AY3Xv7l7s7tvJfgBvmqgjcM2jk8CX3D37e7e4+4vuPshAHe/y93bwvlbgIVhFVgspwEV7n6ru3d60MZxJ3BljNc1gh/BL4XxtgH/J2LbK4BfuPub7r4/fO2+fTOBDwFfd/cOd18FRNbhXwJsdvdfuHu3u79CUKr5cLi+C5hnZsXuvjdcH8vTBFV9EPwwfSti/txwfby2uPud7t4TxloFjB+CbaNdDTzq7o+6e6+7/xlYRpAwICyRmlm+u+9w95XH8R762/czBIl9tbt3E3yOp5jZ1DiP+xuO/p5+LFwGwWdVRVAi6XL3Zz385Y4UftfOI0iQ9wF7wjaywojNlrv7/e7eBXwfyCMobUkEJYj0MpGgaB2vbRHTWwiqfAYyjuAPZUP0CjPLNLNvm9kGM2sl+C+rb59YpgLVYSN7i5m1AP9M7B+3CoL//pdHbPvHcDlh3NHvJXLfrKj1kdNTgTOi4vg4MCFc/yGCH8wtZva0mZ3Vz/t5GniXmU0g+G/+v4BzwobPEuC1fvaLZWffhLt3hJOFQ7BttKnAR6Le+18BVWGi/ShwPbDDzB4xs7nxHHSQfacCP4x4vWbACL678XgSyDezM8Kkcgrwh3DddwlKJ38ys41mdvMAMb7k7le4ewVBQn838NWITbZFbNsL1DP438eoowSRJszsNII/sr6rS/YT/Kj2mXDMTjA5YnoKQQlkIHsIiuYzY6z7GHA5cCHBD+K0vtDC5+j/5LYBm9y9NOJR5O7v51h7COqZayO2LXH3vh/CHTHeS59GoBuYFLEsctttwNNRcRS6+w0A7r7U3S8nqON+gOA/zmO4+3qgg6Bd55mwlLOToP3lufBH5pjdYh0rgWJ9BvdEvfcx7v5tAHd/zN3fQ/Bf+VsEJbxYxzn2hfrfdxvwmajXzHf3F+J6A8F5vI+gFPEx4OHwXBOWXv/B3WcAlwJ/b2aL4jjmUuC/Caoo+xz+joQl50kM/vcx6oa+VoJIcWZWbGaXENSx/9rd3whXvQZ80MwKwqs2PhVj95vMrMzMJgNfIPivt1/hH+ddwPfNrDosNZxlZrkE7QKHCOqwCwiqDiLtIqj77fMXoNXM/snM8sNjzQ8TXazXvRP4gZlVhu97opm9L9zkPuAaM5tnZgXA1yP27SH4478lPBdzgf8VcfiHgdlm9jdmlh0+TjOzGjPLMbOPm1lJWNXQStA425+ngRs5Up30VNR8tEaCqpjh6mMS/Rn8GrjUzN4Xnv88MzvPzCaZ2Xgzu8zMxhB8ru0cee+7gElmlhPrRQbZ9yfAV8ysNty2xMw+MkCMsfyGoITycY5UL2Fml5jZSWGVZN9ndcznZWZ/ZWbXRnyX5gKXAS9FbHaqmX0wvGT1i+H7eCn6WFGG+/NMOiWI1PWQmbUR/Ef2VYJ60shLXH9AcNXQLoK66f+McYwHCRqKXwMeAX4ex+t+GXiD4GqgZuDfCL4ndxNU7WwnuOon+o/p5wR1+S1m9kD4w30pQRXBJoJSws8ISh+x/BNB9cFLYRXW48AcAHf/H4I2lCfDbZ6M2vfG8Lg7Ca7m+S3BHzzhf5/vJWjPaAi3+TeChnaAvwE2h695PUG9fX+eJkiUz/Qzf5SwSuhfgefD85LoOu5vAf8SvtaX3X0bQanvnwl+3LYRXA2WET7+geCcNBO0o/xdeJwngZXATjPbE+N1+t3X3f9AcH7vDc/pm8DFEfveAvwqjPGKWG/C3V8mKCFXA/8TsWoWwfeiHXgR+JG7PxXjEC0ECeENM2snqK78A/CdiG0eJEhCewm+Ax8M/0noVxI+z6Tru3JCRhgzc2BWWDUyqpjZvwET3F09kuUYZnYLwZVUA/0zIKgEISOAmc01swUWOJ2guu0Pg+0nIgNTT2oZCYoIqpWqgd0EHf0eTGpEIiOAqphERCSmhFYxmdlFZrbGzNbHumbZzKaY2RIze9XMVpjZ+yPWfSXcb03E1SwiIjJMElaCsKCH61qC7v31BFfFXBX2dO3b5g7gVXf/sZnNI+j1OS2c/i3B+EPVBFcuzA6vjIlp3LhxPm3atIS8FxGRkWr58uV7wg6Fx0hkG8TpwPqw2ztmdi/BJXerIrZxoDicLuFIR5XLgXvDIR02mdn68Hgv9vdi06ZNY9myZUP7DkRERjgz29LfukRWMU3k6CEP6jm2u/0twNVmVg88CnzuOPbFzK4zs2VmtqyxsXGo4hYRERKbICzGsuj6rKuAX7r7JILxcO4Ju73Hsy/ufoe717l7XUVFzBKSiIicoERWMdVz9Jg4scY6+RTBUM64+4tmlkcw+Fs8+4qISAIlsgSxFJhlZtPDMV2uJBjvP9JWYBGAmdUQjCTaGG53pZnlmtl0gi72f0lgrCIiEiVhJQh37zazGwnuKJUJ3OXuK83sVmCZuy8mGM/lTjP7EkEV0jXh+O4rzew+ggbtbuCzA13BJCIiQ2/EdJSrq6tzXcUkInJ8zGy5u9fFWqexmEREJKZRPxZTd08v3/3TGqpL8qkqyaO6NHguH5NDMOy8iMjoNOoTRNP+Tu56bhNdPUdXteVmZVBVkkdVST5VpXlBAol4rirJpzgvS0lEREasUZ8gxhfnseYbF7Nn/yF2tBxkx74DNPQ97zvIzn0HeWlDE7vaDtHTe3QSGZOTSVVY4qgqyaO0IIfC3CyK8rIoysumMDeL4rwsCsP5orwsCnOzyMvOTNK7FRGJ36hPEAAZGUZlUR6VRXksnFwac5vunl4a2w8dTh47Wg7SED7v2HeAtbvaaD3QzYGuwS+2ysnMCJJFXtbhpNGXQErzcygryKZ0TPBcVpBDafhcVpBDfo6Si4gMDyWIOGVlZgTVTSX5QFm/23X19NJ+sJv2Q920Huyi7WA37Qe7aTsUTPc92iPm2w92s625g7aD3bR0dLK/s/8kk5uVcThplB5OIEeSSWVxLqdMLmVKeYGqv0TkbVGCGGLZmRmUjcmhbEzM+73H5VB3D/s6utjb0cXejk5aOjojprvYuz+Y33egk3W72w+vj6wCqyzK5bRp5dRNK+O0aeXUVBWTmaGEISLxU4JIQblZmVQWZ1JZnBf3Pu5O26FuGloOsHzLXpZuambp5r088sYOAApzs3jHlFJOm1bOadPKOWVyqaqrRGRA6ig3wjW0HGDp5maWbm5m2ea9rNnVhjtkZxrzJ5YEpYypZdRNK6f8bZR6RCQ9DdRRTglilNnX0cXyrUHpYummZlbU76OzpxeAkyoLOW1aGadOLae2upiZFYXkZKkvpchIpgQh/TrY1cMb2/fxl03NLNvczLIte2k72A0EpYxZlUXUVBVTU1XEvOpi5lUVU1qgkobISDFQglAbxCiXl515uF0CoKfX2djYzqodraze0caqHa08s66R379Sf3if6pI8aqqKmVddHDxXFTOlvIAMNYKLjChKEHKUzAxj1vgiZo0v4vJTjixvbDvE6h2tYeJoZVVDK0+tbTx85dSYnEzmhiWNmqpiaqtLqKkqIjdLDeEi6UpVTHLCDnb1sG5XO6t27AtKGw1B8mg7dKSKau6EYhZMKmHh5FIWTirlpMpCXW4rkkJUxSQJkZedycmTSjh5UsnhZe5O/d4DvLl9H6/X72NFfQuLX2vgP1/eCkBBTibzq0tYOLmEBZOCpDG5PF+d+kRSkBKEDCkzY3J5AZPLC7j45CoAenudjXv2s6K+hRX1+3i9voVfvbiFzu5NAJQVZIfJIkgaCyaXUFkUfx8QEUkMVTFJUnT19LJmZ1uQMLa18Hp9C+t2tx9u06gqyWNeVTGTywuYVJYfJJ2yAiaX51OUl53k6EVGDlUxScrJzsxg/sQS5k8s4WNnTAHgQGcPKxuOVE2t2dnGy5uaaQ/bNPqUFmQfThaTywqYVF7A5DCJTCzN12i5IkNECUJSRn5OJnXTyqkLL7mFoE2jpaOLbXs72NZ8IHzuYNveA7y1o43HV+0+3NGvz/ji3DCBFDB7fBFnzxzL/IklahwXOU5KEJLSzOzw4IcLJh07FHtvr7O77dCRxBGRRF7e2MQfXt0OQHFeFmfOGMvZM8dy9knjmFVZqIZxkUEoQUhay8gwJpTkMaEk73Bnv0i72w7y4oYmXtzQxPMb9vCnVbsAGFeYy9kzx3LOSWM5e+Y4JpcXDHfoIilPjdQyqmxr7uCFDXt4YUMTL2xoorHtEACTy/M5e8Y4zj5pLGfNHKurqGTU0FhMIjG4O+t3t/P8+iBhvLSxidZwHKrZ4ws5e+Y4zpo5loWTSqksytVQIjIiKUGIxKGn11nZsI/n1zfxwoY9LN3czMGuoAE8NyuDKeUFTB1bwJTyMeFzAVPGBpfhakgRSVdKECIn4FB3D69tbWHtrja2NHWwpTlo/N7S1HHUvcfNoLokP0gYYdKYOraAqeVjmFJeQEmB+m1I6lI/CJETkJuVyRkzxnLGjLFHLXd3GtsPsbWpIypx7OeJt3axp73zqO1L8rMpzM06vC+AHz4WeDgXTB9ZTsRygIqi3COlmLFjmFpewLSxY6guzSMrU/ftkKGnBCFynMyMyqI8Kovyjuqz0af9UDdbmzrY2ryfrVElDsPou7rWCEofhoXHhSNX3lq4LuDA7taDbNqzn6fXNnKo+0jfj6wMY2JZUIKZNvZI9dfUsUEJRreWlROlBCEyxApzs4KbK1UXJ+T4vb3OrraDbGnqCEoxzfuDkkxTBw9u2364ob3P+OLcoLorbC+pLs2nuiSf6tI8qtXzXAagBCGSZjIyjKqSfKpK8jkzqvoLoKWj83DV15Y9+9nSHCSSZ9c1srvtENHNjmPH5ARJI0wYE0vDJBIuGzdGV3CNVglNEGZ2EfBDIBP4mbt/O2r9D4Dzw9kCoNLdS8N1PcAb4bqt7n5ZImMVGSlKC3IoLchh4eRje553dveyq/Ug21sO0BA+trccpKHlABsb9/Psuj10dPYctU9OZgZVpXlUl+QzsSyf06eXc97sCiqL1VdkpEtYgjCzTOB24D1APbDUzBa7+6q+bdz9SxHbfw54R8QhDrh7xD3NROTtysnKODwceyzuTuuB7iMJZN+BcDpIIkve2s39y4Pbz9ZWF3PenArOn1PJKZNL1VA+AiWyBHE6sN7dNwKY2b3A5cCqfra/Cvh6AuMRkUGYGSUF2ZQUZMdsQ3F3Vu9oY8ma3Ty9ppGfPL2R25dsoDgvi3fNDpLFubMrqCjKTUL0MtQSmSAmAtsi5uuBM2JtaGZTgenAkxGL88xsGdANfNvdH4ix33XAdQBTpkwZorBFpD9mdrgB/rPnn8S+A108t25PkDDWNvLIih0AnDyxhPPmVHBeWLrQSLrpKZEJItY3or9eeVcC97t7ZOXnFHdvMLMZwJNm9oa7bzjqYO53AHdA0FFuKIIWkfiV5GfzgQVVfGBBFb29zqodrTy1ZjdPrWnk9iXr+Y8n11NakM27Z1Vw3pwK3j27gnGFKl2ki0QmiHpgcsT8JKChn22vBD4bucDdG8LnjWb2FEH7xIZjdxWRVJCRYYdvAnXjBbNo6ejk2bB08czaRha/3oAZLJhUysXzJ/CBk6s0im6KS9hQG2aWBawFFgHbgaXAx9x9ZdR2c4DHgOkeBmNmZUCHux8ys3HAi8DlkQ3c0TTUhkjq6u113mzYx1NrGnl89S5W1O8DYOHkUi5dUMX7T66iujQ/yVGOTkkbi8nM3g/cRnCZ613u/q9mdiuwzN0Xh9vcAuS5+80R+50N/BToBTKA29z95wO9lhKESPrY1tzBwyt28PCKBlY2tAJw6tQyLllQxQdOrtIltMNIg/WJSMratGc/j6xo4OEVO3hrZxtmcPq0ci5ZWM3F8yeozSLBlCBEJC2s393GQ68HJYsNjfvJMDhr5lguWVDNRbUTKBuTk+wQRxwlCBFJK+7Oml1tPBwmi81NHWRmGOecNI5LFlRxwdxKDDjQ1cOBzh46OnuOmu7o7OZgV990T9R23XR09lBekMO/XDKP8lGedJQgRCRtuTsrG1p5eMUOHnmjgW3NB477GHnZGeRnZ1KQk0V+Tib52Zms2dXGpLJ87vnUGUwcxQ3kShAiMiK4Oyvq97F0czPZmcGPfn5OJgU5mYd/+AtysijIySQvO1yenRlzsMGXNzbx6buXMSYni7s/dTqzxxcl4R0lnxKEiEgMq3e08om7/sLBrh7uuua0mPf3GOkGShAaXUtERq2aqmJ+f8PZjCvM5eM/e5nHV+1KdkgpRQlCREa1yeUF/O76s5gzoYjP/Ho59y3bNvhOo4QShIiMemMLc/nttWdy9syx/OP9K/jRU+sZKdXvb4cShIgIMCY3i59/4jQuW1jNd/64hm88vJre3tGdJHTLURGRUE5WBrd99BTGFuZw1/ObaNp/iO9+eCE5WaPzf2klCBGRCBkZxtcumUdFUS7f+eMamvd38pOrT2VM7uj7uRw0LZrZOWY2Jpy+2sy+H97gR0RkRDIz/u68k/jOhxfwwoYmPnbnSzS1H0p2WMMunnLTj4EOM1sI/COwBbg7oVGJiKSAK+om89OrT+WtnW185Ccvsq25I9khDat4EkR3eJ+Gy4EfuvsPgdHZ5VBERp0L543nPz99BnvaD/GhH7/A6h2tyQ5p2MSTINrM7CvA3wCPmFkmkJ3YsEREUkfdtHJ+d/3ZmMEVP32Rv2xqTnZIwyKeBPFR4BDwSXffCUwEvpvQqEREUsycCUX8/oazqSjK5eqfv8yfVu5MdkgJN2iCCJPC74G+u3bsAf6QyKBERFLRpLIC7r/+bGqqirn+18u5b+nI7nUdz1VM1wL3E9wCFIISxAOJDEpEJFWVj8nht9eewVkzx/LVB95g34GuZIeUMPFUMX0WOAdoBXD3dUBlIoMSEUllBTlZfPHC2XT1OM+ua0x2OAkTT4I45O6dfTNmlgWM7v7nIjLqvXNKGWUF2TyxeneyQ0mYeBLE02b2z0C+mb0H+B3wUGLDEhFJbZkZxvlzKlmyZjfdPb3JDich4kkQNwONwBvAZ4BH3f2rCY1KRCQNLKoZT0tHF69ua0l2KAkRT4L4nLvf6e4fcfcPu/udZvaFhEcmIpLi3jV7HFkZxuOrR+aNhuJJEJ+IseyaIY5DRCTtFOdlc8aM8hHbDtFvgjCzq8zsIWC6mS2OeCwBmoYvRBGR1LVo7njW725nS9P+ZIcy5AYav/YFYAcwDvhexPI2YEUigxIRSReLaiq59eFVPLF6N5/8q+nJDmdI9Zsg3H0LwcitZw1fOCIi6WXq2DGcVFnIE2/tGnEJIp6e1Gea2VIzazezTjPrMbPRM5yhiMggFtVU8vLGZloPjqxe1fE0Uv8/4CpgHZAPfBr4j3gObmYXmdkaM1tvZjfHWP8DM3stfKw1s5aIdZ8ws3XhI1ZDuYhISriwZjzdvc4za0dWr+q47qHn7uvNLNPde4BfmNkLg+0TDgt+O/AeoB5YamaL3X1VxHG/FLH954B3hNPlwNeBOoJe28vDfffG/9ZERIbHO6eUUVqQzZOrd3PJgupkhzNk4ilBdJhZDvCamX3HzL4EjIljv9OB9e6+MRyq416Cmw715yrgt+H0+4A/u3tzmBT+DFwUx2uKiAy7yF7VPb0jZySieBLE34Tb3QjsByYDH4pjv4lA5Fi49eGyY4T3uJ4OPHk8+5rZdWa2zMyWNTaOrKKdiKSXRTWV7O3o4pWtI6eiI54EsQfodPdWd//fwE1AQxz7WYxl/aXWK4H7wyqsuPd19zvcvc7d6yoqKuIISUQkMd49u2LE9aqOJ0E8ARREzOcDj8exXz1BaaPPJPpPLFdypHrpePcVEUm6vl7VT46gXtXxJIg8d2/vmwmnCwbYvs9SYJaZTQ/bMK4EFkdvZGZzgDLgxYjFjwHvNbMyMysD3hsuExFJWRfMHc+63e1sbepIdihDIp4Esd/M3tk3Y2anAgcG28nduwnaLR4DVgP3uftKM7vVzC6L2PQq4F5394h9m4FvECSZpcCt4TIRkZR1YU1wL7WRUs1kEb/LsTcwO43gCqS+Kp4q4KPuvjzBsR2Xuro6X7ZsWbLDEJFR7sLvP82E4jx+/ekzkh1KXMxsubvXxVo3aD8Id19qZnOBOQSNx2+5+8jqLigiMkQW1VRy13ObaDvYRVFedrLDeVsGGs31gvD5g8ClwGxgFnBpuExERKIsmjuerh7nmbV7kh3K2zZQCeJcgn4Jl8ZY58B/JyQiEZE09s4ppZQWZPPE6l18YEFVssN5WwYazfXr4fPfDl84IiLpLSsz46he1ZkZsbp1pYd+E4SZ/f1AO7r794c+HBGR9LeoppI/vLqdV7fupW5aebLDOWEDXeZaNMhDRERiONKrOr07zQ1UxfS/hzMQEZGRojgvm9Onl/PE6l3cfPHcZIdzwuK5YdAMM3vIzBrNbLeZPWhmM4YjOBGRdLWoJv17VcfTk/o3wH0EHeSqgd9x9LhJIiISpa9X9RNvpW+v6ngShLn7Pe7eHT5+Tf+jsoqICMG9qmdWjOGJNG6HiCdBLDGzm81smplNNbN/BB4xs/Lwzm8iIhLDhTXjeXlTE21peq/qeBLER4HPAEuAp4AbgE8CywENfiQi0o9FNUGv6mfXpWev6njGYpo+HIGIiIw0fb2qH1+9i/efnH69qgdNEGb2v2Itd/e7hz4cEZGRIyszg/NmV/DUmsa07FUdTxXTaRGPdwG3AJcNtIOIiAQW1YyneX8nr6bhvarjqWL6XOS8mZUA9yQsIhGREeTcOUGv6ife2p12w27EU4KI1kEw7LeIiAwisld1uomnDeIhjvR7yADmEXScExGROFwwt5JvPrKabc0dTC4vSHY4cRs0QQD/HjHdDWxx9/oExSMiMuJcWDOebz6ymidW7+Kac9LnwtB42iCeHo5ARERGqmnjwl7Vb+1OqwQx0C1H28ystb/HcAYpIpLuLqwZz0sb06tXdb8Jwt2L3L0YuA24GZgITAL+Cfjm8IQnIjIyXDC3Mu16VcdzFdP73P1H7t7m7q3u/mPgQ4kOTERkJDl1ahkl+dlpNXhfPAmix8w+bmaZZpZhZh8HehIdmIjISBLcq7ri8L2q00E8CeJjwBXArvDxkXCZiIgch75e1a9tS49e1fFcxbQZuDzxoYiIjGyR96o+dWrq96o+kZ7UIiJyAkryszltWjlPpkk7hBKEiMgwWlRTyZpdbWxrTv17VQ/UD+IL4fM5J3pwM7vIzNaY2Xozu7mfba4ws1VmttLMfhOxvMfMXgsfi080BhGRVHJhzXiAtBibaaASxN+Gz/9xIgc2s0zgduBigvGbrjKzeVHbzAK+Apzj7rXAFyNWH3D3U8KHhhcXkRFh2rgxzAh7Vae6gRqpV5vZZqDCzFZELDfA3X3BIMc+HVjv7hsBzOxegsbuVRHbXAvc7u57CQ6a+mdMRORturBmPL94fhPth7opzI1nSLzkGKgn9VXAmcB64NKIxyXh82AmAtsi5uvDZZFmA7PN7Hkze8nMLopYl2dmy8Llfx3rBczsunCbZY2NjXGEJCKSfIv6elWvTe3frQEbqd19p7svBHYAReGjwd23xHHsWPfWi+4dkkVwb4nzgKuAn5lZabhuirvXEfS5uM3MZsaI7w53r3P3uoqKijhCEhFJvr5e1Y+n+NVMg17FZGbnAusI2hN+BKw1s/MCcBgAABCWSURBVHfHcex6YHLE/CSgIcY2D7p7l7tvAtYQ3ozI3RvC543AU8A74nhNEZGUl5WZwXlzKngqxXtVx3OZ6/eB97r7ue7+buB9wA/i2G8pMMvMpptZDnAlEH010gPA+QBmNo6gymmjmZWZWW7E8nM4uu1CRCStLaoZT9P+Tl7b1pLsUPoVT4LIdvc1fTPuvhbIHmwnd+8GbgQeA1YD97n7SjO71cz6rkp6DGgys1XAEuAmd28CaoBlZvZ6uPzb7q4EISIjxrlhr+pUvtzV3Acu3pjZXQRtB/eEiz4OZLn73/a/1/Crq6vzZcuWJTsMEZG4XXXHS+zt6OSPX4yn1j4xzGx52N57jHhKEDcAK4HPA18gqOq5fujCExEZnU6bXs7aXW0c6EzNAbLjGazvEEE7xPcTH46IyOhRW11Mr8Pqna28c0pZssM5hsZiEhFJktrqYgBWNqTmXZyVIEREkmRiaT4l+dmsatiX7FBiUoIQEUkSM2P+xOKULUEM2gZhZrOBm4Cpkdu7+wUJjEtEZFSorS7hly9spqunl+zM1PqfPZ5Ron4H/AS4E92LWkRkSNVWF9PZ3cuGxnbmTihOdjhHiSdBdLv7jxMeiYjIKHS4oXp7a8oliHjKMw+Z2d+ZWZWZlfc9Eh6ZiMgoMH1cIfnZmbyZgg3V8ZQgPhE+3xSxzIEZQx+OiMjokplhzK0qSsmG6ng6yk0fjkBEREar2upiHny1gd5eJyMj1p0SkiOe4b6zzezzZnZ/+LjRzAYdrE9EROJTW11C26Futu3tSHYoR4mnDeLHwKkE94L4UTitRmsRkSGSqj2q42mDOC28q1yfJ8NhuEVEZAjMHl9EVoaxsmEf7z+5KtnhHBZPCaIn8nafZjYD9YcQERkyedmZnFRZmJYliJuAJWa2keA+01OBlLoXhIhIuqutLuGZdY3JDuMo8VzF9ISZzQLmECSIt8IhwEVEZIjUVhfz+1fq2d16kMrivGSHAwyQIMzsAnd/0sw+GLVqppnh7v+d4NhEREaNyIbqlE8QwLnAk8ClMdY5oAQhIjJE5h1OEPs4f25lkqMJ9Jsg3P3r4eSt7r4pcp2ZqfOciMgQKsrLZurYgpRqqI7nKqbfx1h2/1AHIiIy2tVWp9a9IQZqg5gL1AIlUe0QxUBqVJCJiIwgtdUlPPrGTloPdlGcl/wBKwZqg5gDXAKUcnQ7RBtwbSKDEhEZjfraIVY1tHLmjLFJjmbgNogHgQfN7Cx3f3EYYxIRGZXmV5cAwZVMqZAg4mmDuN7MSvtmzKzMzO5KYEwiIqNSRVEulUW5rEyRe0PEkyAWuHtL34y77wXekbiQRERGr9rqYlalSEN1PAkiw8zK+mbCu8nFM0SHiIgcp9rqEtbtbudgV/KHvIvnh/57wAtm1ndp60eAf01cSCIio1dtdTE9vc6anW0snFw6+A4JNGgJwt3vBj4M7AJ2Ax9093sSHZiIyGhUG9FQnWzxVDHh7iuB+4AHgXYzmxLPfmZ2kZmtMbP1ZnZzP9tcYWarzGylmf0mYvknzGxd+PhErH1FREaayeX5FOVlpURD9aBVTGZ2GUE1UzVBCWIqsJqgE91A+2UCtwPvAeqBpWa22N1XRWwzC/gKcI677zWzynB5OfB1oI5g3Kfl4b57j/8tioikDzNjXlVq9KiOpwTxDeBMYK27TwcWAc/Hsd/pwHp33+juncC9wOVR21wL3N73w+/uu8Pl7wP+7O7N4bo/AxfF8ZoiImmvtrqEt3a20tPrSY0jngTR5e5NBFczZbj7EuCUOPabCGyLmK8Pl0WaDcw2s+fN7CUzu+g49sXMrjOzZWa2rLExtW60ISJyouZPLOZgVy8bG9uTGkc8VzG1mFkh8Azwn2a2G+iOYz+LsSw6HWYBs4DzgEnAs2Y2P859cfc7gDsA6urqkptqRUSGSGRD9azxRUmLI54SxOVAB/Al4I/ABmLfIyJaPTA5Yn4S0BBjmwfdvSscUnwNQcKIZ18RkRFpZsUYcrMykt5QPWCCCBuaH3T3Xnfvdvdfufv/DaucBrMUmGVm080sB7gSWBy1zQPA+eFrjSOoctoIPAa8NxzWowx4b7hMRGTEy8rMYO6EIt7cntyG6gEThLv3AB1mVnK8B3b3buBGgh/21cB97r7SzG4Nr4wiXNdkZquAJcBN7t7k7s0EjeNLw8et4TIRkVFhXnUJKxv24Z682vN42iAOAm+Y2Z+B/X0L3f3zg+3o7o8Cj0Yt+1rEtAN/Hz6i970L0KCAIjIq1VYX89u/bKV+7wEmlxckJYZ4EsQj4UNERIZJ7eF7VLemXoIwsynuvtXdfzWcAYmICMydUEyGwaqGfVw0f0JSYhioDeKBvgkzi3VfahERSZD8nExmVhQmtUf1QAkisi/CjEQHIiIiR5s/sSRlE4T3My0iIsOgtrqYna0HaWo/lJTXHyhBLDSzVjNrAxaE061m1mZmyR9FSkRkhJsX0VCdDP0mCHfPdPdidy9y96xwum++eDiDFBEZjWqrgi5obyapR3Vc94MQEZHhV1KQzaSy/NQrQYiISPLVVhezSglCRESi1VaXsGnPftoPxTOI9tBSghARSWF9PapX7xj+UoQShIhICjt8b4jtw99QrQQhIpLCxhfnMq4wJykN1UoQIiIpzMzCob+VIEREJEptdTHrdrdxqLtnWF9XCUJEJMXVVhfT1eOs29U+rK+rBCEikuION1QPc49qJQgRkRQ3tbyAwtysYW+HUIIQEUlxGRlGTVWREoSIiByrtrqE1Tta6ekdvrsvKEGIiKSBedXFdHT2sLlp/7C9phKEiEgamH+4oXr4qpmUIERE0sCs8YXkZGYM65VMShAiImkgOzOD2RMKh3XobyUIEZE0UVtVwpvb9+E+PA3VShAiImmidmIxezu62LHv4LC8nhKEiEia6Ls3xHA1VCtBiIikibkTijEbviE3EpogzOwiM1tjZuvN7OYY668xs0Yzey18fDpiXU/E8sWJjFNEJB2Myc1i+rgxw1aCyErUgc0sE7gdeA9QDyw1s8Xuvipq0/9y9xtjHOKAu5+SqPhERNJRbXUJr2zZOyyvlcgSxOnAenff6O6dwL3A5Ql8PRGREa+2upjtLQfYu78z4a+VyAQxEdgWMV8fLov2ITNbYWb3m9nkiOV5ZrbMzF4ys7+O9QJmdl24zbLGxsYhDF1EJDX19ahetSPx1UyJTBAWY1n0xbsPAdPcfQHwOPCriHVT3L0O+Bhwm5nNPOZg7ne4e52711VUVAxV3CIiKevIlUyJb6hOZIKoByJLBJOAhsgN3L3J3Q+Fs3cCp0asawifNwJPAe9IYKwiImmhbEwO1SV5vLk9vUsQS4FZZjbdzHKAK4GjrkYys6qI2cuA1eHyMjPLDafHAecA0Y3bIiKj0rzqkmEpQSTsKiZ37zazG4HHgEzgLndfaWa3AsvcfTHweTO7DOgGmoFrwt1rgJ+aWS9BEvt2jKufRERGpdrqYp54axcdnd0U5CTsZzxxCQLA3R8FHo1a9rWI6a8AX4mx3wvAyYmMTUQkXdVWF+MOq3e0cerUsoS9jnpSi4ikmdqJ4ZVMCa5mUoIQEUkz1SV5lBZkJ7xHtRKEiEiaMTNqq4uVIERE5Fjzq0tYs7ONrp7ehL2GEoSISBqaV11MZ08v63e3J+w1lCBERNJQbTjkRiKrmZQgRETS0PRxY8jPzuTN7Ym7kkkJQkQkDWVmGDVVRaxSCUJERKLVVpewakcrvb3R46AODSUIEZE0VVtdTPuhbrY2dyTk+EoQIiJpKtEN1UoQIiJpavaEQrIyLGEjuypBiIikqdysTGaNL0pYCSKho7mKiEhiXbKgio7O7oQcWwlCRCSNffb8kxJ2bFUxiYhITEoQIiISkxKEiIjEpAQhIiIxKUGIiEhMShAiIhKTEoSIiMSkBCEiIjGZe2KGiR1uZtYIbEl2HClgHLAn2UGkEJ2Po+l8HKFzEZjq7hWxVoyYBCEBM1vm7nXJjiNV6HwcTefjCJ2LwamKSUREYlKCEBGRmJQgRp47kh1AitH5OJrOxxE6F4NQG4SIiMSkEoSIiMSkBCEiIjEpQaQpM7vIzNaY2XozuznG+r83s1VmtsLMnjCzqcmIc7gMdj4itvuwmbmZjdjLG+M5F2Z2Rfj9WGlmvxnuGIdTHH8rU8xsiZm9Gv69vD8ZcaYkd9cjzR5AJrABmAHkAK8D86K2OR8oCKdvAP4r2XEn83yE2xUBzwAvAXXJjjuJ341ZwKtAWThfmey4k3w+7gBuCKfnAZuTHXeqPFSCSE+nA+vdfaO7dwL3ApdHbuDuS9y9I5x9CZg0zDEOp0HPR+gbwHeAg8MZ3DCL51xcC9zu7nsB3H33MMc4nOI5Hw4Uh9MlQMMwxpfSlCDS00RgW8R8fbisP58C/iehESXXoOfDzN4BTHb3h4czsCSI57sxG5htZs+b2UtmdtGwRTf84jkftwBXm1k98CjwueEJLfVlJTsAOSEWY1nM65XN7GqgDjg3oREl14Dnw8wygB8A1wxXQEkUz3cji6Ca6TyCkuWzZjbf3VsSHFsyxHM+rgJ+6e7fM7OzgHvC89Gb+PBSm0oQ6akemBwxP4kYxWIzuxD4KnCZux8aptiSYbDzUQTMB54ys83AmcDiEdpQHc93ox540N273H0TsIYgYYxE8ZyPTwH3Abj7i0AewUB+o54SRHpaCswys+lmlgNcCSyO3CCsUvkpQXIYyXXMMMj5cPd97j7O3ae5+zSCNpnL3H1ZcsJNqEG/G8ADBBcxYGbjCKqcNg5rlMMnnvOxFVgEYGY1BAmicVijTFFKEGnI3buBG4HHgNXAfe6+0sxuNbPLws2+CxQCvzOz18ws+o9ixIjzfIwKcZ6Lx4AmM1sFLAFucvem5EScWHGej38ArjWz14HfAtd4eEnTaKehNkREJCaVIEREJCYlCBERiUkJQkREYlKCEBGRmJQgREQkJiUIGfXMbGx4KfBrZrbTzLZHzOfEeYxfmNmcBMZYb2alZpZpZs+Gy2aY2ZWJek0RXeYqEsHMbgHa3f3fo5Ybwd9LUoZfCMcJOmo4jLCn/I3u/tfJiElGPpUgRPphZieZ2Ztm9hPgFaDKzO4ws2XhfRS+FrHtc2Z2ipllmVmLmX3bzF43sxfNrDLGsS8I179mZq+Y2RgzuzC8L8ED4b0abg8TU+R+WWbWlyS+DZwfHuPziTwXMjopQYgMbB7wc3d/h7tvB2529zpgIfAeM5sXY58S4Gl3Xwi8CHwyxjY3Ade5+ynAuzkyBPkZwBeBk4EaYg9b3udmYIm7n+Lu//cE3pvIgJQgRAa2wd2XRsxfZWavEJQoaggSSLQD7t43vPpyYFqMbZ4HbjOzzwHF7t4TLn/J3TeH8/cCfzUUb0LkRChBiAxsf9+Emc0CvgBc4O4LgD8SDOwWrTNiuocYw+q7+zeBzxCMl7U0PDYcOxS1GgklaZQgROJXDLQBrWZWBbzvRA9kZjPdfYW7f4vg9p99V0CdGd4jORO4AnhugMO0EQxlLpIQShAi8XsFWAW8CdxJUE10or4cNoCvAFqAP4XLXwC+B7wBrOXYoakjvQpkho3daqSWIafLXEVShC5blVSjEoSIiMSkEoSIiMSkEoSIiMSkBCEiIjEpQYiISExKECIiEpMShIiIxPT/AVXHV4h6MI2QAAAAAElFTkSuQmCC\n", 175 | "text/plain": [ 176 | "
" 177 | ] 178 | }, 179 | "metadata": { 180 | "needs_background": "light" 181 | }, 182 | "output_type": "display_data" 183 | } 184 | ], 185 | "source": [ 186 | "plt.xlabel('Train split')\n", 187 | "plt.ylabel('Fraction of duplicates')\n", 188 | "plt.title('Duplicate edges within test set vs Split')\n", 189 | "plt.plot(splits, duplicates)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 8, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "[]" 201 | ] 202 | }, 203 | "execution_count": 8, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | }, 207 | { 208 | "data": { 209 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3dd3gVZdrH8e9tAOlNAkgJvSPNCPbOig1cK4i+YmN1Ldt0dVcXEd3Vta4i9lVsK3ZFRbGBHSlSpBN6CEjvNcn9/jETPcSUA+bknOT8PteVizP9PsOcued5ZuZ5zN0REZHkdUC8AxARkfhSIhARSXJKBCIiSU6JQEQkySkRiIgkOSUCEZEkp0SQpMxskJl9VULrGmpmL5bEuvZhm+3MbKqZbTGz60tz2yXFzP5uZk/v57KPm9k/SjqmRGFmI83szvDzMWY2L94x7QszG29mV4SfB5rZR/GOqShKBMUwsyVmtsPMtprZqvAArZ4Acf10oCWpvwLj3b2Guz9c2hsvif3v7v9y9/1ah7tf5e537M+yJXXsmNnxZpb5a9dTHHf/0t3bRWx3iZmdHOvtmtnRZvaNmW0ys/Vm9rWZHbav63H3l9z9NxHrdTNrXbLR/jpKBNE5092rA92A7sDf4hyPQDNgVmETzSylFGMpaPsV4rl9+XXMrCbwHjAcqAs0Bm4HdsUzrphxd/0V8QcsAU6OGL4HeD9i+EDgPmAZ8CPwOFAlnHY8kAn8HVgbrmtglMvWITgQ1wAbws9Nwmn/BHKAncBW4JFwfHvgY2A9MA84P2JbBwGjgc3AROAO4KsivvfhwDfARmA6cHzEtBbA58CWcHuPAC9GTP8/YCmwDvhH5D4kuPi4GVgYTn8VqBtOqwy8GI7fCEwCGhQQ22f5vn9bYCTwGDAG2AacDNQCng/34VLgVuCAcB2DgK+BB8NtLQKODMcvB1YDlxSybwrb/w5cAywAFofjHgrXtxmYAhwTsZ6hefsNaB4uf0l4PKwFbini/2ckcGe+4+wvYdwrgUv3Mfaijp3TgNnh//cK4AagGrADyA3XsxVoVMD2frFslL+NX3y/8PML4TZ3hNv8awHbnAOcETFcIdxGD6I/xtKBjUXs/7zjZziwCZgLnBQxfTxwRcS8X4Wfvwj/n7eF8V8Q73OcuysRFLuD9j6JNQF+AB6KmP4fghNsXaAG8C5wVzjteCAbeIDgpH9ceAC0i2LZg4BzgKrhtNeAtws60MLhagQnnEvDA79HePB3CqePIjjpVgM6hz/KAhMBwdXPuvBHfADQOxxODad/G/GdjiX4keed0DqGB/jRQCWCRLcnYh/+EZgQ7ssDgSeAl8Npvwv3QVUgBTgUqFlIjPm//8jwB3lUGHNlgiTwTrj/mgPzgcvD+QeF/zeXhtu6k+AEPCKM6zfh96oezfbDcU5wMq3Lzwn9ovD/sgLBiXoVUDmcNpRfJoKngCpAV4Krzw6FbH8ke58os4FhQMXw/207UCfKfVfcsbOSMIERXKD0iNhuZjG/n6KWLeq3kf/7ZUascwkRF2cFbHMI8FLE8OnA3H05xoCaBMf8c8Cp+fdlxPHzp3CfX0Bw/OVd1Py0j4lIBBHHSet4n9v2+j7xDiDR/8KDbivBScGBT4Ha4TQLD95WEfMfwc9Xg3kHe7WI6a8SXCUXuWwBcXQDNkQM5/8xXwB8mW+ZJ4DbwgN+D9A+Ytq/KDwR3AS8kG/cWIKr1bQCvtP/+PmENoTwxB4OVwV283MimMPeV04Hh7FVAC4jKIV0ieL/Jf/3Hwk8HzGcQnAi7Rgx7ncE9xXyfpwLIqYdEv7/NogYtw7oFs32w3EOnFhM3BuAruHnofwyETSJmHci0L+Q9Yxk7xPlDqBCxPTVwOFR7rtCj53w87Jw39XMN8/xFJ8Iilq2wN9GId9vXxJBa4Lfa9Vw+CVgSPh5X46xDmEcmWGso/OOj/D4yQIs3//Xxfn3MWUgEegeQXTOcvcaBAdke6BeOD6V4EQ3xcw2mtlG4MNwfJ4N7r4tYngp0Ki4Zc2sqpk9YWZLzWwzQZGydhF1382AXnnrCtc3EGgYrrMCwVVfZByFaQacl29dRxOctBsV8p3yNIrcjrtvJzihRq77rYj1ziGoqmhAUOwfC4wysywzu8fMKhYRZ36R368eQYkkMralBKWdPD9GfN4Rxpt/3L4+GBAZA2b2FzObE95w3EhQXVWv4EWBoMSQZ/s+bH+du2fv57JFHTsQlExPA5aa2edmdkSU6y1u2cJ+G7+Ku2cQHFdnmllVoC/BxQrswzHm7nPcfZC7NyEoRTciKMXnWeHhmb0k448HJYJ94O6fE1wh3BeOWktwsujk7rXDv1oe3FjOU8fMqkUMpxFcSRS37F+AdkAvd69JUAUDQUkCgquKSMuBzyPWVdvdq7v71QR15NlA03xxFGY5QYkgcl3V3P1ugqJ+Qd8pz0qCap8gWLMqBFUjkes+Nd+6K7v7Cnff4+63u3tHgvr6MwjuN0Qrcp+sJShpNMsX54p9WF+02ypwvJkdQ1C6Op+gaqE2QfWBFbJsadmXYwd3n+Tu/YD6wNsEV+4FreeXGyp8WSj8t7Gv8RfkZWAA0A+YHSYH9vcYc/e5BL/9zhGjG5tZ5P9ltPEnHCWCffcfoLeZdXP3XII63QfNrD6AmTU2s1PyLXO7mVUKTwxnAK9FsWwNgkSx0czqElTxRPoRaBkx/B7Q1swuNrOK4d9hZtbB3XOAN4GhYUmjI0E1T2FeJLiaOsXMUsyscvioYBN3XwpMjvhORwNnRiz7erjskWZWieBJi8gfy+PAP82sWfidU82sX/j5BDM7JCz1bCY4kecUEWehwu/8aritGuH2/hx+t5KQf/8XpAZBAl4DVDCzIQR1z/EW9bET/h8PNLNa7r6H4P8lJ2I9B5lZrYI2UsyyeX7x29iP+AsyiuA+z9X8XBqI+hgzs/Zhaa5JONyUILFMiJitPnB9uL/OI6hKGlNC8ZcqJYJ95O5rCG5C5r3McxOQAUwIq3A+IbiSz7OKoF44i6Cu8qrw6qK4Zf9DcNNwLcHB92G+UB4CzjWzDWb2sLtvITjw+4fbWgX8m+BGHMC1BFUFqwiubJ4t4jsuJ7iS+jvBSWw5cCM/Hy8XAr0InjC5LdwfecvOAq4j+CGuJKirXc3Pj909RFDX+pGZbQm/W69wWkOCRLKZoGj/Ob/uxH0dwX2YRcBXBCeEZ37F+iLttf8LmWcs8AHBTeqlBE/qLC9k3tK0r8fOxcCS8Bi9iuAGeN5V8svAorBKqaBqkQKXDRX12yjKXcCt4TZvKGgGd19J8FDDkcArEZOiPca2EByX35nZNoLjdCZBST3Pd0Abgt/oP4Fz3X1d/hUVYCjwXBj/+VHMH3O2dxWXlCQzO57gZmCT4uYtryx4+W4j0MbdF8c7HkkMZf23YWaDCG4GHx3vWEqCSgRS4szszLAKqhrB/ZQfCJ70EJEEpEQgsdCPoLifRVB07u8qeookLFUNiYgkOZUIRESSXJlrGKtevXrevHnzeIchIlKmTJkyZa27pxY0rcwlgubNmzN58uR4hyEiUqaYWaGtCahqSEQkySkRiIgkOSUCEZEkp0QgIpLklAhERJKcEoGISJJTIhARSXJKBCIiCS47J5d/vj+brI07YrJ+JQIRkQSWk+vc8Np0nvpyMePmrY7JNpQIREQSVE6uc+Pr03l7WhY3ntKOgb2aFb/QflAiEBFJQLm5zt/enMGb36/gz73bcs0JrWO2LSUCEZEEk5vr3PL2D7w6OZPrT2rD9Se1ien2lAhERBKIuzNk9Exenrica05oxZ9Ojm0SACUCEZGE4e7c/u5sXpywjN8d15IbftMOM4v5dpUIREQSgLtz5/tzGPnNEq44ugU392lfKkkAYpwIzKyPmc0zswwzu7mA6WlmNs7MpprZDDM7LZbxiIgkInfn7g/m8t+vFjPoyObccnqHUksCEMNEYGYpwAjgVKAjMMDMOuab7VbgVXfvDvQHHo1VPCIiicjduXfsPJ74YhEXH96M287sWKpJAGJbIugJZLj7InffDYwC+uWbx4Ga4edaQFYM4xERSTgPfrKAR8cv5MJeadzet1OpJwGIbSJoDCyPGM4Mx0UaClxkZpnAGOC6glZkZoPNbLKZTV6zZk0sYhURKXUPfbKAhz9dwAXpTbmzX2cOOKD0kwDENhEU9I083/AAYKS7NwFOA14ws1/E5O5Punu6u6enphbY97KISJkyYlwGD34yn3MPbcJdZx8StyQAsU0EmUDTiOEm/LLq53LgVQB3/xaoDNSLYUwiInH3xOcLuXfsPH7bvTH/PqdLXJMAxDYRTALamFkLM6tEcDN4dL55lgEnAZhZB4JEoLofESm3nv5yEXd9MJe+XRtx33ldSYlzEoAYJgJ3zwauBcYCcwieDpplZsPMrG8421+AK81sOvAyMMjd81cfiYiUCyO/Xsyd78/h9EMO5oHzEyMJAFSI5crdfQzBTeDIcUMiPs8GjoplDCIiieCFb5cw9N3ZnNKpAf/p340KKYnzPm/iRCIiUk699N1S/vHOLE7u0IDhA3pQMYGSAMS4RCAiksy27crm9ndn8erkTE5sX58RA7tTqUJiJQFQIhARiYmpyzbwx1emsWz9dq45oRV/PLltwpUE8igRiIiUoOycXB4dv5CHPl1Aw5qVeWXwEfRsUTfeYRVJiUBEpIQsX7+dP74yjSlLN3BWt0YMO6szNStXjHdYxVIiEBH5ldydN79fwW2jZ2EGD/XvRr9u+VvUSVxKBCIiv8Km7Xv4+9s/8P6MlfRsUZcHzu9KkzpV4x3WPlEiEBHZT98sXMtfXp3Omi27+Gufdvzu2FYJ85LYvlAiEBHZR7uzc7n/43k8+cUiWhxUjTd/fyRdmtSOd1j7TYlARGQfZKzewh9GTWNW1mYu7JXGrad3oGqlsn0qLdvRi4iUEnfnxQlLufP9OVQ7sAJP/V86vTs2iHdYJUKJQESkGGu27OKmN2bw2dzVHNc2lXvP60L9GpXjHVaJUSIQESnCuLmrueG16Wzdlc3tfTvxf0c0i0t3krGkRCAiUogXJyzlH+/MpF2DGrw8+HDaNqgR75BiQolARCQfd2f4Zxk88PH8oLG4C3tQpVJKvMOKGSUCEZEIubnOsPdmM/KbJZzdI+hKMlEbiyspSgQiIqHd2bnc+Pp03pmWxRVHt+Dvp3WIe3/CpSGmac7M+pjZPDPLMLObC5j+oJlNC//mm9nGWMYjIlKY7buzufL5ybwzLYu/9mnHLacnRxKAGJYIzCwFGAH0BjKBSWY2OuyeEgB3/1PE/NcB3WMVj4hIYTZu381lIycxbflG7jr7EAb0TIt3SKUqliWCnkCGuy9y993AKKBfEfMPIOjAXkSk1KzatJPzn/iWmSs28+jAHkmXBCC29wgaA8sjhjOBXgXNaGbNgBbAZzGMR0RkL4vWbOXi/05k4/bdjLz0MI5sXS/eIcVFLBNBQZVrXsi8/YHX3T2nwBWZDQYGA6SlJV+2FpGS90PmJgY9OxGAUYOP4JAmteIcUfzEsmooE2gaMdwEyCpk3v4UUS3k7k+6e7q7p6emppZgiCKSjL5ZuJYBT02gcsUUXrsquZMAxDYRTALamFkLM6tEcLIfnX8mM2sH1AG+jWEsIiIAfDhzJYOemUSj2pV54+ojaZlaPd4hxV3MEoG7ZwPXAmOBOcCr7j7LzIaZWd+IWQcAo9y9sGojEZESMWriMn7/0vd0alyTV393BA1rlZ+G436NmL5Q5u5jgDH5xg3JNzw0ljGIiLg7j32+kHs+nMdxbVN57KIeZb4PgZKkPSEi5VpurvOvMXN4+qvF9O3aiPvO60qlCuW7yYh9pUQgIuXWnpxcbnpjBm9+v4JBRzZnyBkdk+Zt4X2hRCAi5cKenFyWrtvOwjVbg7/V25iRuZEFq7fy595tue7E1uWuH4GSokQgImXKph17whP9Vhau2fbTiX/Zuu1k5/78zEnDmpVpVb8aVx3XinMObRLHiBOfEoGIJKRNO/YwbfnG8ISf97eNNVt2/TRPxRSj+UHVaFu/Bqd2bkir1Oq0Sq1Oy9Rq1KhcMY7Rly1KBCKScN6bkcWQd2axfttuAGpVqUjr+tU5oV3qTyf7VvWr07ROFSqU874CSoMSgYgkjLVbdzHknZmM+WEVXZvU4uH+3elwcA3qVquk+v0YUiIQkbhzd96bsZIh78xk264cburTniuPaaGr/VKiRCAicbVmyy7+8fZMPpy1iq5Na3PfuV1oU047iU9USgQiEhfuzrszVnLbOzPZtjuHm09tzxVHqxQQD0oEIlLqVm/ZyT/ensnYWT/SrWlt7juvC63rqxQQL0oEIlJq3J3R07O4bfQstu/O4W+ntueKY1qSord946rYRGBmfwCeBbYATxP0K3yzu38U49hEpBxZvWUnt741k49m/0j3tNrce25XWtdXE9CJIJoSwWXu/pCZnQKkApcSJAYlAhEplrvzzrSgFLBzTw63nNaBy45uoVJAAokmEeT9b50GPOvu000P9IpIFFZv3snf35rJJ3N+pEdabe49ryut1BFMwokmEUwxs48IOpf/m5nVAHJjG5aIlGXuztvTVjB09Gx27snh1tM7cOlRKgUkqmgSweVAN2CRu283s4MIqodERPayOzuXD2et4vlvljB56QYObVaHe8/tou4gE1w0ieBjdz8pb8Dd15nZq8BJRSwjIklkxcYdvPzdMkZNWsbarbtpdlBVhvXrxMBezVQKKAMKTQRmVhmoCtQzszr8fK+gJtAompWbWR/gISAFeNrd7y5gnvOBoYAD0939wn35AiISH7m5ztcL1/LCt0v5ZM6POHBS+/pcfERzjmldTx3AlCFFlQh+B/yR4KQ/hZ8TwWZgRHErNrOUcL7eQCYwycxGu/vsiHnaAH8DjnL3DWZWf7++hYiUmk3b9/D695m8NGEpi9Zu46BqlbjquFYM6JlG07pV4x2e7IdCE4G7PwQ8ZGbXufvw/Vh3TyDD3RcBmNkooB8wO2KeK4ER7r4h3Obq/diOiJSCmSs28cK3S3ln+gp27smlR1pt/nNBN049pCEHVkiJd3jyKxR7j8Ddh5tZZ6AjUDli/PPFLNoYWB4xnAn0yjdPWwAz+5qg+miou3+Yf0VmNhgYDJCWllZcyCJSQnbuyWHMDyt5YcJSpi7bSJWKKfy2e2MG9mpG58a14h2elJBo3iy+DTieIBGMAU4FvgKKSwQFVRB6vuEKQJtw/U2AL82ss7tv3Gsh9yeBJwHS09Pzr0NEStjy9dt56btlvDp5Oeu37aZlvWoMOaMj5xzahFpV1PNXeRPNU0PnAl2Bqe5+qZk1IGhqojiZQNOI4SZAVgHzTHD3PcBiM5tHkBgmRbF+ESlhKzbu4N4P5/LO9CwM6N2xARcf3pyjWh+kjmHKsWgSwQ53zzWzbDOrCawGWkax3CSgjZm1AFYA/YH8TwS9DQwARppZPYKqokVRRy8iJWLbrmweG7+Qp74Mfn6Dj23JJUc0p1HtKnGOTEpDNIlgspnVBp4ieHpoKzCxuIXcPdvMrgXGEtT/P+Pus8xsGDDZ3UeH035jZrOBHOBGd1+3n99FRPZRTq7z+pTl3PfRfNZs2UW/bo34a5/2NFYCSCrmHn2Vu5k1B2q6+4xYBVSc9PR0nzx5crw2L1JufJOxljven8OclZvpkVabW8/oSI+0OvEOS2LEzKa4e3pB06K5WWzAQKCluw8zszQz6+nuxZYKRCTxLFqzlX+Nmcsnc36kce0qDB/QnTO6HKx7AEksmqqhRwkamTsRGEbQL8EbwGExjEtEStjG7bt56NMFvPDtUipXTOGvfdpx2VEtqFxR7wAku2gSQS9372FmUwHCN4ArxTguESkhe3JyeeHbpTz06QK27NzDBYel8efebUmtcWC8Q5MEEU0i2BM2F+EAZpaKmqEWSXjuzidzVnPXmDksWruNo1vX49YzOtC+Yc14hyYJJppE8DDwFtDAzP5J8F7BrTGNSkR+lVlZm/jn+3P4ZuE6WqZW45lB6ZzQrr7uA0iBomli4iUzm8LPzU6f5e5zYhuWiOyPlZt28J+PF/DqlOXUqlKR2/t24sJeaVRMOSDeoUkCi6ZEAEFz1HnVQ3rAWCTBLFu3ncc+X8gbUzJxnMuPasF1J7ahVlU1ByHFi+bx0SHAeQRPChnwrJm95u53xjo4ESlaxuqtPDo+g3emZZFixnnpTbjquFZqDlr2STQlggFAd3ffCWBmdwPfA0oEInEyZ+VmHhmXwZgfVnJghQO45IjmDD62JQ1rVS5+YZF8okkESwian94ZDh8ILIxVQCJSuGnLN/LIZxl8MudHqh9YgauOa8XlR7egXnU9Cir7L5pEsAuYZWYfE9wj6A18ZWYPA7j79TGMT0SAiYvXM/yzBXy5YC21qlTkTye3ZdCRzXUPQEpENIngrfAvz/jYhCIikdydrzLWMvyzDCYuXk+96pW4+dT2XHR4M6ofGO1zHiLFi+bx0efMrAqQ5u7zSiEmkaSW9yLYI+MymL58Iw1rVua2MzvS/7A0qlRScxBS8qJ5auhM4D6gEtDCzLoBw9y9b6yDE0kmObnOhzNXMfyzBcxdtYUmdarwr98ewjmHNlafwBJT0ZQvhxJ0RD8ewN2nhZ3NiEgJ2JOTy9tTV/DY5wtZtGYbLVOrcf95XenbrZFeBJNSEU0iyHb3TfleTVe/wSK/0s49ObwyaTlPfrGIFRt30OHgmjxyYXdO7XwwKQeoKQgpPdEkgplmdiGQYmZtgOuBb2Iblkj5tXnnHl6csJRnvlrM2q27ObRZHe48qzPHt0tVW0ASF9EkguuAWwgeI/0fQfeSUb1MZmZ9gIcImqd42t3vzjd9EHAvQZ/GAI+4+9NRRS5Sxqzftptnv17MyG+WsGVnNse2TeWa41vRs0VdJQCJqyITQdj89O3ufiNBMohauOwIgvcOMoFJZjba3Wfnm/UVd792X9YtUpas3LSDp75YzMsTl7FjTw59OjXkmhNac0iTWvEOTQQoJhG4e46ZHbqf6+4JZLj7IgAzGwX0A/InApFyacnabTz++ULe+D6TXId+3Rpx9XGtaNOgRrxDE9lLNFVDU81sNPAasC1vpLu/WcxyjYHlEcOZQK8C5jvHzI4F5gN/cvfl+Wcws8HAYIC0tLQoQhaJnzkrN/Po+IW8PyOLCikH0P+wNAYf21INwUnCiiYR1AXWEfRZnMeB4hJBQZWe+Z82ehd42d13mdlVwHP5thMs5P4k8CRAenq6nliShPT9sg08Oi6DT+asplqlFK48tiWXH92C+jXUEJwktmjeLL50P9edCTSNGG4CZOVb97qIwaeAf+/ntkTiZlbWJu4bO49x89ZQu2pF/ty7LZccoXaApOyIZYMlk4A24ctnK4D+wIWRM5jZwe6+MhzsC6jnMykzlqzdxgMfz2f09CxqVanITX3a839HNKOa2gGSMiZmR6y7Z5vZtQSPm6YAz7j7LDMbBkx299HA9WbWF8gG1gODYhWPSEn5cfNOHv50Aa9MWk7FlAO45oRWDD62FbWqqAQgZZO5F1zlbmZ/cPeHzOwod/+6lOMqVHp6uk+ePDneYUgS2rR9D49/sZBnv15Mdo4zoGca153Ymvo1dQ9AEp+ZTXH39IKmFVUiuJTgZbDhQI9YBCZSFuzYncOz3yzm8fEL2bIrm35dG/Gn3m1pdlC1eIcmUiKKSgRzzGwJkGpmMyLGG+Du3iWmkYnE2Z6cXEZNWs7wTxewessuTmpfnxtOaUeHg2vGOzSRElVoInD3AWbWkKCOX01OS9LIzXXenZHFAx/PZ+m67RzWvA4jBvbgsOZ14x2aSEwU92bxKqCrmVUC2oaj57n7nphHJlLK3J3x89dwz4fzmLNyM+0b1uDZQYepMTgp96LpmOY44HmCTuwNaGpml7j7FzGOTaTUTF6ynns+nMfEJetJq1uVh/p348wujThAzUFLEojm8dEHgN/kdVNpZm2Bl4H9bYNIJGF8t2gdwz/L4KuMtaTWOJA7zurMBelNqVRBHcJI8ogmEVSM7KvY3eebmR6YljLrp07hP81g4pL11Kt+IH87tT0XH9GMqpX0Mpgkn2iO+slm9l/ghXB4IDAldiGJxIa789nc1Qz/LINpYafwQ8/sSP+eaVSuqD6BJXlFkwiuBq4h6JnMgC+AR2MZlEhJys11xs5axfDPMpi9crM6hRfJJ5pG53YR3Cd4IPbhiJScnFznvRlZPPJZBgtWb6VlvWrcd15X+qlTeJG9qEJUyp09Obm8NXUFj41fyOK122jboDoPD+jO6YeoU3iRgigRSLmxKzuH1yZn8tj4hazYuINOjWry+EU9+E3HhnoMVKQISgRS5u3YncPLE5fxxBcL+XHzLrqn1ebOszrrRTCRKEXzQllb4EagWeT87v6LnsREStsHP6xkyOhZrNmyi14t6vLA+d04stVBSgAi+yCaEsFrwOMEPYjlxDYckeis2bKL20bPZMwPq+jcuCYjLuxBzxZqC0hkf0STCLLd/bGYRyISBXdn9PQsho6exbZdOfy1TzsGH9OSCnoKSGS/RZMI3jWz3wNvAbvyRrr7+phFJVKAVZt2cuvbP/DJnNV0T6vNved2oXX9GvEOS6TMiyYRXBL+e2PEOAdaFregmfUh6NwmBXja3e8uZL5zCaqgDnN3dT8me3F3XpuSyR3vzWZ3di63nt6BS49qoUdBRUpINC+UtdifFZtZCjAC6A1kApPMbLS7z843Xw2Ct5a/25/tSPm2YuMObn5jBl8uWEvPFnX59zldaFFPPYOJlKRonhqqSNDMxLHhqPHAE1H0SdATyHD3ReF6RgH9gNn55rsDuAe4IfqwpbzLzXX+N3EZd42ZgwN39OvEwF7N9D6ASAxEUzX0GFCRn9sXujgcd0UxyzUGlkcMZwK9Imcws+5AU3d/z8wKTQRmNhgYDJCWlhZFyFKWLV23jZvemMGERes5unU97jr7EJrWrRrvsETKrWgSwWHu3jVi+DMzmx7FcgVduvlPE80OAB4EBhW3Ind/EngSID093YuZXcqonFznuW+WcO/YeVQ4wPj3OYdwfnpTvRMgEqOFVDYAABMLSURBVGPRJIIcM2vl7gsBzKwl0b1PkAk0jRhuAmRFDNcAOgPjwx96Q2C0mfXVDePkk7F6Kze9MYMpSzdwQrtU/nX2IRxcq0q8wxJJCtEkghuBcWa2iOAqvxlwaRTLTQLamFkLYAXQH7gwb6K7bwLq5Q2b2XjgBiWB5JKdk8tTXy7mwU/mU6ViCg9e0JWzujVWKUCkFEXz1NCnZtYGaEeQCOaGTVMXt1y2mV0LjCV4fPQZd59lZsOAye4++lfGLmWYuzN+3hru/3geM1dspk+nhgw7qxP1a1SOd2giSafQRGBmJ7r7Z2Z2dr5JrcwMd3+zuJW7+xhgTL5xQwqZ9/go4pUyzt35ZM5qHv50AT+s2ESTOlV45MKgiWiVAkTio6gSwXHAZ8CZBUxzoNhEIJInN9f5aPYqHv406CUsrW5V7jmnC7/t0VidxIjEWaGJwN1vCz8Oc/fFkdPCen+RYuXmOh/MXMXwzxYwd9UWWtSrxv1hL2FqH0gkMURzs/gNoEe+ca8Dh5Z8OFJe5O8mslVqNf5zQTfO6HKwEoBIginqHkF7oBNQK999gpqA7uhJgbJzcnl3RhbDP8tg0Zqgm8jhA7pzmrqJFElYRZUI2gFnALXZ+z7BFuDKWAYlZc+enFzenrqCEeMyWLJuO+0b1uCxgT04pZO6iRRJdEXdI3gHeMfMjnD3b0sxJilDdmfn8tbUTEaMW8iy9dvp1KgmT1x8KL07NFACECkjorlHcJWZzXH3jQBmVge4390vi21oksh2Z+fy+pRMRozLYMXGHXRpUoshZ6RzUof6egxUpIyJJhF0yUsCAO6+IWwsTpJQdk4ub01dwUOfLiBzww66Na3Nnb/tzPFt1VG8SFkVTSI4wMzquPsGADOrG+VyUo7k5jrvzsjioU8WsGjtNjo3rskdZykBiJQH0ZzQ7we+MbPXw+HzgH/GLiRJJO7O2FmreODj+cz/cSvtGtTgiYsP5TcdGygBiJQT0bQ19LyZTQFOIGhr6Oz8vYxJ+ePujJu3mvs/ms+srM20TK3G8AFBUxC6CSxSvkRVxRM2FreG8P0BM0tz92UxjUziwt35OmMd9300j2nLN5JWt6reBBYp56LpqrIvQfVQI2A1QTPUcwheNpNy5LtF67j/4/lMXLyeRrUqc9fZh3DuoU3UFpBIORdNieAO4HDgE3fvbmYnAANiG5aUpqnLNvDAx/P5csFaUmscyO19O9G/Z1MOrJAS79BEpBREkwj2uPs6MzvAzA5w93Fm9u+YRyYxN3PFJh78eD6fzl1N3WqVuOW0Dlx0eDOqVFICEEkm0SSCjWZWHfgCeMnMVgPZsQ1LYilzw3bu+mAu789YSc3KFbjxlHZccmRzqh+op4JFklE0v/x+wA7gT8BAoBYwLJZBSWzs3JPDE58v4rHPMwC4/sTWXH5MS2pVqRjnyEQknopMBGaWArzj7icDucBz+7JyM+sDPETQVeXT7n53vulXAdcAOcBWYLAeTS157s6HM1dx5/tzWLFxB2d0OZi/ndaBxrXVObyIFJMI3D3HzLabWa2ws/mohUlkBNAbyAQmmdnofCf6/7n74+H8fYEHgD779A2kSPNWbeH2d2fxzcJ1tG9Yg1GDD+fwlgfFOywRSSDRVA3tBH4ws4+BbXkj3f36YpbrCWS4+yIAMxtFUM30UyJw980R81cj6AJTSsCm7Xt48JP5vDBhKdUPrMAd/ToxoGea3gUQkV+IJhG8H/7tq8bA8ojhTKBX/pnM7Brgz0Al4MSCVmRmg4HBAGlpafsRSvLIyXVembSc+z6ax8btuxnYqxl/7t2WOtUqxTs0EUlQRfVQlubuy9x9n+4LRK6igHG/uOJ39xHACDO7ELgVuKSAeZ4EngRIT09XqaEQk5es57bRs5iVtZmeLeoy9MxOdGxUM95hiUiCK6pE8DZhX8Vm9oa7n7OP684EmkYMNwGyiph/FPDYPm5DgFWbdnL3B3N4e1oWB9eqzPAB3Tmjy8FqFE5EolJUIog8i7Tcj3VPAtqYWQtgBdAfuHCvDZi1cfcF4eDpwAIkaruyc3j6y8WMGJdBdq5z3Ymtufr4VlStpPcBRCR6RZ0xvJDPUXH3bDO7FhhL8PjoM2HjdcOAye4+GrjWzE4G9gAbKKBaSH7J3fl0zmrueH82S9dt5zcdG3Dr6R1JO6hqvEMTkTLI3As+x5tZDsFTQgZUAbbnTQLc3eNS+Zyenu6TJ0+Ox6YTQnZOLje8Np23p2XRun51bjuzI8e0SY13WCKS4MxsirunFzStqM7r1eBMgsnNdW564wfenpbFH09uwzUntFbLoCLyq6kyuYxwd4a+O4s3vs/kTye35Q8nt4l3SCJSTuhysoy4d+w8nv92KYOPbcn1J7WOdzgiUo4oEZQBI8Zl8Oj4hVzYK42/ndpej4WKSIlSIkhwz32zhHvHzuOsbo24s19nJQERKXFKBAnstcnLuW30LHp3bMC953VVp/EiEhNKBAlqzA8ruemNGRzTph6PXNhdTweJSMzo7JKAxs1dzR9GTaVHWh2euPhQ9R0sIjGlRJBgJixax1UvTqFdwxo8c+lhai5CRGJOiSCBTFu+kctHTiKtblWev6wXNSurC0kRiT0lggQxd9VmLnlmIgdVP5AXr+hFXfUfICKlRIkgASxas5WLnp5IlYopvHRFLxrUrBzvkEQkiSgRxNmKjTu46OnvcHdevKIXTeuqBVERKV1KBHG0estOBj41ga27snn+8p60rl893iGJSBLSIylxsnH7bi5+eiKrt+zixSt60alRrXiHJCJJSokgDrbs3MMlz0xk8bptjBx0GD3S6sQ7JBFJYqoaKmU7dudw+XOTmZW1mccG9uDI1vXiHZKIJLmYJgIz62Nm88wsw8xuLmD6n81stpnNMLNPzaxZLOOJt93Zufz+pSlMWrKeBy/oxkkdGsQ7JBGR2CUCM0sBRgCnAh2BAWbWMd9sU4F0d+8CvA7cE6t44i0n1/nLa9MZN28Nd/32EM7s2ijeIYmIALEtEfQEMtx9kbvvBkYB/SJncPdx7p7XF/IEoEkM44kbd2fIOzN5d3oWfz+tPf17psU7JBGRn8QyETQGlkcMZ4bjCnM58EFBE8xssJlNNrPJa9asKcEQS8e9Y+fx0nfL+P3xrRh8bKt4hyMispdYJoKCGs/3Amc0uwhIB+4taLq7P+nu6e6enpqaWoIhxt4Tny/k0fELGdgrjRtPaRfvcEREfiGWj49mAk0jhpsAWflnMrOTgVuA49x9VwzjKXWjJi7jrg/mcmbXRgxT72IikqBiWSKYBLQxsxZmVgnoD4yOnMHMugNPAH3dfXUMYyl1789Yyd/e+oHj26Vy/3ldSVHvYiKSoGKWCNw9G7gWGAvMAV5191lmNszM+oaz3QtUB14zs2lmNrqQ1ZUpX8xfwx9fmcqhaXV4bOChVKqg1zVEJHHF9M1idx8DjMk3bkjE55Njuf14mLJ0Pb97YQpt6tfgv4MOo0ol9S4mIolNl6olaM7KzVz67CQa1qrMc5f1pFYVdSwjIolPiaCELFm7jYv/O5GqlSrwwuU9Sa1xYLxDEhGJihJBCVi1aScX/fc7ct158YqeNKmjPgVEpOxQIviVNmzbzcX//Y6N2/fw3KU9aV2/RrxDEhHZJ2qG+lfYuiubQc9OZOn67Tx3aU8OaaI+BUSk7FGJYD/t3JPDlc9NZmbWZh69sAdHtDoo3iGJiOwXJYL9kJ2Ty3UvT+XbReu477wunNxRzUmLSNmlRLCPcnOdv74xg49n/8jQMzvy2+7lssFUEUkiSgT7wN0Z9t5s3vx+BX86uS2DjmoR75BERH41JYIouTt3fziXkd8s4bKjWnD9Sa3jHZKISInQU0NR2J2dy81vzODNqSsY2CuNW0/voJZERaTcUCIoxtZd2Vz94hS+XLCWv/Ruy7UntlYSEJFyRYmgCKu37OTSZycxd9UW7jm3C+enNy1+IRGRMkaJoBAL12zlkmcmsn7bbp6+JJ0T2tWPd0giIjGhRFCAKUs3cPlzk6hwgDFq8OF0aVI73iGJiMSMEkE+H8/+kWv/9z0Hh01JNzuoWrxDEhGJKSWCCC99t5R/vD2TQxrX4r+DDqNedTUlLSLlX0zfIzCzPmY2z8wyzOzmAqYfa2bfm1m2mZ0by1iK4u7c/9E8bnlrJse1TeXlwYcrCYhI0ohZicDMUoARQG8gE5hkZqPdfXbEbMuAQcANsYqjOHtycvn7mz/w2pRMLkhvyj9/25kKKXrPTkSSRyyrhnoCGe6+CMDMRgH9gJ8SgbsvCaflxjCOQm3blc3vX/qez+ev4Q8nteGPJ7fROwIiknRimQgaA8sjhjOBXjHc3j5Zu3UXl42cxMwVm7jr7EMY0DMt3iGJiMRFLBNBQZfWvl8rMhsMDAZIS/v1J+wla7dxybMT+XHzTp76v3RO6qBmpEUkecWyMjwTiHwVtwmQtT8rcvcn3T3d3dNTU1N/VVDTlm/k7Me+YcvObF6+8nAlARFJerFMBJOANmbWwswqAf2B0THcXrE+m/sjA56cQLUDU3j9qiPonlYnnuGIiCSEmCUCd88GrgXGAnOAV919lpkNM7O+AGZ2mJllAucBT5jZrFjF88aUTK58fgqt61fnzauPomVq9VhtSkSkTInpC2XuPgYYk2/ckIjPkwiqjGKu2UFVOal9fR68oBvVDtR7dCIieZLmjJjevC7pzevGOwwRkYSjN6dERJKcEoGISJJTIhARSXJKBCIiSU6JQEQkySkRiIgkOSUCEZEkp0QgIpLkzH2/GgSNGzNbAyyNdxwJoB6wNt5BJBDtj59pX+xN+yPQzN0LbLWzzCUCCZjZZHdPj3cciUL742faF3vT/iieqoZERJKcEoGISJJTIii7nox3AAlG++Nn2hd70/4ohu4RiIgkOZUIRESSnBKBiEiSUyJIcGbWx8zmmVmGmd1cwPQ/m9lsM5thZp+aWbN4xFkaitsXEfOda2ZuZuX6kcFo9oeZnR8eH7PM7H+lHWNpiuK3kmZm48xsavh7OS0ecSYkd9dfgv4BKcBCoCVQCZgOdMw3zwlA1fDz1cAr8Y47XvsinK8G8AUwAUiPd9xxPjbaAFOBOuFw/XjHHef98SRwdfi5I7Ak3nEnyp9KBImtJ5Dh7ovcfTcwCugXOYO7j3P37eHgBEqpD+g4KHZfhO4A7gF2lmZwcRDN/rgSGOHuGwDcfXUpx1iaotkfDtQMP9cCskoxvoSmRJDYGgPLI4Yzw3GFuRz4IKYRxU+x+8LMugNN3f290gwsTqI5NtoCbc3sazObYGZ9Si260hfN/hgKXGRmmcAY4LrSCS3xJU3n9WWUFTCuwOd9zewiIB04LqYRxU+R+8LMDgAeBAaVVkBxFs2xUYGgeuh4gpLil2bW2d03xji2eIhmfwwARrr7/WZ2BPBCuD9yYx9eYlOJILFlAk0jhptQQHHWzE4GbgH6uvuuUoqttBW3L2oAnYHxZrYEOBwYXY5vGEdzbGQC77j7HndfDMwjSAzlUTT743LgVQB3/xaoTNAgXdJTIkhsk4A2ZtbCzCoB/YHRkTOE1SFPECSB8lwHXOS+cPdN7l7P3Zu7e3OC+yV93X1yfMKNuWKPDeBtgocJMLN6BFVFi0o1ytITzf5YBpwEYGYdCBLBmlKNMkEpESQwd88GrgXGAnOAV919lpkNM7O+4Wz3AtWB18xsmpnlP/jLhSj3RdKIcn+MBdaZ2WxgHHCju6+LT8SxFeX++AtwpZlNB14GBnn4CFGyUxMTIiJJTiUCEZEkp0QgIpLklAhERJKcEoGISJJTIhARSXJKBJI0zOyg8BHbaWa2ysxWRAxXinIdz5pZuxjGmGlmtc0sxcy+DMe1NLP+sdqmiB4flaRkZkOBre5+X77xRvC7iEuzA2E7OHs1AxG+OX6tu58Vj5ik/FOJQJKembU2s5lm9jjwPXCwmT1pZpPDdvyHRMz7lZl1M7MKZrbRzO42s+lm9q2Z1S9g3SeG06eZ2fdmVs3MTg7bxX877CtgRJiAIperYGZ5yeBu4IRwHdfHcl9IclIiEAl0BP7r7t3dfQVws7unA12B3mbWsYBlagGfu3tX4FvgsgLmuREY7O7dgGP5uXnsXsAfgUOADhTcpHaem4Fx7t7N3R/ej+8mUiQlApHAQnefFDE8wMy+JyghdCBIFPntcPe8Zr+nAM0LmOdr4D9mdh1Q091zwvET3H1JODwKOLokvoTI/lAiEAlsy/tgZm2APwAnunsX4EOCBsry2x3xOYcCmnV39zuB3xG0BzUpXDf8solk3ayTuFEiEPmlmsAWYLOZHQycsr8rMrNW7j7D3e8i6DYy74mjw8M+dFOA84GviljNFoJmtkViQolA5Je+B2YDM4GnCKp39tcN4Y3oGcBG4KNw/DfA/cAPwHx+2WRypKlASnjTWTeLpcTp8VGRUqbHQSXRqEQgIpLkVCIQEUlyKhGIiCQ5JQIRkSSnRCAikuSUCEREkpwSgYhIkvt/Nue33v+rKiEAAAAASUVORK5CYII=\n", 210 | "text/plain": [ 211 | "
" 212 | ] 213 | }, 214 | "metadata": { 215 | "needs_background": "light" 216 | }, 217 | "output_type": "display_data" 218 | } 219 | ], 220 | "source": [ 221 | "plt.xlabel('Train split')\n", 222 | "plt.ylabel('Fraction of repeats')\n", 223 | "plt.title('Repeated edges from train in test split vs Split')\n", 224 | "plt.plot(splits, repeats)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "# Look at some stuff about graph statistics like degree, etc." 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 3, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "name": "stdout", 241 | "output_type": "stream", 242 | "text": [ 243 | "--------------------------------\n", 244 | "Reading dataset from /Users/raunak/Documents/Datasets/temporal-networks-network-repository/ia-contact/ia-contact.edges\n", 245 | "Finished reading data.\n", 246 | "Setting up graph.\n", 247 | "Finished setting up graph.\n", 248 | "Setting up examples.\n", 249 | "Finished setting up examples.\n", 250 | "Dataset properties:\n", 251 | "Mode: train\n", 252 | "Number of vertices: 274\n", 253 | "Number of static edges: 1686\n", 254 | "Number of temporal edges: 8473\n", 255 | "Number of examples/datapoints: 11298\n", 256 | "--------------------------------\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "duplicate_examples = True\n", 262 | "repeat_examples = True\n", 263 | "\n", 264 | "source = 'networkrepository' # SNAP or networkrepository\n", 265 | "\n", 266 | "if source == 'SNAP':\n", 267 | " data_dir = os.path.expanduser('~/Documents/Datasets/temporal-networks-snap')\n", 268 | " name = 'CollegeMsg.txt'\n", 269 | " dataset_name = 'CollegeMsg'\n", 270 | " path = os.path.join(data_dir, name)\n", 271 | "else:\n", 272 | " data_dir = os.path.expanduser('~/Documents/Datasets/temporal-networks-network-repository')\n", 273 | " name = 'ia-contact'\n", 274 | " dataset_name = 'IAContact'\n", 275 | " path = os.path.join(data_dir, name, name+'.edges')\n", 276 | " \n", 277 | "class_attr = getattr(importlib.import_module('datasets.link_prediction'), dataset_name)\n", 278 | "dataset = class_attr(path, duplicate_examples=duplicate_examples, repeat_examples=repeat_examples)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 7, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "15\n", 291 | "1\n", 292 | "4\n", 293 | "3\n", 294 | "8\n", 295 | "5\n", 296 | "7\n", 297 | "4\n", 298 | "2\n", 299 | "3\n", 300 | "2\n", 301 | "3\n", 302 | "2\n", 303 | "3\n", 304 | "2\n", 305 | "5\n", 306 | "4\n", 307 | "1\n", 308 | "6\n", 309 | "4\n", 310 | "4\n", 311 | "3\n", 312 | "1\n", 313 | "1\n", 314 | "6\n", 315 | "4\n", 316 | "1\n", 317 | "1\n", 318 | "2\n", 319 | "1\n", 320 | "2\n", 321 | "2\n", 322 | "1\n", 323 | "2\n", 324 | "2\n", 325 | "8\n", 326 | "4\n", 327 | "6\n", 328 | "3\n", 329 | "6\n", 330 | "15\n", 331 | "7\n", 332 | "3\n", 333 | "4\n", 334 | "4\n", 335 | "1\n", 336 | "3\n", 337 | "5\n", 338 | "4\n", 339 | "5\n", 340 | "7\n", 341 | "13\n", 342 | "6\n", 343 | "2\n", 344 | "4\n", 345 | "4\n", 346 | "3\n", 347 | "11\n", 348 | "9\n", 349 | "4\n", 350 | "5\n", 351 | "5\n", 352 | "3\n", 353 | "4\n", 354 | "3\n", 355 | "2\n", 356 | "1\n", 357 | "1\n", 358 | "3\n", 359 | "9\n", 360 | "9\n", 361 | "5\n", 362 | "5\n", 363 | "1\n", 364 | "2\n", 365 | "10\n", 366 | "2\n", 367 | "7\n", 368 | "3\n", 369 | "1\n", 370 | "3\n", 371 | "5\n", 372 | "3\n", 373 | "2\n", 374 | "4\n", 375 | "3\n", 376 | "2\n", 377 | "3\n", 378 | "3\n", 379 | "4\n", 380 | "3\n", 381 | "2\n", 382 | "3\n", 383 | "2\n", 384 | "3\n", 385 | "2\n", 386 | "3\n", 387 | "7\n", 388 | "7\n", 389 | "2\n", 390 | "2\n", 391 | "6\n", 392 | "1\n", 393 | "2\n", 394 | "1\n", 395 | "1\n", 396 | "1\n", 397 | "1\n", 398 | "2\n", 399 | "1\n", 400 | "1\n", 401 | "1\n", 402 | "6\n", 403 | "1\n", 404 | "2\n", 405 | "2\n", 406 | "3\n", 407 | "1\n", 408 | "1\n", 409 | "6\n", 410 | "2\n", 411 | "7\n", 412 | "1\n", 413 | "4\n", 414 | "2\n", 415 | "1\n", 416 | "4\n", 417 | "4\n", 418 | "15\n", 419 | "3\n", 420 | "6\n", 421 | "1\n", 422 | "5\n", 423 | "2\n", 424 | "1\n", 425 | "6\n", 426 | "1\n", 427 | "1\n", 428 | "2\n", 429 | "1\n", 430 | "4\n", 431 | "3\n", 432 | "1\n", 433 | "17\n", 434 | "6\n", 435 | "1\n", 436 | "2\n", 437 | "4\n", 438 | "1\n", 439 | "1\n", 440 | "1\n", 441 | "1\n", 442 | "1\n", 443 | "1\n", 444 | "1\n", 445 | "1\n", 446 | "1\n", 447 | "6\n", 448 | "3\n", 449 | "5\n", 450 | "4\n", 451 | "7\n", 452 | "5\n", 453 | "3\n", 454 | "10\n", 455 | "4\n", 456 | "11\n", 457 | "2\n", 458 | "5\n", 459 | "5\n", 460 | "11\n", 461 | "6\n", 462 | "3\n", 463 | "5\n", 464 | "5\n", 465 | "10\n", 466 | "10\n", 467 | "1\n", 468 | "5\n", 469 | "8\n", 470 | "5\n", 471 | "3\n", 472 | "8\n", 473 | "5\n", 474 | "9\n", 475 | "1\n", 476 | "7\n", 477 | "1\n", 478 | "1\n", 479 | "1\n", 480 | "3\n", 481 | "1\n", 482 | "1\n", 483 | "8\n", 484 | "14\n", 485 | "7\n", 486 | "7\n", 487 | "4\n", 488 | "2\n", 489 | "6\n", 490 | "6\n", 491 | "1\n", 492 | "1\n", 493 | "2\n", 494 | "5\n", 495 | "6\n", 496 | "3\n", 497 | "6\n", 498 | "6\n", 499 | "5\n", 500 | "8\n", 501 | "8\n", 502 | "6\n", 503 | "4\n", 504 | "1\n", 505 | "2\n", 506 | "9\n", 507 | "7\n", 508 | "8\n", 509 | "6\n", 510 | "2\n", 511 | "9\n", 512 | "6\n", 513 | "33\n", 514 | "1\n", 515 | "4\n", 516 | "1\n", 517 | "5\n", 518 | "1\n", 519 | "2\n", 520 | "3\n", 521 | "4\n", 522 | "1\n", 523 | "4\n", 524 | "2\n", 525 | "2\n", 526 | "4\n", 527 | "5\n", 528 | "6\n", 529 | "5\n", 530 | "3\n", 531 | "2\n", 532 | "3\n", 533 | "4\n", 534 | "1\n", 535 | "1\n", 536 | "1\n", 537 | "9\n", 538 | "11\n", 539 | "10\n", 540 | "6\n", 541 | "14\n", 542 | "13\n", 543 | "7\n", 544 | "5\n", 545 | "8\n", 546 | "9\n", 547 | "15\n", 548 | "3\n", 549 | "8\n", 550 | "3\n", 551 | "8\n", 552 | "16\n", 553 | "15\n", 554 | "6\n", 555 | "37\n", 556 | "5\n", 557 | "12\n", 558 | "6\n", 559 | "5\n", 560 | "9\n", 561 | "5\n", 562 | "2\n", 563 | "2\n", 564 | "4\n", 565 | "2\n", 566 | "5\n", 567 | "3\n", 568 | "7\n", 569 | "4\n", 570 | "3\n", 571 | "1\n", 572 | "1\n", 573 | "2\n", 574 | "1\n", 575 | "1\n", 576 | "10\n", 577 | "7\n", 578 | "8\n", 579 | "1\n", 580 | "1\n", 581 | "1\n", 582 | "4\n", 583 | "4\n", 584 | "4\n", 585 | "1\n", 586 | "1\n", 587 | "4\n", 588 | "17\n", 589 | "11\n", 590 | "9\n", 591 | "5\n", 592 | "8\n", 593 | "5\n", 594 | "6\n", 595 | "2\n", 596 | "4\n", 597 | "4\n", 598 | "3\n", 599 | "5\n", 600 | "4\n", 601 | "4\n", 602 | "4\n", 603 | "8\n", 604 | "1\n", 605 | "2\n", 606 | "5\n", 607 | "8\n", 608 | "6\n", 609 | "11\n", 610 | "2\n", 611 | "6\n", 612 | "5\n", 613 | "1\n", 614 | "4\n", 615 | "3\n", 616 | "1\n", 617 | "2\n", 618 | "9\n", 619 | "1\n", 620 | "1\n", 621 | "1\n", 622 | "1\n", 623 | "4\n", 624 | "7\n", 625 | "2\n", 626 | "9\n", 627 | "2\n", 628 | "2\n", 629 | "1\n", 630 | "3\n", 631 | "2\n", 632 | "8\n", 633 | "6\n", 634 | "13\n", 635 | "7\n", 636 | "13\n", 637 | "6\n", 638 | "11\n", 639 | "6\n", 640 | "4\n", 641 | "4\n", 642 | "8\n", 643 | "1\n", 644 | "8\n", 645 | "14\n", 646 | "8\n", 647 | "12\n", 648 | "4\n", 649 | "8\n", 650 | "6\n", 651 | "3\n", 652 | "4\n", 653 | "4\n", 654 | "5\n", 655 | "7\n", 656 | "10\n", 657 | "8\n", 658 | "2\n", 659 | "4\n", 660 | "3\n", 661 | "4\n", 662 | "9\n", 663 | "5\n", 664 | "1\n", 665 | "3\n", 666 | "1\n", 667 | "2\n", 668 | "46\n", 669 | "2\n", 670 | "1\n", 671 | "1\n", 672 | "5\n", 673 | "8\n", 674 | "3\n", 675 | "5\n", 676 | "2\n", 677 | "4\n", 678 | "5\n", 679 | "6\n", 680 | "4\n", 681 | "1\n", 682 | "5\n", 683 | "1\n", 684 | "2\n", 685 | "1\n", 686 | "2\n", 687 | "3\n", 688 | "1\n", 689 | "2\n", 690 | "2\n", 691 | "1\n", 692 | "42\n", 693 | "7\n", 694 | "2\n", 695 | "1\n", 696 | "2\n", 697 | "3\n", 698 | "3\n", 699 | "6\n", 700 | "5\n", 701 | "5\n", 702 | "1\n", 703 | "1\n", 704 | "8\n", 705 | "1\n", 706 | "15\n", 707 | "3\n", 708 | "1\n", 709 | "1\n", 710 | "1\n", 711 | "3\n", 712 | "5\n", 713 | "1\n", 714 | "6\n", 715 | "3\n", 716 | "6\n", 717 | "1\n", 718 | "6\n", 719 | "5\n", 720 | "4\n", 721 | "9\n", 722 | "10\n", 723 | "5\n", 724 | "4\n", 725 | "2\n", 726 | "1\n", 727 | "1\n", 728 | "1\n", 729 | "8\n", 730 | "9\n", 731 | "6\n", 732 | "11\n", 733 | "6\n", 734 | "8\n", 735 | "8\n", 736 | "10\n", 737 | "5\n", 738 | "3\n", 739 | "9\n", 740 | "6\n", 741 | "2\n", 742 | "1\n", 743 | "9\n", 744 | "4\n", 745 | "6\n", 746 | "1\n", 747 | "8\n", 748 | "2\n", 749 | "7\n", 750 | "6\n", 751 | "3\n", 752 | "9\n", 753 | "3\n", 754 | "1\n", 755 | "4\n", 756 | "6\n", 757 | "8\n", 758 | "2\n", 759 | "5\n", 760 | "8\n", 761 | "4\n", 762 | "3\n", 763 | "2\n", 764 | "1\n", 765 | "1\n", 766 | "1\n", 767 | "1\n", 768 | "1\n", 769 | "1\n", 770 | "7\n", 771 | "5\n", 772 | "4\n", 773 | "3\n", 774 | "5\n", 775 | "4\n", 776 | "2\n", 777 | "7\n", 778 | "6\n", 779 | "4\n", 780 | "7\n", 781 | "16\n", 782 | "4\n", 783 | "2\n", 784 | "5\n", 785 | "3\n", 786 | "8\n", 787 | "4\n", 788 | "6\n", 789 | "2\n", 790 | "6\n", 791 | "9\n", 792 | "4\n", 793 | "3\n", 794 | "4\n", 795 | "6\n", 796 | "5\n", 797 | "5\n", 798 | "2\n", 799 | "2\n", 800 | "3\n", 801 | "3\n", 802 | "1\n", 803 | "1\n", 804 | "1\n", 805 | "8\n", 806 | "11\n", 807 | "8\n", 808 | "10\n", 809 | "1\n", 810 | "2\n", 811 | "9\n", 812 | "5\n", 813 | "7\n", 814 | "12\n", 815 | "16\n", 816 | "13\n", 817 | "4\n", 818 | "2\n", 819 | "2\n", 820 | "14\n", 821 | "9\n", 822 | "19\n", 823 | "4\n", 824 | "12\n", 825 | "4\n", 826 | "7\n", 827 | "10\n", 828 | "2\n", 829 | "3\n", 830 | "13\n", 831 | "10\n", 832 | "1\n", 833 | "33\n", 834 | "24\n", 835 | "4\n", 836 | "6\n", 837 | "3\n", 838 | "2\n", 839 | "2\n", 840 | "1\n", 841 | "2\n", 842 | "4\n", 843 | "1\n", 844 | "1\n", 845 | "1\n", 846 | "2\n", 847 | "8\n", 848 | "6\n", 849 | "3\n", 850 | "5\n", 851 | "8\n", 852 | "3\n", 853 | "2\n", 854 | "3\n", 855 | "3\n", 856 | "1\n", 857 | "1\n", 858 | "10\n", 859 | "12\n", 860 | "4\n", 861 | "11\n", 862 | "7\n", 863 | "9\n", 864 | "19\n", 865 | "10\n", 866 | "10\n", 867 | "5\n", 868 | "7\n", 869 | "10\n", 870 | "8\n", 871 | "3\n", 872 | "8\n", 873 | "11\n", 874 | "18\n", 875 | "11\n", 876 | "16\n", 877 | "7\n", 878 | "12\n", 879 | "11\n", 880 | "16\n", 881 | "17\n", 882 | "5\n", 883 | "12\n", 884 | "6\n", 885 | "7\n", 886 | "9\n", 887 | "4\n", 888 | "4\n", 889 | "9\n", 890 | "5\n", 891 | "3\n", 892 | "2\n", 893 | "7\n", 894 | "10\n", 895 | "8\n", 896 | "4\n", 897 | "9\n", 898 | "4\n", 899 | "16\n", 900 | "5\n", 901 | "12\n", 902 | "2\n", 903 | "8\n", 904 | "3\n", 905 | "11\n", 906 | "8\n", 907 | "4\n", 908 | "2\n", 909 | "3\n", 910 | "5\n", 911 | "7\n", 912 | "5\n", 913 | "15\n", 914 | "2\n", 915 | "8\n", 916 | "4\n", 917 | "27\n", 918 | "8\n", 919 | "8\n", 920 | "7\n", 921 | "7\n", 922 | "4\n", 923 | "15\n", 924 | "4\n", 925 | "6\n", 926 | "5\n", 927 | "5\n", 928 | "7\n", 929 | "3\n", 930 | "2\n", 931 | "5\n", 932 | "10\n", 933 | "3\n", 934 | "23\n", 935 | "20\n", 936 | "7\n", 937 | "3\n", 938 | "3\n", 939 | "2\n", 940 | "2\n", 941 | "1\n", 942 | "2\n", 943 | "1\n", 944 | "1\n", 945 | "2\n", 946 | "3\n", 947 | "3\n", 948 | "1\n", 949 | "10\n", 950 | "7\n", 951 | "4\n", 952 | "2\n", 953 | "1\n", 954 | "1\n", 955 | "1\n", 956 | "10\n", 957 | "7\n", 958 | "8\n", 959 | "2\n", 960 | "1\n", 961 | "11\n", 962 | "2\n", 963 | "12\n", 964 | "1\n", 965 | "6\n", 966 | "2\n", 967 | "5\n", 968 | "2\n", 969 | "7\n", 970 | "1\n", 971 | "5\n", 972 | "2\n", 973 | "5\n", 974 | "3\n", 975 | "4\n", 976 | "5\n", 977 | "5\n", 978 | "5\n", 979 | "17\n", 980 | "5\n", 981 | "8\n", 982 | "7\n", 983 | "2\n", 984 | "1\n", 985 | "2\n", 986 | "1\n", 987 | "1\n", 988 | "4\n", 989 | "4\n", 990 | "7\n", 991 | "3\n", 992 | "1\n", 993 | "2\n", 994 | "1\n", 995 | "1\n", 996 | "7\n", 997 | "4\n", 998 | "9\n", 999 | "8\n", 1000 | "3\n", 1001 | "5\n", 1002 | "6\n", 1003 | "12\n", 1004 | "6\n", 1005 | "1\n", 1006 | "9\n", 1007 | "5\n", 1008 | "14\n", 1009 | "9\n", 1010 | "17\n", 1011 | "11\n", 1012 | "4\n", 1013 | "12\n", 1014 | "3\n", 1015 | "9\n", 1016 | "10\n", 1017 | "3\n", 1018 | "4\n", 1019 | "9\n", 1020 | "7\n", 1021 | "3\n", 1022 | "4\n", 1023 | "4\n", 1024 | "7\n", 1025 | "4\n", 1026 | "5\n", 1027 | "3\n", 1028 | "1\n", 1029 | "4\n", 1030 | "2\n", 1031 | "1\n", 1032 | "15\n", 1033 | "1\n", 1034 | "2\n", 1035 | "3\n", 1036 | "1\n", 1037 | "1\n", 1038 | "4\n", 1039 | "1\n", 1040 | "1\n", 1041 | "1\n", 1042 | "1\n", 1043 | "4\n", 1044 | "8\n", 1045 | "23\n", 1046 | "1\n", 1047 | "16\n", 1048 | "5\n", 1049 | "5\n", 1050 | "10\n", 1051 | "1\n", 1052 | "9\n", 1053 | "8\n", 1054 | "9\n", 1055 | "7\n", 1056 | "12\n", 1057 | "14\n", 1058 | "11\n", 1059 | "8\n", 1060 | "12\n", 1061 | "1\n", 1062 | "9\n", 1063 | "10\n", 1064 | "3\n", 1065 | "6\n", 1066 | "6\n", 1067 | "7\n", 1068 | "3\n", 1069 | "2\n", 1070 | "4\n", 1071 | "1\n", 1072 | "4\n", 1073 | "8\n", 1074 | "4\n", 1075 | "3\n", 1076 | "4\n", 1077 | "3\n", 1078 | "1\n", 1079 | "1\n", 1080 | "28\n", 1081 | "8\n", 1082 | "3\n", 1083 | "4\n", 1084 | "1\n", 1085 | "3\n", 1086 | "2\n", 1087 | "1\n", 1088 | "2\n", 1089 | "1\n", 1090 | "2\n", 1091 | "2\n", 1092 | "4\n", 1093 | "42\n", 1094 | "4\n", 1095 | "3\n", 1096 | "2\n", 1097 | "4\n", 1098 | "10\n", 1099 | "9\n", 1100 | "2\n", 1101 | "2\n", 1102 | "11\n", 1103 | "2\n", 1104 | "3\n", 1105 | "2\n", 1106 | "2\n", 1107 | "2\n", 1108 | "9\n", 1109 | "3\n", 1110 | "3\n", 1111 | "1\n", 1112 | "1\n", 1113 | "1\n", 1114 | "3\n", 1115 | "7\n", 1116 | "7\n", 1117 | "6\n", 1118 | "1\n", 1119 | "1\n", 1120 | "3\n", 1121 | "3\n", 1122 | "1\n", 1123 | "1\n", 1124 | "1\n", 1125 | "6\n", 1126 | "1\n", 1127 | "2\n", 1128 | "3\n", 1129 | "5\n", 1130 | "1\n", 1131 | "5\n", 1132 | "9\n", 1133 | "1\n", 1134 | "9\n", 1135 | "9\n", 1136 | "5\n", 1137 | "7\n", 1138 | "7\n", 1139 | "5\n", 1140 | "3\n", 1141 | "9\n", 1142 | "8\n", 1143 | "5\n", 1144 | "10\n", 1145 | "22\n", 1146 | "4\n", 1147 | "4\n", 1148 | "9\n", 1149 | "7\n", 1150 | "11\n", 1151 | "7\n", 1152 | "11\n", 1153 | "8\n", 1154 | "1\n", 1155 | "6\n", 1156 | "11\n", 1157 | "9\n", 1158 | "10\n", 1159 | "9\n", 1160 | "6\n", 1161 | "6\n", 1162 | "5\n", 1163 | "1\n", 1164 | "3\n", 1165 | "3\n", 1166 | "4\n", 1167 | "4\n", 1168 | "2\n", 1169 | "1\n", 1170 | "1\n", 1171 | "7\n", 1172 | "12\n", 1173 | "1\n", 1174 | "3\n", 1175 | "2\n", 1176 | "2\n", 1177 | "5\n", 1178 | "5\n", 1179 | "11\n", 1180 | "3\n", 1181 | "6\n", 1182 | "5\n", 1183 | "2\n", 1184 | "5\n", 1185 | "1\n", 1186 | "6\n", 1187 | "2\n", 1188 | "7\n", 1189 | "3\n", 1190 | "8\n", 1191 | "3\n", 1192 | "8\n", 1193 | "8\n", 1194 | "6\n", 1195 | "11\n", 1196 | "2\n", 1197 | "4\n", 1198 | "7\n", 1199 | "4\n", 1200 | "4\n", 1201 | "2\n", 1202 | "4\n", 1203 | "4\n", 1204 | "4\n", 1205 | "2\n", 1206 | "3\n", 1207 | "5\n", 1208 | "1\n", 1209 | "4\n", 1210 | "1\n", 1211 | "2\n", 1212 | "6\n", 1213 | "6\n", 1214 | "6\n", 1215 | "1\n", 1216 | "6\n", 1217 | "4\n", 1218 | "7\n", 1219 | "6\n", 1220 | "9\n", 1221 | "4\n", 1222 | "17\n", 1223 | "6\n", 1224 | "7\n", 1225 | "15\n", 1226 | "5\n", 1227 | "8\n", 1228 | "19\n", 1229 | "4\n", 1230 | "12\n", 1231 | "6\n", 1232 | "7\n", 1233 | "7\n", 1234 | "4\n", 1235 | "5\n", 1236 | "14\n", 1237 | "11\n", 1238 | "2\n", 1239 | "4\n", 1240 | "13\n", 1241 | "4\n", 1242 | "3\n", 1243 | "4\n", 1244 | "2\n", 1245 | "2\n", 1246 | "2\n", 1247 | "1\n", 1248 | "1\n", 1249 | "1\n", 1250 | "2\n", 1251 | "7\n", 1252 | "1\n", 1253 | "7\n", 1254 | "1\n", 1255 | "3\n", 1256 | "2\n", 1257 | "1\n", 1258 | "2\n", 1259 | "1\n", 1260 | "1\n", 1261 | "23\n", 1262 | "2\n", 1263 | "4\n", 1264 | "4\n", 1265 | "3\n", 1266 | "3\n", 1267 | "3\n", 1268 | "4\n", 1269 | "2\n", 1270 | "3\n", 1271 | "3\n", 1272 | "3\n", 1273 | "2\n", 1274 | "1\n", 1275 | "5\n", 1276 | "2\n", 1277 | "2\n", 1278 | "3\n", 1279 | "3\n", 1280 | "2\n", 1281 | "3\n", 1282 | "9\n", 1283 | "1\n", 1284 | "2\n", 1285 | "1\n", 1286 | "3\n", 1287 | "6\n", 1288 | "3\n", 1289 | "14\n", 1290 | "2\n", 1291 | "5\n", 1292 | "1\n", 1293 | "2\n", 1294 | "4\n", 1295 | "2\n", 1296 | "3\n", 1297 | "6\n", 1298 | "1\n", 1299 | "1\n", 1300 | "12\n", 1301 | "2\n", 1302 | "2\n", 1303 | "4\n", 1304 | "19\n", 1305 | "9\n", 1306 | "2\n", 1307 | "2\n", 1308 | "4\n", 1309 | "3\n", 1310 | "3\n", 1311 | "12\n", 1312 | "5\n", 1313 | "2\n", 1314 | "9\n", 1315 | "5\n", 1316 | "2\n", 1317 | "2\n", 1318 | "3\n", 1319 | "10\n", 1320 | "6\n", 1321 | "2\n", 1322 | "1\n", 1323 | "7\n", 1324 | "1\n", 1325 | "2\n", 1326 | "3\n", 1327 | "8\n", 1328 | "7\n", 1329 | "9\n", 1330 | "9\n", 1331 | "7\n", 1332 | "2\n", 1333 | "1\n", 1334 | "1\n", 1335 | "4\n", 1336 | "1\n", 1337 | "2\n", 1338 | "1\n", 1339 | "12\n", 1340 | "4\n", 1341 | "4\n", 1342 | "6\n", 1343 | "4\n", 1344 | "7\n", 1345 | "4\n", 1346 | "10\n", 1347 | "11\n", 1348 | "5\n", 1349 | "10\n", 1350 | "2\n", 1351 | "1\n", 1352 | "6\n", 1353 | "2\n", 1354 | "7\n", 1355 | "2\n", 1356 | "3\n", 1357 | "2\n", 1358 | "2\n", 1359 | "4\n", 1360 | "5\n", 1361 | "1\n", 1362 | "3\n", 1363 | "7\n", 1364 | "5\n", 1365 | "2\n", 1366 | "5\n", 1367 | "2\n", 1368 | "17\n", 1369 | "6\n", 1370 | "6\n", 1371 | "1\n", 1372 | "20\n", 1373 | "1\n", 1374 | "1\n", 1375 | "1\n", 1376 | "1\n", 1377 | "10\n", 1378 | "20\n", 1379 | "9\n", 1380 | "12\n", 1381 | "20\n", 1382 | "1\n", 1383 | "11\n", 1384 | "4\n", 1385 | "7\n", 1386 | "3\n", 1387 | "1\n", 1388 | "3\n", 1389 | "13\n", 1390 | "1\n", 1391 | "14\n", 1392 | "5\n", 1393 | "2\n", 1394 | "18\n", 1395 | "1\n", 1396 | "3\n", 1397 | "3\n", 1398 | "8\n", 1399 | "5\n", 1400 | "14\n", 1401 | "3\n", 1402 | "1\n", 1403 | "12\n", 1404 | "1\n", 1405 | "1\n", 1406 | "12\n", 1407 | "11\n", 1408 | "14\n", 1409 | "8\n", 1410 | "9\n", 1411 | "11\n", 1412 | "2\n", 1413 | "5\n", 1414 | "4\n", 1415 | "11\n", 1416 | "4\n", 1417 | "6\n", 1418 | "1\n", 1419 | "4\n", 1420 | "6\n", 1421 | "6\n", 1422 | "1\n", 1423 | "1\n", 1424 | "9\n", 1425 | "5\n", 1426 | "6\n", 1427 | "4\n", 1428 | "2\n", 1429 | "5\n", 1430 | "10\n", 1431 | "5\n", 1432 | "4\n", 1433 | "10\n", 1434 | "9\n", 1435 | "5\n", 1436 | "7\n", 1437 | "8\n", 1438 | "11\n", 1439 | "8\n", 1440 | "5\n", 1441 | "9\n", 1442 | "12\n", 1443 | "15\n", 1444 | "15\n", 1445 | "10\n", 1446 | "6\n", 1447 | "11\n", 1448 | "18\n", 1449 | "2\n", 1450 | "10\n", 1451 | "4\n", 1452 | "15\n", 1453 | "5\n", 1454 | "4\n", 1455 | "3\n", 1456 | "1\n", 1457 | "7\n", 1458 | "12\n", 1459 | "7\n", 1460 | "8\n", 1461 | "2\n", 1462 | "2\n", 1463 | "1\n", 1464 | "8\n", 1465 | "5\n", 1466 | "13\n", 1467 | "4\n", 1468 | "9\n", 1469 | "6\n", 1470 | "4\n", 1471 | "1\n", 1472 | "1\n", 1473 | "7\n", 1474 | "2\n", 1475 | "11\n", 1476 | "5\n", 1477 | "7\n", 1478 | "4\n", 1479 | "8\n", 1480 | "8\n", 1481 | "9\n", 1482 | "9\n", 1483 | "4\n", 1484 | "6\n", 1485 | "2\n", 1486 | "1\n", 1487 | "5\n", 1488 | "5\n", 1489 | "2\n", 1490 | "9\n", 1491 | "2\n", 1492 | "7\n", 1493 | "7\n", 1494 | "4\n", 1495 | "4\n", 1496 | "2\n", 1497 | "9\n", 1498 | "7\n", 1499 | "4\n", 1500 | "6\n", 1501 | "8\n", 1502 | "8\n", 1503 | "10\n", 1504 | "3\n", 1505 | "5\n", 1506 | "3\n", 1507 | "2\n", 1508 | "12\n", 1509 | "3\n", 1510 | "4\n", 1511 | "8\n", 1512 | "5\n", 1513 | "6\n", 1514 | "3\n", 1515 | "2\n", 1516 | "3\n", 1517 | "3\n", 1518 | "6\n", 1519 | "3\n", 1520 | "3\n", 1521 | "3\n", 1522 | "6\n", 1523 | "1\n", 1524 | "6\n", 1525 | "6\n", 1526 | "5\n", 1527 | "3\n", 1528 | "10\n", 1529 | "5\n", 1530 | "4\n", 1531 | "3\n", 1532 | "4\n", 1533 | "3\n", 1534 | "17\n", 1535 | "4\n", 1536 | "1\n", 1537 | "2\n", 1538 | "1\n", 1539 | "3\n", 1540 | "1\n", 1541 | "9\n", 1542 | "2\n", 1543 | "1\n", 1544 | "1\n", 1545 | "1\n", 1546 | "1\n", 1547 | "1\n", 1548 | "2\n", 1549 | "2\n", 1550 | "2\n", 1551 | "1\n", 1552 | "2\n", 1553 | "1\n", 1554 | "2\n", 1555 | "4\n", 1556 | "2\n", 1557 | "2\n", 1558 | "2\n", 1559 | "1\n", 1560 | "1\n", 1561 | "1\n", 1562 | "41\n", 1563 | "1\n", 1564 | "2\n", 1565 | "1\n", 1566 | "3\n", 1567 | "4\n", 1568 | "4\n", 1569 | "6\n", 1570 | "9\n", 1571 | "6\n", 1572 | "5\n", 1573 | "5\n", 1574 | "1\n", 1575 | "4\n", 1576 | "2\n", 1577 | "1\n", 1578 | "1\n", 1579 | "1\n", 1580 | "1\n", 1581 | "2\n", 1582 | "1\n", 1583 | "7\n", 1584 | "4\n", 1585 | "9\n", 1586 | "4\n", 1587 | "10\n", 1588 | "8\n", 1589 | "6\n", 1590 | "2\n", 1591 | "9\n", 1592 | "2\n", 1593 | "8\n", 1594 | "1\n", 1595 | "7\n", 1596 | "11\n", 1597 | "2\n", 1598 | "5\n", 1599 | "17\n", 1600 | "11\n", 1601 | "10\n", 1602 | "3\n", 1603 | "1\n", 1604 | "13\n", 1605 | "6\n", 1606 | "2\n", 1607 | "2\n", 1608 | "4\n", 1609 | "3\n", 1610 | "4\n", 1611 | "4\n", 1612 | "5\n", 1613 | "4\n", 1614 | "5\n", 1615 | "43\n", 1616 | "2\n", 1617 | "2\n", 1618 | "1\n", 1619 | "1\n", 1620 | "1\n", 1621 | "3\n", 1622 | "4\n", 1623 | "1\n", 1624 | "1\n", 1625 | "2\n", 1626 | "1\n", 1627 | "2\n", 1628 | "1\n", 1629 | "1\n", 1630 | "1\n", 1631 | "3\n", 1632 | "3\n", 1633 | "7\n", 1634 | "6\n", 1635 | "4\n", 1636 | "7\n", 1637 | "6\n", 1638 | "8\n", 1639 | "6\n", 1640 | "3\n", 1641 | "2\n", 1642 | "2\n", 1643 | "2\n", 1644 | "4\n", 1645 | "4\n", 1646 | "4\n", 1647 | "2\n", 1648 | "6\n", 1649 | "5\n", 1650 | "4\n", 1651 | "4\n", 1652 | "4\n", 1653 | "4\n", 1654 | "1\n", 1655 | "6\n", 1656 | "4\n", 1657 | "2\n", 1658 | "3\n", 1659 | "4\n", 1660 | "3\n", 1661 | "2\n", 1662 | "5\n", 1663 | "3\n", 1664 | "3\n", 1665 | "1\n", 1666 | "1\n", 1667 | "5\n", 1668 | "1\n", 1669 | "1\n", 1670 | "2\n", 1671 | "1\n", 1672 | "1\n", 1673 | "1\n", 1674 | "13\n", 1675 | "5\n", 1676 | "8\n", 1677 | "4\n", 1678 | "5\n", 1679 | "7\n", 1680 | "14\n", 1681 | "5\n", 1682 | "4\n", 1683 | "6\n", 1684 | "4\n", 1685 | "6\n", 1686 | "3\n", 1687 | "7\n", 1688 | "5\n", 1689 | "6\n", 1690 | "3\n", 1691 | "4\n", 1692 | "10\n", 1693 | "3\n", 1694 | "6\n", 1695 | "5\n", 1696 | "13\n", 1697 | "5\n", 1698 | "4\n", 1699 | "1\n", 1700 | "9\n", 1701 | "4\n", 1702 | "3\n", 1703 | "6\n", 1704 | "1\n", 1705 | "2\n", 1706 | "1\n", 1707 | "1\n", 1708 | "5\n", 1709 | "2\n", 1710 | "1\n", 1711 | "1\n", 1712 | "4\n", 1713 | "8\n", 1714 | "9\n", 1715 | "2\n", 1716 | "3\n", 1717 | "1\n", 1718 | "4\n", 1719 | "3\n", 1720 | "3\n", 1721 | "2\n", 1722 | "4\n", 1723 | "2\n", 1724 | "3\n", 1725 | "2\n", 1726 | "1\n", 1727 | "3\n", 1728 | "2\n", 1729 | "1\n", 1730 | "2\n", 1731 | "12\n", 1732 | "3\n", 1733 | "3\n", 1734 | "9\n", 1735 | "2\n", 1736 | "3\n", 1737 | "3\n", 1738 | "2\n", 1739 | "2\n", 1740 | "3\n", 1741 | "2\n", 1742 | "3\n", 1743 | "2\n", 1744 | "1\n", 1745 | "1\n", 1746 | "1\n", 1747 | "1\n", 1748 | "1\n", 1749 | "7\n", 1750 | "1\n", 1751 | "5\n", 1752 | "15\n", 1753 | "1\n", 1754 | "3\n", 1755 | "2\n", 1756 | "4\n", 1757 | "3\n", 1758 | "6\n", 1759 | "3\n", 1760 | "2\n", 1761 | "1\n", 1762 | "3\n", 1763 | "3\n", 1764 | "1\n", 1765 | "6\n", 1766 | "3\n", 1767 | "2\n", 1768 | "3\n", 1769 | "22\n", 1770 | "2\n", 1771 | "2\n", 1772 | "2\n", 1773 | "12\n", 1774 | "14\n", 1775 | "2\n", 1776 | "1\n", 1777 | "4\n", 1778 | "1\n", 1779 | "1\n", 1780 | "3\n", 1781 | "1\n", 1782 | "1\n", 1783 | "7\n", 1784 | "7\n", 1785 | "1\n", 1786 | "4\n", 1787 | "5\n", 1788 | "1\n", 1789 | "2\n", 1790 | "5\n", 1791 | "3\n", 1792 | "1\n", 1793 | "1\n", 1794 | "1\n", 1795 | "1\n", 1796 | "1\n", 1797 | "1\n", 1798 | "1\n", 1799 | "9\n", 1800 | "1\n", 1801 | "1\n", 1802 | "2\n", 1803 | "1\n", 1804 | "1\n", 1805 | "1\n", 1806 | "1\n", 1807 | "1\n", 1808 | "1\n", 1809 | "15\n", 1810 | "11\n", 1811 | "2\n", 1812 | "9\n", 1813 | "10\n", 1814 | "10\n", 1815 | "3\n", 1816 | "6\n", 1817 | "3\n", 1818 | "12\n", 1819 | "12\n", 1820 | "5\n", 1821 | "4\n", 1822 | "11\n", 1823 | "5\n", 1824 | "9\n", 1825 | "5\n", 1826 | "8\n", 1827 | "9\n", 1828 | "13\n", 1829 | "10\n", 1830 | "7\n", 1831 | "5\n", 1832 | "3\n", 1833 | "10\n", 1834 | "8\n", 1835 | "10\n", 1836 | "2\n", 1837 | "1\n", 1838 | "3\n", 1839 | "1\n", 1840 | "9\n", 1841 | "8\n", 1842 | "3\n", 1843 | "4\n", 1844 | "6\n", 1845 | "1\n", 1846 | "3\n", 1847 | "1\n", 1848 | "1\n", 1849 | "2\n", 1850 | "9\n", 1851 | "3\n", 1852 | "9\n", 1853 | "1\n", 1854 | "4\n", 1855 | "2\n", 1856 | "1\n", 1857 | "2\n", 1858 | "4\n", 1859 | "1\n", 1860 | "4\n", 1861 | "1\n", 1862 | "1\n", 1863 | "2\n", 1864 | "3\n", 1865 | "1\n", 1866 | "3\n", 1867 | "2\n", 1868 | "2\n", 1869 | "4\n", 1870 | "3\n", 1871 | "2\n", 1872 | "1\n", 1873 | "2\n", 1874 | "1\n", 1875 | "1\n", 1876 | "1\n", 1877 | "1\n", 1878 | "2\n", 1879 | "1\n", 1880 | "2\n", 1881 | "1\n", 1882 | "1\n", 1883 | "1\n", 1884 | "1\n", 1885 | "1\n", 1886 | "1\n", 1887 | "2\n", 1888 | "1\n", 1889 | "1\n", 1890 | "1\n", 1891 | "1\n", 1892 | "4\n", 1893 | "1\n", 1894 | "12\n", 1895 | "3\n", 1896 | "14\n", 1897 | "9\n", 1898 | "9\n", 1899 | "5\n", 1900 | "2\n", 1901 | "16\n", 1902 | "7\n", 1903 | "4\n", 1904 | "12\n", 1905 | "4\n", 1906 | "8\n", 1907 | "8\n", 1908 | "4\n", 1909 | "8\n", 1910 | "3\n", 1911 | "3\n", 1912 | "6\n", 1913 | "8\n", 1914 | "3\n", 1915 | "12\n", 1916 | "5\n", 1917 | "3\n", 1918 | "13\n", 1919 | "4\n", 1920 | "1\n", 1921 | "1\n", 1922 | "2\n", 1923 | "4\n", 1924 | "1\n", 1925 | "2\n", 1926 | "1\n", 1927 | "1\n", 1928 | "7\n", 1929 | "3\n", 1930 | "1\n", 1931 | "1\n", 1932 | "1\n", 1933 | "2\n", 1934 | "1\n", 1935 | "2\n", 1936 | "1\n", 1937 | "5\n", 1938 | "9\n", 1939 | "19\n", 1940 | "2\n", 1941 | "8\n", 1942 | "3\n", 1943 | "8\n", 1944 | "5\n", 1945 | "4\n", 1946 | "10\n", 1947 | "10\n", 1948 | "9\n", 1949 | "9\n", 1950 | "4\n", 1951 | "4\n", 1952 | "4\n", 1953 | "10\n", 1954 | "14\n", 1955 | "13\n", 1956 | "5\n", 1957 | "5\n", 1958 | "1\n", 1959 | "4\n", 1960 | "5\n", 1961 | "10\n", 1962 | "7\n", 1963 | "8\n", 1964 | "5\n", 1965 | "4\n", 1966 | "6\n", 1967 | "2\n", 1968 | "5\n", 1969 | "1\n", 1970 | "1\n", 1971 | "3\n", 1972 | "2\n", 1973 | "5\n", 1974 | "2\n", 1975 | "1\n" 1976 | ] 1977 | } 1978 | ], 1979 | "source": [ 1980 | "timestamps = dataset.timestamps\n", 1981 | "for u in timestamps.keys():\n", 1982 | " for v in timestamps[u].keys():\n", 1983 | " print(len(timestamps[u][v]))" 1984 | ] 1985 | }, 1986 | { 1987 | "cell_type": "code", 1988 | "execution_count": null, 1989 | "metadata": {}, 1990 | "outputs": [], 1991 | "source": [] 1992 | } 1993 | ], 1994 | "metadata": { 1995 | "kernelspec": { 1996 | "display_name": "Python 3", 1997 | "language": "python", 1998 | "name": "python3" 1999 | }, 2000 | "language_info": { 2001 | "codemirror_mode": { 2002 | "name": "ipython", 2003 | "version": 3 2004 | }, 2005 | "file_extension": ".py", 2006 | "mimetype": "text/x-python", 2007 | "name": "python", 2008 | "nbconvert_exporter": "python", 2009 | "pygments_lexer": "ipython3", 2010 | "version": "3.6.8" 2011 | } 2012 | }, 2013 | "nbformat": 4, 2014 | "nbformat_minor": 2 2015 | } 2016 | -------------------------------------------------------------------------------- /notebooks/dataset_testing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 16, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "The autoreload extension is already loaded. To reload it, use:\n", 20 | " %reload_ext autoreload\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "from math import ceil\n", 26 | "import json\n", 27 | "import os\n", 28 | "import sys\n", 29 | "\n", 30 | "src_path = os.path.join(os.path.dirname(os.path.abspath('')), 'src')\n", 31 | "sys.path.append(src_path)\n", 32 | "\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import numpy as np\n", 35 | "from torch.utils.data import DataLoader\n", 36 | "\n", 37 | "from datasets import link_prediction\n", 38 | "\n", 39 | "%load_ext autoreload\n", 40 | "%autoreload 2\n", 41 | "\n", 42 | "%matplotlib inline" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "# Set up arguments for datasets, models and training." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 17, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "args = {\n", 59 | " \"task\" : \"link_prediction\",\n", 60 | " \n", 61 | " \"dataset\" : \"IAContact\",\n", 62 | " \"dataset_path\" : \"/Users/raunak/Documents/Datasets/temporal-networks-network-repository/ia-contact/ia-contact.edges\",\n", 63 | " \"mode\" : \"train\",\n", 64 | " \"generate_neg_examples\" : False,\n", 65 | " \n", 66 | " \"duplicate_examples\" : True,\n", 67 | " \"repeat_examples\" : True,\n", 68 | " \n", 69 | " \"self_loop\" : False,\n", 70 | " \"normalize_adj\" : False,\n", 71 | " \n", 72 | " \"cuda\" : \"True\",\n", 73 | " \"model\" : \"GraphSAGE\",\n", 74 | " \"agg_class\" : \"MaxPoolAggregator\",\n", 75 | " \"hidden_dims\" : [64],\n", 76 | " \"dropout\" : 0.5,\n", 77 | " \"num_samples\" : -1,\n", 78 | " \n", 79 | " \"epochs\" : 3,\n", 80 | " \"batch_size\" : 32,\n", 81 | " \"lr\" : 1e-4,\n", 82 | " \"weight_decay\" : 5e-4,\n", 83 | " \"stats_per_batch\" : 3,\n", 84 | " \"visdom\" : False,\n", 85 | " \n", 86 | " \"load\" : False,\n", 87 | " \"save\" : False\n", 88 | "}\n", 89 | "config = args\n", 90 | "config['num_layers'] = len(config['hidden_dims']) + 1\n", 91 | "\n", 92 | "\n", 93 | "if config['cuda'] and torch.cuda.is_available():\n", 94 | " device = 'cuda:0'\n", 95 | "else:\n", 96 | " device = 'cpu'\n", 97 | "config['device'] = device" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "# Get the dataset, dataloader and model." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 18, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "--------------------------------\n", 117 | "Reading dataset from /Users/raunak/Documents/Datasets/temporal-networks-network-repository/ia-contact/ia-contact.edges\n", 118 | "Finished reading data.\n", 119 | "Setting up graph.\n", 120 | "Finished setting up graph.\n", 121 | "Setting up examples.\n", 122 | "Finished setting up examples.\n", 123 | "Dataset properties:\n", 124 | "Mode: train\n", 125 | "Number of vertices: 274\n", 126 | "Number of static edges: 1686\n", 127 | "Number of temporal edges: 8473\n", 128 | "Number of examples/datapoints: 11298\n", 129 | "--------------------------------\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "dataset_args = (config['task'], config['dataset'], config['dataset_path'],\n", 135 | " config['generate_neg_examples'], 'train',\n", 136 | " config['duplicate_examples'], config['repeat_examples'],\n", 137 | " config['num_layers'], config['self_loop'],\n", 138 | " config['normalize_adj'])\n", 139 | "dataset = utils.get_dataset(dataset_args)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 19, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "loader = DataLoader(dataset=dataset, batch_size=config['batch_size'],\n", 149 | " shuffle=True, collate_fn=dataset.collate_wrapper)\n", 150 | "input_dim, output_dim = dataset.get_dims()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "# Stuff" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 20, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "node = 0\n", 167 | "nbrs = dataset.nbrs_s[node]\n", 168 | "timestamps = dataset.timestamps[node]" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 22, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "Last 2 times: 1.2308601250553888e-05, 1.2238703676506584e-05\n", 181 | "Last 2 times: 1.2327872085999236e-05, 1.2285766939001168e-05\n", 182 | "Last 2 times: 4.5479352374022196e-05, 1.3297872340425532e-05\n", 183 | "Last 2 times: 3.256904637832204e-05, 3.155868337172973e-05\n", 184 | "Last 2 times: 1.2606842994377348e-05, 1.2584156546907443e-05\n", 185 | "Last 2 times: 1.2461214469962243e-05, 1.2299366582620994e-05\n", 186 | "Last 2 times: 3.023705853894533e-05, 2.9787614309969916e-05\n", 187 | "Last 2 times: 4.376176097326157e-05, 1.3499831252109349e-05\n", 188 | "Last 2 times: 1.2889753934597388e-05, 1.2830217728794857e-05\n", 189 | "Last 2 times: 1.2326808342783887e-05, 1.2297097884899164e-05\n", 190 | "Last 2 times: 1.3010838028077389e-05, 1.2921232168699607e-05\n", 191 | "Last 2 times: 4.6448975800083605e-05, 4.540501271340356e-05\n", 192 | "Last 2 times: 1.2272197336933178e-05, 1.222135314822057e-05\n", 193 | "Last 2 times: 1.2272197336933178e-05, 1.2228377172065495e-05\n", 194 | "Last 2 times: 4.394831677946735e-05, 4.352367688022284e-05\n", 195 | "Last 2 times: 1.2606842994377348e-05, 1.229150534065907e-05\n", 196 | "Last 2 times: 1.229150534065907e-05, 1.222135314822057e-05\n", 197 | "Last 2 times: 4.2865103519225e-05, 1.2255502720721604e-05\n", 198 | "Last 2 times: 1.2329848096271454e-05, 1.2294830023974919e-05\n", 199 | "Last 2 times: 1.2710194847287009e-05, 1.2519248344329406e-05\n", 200 | "Last 2 times: 4.303111149360988e-05, 4.299965600275198e-05\n", 201 | "Last 2 times: 3.737200089692802e-05, 1.25208158563612e-05\n", 202 | "Last 2 times: 1.2437810945273631e-05, 1.2347813202281876e-05\n", 203 | "Last 2 times: 1.3998152243903804e-05, 1.2570552224359216e-05\n", 204 | "Last time: 4.3956043956043955e-05\n", 205 | "Last 2 times: 1.2300123001230013e-05, 1.2227031521287262e-05\n", 206 | "Last time: 3.914660403210022e-05\n", 207 | "Last time: 3.8229222417616025e-05\n", 208 | "Last 2 times: 1.2786088735455825e-05, 1.2746322685905116e-05\n", 209 | "Last 2 times: 1.2743723716069835e-05, 1.2665281929175743e-05\n", 210 | "Last 2 times: 1.229150534065907e-05, 1.2281994595922379e-05\n", 211 | "Last 2 times: 1.2367207113617532e-05, 1.2328936012822093e-05\n", 212 | "Last 2 times: 1.355785135171778e-05, 1.3389031705227079e-05\n", 213 | "Last 2 times: 1.3091575571119984e-05, 1.2238703676506584e-05\n", 214 | "Last 2 times: 1.3343474373857465e-05, 1.2991230919129587e-05\n", 215 | "Last 2 times: 1.2386508614816741e-05, 1.2329848096271454e-05\n", 216 | "Last 2 times: 1.2442453651860147e-05, 1.234796567265543e-05\n", 217 | "Last 2 times: 1.273268990806998e-05, 1.2726694241170856e-05\n", 218 | "Last 2 times: 1.2554454948338417e-05, 1.2478474631261075e-05\n", 219 | "Last 2 times: 1.2328936012822093e-05, 1.2272197336933178e-05\n", 220 | "Last 2 times: 1.2952528981283595e-05, 1.2732365673542144e-05\n", 221 | "Last 2 times: 1.2329848096271454e-05, 1.230981338322911e-05\n", 222 | "Last 2 times: 1.2422514565398328e-05, 1.2329848096271454e-05\n", 223 | "Last 2 times: 1.2354830738818879e-05, 1.2326808342783887e-05\n", 224 | "Last time: 1.2554454948338417e-05\n", 225 | "Last 2 times: 1.2503907471084714e-05, 1.229150534065907e-05\n", 226 | "Last 2 times: 4.3521782652217434e-05, 3.7213456385829117e-05\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "for nbr in nbrs:\n", 232 | " times = timestamps[nbr]\n", 233 | " if len(times) == 1:\n", 234 | " print('Last time: {}'.format(times[-1]))\n", 235 | " else:\n", 236 | " print('Last 2 times: {}, {}'.format(times[-2], times[-1]))" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.6.8" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 2 268 | } 269 | -------------------------------------------------------------------------------- /notebooks/simple_baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import sys\n", 18 | "\n", 19 | "src_path = os.path.join(os.path.dirname(os.path.abspath('')), 'src')\n", 20 | "sys.path.append(src_path)\n", 21 | "\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import numpy as np\n", 24 | "from sklearn.metrics import classification_report, roc_curve, roc_auc_score\n", 25 | "\n", 26 | "from datasets import link_prediction\n", 27 | "import utils\n", 28 | "\n", 29 | "%load_ext autoreload\n", 30 | "%autoreload 2\n", 31 | "\n", 32 | "%matplotlib inline" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "# Set up arguments for datasets." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 5, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "args = {\n", 49 | " \"task\" : \"link_prediction\",\n", 50 | " \n", 51 | " \"dataset\" : \"IAContact\",\n", 52 | " \"dataset_path\" : \"/Users/raunak/Documents/Datasets/temporal-networks-network-repository/ia-contact/ia-contact.edges\",\n", 53 | " \"mode\" : \"train\",\n", 54 | " \"generate_neg_examples\" : False,\n", 55 | " \n", 56 | " \"duplicate_examples\" : False,\n", 57 | " \"repeat_examples\" : False,\n", 58 | " \n", 59 | " \"self_loop\" : False,\n", 60 | " \"normalize_adj\" : False,\n", 61 | " \n", 62 | " \"num_layers\" : 1\n", 63 | "}\n", 64 | "config = args" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "# Get the dataset." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 6, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "--------------------------------\n", 84 | "Reading dataset from /Users/raunak/Documents/Datasets/temporal-networks-network-repository/ia-contact/ia-contact.edges\n", 85 | "Finished reading data.\n", 86 | "Setting up graph.\n", 87 | "Finished setting up graph.\n", 88 | "Setting up examples.\n", 89 | "Finished setting up examples.\n", 90 | "Dataset properties:\n", 91 | "Mode: test\n", 92 | "Number of vertices: 274\n", 93 | "Number of static edges: 2472\n", 94 | "Number of temporal edges: 21183\n", 95 | "Number of examples/datapoints: 1348\n", 96 | "--------------------------------\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "dataset_args = (config['task'], config['dataset'], config['dataset_path'],\n", 102 | " config['generate_neg_examples'], 'test',\n", 103 | " config['duplicate_examples'], config['repeat_examples'],\n", 104 | " config['num_layers'], config['self_loop'],\n", 105 | " config['normalize_adj'])\n", 106 | "test_dataset = utils.get_dataset(dataset_args)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "# Evaluate on test set." 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "ROC-AUC score: 0.6061\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "y_true = test_dataset.y.numpy()\n", 131 | "y_scores = []\n", 132 | "y_pred = []\n", 133 | "for (x, y) in test_dataset:\n", 134 | " if x.numpy() in test_dataset.edges_s:\n", 135 | " y_scores.append(1.0)\n", 136 | " else:\n", 137 | " y_scores.append(np.random.rand())\n", 138 | "# y_pred = np.array(y_pred)\n", 139 | "# report = classification_report(y_true, y_pred)\n", 140 | "# print('Classification report\\n', report)\n", 141 | "y_scores = np.array(y_scores)\n", 142 | "score = roc_auc_score(y_true, y_scores)\n", 143 | "print('ROC-AUC score: {:.4f}'.format(score))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Python 3", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.6.8" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /results.md: -------------------------------------------------------------------------------- 1 | | Dataset | Model | Temporal | Duplicates | Repeats | ROC-AUC Score | CTNDE | 2 | | ---------------------------- | :--------: | :------: | :--------: | :-----: | :-----------: | ----: | 3 | | IAContact | GraphSAGE | No | True | True | 0.9416 | 0.913 | 4 | | IAContact | GraphSAGE | No | True | False | 0.8238 | NA | 5 | | IAContact | GraphSAGE | No | False | False | 0.7796 | NA | 6 | | IAContact | GraphSAGE | No | False | True | 0.8852 | NA | 7 | | IAContactsHypertext09 | GraphSAGE | No | True | True | 0.7069 | 0.671 | 8 | | IAContactsHypertext09 | GraphSAGE | No | True | False | 0.6847 | NA | 9 | | IAContactsHypertext09 | GraphSAGE | No | False | False | 0.5994 | NA | 10 | | IAContactsHypertext09 | GraphSAGE | No | False | True | 0.5861 | NA | 11 | | IAEnronEmployees | GraphSAGE | No | True | True | 0.7847 | 0.777 | 12 | | IAEnronEmployees | GraphSAGE | No | True | False | 0.7274 | NA | 13 | | IAEnronEmployees | GraphSAGE | No | False | False | 0.7000 | NA | 14 | | IAEnronEmployees | GraphSAGE | No | False | True | 0.7338 | NA | 15 | | IARadoslaw | GraphSAGE | No | True | True | 0.9312 | 0.811 | 16 | | IARadoslaw | GraphSAGE | No | True | False | 0.9162 | NA | 17 | | IARadoslaw | GraphSAGE | No | False | False | 0.8805 | NA | 18 | | IARadoslaw | GraphSAGE | No | False | True | 0.8945 | NA | 19 | | IAEmailEU | GraphSAGE | No | True | True | NA | 0.890 | 20 | | IAEmailEU | GraphSAGE | No | True | False | NA | NA | 21 | | IAEmailEU | GraphSAGE | No | False | False | NA | NA | 22 | | IAEmailEU | GraphSAGE | No | False | True | NA | NA | 23 | | FBForum | GraphSAGE | No | True | True | NA | 0.826 | 24 | | FBForum | GraphSAGE | No | True | False | NA | NA | 25 | | FBForum | GraphSAGE | No | False | False | NA | NA | 26 | | FBForum | GraphSAGE | No | False | True | NA | NA | 27 | | SOCBitcoinA | GraphSAGE | No | True | True | NA | 0.891 | 28 | | SOCBitcoinA | GraphSAGE | No | True | False | NA | NA | 29 | | SOCBitcoinA | GraphSAGE | No | False | False | NA | NA | 30 | | SOCBitcoinA | GraphSAGE | No | False | True | NA | NA | 31 | | SOCWikiElec | GraphSAGE | No | True | True | NA | 0.857 | 32 | | SOCWikiElec | GraphSAGE | No | True | False | NA | NA | 33 | | SOCWikiElec | GraphSAGE | No | False | False | NA | NA | 34 | | SOCWikiElec | GraphSAGE | No | False | True | NA | NA | -------------------------------------------------------------------------------- /src/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "stats_per_batch" : 3, 3 | "dataset" : "CollegeMsg", 4 | "dataset_path" : "/Users/raunak/Documents/Datasets/temporal-networks-snap/CollegeMsg.txt", 5 | "neg_examples_path" : "/Users/raunak/Documents/Datasets/temporal-networks-snap/CollegeMsg_neg_examples_train.txt", 6 | "mode" : "train", 7 | "task" : "link_prediction", 8 | "agg_class" : "MaxPoolAggregator", 9 | "cuda" : "True", 10 | "hidden_dims" : [64], 11 | "num_samples" : -1, 12 | "batch_size" : 16, 13 | "epochs" : 2, 14 | "lr" : 5e-2, 15 | "weight_decay" : 0e-3 16 | } 17 | -------------------------------------------------------------------------------- /src/config_gat.json: -------------------------------------------------------------------------------- 1 | { 2 | "task" : "link_prediction", 3 | 4 | "dataset" : "IAContactsHypertext", 5 | "dataset_path" : "/Users/raunak/Documents/Datasets/temporal-networks-network-repository/ia-contacts_hypertext2009/ia-contacts_hypertext2009.edges", 6 | "mode" : "train", 7 | "generate_neg_examples" : "False", 8 | 9 | "duplicate_examples" : "True", 10 | "repeat_examples" : "True", 11 | 12 | "self_loop" : "True", 13 | "normalize_adj" : "False", 14 | 15 | "cuda" : "True", 16 | "model" : "GAT", 17 | "num_heads" : [1, 1], 18 | "hidden_dims" : [64], 19 | "dropout" : 0, 20 | 21 | "epochs" : 3, 22 | "batch_size" : 32, 23 | "lr" : 5e-4, 24 | "weight_decay" : 5e-3, 25 | "stats_per_batch" : 3, 26 | "visdom" : "True", 27 | 28 | "load" : "False", 29 | "save" : "False" 30 | } 31 | -------------------------------------------------------------------------------- /src/config_graphsage.json: -------------------------------------------------------------------------------- 1 | { 2 | "task" : "link_prediction", 3 | 4 | "dataset" : "IAEnronEmployees", 5 | "dataset_path" : "/Users/raunak/Documents/Datasets/temporal-networks-network-repository/ia-enron-employees/ia-enron-employees.edges", 6 | "mode" : "train", 7 | "generate_neg_examples" : "True", 8 | 9 | "duplicate_examples" : "False", 10 | "repeat_examples" : "True", 11 | 12 | "self_loop" : "True", 13 | "normalize_adj" : "False", 14 | 15 | "cuda" : "True", 16 | "model" : "GraphSAGE", 17 | "agg_class" : "MaxPoolAggregator", 18 | "hidden_dims" : [64], 19 | "dropout" : 0, 20 | "num_samples" : -1, 21 | 22 | "epochs" : 9, 23 | "batch_size" : 32, 24 | "lr" : 5e-4, 25 | "weight_decay" : 5e-4, 26 | "stats_per_batch" : 3, 27 | "visdom" : "True", 28 | 29 | "load" : "False", 30 | "save" : "False" 31 | } 32 | -------------------------------------------------------------------------------- /src/datasets/link_prediction.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | import os 3 | 4 | import numpy as np 5 | import scipy.sparse as sp 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | class TemporalNetworkDataset(Dataset): 10 | 11 | def __init__(self, path, generate_neg_examples=False, mode='train', 12 | duplicate_examples=False, repeat_examples=False, 13 | num_layers=2, self_loop=False, normalize_adj=False, 14 | data_split=[0.37, 0.25, 0.13, 0.25]): 15 | """ 16 | Parameters 17 | ---------- 18 | path : str 19 | Path to the dataset file. For example, CollegeMsg.txt, etc. 20 | neg_examples: Boolean 21 | Whether to generate negative examples or read from file. If True, then negative examples are generated and saved in the same directory as the dataset file as ${dataset file name}_neg_examples_${mode}_${duplicate_examples}_${repeat_examples}.txt. If False, then examples are read from a file that should be named and located as described. Default: False. 22 | mode : str 23 | One of train, val or test. Default: train. 24 | duplicate_examples : Boolean 25 | Whether to keep multiple instances of an edge in the list of positive examples. Default: False. 26 | repeat_examples : Boolean 27 | Whether to keep a positive example that has already appeared been used for graph construction or training. Default: False. 28 | num_layers : int 29 | Number of layers in the computation graph. Default: 2. 30 | self_loop : Boolean 31 | Whether to add self loops. Default: False. 32 | normalize_adj : Boolean 33 | Whether to use symmetric normalization on the adjacency matrix. Default: False. 34 | data_split: list 35 | Fraction of edges to use for graph construction / train / val / test. Default: [0.85, 0.08, 0.02, 0.03]. 36 | """ 37 | super().__init__() 38 | 39 | self.path = path 40 | self.generate_neg_examples = generate_neg_examples 41 | self.mode = mode 42 | self.duplicate_examples = duplicate_examples 43 | self.repeat_examples = repeat_examples 44 | self.num_layers = num_layers 45 | self.self_loop = self_loop 46 | self.normalize_adj = normalize_adj 47 | self.data_split = data_split 48 | 49 | print('--------------------------------') 50 | print('Reading dataset from {}'.format(path)) 51 | edges_all = self._read_from_file(path) 52 | edges_all = edges_all[edges_all[:, 2].argsort()] 53 | print('Finished reading data.') 54 | 55 | print('Setting up graph.') 56 | vertex_id = {j : i for (i, j) in enumerate(np.unique(edges_all[:, :2]))} 57 | idxs = [floor(v*edges_all.shape[0]) for v in np.cumsum(data_split)] 58 | if mode == 'train': 59 | idx1, idx2 = idxs[0], idxs[1] 60 | elif mode == 'val': 61 | idx1, idx2 = idxs[1], idxs[2] 62 | elif mode == 'test': 63 | idx1, idx2 = idxs[2], idxs[3] 64 | edges_t, pos_examples = edges_all[:idx1, :], edges_all[idx1:idx2, :] 65 | 66 | edges_t[:, :2] = np.array([vertex_id[u] for u in edges_t[:, :2].flatten()]).reshape(edges_t[:, :2].shape) 67 | edges_s = np.unique(edges_t[:, :2], axis=0) 68 | 69 | self.n = len(vertex_id) 70 | self.m_s, self.m_t = edges_s.shape[0], edges_t.shape[0] 71 | 72 | adj = sp.coo_matrix((np.ones(self.m_s), (edges_s[:, 0], edges_s[:, 1])), 73 | shape=(self.n,self.n), 74 | dtype=np.float32) 75 | # Symmetric. 76 | adj += adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 77 | if self_loop: 78 | adj += sp.eye(self.n) 79 | if normalize_adj: 80 | degrees = np.power(np.array(np.sum(adj, axis=1)), -0.5).flatten() 81 | degrees = sp.diags(degrees) 82 | adj = degrees.dot(adj.dot(degrees)) 83 | 84 | self.adj = adj.tolil() 85 | self.edges_s = edges_s 86 | self.nbrs_s = self.adj.rows 87 | self.features = torch.from_numpy(np.eye(self.n)).float() 88 | 89 | nbrs_t = [[] for _ in range(self.n)] 90 | for (u, v, t) in edges_t: 91 | nbrs_t[u].append((v, t)) 92 | # Symmetric. 93 | nbrs_t[v].append((u, t)) 94 | nbrs_t = np.array(nbrs_t) 95 | 96 | self.nbrs_t = nbrs_t 97 | 98 | # TODO : Modify to time since last interaction rather than just time? 99 | last_time = np.max(edges_t[:, -1]) + 1 100 | timestamps = dict() 101 | for (u, v, t) in edges_t: 102 | if u not in timestamps.keys(): 103 | timestamps[u] = dict() 104 | if v not in timestamps[u].keys(): 105 | timestamps[u][v] = [] 106 | timestamps[u][v].append(last_time - t) 107 | # Symmetric. 108 | if v not in timestamps.keys(): 109 | timestamps[v] = dict() 110 | if u not in timestamps[v].keys(): 111 | timestamps[v][u] = [] 112 | timestamps[v][u].append(last_time - t) 113 | for u in range(self.n): 114 | if u not in timestamps.keys(): 115 | timestamps[u] = dict() 116 | timestamps[u][u] = [1] 117 | self.timestamps = timestamps 118 | print('Finished setting up graph.') 119 | 120 | print('Setting up examples.') 121 | if repeat_examples: 122 | pos_seen = set() 123 | else: 124 | pos_seen = set(tuple([u,v]) for (u,v) in edges_s) 125 | # Symmetric. 126 | pos_seen |= set(tuple([v,u]) for (u,v) in edges_s) 127 | pos_examples = pos_examples[:, :2] 128 | pos_examples = np.array([row for row in pos_examples \ 129 | if (row[0] < self.n) and 130 | (row[1] < self.n) and 131 | ((row[0], row[1]) not in pos_seen) and 132 | ((row[1], row[0]) not in pos_seen)]) 133 | if not duplicate_examples: 134 | pos_examples = np.unique(pos_examples, axis=0) 135 | 136 | neg_path = os.path.splitext(path)[0] + '_neg_examples_{}_{}_{}.txt'.format(mode, duplicate_examples, repeat_examples) 137 | if not generate_neg_examples: 138 | neg_examples = np.loadtxt(neg_path) 139 | else: 140 | num_neg_examples = pos_examples.shape[0] 141 | neg_examples = [] 142 | cur = 0 143 | n, _choice = self.n, np.random.choice 144 | neg_seen = set(tuple(e[:2]) for e in edges_all) 145 | while cur < num_neg_examples: 146 | u, v = _choice(n, 2, replace=False) 147 | if (u, v) in neg_seen: 148 | continue 149 | cur += 1 150 | neg_examples.append([u, v]) 151 | np.savetxt(neg_path, neg_examples) 152 | neg_examples = np.array(neg_examples, dtype=np.int64) 153 | 154 | # pos_examples, neg_examples = pos_examples[:1024], neg_examples[:1024] 155 | 156 | x = np.vstack((pos_examples, neg_examples)) 157 | y = np.concatenate((np.ones(pos_examples.shape[0]), 158 | np.zeros(neg_examples.shape[0]))) 159 | perm = np.random.permutation(x.shape[0]) 160 | x, y = x[perm, :], y[perm] 161 | x, y = torch.from_numpy(x).long(), torch.from_numpy(y).long() 162 | self.x, self.y = x, y 163 | print('Finished setting up examples.') 164 | 165 | print('Dataset properties:') 166 | print('Mode: {}'.format(self.mode)) 167 | print('Number of vertices: {}'.format(self.n)) 168 | print('Number of static edges: {}'.format(self.m_s)) 169 | print('Number of temporal edges: {}'.format(self.m_t)) 170 | print('Number of examples/datapoints: {}'.format(self.x.shape[0])) 171 | print('--------------------------------') 172 | 173 | 174 | def _read_from_file(self, path): 175 | raise NotImplementedError 176 | 177 | def __len__(self): 178 | return len(self.x) 179 | 180 | def __getitem__(self, idx): 181 | return self.x[idx], self.y[idx] 182 | 183 | def _form_computation_graph(self, idx): 184 | """ 185 | Parameters 186 | ---------- 187 | idx : int or list 188 | Indices of the node for which the forward pass needs to be computed. 189 | 190 | Returns 191 | ------- 192 | node_layers : list of numpy array 193 | node_layers[i] is an array of the nodes in the ith layer of the 194 | computation graph. 195 | mappings : list of dictionary 196 | mappings[i] is a dictionary mapping node v (labelled 0 to |V|-1) 197 | in node_layers[i] to its position in node_layers[i]. For example, 198 | if node_layers[i] = [2,5], then mappings[i][2] = 0 and 199 | mappings[i][5] = 1. 200 | """ 201 | _list, _set = list, set 202 | if type(idx) is int: 203 | node_layers = [np.array([idx], dtype=np.int64)] 204 | elif type(idx) is list: 205 | node_layers = [np.array(idx, dtype=np.int64)] 206 | 207 | for _ in range(self.num_layers): 208 | prev = node_layers[-1] 209 | arr = [node for node in prev] 210 | arr.extend([e[0] for node in arr for e in self.nbrs_t[node]]) 211 | arr = np.array(_list(_set(arr)), dtype=np.int64) 212 | node_layers.append(arr) 213 | node_layers.reverse() 214 | 215 | mappings = [{j : i for (i, j) in enumerate(arr)} for arr in node_layers] 216 | 217 | return node_layers, mappings 218 | 219 | def collate_wrapper(self, batch): 220 | """ 221 | Parameters 222 | ---------- 223 | batch : list 224 | A list of examples from this dataset. An example is (edge, label). 225 | 226 | Returns 227 | ------- 228 | edges : numpy array 229 | The edges in the batch. 230 | features : torch.FloatTensor 231 | An (n' x input_dim) tensor of input node features. 232 | node_layers : list of numpy array 233 | node_layers[i] is an array of the nodes in the ith layer of the 234 | computation graph. 235 | mappings : list of dictionary 236 | mappings[i] is a dictionary mapping node v (labelled 0 to |V|-1) 237 | in node_layers[i] to its position in node_layers[i]. For example, 238 | if node_layers[i] = [2,5], then mappings[i][2] = 0 and 239 | mappings[i][5] = 1. 240 | rows : numpy array 241 | Each row is the list of neighbors of nodes in node_layers[0]. 242 | labels : torch.LongTensor 243 | Labels (1 or 0) for the edges in the batch. 244 | """ 245 | idx = list(set([v.item() for sample in batch for v in sample[0][:2]])) 246 | 247 | node_layers, mappings = self._form_computation_graph(idx) 248 | 249 | rows = self.nbrs_s[node_layers[0]] 250 | features = self.features[node_layers[0], :] 251 | labels = torch.FloatTensor([sample[1] for sample in batch]) 252 | edges = np.array([sample[0].numpy() for sample in batch]) 253 | edges = np.array([mappings[-1][v] for v in edges.flatten()]).reshape(edges.shape) 254 | 255 | # TODO: Pin memory. Change type of node_layers, mappings and rows to 256 | # tensor? 257 | 258 | return edges, features, node_layers, mappings, rows, labels 259 | 260 | def get_dims(self): 261 | return self.features.shape[0], 1 262 | 263 | class CollegeMsg(TemporalNetworkDataset): 264 | 265 | def _read_from_file(self, path): 266 | return np.loadtxt(path, dtype=np.int64) 267 | 268 | class BitcoinAlpha(TemporalNetworkDataset): 269 | 270 | def _read_from_file(self, path): 271 | edges_all = np.loadtxt(path, delimiter=',', dtype=np.int64) 272 | return np.concatenate((edges_all[:, :2], edges_all[:, 3:]), axis=1) 273 | 274 | class FBForum(TemporalNetworkDataset): 275 | 276 | def _read_from_file(self, path): 277 | with open(path, 'r') as f: 278 | lines = f.readlines() 279 | lines = [[int(float(x)) for x in line.strip().split(',')] for line in lines] 280 | return np.array(lines) 281 | # return np.loadtxt(path, delimiter=',', dtype=np.int64) 282 | 283 | class IAContact(TemporalNetworkDataset): 284 | 285 | def _read_from_file(self, path): 286 | with open(path, 'r') as f: 287 | lines = f.readlines() 288 | lines = [line.split('\t') for line in lines] 289 | lines = [[*line[0].split(), line[1].split()[1]] for line in lines] 290 | return np.array(lines, dtype=np.int64) 291 | 292 | class IAContactsHypertext(TemporalNetworkDataset): 293 | 294 | def _read_from_file(self, path): 295 | return np.loadtxt(path, delimiter=',', dtype=np.int64) 296 | 297 | class IAEnronEmployees(TemporalNetworkDataset): 298 | 299 | def _read_from_file(self, path): 300 | with open(path, 'r') as f: 301 | lines = f.readlines() 302 | lines = [line.split() for line in lines] 303 | lines = [[line[0], line[1], line[3]] for line in lines] 304 | return np.array(lines, dtype=np.int64) 305 | 306 | class IARadoslawEmail(TemporalNetworkDataset): 307 | 308 | def _read_from_file(self, path): 309 | with open(path, 'r') as f: 310 | lines = f.readlines() 311 | lines = [line.split() for line in lines[2:]] 312 | lines = [[line[0], line[1], line[3]] for line in lines] 313 | return np.array(lines, dtype=np.int64) 314 | 315 | class WikiElec(TemporalNetworkDataset): 316 | 317 | def _read_from_file(self, path): 318 | with open(path, 'r') as f: 319 | lines = f.readlines() 320 | lines = [line.split() for line in lines[2:]] 321 | lines = [[line[0], line[1], line[3]] for line in lines] 322 | return np.array(lines, dtype=np.int64) 323 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class Aggregator(nn.Module): 9 | 10 | def __init__(self, input_dim=None, output_dim=None, device='cpu'): 11 | """ 12 | Parameters 13 | ---------- 14 | input_dim : int or None. 15 | Dimension of input node features. Used for defining fully 16 | connected layer in pooling aggregators. Default: None. 17 | output_dim : int or None 18 | Dimension of output node features. Used for defining fully 19 | connected layer in pooling aggregators. Currently only works when 20 | input_dim = output_dim. Default: None. 21 | """ 22 | # super(Aggregator, self).__init__() 23 | super().__init__() 24 | 25 | self.input_dim = input_dim 26 | self.output_dim = output_dim 27 | self.device = device 28 | 29 | def forward(self, features, nodes, mapping, rows, num_samples=25): 30 | """ 31 | Parameters 32 | ---------- 33 | features : torch.Tensor 34 | An (n' x input_dim) tensor of input node features. 35 | nodes : numpy array 36 | nodes is a numpy array of nodes in the current layer of the computation graph. 37 | mapping : dict 38 | mapping is a dictionary mapping node v (labelled 0 to |V|-1) to 39 | its position in the layer of nodes in the computationn graph 40 | before nodes. For example, if the layer before nodes is [2,5], 41 | then mapping[2] = 0 and mapping[5] = 1. 42 | rows : numpy array 43 | rows[i] is an array of neighbors of node i which is present in nodes. 44 | num_samples : int 45 | Number of neighbors to sample while aggregating. Default: 25. 46 | 47 | Returns 48 | ------- 49 | out : torch.Tensor 50 | An (len(nodes) x output_dim) tensor of output node features. 51 | Currently only works when output_dim = input_dim. 52 | """ 53 | _choice, _len, _min = np.random.choice, len, min 54 | mapped_rows = [np.array([mapping[v] for v in row], dtype=np.int64) for row in rows] 55 | if num_samples == -1: 56 | sampled_rows = mapped_rows 57 | else: 58 | sampled_rows = [_choice(row, _min(_len(row), num_samples), _len(row) < num_samples) for row in mapped_rows] 59 | 60 | n = _len(nodes) 61 | if self.__class__.__name__ == 'LSTMAggregator': 62 | out = torch.zeros(n, 2*self.output_dim).to(self.device) 63 | else: 64 | out = torch.zeros(n, self.output_dim).to(self.device) 65 | for i in range(n): 66 | if _len(sampled_rows[i]) != 0: 67 | out[i, :] = self._aggregate(features[sampled_rows[i], :]) 68 | 69 | return out 70 | 71 | def _aggregate(self, features): 72 | """ 73 | Parameters 74 | ---------- 75 | 76 | Returns 77 | ------- 78 | """ 79 | raise NotImplementedError 80 | 81 | class MeanAggregator(Aggregator): 82 | 83 | def _aggregate(self, features): 84 | """ 85 | Parameters 86 | ---------- 87 | features : torch.Tensor 88 | Input features. 89 | 90 | Returns 91 | ------- 92 | Aggregated feature. 93 | """ 94 | return torch.mean(features, dim=0) 95 | 96 | class PoolAggregator(Aggregator): 97 | 98 | def __init__(self, input_dim, output_dim, device='cpu'): 99 | """ 100 | Parameters 101 | ---------- 102 | input_dim : int 103 | Dimension of input node features. Used for defining fully connected layer. 104 | output_dim : int 105 | Dimension of output node features. Used for defining fully connected layer. Currently only works when output_dim = input_dim. 106 | """ 107 | # super(PoolAggregator, self).__init__(input_dim, output_dim, device) 108 | super().__init__(input_dim, output_dim, device) 109 | 110 | self.fc1 = nn.Linear(input_dim, output_dim) 111 | self.relu = nn.ReLU() 112 | 113 | def _aggregate(self, features): 114 | """ 115 | Parameters 116 | ---------- 117 | features : torch.Tensor 118 | Input features. 119 | 120 | Returns 121 | ------- 122 | Aggregated feature. 123 | """ 124 | out = self.relu(self.fc1(features)) 125 | return self._pool_fn(out) 126 | 127 | def _pool_fn(self, features): 128 | """ 129 | Parameters 130 | ---------- 131 | 132 | Returns 133 | ------- 134 | """ 135 | raise NotImplementedError 136 | 137 | class MaxPoolAggregator(PoolAggregator): 138 | 139 | def _pool_fn(self, features): 140 | """ 141 | Parameters 142 | ---------- 143 | features : torch.Tensor 144 | Input features. 145 | 146 | Returns 147 | ------- 148 | Aggregated feature. 149 | """ 150 | return torch.max(features, dim=0)[0] 151 | 152 | class MeanPoolAggregator(PoolAggregator): 153 | 154 | def _pool_fn(self, features): 155 | """ 156 | Parameters 157 | ---------- 158 | features : torch.Tensor 159 | Input features. 160 | 161 | Returns 162 | ------- 163 | Aggregated feature. 164 | """ 165 | return torch.mean(features, dim=0)[0] 166 | 167 | class LSTMAggregator(Aggregator): 168 | 169 | def __init__(self, input_dim, output_dim, device='cpu'): 170 | """ 171 | Parameters 172 | ---------- 173 | input_dim : int 174 | Dimension of input node features. Used for defining LSTM layer. 175 | output_dim : int 176 | Dimension of output node features. Used for defining LSTM layer. Currently only works when output_dim = input_dim. 177 | 178 | """ 179 | # super(LSTMAggregator, self).__init__(input_dim, output_dim, device) 180 | super().__init__(input_dim, output_dim, device) 181 | 182 | self.lstm = nn.LSTM(input_dim, output_dim, bidirectional=True, batch_first=True) 183 | 184 | def _aggregate(self, features): 185 | """ 186 | Parameters 187 | ---------- 188 | features : torch.Tensor 189 | Input features. 190 | 191 | Returns 192 | ------- 193 | Aggregated feature. 194 | """ 195 | perm = np.random.permutation(np.arange(features.shape[0])) 196 | features = features[perm, :] 197 | features = features.unsqueeze(0) 198 | 199 | out, _ = self.lstm(features) 200 | out = out.squeeze(0) 201 | out = torch.sum(out, dim=0) 202 | 203 | return out 204 | 205 | class GraphAttention(nn.Module): 206 | 207 | def __init__(self, input_dim, output_dim, num_heads, dropout=0.5): 208 | """ 209 | Parameters 210 | ---------- 211 | input_dim : int 212 | Dimension of input node features. 213 | output_dim : int 214 | Dimension of output features after each attention head. 215 | num_heads : int 216 | Number of attention heads. 217 | dropout : float 218 | Dropout rate. Default: 0.5. 219 | """ 220 | super().__init__() 221 | 222 | self.input_dim = input_dim 223 | self.output_dim = output_dim 224 | self.num_heads = num_heads 225 | 226 | self.fcs = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)]) 227 | self.a = nn.ModuleList([nn.Linear(2*output_dim, 1) for _ in range(num_heads)]) 228 | 229 | self.dropout = nn.Dropout(dropout) 230 | self.softmax = nn.Softmax(dim=0) 231 | self.leakyrelu = nn.LeakyReLU() 232 | 233 | def forward(self, features, nodes, mapping, rows): 234 | """ 235 | Parameters 236 | ---------- 237 | features : torch.Tensor 238 | An (n' x input_dim) tensor of input node features. 239 | nodes : numpy array 240 | nodes is a numpy array of nodes in the current layer of the computation graph. 241 | mapping : dict 242 | mapping is a dictionary mapping node v (labelled 0 to |V|-1) to 243 | its position in the layer of nodes in the computation graph 244 | before nodes. For example, if the layer before nodes is [2,5], 245 | then mapping[2] = 0 and mapping[5] = 1. 246 | rows : numpy array 247 | rows[i] is an array of neighbors of node i which is present in nodes. 248 | 249 | Returns 250 | ------- 251 | out : list of torch.Tensor 252 | A list of (len(nodes) x input_dim) tensor of output node features. 253 | """ 254 | 255 | nprime = features.shape[0] 256 | rows = [np.array([mapping[v] for v in row], dtype=np.int64) for row in rows] 257 | sum_degs = np.hstack(([0], np.cumsum([len(row) for row in rows]))) 258 | mapped_nodes = [mapping[v] for v in nodes] 259 | indices = torch.LongTensor([[v, c] for (v, row) in zip(mapped_nodes, rows) for c in row]).t() 260 | 261 | out = [] 262 | for k in range(self.num_heads): 263 | h = self.fcs[k](features) 264 | 265 | nbr_h = torch.cat(tuple([h[row] for row in rows if len(row) > 0]), dim=0) 266 | self_h = torch.cat(tuple([h[mapping[nodes[i]]].repeat(len(row), 1) for (i, row) in enumerate(rows) if len(row) > 0]), dim=0) 267 | attn_h = torch.cat((self_h, nbr_h), dim=1) 268 | 269 | e = self.leakyrelu(self.a[k](attn_h)) 270 | 271 | alpha = [self.softmax(e[lo : hi]) for (lo, hi) in zip(sum_degs, sum_degs[1:])] 272 | alpha = torch.cat(tuple(alpha), dim=0) 273 | alpha = alpha.squeeze(1) 274 | alpha = self.dropout(alpha) 275 | 276 | adj = torch.sparse.FloatTensor(indices, alpha, torch.Size([nprime, nprime])) 277 | out.append(torch.sparse.mm(adj, h)[mapped_nodes]) 278 | 279 | return out -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import json 3 | import os 4 | import sys 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from sklearn.metrics import roc_auc_score, roc_curve, classification_report 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | import visdom 14 | 15 | from datasets import link_prediction 16 | from layers import MeanAggregator, LSTMAggregator, MaxPoolAggregator, MeanPoolAggregator 17 | import models 18 | import utils 19 | 20 | def main(): 21 | 22 | # Set up arguments for datasets, models and training. 23 | config = utils.parse_args() 24 | config['num_layers'] = len(config['hidden_dims']) + 1 25 | 26 | if config['cuda'] and torch.cuda.is_available(): 27 | device = 'cuda:0' 28 | else: 29 | device = 'cpu' 30 | config['device'] = device 31 | 32 | # Get the dataset, dataloader and model. 33 | dataset_args = (config['task'], config['dataset'], config['dataset_path'], 34 | config['generate_neg_examples'], 'train', 35 | config['duplicate_examples'], config['repeat_examples'], 36 | config['num_layers'], config['self_loop'], 37 | config['normalize_adj']) 38 | dataset = utils.get_dataset(dataset_args) 39 | 40 | loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], 41 | shuffle=True, collate_fn=dataset.collate_wrapper) 42 | input_dim, output_dim = dataset.get_dims() 43 | 44 | if config['model'] == 'GraphSAGE': 45 | agg_class = utils.get_agg_class(config['agg_class']) 46 | model = models.GraphSAGE(input_dim, config['hidden_dims'], 47 | output_dim, config['dropout'], 48 | agg_class, config['num_samples'], 49 | config['device']) 50 | else: 51 | model = models.GAT(input_dim, config['hidden_dims'], 52 | output_dim, config['num_heads'], 53 | config['dropout'], config['device']) 54 | model.apply(models.init_weights) 55 | model.to(config['device']) 56 | print(model) 57 | 58 | # Compute ROC-AUC score for the untrained model. 59 | if not config['load']: 60 | print('--------------------------------') 61 | print('Computing ROC-AUC score for the training dataset before training.') 62 | y_true, y_scores = [], [] 63 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 64 | with torch.no_grad(): 65 | for (idx, batch) in enumerate(loader): 66 | edges, features, node_layers, mappings, rows, labels = batch 67 | features, labels = features.to(device), labels.to(device) 68 | out = model(features, node_layers, mappings, rows) 69 | all_pairs = torch.mm(out, out.t()) 70 | scores = all_pairs[edges.T] 71 | y_true.extend(labels.detach().cpu().numpy()) 72 | y_scores.extend(scores.detach().cpu().numpy()) 73 | print(' Batch {} / {}'.format(idx+1, num_batches)) 74 | y_true = np.array(y_true).flatten() 75 | y_scores = np.array(y_scores).flatten() 76 | area = roc_auc_score(y_true, y_scores) 77 | print('ROC-AUC score: {:.4f}'.format(area)) 78 | print('--------------------------------') 79 | 80 | # Train. 81 | if not config['load']: 82 | use_visdom = config['visdom'] 83 | if use_visdom: 84 | vis = visdom.Visdom() 85 | loss_window = None 86 | criterion = utils.get_criterion(config['task']) 87 | optimizer = optim.Adam(model.parameters(), lr=config['lr'], 88 | weight_decay=config['weight_decay']) 89 | epochs = config['epochs'] 90 | stats_per_batch = config['stats_per_batch'] 91 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 92 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.8) 93 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[300, 600], gamma=0.5) 94 | model.train() 95 | print('--------------------------------') 96 | print('Training.') 97 | for epoch in range(epochs): 98 | print('Epoch {} / {}'.format(epoch+1, epochs)) 99 | running_loss = 0.0 100 | for (idx, batch) in enumerate(loader): 101 | edges, features, node_layers, mappings, rows, labels = batch 102 | features, labels = features.to(device), labels.to(device) 103 | optimizer.zero_grad() 104 | out = model(features, node_layers, mappings, rows) 105 | all_pairs = torch.mm(out, out.t()) 106 | scores = all_pairs[edges.T] 107 | loss = criterion(scores, labels.float()) 108 | loss.backward() 109 | optimizer.step() 110 | with torch.no_grad(): 111 | running_loss += loss.item() 112 | if (idx + 1) % stats_per_batch == 0: 113 | running_loss /= stats_per_batch 114 | print(' Batch {} / {}: loss {:.4f}'.format( 115 | idx+1, num_batches, running_loss)) 116 | if (torch.sum(labels.long() == 0).item() > 0) and (torch.sum(labels.long() == 1).item() > 0): 117 | area = roc_auc_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy()) 118 | print(' ROC-AUC score: {:.4f}'.format(area)) 119 | running_loss = 0.0 120 | num_correct, num_examples = 0, 0 121 | if use_visdom: 122 | if loss_window is None: 123 | loss_window = vis.line( 124 | Y=[loss.item()], 125 | X=[epoch*num_batches+idx], 126 | opts=dict(xlabel='batch', ylabel='Loss', title='Training Loss', legend=['Loss'])) 127 | else: 128 | vis.line( 129 | [loss.item()], 130 | [epoch*num_batches+idx], 131 | win=loss_window, 132 | update='append') 133 | scheduler.step() 134 | if use_visdom: 135 | vis.close(win=loss_window) 136 | print('Finished training.') 137 | print('--------------------------------') 138 | 139 | if not config['load']: 140 | if config['save']: 141 | print('--------------------------------') 142 | directory = os.path.join(os.path.dirname(os.getcwd()), 143 | 'trained_models') 144 | if not os.path.exists(directory): 145 | os.makedirs(directory) 146 | fname = utils.get_fname(config) 147 | path = os.path.join(directory, fname) 148 | print('Saving model at {}'.format(path)) 149 | torch.save(model.state_dict(), path) 150 | print('Finished saving model.') 151 | print('--------------------------------') 152 | 153 | # Compute ROC-AUC score after training. 154 | if not config['load']: 155 | print('--------------------------------') 156 | print('Computing ROC-AUC score for the training dataset after training.') 157 | y_true, y_scores = [], [] 158 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 159 | with torch.no_grad(): 160 | for (idx, batch) in enumerate(loader): 161 | edges, features, node_layers, mappings, rows, labels = batch 162 | features, labels = features.to(device), labels.to(device) 163 | out = model(features, node_layers, mappings, rows) 164 | all_pairs = torch.mm(out, out.t()) 165 | scores = all_pairs[edges.T] 166 | y_true.extend(labels.detach().cpu().numpy()) 167 | y_scores.extend(scores.detach().cpu().numpy()) 168 | print(' Batch {} / {}'.format(idx+1, num_batches)) 169 | y_true = np.array(y_true).flatten() 170 | y_scores = np.array(y_scores).flatten() 171 | area = roc_auc_score(y_true, y_scores) 172 | print('ROC-AUC score: {:.4f}'.format(area)) 173 | print('--------------------------------') 174 | 175 | # Plot the true positive rate and true negative rate vs threshold. 176 | if not config['load']: 177 | tpr, fpr, thresholds = roc_curve(y_true, y_scores) 178 | tnr = 1 - fpr 179 | plt.plot(thresholds, tpr, label='tpr') 180 | plt.plot(thresholds, tnr, label='tnr') 181 | plt.xlabel('Threshold') 182 | plt.title('TPR / TNR vs Threshold') 183 | plt.legend() 184 | plt.show() 185 | 186 | # Choose an appropriate threshold and generate classification report on the train set. 187 | idx1 = np.where(tpr <= tnr)[0] 188 | idx2 = np.where(tpr >= tnr)[0] 189 | t = thresholds[idx1[-1]] 190 | total_correct, total_examples = 0, 0 191 | y_true, y_pred = [], [] 192 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 193 | with torch.no_grad(): 194 | for (idx, batch) in enumerate(loader): 195 | edges, features, node_layers, mappings, rows, labels = batch 196 | features, labels = features.to(device), labels.to(device) 197 | out = model(features, node_layers, mappings, rows) 198 | all_pairs = torch.mm(out, out.t()) 199 | scores = all_pairs[edges.T] 200 | predictions = (scores >= t).long() 201 | y_true.extend(labels.detach().cpu().numpy()) 202 | y_pred.extend(predictions.detach().cpu().numpy()) 203 | total_correct += torch.sum(predictions == labels.long()).item() 204 | total_examples += len(labels) 205 | print(' Batch {} / {}'.format(idx+1, num_batches)) 206 | print('Threshold: {:.4f}, accuracy: {:.4f}'.format(t, total_correct / total_examples)) 207 | y_true = np.array(y_true).flatten() 208 | y_pred = np.array(y_pred).flatten() 209 | report = classification_report(y_true, y_pred) 210 | print('Classification report\n', report) 211 | 212 | # Evaluate on the validation set. 213 | if config['load']: 214 | directory = os.path.join(os.path.dirname(os.getcwd()), 215 | 'trained_models') 216 | fname = utils.get_fname(config) 217 | path = os.path.join(directory, fname) 218 | model.load_state_dict(torch.load(path)) 219 | dataset_args = (config['task'], config['dataset'], config['dataset_path'], 220 | config['generate_neg_examples'], 'val', 221 | config['duplicate_examples'], config['repeat_examples'], 222 | config['num_layers'], config['self_loop'], 223 | config['normalize_adj']) 224 | dataset = utils.get_dataset(dataset_args) 225 | loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], 226 | shuffle=False, collate_fn=dataset.collate_wrapper) 227 | criterion = utils.get_criterion(config['task']) 228 | stats_per_batch = config['stats_per_batch'] 229 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 230 | model.eval() 231 | print('--------------------------------') 232 | print('Computing ROC-AUC score for the validation dataset after training.') 233 | running_loss, total_loss = 0.0, 0.0 234 | num_correct, num_examples = 0, 0 235 | total_correct, total_examples = 0, 0 236 | y_true, y_scores, y_pred = [], [], [] 237 | for (idx, batch) in enumerate(loader): 238 | edges, features, node_layers, mappings, rows, labels = batch 239 | features, labels = features.to(device), labels.to(device) 240 | out = model(features, node_layers, mappings, rows) 241 | all_pairs = torch.mm(out, out.t()) 242 | scores = all_pairs[edges.T] 243 | loss = criterion(scores, labels.float()) 244 | running_loss += loss.item() 245 | total_loss += loss.item() 246 | predictions = (scores >= t).long() 247 | num_correct += torch.sum(predictions == labels.long()).item() 248 | total_correct += torch.sum(predictions == labels.long()).item() 249 | num_examples += len(labels) 250 | total_examples += len(labels) 251 | y_true.extend(labels.detach().cpu().numpy()) 252 | y_scores.extend(scores.detach().cpu().numpy()) 253 | y_pred.extend(predictions.detach().cpu().numpy()) 254 | if (idx + 1) % stats_per_batch == 0: 255 | running_loss /= stats_per_batch 256 | accuracy = num_correct / num_examples 257 | print(' Batch {} / {}: loss {:.4f}, accuracy {:.4f}'.format( 258 | idx+1, num_batches, running_loss, accuracy)) 259 | if (torch.sum(labels.long() == 0).item() > 0) and (torch.sum(labels.long() == 1).item() > 0): 260 | area = roc_auc_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy()) 261 | print(' ROC-AUC score: {:.4f}'.format(area)) 262 | running_loss = 0.0 263 | num_correct, num_examples = 0, 0 264 | total_loss /= num_batches 265 | total_accuracy = total_correct / total_examples 266 | print('Loss {:.4f}, accuracy {:.4f}'.format(total_loss, total_accuracy)) 267 | y_true = np.array(y_true).flatten() 268 | y_scores = np.array(y_scores).flatten() 269 | y_pred = np.array(y_pred).flatten() 270 | report = classification_report(y_true, y_pred) 271 | area = roc_auc_score(y_true, y_scores) 272 | print('ROC-AUC score: {:.4f}'.format(area)) 273 | print('Classification report\n', report) 274 | print('Finished validating.') 275 | print('--------------------------------') 276 | 277 | # Evaluate on test set. 278 | if config['load']: 279 | directory = os.path.join(os.path.dirname(os.getcwd()), 280 | 'trained_models') 281 | fname = utils.get_fname(config) 282 | path = os.path.join(directory, fname) 283 | model.load_state_dict(torch.load(path)) 284 | dataset_args = (config['task'], config['dataset'], config['dataset_path'], 285 | config['generate_neg_examples'], 'test', 286 | config['duplicate_examples'], config['repeat_examples'], 287 | config['num_layers'], config['self_loop'], 288 | config['normalize_adj']) 289 | dataset = utils.get_dataset(dataset_args) 290 | loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], 291 | shuffle=False, collate_fn=dataset.collate_wrapper) 292 | criterion = utils.get_criterion(config['task']) 293 | stats_per_batch = config['stats_per_batch'] 294 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 295 | model.eval() 296 | print('--------------------------------') 297 | print('Computing ROC-AUC score for the test dataset after training.') 298 | running_loss, total_loss = 0.0, 0.0 299 | num_correct, num_examples = 0, 0 300 | total_correct, total_examples = 0, 0 301 | y_true, y_scores, y_pred = [], [], [] 302 | for (idx, batch) in enumerate(loader): 303 | edges, features, node_layers, mappings, rows, labels = batch 304 | features, labels = features.to(device), labels.to(device) 305 | out = model(features, node_layers, mappings, rows) 306 | all_pairs = torch.mm(out, out.t()) 307 | scores = all_pairs[edges.T] 308 | loss = criterion(scores, labels.float()) 309 | running_loss += loss.item() 310 | total_loss += loss.item() 311 | predictions = (scores >= t).long() 312 | num_correct += torch.sum(predictions == labels.long()).item() 313 | total_correct += torch.sum(predictions == labels.long()).item() 314 | num_examples += len(labels) 315 | total_examples += len(labels) 316 | y_true.extend(labels.detach().cpu().numpy()) 317 | y_scores.extend(scores.detach().cpu().numpy()) 318 | y_pred.extend(predictions.detach().cpu().numpy()) 319 | if (idx + 1) % stats_per_batch == 0: 320 | running_loss /= stats_per_batch 321 | accuracy = num_correct / num_examples 322 | print(' Batch {} / {}: loss {:.4f}, accuracy {:.4f}'.format( 323 | idx+1, num_batches, running_loss, accuracy)) 324 | if (torch.sum(labels.long() == 0).item() > 0) and (torch.sum(labels.long() == 1).item() > 0): 325 | area = roc_auc_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy()) 326 | print(' ROC-AUC score: {:.4f}'.format(area)) 327 | running_loss = 0.0 328 | num_correct, num_examples = 0, 0 329 | total_loss /= num_batches 330 | total_accuracy = total_correct / total_examples 331 | print('Loss {:.4f}, accuracy {:.4f}'.format(total_loss, total_accuracy)) 332 | y_true = np.array(y_true).flatten() 333 | y_scores = np.array(y_scores).flatten() 334 | y_pred = np.array(y_pred).flatten() 335 | report = classification_report(y_true, y_pred) 336 | area = roc_auc_score(y_true, y_scores) 337 | print('ROC-AUC score: {:.4f}'.format(area)) 338 | print('Classification report\n', report) 339 | print('Finished testing.') 340 | print('--------------------------------') 341 | 342 | if __name__ == '__main__': 343 | main() -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import layers 6 | from layers import MeanAggregator, LSTMAggregator, MaxPoolAggregator, MeanPoolAggregator 7 | 8 | def init_weights(m): 9 | if type(m) == nn.Linear: 10 | torch.nn.init.xavier_uniform_(m.weight, gain=1.414) 11 | 12 | class GraphSAGE(nn.Module): 13 | 14 | def __init__(self, input_dim, hidden_dims, output_dim, 15 | dropout=0.5, agg_class=MaxPoolAggregator, num_samples=25, 16 | device='cpu'): 17 | """ 18 | Parameters 19 | ---------- 20 | input_dim : int 21 | Dimension of input node features. 22 | hidden_dims : list of ints 23 | Dimension of hidden layers. Must be non empty. 24 | output_dim : int 25 | Dimension of output node features. 26 | dropout : float 27 | Probability of setting an element to 0 in dropout layer. Default: 0.5. 28 | agg_class : An aggregator class. 29 | Aggregator. One of the aggregator classes imported at the top of 30 | this module. Default: MaxPoolAggregator. 31 | num_samples : int 32 | Number of neighbors to sample while aggregating. Default: 25. 33 | device : string 34 | 'cpu' or 'cuda:0'. Default: 'cpu'. 35 | """ 36 | super(GraphSAGE, self).__init__() 37 | 38 | self.input_dim = input_dim 39 | self.hidden_dims = hidden_dims 40 | self.output_dim = output_dim 41 | self.agg_class = agg_class 42 | self.num_samples = num_samples 43 | self.device = device 44 | self.num_layers = len(hidden_dims) + 1 45 | 46 | self.aggregators = nn.ModuleList([agg_class(input_dim, input_dim, device)]) 47 | self.aggregators.extend([agg_class(dim, dim, device) for dim in hidden_dims]) 48 | 49 | 50 | c = 3 if agg_class == LSTMAggregator else 2 51 | self.fcs = nn.ModuleList([nn.Linear(c*input_dim, hidden_dims[0])]) 52 | self.fcs.extend([nn.Linear(c*hidden_dims[i-1], hidden_dims[i]) for i in range(1, len(hidden_dims))]) 53 | self.fcs.extend([nn.Linear(c*hidden_dims[-1], output_dim)]) 54 | 55 | self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for hidden_dim in hidden_dims]) 56 | 57 | self.dropout = nn.Dropout(dropout) 58 | self.relu = nn.ReLU() 59 | 60 | def forward(self, features, node_layers, mappings, rows): 61 | """ 62 | Parameters 63 | ---------- 64 | features : torch.Tensor 65 | An (n' x input_dim) tensor of input node features. 66 | node_layers : list of numpy array 67 | node_layers[i] is an array of the nodes in the ith layer of the 68 | computation graph. 69 | mappings : list of dictionary 70 | mappings[i] is a dictionary mapping node v (labelled 0 to |V|-1) 71 | in node_layers[i] to its position in node_layers[i]. For example, 72 | if node_layers[i] = [2,5], then mappings[i][2] = 0 and 73 | mappings[i][5] = 1. 74 | rows : numpy array 75 | rows[i] is an array of neighbors of node i. 76 | 77 | Returns 78 | ------- 79 | out : torch.Tensor 80 | An (len(node_layers[-1]) x output_dim) tensor of output node features. 81 | """ 82 | out = features 83 | for k in range(self.num_layers): 84 | nodes = node_layers[k+1] 85 | mapping = mappings[k] 86 | init_mapped_nodes = np.array([mappings[0][v] for v in nodes], dtype=np.int64) 87 | cur_rows = rows[init_mapped_nodes] 88 | aggregate = self.aggregators[k](out, nodes, mapping, cur_rows, 89 | self.num_samples) 90 | cur_mapped_nodes = np.array([mapping[v] for v in nodes], dtype=np.int64) 91 | out = torch.cat((out[cur_mapped_nodes, :], aggregate), dim=1) 92 | out = self.fcs[k](out) 93 | if k+1 < self.num_layers: 94 | out = self.relu(out) 95 | out = self.bns[k](out) 96 | out = self.dropout(out) 97 | out = out.div(out.norm(dim=1, keepdim=True)+1e-6) 98 | 99 | return out 100 | 101 | class GAT(nn.Module): 102 | 103 | def __init__(self, input_dim, hidden_dims, output_dim, num_heads, 104 | dropout=0.5, device='cpu'): 105 | """ 106 | Parameters 107 | ---------- 108 | input_dim : int 109 | Dimension of input node features. 110 | hidden_dims : list of ints 111 | Dimension of hidden layers. Must be non empty. 112 | output_dim : int 113 | Dimension of output node features. 114 | num_heads : list of ints 115 | Number of attention heads in each hidden layer and output layer. Must be non empty. Note that len(num_heads) = len(hidden_dims)+1. 116 | dropout : float 117 | Dropout rate. Default: 0.5. 118 | device : str 119 | 'cpu' or 'cuda:0'. Default: 'cpu'. 120 | """ 121 | super().__init__() 122 | 123 | self.input_dim = input_dim 124 | self.hidden_dims = hidden_dims 125 | self.output_dim = output_dim 126 | self.num_heads = num_heads 127 | self.device = device 128 | self.num_layers = len(hidden_dims) + 1 129 | 130 | dims = [input_dim] + [d*nh for (d, nh) in zip(hidden_dims, num_heads[:-1])] + [output_dim*num_heads[-1]] 131 | in_dims = dims[:-1] 132 | out_dims = [d // nh for (d, nh) in zip(dims[1:], num_heads)] 133 | 134 | self.attn = nn.ModuleList([layers.GraphAttention(i, o, nh, dropout) for (i, o, nh) in zip(in_dims, out_dims, num_heads)]) 135 | 136 | self.bns = nn.ModuleList([nn.BatchNorm1d(dim) for dim in dims[1:-1]]) 137 | 138 | self.dropout = nn.Dropout(dropout) 139 | self.elu = nn.ELU() 140 | 141 | def forward(self, features, node_layers, mappings, rows): 142 | """ 143 | Parameters 144 | ---------- 145 | features : torch.Tensor 146 | An (n' x input_dim) tensor of input node features. 147 | node_layers : list of numpy array 148 | node_layers[i] is an array of the nodes in the ith layer of the 149 | computation graph. 150 | mappings : list of dictionary 151 | mappings[i] is a dictionary mapping node v (labelled 0 to |V|-1) 152 | in node_layers[i] to its position in node_layers[i]. For example, 153 | if node_layers[i] = [2,5], then mappings[i][2] = 0 and 154 | mappings[i][5] = 1. 155 | rows : numpy array 156 | rows[i] is an array of neighbors of node i. 157 | 158 | Returns 159 | ------- 160 | out : torch.Tensor 161 | An (len(node_layers[-1]) x output_dim) tensor of output node features. 162 | """ 163 | out = features 164 | for k in range(self.num_layers): 165 | nodes = node_layers[k+1] 166 | mapping = mappings[k] 167 | init_mapped_nodes = np.array([mappings[0][v] for v in nodes], dtype=np.int64) 168 | cur_rows = rows[init_mapped_nodes] 169 | out = self.dropout(out) 170 | out = self.attn[k](out, nodes, mapping, cur_rows) 171 | if k+1 < self.num_layers: 172 | out = [self.elu(o) for o in out] 173 | out = torch.cat(tuple(out), dim=1) 174 | out = self.bns[k](out) 175 | else: 176 | out = torch.cat(tuple([x.flatten().unsqueeze(0) for x in out]), dim=0) 177 | out = out.mean(dim=0).reshape(len(nodes), self.output_dim) 178 | 179 | return out -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import json 4 | import sys 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from datasets import link_prediction 10 | from layers import MeanAggregator, LSTMAggregator, MaxPoolAggregator, MeanPoolAggregator 11 | import models 12 | 13 | def get_agg_class(agg_class): 14 | """ 15 | Parameters 16 | ---------- 17 | agg_class : str 18 | Name of the aggregator class. 19 | 20 | Returns 21 | ------- 22 | layers.Aggregator 23 | Aggregator class. 24 | """ 25 | return getattr(sys.modules[__name__], agg_class) 26 | 27 | def get_criterion(task): 28 | """ 29 | Parameters 30 | ---------- 31 | task : str 32 | Name of the task. 33 | 34 | Returns 35 | ------- 36 | criterion : torch.nn.modules._Loss 37 | Loss function for the task. 38 | """ 39 | if task == 'link_prediction': 40 | criterion = nn.BCEWithLogitsLoss() 41 | 42 | return criterion 43 | 44 | def get_dataset(args): 45 | """ 46 | Parameters 47 | ---------- 48 | args : tuple 49 | Tuple of task, dataset name and other arguments required by the dataset constructor. 50 | 51 | Returns 52 | ------- 53 | dataset : torch.utils.data.Dataset 54 | The dataset. 55 | """ 56 | task, dataset_name, *dataset_args = args 57 | if task == 'link_prediction': 58 | class_attr = getattr(importlib.import_module('datasets.link_prediction'), dataset_name) 59 | dataset = class_attr(*dataset_args) 60 | 61 | return dataset 62 | 63 | def get_fname(config): 64 | """ 65 | Parameters 66 | ---------- 67 | config : dict 68 | A dictionary with all the arguments and flags. 69 | 70 | Returns 71 | ------- 72 | fname : str 73 | The filename for the saved model. 74 | """ 75 | duplicate_examples = config['duplicate_examples'] 76 | repeat_examples = config['repeat_examples'] 77 | agg_class = config['agg_class'] 78 | hidden_dims_str = '_'.join([str(x) for x in config['hidden_dims']]) 79 | num_samples = config['num_samples'] 80 | batch_size = config['batch_size'] 81 | epochs = config['epochs'] 82 | lr = config['lr'] 83 | weight_decay = config['weight_decay'] 84 | fname = 'graphsage_agg_class_{}_hidden_dims_{}_num_samples_{}_batch_size_{}_epochs_{}_lr_{}_weight_decay_{}_duplicate_{}_repeat_{}.pth'.format( 85 | agg_class, hidden_dims_str, num_samples, batch_size, epochs, lr, 86 | weight_decay, duplicate_examples, repeat_examples) 87 | 88 | return fname 89 | 90 | # def get_model(config): 91 | # """ 92 | # Parameters 93 | # ---------- 94 | # config : dict 95 | # A dictionary with all the arguments and flags. 96 | # 97 | # Returns 98 | # ------- 99 | # model : torch.nn.Module 100 | # The model. 101 | # """ 102 | # if config['model'] == 'GraphSAGE': 103 | # agg_class = get_agg_class(config['agg_class']) 104 | # model = models.GraphSAGE(config['input_dim'], config['hidden_dims'], 105 | # config['output_dim'], config['dropout'], 106 | # agg_class, config['num_samples'], 107 | # config['device']) 108 | # elif config['model'] == 'GAT': 109 | # model = models.GAT(config['input_dim'], config['hidden_dims'], 110 | # config['output_dim'], config['num_heads'], 111 | # config['dropout'], config['device']) 112 | # 113 | # model.to(config['device']) 114 | # 115 | # return model 116 | 117 | 118 | def parse_args(): 119 | """ 120 | Returns 121 | ------- 122 | config : dict 123 | A dictionary with the required arguments and flags. 124 | """ 125 | parser = argparse.ArgumentParser() 126 | 127 | parser.add_argument('--json', type=str, default='config.json', 128 | help='path to json file with arguments, default: config.json') 129 | 130 | parser.add_argument('--stats_per_batch', type=int, default=16, 131 | help='print loss and accuracy after how many batches, default: 16') 132 | 133 | parser.add_argument('--dataset', type=str, 134 | choices=['CollegeMsg', 'BitcoinAlpha'], 135 | # required=True, 136 | help='name of the dataset') 137 | parser.add_argument('--dataset_path', type=str, 138 | # required=True, 139 | help='path to dataset') 140 | parser.add_argument('--neg_examples_path', type=str, 141 | default='', 142 | help='path to file with negative examples, default=''') 143 | parser.add_argument('--duplicate_examples', action='store_true', 144 | help='whether to allow duplicate edges in the list of positive examples, default=False') 145 | parser.add_argument('--repeat_examples', action='store_true', 146 | help='whether to use positive examples that have already been seen during graph construction or training, default=False') 147 | parser.add_argument('--self_loop', action='store_true', 148 | help='whether to add self loops to adjacency matrix, default=False') 149 | parser.add_argument('--normalize_adj', action='store_true', 150 | help='whether to normalize adj like in gcn, default=False') 151 | 152 | parser.add_argument('--task', type=str, 153 | choices=['unsupervised', 'link_prediction'], 154 | default='link_prediction', 155 | help='type of task, default=link_prediction') 156 | 157 | parser.add_argument('--agg_class', type=str, 158 | choices=[MeanAggregator, LSTMAggregator, MaxPoolAggregator, MeanPoolAggregator], 159 | default=MaxPoolAggregator, 160 | help='aggregator class, default: MaxPoolAggregator') 161 | parser.add_argument('--cuda', action='store_true', 162 | help='whether to use GPU, default: False') 163 | parser.add_argument('--dropout', type=float, default=0.5, 164 | help='dropout out, currently only for GCN, default: 0.5') 165 | parser.add_argument('--hidden_dims', type=int, nargs="*", 166 | help='dimensions of hidden layers, length should be equal to num_layers, specify through config.json') 167 | parser.add_argument('--num_samples', type=int, default=-1, 168 | help='number of neighbors to sample, default=-1') 169 | 170 | parser.add_argument('--batch_size', type=int, default=32, 171 | help='training batch size, default=32') 172 | parser.add_argument('--epochs', type=int, default=2, 173 | help='number of training epochs, default=2') 174 | parser.add_argument('--lr', type=float, default=1e-4, 175 | help='learning rate, default=1e-4') 176 | parser.add_argument('--weight_decay', type=float, default=5e-4, 177 | help='weight decay, default=5e-4') 178 | 179 | parser.add_argument('--save', action='store_true', 180 | help='whether to save model in trained_models/ directory, default: False') 181 | parser.add_argument('--load', action='store_true', 182 | help='whether to load model in trained_models/ directory') 183 | 184 | args = parser.parse_args() 185 | config = vars(args) 186 | if config['json']: 187 | with open(config['json']) as f: 188 | json_dict = json.load(f) 189 | config.update(json_dict) 190 | 191 | for (k, v) in config.items(): 192 | if config[k] == 'True': 193 | config[k] = True 194 | elif config[k] == 'False': 195 | config[k] = False 196 | 197 | config['num_layers'] = len(config['hidden_dims']) + 1 198 | 199 | print('--------------------------------') 200 | print('Config:') 201 | for (k, v) in config.items(): 202 | print(" '{}': '{}'".format(k, v)) 203 | print('--------------------------------') 204 | 205 | return config -------------------------------------------------------------------------------- /visualizations/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raunakkmr/GraphSAGE-and-GAT-for-link-prediction/d264920416f32d8e0c2db206eb6ae4727e0b0ad5/visualizations/.DS_Store --------------------------------------------------------------------------------