├── README.md ├── LICENSE ├── pyg_1.ipynb ├── pyg_2.ipynb └── pyg_3.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # PyG_Study 2 | 3 | 中文: 4 | 5 | 这是一个关于PyTorch Geometric --> "用于几何深度学习的PyTorch扩展库" 的教学仓库 6 | 基本上用的是 https://pytorch-geometric.readthedocs.io/en/latest/ 的内容 7 | 不定期更新 8 | 9 | 尝试过录视频上传B站, 但是... 以我现在的时间, 精力, 能力, B站视频讲解可能会.... 有点吃力. 遂暂时放弃录制视频(但是有时间还是会上传内容的). 10 | 11 | @wmf1997 12 | 13 | English: 14 | 15 | This is a teach & study repository of PyTorch Geometric. --> *Geometric Deep Learning Extension Library for PyTorch* 16 | I use https://pytorch-geometric.readthedocs.io/en/latest/ as reference. 17 | Update Irregularly. 18 | 19 | I manage to record some videos, and update them to bilibili, however, it is a little hard for me to go on recording and teaching via recording videos. I give up recording videos to teach PyG, but, ... if I do have time, I will go on updating this repo. 20 | 21 | @wmf1997 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 MingFei Wang (王明非, Wang MingFei, WMF) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyg_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 5, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "Data(edge_index=[2, 3])\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import torch_geometric\n", 29 | "import torch_geometric.nn as gnn\n", 30 | "import torch_geometric.data as gdata\n", 31 | "\n", 32 | "edge_index = torch.tensor([[0,2],[1,0],[2,1]]).transpose(0, 1)\n", 33 | "g = gdata.Data(edge_index=edge_index)\n", 34 | "print (g)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 6, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stderr", 44 | "output_type": "stream", 45 | "text": [ 46 | "/home/wmf997/anaconda3/lib/python3.7/site-packages/torch_geometric/data/data.py:184: UserWarning: The number of nodes in your data object can only be inferred by its edge indices, and hence may result in unexpected batch-wise behavior, e.g., in case there exists isolated nodes. Please consider explicitly setting the number of nodes for this data object by assigning it to data.num_nodes.\n", 47 | " warnings.warn(__num_nodes_warn_msg__.format('edge'))\n" 48 | ] 49 | }, 50 | { 51 | "data": { 52 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAE/CAYAAACXV7AVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEFJJREFUeJzt3W+IVelhx/HftW501GZQMAkNXVtSQmiihURxE8rWJVLQNWV3w+6SlFK1SYjdkFfuCyGUQkINwZAXSSqFbjJL/2A0Cy4FJaRv3CRCcMu62j/ZlOzitnkRJ1Hcbh2b2r19cR1XxzvjvTPn3vuccz8fGJh77rnnPvPqy3PmnOe02u12OwDASC0b9QAAAEEGgCIIMgAUQJABoACCDAAFEGQAKIAgA0ABBBkACiDIAFAAQQaAAggyABRAkAGgAIIMAAUQZAAogCADQAEEGQAKIMgAUABBBoACCDIAFECQAaAAggwABRBkACjA8lEPoNYuXkymppJz55IrV5LJyWTTpmTPnmT9+lGPDoAaabXb7faoB1E7Z84kBw8mJ092Xl+79uZ7ExNJu53s2JEcOJBs2TKaMQJQK4Lcr8OHk/37k5mZTnjn02p14nzoULJv3/DGB0AtOWXdj9kYX716933b7c5++/d3XosyAAswQ+7VmTPJtm13xPhrSaaSnE/ysRu/32HVquTUqWTz5sGOEYDacpV1rw4e7JymnuPXknwuyd6FPjsz0/k8AMzDDLkXFy8mGzbcfvHWHJ9L8p+ZZ4acJCtXJq++6uprALoyQ+7F1NTSj9FqVXMcABpJkHtx7tyCs+OezMwk589XMx4AGkeQe3HlSjXHuXy5muMA0DiC3IvJyWqOs3ZtNccBoHEEuRebNnUuyuriepJrSf7vxs+1G9vmuprk6RdeyJEjR/L6668PaqQA1JQg92L37nnf+kKSiSRfTPK3N37/Qpf9JlasyD2f+ESefvrpvPOd78xHP/pRcQbgJrc99eqRR5LjxxdeLnM+rVby8MPJM88kSS5dupTjx4/n2LFjOX36dLZv355HH300u3btypo1ayoeOAB1IMi9mmelrp4ssFKXOAOQCHJ/+lnLetaqVT0/YEKcAcaXIPdrSE97EmeA8SLIi/H88521qU+c6IT31jWuZ5+HvHNn53nIFTxQQpwBmk+Ql2J6Opmayn+cPJl//cEP8vuPPZbWpk2dq7IHtGa1OAM0kyBX4P3vf39euHGP8eOPPz607xVngOYQ5CX64Q9/mG3btuXatWu5995788orr2TZsuHf3i3OAPUmyEt0//3353vf+16SZPXq1XnqqaeGOkvuRpwB6keQl2B6ejpvf/vbs3z58ly/fj3Lli3LAw88kO9+97ujHtpN4gxQD4K8RK+99lpeeOGFPPHEE3nuueeyevXqrFixYtTD6kqcAcplLesleutb35rJycksX74869atKzbGSbJu3brs3bs3J0+ezCuvvJIHH3zQ2toAhRDkMSXOAGURZMQZoACCzG3EGWA0BJl5iTPA8AgyPRFngMESZPomzgDVE2SWRJwBqiHIVEacARZPkBkIcQbojyAzcOIMcHeCzFCJM0B3gszIiDPAmwSZIogzMO4EmeKIMzCOBJmiiTMwLgSZ2hBnoMkEmVoSZ6BpBJnaE2egCQSZRhFnoK4EmcYSZ6BOBJmxIM5A6QSZsSPOQIkEmbEmzkApBBluEGdglAQZuhBnYNgEGe5CnIFhEGTogzgDgyLIsEjiDFRJkKEC4gwslSBDxcQZWAxBhgESZ6BXggxDIs7AQgQZRkCcgbkEGUZMnIFEkKEo4gzjS5ChUOIM40WQoQbEGZpPkKFmxBmaSZChxsQZmkOQoSHEGepNkKGBxBnqR5Ch4cQZ6kGQYYzcLc7f+ta3xBlGRJBhTHWL89TUlDjDiAgyIM5QAEEGbiPOMBqCDMxLnGF4BBnoiTjDYAky0DdxhuoJMrAk4gzVEGSgMuIMiyfIwECIM/RHkIGBE2e4O0EGhkqcoTtBBkZGnOFNggwUQZwZd4IMFEecGUeCDBRNnBkXggzUhjjTZIIM1JI40zSCDNSeONMEggw0ijhTV4IMNJY4UyeCDIwFcaZ0ggyMHXGmRIIMjDVxphSCDHCDODNKggzQhTgzbIIMcBfizDAIMkAfxJlBEWSARRJnqiTIABWYG+ddu3aJM30RZICKrVu3Lnv27BFn+iLIAAMkzvRKkAGGRJxZiCADjIA4M5cgA4yYOJMIMkBRxHl8CTJAocR5vAgyQA2Ic/MJMkDNiHMzCTJAjYlzcwgyQEOIc70JMkADiXP9CDJAw4lzPQgywBgR53IJMsCYEueyCDIA4lwAQQbgNuI8GoIMwLzEeXgEGYCeiPNgCTIAfRPn6gkyAEsiztUQZAAqI86L12q32+1RD6Luzp49m927d+fs2bOjHgpAkS5dupRnn302R48ezenTp7N9+/Y89thjefDBB7NmzZpqvuTixWRqKjl3LrlyJZmcTDZtSvbsSdavr+Y7BkiQKyDIAL2rPM5nziQHDyYnT3ZeX7v25nsTE0m7nezYkRw4kGzZUs0fMQBOWQMwVP2e1v7MZz6Tz372s+k6fzx8ONm2LTl+vBPiW2OcJDMznW3Hj3f2O3x4oH/bUiwf9QAAGF+zcd6zZ8/NmfPU1FQ+9alPZfv27XnkkUfyzW9+M+12O9evX8/Xv/71tFqtzocPH07270+uXr37F7Xbnf327++83rdvcH/UIpkhA1CEbjPnr3zlK7l69WpmZmbyjW98I5/85Cc7M+UzZ7rG+FKSh5OsTrIhyd/P/ZLZKD///FD+pn4IMgDFmY3zu971rpvbfvnLX+app57Kpz/96c7/jGdm7vjcE0nekuRnSf4uyb4k/zJ3p5mZzucL45Q1AMVqtVrZunVr3ve+92Xjxo15xzvekfvf857kvvs6p6Fv8d9Jnknyz0nWJPndJH+Q5G+SfPHWHdvt5MSJZHq6qKuvBRmAYh05cuTOjV/6Utd9f5zkV5K8+5Ztv5PkVLedW63OLVJPPrnUIVbGKWsA6uXcuTuvpk7yepLJOdsmk/xXt2PMzCTnz1c/tiUQZADq5cqVrpvXJHltzrbXkvzqfMe5fLm6MVVAkAGol8m58+COdye5nuTfb9n2YpL3znectWsrHdZS+R8yALVw4cKFHDt2LCtOnconkkzMeX91kkeS/FmSv05yNsmzSU53O9jERLJx4yCH2zczZACKdeHChRw6dChbt27N5s2b89JLL2Xjl7+clStWdN3/L5PMJHlbko8lOZx5ZsjtdrJ794BGvThmyAAUZXYmfOzYsbz88st56KGH8vnPfz4PPPBA7rnnns5OR450lsOcc+vTuiTH7/YFrVayc2dRtzwlggxAAXqK8K0OHEi+853els2ca2Ki8/nCCDIAI9F3hG+1ZUty6FDva1nPWrWq87nNm5c2+AEQZACGZkkRnmv2ARH793fuK17oacKtVmdmfOhQkQ+WSAQZgAGrNMJz7dvXmS0fPNhZDrPVun2N69nnIe/c2TlNXeDMeJYgA1C5gUZ4rs2bk2ee6axNPTXVWYHr8uXOfcYbN3aupi7sAq5uBBmASgw1wt2sX1/U2tT9EmQAFm3kEW4QQQagLyI8GIIMwF2J8OAJMgBdifBwCTIAN4nw6AgywJgT4TIIMsAYEuHyCDLAmBDhsgkyQIOJcH0IMkDDiHA9CTJAA4hw/QkyQE2JcLMIMkCNiHBzCTJA4UR4PAgyQIFEePwIMkAhRHi8CTLACIkwswQZYMhEmG4EGWAIRJi7EWSAARFh+iHIABUSYRZLkAGWSISpgiADLIIIUzVBBuiRCDNIggywABFmWAQZYA4RZhQEGSAizOgJMjC2RJiSCDIwVkSYUgky0HgiTB0IMtBIIkzdCDLQGCJMnQkyUGsiTFMIMlA7IkwTCTJQCyJM0wkyUCwRZpwIMlAUEWZcCTIwciIMggyMiAjD7QQZGBoRhvkJMjBQIgy9EWSgciIM/RNkoBIiDEsjyMCiiTBUR5CBvogwDIYgA3clwjB4ggx0deHChXz729/O0aNHRRiGQJCBm0QYRkeQYcyJMJRBkGEMiTCUR5BhTIgwlE2QocFEGOpDkKFhRBjqSZChAUQY6k+QoaZEGJpFkKFGRBiaS5ChcCIM40GQoUAiDONHkKEQIgzjTZBhhEQYmCXIMGQiDHQjyDAEIgzcjSDDgIgw0A9BhgqJMLBYggxLJMJAFQQZFkGEgaoJMvRIhIFBEmRYgAgDwyLIMIcIA6MgyBARBkZPkBlbIgyURJAZKyIMlEqQaTwRBupAkGkkEQbqRpBpDBEG6kyQqTURBppCkKkdEQaaSJCphVsj/JOf/CQPP/ywCAONIsgUS4SBcSLIFEWEgXElyIycCAMIMiMiwgC3E2SGRoQB5ifIDJQIA/RGkKmcCAP0T5CphAgDLI0gs2giDFAdQaYvIgwwGILMXYkwwOAJMl2JMMBwCTI3iTDA6AjymBNhgDII8hgSYYDyCPKYEGGAsglyg4kwQH0IcsOIMEA9CXIDiDBA/QlyTYkwQLMIco2IMEBztdrtdnvUg6itixeTqalcOnUq57///fzeRz6SbNqU7NmTrF9fyVd0i/Cjjz4qwgANI8iLceZMcvBgcvJk5/W1a2++NzGRtNvJjh3JgQPJli19H16EAcaPIPfr8OFk//5kZqYT3vm0Wp04HzqU7Nt321vT09PZsWNHvvrVr+aDH/xgEhEGGHeC3I/ZGF+92vtnVq26LcrT09PZunVrLly4kIceeigf+tCHRBgAQe7ZmTPJtm23xfh/kvxpkn9McinJbyX5iyQ75n521ark1KlMb9iQD3zgA/npT3+aN954I0myd+/ePP744yIMMOZcZd2rgwc7p6lvcT3Jryc5leTeJCeSPJbkfJLfuHXHmZnk4MH89nPP5ec///nNzStWrMjHP/7xfPjDHx7s2AEoniD34uLFzgVcc04mrE7y57e83pXkN5P8U+YEud1OTpzIPxw9mh9fvpyXX345L774Yl566aX84he/GOzYAagFQe7F1FRPu/0syY+TvLfbm61W7vvRj3Lfk09WNy4AGmPZqAdQC+fO3X5rUxf/m+QPk/xxkvd022FmJjl/vvqxAdAIgtyLK1cWfPuNJH+U5C1JvrbQjpcvVzcmABrFKeteTE7O+1Y7yZ+kc7r6RJIFr5Neu7bSYQHQHGbIvdi0KVm5sutb+5L8W5J/SDKx0DEmJpKNG6sfGwCN4D7kXly8mGzYcMf/kS+kczX1itx+quGv0vl/8m1WrkxefbWyNa4BaBYz5F687W2dtalbrds2b0jnlPW1JK/f8nNHjFutZOdOMQZgXmbIveqyUlfPbqzUlc2bKx8WAM1ghtyrLVs6a1KvWtXf52bXshZjABbgKut+zD61aYlPewKAuZyyXoznn++sbX3iRCe8t65xPfs85J07O89DNjMGoAeCvBTT051lNc+f7yz6sXZt59am3btdwAVAXwQZAArgoi4AKIAgA0ABBBkACiDIAFAAQQaAAggyABRAkAGgAIIMAAUQZAAogCADQAEEGQAKIMgAUABBBoACCDIAFECQAaAAggwABRBkACiAIANAAQQZAAogyABQAEEGgAIIMgAUQJABoACCDAAFEGQAKIAgA0ABBBkACiDIAFAAQQaAAggyABRAkAGgAIIMAAUQZAAogCADQAEEGQAKIMgAUABBBoACCDIAFOD/Ada4b+isOPL0AAAAAElFTkSuQmCC\n", 53 | "text/plain": [ 54 | "
" 55 | ] 56 | }, 57 | "metadata": {}, 58 | "output_type": "display_data" 59 | } 60 | ], 61 | "source": [ 62 | "import networkx as nx\n", 63 | "import torch_geometric.utils as gutils\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "\n", 66 | "g_nx = gutils.to_networkx(g)\n", 67 | "nx.draw_kamada_kawai(g_nx, with_labels=True)\n", 68 | "plt.show()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [] 77 | } 78 | ], 79 | "metadata": { 80 | "kernelspec": { 81 | "display_name": "Python 3", 82 | "language": "python", 83 | "name": "python3" 84 | }, 85 | "language_info": { 86 | "codemirror_mode": { 87 | "name": "ipython", 88 | "version": 3 89 | }, 90 | "file_extension": ".py", 91 | "mimetype": "text/x-python", 92 | "name": "python", 93 | "nbconvert_exporter": "python", 94 | "pygments_lexer": "ipython3", 95 | "version": "3.7.3" 96 | } 97 | }, 98 | "nbformat": 4, 99 | "nbformat_minor": 2 100 | } 101 | -------------------------------------------------------------------------------- /pyg_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch_geometric\n", 11 | "import torch_geometric.data as gdata\n", 12 | "\n", 13 | "edge_index = torch.tensor([[0,1],[1,2],[2,0],[2,1]], dtype=torch.int64).transpose(1,0)\n", 14 | "g = gdata.Data(edge_index=edge_index)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 6, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "Data(edge_index=[2, 4], x=[3, 10])\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "x = torch.randn(3, 10)\n", 32 | "g.x = x\n", 33 | "print (g)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 7, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import networkx as nx\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import torch_geometric.utils as gutils" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 9, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAE/CAYAAAADsRnnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEP5JREFUeJzt3V2IXOdhx+H/yJKl1cqIrZGcNLa0cWq7JLFsBSnJRYudSymkkNRycJtgqQ4BGXxlpSAa6jgl2V7oKiQVvSgoGAqRbNxiImMoOLJxnSAnWBJO83kRqxgsEcsbpbvyx2p6MV5rtJpdze7OnDnnzPPAgnbmzJl3dPPjfffMexrNZrMZAKDvVgx6AAAwLEQXAAoiugBQENEFgIKILgAURHQBoCCiCwAFEV0AKIjoAkBBRBcACiK6AFAQ0QWAgoguABREdAGgIKILAAURXQAoiOgCQEFEFwAKIroAUBDRBYCCiC4AFER0AaAgKwc9gEo4cyY5dCg5eTKZnEzWr0+2bEn27Ek2bBj06ACoiEaz2WwOehCldfx4MjGRPP106/cLFy49NzKSNJvJjh3J/v3J9u2DGSMAlSG68zl4MNm3L5mebsV1Po1GK8AHDiR79xY3PgAqx/JyJ7PBnZq6+rHNZuu4fftavwsvAPMw053r+PHk7rsvC+53kxxKcirJfe/9u6O1a5Njx5Jt2/o7RgAqydXLc01MtJaU2/xpkq8n+burvXZ6uvV6AOjATLfdmTPJ5s2XXzDV5utJ/jcLzHSTZM2a5NVXXdUMwBXMdNsdOrT8czQavTkPALUjuu1Onpx3ltu16enk1KnejAeAWhHddpOTvTnPuXO9OQ8AtSK67dav7815xsaSJM1mM6+//npefvnl+NM5AKLbbsuW1oVQc7yb5EKSmfd+Lrz3WCcXV6/OPx89mg996EMZGRnJpk2b8olPfCLnz5/v27ABqAZXL7eb5+rlbyR5dM6hj7z3+FzN1atz28hIfv3mm+8/dsstt+SXv/xlGo1Gb8cLQKWY6bbbuLG1l/KcOH4jSXPOzzc6vb7RSOOzn82Lv/lNbrnllqxatSorV67M5ORkxsfH8/DDD+cnP/mJpWaAISW6c+3f39pLeSlGRpL9+3P99dfnxRdfzPj4eC5evJiXX345P/zhDzM6Oprdu3cLMMCQsrzcyWL2Xp61du0VNz34/e9/n6eeeiq7d+9+/7Fms5lXXnklhw8fzpEjRzI1NZV77rkn9957bz75yU9aggaoMdGdTwF3GRJggOEiugt56aXWXspHj7bi2r4n8+z9dHfubC1JL/MmBwIMUH+i242zZ5NDh/L4I4/khtWr85ef+1xy++3J7t192WNZgAHqSXS7dPr06WzevDnXXHNNzp07l3Xr1hXyvgIMUB+uXu7SI488kmazmWazme985zuFvW+j0cjHP/7xfPOb38zPf/5zV0EDVJiZbhdOnz6dW2+9NRfe2zTjuuuuy2uvvVbYbLcTM2CA6jHT7cIPfvCDvPXWW2k0GlmxYkXOnz+fZ555ZqBjMgMGqB4z3S7MzMxkcnIyd911VzZt2pTHHnssY2NjpZxNmgEDlJfoLsLWrVszPj6eJ598ctBD6YoAA5SL5eUaswQNUC6iOyQEGGDwRHcICTDAYIjukBNggOKILu8TYID+El06EmCA3hNdrkqAAXpDdFkUAQZYOtFlyQQYYHFEl54QYICrE116ToABOhNd+kqAAS4RXQojwMCwE10GQoCBYSS6DJwAA8NCdCkVAQbqTHQprU4BXrdu3fsB3rdvnwADlSK6VMJsgB999NGOM2ABBqpAdKkcAQaqSnSpNAEGqkR0qQ0BBspOdKklAQbKSHSpPQEGykJ0GSoCDAyS6DK0BBgomuhCBBgohujCHAIM9IvowgIEGOgl0YUuCTCwXKILSyDAwFKILiyTAAPdEl3oIQEGFiK60CcCDMwlulAAAQYS0YXCCTAML9GFARJgGC6iCyUhwFB/ogslJMBQT6ILJSfAUB+iCxUiwFBtogsVJcBQPaILNSDAUA2iCzUjwFBeogs1JsBQLqILQ0KAYfBEF4aQAMNgiC4MOQGG4ogu8D4Bhv4SXaAjAYbeE13gqgQYekN0gUURYFg60QWWTIBhcUQX6AkBhqsTXaDnBBg6E12grwQYLhFdoDACzLATXWAgBJhhJLrAwAkww0J0gVIRYOpMdIHSEmDqRnSBShBg6kB0gcoRYKpKdIFKE2CqRHSB2rhagB9++GEBZqBEF6ilTgFet26dADNQogvUngBTFqILDBUBZpBEFxhaAkzRRBcgAkwxRBdgDgGmX0QXYAECTC+JLkCXBJjlEl2AJRBglkJ0AZZJgOmW6AL0kACzENEF6BMBZi7RBSiAAJOILkDhBHh4iS7AAAnwcBFdgJIQ4PoTXYASEuB6El2AkhPg+hBdgArpJsA//vGPBbikRBegouYL8J49ewS4pEQXoAYEuBpEF6BmBLi8RBegxgS4XEQXYEgI8OCJLsAQEuDBEF2AISfAxRFdAN4nwP0lugB0JMC9J7oAXJUA94boArAoArx0jab/la5t3bo14+PjefLJJwc9FIDSaTabeeWVV3LkyJEcPnw4U1NTueeee7Jr16586lOfSqPR6M0bnTmTHDqUnDyZTE4m69cnW7Yke/YkGzb05j36RHQXQXQButOXAB8/nkxMJE8/3fr9woVLz42MJM1msmNHsn9/sn17bz5Ij4nuIoguwOL1JMAHDyb79iXT0624zqfRaAX4wIFk797efYge8TddAPpq2X8Dng3u1NTCwU1az09NtY4/eLD3H2aZRBeAwiw6wMePXwpumzeSfD7JaJLNSf597hvNhvell/r+mRbD8vIiWF4G6I/5lqD/4ac/zdhzz6UxJ1X3JbmY5N+SvJzks0n+O8nH2g9qNJLPfz554omCPsXVmekCMHCdZsA3NBoZ7RDc/0vyRJJ/SrIuyV8k+askj809abOZHD2anD1bwCfojugCUCqzAf77jRtz7erVVzz/qyTXJLm17bE7krzS+WStrxeVhOgCUE4nT6bR/rWg9/wxyfo5j61Pcr7TOaank1Onej+2JRJdAMppcrLjw+uS/GHOY39Ict185zl3rndjWibRBWCgZmZmcuLEibzxxhuXP7F+7ny25dYk7yb5ddtjJzLnIqp2Y2PLHmOviC4AA/Xb3/42d955Zz74wQ9mdHQ0mzZtykc/+tH84cMfTtasueL40SRfSPKPaV1U9UKS/0zy5U4nHxlJbr+9j6NfHNEFYKBuvvnmjI2N5e23387U1FROnz6d8+fPZ9VXvjLva/4lyXSSjWl9fehg5pnpNpvJ7t19GPXSiC4AhZuZmclzzz2Xhx56KDfddFOuueaarFixItdee222b9+eX/ziFxnZvLm1l3KHbSL/JMl/pDXTfTXJ33R6k0Yj2bmzVDdBEF0ACtEe2htvvDEPPfRQPvCBD+TYsWN55plncvHixdxxxx159tlnMzo62nrR/v2tJeKlGBlpvb5EVg56AADU18zMTF544YUcOXIkjz/+eDZu3Jh77703x44dy623XvqmbbPZzPe+973cf//9l4KbtO4WdOBAx60gF7R2bet127b18NMsn+gC0FPdhrZdo9HIgw8+2PmEs3cLqsFdhkQXgGVbSmgXZe/e1qx3YqK1tWOj0QrwrNn76e7c2VpSLtkMd5boArAkfQ/tXNu2tW5ecPZsa2vHU6daG1+MjbW+FrR7d6kumupEdAHoWuGh7WTDhuRrXyvmvXpMdAFYUClCWxOiC8AVhLY/RBeAJEJbBNEFGGJCWyzRBRgyQjs4ogswBIS2HEQXoKaEtnxEF6BGhLbcRBeg4oS2OkQXoIKEtppEF6AihLb6RBegxIS2XkQXoGSEtr5EF6AEhHY4iC7AgAjt8BFdgAIJ7XATXYA+E1pmiS5AHwgtnYguQI8ILVcjugDLILQshugCLJLQslSiC9AFoaUXRBdgHkJLr4kuQBuhpZ9EFxh6QktRRBcYSkLLIIguMDSElkETXaDWhJYyEV2gdoSWshJdoBaElioQXaCyhJaqEV2gUoSWKhNdoPSElroQXaCUhJY6El2gNISWuhNdYKCElmEiukDhhJZhJbpAIYQWRBfoI6GFy4ku0FNCC/MTXWDZhBa6I7rAkggtLJ7oAl0TWlge0QUWJLTQO6ILXEFooT9EF0gitFAE0YUhJrRQLNGFISO0MDiiC0NAaKEcRBdqSmihfEQXamS+0P7oRz/KbbfdNujhwdATXag4M1qoDtGFChJaqCbRhYoQWqg+0YUSE1qoF9GFkhFaqC/RhRIQWhgOogsDIrQwfEQXCiS0MNxEF/pMaIFZogt9ILRAJ6ILPSK0wNWILiyD0AKLIbqwSEILLJXoQheEFugF0YV5CC3Qa6ILbYQW6CfRZegJLVAU0WUoCS0wCKLL0BBaYNBEl1oTWqBMRJfaEVqgrESXWhBaoApEl8oSWqBqRJdKEVqgykSX0hNaoC5El1ISWqCORJfSEFqg7kSXgZoN7eHDh/PEE08ILVBrokvh5ob2hhtuyK5du4QWqD3RpRBCCyC69JHQAlxOdOkpoQWYn+iybEIL0B3RZUmEFmDxRJeuCS3A8oguCxJagN4RXa4gtAD9IbokEVqAIojuEBNagGKJ7pARWoDBEd0hILQA5SC6NSW0AOUjul04ceJEjh49mtdffz1vv/12JiYmct9992V8fHzQQ7uM0AKUW6PZbDYHPYiy+/73v58HHnggMzMzSZJGo5Fnn302d91114BHNn9od+3aJbQAJSO6XXjnnXdy44035syZM0mSO++8Mz/72c/SaDQGMh6hBagmy8tdWLVqVb71rW/lq1/9alasWJEDBw4UHlxLxwDVZ6bbpXfeeScjIyMZHR3Nm2++WUh0zWgB6sVMt0urVq3KZz7zmdx88819Da4ZLUB9iW43zpxJDh3Kt0+fzvWvvZZ86UvJli3Jnj3Jhg3LPr3QAgwHy8sLOX48mZhInn669fuFC5eeGxlJms1kx45k//5k+/ZFndrSMcDwEd35HDyY7NuXTE+34jqfRqMV4AMHkr17Fzyl0AIMN8vLncwGd2rq6sc2m63j9u1r/T4nvJaOAZhlpjvX8ePJ3XdfFty3kjyY5L+SvJHkz5J8O8mOua9duzY5diwzW7ea0QJwBTPduSYmWkvKbd5NclOSY0k2JTma5N4kp5KMtx3XnJ7OiS9+MTumpsxoAbiCmW67M2eSzZsvv2BqHluSPJLkr+c8/u7Klfnd88/nI5/+dD9GCECFrRj0AErl0KGuDns9ya+SfKzDcytXrcpHnn++h4MCoC5Et93Jk1ed5b6T5G+T3J/kzzsdMD2dnDrV+7EBUHmi225ycsGnLyb5cpJrk3x3oQPPnevdmACoDRdStVu/ft6nmkkeSGtp+WiSVQudZ2ysp8MCoB7MdNtt2ZKsWdPxqb1J/ifJU0lGFjrHyEhy++29HxsAlefq5XbzXL38u7S+GrQ6ly8N/Gtaf9+9zJo1yauv9mRPZgDqxUy33caNrb2U59xFaHNay8sXkvyx7eeK4DYayc6dggtAR2a6c3XYkapr7+1IlW3bej4sAKrPTHeu7dtbNy9Yu3Zxr1u7tvU6wQVgHq5e7mT2pgU9vssQAMPN8vJCXnqptRfz0aOtuLbvyTx7P92dO1v30zXDBeAqRLcbZ8+2tog8daq18cXYWOtrQbt3u2gKgK6JLgAUxIVUAFAQ0QWAgoguABREdAGgIKILAAURXQAoiOgCQEFEFwAKIroAUBDRBYCCiC4AFER0AaAgogsABRFdACiI6AJAQUQXAAoiugBQENEFgIKILgAURHQBoCCiCwAFEV0AKIjoAkBBRBcACiK6AFAQ0QWAgoguABREdAGgIKILAAURXQAoiOgCQEFEFwAKIroAUBDRBYCCiC4AFER0AaAgogsABRFdACjI/wOmsBUwy8dcxwAAAABJRU5ErkJggg==\n", 55 | "text/plain": [ 56 | "
" 57 | ] 58 | }, 59 | "metadata": {}, 60 | "output_type": "display_data" 61 | } 62 | ], 63 | "source": [ 64 | "g_nx = gutils.to_networkx(g)\n", 65 | "nx.draw_kamada_kawai(g_nx, with_labels=True)\n", 66 | "plt.show()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 11, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "torch.Size([3, 5])\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "import torch\n", 84 | "import torch_geometric\n", 85 | "import torch_geometric.nn as gnn\n", 86 | "import torch_geometric.data as gdata\n", 87 | "\n", 88 | "x = torch.randn(3, 10)\n", 89 | "edge_index = torch.tensor([[0,1],[1,2],[2,0],[2,1]], dtype=torch.int64).transpose(1,0)\n", 90 | "g = gdata.Data(x=x, edge_index=edge_index)\n", 91 | "gcn = gnn.GCNConv(10, 5)\n", 92 | "y = gcn(x=g.x, edge_index=g.edge_index) # it can be further modified. \n", 93 | "print (y.shape)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [] 102 | } 103 | ], 104 | "metadata": { 105 | "kernelspec": { 106 | "display_name": "Python 3", 107 | "language": "python", 108 | "name": "python3" 109 | }, 110 | "language_info": { 111 | "codemirror_mode": { 112 | "name": "ipython", 113 | "version": 3 114 | }, 115 | "file_extension": ".py", 116 | "mimetype": "text/x-python", 117 | "name": "python", 118 | "nbconvert_exporter": "python", 119 | "pygments_lexer": "ipython3", 120 | "version": "3.7.3" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 2 125 | } 126 | -------------------------------------------------------------------------------- /pyg_3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch_geometric\n", 21 | "import torch_geometric.nn as gnn\n", 22 | "import torch_geometric.data as gdata\n", 23 | "import torch_geometric.datasets as gdatasets" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 4, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "output_type": "stream", 33 | "name": "stdout", 34 | "text": "cora()\nData(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])\n" 35 | } 36 | ], 37 | "source": [ 38 | "cora_path = \"pyg_datasets/cora\"\n", 39 | "cora = gdatasets.Planetoid(cora_path, \"cora\")\n", 40 | "print (cora)\n", 41 | "print (cora[0])" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 5, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "output_type": "stream", 51 | "name": "stdout", 52 | "text": "ENZYMES(600)\nData(edge_index=[2, 168], x=[37, 3], y=[1])\nData(edge_index=[2, 102], x=[23, 3], y=[1])\n" 53 | } 54 | ], 55 | "source": [ 56 | "enzymes_path = \"pyg_datasets/enzymes\"\n", 57 | "enzymes = gdatasets.TUDataset(enzymes_path, \"ENZYMES\")\n", 58 | "print (enzymes)\n", 59 | "print (enzymes[0])\n", 60 | "print (enzymes[1])" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 7, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# enzymes_train = enzymes[:540]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 7, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "output_type": "stream", 79 | "name": "stdout", 80 | "text": "epoch: 0\t, loss: 1.9417, test_acc: 0.4460\nepoch: 1\t, loss: 1.8264, test_acc: 0.4910\nepoch: 2\t, loss: 1.6927, test_acc: 0.5060\nepoch: 3\t, loss: 1.5415, test_acc: 0.4960\nepoch: 4\t, loss: 1.4023, test_acc: 0.5200\nepoch: 5\t, loss: 1.2664, test_acc: 0.5720\nepoch: 6\t, loss: 1.1343, test_acc: 0.6290\nepoch: 7\t, loss: 1.0119, test_acc: 0.6630\nepoch: 8\t, loss: 0.8993, test_acc: 0.6850\nepoch: 9\t, loss: 0.7953, test_acc: 0.7020\nepoch: 10\t, loss: 0.6980, test_acc: 0.7170\nepoch: 11\t, loss: 0.6072, test_acc: 0.7320\nepoch: 12\t, loss: 0.5234, test_acc: 0.7560\nepoch: 13\t, loss: 0.4471, test_acc: 0.7670\nepoch: 14\t, loss: 0.3788, test_acc: 0.7720\nepoch: 15\t, loss: 0.3184, test_acc: 0.7860\nepoch: 16\t, loss: 0.2660, test_acc: 0.7910\nepoch: 17\t, loss: 0.2213, test_acc: 0.7920\nepoch: 18\t, loss: 0.1838, test_acc: 0.7960\nepoch: 19\t, loss: 0.1527, test_acc: 0.7980\nepoch: 20\t, loss: 0.1270, test_acc: 0.7970\nepoch: 21\t, loss: 0.1058, test_acc: 0.7950\nepoch: 22\t, loss: 0.0883, test_acc: 0.7950\nepoch: 23\t, loss: 0.0739, test_acc: 0.7940\nepoch: 24\t, loss: 0.0621, test_acc: 0.7940\nepoch: 25\t, loss: 0.0524, test_acc: 0.7950\nepoch: 26\t, loss: 0.0445, test_acc: 0.7940\nepoch: 27\t, loss: 0.0380, test_acc: 0.7950\nepoch: 28\t, loss: 0.0326, test_acc: 0.7960\nepoch: 29\t, loss: 0.0282, test_acc: 0.7930\nepoch: 30\t, loss: 0.0245, test_acc: 0.7900\nepoch: 31\t, loss: 0.0214, test_acc: 0.7900\nepoch: 32\t, loss: 0.0188, test_acc: 0.7880\nepoch: 33\t, loss: 0.0166, test_acc: 0.7860\nepoch: 34\t, loss: 0.0148, test_acc: 0.7870\nepoch: 35\t, loss: 0.0133, test_acc: 0.7910\nepoch: 36\t, loss: 0.0120, test_acc: 0.7910\nepoch: 37\t, loss: 0.0109, test_acc: 0.7900\nepoch: 38\t, loss: 0.0099, test_acc: 0.7880\nepoch: 39\t, loss: 0.0090, test_acc: 0.7870\nepoch: 40\t, loss: 0.0083, test_acc: 0.7860\nepoch: 41\t, loss: 0.0077, test_acc: 0.7820\nepoch: 42\t, loss: 0.0071, test_acc: 0.7830\nepoch: 43\t, loss: 0.0066, test_acc: 0.7840\nepoch: 44\t, loss: 0.0062, test_acc: 0.7850\nepoch: 45\t, loss: 0.0058, test_acc: 0.7850\nepoch: 46\t, loss: 0.0055, test_acc: 0.7840\nepoch: 47\t, loss: 0.0052, test_acc: 0.7840\nepoch: 48\t, loss: 0.0049, test_acc: 0.7840\nepoch: 49\t, loss: 0.0046, test_acc: 0.7840\nepoch: 50\t, loss: 0.0044, test_acc: 0.7840\nepoch: 51\t, loss: 0.0042, test_acc: 0.7830\nepoch: 52\t, loss: 0.0040, test_acc: 0.7830\nepoch: 53\t, loss: 0.0039, test_acc: 0.7830\nepoch: 54\t, loss: 0.0037, test_acc: 0.7830\nepoch: 55\t, loss: 0.0036, test_acc: 0.7830\nepoch: 56\t, loss: 0.0035, test_acc: 0.7820\nepoch: 57\t, loss: 0.0034, test_acc: 0.7820\nepoch: 58\t, loss: 0.0032, test_acc: 0.7820\nepoch: 59\t, loss: 0.0032, test_acc: 0.7810\nepoch: 60\t, loss: 0.0031, test_acc: 0.7820\nepoch: 61\t, loss: 0.0030, test_acc: 0.7820\nepoch: 62\t, loss: 0.0029, test_acc: 0.7820\nepoch: 63\t, loss: 0.0028, test_acc: 0.7820\nepoch: 64\t, loss: 0.0028, test_acc: 0.7820\nepoch: 65\t, loss: 0.0027, test_acc: 0.7820\nepoch: 66\t, loss: 0.0026, test_acc: 0.7820\nepoch: 67\t, loss: 0.0026, test_acc: 0.7820\nepoch: 68\t, loss: 0.0025, test_acc: 0.7840\nepoch: 69\t, loss: 0.0025, test_acc: 0.7850\nepoch: 70\t, loss: 0.0024, test_acc: 0.7850\nepoch: 71\t, loss: 0.0024, test_acc: 0.7850\nepoch: 72\t, loss: 0.0024, test_acc: 0.7850\nepoch: 73\t, loss: 0.0023, test_acc: 0.7850\nepoch: 74\t, loss: 0.0023, test_acc: 0.7850\nepoch: 75\t, loss: 0.0022, test_acc: 0.7860\nepoch: 76\t, loss: 0.0022, test_acc: 0.7850\nepoch: 77\t, loss: 0.0022, test_acc: 0.7850\nepoch: 78\t, loss: 0.0021, test_acc: 0.7850\nepoch: 79\t, loss: 0.0021, test_acc: 0.7850\nepoch: 80\t, loss: 0.0021, test_acc: 0.7860\nepoch: 81\t, loss: 0.0020, test_acc: 0.7860\nepoch: 82\t, loss: 0.0020, test_acc: 0.7860\nepoch: 83\t, loss: 0.0020, test_acc: 0.7860\nepoch: 84\t, loss: 0.0020, test_acc: 0.7860\nepoch: 85\t, loss: 0.0019, test_acc: 0.7860\nepoch: 86\t, loss: 0.0019, test_acc: 0.7860\nepoch: 87\t, loss: 0.0019, test_acc: 0.7860\nepoch: 88\t, loss: 0.0019, test_acc: 0.7870\nepoch: 89\t, loss: 0.0018, test_acc: 0.7870\nepoch: 90\t, loss: 0.0018, test_acc: 0.7880\nepoch: 91\t, loss: 0.0018, test_acc: 0.7880\nepoch: 92\t, loss: 0.0018, test_acc: 0.7860\nepoch: 93\t, loss: 0.0017, test_acc: 0.7860\nepoch: 94\t, loss: 0.0017, test_acc: 0.7860\nepoch: 95\t, loss: 0.0017, test_acc: 0.7850\nepoch: 96\t, loss: 0.0017, test_acc: 0.7850\nepoch: 97\t, loss: 0.0017, test_acc: 0.7850\nepoch: 98\t, loss: 0.0017, test_acc: 0.7850\nepoch: 99\t, loss: 0.0016, test_acc: 0.7850\n" 81 | } 82 | ], 83 | "source": [ 84 | "# this is a example of using gnn.GCNConv\n", 85 | "import torch\n", 86 | "import torch.nn as nn\n", 87 | "import torch.nn.functional as F\n", 88 | "import torch.optim as optim\n", 89 | "import torch_geometric\n", 90 | "import torch_geometric.data as gdata\n", 91 | "import torch_geometric.datasets as gdatasets\n", 92 | "\n", 93 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 94 | "cora_path = \"pyg_datasets/cora\"\n", 95 | "cora = gdatasets.Planetoid(cora_path, \"cora\")\n", 96 | "cora = cora[0]\n", 97 | "\n", 98 | "class Net(nn.Module):\n", 99 | " def __init__(self, in_feature, hid_feature, out_feature):\n", 100 | " nn.Module.__init__(self)\n", 101 | " self.gcn_1 = gnn.GCNConv(in_feature, hid_feature)\n", 102 | " self.gcn_2 = gnn.GCNConv(hid_feature, out_feature)\n", 103 | " \n", 104 | " def forward(self, x, edge_index):\n", 105 | " x = self.gcn_1(x, edge_index)\n", 106 | " x = F.relu(x)\n", 107 | " x = self.gcn_2(x, edge_index)\n", 108 | " return F.log_softmax(x, dim=1)\n", 109 | "\n", 110 | "net, cora = Net(1433, 16, 7).to(device), cora.to(device)\n", 111 | "optimizer = optim.Adam(net.parameters(), lr=1e-2)\n", 112 | "\n", 113 | "for i in range(100):\n", 114 | " optimizer.zero_grad()\n", 115 | " logit = net(cora.x, cora.edge_index)[cora.train_mask]\n", 116 | " loss = F.nll_loss(logit, cora.y[cora.train_mask])\n", 117 | " loss.backward()\n", 118 | " optimizer.step()\n", 119 | " pred = net(cora.x, cora.edge_index)[cora.test_mask]\n", 120 | " pred = pred.max(1)[1]\n", 121 | " acc = pred.eq(cora.y[cora.test_mask]).sum().item() / cora.test_mask.sum().item()\n", 122 | " print (\"epoch: {}\\t, loss: {:.4f}, test_acc: {:.4f}\".format(i, loss, acc))\n" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 6, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "output_type": "stream", 132 | "name": "stdout", 133 | "text": "epoch: 0\t, loss: 1.9335, test_acc: 0.1000\nepoch: 1\t, loss: 1.8521, test_acc: 0.1667\nepoch: 2\t, loss: 1.8398, test_acc: 0.0833\nepoch: 3\t, loss: 1.8276, test_acc: 0.1500\nepoch: 4\t, loss: 1.8382, test_acc: 0.1833\nepoch: 5\t, loss: 1.8182, test_acc: 0.2000\nepoch: 6\t, loss: 1.7966, test_acc: 0.1667\nepoch: 7\t, loss: 1.8075, test_acc: 0.2167\nepoch: 8\t, loss: 1.7327, test_acc: 0.2000\nepoch: 9\t, loss: 1.7697, test_acc: 0.2333\nepoch: 10\t, loss: 1.7340, test_acc: 0.3667\nepoch: 11\t, loss: 1.7433, test_acc: 0.2833\nepoch: 12\t, loss: 1.7725, test_acc: 0.2333\nepoch: 13\t, loss: 1.6981, test_acc: 0.2333\nepoch: 14\t, loss: 1.7049, test_acc: 0.3667\nepoch: 15\t, loss: 1.7954, test_acc: 0.1833\nepoch: 16\t, loss: 1.5727, test_acc: 0.2333\nepoch: 17\t, loss: 1.6031, test_acc: 0.3333\nepoch: 18\t, loss: 1.7465, test_acc: 0.3500\nepoch: 19\t, loss: 1.5707, test_acc: 0.3000\nepoch: 20\t, loss: 1.7060, test_acc: 0.3167\nepoch: 21\t, loss: 1.7750, test_acc: 0.2500\nepoch: 22\t, loss: 1.5671, test_acc: 0.3000\nepoch: 23\t, loss: 1.5639, test_acc: 0.3333\nepoch: 24\t, loss: 1.6557, test_acc: 0.3667\nepoch: 25\t, loss: 1.6018, test_acc: 0.2500\nepoch: 26\t, loss: 1.6749, test_acc: 0.2500\nepoch: 27\t, loss: 1.6249, test_acc: 0.2667\nepoch: 28\t, loss: 1.6330, test_acc: 0.3167\nepoch: 29\t, loss: 1.6527, test_acc: 0.3167\nepoch: 30\t, loss: 1.5577, test_acc: 0.3500\nepoch: 31\t, loss: 1.6763, test_acc: 0.3667\nepoch: 32\t, loss: 1.6061, test_acc: 0.3667\nepoch: 33\t, loss: 1.5117, test_acc: 0.3833\nepoch: 34\t, loss: 1.7553, test_acc: 0.3833\nepoch: 35\t, loss: 1.5468, test_acc: 0.3167\nepoch: 36\t, loss: 1.6121, test_acc: 0.3500\nepoch: 37\t, loss: 1.5593, test_acc: 0.3667\nepoch: 38\t, loss: 1.5578, test_acc: 0.3167\nepoch: 39\t, loss: 1.4911, test_acc: 0.3333\nepoch: 40\t, loss: 1.6122, test_acc: 0.3833\nepoch: 41\t, loss: 1.5523, test_acc: 0.3500\nepoch: 42\t, loss: 1.4960, test_acc: 0.2833\nepoch: 43\t, loss: 1.5295, test_acc: 0.3667\nepoch: 44\t, loss: 1.6140, test_acc: 0.3667\nepoch: 45\t, loss: 1.5666, test_acc: 0.3833\nepoch: 46\t, loss: 1.5833, test_acc: 0.4333\nepoch: 47\t, loss: 1.4843, test_acc: 0.3500\nepoch: 48\t, loss: 1.5178, test_acc: 0.4833\nepoch: 49\t, loss: 1.4627, test_acc: 0.4000\nepoch: 50\t, loss: 1.7181, test_acc: 0.3833\nepoch: 51\t, loss: 1.5530, test_acc: 0.3833\nepoch: 52\t, loss: 1.5469, test_acc: 0.4167\nepoch: 53\t, loss: 1.4480, test_acc: 0.3833\nepoch: 54\t, loss: 1.4740, test_acc: 0.4000\nepoch: 55\t, loss: 1.5525, test_acc: 0.3833\nepoch: 56\t, loss: 1.6229, test_acc: 0.4000\nepoch: 57\t, loss: 1.5572, test_acc: 0.4000\nepoch: 58\t, loss: 1.4583, test_acc: 0.4333\nepoch: 59\t, loss: 1.4862, test_acc: 0.4333\nepoch: 60\t, loss: 1.5586, test_acc: 0.4833\nepoch: 61\t, loss: 1.4724, test_acc: 0.3833\nepoch: 62\t, loss: 1.4478, test_acc: 0.4333\nepoch: 63\t, loss: 1.3095, test_acc: 0.4000\nepoch: 64\t, loss: 1.4990, test_acc: 0.4000\nepoch: 65\t, loss: 1.3114, test_acc: 0.4000\nepoch: 66\t, loss: 1.3036, test_acc: 0.4333\nepoch: 67\t, loss: 1.4126, test_acc: 0.4833\nepoch: 68\t, loss: 1.3047, test_acc: 0.3667\nepoch: 69\t, loss: 1.3002, test_acc: 0.4500\nepoch: 70\t, loss: 1.4454, test_acc: 0.3833\nepoch: 71\t, loss: 1.3319, test_acc: 0.3667\nepoch: 72\t, loss: 1.2904, test_acc: 0.4167\nepoch: 73\t, loss: 1.1735, test_acc: 0.3667\nepoch: 74\t, loss: 1.4170, test_acc: 0.3833\nepoch: 75\t, loss: 1.1892, test_acc: 0.3500\nepoch: 76\t, loss: 1.4599, test_acc: 0.3833\nepoch: 77\t, loss: 1.2196, test_acc: 0.3833\nepoch: 78\t, loss: 1.3578, test_acc: 0.4167\nepoch: 79\t, loss: 1.3342, test_acc: 0.3833\nepoch: 80\t, loss: 1.3283, test_acc: 0.4167\nepoch: 81\t, loss: 1.3372, test_acc: 0.4333\nepoch: 82\t, loss: 1.1629, test_acc: 0.3333\nepoch: 83\t, loss: 1.2680, test_acc: 0.4000\nepoch: 84\t, loss: 1.2172, test_acc: 0.3833\nepoch: 85\t, loss: 1.2192, test_acc: 0.3500\nepoch: 86\t, loss: 1.3005, test_acc: 0.4000\nepoch: 87\t, loss: 1.3010, test_acc: 0.4500\nepoch: 88\t, loss: 1.2899, test_acc: 0.3833\nepoch: 89\t, loss: 1.4239, test_acc: 0.4333\nepoch: 90\t, loss: 1.2186, test_acc: 0.3667\nepoch: 91\t, loss: 1.3733, test_acc: 0.3833\nepoch: 92\t, loss: 1.2635, test_acc: 0.3333\nepoch: 93\t, loss: 1.3786, test_acc: 0.3833\nepoch: 94\t, loss: 1.3367, test_acc: 0.4167\nepoch: 95\t, loss: 1.0872, test_acc: 0.4500\nepoch: 96\t, loss: 1.1411, test_acc: 0.4000\nepoch: 97\t, loss: 1.1328, test_acc: 0.4500\nepoch: 98\t, loss: 1.2630, test_acc: 0.4333\nepoch: 99\t, loss: 1.0744, test_acc: 0.3833\n" 134 | } 135 | ], 136 | "source": [ 137 | "import torch\n", 138 | "import torch.nn as nn\n", 139 | "import torch.nn.functional as F\n", 140 | "import torch.optim as optim\n", 141 | "import torch_geometric\n", 142 | "import torch_geometric.nn as gnn\n", 143 | "import torch_geometric.data as gdata\n", 144 | "import torch_geometric.datasets as gdatasets\n", 145 | "\n", 146 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 147 | "enzymes_path= \"pyg_datasets/enzymes\"\n", 148 | "enzymes = gdatasets.TUDataset(enzymes_path, \"ENZYMES\")\n", 149 | "enzymes = enzymes.shuffle() # shuffle the dtataset\n", 150 | "enzymes_train = enzymes[:540]\n", 151 | "enzymes_test = enzymes[540:]\n", 152 | "enzymes_trainloader = gdata.DataLoader(enzymes_train, batch_size=60, shuffle=True)\n", 153 | "enzymes_testloader = gdata.DataLoader(enzymes_test, batch_size=60, shuffle=False)\n", 154 | "\n", 155 | "class Net(nn.Module):\n", 156 | " def __init__(self):\n", 157 | " nn.Module.__init__(self)\n", 158 | " self.conv1 = gnn.GraphConv(3, 128)\n", 159 | " self.pool1 = gnn.TopKPooling(128, ratio=0.8)\n", 160 | " self.conv2 = gnn.GraphConv(128, 128)\n", 161 | " self.pool2 = gnn.TopKPooling(128, ratio=0.8)\n", 162 | " self.conv3 = gnn.GraphConv(128, 128)\n", 163 | " self.pool3 = gnn.TopKPooling(128, ratio=0.8)\n", 164 | "\n", 165 | " self.lin1 = nn.Linear(256, 128)\n", 166 | " self.lin2 = nn.Linear(128, 64)\n", 167 | " self.lin3 = nn.Linear(64, 7)\n", 168 | "\n", 169 | " def forward(self, x, edge_index, batch):\n", 170 | " x = F.relu(self.conv1(x, edge_index))\n", 171 | " x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)\n", 172 | " x1 = torch.cat([gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1)\n", 173 | "\n", 174 | " x = F.relu(self.conv2(x, edge_index))\n", 175 | " x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)\n", 176 | " x2 = torch.cat([gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1)\n", 177 | " \n", 178 | " x = F.relu(self.conv3(x, edge_index))\n", 179 | " x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)\n", 180 | " x3 = torch.cat([gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1)\n", 181 | " \n", 182 | " x = x1 + x2 + x3\n", 183 | "\n", 184 | " x = F.relu(self.lin1(x))\n", 185 | " x = F.dropout(x, p=0.5, training=self.training)\n", 186 | " x = F.relu(self.lin2(x))\n", 187 | " x = F.log_softmax(self.lin3(x), dim=-1)\n", 188 | "\n", 189 | " return x\n", 190 | "\n", 191 | "net = Net().to(device)\n", 192 | "optimizer = optim.Adam(net.parameters(), lr=5e-4)\n", 193 | "\n", 194 | "for i in range(100): # one epoch\n", 195 | " for one_train_batch in enzymes_trainloader:\n", 196 | " net.train()\n", 197 | " optimizer.zero_grad()\n", 198 | " one_train_batch = one_train_batch.to(device)\n", 199 | " logit = net(one_train_batch.x, one_train_batch.edge_index, one_train_batch.batch)\n", 200 | " loss = F.nll_loss(logit, one_train_batch.y)\n", 201 | " loss.backward()\n", 202 | " optimizer.step()\n", 203 | " for one_test_batch in enzymes_testloader:\n", 204 | " net.eval()\n", 205 | " one_test_batch = one_test_batch.to(device)\n", 206 | " pred = net(one_test_batch.x, one_test_batch.edge_index, one_test_batch.batch)\n", 207 | " pred = pred.max(1)[1]\n", 208 | " acc = pred.eq(one_test_batch.y).sum().item() / len(one_test_batch.y)\n", 209 | " print (\"epoch: {}\\t, loss: {:.4f}, test_acc: {:.4f}\".format(i, loss, acc))" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.7.3-final" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } --------------------------------------------------------------------------------