├── .gitignore ├── README.md ├── data_processing ├── __init__.py ├── data_analysis.ipynb ├── data_extraction.py ├── data_generation.py ├── data_util.py └── text_util.py ├── graph_pb2.py ├── models ├── __init__.py ├── full_model.py ├── gat_encoder.py ├── gcn_encoder.py ├── graph_attention_layer.py ├── graph_convolutional_layer.py ├── lstm_decoder.py └── lstm_encoder.py └── training ├── __init__.py ├── evaluation_util.py ├── train.py └── train_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | ### JetBrains template 108 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 109 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 110 | 111 | # User-specific stuff 112 | .idea/**/workspace.xml 113 | .idea/**/tasks.xml 114 | .idea/**/dictionaries 115 | .idea/**/shelf 116 | 117 | # Sensitive or high-churn files 118 | .idea/**/dataSources/ 119 | .idea/**/dataSources.ids 120 | .idea/**/dataSources.local.xml 121 | .idea/**/sqlDataSources.xml 122 | .idea/**/dynamic.xml 123 | .idea/**/uiDesigner.xml 124 | .idea/**/dbnavigator.xml 125 | 126 | # Gradle 127 | .idea/**/gradle.xml 128 | .idea/**/libraries 129 | 130 | # CMake 131 | cmake-build-debug/ 132 | cmake-build-release/ 133 | 134 | # Mongo Explorer plugin 135 | .idea/**/mongoSettings.xml 136 | 137 | # File-based project format 138 | *.iws 139 | 140 | # IntelliJ 141 | out/ 142 | 143 | # mpeltonen/sbt-idea plugin 144 | .idea_modules/ 145 | 146 | # JIRA plugin 147 | atlassian-ide-plugin.xml 148 | 149 | # Cursive Clojure plugin 150 | .idea/replstate.xml 151 | 152 | # Crashlytics plugin (for Android Studio and IntelliJ) 153 | com_crashlytics_export_strings.xml 154 | crashlytics.properties 155 | crashlytics-build.properties 156 | fabric.properties 157 | 158 | # Editor-based Rest Client 159 | .idea/httpRequests 160 | 161 | data/ 162 | 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structured Neural Summarization 2 | 3 | ## Extracting the Dataset 4 | In order to extract the features from the corpus proto files, run: 5 | 6 | `python data_processing/data_generation.py` 7 | 8 | In order for the command to be successful, it is necessary to have a directory 9 | _corpus/r252-corpus-features_ with the protos of the corpus. Optionally, it is possible to 10 | downloaded the extracted dataset at https://drive.google.com/file/d/14k4AgOVws4_TfPtDGefXzPn3x2Ph083h/view?usp=sharing. After putting the downloaded file 11 | under the _data/_ directory (which needs to be created), it is possible to train and evaluate the 12 | model. 13 | 14 | ## Running the Models 15 | In order to train a model and evaluate a model, run: 16 | 17 | `python training/train.py --model_name="lstm_gcn_to_lstm_attention" 18 | --print_every=10000 --attention=True --graph=True --iterations=500000` 19 | 20 | All the possible options when running a model can be seen by running: 21 | 22 | `python train.py --help` 23 | 24 | ## Pretrained Models 25 | A pretrained version of the best performing model (as a state dictionary) can be downloaded at 26 | https://drive.google.com/file/d/1fm7hGzr-tziNhUMh8duc8s4j5gWW3uKm/view?usp=sharing 27 | 28 | ## High-Level Code Structure 29 | - data_processing/: contains the code for extracting, storing, analysing and processing data 30 | - data_analysis.ipynb: notebook containing analysis of the extracted data 31 | - data_extraction.py: contains the logic to extract the features data from the proto files of 32 | the corpus 33 | - data_generation.py: file to be called to generate the features data 34 | - data_util.py: contains utilities to work with data 35 | - text_util.py: contains utilities to work with text 36 | - models/: contains all the code for the different models 37 | - full_model.py: class of the complete methodNaming model 38 | - gat_encoder.py: class for the Graph Attention Network encoder 39 | - gcn_encoder.py: class for the Graph Convolutional Network encoder 40 | - graph_attention_layer.py: class for the Graph Attention Layer used by the Graph Attention 41 | Network 42 | - graph_convolutional_layer.py: class for the Graph Convolutional Layer used by the Graph 43 | Convolutional Network 44 | - lstm_decoder.py: class for the LSTM sequence decoder 45 | - lstm_encoder.py: class for the LSTM sequence encoder 46 | - training.py: contains code to train and evaluate the models 47 | - evaluation_util.py: contains utilities to compute evaluation metrics 48 | - train.py: entry-point for training the models 49 | - train_model.py: contains logic to train the models 50 | 51 | -------------------------------------------------------------------------------- /data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emalgorithm/structured-neural-summarization-replication/e9b672c0b76ea075ff1a3ec8a5c3fc88afc521b9/data_processing/__init__.py -------------------------------------------------------------------------------- /data_processing/data_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "import pickle\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import numpy as np" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Methods Analysis" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 8, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "data = pickle.load(open('../data/methods_tokens_data.pkl', 'rb'))\n", 40 | "methods_source = data['methods_source']\n", 41 | "methods_names = data['methods_names']\n", 42 | "# method_graphs = data['method_graphs']" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 9, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "The repo contains 88479 methods\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "print(\"The repo contains {} methods\".format(len(methods_names)))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 12, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAEKCAYAAADaa8itAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHkBJREFUeJzt3XuUFeWZ7/HvT/CeKKiEQS4BTccc4hlRO0riZcgYETUjmrgUViYSxyPxqBNzmZVBJys6yXiGTDRGE4eIyhEy8Rq8MIohQLyMc6JclOGmhlYxNiIwohIvwaDP+aPejUWzd/duqN2b3fw+a9XaVU+9VfW+u7Qf3rdqVykiMDMzK8Iu9a6AmZl1H04qZmZWGCcVMzMrjJOKmZkVxknFzMwK46RiZmaFcVIxM7PCOKmYmVlhnFTMzKwwPetdga52wAEHxODBg+tdDTOzhrJw4cL/jog+HZXb6ZLK4MGDWbBgQb2rYWbWUCS9WE05D3+ZmVlhnFTMzKwwTipmZlYYJxUzMyuMk4qZmRXGScXMzApTs6QiaaCkhyQtl7RM0iUpvp+k2ZJWpM/eKS5J10lqkbRY0hG5fY1L5VdIGpeLHylpSdrmOkmqVXvMzKxjteypbAK+FRFDgeHARZKGAhOAuRHRBMxNywAnA01pGg9MgiwJAZcDRwNHAZeXElEqc35uu1E1bI+ZmXWgZkklIlZHxJNp/g/A00B/YDQwNRWbCpye5kcD0yLzONBLUj/gJGB2RKyPiNeA2cCotG6fiHg8IgKYltuXmZnVQZf8ol7SYOBw4Amgb0SsTqteAfqm+f7AS7nNWlOsvXhrmXi5448n6/0waNCgbW9IzuAJD2yeXznx1EL2aWbW6Gp+oV7Sh4DpwNcjYkN+XephRK3rEBGTI6I5Ipr79Onw0TVmZraNappUJO1KllB+ERF3p/CaNHRF+lyb4quAgbnNB6RYe/EBZeJmZlYntbz7S8DNwNMR8aPcqhlA6Q6uccB9ufg56S6w4cAbaZhsFjBSUu90gX4kMCut2yBpeDrWObl9mZlZHdTymsoxwJeBJZIWpdhlwETgTknnAS8CZ6V1M4FTgBbgbeBcgIhYL+n7wPxU7nsRsT7NXwjcAuwJPJgmMzOrk5ollYh4DKj0u5ETypQP4KIK+5oCTCkTXwAcuh3VNDOzAvkX9WZmVhgnFTMzK4yTipmZFcZJxczMCuOkYmZmhXFSMTOzwjipmJlZYZxUzMysME4qZmZWGCcVMzMrjJOKmZkVxknFzMwK46RiZmaFcVIxM7PCOKmYmVlhnFTMzKwwTipmZlaYWr6jfoqktZKW5mJ3SFqUppWl1wxLGizpndy6n+W2OVLSEkktkq5L76NH0n6SZktakT5716otZmZWnVr2VG4BRuUDEXF2RAyLiGHAdODu3OrnSusi4oJcfBJwPtCUptI+JwBzI6IJmJuWzcysjmqWVCLiUWB9uXWpt3EWcFt7+5DUD9gnIh5P77CfBpyeVo8Gpqb5qbm4mZnVSb2uqRwHrImIFbnYEElPSXpE0nEp1h9ozZVpTTGAvhGxOs2/AvStdDBJ4yUtkLRg3bp1BTXBzMzaqldSGcuWvZTVwKCIOBz4JnCrpH2q3VnqxUQ76ydHRHNENPfp02db62xmZh3o2dUHlNQT+AJwZCkWERuBjWl+oaTngI8Dq4ABuc0HpBjAGkn9ImJ1GiZb2xX1NzOzyurRU/kc8ExEbB7WktRHUo80fxDZBfnn0/DWBknD03WYc4D70mYzgHFpflwubmZmdVLLW4pvA34LHCKpVdJ5adUYtr5AfzywON1i/EvggogoXeS/ELgJaAGeAx5M8YnAiZJWkCWqibVqi5mZVadmw18RMbZC/CtlYtPJbjEuV34BcGiZ+KvACdtXSzMzK5J/UW9mZoVxUjEzs8I4qZiZWWGcVMzMrDBOKmZmVhgnFTMzK4yTipmZFcZJxczMCuOkYmZmhXFSMTOzwjipmJlZYZxUzMysME4qZmZWGCcVMzMrjJOKmZkVxknFzMwKU8s3P06RtFbS0lzsCkmrJC1K0ym5dZdKapH0rKSTcvFRKdYiaUIuPkTSEyl+h6TdatUWMzOrTi17KrcAo8rEr4mIYWmaCSBpKNlrhj+ZtvlXST3Se+uvB04GhgJjU1mAH6R9fQx4DTiv7YHMzKxr1SypRMSjwPoOC2ZGA7dHxMaIeIHsffRHpaklIp6PiHeB24HRkgT8Jdn77AGmAqcX2gAzM+u0elxTuVjS4jQ81jvF+gMv5cq0plil+P7A6xGxqU3czMzqqKuTyiTgYGAYsBq4uisOKmm8pAWSFqxbt64rDmlmtlPq0qQSEWsi4r2IeB+4kWx4C2AVMDBXdECKVYq/CvSS1LNNvNJxJ0dEc0Q09+nTp5jGmJnZVro0qUjql1s8AyjdGTYDGCNpd0lDgCZgHjAfaEp3eu1GdjF/RkQE8BBwZtp+HHBfV7TBzMwq69lxkW0j6TZgBHCApFbgcmCEpGFAACuBrwJExDJJdwLLgU3ARRHxXtrPxcAsoAcwJSKWpUP8PXC7pH8CngJurlVbzMysOjVLKhExtky44h/+iLgSuLJMfCYws0z8eT4YPjMzsx2Af1FvZmaFcVIxM7PCOKmYmVlhnFTMzKwwTipmZlYYJxUzMyuMk4qZmRXGScXMzArjpGJmZoVxUjEzs8I4qZiZWWGcVMzMrDBOKmZmVhgnFTMzK4yTipmZFcZJxczMCuOkYmZmhakqqUj6n53dsaQpktZKWpqL/VDSM5IWS7pHUq8UHyzpHUmL0vSz3DZHSloiqUXSdZKU4vtJmi1pRfrs3dk6mplZsartqfyrpHmSLpS0b5Xb3AKMahObDRwaEX8O/A64NLfuuYgYlqYLcvFJwPlAU5pK+5wAzI2IJmBuWjYzszqqKqlExHHAl4CBwEJJt0o6sYNtHgXWt4n9OiI2pcXHgQHt7UNSP2CfiHg8IgKYBpyeVo8Gpqb5qbm4mZnVSc9qC0bECknfARYA1wGHp6GoyyLi7m049t8Ad+SWh0h6CtgAfCci/gPoD7TmyrSmGEDfiFid5l8B+m5DHQoxeMIDm+dXTjy1XtUwM6u7qpKKpD8HzgVOJRvC+quIeFLSgcBvgU4lFUn/AGwCfpFCq4FBEfGqpCOBeyV9str9RURIinaONx4YDzBo0KDOVNXMzDqh2msqPwGeBA6LiIsi4kmAiHgZ+E5nDijpK8DngS+lIS0iYmNEvJrmFwLPAR8HVrHlENmAFANYk4bHSsNkaysdMyImR0RzRDT36dOnM9U1M7NOqDapnArcGhHvAEjaRdJeABHx82oPJmkU8G3gtIh4OxfvI6lHmj+I7IL882l4a4Ok4Wmo7RzgvrTZDGBcmh+Xi5uZWZ1Um1TmAHvmlvdKsYok3UY2NHaIpFZJ5wE/BT4MzG5z6/DxwGJJi4BfAhdEROki/4XATUALWQ/mwRSfCJwoaQXwubRsZmZ1VO2F+j0i4s3SQkS8WeqpVBIRY8uEb65QdjowvcK6BcChZeKvAie0VwczM+ta1fZU3pJ0RGkhXUx/pzZVMjOzRlVtT+XrwF2SXgYE/Blwds1qZWZmDamqpBIR8yV9AjgkhZ6NiD/VrlpmZtaIqv7xI/ApYHDa5ghJRMS0mtTKzMwaUrU/fvw5cDCwCHgvhUuPTTEzMwOq76k0A0NLP1Y0MzMrp9q7v5aSXZw3MzOrqNqeygHAcknzgI2lYEScVpNamZlZQ6o2qVxRy0qYmVn3UO0txY9I+ijQFBFz0q/pe9S2amZm1miqfZ3w+WTP5LohhfoD99aqUmZm1piqvVB/EXAM2Qu0iIgVwEdqVSkzM2tM1SaVjRHxbmlBUk+y36mYmZltVm1SeUTSZcCe6d30dwH/XrtqmZlZI6o2qUwA1gFLgK8CM+nkGx/NzKz7q/bur/eBG9NkZmZWVrXP/nqBMtdQIuKgwmtkZmYNq9rhr2aypxR/CjgOuA74t442kjRF0lpJS3Ox/STNlrQiffZOcUm6TlKLpMVtXgo2LpVfIWlcLn6kpCVpm+vSe+zNzKxOqkoqEfFqbloVET8GTq1i01uAUW1iE4C5EdEEzE3LACcDTWkaD0yCLAkBlwNHA0cBl5cSUSpzfm67tscyM7MuVO3w1xG5xV3Iei4dbhsRj0oa3CY8GhiR5qcCDwN/n+LT0pOQH5fUS1K/VHZ2RKxPdZkNjJL0MLBPRDye4tOA04EHq2mTmZkVr9pnf12dm98ErATO2sZj9o2I1Wn+FaBvmu8PvJQr15pi7cVby8TNzKxOqr3767O1OHhEhKSa/4hS0niyITUGDRpU68OZme20qh3++mZ76yPiR5045hpJ/SJidRreWpviq4CBuXIDUmwVHwyXleIPp/iAMuXL1W8yMBmgubnZTwIwM6uRztz99b/5YNjpAuAI4MNp6owZQOkOrnHAfbn4OekusOHAG2mYbBYwUlLvdIF+JDArrdsgaXi66+uc3L7MzKwOqr2mMgA4IiL+ACDpCuCBiPjr9jaSdBtZL+MASa1kd3FNBO6UdB7wIh9cm5kJnAK0AG8D5wJExHpJ3wfmp3LfK120By4ku8NsT7IL9L5Ib2ZWR9Umlb7Au7nld/ngAntFETG2wqoTypQNsqchl9vPFGBKmfgC4NCO6mFmZl2j2qQyDZgn6Z60fDrZ7cBmZmabVXv315WSHiT7NT3AuRHxVO2qZWZmjajaC/UAewEbIuJaoFXSkBrVyczMGlS1rxO+nOxX75em0K5U8ewvMzPbuVTbUzkDOA14CyAiXqbztxKbmVk3V21SeTfdnRUAkvauXZXMzKxRVZtU7pR0A9BL0vnAHPzCLjMza6Pau7+uSu+m3wAcAnw3ImbXtGZmZtZwOkwqknoAc9JDJZ1IzMysog6HvyLiPeB9Sft2QX3MzKyBVfuL+jeBJekFWW+VghHxtZrUyszMGlK1SeXuNJmZmVXUblKRNCgifh8Rfs6XmZl1qKNrKveWZiRNr3FdzMyswXWUVJSbP6iWFTEzs8bXUVKJCvNmZmZb6ehC/WGSNpD1WPZM86TliIh9alo7MzNrKO0mlYjoUfQBJR0C3JELHQR8F+gFnA+sS/HLImJm2uZS4DzgPeBrETErxUcB1wI9gJsiYmLR9e2swRMe2Dy/cuKpdayJmVnXq/aW4sJExLPAMNj8a/1VwD1k76S/JiKuypeXNBQYA3wSOBCYI+njafX1wIlAKzBf0oyIWN4lDTEzs610eVJp4wTguYh4UVKlMqOB2yNiI/CCpBbgqLSuJSKeB5B0eyrrpGJmViedefNjLYwBbsstXyxpsaQpknqnWH/gpVyZ1hSrFN+KpPGSFkhasG7dunJFzMysAHVLKpJ2I3vx110pNAk4mGxobDVwdVHHiojJEdEcEc19+vQpardmZtZGPYe/TgaejIg1AKVPAEk3AvenxVXAwNx2A1KMduJmZlYH9Rz+Gktu6EtSv9y6M4ClaX4GMEbS7pKGAE3APGA+0CRpSOr1jEllzcysTurSU0mvIz4R+Gou/C+ShpH9yHJlaV1ELJN0J9kF+E3ARelx/Ei6GJhFdkvxlIhY1mWNMDOzrdQlqUTEW8D+bWJfbqf8lcCVZeIzgZmFV9DMzLZJve/+MjOzbsRJxczMCuOkYmZmhXFSMTOzwjipmJlZYZxUzMysME4qZmZWGCcVMzMrjJOKmZkVxknFzMwK46RiZmaFcVIxM7PCOKmYmVlhnFTMzKwwTipmZlYYJxUzMyuMk4qZmRWmbklF0kpJSyQtkrQgxfaTNFvSivTZO8Ul6TpJLZIWSzoit59xqfwKSePq1R4zM6t/T+WzETEsIprT8gRgbkQ0AXPTMsDJQFOaxgOTIEtCwOXA0cBRwOWlRGRmZl2v3kmlrdHA1DQ/FTg9F58WmceBXpL6AScBsyNifUS8BswGRnV1pc3MLNOzjscO4NeSArghIiYDfSNidVr/CtA3zfcHXspt25pileJbkDSerIfDoEGDimxDuwZPeGDz/MqJp3bZcc3M6qWeSeXYiFgl6SPAbEnP5FdGRKSEs91SwpoM0NzcXMg+zcxsa3Ub/oqIVelzLXAP2TWRNWlYi/S5NhVfBQzMbT4gxSrFzcysDuqSVCTtLenDpXlgJLAUmAGU7uAaB9yX5mcA56S7wIYDb6RhslnASEm90wX6kSlmZmZ1UK/hr77APZJKdbg1In4laT5wp6TzgBeBs1L5mcApQAvwNnAuQESsl/R9YH4q972IWN91zTAzs7y6JJWIeB44rEz8VeCEMvEALqqwrynAlKLraGZmnbej3VJsZmYNzEnFzMwK46RiZmaFcVIxM7PC1PPHjzsV/7rezHYG7qmYmVlh3FPphHxvw8zMtuaeipmZFcZJxczMCuOkYmZmhXFSMTOzwvhCfR349mIz667cUzEzs8I4qZiZWWGcVMzMrDBOKmZmVpguTyqSBkp6SNJyScskXZLiV0haJWlRmk7JbXOppBZJz0o6KRcflWItkiZ0dVvMzGxL9bj7axPwrYh4Mr2nfqGk2WndNRFxVb6wpKHAGOCTwIHAHEkfT6uvB04EWoH5kmZExPIuaYWZmW2ly5NKRKwGVqf5P0h6GujfziajgdsjYiPwgqQW4Ki0riW9mhhJt6eyDZVUfHuxmXUndb2mImkwcDjwRApdLGmxpCmSeqdYf+Cl3GatKVYpbmZmdVK3pCLpQ8B04OsRsQGYBBwMDCPryVxd4LHGS1ogacG6deuK2q2ZmbVRl1/US9qVLKH8IiLuBoiINbn1NwL3p8VVwMDc5gNSjHbiW4iIycBkgObm5iigCTXhoTAza3T1uPtLwM3A0xHxo1y8X67YGcDSND8DGCNpd0lDgCZgHjAfaJI0RNJuZBfzZ3RFG8zMrLx69FSOAb4MLJG0KMUuA8ZKGgYEsBL4KkBELJN0J9kF+E3ARRHxHoCki4FZQA9gSkQs68qGmJnZlupx99djgMqsmtnONlcCV5aJz2xvOzMz61p+SvEOytdXzKwR+TEtZmZWGCcVMzMrjIe/GkB+KAw8HGZmOy4nlQbk6y1mtqPy8JeZmRXGPZUG516Lme1InFS6EScYM6s3D3+ZmVlh3FPpptreMVbiHoyZ1ZJ7KmZmVhj3VHZivgZjZkVzUtnJVBoWc4IxsyI4qdhWfD3GzLaVk4pVzb0ZM+uIk4ptk0q9mTwnHrOdj5OK1YwTj9nOp+GTiqRRwLVkrxS+KSIm1rlK1gnVJJ48JyGzHVtDJxVJPYDrgROBVmC+pBkRsby+NbNa6WwSqpaTlVkxGjqpAEcBLRHxPICk24HRgJOKdUqtklVJPmn57jrrzho9qfQHXsottwJH16kuZhVVk7Rqndh2Ru0l80rrqolbZY2eVKoiaTwwPi2+KenZbdzVAcB/F1OrhuE27xy6ZZv1g3bXlW1zpW3a21eD2N5z/NFqCjV6UlkFDMwtD0ixLUTEZGDy9h5M0oKIaN7e/TQSt3nn4DZ3f13V3kZ/oOR8oEnSEEm7AWOAGXWuk5nZTquheyoRsUnSxcAssluKp0TEsjpXy8xsp9XQSQUgImYCM7vocNs9hNaA3Oadg9vc/XVJexURXXEcMzPbCTT6NRUzM9uBOKlUSdIoSc9KapE0od71KYKkgZIekrRc0jJJl6T4fpJmS1qRPnunuCRdl76DxZKOqG8Ltp2kHpKeknR/Wh4i6YnUtjvSjR9I2j0tt6T1g+tZ720lqZekX0p6RtLTkj7d3c+zpG+k/66XSrpN0h7d7TxLmiJpraSluVinz6ukcan8CknjtqdOTipVyD0O5mRgKDBW0tD61qoQm4BvRcRQYDhwUWrXBGBuRDQBc9MyZO1vStN4YFLXV7kwlwBP55Z/AFwTER8DXgPOS/HzgNdS/JpUrhFdC/wqIj4BHEbW9m57niX1B74GNEfEoWQ38oyh+53nW4BRbWKdOq+S9gMuJ/vh+FHA5aVEtE0iwlMHE/BpYFZu+VLg0nrXqwbtvI/sOWrPAv1SrB/wbJq/ARibK7+5XCNNZL9nmgv8JXA/ILIfhfVse77J7iz8dJrvmcqp3m3oZHv3BV5oW+/ufJ754Gkb+6Xzdj9wUnc8z8BgYOm2nldgLHBDLr5Fuc5O7qlUp9zjYPrXqS41kbr7hwNPAH0jYnVa9QrQN813l+/hx8C3gffT8v7A6xGxKS3n27W5zWn9G6l8IxkCrAP+bxryu0nS3nTj8xwRq4CrgN8Dq8nO20K693ku6ex5LfR8O6kYkj4ETAe+HhEb8usi+6dLt7lFUNLngbURsbDedelCPYEjgEkRcTjwFh8MiQDd8jz3Jnu47BDgQGBvth4m6vbqcV6dVKpT1eNgGpGkXckSyi8i4u4UXiOpX1rfD1ib4t3hezgGOE3SSuB2siGwa4Fekkq/28q3a3Ob0/p9gVe7ssIFaAVaI+KJtPxLsiTTnc/z54AXImJdRPwJuJvs3Hfn81zS2fNa6Pl2UqlOt3wcjCQBNwNPR8SPcqtmAKU7QMaRXWspxc9Jd5EMB97IdbMbQkRcGhEDImIw2Xn8TUR8CXgIODMVa9vm0ndxZirfUP+ij4hXgJckHZJCJ5C9HqLbnmeyYa/hkvZK/52X2txtz3NOZ8/rLGCkpN6phzcyxbZNvS8yNcoEnAL8DngO+Id616egNh1L1jVeDCxK0ylkY8lzgRXAHGC/VF5kd8E9Bywhu7Om7u3YjvaPAO5P8wcB84AW4C5g9xTfIy23pPUH1bve29jWYcCCdK7vBXp39/MM/CPwDLAU+Dmwe3c7z8BtZNeM/kTWIz1vW84r8Dep7S3AudtTJ/+i3szMCuPhLzMzK4yTipmZFcZJxczMCuOkYmZmhXFSMTOzwjipmJlZYZxUrCFI+mdJn5V0uqRL26z7saTjt2GfI0qPvi+zbk7ukeG9JF24PfsrmqQLJJ3TFcdKxxsh6TO55VskndneNh3sb7u2tx2Xk4o1iqOBx4G/AB4tBSXtDwyPiEcrbbiNfg6UEkmv3PwOISJ+FhHTuvCQI4DPdFTIzEnFdmiSfihpMfAp4LfA/wImSfpuKvJF4Fe58hOVvXRssaSrUmyLfxVLejN3iH0kPaDsBWw/k1T6f2IG2SPBASYCB0talOqj9LlU0hJJZ5ep96fSE4EPlrR3epnSvBQbncp8RdLdkn6VXo70LyneI9W5tP9vlNn/FZL+Ls0/LOkHaf+/k3RcmfIjJD0i6T5Jz6fv6UtpmyWSDk7l+kiaLml+mo5R9gTrC4BvpO+gtP/jJf2/tL8z0/Zlv5sU/2n6nucAHyl7wq3x1fsxA548dTSRJZSfALsC/9lm3VTgr9L8/mTviCg9KaJX+rwFODO3zZvpcwTwR7JHd/QAZrcptyLtczBbvq/ii6lsD7LHiv+e7L0UI8je2/EZssesD0rl/w/w16U6kT3uZ2/gK8DzZA8v3AN4kezBfkcCs3PH61XmO7kC+Ls0/zBwdZo/BZhTpvwI4PVUz93JHhj4j2ndJcCP0/ytwLFpfhDZc+G2OF7uO72L7B+mQ4GWDr6bL+TiB6a6nNm2np4af3JPxRrBEcB/AZ9gy7c1QvYHa12af4MsSdws6QvA21Xse15EPB8R75E9R+nY3Lq1ZH8A2zoWuC0i3ouINcAjZIkP4H8Ak8kS3e9TbCQwQdIisgSwB9kfbMje0PdGRPyR7IGHHyVLNAdJ+omkUcAWryOooPSE6YVkSbCc+RGxOiI2kj3/6dcpviS3zeeAn6a6ziDryX2owv7ujYj3I2I5H7yzo9J3c3wu/jLwmyraZA2oZ8dFzOpD0jCyfxEPIHsT315ZWIvI3tL3DvAO2R9pImKTpKPInkh7JnAx2aPtN5GGetPw1m65w7R9+F1+eY+0/85YnbY7HHi51BTgixHxbJv2HQ1szIXeI3sr4WuSDiN7U+EFwFlkD/xrT2k/71H5/+v8sd7PLb+f22YXsmtUf2xT1472V7aA7XzcU7EdVkQsiohhZMNFQ8n+dXtSRAxLCQWynsvHYPPLxvaNiJnAN8jexQ6wkmxICeA0smG0kqOUvdJgF+Bs4LG0LwF/lrb9A/Dh3Db/AZydrn30IftX+Ly07nXgVOCfJY1IsVnA36Z9Iunw9tot6QBgl4iYDnyHrKfWVX4N/G2uLsPSbNvvoJJK382juXg/4LPFVtt2FE4qtkNLf5hei4j3gU+koZa8B8iuF0D2R+/+dGH/MeCbKX4j8BeS/ovsveRv5bafD/yULDm9ANyT4kcCj0fEpoh4FfjPdPH5h6nMYrIhud8A347snSUApGGfzwPXp97I98kS2WJJy9Jye/oDD6ce2b8Bl3ZQvkhfA5rTjQ7LyXpKAP8OnNHmQn05lb6be8iuUS0HppHddGHdkB99bw1P0mPA5yPi9QL3eS0wIyLmFrVPs52BeyrWHXyLDy58F2WpE4pZ57mnYmZmhXFPxczMCuOkYmZmhXFSMTOzwjipmJlZYZxUzMysMP8f8CEAk4N+Kt4AAAAASUVORK5CYII=\n", 70 | "text/plain": [ 71 | "
" 72 | ] 73 | }, 74 | "metadata": { 75 | "needs_background": "light" 76 | }, 77 | "output_type": "display_data" 78 | } 79 | ], 80 | "source": [ 81 | "methods_length = [len(method) for method in methods_source]\n", 82 | "plt.xlabel('#(sub)tokens in method')\n", 83 | "plt.ylabel('Frequency')\n", 84 | "a = plt.hist(methods_length, bins=100, range=(0, 1000), )" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 13, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "The average number of sub-identifiers in a method name is: 2.807434532487935\n" 97 | ] 98 | }, 99 | { 100 | "data": { 101 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAEKCAYAAADaa8itAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHGVJREFUeJzt3X+4VnWZ7/H3R9D8VYpBDAK2zZg81IxIO7VMx3QGEU+BHY9pnWTUE9MJp2yaM2F1pWnO6DmTztCkk41cQqlomsokHkSznBp/gIr8MmNLmCACCYqmaeB9/ljfjcvt8+z9wP6u/exHPq/rWtez1r2+a617PWz2vdev71JEYGZmlsMuzU7AzMzePFxUzMwsGxcVMzPLxkXFzMyycVExM7NsXFTMzCwbFxUzM8vGRcXMzLJxUTEzs2wGNjuBvjZ48OBoa2trdhpmZi3lwQcf/G1EDOmp3U5XVNra2li4cGGz0zAzaymSnmiknU9/mZlZNi4qZmaWjYuKmZllU1lRkbS7pAckPSJpmaRvpPiBku6X1CHpekm7pfhb0nRHmt9WWte5Kf6YpONL8fEp1iFpWlX7YmZmjanySOVl4NiIOAQYA4yXdARwCXBZRLwb2AScldqfBWxK8ctSOySNBk4F3guMBy6XNEDSAOA7wAnAaOC01NbMzJqksqIShRfS5K5pCOBY4MYUnwlMSuMT0zRp/nGSlOKzI+LliPg10AEcloaOiFgZEa8As1NbMzNrkkqvqaQjikXAemA+8DjwbERsSU1WA8PT+HDgSYA0/zng7eV4l2Xqxc3MrEkqLSoRsTUixgAjKI4sDq5ye/VImiJpoaSFGzZsaEYKZmY7hT65+ysingXuBj4I7Cup86HLEcCaNL4GGAmQ5u8DPFOOd1mmXrzW9q+MiPaIaB8ypMcHQs3MbAdV9kS9pCHAHyLiWUl7AH9BcfH9buBkimsgk4Fb0yJz0vS9af5PIiIkzQGulXQpsD8wCngAEDBK0oEUxeRU4JNV7U8ubdNu6/U6Vl18YoZMzMzyq7KblmHAzHSX1i7ADRHxY0nLgdmSvgk8DFyV2l8FfF9SB7CRokgQEcsk3QAsB7YAUyNiK4Cks4F5wABgRkQsq3B/zMysB5UVlYhYDBxaI76S4vpK1/jvgf9eZ10XARfViM8F5vY6WTMzy8JP1JuZWTYuKmZmlo2LipmZZeOiYmZm2biomJlZNi4qZmaWjYuKmZll46JiZmbZuKiYmVk2LipmZpaNi4qZmWXjomJmZtm4qJiZWTYuKmZmlo2LipmZZeOiYmZm2biomJlZNi4qZmaWjYuKmZll46JiZmbZuKiYmVk2LipmZpaNi4qZmWXjomJmZtm4qJiZWTYuKmZmlk1lRUXSSEl3S1ouaZmkL6T4+ZLWSFqUhgmlZc6V1CHpMUnHl+LjU6xD0rRS/EBJ96f49ZJ2q2p/zMysZ1UeqWwBvhQRo4EjgKmSRqd5l0XEmDTMBUjzTgXeC4wHLpc0QNIA4DvACcBo4LTSei5J63o3sAk4q8L9MTOzHlRWVCJibUQ8lMafBx4FhnezyERgdkS8HBG/BjqAw9LQERErI+IVYDYwUZKAY4Eb0/IzgUnV7I2ZmTWiT66pSGoDDgXuT6GzJS2WNEPSoBQbDjxZWmx1itWLvx14NiK2dImbmVmTVF5UJO0N3AScExGbgSuAg4AxwFrgW32QwxRJCyUt3LBhQ9WbMzPbaVVaVCTtSlFQromIHwFExLqI2BoRrwLfozi9BbAGGFlafESK1Ys/A+wraWCX+BtExJUR0R4R7UOGDMmzc2Zm9gZV3v0l4Crg0Yi4tBQfVmp2ErA0jc8BTpX0FkkHAqOAB4AFwKh0p9duFBfz50REAHcDJ6flJwO3VrU/ZmbWs4E9N9lhRwKfBpZIWpRiX6G4e2sMEMAq4K8AImKZpBuA5RR3jk2NiK0Aks4G5gEDgBkRsSyt78vAbEnfBB6mKGJmZtYklRWViPg5oBqz5nazzEXARTXic2stFxEree30mZmZNZmfqDczs2xcVMzMLBsXFTMzy8ZFxczMsnFRMTOzbFxUzMwsGxcVMzPLxkXFzMyycVExM7NsXFTMzCwbFxUzM8vGRcXMzLJxUTEzs2xcVMzMLBsXFTMzy8ZFxczMsnFRMTOzbFxUzMwsGxcVMzPLxkXFzMyycVExM7NsXFTMzCwbFxUzM8tmYLMTsO3XNu22Xq9j1cUnZsjEzOz1fKRiZmbZuKiYmVk2lRUVSSMl3S1puaRlkr6Q4vtJmi9pRfoclOKSNF1Sh6TFksaW1jU5tV8haXIp/n5JS9Iy0yWpqv0xM7OeVXmksgX4UkSMBo4ApkoaDUwD7oqIUcBdaRrgBGBUGqYAV0BRhIDzgMOBw4DzOgtRavOZ0nLjK9wfMzPrQWVFJSLWRsRDafx54FFgODARmJmazQQmpfGJwKwo3AfsK2kYcDwwPyI2RsQmYD4wPs17W0TcFxEBzCqty8zMmqBPrqlIagMOBe4HhkbE2jTraWBoGh8OPFlabHWKdRdfXSNuZmZNUnlRkbQ3cBNwTkRsLs9LRxjRBzlMkbRQ0sINGzZUvTkzs51WpUVF0q4UBeWaiPhRCq9Lp65In+tTfA0wsrT4iBTrLj6iRvwNIuLKiGiPiPYhQ4b0bqfMzKyuKu/+EnAV8GhEXFqaNQfovINrMnBrKX56ugvsCOC5dJpsHjBO0qB0gX4cMC/N2yzpiLSt00vrMjOzJqjyifojgU8DSyQtSrGvABcDN0g6C3gCOCXNmwtMADqAF4EzACJio6QLgQWp3QURsTGNfw64GtgDuD0NZmbWJJUVlYj4OVDvuZHjarQPYGqddc0AZtSILwTe14s0zcwsIz9Rb2Zm2TRUVCT9SdWJmJlZ62v0SOVySQ9I+pykfSrNyMzMWlZDRSUijgI+RXFr74OSrpX0F5VmZmZmLafhayoRsQL4GvBl4M+A6ZJ+KenjVSVnZmatpdFrKn8q6TKK/ruOBT4aEf8ljV9WYX5mZtZCGr2l+NvAvwFfiYiXOoMR8ZSkr1WSmZmZtZxGi8qJwEsRsRVA0i7A7hHxYkR8v7LszMyspTR6TeVOiqfWO+2ZYmZmZts0WlR2j4gXOifS+J7VpGRmZq2q0aLyuy6v930/8FI37c3MbCfU6DWVc4AfSnqKoj+vPwI+UVlWZmbWkhoqKhGxQNLBwHtS6LGI+EN1aZmZWSvanl6KPwC0pWXGSiIiZlWSlZmZtaSGioqk7wMHAYuArSkcgIuKmZlt0+iRSjswOr3zxMzMrKZGi8pSiovzayvMpd9rm3Zbs1MwM+vXGi0qg4Hlkh4AXu4MRsTHKsnKzMxaUqNF5fwqkzAzszeHRm8p/pmkdwKjIuJOSXsCA6pNzczMWk2jXd9/BrgR+G4KDQduqSopMzNrTY120zIVOBLYDNte2PWOqpIyM7PW1GhReTkiXumckDSQ4jkVMzOzbRotKj+T9BVgj/Ru+h8C/15dWmZm1ooaLSrTgA3AEuCvgLkU76s3MzPbptG7v14FvpcGMzOzmhq9++vXklZ2HXpYZoak9ZKWlmLnS1ojaVEaJpTmnSupQ9Jjko4vxcenWIekaaX4gZLuT/HrJe22fbtuZma5NXr6q52il+IPAEcB04Ef9LDM1cD4GvHLImJMGuYCSBoNnAq8Ny1zuaQBkgYA3wFOAEYDp6W2AJekdb0b2ASc1eC+mJlZRRoqKhHxTGlYExH/BJzYwzL3ABsbzGMiMDsiXo6IXwMdwGFp6IiIlenus9nAREkCjqV4dgZgJjCpwW2ZmVlFGu36fmxpcheKI5fteRdL2dmSTgcWAl+KiE0UD1PeV2qzOsUAnuwSPxx4O/BsRGyp0b5W/lOAKQAHHHDADqZtZmY9abQwfKs0vgVYBZyyA9u7AriQ4hmXC9N6z9yB9WyXiLgSuBKgvb3dz9eYmVWk0bu/PpJjYxGxrnNc0veAH6fJNcDIUtMRKUad+DPAvpIGpqOVcnszM2uSRk9//U138yPi0gbXMywiOt/JchLFe1oA5gDXSroU2B8YBTwACBgl6UCKonEq8MmICEl3AydTXGeZDNzaSA5mZlad7Xnz4wcofvkDfJTil/6KegtIug44BhgsaTVwHnCMpDEUp79WUTxISUQsk3QDsJzi9NrUiNia1nM2MI+iV+QZEbEsbeLLwGxJ3wQeBq5qcF/MzKwijRaVEcDYiHgeiudNgNsi4n/UWyAiTqsRrvuLPyIuAi6qEZ9L8QR/1/hKirvDzMysn2j0OZWhwCul6VdSzMzMbJtGj1RmAQ9IujlNT6J4NsTMzGybRu/+ukjS7RRP0wOcEREPV5eWmZm1okZPfwHsCWyOiH8GVqc7sszMzLZptEPJ8yjutjo3hXal576/zMxsJ9PokcpJwMeA3wFExFPAW6tKyszMWlOjReWViAjSK4Ql7VVdSmZm1qoaLSo3SPouRdconwHuxC/sMjOzLhq9++sf07vpNwPvAb4eEfMrzczMzFpOj0UlvSjrztSppAuJmZnV1ePpr9QH16uS9umDfMzMrIU1+kT9C8ASSfNJd4ABRMTnK8nKzMxaUqNF5UdpMDMzq6vboiLpgIj4TUS4ny8zM+tRT9dUbukckXRTxbmYmVmL66moqDT+rioTMTOz1tdTUYk642ZmZm/Q04X6QyRtpjhi2SONk6YjIt5WaXZmZtZSui0qETGgrxIxM7PWtz3vUzEzM+uWi4qZmWXjomJmZtm4qJiZWTYuKmZmlo2LipmZZeOiYmZm2VRWVCTNkLRe0tJSbD9J8yWtSJ+DUlySpkvqkLRY0tjSMpNT+xWSJpfi75e0JC0zXZIwM7OmqvJI5WpgfJfYNOCuiBgF3JWmAU4ARqVhCnAFFEUIOA84HDgMOK+zEKU2nykt13VbZmbWxyorKhFxD7CxS3gi0NmN/kxgUik+Kwr3AftKGgYcD8yPiI0RsYnidcbj07y3RcR9ERHArNK6zMysSfr6msrQiFibxp8Ghqbx4cCTpXarU6y7+Ooa8ZokTZG0UNLCDRs29G4PzMysrqZdqE9HGH3S83FEXBkR7RHRPmTIkL7YpJnZTqnR1wnnsk7SsIhYm05hrU/xNcDIUrsRKbYGOKZL/KcpPqJGe2tQ27Tber2OVRefmCETM3sz6esjlTlA5x1ck4FbS/HT011gRwDPpdNk84BxkgalC/TjgHlp3mZJR6S7vk4vrcvMzJqksiMVSddRHGUMlrSa4i6ui4EbJJ0FPAGckprPBSYAHcCLwBkAEbFR0oXAgtTugojovPj/OYo7zPYAbk+DmZk1UWVFJSJOqzPruBptA5haZz0zgBk14guB9/UmRzMzy8tP1JuZWTYuKmZmlo2LipmZZeOiYmZm2biomJlZNi4qZmaWjYuKmZll46JiZmbZuKiYmVk2LipmZpaNi4qZmWXjomJmZtm4qJiZWTYuKmZmlo2LipmZZeOiYmZm2biomJlZNi4qZmaWjYuKmZll46JiZmbZuKiYmVk2LipmZpaNi4qZmWXjomJmZtm4qJiZWTZNKSqSVklaImmRpIUptp+k+ZJWpM9BKS5J0yV1SFosaWxpPZNT+xWSJjdjX8zM7DXNPFL5SESMiYj2ND0NuCsiRgF3pWmAE4BRaZgCXAFFEQLOAw4HDgPO6yxEZmbWHP3p9NdEYGYanwlMKsVnReE+YF9Jw4DjgfkRsTEiNgHzgfF9nbSZmb2mWUUlgDskPShpSooNjYi1afxpYGgaHw48WVp2dYrVi5uZWZMMbNJ2PxwRayS9A5gv6ZflmRERkiLXxlLhmgJwwAEH5FqtmZl10ZQjlYhYkz7XAzdTXBNZl05rkT7Xp+ZrgJGlxUekWL14re1dGRHtEdE+ZMiQnLtiZmYlfX6kImkvYJeIeD6NjwMuAOYAk4GL0+etaZE5wNmSZlNclH8uItZKmgf8feni/Djg3D7clZ1e27TberX8qotPzJSJmfUXzTj9NRS4WVLn9q+NiP8naQFwg6SzgCeAU1L7ucAEoAN4ETgDICI2SroQWJDaXRARG/tuN8zMrKs+LyoRsRI4pEb8GeC4GvEAptZZ1wxgRu4czcxsx/SnW4rNzKzFuaiYmVk2LipmZpaNi4qZmWXjomJmZtm4qJiZWTYuKmZmlo2LipmZZeOiYmZm2biomJlZNi4qZmaWTbPep2LW616OwT0dm/U3PlIxM7NsXFTMzCwbFxUzM8vGRcXMzLJxUTEzs2xcVMzMLBsXFTMzy8bPqVhL87MuZv2Lj1TMzCwbFxUzM8vGRcXMzLLxNRXb6fm6jFk+PlIxM7NsXFTMzCyblj/9JWk88M/AAODfIuLiJqdkOyGfQjMrtPSRiqQBwHeAE4DRwGmSRjc3KzOznVerH6kcBnRExEoASbOBicDypmZltgN8tGNvBq1eVIYDT5amVwOHNykXs6bLUZhycHHbebV6UWmIpCnAlDT5gqTHmplPDwYDv212Eg1qlVydZ1495qlL+iiTnr1pvtN+4J2NNGr1orIGGFmaHpFirxMRVwJX9lVSvSFpYUS0NzuPRrRKrs4zr1bJE1on11bJsxEtfaEeWACMknSgpN2AU4E5Tc7JzGyn1dJHKhGxRdLZwDyKW4pnRMSyJqdlZrbTaumiAhARc4G5zc4jo5Y4TZe0Sq7OM69WyRNaJ9dWybNHiohm52BmZm8SrX5NxczM+hEXlSaQNFLS3ZKWS1om6Qs12hwj6TlJi9Lw9SblukrSkpTDwhrzJWm6pA5JiyWNbVKe7yl9V4skbZZ0Tpc2TflOJc2QtF7S0lJsP0nzJa1In4PqLDs5tVkhaXIT8vy/kn6Z/m1vlrRvnWW7/Tnpo1zPl7Sm9O87oc6y4yU9ln5mpzUhz+tLOa6StKjOsn36nWYTER76eACGAWPT+FuBXwGju7Q5BvhxP8h1FTC4m/kTgNsBAUcA9/eDnAcATwPv7A/fKXA0MBZYWor9H2BaGp8GXFJjuf2AlelzUBof1Md5jgMGpvFLauXZyM9JH+V6PvC3DfxsPA68C9gNeKTr/72q8+wy/1vA1/vDd5pr8JFKE0TE2oh4KI0/DzxK0TtAK5oIzIrCfcC+koY1OafjgMcj4okm5wFARNwDbOwSngjMTOMzgUk1Fj0emB8RGyNiEzAfGN+XeUbEHRGxJU3eR/EsWNPV+U4bsa1rp4h4Bejs2qkS3eUpScApwHVVbb8ZXFSaTFIbcChwf43ZH5T0iKTbJb23TxN7TQB3SHow9UzQVa2ucppdIE+l/n/U/vCdAgyNiLVp/GlgaI02/e27PZPiqLSWnn5O+srZ6VTdjDqnFPvTd3oUsC4iVtSZ31++0+3iotJEkvYGbgLOiYjNXWY/RHH65hDg28AtfZ1f8uGIGEvRE/RUSUc3KY+GpIdgPwb8sMbs/vKdvk4U5zr69W2Ykr4KbAGuqdOkP/ycXAEcBIwB1lKcWurPTqP7o5T+8J1uNxeVJpG0K0VBuSYiftR1fkRsjogX0vhcYFdJg/s4TSJiTfpcD9xMcfqgrKGucvrQCcBDEbGu64z+8p0m6zpPE6bP9TXa9IvvVtJfAv8V+FQqgG/QwM9J5SJiXURsjYhXge/VyaG/fKcDgY8D19dr0x++0x3hotIE6VzqVcCjEXFpnTZ/lNoh6TCKf6tn+i5LkLSXpLd2jlNctF3apdkc4PR0F9gRwHOl0zrNUPevv/7wnZbMATrv5poM3FqjzTxgnKRB6VTOuBTrMypegvd3wMci4sU6bRr5Oalcl2t5J9XJob907fTnwC8jYnWtmf3lO90hzb5TYGccgA9TnO5YDCxKwwTgs8BnU5uzgWUUd6fcB3yoCXm+K23/kZTLV1O8nKcoXpT2OLAEaG/i97oXRZHYpxRr+ndKUeTWAn+gOId/FvB24C5gBXAnsF9q207xBtPOZc8EOtJwRhPy7KC4BtH5c/qvqe3+wNzufk6akOv308/gYopCMaxrrml6AsUdl49XnWutPFP86s6fy1Lbpn6nuQY/UW9mZtn49JeZmWXjomJmZtm4qJiZWTYuKmZmlo2LipmZZeOiYtlI+gdJH5E0SdK5O7iOtnKPrt2021/SjXXm/VTSDr3vO/Vk/KHS9GclnZ7GD049xj4s6SBJ/7kj2+hh+9u21xdq7O/Vkk7uxfp6tby1PhcVy+lwiuc//gy4p8oNRcRTEVHFL69jgG2/ZCPiXyNiVpqcBNwYEYdGxOMR8aFaK6glPRza4/+3LtvrC8dQ2l+z3nJRsV5T8c6NxcAHgHuB/wlcofS+EkmfV/HumMWSZqfY+ZL+trSOpalzTYCBkq6R9KikGyXtWWOb245oJO0haXZqfzOwR6ndOEn3SnpI0g9Tf2ud76r4RoovSUchbRQPS34xHZEc1ZmnindznAP8L0l3p3W8UNrO/5a0IO3jN0o5PiZpFsXT0CPTX/JL0za/WGO/tn0v6YjrEkkPSPqVpKNqtD9G0s8k3SpppaSLJX0qLbNE0kGp3RBJN6UcF0g6stb+ptUeLek/0/pOTssr/Tt35v6JUvxf0n7eCbyjzs9IzX1J39F/pH+HhzqPmnqzX7W2b32o2U9fenhzDBQF5dvArsAvusx7CnhLGt83fZ5P6d0XFL9029IQwJEpPoMa78hI7Zam8b8BZqTxP6Xo+LAdGExxxLRXmvdl0rsrKN5V8ddp/HOkp9hr5LVtusa8F9LnOIp3jIviD7UfU7xHow14FTgitXs/RVf2lL+LLvtV3t5PgW+l8QnAnTXaHwM8S/GOnrdQ9GP1jTTvC8A/pfFrKTooBDiAoougWvt0NUVnnLsAoym6iQf4bxRd7w+g6FH5N2mbHy/F90+5nFwjz5r7AuwJ7J7GRwELc+yXh+YNAzHLYyxFlxIHU7wfpmwxcI2kW2isZ+AnI+IXafwHwOeBf+ym/dHAdICIWJyOmqB4adho4BcquvzajeJIqlNnR54PUvxy3FHj0vBwmt6b4hfkb4AnonjPDBQv2XqXpG8DtwF3NLDuco5tddosiNTfmqTHS+tdAnwkjf85MDp9DwBv6zxqq+GWKDplXC6ps0v+DwPXRcRWis4wf0bxh8TRpfhTkn6ynfuyK/AvksYAW4E/zrFfkToOtb7nomK9kn4ZXE3R2+tvKf7ylIpXpH4wIl4CTqT45fNR4KuS/oTiaKJ8+nX30njXvoNC0uHAd9P01ykKVY/pURwZnFZn/svpcyu9+78g4B8i4ruvCxanl37XOR0RmyQdQvHyrc9SvKDpzB7W3UiOL5fGXy1Nv1paZheKI6bfd8mxp/XVbLCDau3LF4F1wCEpx9/XaA/buV/WPL6mYr0SEYsiYgzplcjAT4DjI2JMRLyk4uL0yIi4m+L00z4Uf8mvoji6QcV77Q8srfYASR9M458Efh4R96d1jomIrr3K3pPaIel9FKfAoLhp4EhJ707z9pL0x3TveYpXPG+PecCZpes1wyW94dqCim72d4mIm4Cvkfa/j9wB/HUplzFptNH9/Q/gE5IGSBpC8UfCAxTffWd8GK8dQTRqH2BtOjL6NMVptO1Rb7+sSVxUrNfSL5lN6RfDwRGxvDR7APADSUsoTg9Nj4hnKd4ls5+kZRS9B/+qtMxjFC8lepTi3exX9JDCFcDeqf0FFKdXiIgNwF8C16VTYvdSnJ7rzr8DJ3W5cN2tiLiD4tz+vWk/b6T2L+rhwE/TUdwPgB267XoHfR5oTzcSLKc4UoLG9/dmiqPDRyj+cPi7iHg6xVcAy4FZvP70YiMuByZL6jx1+rse2ndVb7+sSdxLsZmZZeMjFTMzy8ZFxczMsnFRMTOzbFxUzMwsGxcVMzPLxkXFzMyycVExM7NsXFTMzCyb/w90G9bIKLUthQAAAABJRU5ErkJggg==\n", 102 | "text/plain": [ 103 | "
" 104 | ] 105 | }, 106 | "metadata": { 107 | "needs_background": "light" 108 | }, 109 | "output_type": "display_data" 110 | } 111 | ], 112 | "source": [ 113 | "methods_names_length = [len(method_name) for method_name in methods_names]\n", 114 | "print(\"The average number of sub-identifiers in a method name is: {}\".format(np.mean(methods_names_length)))\n", 115 | "plt.xlabel('#sub-identifiers in method name')\n", 116 | "plt.ylabel('Frequency')\n", 117 | "a = plt.hist(methods_names_length, bins=range(1, 20))\n" 118 | ] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "r252", 124 | "language": "python", 125 | "name": "r252" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.7.2" 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 2 142 | } 143 | -------------------------------------------------------------------------------- /data_processing/data_extraction.py: -------------------------------------------------------------------------------- 1 | from graph_pb2 import FeatureNode, FeatureEdge 2 | from data_processing.text_util import split_identifier_into_parts 3 | from pathlib import Path 4 | from graph_pb2 import Graph 5 | import networkx as nx 6 | import sys 7 | import numpy as np 8 | 9 | sys.setrecursionlimit(10000) 10 | 11 | 12 | def get_dataset_from_dir(dir="../corpus/r252-corpus-features/"): 13 | """ 14 | Extract methods source code, names and graphs structure. 15 | :param dir: directory where to look for proto files 16 | :return: (methods_source, methods_names, methods_graphs) 17 | """ 18 | methods_source = [] 19 | methods_names = [] 20 | methods_graphs = [] 21 | 22 | proto_files = list(Path(dir).rglob("*.proto")) 23 | print("A total of {} files have been found".format(len(proto_files))) 24 | 25 | for i, file in enumerate(proto_files): 26 | file_methods_source, file_methods_names, file_methods_graph = get_file_methods_data( 27 | file) 28 | methods_source += file_methods_source 29 | methods_names += file_methods_names 30 | methods_graphs += file_methods_graph 31 | 32 | return methods_source, methods_names, methods_graphs 33 | 34 | 35 | def get_file_methods_data(file): 36 | """ 37 | Extract the source code tokens, identifier names and graph for methods in a source file. 38 | Identifier tokens are split into subtokens. Constructors are not included in the methods. 39 | :param file: file 40 | :return: (methods_source, methods_names, methods_graph) where methods_source[i] is a list of the tokens for 41 | the source of ith method in the file, methods_names[i] is a list of tokens for name of the 42 | ith method in the file, and methods_graph[i] is the subtree of the file parse tree starting 43 | from the method node. 44 | """ 45 | adj_list, nodes, edges = get_file_graph(file) 46 | 47 | with file.open('rb') as f: 48 | class_name = file.name.split('.') 49 | 50 | g = Graph() 51 | g.ParseFromString(f.read()) 52 | methods_source = [] 53 | methods_names = [] 54 | methods_graph = [] 55 | # class_name_node = get_class_name_node(g) 56 | 57 | for node in g.node: 58 | if node.contents == "METHOD": 59 | method_name_node = get_method_name_node(g, node) 60 | 61 | # If method name is the same as class name, then method name is constructor, 62 | # so discard it 63 | if method_name_node.contents == class_name: 64 | continue 65 | 66 | method_edges, method_nodes, non_tokens_nodes_features = get_method_edges(node.id, adj_list, nodes) 67 | methods_graph.append((method_edges, non_tokens_nodes_features)) 68 | methods_names.append(split_identifier_into_parts(method_name_node.contents)) 69 | 70 | method_source = [] 71 | 72 | for other_node in method_nodes.values(): 73 | if other_node.id == method_name_node.id: 74 | # Replace method name with '_' in method source code 75 | method_source.append('_') 76 | elif other_node.type == FeatureNode.TOKEN or other_node.type == \ 77 | FeatureNode.IDENTIFIER_TOKEN: 78 | method_source.append(other_node.contents) 79 | 80 | methods_source.append(method_source) 81 | 82 | return methods_source, methods_names, methods_graph 83 | 84 | 85 | def get_file_graph(file): 86 | """ 87 | Compute graph for the given file. 88 | """ 89 | with file.open('rb') as f: 90 | g = Graph() 91 | g.ParseFromString(f.read()) 92 | node_ids = [node.id for node in g.node] 93 | edges = [(e.sourceId, e.destinationId, e.type) for e in g.edge] 94 | 95 | adj_list = {node: [] for node in node_ids} 96 | for edge in edges: 97 | adj_list[edge[0]].append({'destination': edge[1], 'edge_type': edge[2]}) 98 | 99 | nodes = {node.id: node for node in g.node} 100 | 101 | return adj_list, nodes, edges 102 | 103 | 104 | def get_method_edges(method_node_id, file_adj_list, file_nodes): 105 | """ 106 | Compute edges of a method graph for a method starting at the node 'method_node_id'. 107 | """ 108 | method_nodes_ids = [] 109 | 110 | get_method_nodes_rec(method_node_id, method_nodes_ids, file_adj_list) 111 | methods_edges = [] 112 | 113 | for node in method_nodes_ids: 114 | for edge in file_adj_list[node]: 115 | if edge['destination'] in method_nodes_ids: 116 | methods_edges.append((node, edge['destination'])) 117 | 118 | method_nodes = {node_id: node for node_id, node in file_nodes.items() if node_id in 119 | method_nodes_ids} 120 | 121 | methods_edges, non_tokens_nodes_features = remap_edges(methods_edges, method_nodes) 122 | 123 | return methods_edges, method_nodes, non_tokens_nodes_features 124 | 125 | 126 | def get_method_nodes_rec(node_id, method_nodes_ids, file_adj_list): 127 | """ 128 | Utilities to recursively retrieve all edges of a method graph. 129 | """ 130 | method_nodes_ids.append(node_id) 131 | 132 | for edge in file_adj_list[node_id]: 133 | if edge['edge_type'] != FeatureEdge.NEXT_TOKEN and edge['destination'] not in method_nodes_ids: 134 | get_method_nodes_rec(edge['destination'], method_nodes_ids, file_adj_list) 135 | 136 | 137 | def remap_edges(edges, nodes): 138 | """ 139 | Remap edges so that ids start from 0 and are consecutive. 140 | """ 141 | old_id_to_new_id = {} 142 | i = 0 143 | nodes_values = sorted(nodes.values(), key=lambda node: node.id) 144 | new_edges = [] 145 | 146 | # Set new ids for tokens 147 | for node_value in nodes_values: 148 | if is_token(node_value): 149 | old_id_to_new_id[node_value.id] = i 150 | i += 1 151 | 152 | non_tokens_nodes_features = np.zeros((len(nodes_values) - len(old_id_to_new_id), 11)) 153 | j = i 154 | # Set new ids for other nodes 155 | for node_value in nodes_values: 156 | if not is_token(node_value): 157 | old_id_to_new_id[node_value.id] = i 158 | non_tokens_nodes_features[i - j][node_value.type - 1] = 1 159 | i += 1 160 | 161 | for edge in edges: 162 | new_edges.append((old_id_to_new_id[edge[0]], old_id_to_new_id[edge[1]])) 163 | 164 | return new_edges, non_tokens_nodes_features 165 | 166 | 167 | def is_token(node_value): 168 | return node_value.type == FeatureNode.TOKEN or node_value.type == FeatureNode.IDENTIFIER_TOKEN 169 | 170 | 171 | def get_method_name_node(g, method_node): 172 | """ 173 | Return the node corresponding to the name of a method. 174 | """ 175 | method_id = method_node.id 176 | method_name_node_id = 0 177 | 178 | for edge in g.edge: 179 | if edge.sourceId == method_id and edge.type == FeatureEdge.ASSOCIATED_TOKEN: 180 | method_name_node_id = edge.destinationId 181 | break 182 | 183 | for node in g.node: 184 | if node.id == method_name_node_id: 185 | return node 186 | 187 | 188 | def get_class_name_node(g): 189 | """ 190 | :param g: graph representing the file 191 | :return: the node corresponding to the class identifier token 192 | """ 193 | class_node = [node for node in g.node if node.contents == "CLASS"][0] 194 | class_associated_nodes_ids = [edge.destinationId for edge in g.edge if edge.sourceId == 195 | class_node.id and edge.type == FeatureEdge.ASSOCIATED_TOKEN] 196 | class_associated_nodes = [node for node in g.node if node.id in class_associated_nodes_ids] 197 | 198 | return class_associated_nodes[1] 199 | 200 | 201 | def get_nx_graph(file): 202 | """ 203 | Get networkx graph corresponding to a file. 204 | """ 205 | nx_graph = nx.DiGraph() 206 | with file.open('rb') as f: 207 | g = Graph() 208 | g.ParseFromString(f.read()) 209 | 210 | for edge in g.edge: 211 | edge_type = [name for name, value in list(vars(FeatureEdge).items())[8:] if value == 212 | edge.type][0] 213 | nx_graph.add_edge(edge.sourceId, edge.destinationId, edge_type=edge_type) 214 | return nx_graph 215 | -------------------------------------------------------------------------------- /data_processing/data_generation.py: -------------------------------------------------------------------------------- 1 | from data_processing.data_extraction import get_dataset_from_dir 2 | import pickle 3 | 4 | # Generate data 5 | methods_source, methods_names, methods_graphs = get_dataset_from_dir( 6 | "../corpus/r252-corpus-features/") 7 | 8 | # Store data 9 | pickle.dump({'methods_source': methods_source, 'methods_names': methods_names, 'methods_graphs': 10 | methods_graphs}, open('data/methods_tokens_graphs.pkl', 'wb')) 11 | -------------------------------------------------------------------------------- /data_processing/data_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | from io import open 3 | import torch 4 | import pickle 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import pylab 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | np.random.seed(42) 11 | 12 | 13 | class TokenLang: 14 | """ 15 | Language of all tokens in a dataset. 16 | """ 17 | def __init__(self, name): 18 | self.name = name 19 | self.word2index = {} 20 | self.word2count = {} 21 | self.index2word = {0: "SOS", 1: "EOS"} 22 | self.n_words = 2 # Count SOS and EOS 23 | self.SOS_token = 0 24 | self.EOS_token = 1 25 | 26 | def add_sentence(self, sentence): 27 | for word in sentence: 28 | self.add_word(word) 29 | 30 | def add_word(self, word): 31 | if word not in self.word2index: 32 | self.word2index[word] = self.n_words 33 | self.word2count[word] = 1 34 | self.index2word[self.n_words] = word 35 | self.n_words += 1 36 | else: 37 | self.word2count[word] += 1 38 | 39 | def to_tokens(self, idxs): 40 | return np.array([self.index2word[idx] for idx in idxs]) 41 | 42 | 43 | def read_data(): 44 | """ 45 | Read data and return X, y pairs. 46 | """ 47 | data = pickle.load(open('data/methods_tokens_graphs2.pkl', 'rb')) 48 | methods_source = data['methods_source'] 49 | methods_graphs = data['methods_graphs'] 50 | methods_names = data['methods_names'] 51 | 52 | pairs = [((methods_source[i], methods_graphs[i]), methods_names[i]) for i in range(len( 53 | methods_source))] 54 | np.random.shuffle(pairs) 55 | 56 | return pairs 57 | 58 | 59 | def read_tokens(): 60 | """ 61 | Read data and return X, y pairs without graph information. 62 | """ 63 | data = pickle.load(open('data/methods_tokens_graphs2.pkl', 'rb')) 64 | methods_source = data['methods_source'] 65 | methods_names = data['methods_names'] 66 | 67 | pairs = [(methods_source[i], methods_names[i]) for i in range(len(methods_source))] 68 | np.random.shuffle(pairs) 69 | 70 | return pairs 71 | 72 | 73 | def prepare_tokens(num_samples=None): 74 | """ 75 | Prepare data and return language and X, y pairs. 76 | """ 77 | lang = TokenLang('code') 78 | pairs = read_tokens() 79 | pairs = pairs if not num_samples else pairs[:num_samples] 80 | print("Read %s sentence pairs" % len(pairs)) 81 | for pair in pairs: 82 | lang.add_sentence(pair[0]) 83 | lang.add_sentence(pair[1]) 84 | print("Counted words:") 85 | print(lang.name, lang.n_words) 86 | return lang, pairs 87 | 88 | 89 | def prepare_data(num_samples=None): 90 | """ 91 | Prepare data and return language and X, y pairs without graph information. 92 | """ 93 | lang = TokenLang('code') 94 | pairs = read_data() 95 | pairs = pairs if not num_samples else pairs[:num_samples] 96 | print("Read %s sentence pairs" % len(pairs)) 97 | for pair in pairs: 98 | lang.add_sentence(pair[0][0]) 99 | lang.add_sentence(pair[1]) 100 | print("Counted words:") 101 | print(lang.name, lang.n_words) 102 | return lang, pairs 103 | 104 | 105 | def indexes_from_sentence_tokens(lang, sentence): 106 | return [lang.word2index[word] for word in sentence] 107 | 108 | 109 | def tensor_from_sentence_tokens(lang, sentence): 110 | indexes = indexes_from_sentence_tokens(lang, sentence) 111 | indexes.append(lang.EOS_token) 112 | return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1) 113 | 114 | 115 | def tensors_from_pair_tokens(pair, lang): 116 | input_tensor = tensor_from_sentence_tokens(lang, pair[0]) 117 | target_tensor = tensor_from_sentence_tokens(lang, pair[1]) 118 | return input_tensor, target_tensor 119 | 120 | 121 | def sparse_adj_from_edges(edges): 122 | """ 123 | Return a sparse Pytorch matrix given a list of edges. 124 | """ 125 | f = [e[0] for e in edges] 126 | t = [e[1] for e in edges] 127 | n_nodes = max(f + t) + 1 128 | idxs = torch.LongTensor(edges) 129 | values = torch.ones(len(edges)) 130 | 131 | adj = torch.sparse.FloatTensor(idxs.t(), values, torch.Size([n_nodes, n_nodes])) 132 | return adj 133 | 134 | 135 | def tensors_from_pair_tokens_graph(pair, lang): 136 | """ 137 | Get tensor from training given a X, y pair. 138 | """ 139 | input_tensor = tensor_from_sentence_tokens(lang, pair[0][0]) 140 | input_adj = sparse_adj_from_edges(pair[0][1][0]) 141 | node_features = torch.tensor(pair[0][1][1]) 142 | target_tensor = tensor_from_sentence_tokens(lang, pair[1]) 143 | return (input_tensor, input_adj, node_features), target_tensor 144 | 145 | 146 | def plot_loss(train_losses, val_losses, file_path='plots/loss.jpg'): 147 | """ 148 | Plot the train and validation loss. 149 | """ 150 | plt.clf() 151 | plt.plot(train_losses) 152 | plt.plot(val_losses) 153 | plt.legend(('train loss', 'validation loss'), loc='upper right') 154 | plt.title('Losses during training of LSTM->LSTM Model') 155 | plt.xlabel('#epochs') 156 | plt.ylabel('cross-entropy loss') 157 | pylab.savefig(file_path) 158 | -------------------------------------------------------------------------------- /data_processing/text_util.py: -------------------------------------------------------------------------------- 1 | def split_camelcase(camel_case_identifier): 2 | """ 3 | Split camelCase identifiers. 4 | """ 5 | if not len(camel_case_identifier): 6 | return [] 7 | 8 | # split into words based on adjacent cases being the same 9 | result = [] 10 | current = str(camel_case_identifier[0]) 11 | prev_upper = camel_case_identifier[0].isupper() 12 | prev_digit = camel_case_identifier[0].isdigit() 13 | prev_special = not camel_case_identifier[0].isalnum() 14 | for c in camel_case_identifier[1:]: 15 | upper = c.isupper() 16 | digit = c.isdigit() 17 | special = not c.isalnum() 18 | new_upper_word = upper and not prev_upper 19 | new_digit_word = digit and not prev_digit 20 | new_special_word = special and not prev_special 21 | if new_digit_word or new_upper_word or new_special_word: 22 | result.append(current) 23 | current = c 24 | elif not upper and prev_upper and len(current) > 1: 25 | result.append(current[:-1]) 26 | current = current[-1] + c 27 | elif not digit and prev_digit: 28 | result.append(current) 29 | current = c 30 | elif not special and prev_special: 31 | result.append(current) 32 | current = c 33 | else: 34 | current += c 35 | prev_digit = digit 36 | prev_upper = upper 37 | prev_special = special 38 | result.append(current) 39 | return result 40 | 41 | 42 | def split_identifier_into_parts(identifier): 43 | """ 44 | Split a single identifier into parts on snake_case and camelCase 45 | """ 46 | snake_case = identifier.split("_") 47 | 48 | identifier_parts = [] 49 | for i in range(len(snake_case)): 50 | part = snake_case[i] 51 | if len(part) > 0: 52 | identifier_parts.extend(s.lower() for s in split_camelcase(part)) 53 | if len(identifier_parts) == 0: 54 | return [identifier] 55 | return identifier_parts 56 | -------------------------------------------------------------------------------- /graph_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: graph.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='graph.proto', 19 | package='protobuf', 20 | syntax='proto2', 21 | serialized_options=_b('\n$uk.ac.cam.acr31.features.javac.protoB\013GraphProtos'), 22 | serialized_pb=_b('\n\x0bgraph.proto\x12\x08protobuf\"\xf8\x02\n\x0b\x46\x65\x61tureNode\x12\n\n\x02id\x18\x01 \x01(\x03\x12,\n\x04type\x18\x02 \x01(\x0e\x32\x1e.protobuf.FeatureNode.NodeType\x12\x10\n\x08\x63ontents\x18\x03 \x01(\t\x12\x15\n\rstartPosition\x18\x04 \x01(\x05\x12\x13\n\x0b\x65ndPosition\x18\x05 \x01(\x05\x12\x17\n\x0fstartLineNumber\x18\x06 \x01(\x05\x12\x15\n\rendLineNumber\x18\x07 \x01(\x05\"\xc0\x01\n\x08NodeType\x12\t\n\x05TOKEN\x10\x01\x12\x0f\n\x0b\x41ST_ELEMENT\x10\x02\x12\x10\n\x0c\x43OMMENT_LINE\x10\x03\x12\x11\n\rCOMMENT_BLOCK\x10\x04\x12\x13\n\x0f\x43OMMENT_JAVADOC\x10\x05\x12\x14\n\x10IDENTIFIER_TOKEN\x10\x07\x12\x0c\n\x08\x46\x41KE_AST\x10\x08\x12\n\n\x06SYMBOL\x10\t\x12\x0e\n\nSYMBOL_TYP\x10\n\x12\x0e\n\nSYMBOL_VAR\x10\x0b\x12\x0e\n\nSYMBOL_MTH\x10\x0c\"\xe9\x02\n\x0b\x46\x65\x61tureEdge\x12\x10\n\x08sourceId\x18\x01 \x01(\x03\x12\x15\n\rdestinationId\x18\x02 \x01(\x03\x12,\n\x04type\x18\x03 \x01(\x0e\x32\x1e.protobuf.FeatureEdge.EdgeType\"\x82\x02\n\x08\x45\x64geType\x12\x14\n\x10\x41SSOCIATED_TOKEN\x10\x01\x12\x0e\n\nNEXT_TOKEN\x10\x02\x12\r\n\tAST_CHILD\x10\x03\x12\x08\n\x04NONE\x10\x04\x12\x0e\n\nLAST_WRITE\x10\x05\x12\x0c\n\x08LAST_USE\x10\x06\x12\x11\n\rCOMPUTED_FROM\x10\x07\x12\x0e\n\nRETURNS_TO\x10\x08\x12\x13\n\x0f\x46ORMAL_ARG_NAME\x10\t\x12\x0e\n\nGUARDED_BY\x10\n\x12\x17\n\x13GUARDED_BY_NEGATION\x10\x0b\x12\x14\n\x10LAST_LEXICAL_USE\x10\x0c\x12\x0b\n\x07\x43OMMENT\x10\r\x12\x15\n\x11\x41SSOCIATED_SYMBOL\x10\x0e\"\xba\x01\n\x05Graph\x12#\n\x04node\x18\x01 \x03(\x0b\x32\x15.protobuf.FeatureNode\x12#\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32\x15.protobuf.FeatureEdge\x12\x12\n\nsourceFile\x18\x03 \x01(\t\x12*\n\x0b\x66irst_token\x18\x04 \x01(\x0b\x32\x15.protobuf.FeatureNode\x12\'\n\x08\x61st_root\x18\x05 \x01(\x0b\x32\x15.protobuf.FeatureNodeB3\n$uk.ac.cam.acr31.features.javac.protoB\x0bGraphProtos') 23 | ) 24 | 25 | 26 | 27 | _FEATURENODE_NODETYPE = _descriptor.EnumDescriptor( 28 | name='NodeType', 29 | full_name='protobuf.FeatureNode.NodeType', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | values=[ 33 | _descriptor.EnumValueDescriptor( 34 | name='TOKEN', index=0, number=1, 35 | serialized_options=None, 36 | type=None), 37 | _descriptor.EnumValueDescriptor( 38 | name='AST_ELEMENT', index=1, number=2, 39 | serialized_options=None, 40 | type=None), 41 | _descriptor.EnumValueDescriptor( 42 | name='COMMENT_LINE', index=2, number=3, 43 | serialized_options=None, 44 | type=None), 45 | _descriptor.EnumValueDescriptor( 46 | name='COMMENT_BLOCK', index=3, number=4, 47 | serialized_options=None, 48 | type=None), 49 | _descriptor.EnumValueDescriptor( 50 | name='COMMENT_JAVADOC', index=4, number=5, 51 | serialized_options=None, 52 | type=None), 53 | _descriptor.EnumValueDescriptor( 54 | name='IDENTIFIER_TOKEN', index=5, number=7, 55 | serialized_options=None, 56 | type=None), 57 | _descriptor.EnumValueDescriptor( 58 | name='FAKE_AST', index=6, number=8, 59 | serialized_options=None, 60 | type=None), 61 | _descriptor.EnumValueDescriptor( 62 | name='SYMBOL', index=7, number=9, 63 | serialized_options=None, 64 | type=None), 65 | _descriptor.EnumValueDescriptor( 66 | name='SYMBOL_TYP', index=8, number=10, 67 | serialized_options=None, 68 | type=None), 69 | _descriptor.EnumValueDescriptor( 70 | name='SYMBOL_VAR', index=9, number=11, 71 | serialized_options=None, 72 | type=None), 73 | _descriptor.EnumValueDescriptor( 74 | name='SYMBOL_MTH', index=10, number=12, 75 | serialized_options=None, 76 | type=None), 77 | ], 78 | containing_type=None, 79 | serialized_options=None, 80 | serialized_start=210, 81 | serialized_end=402, 82 | ) 83 | _sym_db.RegisterEnumDescriptor(_FEATURENODE_NODETYPE) 84 | 85 | _FEATUREEDGE_EDGETYPE = _descriptor.EnumDescriptor( 86 | name='EdgeType', 87 | full_name='protobuf.FeatureEdge.EdgeType', 88 | filename=None, 89 | file=DESCRIPTOR, 90 | values=[ 91 | _descriptor.EnumValueDescriptor( 92 | name='ASSOCIATED_TOKEN', index=0, number=1, 93 | serialized_options=None, 94 | type=None), 95 | _descriptor.EnumValueDescriptor( 96 | name='NEXT_TOKEN', index=1, number=2, 97 | serialized_options=None, 98 | type=None), 99 | _descriptor.EnumValueDescriptor( 100 | name='AST_CHILD', index=2, number=3, 101 | serialized_options=None, 102 | type=None), 103 | _descriptor.EnumValueDescriptor( 104 | name='NONE', index=3, number=4, 105 | serialized_options=None, 106 | type=None), 107 | _descriptor.EnumValueDescriptor( 108 | name='LAST_WRITE', index=4, number=5, 109 | serialized_options=None, 110 | type=None), 111 | _descriptor.EnumValueDescriptor( 112 | name='LAST_USE', index=5, number=6, 113 | serialized_options=None, 114 | type=None), 115 | _descriptor.EnumValueDescriptor( 116 | name='COMPUTED_FROM', index=6, number=7, 117 | serialized_options=None, 118 | type=None), 119 | _descriptor.EnumValueDescriptor( 120 | name='RETURNS_TO', index=7, number=8, 121 | serialized_options=None, 122 | type=None), 123 | _descriptor.EnumValueDescriptor( 124 | name='FORMAL_ARG_NAME', index=8, number=9, 125 | serialized_options=None, 126 | type=None), 127 | _descriptor.EnumValueDescriptor( 128 | name='GUARDED_BY', index=9, number=10, 129 | serialized_options=None, 130 | type=None), 131 | _descriptor.EnumValueDescriptor( 132 | name='GUARDED_BY_NEGATION', index=10, number=11, 133 | serialized_options=None, 134 | type=None), 135 | _descriptor.EnumValueDescriptor( 136 | name='LAST_LEXICAL_USE', index=11, number=12, 137 | serialized_options=None, 138 | type=None), 139 | _descriptor.EnumValueDescriptor( 140 | name='COMMENT', index=12, number=13, 141 | serialized_options=None, 142 | type=None), 143 | _descriptor.EnumValueDescriptor( 144 | name='ASSOCIATED_SYMBOL', index=13, number=14, 145 | serialized_options=None, 146 | type=None), 147 | ], 148 | containing_type=None, 149 | serialized_options=None, 150 | serialized_start=508, 151 | serialized_end=766, 152 | ) 153 | _sym_db.RegisterEnumDescriptor(_FEATUREEDGE_EDGETYPE) 154 | 155 | 156 | _FEATURENODE = _descriptor.Descriptor( 157 | name='FeatureNode', 158 | full_name='protobuf.FeatureNode', 159 | filename=None, 160 | file=DESCRIPTOR, 161 | containing_type=None, 162 | fields=[ 163 | _descriptor.FieldDescriptor( 164 | name='id', full_name='protobuf.FeatureNode.id', index=0, 165 | number=1, type=3, cpp_type=2, label=1, 166 | has_default_value=False, default_value=0, 167 | message_type=None, enum_type=None, containing_type=None, 168 | is_extension=False, extension_scope=None, 169 | serialized_options=None, file=DESCRIPTOR), 170 | _descriptor.FieldDescriptor( 171 | name='type', full_name='protobuf.FeatureNode.type', index=1, 172 | number=2, type=14, cpp_type=8, label=1, 173 | has_default_value=False, default_value=1, 174 | message_type=None, enum_type=None, containing_type=None, 175 | is_extension=False, extension_scope=None, 176 | serialized_options=None, file=DESCRIPTOR), 177 | _descriptor.FieldDescriptor( 178 | name='contents', full_name='protobuf.FeatureNode.contents', index=2, 179 | number=3, type=9, cpp_type=9, label=1, 180 | has_default_value=False, default_value=_b("").decode('utf-8'), 181 | message_type=None, enum_type=None, containing_type=None, 182 | is_extension=False, extension_scope=None, 183 | serialized_options=None, file=DESCRIPTOR), 184 | _descriptor.FieldDescriptor( 185 | name='startPosition', full_name='protobuf.FeatureNode.startPosition', index=3, 186 | number=4, type=5, cpp_type=1, label=1, 187 | has_default_value=False, default_value=0, 188 | message_type=None, enum_type=None, containing_type=None, 189 | is_extension=False, extension_scope=None, 190 | serialized_options=None, file=DESCRIPTOR), 191 | _descriptor.FieldDescriptor( 192 | name='endPosition', full_name='protobuf.FeatureNode.endPosition', index=4, 193 | number=5, type=5, cpp_type=1, label=1, 194 | has_default_value=False, default_value=0, 195 | message_type=None, enum_type=None, containing_type=None, 196 | is_extension=False, extension_scope=None, 197 | serialized_options=None, file=DESCRIPTOR), 198 | _descriptor.FieldDescriptor( 199 | name='startLineNumber', full_name='protobuf.FeatureNode.startLineNumber', index=5, 200 | number=6, type=5, cpp_type=1, label=1, 201 | has_default_value=False, default_value=0, 202 | message_type=None, enum_type=None, containing_type=None, 203 | is_extension=False, extension_scope=None, 204 | serialized_options=None, file=DESCRIPTOR), 205 | _descriptor.FieldDescriptor( 206 | name='endLineNumber', full_name='protobuf.FeatureNode.endLineNumber', index=6, 207 | number=7, type=5, cpp_type=1, label=1, 208 | has_default_value=False, default_value=0, 209 | message_type=None, enum_type=None, containing_type=None, 210 | is_extension=False, extension_scope=None, 211 | serialized_options=None, file=DESCRIPTOR), 212 | ], 213 | extensions=[ 214 | ], 215 | nested_types=[], 216 | enum_types=[ 217 | _FEATURENODE_NODETYPE, 218 | ], 219 | serialized_options=None, 220 | is_extendable=False, 221 | syntax='proto2', 222 | extension_ranges=[], 223 | oneofs=[ 224 | ], 225 | serialized_start=26, 226 | serialized_end=402, 227 | ) 228 | 229 | 230 | _FEATUREEDGE = _descriptor.Descriptor( 231 | name='FeatureEdge', 232 | full_name='protobuf.FeatureEdge', 233 | filename=None, 234 | file=DESCRIPTOR, 235 | containing_type=None, 236 | fields=[ 237 | _descriptor.FieldDescriptor( 238 | name='sourceId', full_name='protobuf.FeatureEdge.sourceId', index=0, 239 | number=1, type=3, cpp_type=2, label=1, 240 | has_default_value=False, default_value=0, 241 | message_type=None, enum_type=None, containing_type=None, 242 | is_extension=False, extension_scope=None, 243 | serialized_options=None, file=DESCRIPTOR), 244 | _descriptor.FieldDescriptor( 245 | name='destinationId', full_name='protobuf.FeatureEdge.destinationId', index=1, 246 | number=2, type=3, cpp_type=2, label=1, 247 | has_default_value=False, default_value=0, 248 | message_type=None, enum_type=None, containing_type=None, 249 | is_extension=False, extension_scope=None, 250 | serialized_options=None, file=DESCRIPTOR), 251 | _descriptor.FieldDescriptor( 252 | name='type', full_name='protobuf.FeatureEdge.type', index=2, 253 | number=3, type=14, cpp_type=8, label=1, 254 | has_default_value=False, default_value=1, 255 | message_type=None, enum_type=None, containing_type=None, 256 | is_extension=False, extension_scope=None, 257 | serialized_options=None, file=DESCRIPTOR), 258 | ], 259 | extensions=[ 260 | ], 261 | nested_types=[], 262 | enum_types=[ 263 | _FEATUREEDGE_EDGETYPE, 264 | ], 265 | serialized_options=None, 266 | is_extendable=False, 267 | syntax='proto2', 268 | extension_ranges=[], 269 | oneofs=[ 270 | ], 271 | serialized_start=405, 272 | serialized_end=766, 273 | ) 274 | 275 | 276 | _GRAPH = _descriptor.Descriptor( 277 | name='Graph', 278 | full_name='protobuf.Graph', 279 | filename=None, 280 | file=DESCRIPTOR, 281 | containing_type=None, 282 | fields=[ 283 | _descriptor.FieldDescriptor( 284 | name='node', full_name='protobuf.Graph.node', index=0, 285 | number=1, type=11, cpp_type=10, label=3, 286 | has_default_value=False, default_value=[], 287 | message_type=None, enum_type=None, containing_type=None, 288 | is_extension=False, extension_scope=None, 289 | serialized_options=None, file=DESCRIPTOR), 290 | _descriptor.FieldDescriptor( 291 | name='edge', full_name='protobuf.Graph.edge', index=1, 292 | number=2, type=11, cpp_type=10, label=3, 293 | has_default_value=False, default_value=[], 294 | message_type=None, enum_type=None, containing_type=None, 295 | is_extension=False, extension_scope=None, 296 | serialized_options=None, file=DESCRIPTOR), 297 | _descriptor.FieldDescriptor( 298 | name='sourceFile', full_name='protobuf.Graph.sourceFile', index=2, 299 | number=3, type=9, cpp_type=9, label=1, 300 | has_default_value=False, default_value=_b("").decode('utf-8'), 301 | message_type=None, enum_type=None, containing_type=None, 302 | is_extension=False, extension_scope=None, 303 | serialized_options=None, file=DESCRIPTOR), 304 | _descriptor.FieldDescriptor( 305 | name='first_token', full_name='protobuf.Graph.first_token', index=3, 306 | number=4, type=11, cpp_type=10, label=1, 307 | has_default_value=False, default_value=None, 308 | message_type=None, enum_type=None, containing_type=None, 309 | is_extension=False, extension_scope=None, 310 | serialized_options=None, file=DESCRIPTOR), 311 | _descriptor.FieldDescriptor( 312 | name='ast_root', full_name='protobuf.Graph.ast_root', index=4, 313 | number=5, type=11, cpp_type=10, label=1, 314 | has_default_value=False, default_value=None, 315 | message_type=None, enum_type=None, containing_type=None, 316 | is_extension=False, extension_scope=None, 317 | serialized_options=None, file=DESCRIPTOR), 318 | ], 319 | extensions=[ 320 | ], 321 | nested_types=[], 322 | enum_types=[ 323 | ], 324 | serialized_options=None, 325 | is_extendable=False, 326 | syntax='proto2', 327 | extension_ranges=[], 328 | oneofs=[ 329 | ], 330 | serialized_start=769, 331 | serialized_end=955, 332 | ) 333 | 334 | _FEATURENODE.fields_by_name['type'].enum_type = _FEATURENODE_NODETYPE 335 | _FEATURENODE_NODETYPE.containing_type = _FEATURENODE 336 | _FEATUREEDGE.fields_by_name['type'].enum_type = _FEATUREEDGE_EDGETYPE 337 | _FEATUREEDGE_EDGETYPE.containing_type = _FEATUREEDGE 338 | _GRAPH.fields_by_name['node'].message_type = _FEATURENODE 339 | _GRAPH.fields_by_name['edge'].message_type = _FEATUREEDGE 340 | _GRAPH.fields_by_name['first_token'].message_type = _FEATURENODE 341 | _GRAPH.fields_by_name['ast_root'].message_type = _FEATURENODE 342 | DESCRIPTOR.message_types_by_name['FeatureNode'] = _FEATURENODE 343 | DESCRIPTOR.message_types_by_name['FeatureEdge'] = _FEATUREEDGE 344 | DESCRIPTOR.message_types_by_name['Graph'] = _GRAPH 345 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 346 | 347 | FeatureNode = _reflection.GeneratedProtocolMessageType('FeatureNode', (_message.Message,), dict( 348 | DESCRIPTOR = _FEATURENODE, 349 | __module__ = 'graph_pb2' 350 | # @@protoc_insertion_point(class_scope:protobuf.FeatureNode) 351 | )) 352 | _sym_db.RegisterMessage(FeatureNode) 353 | 354 | FeatureEdge = _reflection.GeneratedProtocolMessageType('FeatureEdge', (_message.Message,), dict( 355 | DESCRIPTOR = _FEATUREEDGE, 356 | __module__ = 'graph_pb2' 357 | # @@protoc_insertion_point(class_scope:protobuf.FeatureEdge) 358 | )) 359 | _sym_db.RegisterMessage(FeatureEdge) 360 | 361 | Graph = _reflection.GeneratedProtocolMessageType('Graph', (_message.Message,), dict( 362 | DESCRIPTOR = _GRAPH, 363 | __module__ = 'graph_pb2' 364 | # @@protoc_insertion_point(class_scope:protobuf.Graph) 365 | )) 366 | _sym_db.RegisterMessage(Graph) 367 | 368 | 369 | DESCRIPTOR._options = None 370 | # @@protoc_insertion_point(module_scope) 371 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emalgorithm/structured-neural-summarization-replication/e9b672c0b76ea075ff1a3ec8a5c3fc88afc521b9/models/__init__.py -------------------------------------------------------------------------------- /models/full_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.functional as F 4 | 5 | 6 | class FullModel(nn.Module): 7 | """ 8 | Complete methodNaming model. 9 | """ 10 | def __init__(self, encoder, decoder, device, graph_encoder=None, graph=False): 11 | super().__init__() 12 | 13 | self.encoder = encoder 14 | self.graph_encoder = graph_encoder 15 | self.decoder = decoder 16 | self.device = device 17 | self.graph = graph 18 | self.combine = nn.Linear(2 * encoder.hidden_size, encoder.hidden_size) 19 | 20 | assert encoder.hidden_size == decoder.hidden_size, "Hidden dimensions of encoder and decoder " \ 21 | "must be equal!" 22 | 23 | def forward(self, sequence, target, adj=None, node_features=None): 24 | batch_size = 1 25 | max_len = target.shape[0] 26 | target_vocab_size = self.decoder.output_size 27 | 28 | # tensor to store decoder outputs 29 | outputs = torch.zeros(max_len, batch_size, target_vocab_size).to(self.device) 30 | 31 | # last hidden state of the encoder is used as the initial hidden state of the decoder 32 | # hidden contains last hidden state of encoder 33 | # output contains the hidden states for all input elements 34 | encoder_output, hidden = self.encoder(sequence) 35 | 36 | # graph encoder 37 | if self.graph: 38 | # graph_hidden has shape [1, 1, hidden_size] and contains a graph representation 39 | n_nodes = adj.size(0) 40 | n_tokens = sequence.size(0) 41 | x = torch.zeros(n_nodes, encoder_output.size(2)).to(self.device) 42 | x[:n_tokens, :] = encoder_output.view(encoder_output.size(1), encoder_output.size(2)) 43 | x[n_tokens:, :] = node_features 44 | graph_hidden = self.graph_encoder(x=x, adj=adj) 45 | 46 | new_hidden = self.combine(torch.cat((graph_hidden, torch.squeeze(hidden[1])))) 47 | new_hidden = F.relu(new_hidden) 48 | 49 | hidden = (new_hidden.view(1, 1, new_hidden.size(0)), hidden[1]) 50 | 51 | # first input to the decoder is the tokens 52 | input = torch.tensor([[0]], device=self.device) 53 | 54 | # sequence decoder 55 | for t in range(1, max_len): 56 | output, hidden = self.decoder(input, hidden, encoder_output) 57 | outputs[t] = output 58 | top1 = output.max(1)[1] 59 | input = top1 60 | 61 | return outputs 62 | 63 | -------------------------------------------------------------------------------- /models/gat_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.graph_convolutional_layer import GraphConvolution 3 | from models.graph_attention_layer import GraphAttentionLayer 4 | 5 | 6 | class GATEncoder(nn.Module): 7 | """ 8 | Graph encoder using a Graph Attention Network. 9 | """ 10 | def __init__(self, num_features, hidden_size, dropout=0): 11 | super(GATEncoder, self).__init__() 12 | 13 | self.gc1 = GraphConvolution(num_features, hidden_size) 14 | self.gc2 = GraphConvolution(hidden_size, hidden_size) 15 | 16 | self.attention1 = GraphAttentionLayer(num_features, hidden_size) 17 | self.attention2 = GraphAttentionLayer(hidden_size, hidden_size) 18 | 19 | self.dropout = dropout 20 | 21 | def forward(self, x, adj): 22 | x = self.attention1(x, adj) 23 | x = self.attention1(x, adj, concat=False) 24 | 25 | return x.mean(dim=0) 26 | 27 | -------------------------------------------------------------------------------- /models/gcn_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from models.graph_convolutional_layer import GraphConvolution 4 | 5 | 6 | class GCNEncoder(nn.Module): 7 | """ 8 | Graph encoder using a Graph Convolutional Network. 9 | """ 10 | def __init__(self, num_features, hidden_size, dropout=0): 11 | super(GCNEncoder, self).__init__() 12 | 13 | self.gc1 = GraphConvolution(num_features, hidden_size) 14 | self.gc2 = GraphConvolution(hidden_size, hidden_size) 15 | self.dropout = dropout 16 | 17 | def forward(self, x, adj): 18 | x = F.relu(self.gc1(x, adj)) 19 | x = F.dropout(x, self.dropout, training=self.training) 20 | x = self.gc2(x, adj) 21 | 22 | return x.mean(dim=0) 23 | 24 | -------------------------------------------------------------------------------- /models/graph_attention_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GraphAttentionLayer(nn.Module): 7 | """ 8 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 9 | """ 10 | 11 | def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2, concat=True): 12 | super(GraphAttentionLayer, self).__init__() 13 | self.dropout = dropout 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.alpha = alpha 17 | self.concat = concat 18 | 19 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 20 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 21 | self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) 22 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 23 | 24 | self.leakyrelu = nn.LeakyReLU(self.alpha) 25 | 26 | def forward(self, input, adj): 27 | h = torch.mm(input, self.W) 28 | N = h.size(0) 29 | 30 | a_input = torch.cat((h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)), dim=1).view(N, -1, 31 | 2 * self.out_features) 32 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 33 | 34 | zero_vec = -9e15*torch.ones_like(e) 35 | attention = torch.where(adj > 0, e, zero_vec) 36 | attention = F.softmax(attention, dim=1) 37 | attention = F.dropout(attention, self.dropout, training=self.training) 38 | h_prime = torch.matmul(attention, h) 39 | 40 | if self.concat: 41 | return F.elu(h_prime) 42 | else: 43 | return h_prime 44 | -------------------------------------------------------------------------------- /models/graph_convolutional_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch.nn.parameter import Parameter 6 | from torch.nn.modules.module import Module 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class GraphConvolution(Module): 12 | """ 13 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 14 | """ 15 | 16 | def __init__(self, in_features, out_features, bias=True): 17 | super(GraphConvolution, self).__init__() 18 | self.in_features = in_features 19 | self.out_features = out_features 20 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)).to(device) 21 | if bias: 22 | self.bias = Parameter(torch.FloatTensor(out_features)).to(device) 23 | else: 24 | self.register_parameter('bias', None) 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | stdv = 1. / math.sqrt(self.weight.size(1)) 29 | self.weight.data.uniform_(-stdv, stdv) 30 | if self.bias is not None: 31 | self.bias.data.uniform_(-stdv, stdv) 32 | 33 | def forward(self, input, adj): 34 | support = torch.mm(input, self.weight).to(device) 35 | output = torch.spmm(adj, support).to(device) 36 | if self.bias is not None: 37 | return output + self.bias 38 | else: 39 | return output 40 | 41 | def __repr__(self): 42 | return self.__class__.__name__ + ' (' \ 43 | + str(self.in_features) + ' -> ' \ 44 | + str(self.out_features) + ')' 45 | -------------------------------------------------------------------------------- /models/lstm_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class LSTMDecoder(nn.Module): 7 | """ 8 | Sequence decoder which makes use of a single-layer LSTM. 9 | """ 10 | def __init__(self, hidden_size, output_size, device, attention=False): 11 | super(LSTMDecoder, self).__init__() 12 | self.hidden_size = hidden_size 13 | self.output_size = output_size 14 | self.attention = attention 15 | self.device = device 16 | 17 | self.embedding = nn.Embedding(output_size, hidden_size).to(device) 18 | self.lstm = nn.LSTM(hidden_size, hidden_size).to(device) 19 | self.out = nn.Linear(hidden_size, output_size).to(device) 20 | self.softmax = nn.LogSoftmax(dim=1) 21 | self.attention_layer = nn.Linear(hidden_size * 2, 1).to(device) 22 | self.attention_combine = nn.Linear(hidden_size * 2, hidden_size).to(device) 23 | 24 | def forward(self, input, hidden, encoder_hiddens, input_seq=None): 25 | # encoder_hiddens has shape [batch_size, seq_len, hidden_dim] 26 | output = self.embedding(input).view(1, 1, -1) 27 | 28 | # Compute attention 29 | if self.attention: 30 | # Create a matrix of shape [batch_size, seq_len, 2 * hidden_dim] where the last 31 | # dimension is a concatenation of the ith encoder hidden state and the current decoder 32 | # hidden 33 | hiddens = torch.cat((encoder_hiddens, hidden[0].repeat(1, encoder_hiddens.size(1), 1)), 34 | dim=2) 35 | 36 | # attention_coeff has shape [seq_len] and contains the attention coeffiecients for 37 | # each encoder hidden state 38 | # attention_coeff has shape [batch_size, seq_len, 1] 39 | attention_coeff = self.attention_layer(hiddens) 40 | attention_coeff = torch.squeeze(attention_coeff, dim=2) 41 | attention_coeff = torch.squeeze(attention_coeff, dim=0) 42 | attention_coeff = F.softmax(attention_coeff, dim=0) 43 | 44 | # Make encoder_hiddens of shape [hidden_dim, seq_len] as long as batch size is 1 45 | encoder_hiddens = torch.squeeze(encoder_hiddens, dim=0).t() 46 | 47 | context = torch.matmul(encoder_hiddens, attention_coeff).view(1, 1, -1) 48 | output = torch.cat((output, context), 2) 49 | output = self.attention_combine(output) 50 | 51 | output = F.relu(output) 52 | output, hidden = self.lstm(output, hidden) 53 | output = self.softmax(self.out(output[0])) 54 | 55 | return output, hidden 56 | -------------------------------------------------------------------------------- /models/lstm_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class LSTMEncoder(nn.Module): 6 | """ 7 | Sequence encoder which makes use of a single-layer LSTM. 8 | """ 9 | def __init__(self, input_size, hidden_size, device): 10 | super(LSTMEncoder, self).__init__() 11 | self.hidden_size = hidden_size 12 | 13 | self.embedding = nn.Embedding(input_size, hidden_size).to(device) 14 | self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True).to(device) 15 | self.device = device 16 | 17 | def forward(self, input): 18 | hidden = self.init_hidden() 19 | embedded = self.embedding(input).view(1, -1, self.hidden_size) 20 | output = embedded 21 | output, hidden = self.lstm(output, hidden) 22 | return output, hidden 23 | 24 | def init_hidden(self): 25 | return (torch.zeros(1, 1, self.hidden_size).to(self.device), 26 | torch.zeros(1, 1, self.hidden_size).to(self.device)) 27 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emalgorithm/structured-neural-summarization-replication/e9b672c0b76ea075ff1a3ec8a5c3fc88afc521b9/training/__init__.py -------------------------------------------------------------------------------- /training/evaluation_util.py: -------------------------------------------------------------------------------- 1 | from rouge import Rouge 2 | 3 | 4 | def compute_rouge_scores(pred_seq, target_seq): 5 | """ 6 | :param pred_seq: Predicted sequence 7 | :param target_seq: Target sequence 8 | :return: a pair (rouge_2, rouge_l) containing the rouge-2 and rouge-l scores given pred_seq 9 | and target_seq 10 | """ 11 | rouge = Rouge() 12 | pred_seq_str = ' '.join([str(x) for x in pred_seq]) 13 | target_seq = ' '.join([str(x) for x in target_seq]) 14 | 15 | scores = rouge.get_scores(pred_seq_str, target_seq) 16 | 17 | return scores[0]['rouge-2']['f'], scores[0]['rouge-l']['f'] 18 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from models.full_model import FullModel 5 | from training.train_model import train_iters 6 | from models.lstm_encoder import LSTMEncoder 7 | from models.lstm_decoder import LSTMDecoder 8 | from data_processing.data_util import prepare_tokens, prepare_data 9 | from models.gat_encoder import GATEncoder 10 | from models.gcn_encoder import GCNEncoder 11 | 12 | 13 | def main(): 14 | """ 15 | Entry-point for running the models. 16 | """ 17 | 18 | # Create directory for saving results 19 | model_dir = '../results/{}/'.format(opt.model_name) 20 | if not os.path.exists(model_dir): 21 | os.makedirs(model_dir) 22 | 23 | # Store hyperparams 24 | with open(model_dir + 'hyperparams.txt', 'w') as f: 25 | f.write(str(opt)) 26 | 27 | # Prepare data 28 | if opt.graph: 29 | lang, pairs = prepare_data(num_samples=opt.n_samples) 30 | pairs = [pair for pair in pairs if len(pair[0][1][0]) > 0] 31 | else: 32 | lang, pairs = prepare_tokens(num_samples=opt.n_samples) 33 | 34 | # Create model 35 | hidden_size = 256 36 | encoder = LSTMEncoder(lang.n_words, hidden_size, opt.device).to(opt.device) 37 | 38 | decoder = LSTMDecoder(hidden_size, lang.n_words, opt.device, attention=opt.attention).to( 39 | opt.device) 40 | if opt.graph: 41 | if opt.gat: 42 | graph_encoder = GATEncoder(hidden_size, hidden_size) 43 | else: 44 | graph_encoder = GCNEncoder(hidden_size, hidden_size) 45 | model = FullModel(encoder=encoder, graph_encoder=graph_encoder, decoder=decoder, 46 | device=opt.device) 47 | else: 48 | model = FullModel(encoder=encoder, decoder=decoder, device=opt.device) 49 | 50 | # Train model 51 | train_iters(model, opt.iterations, pairs, print_every=opt.print_every, model_dir=model_dir, 52 | lang=lang, graph=opt.graph) 53 | 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--attention', type=bool, default=False, help='whether to use attention') 57 | parser.add_argument('--model_name', default="test10", help='model name') 58 | parser.add_argument('--device', default="cpu", help='cpu or cuda') 59 | parser.add_argument('--n_samples', type=int, default=None, help='Number of samples to train on') 60 | parser.add_argument('--print_every', type=int, default=1000, help='Print results after a fixed ' 61 | 'number of iterations') 62 | parser.add_argument('--iterations', type=int, default=100, help='Number of iterations to train for') 63 | parser.add_argument('--graph', type=bool, default=False, help='Whether to use a graph encoder') 64 | parser.add_argument('--gat', type=bool, default=False, help='Whether to use GAT or GCN') 65 | 66 | opt = parser.parse_args() 67 | print(opt) 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /training/train_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | import random 3 | from data_processing.data_util import tensors_from_pair_tokens, plot_loss, \ 4 | tensors_from_pair_tokens_graph 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import optim 9 | from sklearn.metrics import f1_score 10 | import numpy as np 11 | from training.evaluation_util import compute_rouge_scores 12 | import pickle 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def evaluate(seq2seq_model, eval_pairs, criterion, eval='val', graph=False): 19 | """ 20 | Evaluate model and return metrics. 21 | """ 22 | with torch.no_grad(): 23 | loss = 0 24 | f1 = 0 25 | rouge_2 = 0 26 | rouge_l = 0 27 | for i in range(len(eval_pairs)): 28 | if graph: 29 | eval_pair = eval_pairs[i] 30 | input_tensor = eval_pair[0][0].to(device) 31 | adj_tensor = eval_pair[0][1].to(device) 32 | node_features = eval_pair[0][2].to(device) 33 | target_tensor = eval_pair[1].to(device) 34 | 35 | output = seq2seq_model(sequence=input_tensor.view(-1), adj=adj_tensor, 36 | target=target_tensor.view(-1), node_features=node_features) 37 | else: 38 | eval_pair = eval_pairs[i] 39 | input_tensor = eval_pair[0] 40 | target_tensor = eval_pair[1] 41 | 42 | output = seq2seq_model(sequence=input_tensor.view(-1), target=target_tensor.view( 43 | -1)) 44 | 45 | loss += criterion(output.view(-1, output.size(2)), target_tensor.view(-1)) 46 | pred = output.view(-1, output.size(2)).argmax(1).cpu().numpy() 47 | 48 | y_true = target_tensor.cpu().numpy().reshape(-1) 49 | f1 += f1_score(y_true, pred, average='micro') 50 | rouge_2_temp, rouge_l_temp = compute_rouge_scores(pred, y_true) 51 | rouge_2 += rouge_2_temp 52 | rouge_l += rouge_l_temp 53 | 54 | loss /= len(eval_pairs) 55 | f1 /= len(eval_pairs) 56 | rouge_2 /= len(eval_pairs) 57 | rouge_l /= len(eval_pairs) 58 | 59 | print('{} loss: {}'.format(eval, loss)) 60 | print('{} f1_score: {}'.format(eval, f1)) 61 | print('{} rouge_2_score: {}'.format(eval, rouge_2)) 62 | print('{} rouge_l_score: {}'.format(eval, rouge_l)) 63 | 64 | return loss.item(), f1, rouge_2, rouge_l 65 | 66 | 67 | def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, graph, 68 | adj_tensor=None, node_features=None): 69 | """ 70 | Train model for a single iteration. 71 | """ 72 | optimizer.zero_grad() 73 | 74 | if graph: 75 | output = seq2seq_model(sequence=input_tensor.view(-1), adj=adj_tensor, 76 | target=target_tensor.view(-1), node_features=node_features) 77 | else: 78 | output = seq2seq_model(sequence=input_tensor.view(-1), target=target_tensor.view(-1)) 79 | 80 | loss = criterion(output.view(-1, output.size(2)), target_tensor.view(-1)) 81 | pred = output.view(-1, output.size(2)).argmax(1).cpu().numpy() 82 | 83 | loss.backward() 84 | 85 | optimizer.step() 86 | 87 | return loss.item(), pred 88 | 89 | 90 | def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0.001, 91 | model_dir=None, lang=None, graph=False): 92 | """ 93 | Run complete training of the model. 94 | """ 95 | train_losses = [] 96 | val_losses = [] 97 | 98 | val_f1_scores = [] 99 | val_rouge_2_scores = [] 100 | val_rouge_l_scores = [] 101 | 102 | print_loss_total = 0 # Reset every print_every 103 | plot_loss_total = 0 # Reset every plot_every 104 | f1 = 0 105 | rouge_2 = 0 106 | rouge_l = 0 107 | 108 | train_pairs, val_pairs, test_pairs = np.split(pairs, 109 | [int(.8 * len(pairs)), int(.9 * len(pairs))]) 110 | 111 | optimizer = optim.Adam(seq2seq_model.parameters(), lr=learning_rate) 112 | 113 | # Prepare data 114 | if graph: 115 | training_pairs = [tensors_from_pair_tokens_graph(random.choice(train_pairs), lang) 116 | for i in range(n_iters)] 117 | val_tensor_pairs = [tensors_from_pair_tokens_graph(val_pair, lang) for val_pair in val_pairs] 118 | else: 119 | training_pairs = [tensors_from_pair_tokens(random.choice(train_pairs), lang) 120 | for i in range(n_iters)] 121 | val_tensor_pairs = [tensors_from_pair_tokens(val_pair, lang) for val_pair in val_pairs] 122 | 123 | # test_tensor_pairs = [tensors_from_pair_tokens(test_pair, lang) for test_pair in test_pairs] 124 | criterion = nn.NLLLoss() 125 | 126 | # Train 127 | for iter in range(1, n_iters + 1): 128 | training_pair = training_pairs[iter - 1] 129 | if graph: 130 | input_tensor = training_pair[0][0].to(device) 131 | adj_tensor = training_pair[0][1].to(device) 132 | node_features = training_pair[0][2].to(device) 133 | target_tensor = training_pair[1].to(device) 134 | 135 | loss, pred = train(input_tensor, target_tensor, seq2seq_model, optimizer, 136 | criterion, adj_tensor=adj_tensor, graph=graph, node_features=node_features) 137 | else: 138 | input_tensor = training_pair[0] 139 | target_tensor = training_pair[1] 140 | 141 | loss, pred = train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, 142 | graph=graph) 143 | 144 | print_loss_total += loss 145 | plot_loss_total += loss 146 | 147 | y_true = target_tensor.cpu().numpy().reshape(-1) 148 | 149 | if len(y_true) < len(pred): 150 | y_true = np.pad(y_true, (0, len(pred) - len(y_true)), mode='constant') 151 | else: 152 | pred = np.pad(pred, (0, len(y_true) - len(pred)), mode='constant') 153 | 154 | f1 += f1_score(y_true, pred, average='micro') 155 | rouge_2_temp, rouge_l_temp = compute_rouge_scores(pred, y_true) 156 | rouge_2 += rouge_2_temp 157 | rouge_l += rouge_l_temp 158 | 159 | if iter % print_every == 0: 160 | # Evaluate 161 | print_loss_avg = print_loss_total / print_every 162 | print_loss_total = 0 163 | print('train (%d %d%%) %.4f' % (iter, iter / n_iters * 100, print_loss_avg)) 164 | print('train f1_score: {}'.format(f1 / iter)) 165 | print('train rouge_2_score: {}'.format(rouge_2 / iter)) 166 | print('train rouge_l_score: {}'.format(rouge_l / iter)) 167 | 168 | train_loss = print_loss_avg 169 | val_loss, val_f1, val_rouge_2, val_rouge_l = evaluate(seq2seq_model, val_tensor_pairs, 170 | criterion, graph=graph) 171 | 172 | if not val_losses or val_loss < min(val_losses): 173 | torch.save(seq2seq_model.state_dict(), model_dir + 'model.pt') 174 | print("Saved updated model") 175 | 176 | train_losses.append(train_loss) 177 | val_losses.append(val_loss) 178 | # test_losses.append(test_loss) 179 | 180 | val_f1_scores.append(val_f1) 181 | val_rouge_2_scores.append(val_rouge_2) 182 | val_rouge_l_scores.append(val_rouge_l) 183 | 184 | # Store results 185 | results = {'train_losses': train_losses, 186 | 'val_losses': val_losses, 187 | 'val_f1_scores': val_f1_scores, 188 | 'val_rouge_2_scores': val_rouge_2_scores, 189 | 'val_rouge_l_scores': val_rouge_l_scores} 190 | 191 | with open(model_dir + 'results.txt', 'w') as f: 192 | f.write(str(results)) 193 | pickle.dump(results, open(model_dir + 'results.pkl', 'wb')) 194 | 195 | plot_loss(train_losses, val_losses, file_path=model_dir + 'loss.jpg') 196 | --------------------------------------------------------------------------------