├── images_readme ├── paper_model.png └── pyramidal_model.png ├── LICENSE ├── README.md ├── pretrained_glove_embedding_script.ipynb ├── test_dataset_script.ipynb ├── .gitignore ├── glove_training_preparation.ipynb ├── dataset_script.ipynb ├── paper_network.ipynb └── pyramidal_network.ipynb /images_readme/paper_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricivi/adversarial_training_methods/HEAD/images_readme/paper_model.png -------------------------------------------------------------------------------- /images_readme/pyramidal_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricivi/adversarial_training_methods/HEAD/images_readme/pyramidal_model.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Enrico Civitelli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Training Methods 2 | Adversarial and virtual adversarial training methods for semi-supervised text classification. 3 | 4 | Based on the paper: 5 | [Adversarial training methods for semi-supervised text classification, ICLR 2017, Miyato T., Dai A., Goodfellow I. 6 | ](https://arxiv.org/abs/1605.07725) 7 | 8 | **Without the pre-training phase** 9 | 10 | ## Requirements 11 | 12 | Package | Version 13 | :-------: | :-------: 14 | [Python](https://www.python.org/downloads/) | 3.5.4 15 | [Jupyter](http://jupyter.org/install) | 1.0.0 16 | [Tensorflow](https://www.tensorflow.org/versions/r1.5/) | r1.5 17 | [GloVe](https://nlp.stanford.edu/projects/glove/) | 1.2 18 | [NLTK](http://www.nltk.org/install.html) | 3.2.5 19 | [ProgressBar2](https://pypi.python.org/pypi/progressbar2) | 3.34.3 20 | [Matplotlib](https://matplotlib.org/2.1.2/index.html) | 2.1.2 21 | [Argparse](https://pypi.python.org/pypi/argparse/1.1) | 1.1 22 | 23 | 24 | ## Dataset creation 25 | 26 | Download the [IMDB](http://ai.stanford.edu/~amaas/data/sentiment/) dataset. 27 | 28 | Then we have to do the following steps: 29 | * prepare dataset in order to train GloVe using _glove_training_preparation.ipynb_ 30 | * set the vectors length at 256 and remove words appearing in less than 3 reviews in the GloVe training script 31 | * run GloVe training script 32 | * create the embedding matrix (used by the network) and the dictionary (used to convert reviews in sequences of index) using _pretrained_glove_embedding_script.ipynb_ 33 | * convert reviews in sequences of index using _dataset_script.pynb_ and _test_dataset_script.ipynb_ 34 | 35 | ## Training 36 | 37 | 5% of the training set is used for validation 38 | 39 | use _paper_network.ipynb_ or _pyramidal_network.ipynb_ according to the network architecture you want to use. 40 | 41 | Paper model | Pyramidal model 42 | :---------: | :-------------: 43 | ![alt text](images_readme/paper_model.png "Paper model") | ![alt text](images_readme/pyramidal_model.png "Pyramidal model") 44 | 45 | ## Results 46 | 47 | Sequences are truncated at 1200 (pyramidal model), 600 and 400 to test the sensitivity of the model to reviews lengths. In particular we cut off or add zero-padding at the initial part of the review. 48 | 49 | Method | Seq. Length | Epochs | Accuracy 50 | :------: | :-----------: | :------: | :--------: 51 | baseline | 400 | 10 | 0.906 52 | adversarial | 400 | 10 | 0.914 53 | virtual adv. | 400 | 10 | **0.921** 54 | baseline | 600 | 10 | 0.904 55 | adversarial | 600 | 10 | 0.912 56 | pyramidal baseline | 1200 | 10 | 0.910 57 | pyramidal adversarial | 1200 | 10 | 0.916 58 | -------------------------------------------------------------------------------- /pretrained_glove_embedding_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "collapsed": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "# building dictionary word:word_embedding\n", 23 | "\n", 24 | "word_embedding = dict()\n", 25 | "dictionary = dict()\n", 26 | "vectors = open(\"./GloVe-1.2/vectors.txt\", \"r\")\n", 27 | "for i, line in enumerate(vectors):\n", 28 | " sline = line.split(' ')\n", 29 | " try:\n", 30 | " word = sline[0]\n", 31 | " word_embedding[word] = np.asarray(sline[1:], dtype='float32')\n", 32 | " dictionary[word] = i+1\n", 33 | " except:\n", 34 | " print(\"error at index {}\".format(i))" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "collapsed": true, 42 | "scrolled": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "# building embedding matrix (used by the network)\n", 47 | "embedding = np.zeros( shape=(len(dictionary.keys())+1, len(list(word_embedding.values())[0])))\n", 48 | "for word in dictionary.keys():\n", 49 | " idx = dictionary[word]\n", 50 | " embedding[idx] = word_embedding[word]" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "collapsed": true 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "np.save(\"./dataset/imdb/nltk_embedding_matrix.npy\", embedding)\n", 62 | "np.save(\"./dataset/imdb/nltk_dictionary.npy\", dictionary)\n", 63 | "np.save(\"./dataset/imdb/nltk_word_embedding.npy\", word_embedding)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [] 74 | } 75 | ], 76 | "metadata": { 77 | "kernelspec": { 78 | "display_name": "Python 3", 79 | "language": "python", 80 | "name": "python3" 81 | }, 82 | "language_info": { 83 | "codemirror_mode": { 84 | "name": "ipython", 85 | "version": 3 86 | }, 87 | "file_extension": ".py", 88 | "mimetype": "text/x-python", 89 | "name": "python", 90 | "nbconvert_exporter": "python", 91 | "pygments_lexer": "ipython3", 92 | "version": "3.6.2" 93 | } 94 | }, 95 | "nbformat": 4, 96 | "nbformat_minor": 2 97 | } 98 | -------------------------------------------------------------------------------- /test_dataset_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "from os import listdir\n", 13 | "from os.path import isfile, join\n", 14 | "from nltk import word_tokenize\n", 15 | "from nltk import RegexpTokenizer\n", 16 | "tokenizer = RegexpTokenizer(r'\\w+')" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 11, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "dictionary = np.load(\"./dataset/imdb/nltk_dictionary.npy\").item()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 12, 33 | "metadata": { 34 | "collapsed": true 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "pos_files = [f for f in listdir(\"./dataset/imdb/test/pos\") if ( isfile(join(\"./dataset/imdb/test/pos\", f)) and (f[0] != '.') ) ]\n", 39 | "neg_files = [f for f in listdir(\"./dataset/imdb/test/neg\") if ( isfile(join(\"./dataset/imdb/test/neg\", f)) and (f[0] != '.') ) ]" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 13, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "xtest = list()\n", 51 | "ytest = list()\n", 52 | "for p in pos_files:\n", 53 | " pos = open(\"./dataset/imdb/test/pos/\"+p, 'r')\n", 54 | " for line in pos:\n", 55 | " if line:\n", 56 | " line = line.replace('

', ' ')\n", 57 | " tokens = tokenizer.tokenize(line.lower())\n", 58 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n", 59 | " \n", 60 | " xtest.append(tokens)\n", 61 | " ytest.append(1)\n", 62 | " \n", 63 | " else:\n", 64 | " print(\"empty line: {}.\".format(line))\n", 65 | " pos.close()\n", 66 | " \n", 67 | "for n in neg_files:\n", 68 | " neg = open(\"./dataset/imdb/test/neg/\"+n, 'r')\n", 69 | " for line in neg:\n", 70 | " if line:\n", 71 | " line = line.replace('

', ' ')\n", 72 | " tokens = tokenizer.tokenize(line.lower())\n", 73 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n", 74 | " \n", 75 | " xtest.append(tokens)\n", 76 | " ytest.append(0)\n", 77 | " \n", 78 | " else:\n", 79 | " print(\"empty line: {}.\".format(line))\n", 80 | " neg.close()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 14, 86 | "metadata": { 87 | "collapsed": true 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "xtest = np.asarray(xtest)\n", 92 | "ytest = np.asarray(ytest)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 15, 98 | "metadata": { 99 | "collapsed": true 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "np.save(\"./dataset/imdb/nltk_xtest.npy\", xtest)\n", 104 | "np.save(\"./dataset/imdb/nltk_ytest.npy\", ytest)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "collapsed": true 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "max( [len(s) for s in x] )" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.6.2" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/linux,macos,python,windows,pycharm 2 | 3 | ### Linux ### 4 | *~ 5 | 6 | # temporary files which can be created if a process still has a handle open of a deleted file 7 | .fuse_hidden* 8 | 9 | # KDE directory preferences 10 | .directory 11 | 12 | # Linux trash folder which might appear on any partition or disk 13 | .Trash-* 14 | 15 | # .nfs files are created when an open file is removed but is still being accessed 16 | .nfs* 17 | 18 | ### macOS ### 19 | *.DS_Store 20 | .AppleDouble 21 | .LSOverride 22 | 23 | # Icon must end with two \r 24 | Icon 25 | 26 | # Thumbnails 27 | ._* 28 | 29 | # Files that might appear in the root of a volume 30 | .DocumentRevisions-V100 31 | .fseventsd 32 | .Spotlight-V100 33 | .TemporaryItems 34 | .Trashes 35 | .VolumeIcon.icns 36 | .com.apple.timemachine.donotpresent 37 | 38 | # Directories potentially created on remote AFP share 39 | .AppleDB 40 | .AppleDesktop 41 | Network Trash Folder 42 | Temporary Items 43 | .apdisk 44 | 45 | ### PyCharm ### 46 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 47 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 48 | 49 | # User-specific stuff: 50 | .idea/**/workspace.xml 51 | .idea/**/tasks.xml 52 | .idea/dictionaries 53 | 54 | # Sensitive or high-churn files: 55 | .idea/**/dataSources/ 56 | .idea/**/dataSources.ids 57 | .idea/**/dataSources.xml 58 | .idea/**/dataSources.local.xml 59 | .idea/**/sqlDataSources.xml 60 | .idea/**/dynamic.xml 61 | .idea/**/uiDesigner.xml 62 | 63 | # Gradle: 64 | .idea/**/gradle.xml 65 | .idea/**/libraries 66 | 67 | # CMake 68 | cmake-build-debug/ 69 | 70 | # Mongo Explorer plugin: 71 | .idea/**/mongoSettings.xml 72 | 73 | ## File-based project format: 74 | *.iws 75 | 76 | ## Plugin-specific files: 77 | 78 | # IntelliJ 79 | /out/ 80 | 81 | # mpeltonen/sbt-idea plugin 82 | .idea_modules/ 83 | 84 | # JIRA plugin 85 | atlassian-ide-plugin.xml 86 | 87 | # Cursive Clojure plugin 88 | .idea/replstate.xml 89 | 90 | # Ruby plugin and RubyMine 91 | /.rakeTasks 92 | 93 | # Crashlytics plugin (for Android Studio and IntelliJ) 94 | com_crashlytics_export_strings.xml 95 | crashlytics.properties 96 | crashlytics-build.properties 97 | fabric.properties 98 | 99 | ### PyCharm Patch ### 100 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 101 | 102 | # *.iml 103 | # modules.xml 104 | # .idea/misc.xml 105 | # *.ipr 106 | 107 | # Sonarlint plugin 108 | .idea/sonarlint 109 | 110 | ### Python ### 111 | # Byte-compiled / optimized / DLL files 112 | __pycache__/ 113 | *.py[cod] 114 | *$py.class 115 | 116 | # C extensions 117 | *.so 118 | 119 | # Distribution / packaging 120 | .Python 121 | build/ 122 | develop-eggs/ 123 | dist/ 124 | downloads/ 125 | eggs/ 126 | .eggs/ 127 | lib/ 128 | lib64/ 129 | parts/ 130 | sdist/ 131 | var/ 132 | wheels/ 133 | *.egg-info/ 134 | .installed.cfg 135 | *.egg 136 | 137 | # PyInstaller 138 | # Usually these files are written by a python script from a template 139 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 140 | *.manifest 141 | *.spec 142 | 143 | # Installer logs 144 | pip-log.txt 145 | pip-delete-this-directory.txt 146 | 147 | # Unit test / coverage reports 148 | htmlcov/ 149 | .tox/ 150 | .coverage 151 | .coverage.* 152 | .cache 153 | .pytest_cache/ 154 | nosetests.xml 155 | coverage.xml 156 | *.cover 157 | .hypothesis/ 158 | 159 | # Translations 160 | *.mo 161 | *.pot 162 | 163 | # Flask stuff: 164 | instance/ 165 | .webassets-cache 166 | 167 | # Scrapy stuff: 168 | .scrapy 169 | 170 | # Sphinx documentation 171 | docs/_build/ 172 | 173 | # PyBuilder 174 | target/ 175 | 176 | # Jupyter Notebook 177 | .ipynb_checkpoints 178 | 179 | # pyenv 180 | .python-version 181 | 182 | # celery beat schedule file 183 | celerybeat-schedule.* 184 | 185 | # SageMath parsed files 186 | *.sage.py 187 | 188 | # Environments 189 | .env 190 | .venv 191 | env/ 192 | venv/ 193 | ENV/ 194 | env.bak/ 195 | venv.bak/ 196 | 197 | # Spyder project settings 198 | .spyderproject 199 | .spyproject 200 | 201 | # Rope project settings 202 | .ropeproject 203 | 204 | # mkdocs documentation 205 | /site 206 | 207 | # mypy 208 | .mypy_cache/ 209 | 210 | ### Windows ### 211 | # Windows thumbnail cache files 212 | Thumbs.db 213 | ehthumbs.db 214 | ehthumbs_vista.db 215 | 216 | # Folder config file 217 | Desktop.ini 218 | 219 | # Recycle Bin used on file shares 220 | $RECYCLE.BIN/ 221 | 222 | # Windows Installer files 223 | *.cab 224 | *.msi 225 | *.msm 226 | *.msp 227 | 228 | # Windows shortcuts 229 | *.lnk 230 | 231 | 232 | # End of https://www.gitignore.io/api/linux,macos,python,windows,pycharm 233 | -------------------------------------------------------------------------------- /glove_training_preparation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# TAKE ALL REVIEWS(POS NEG AND UNSUP) AND MERGE THEM IN A SINGLE TEXT FILE WITH ONE REVIEW PER LINE (merged.txt)\n", 12 | "# PERFORM TEXT TOKENIZATION\n", 13 | "\n", 14 | "# REWRITE ALL THE TOKENIZED CORPUS AS A SINGLE WORDS SEQUENCE SEPARATED BY A BLANK SPACE (new_merged).\n", 15 | "# This file is used as input for GloVe training." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "collapsed": true 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "from os import listdir\n", 27 | "from os.path import isfile, join\n", 28 | "from nltk import RegexpTokenizer\n", 29 | "tokenizer = RegexpTokenizer(r'\\w+')" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "collapsed": true 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "# new_path = \"./\"\n", 41 | "old_path = \"./dataset/imdb/train/\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "collapsed": true 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "# support file \n", 53 | "merged = open(\"./merged.txt\", \"w\")" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "collapsed": true 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "current_directory = old_path + 'pos'\n", 65 | "files = [f for f in listdir(current_directory) if ( isfile(join(current_directory, f)) and (f[0] != '.') ) ]\n", 66 | "for f in files:\n", 67 | " t = open(current_directory+'/'+f, 'r')\n", 68 | " for line in t:\n", 69 | " if line:\n", 70 | " line = line.replace('

', ' ')\n", 71 | " merged.write(line+'\\n')\n", 72 | " t.close()\n", 73 | "\n", 74 | "current_directory = old_path + 'neg'\n", 75 | "files = [f for f in listdir(current_directory) if ( isfile(join(current_directory, f)) and (f[0] != '.') ) ]\n", 76 | "for f in files:\n", 77 | " t = open(current_directory+'/'+f, 'r')\n", 78 | " for line in t:\n", 79 | " if line:\n", 80 | " line = line.replace('

', ' ')\n", 81 | " merged.write(line+'\\n')\n", 82 | " t.close()\n", 83 | "\n", 84 | "current_directory = old_path + 'unsup'\n", 85 | "files = [f for f in listdir(current_directory) if ( isfile(join(current_directory, f)) and (f[0] != '.') ) ]\n", 86 | "for f in files:\n", 87 | " t = open(current_directory+'/'+f, 'r')\n", 88 | " for line in t:\n", 89 | " if line:\n", 90 | " line = line.replace('

', ' ')\n", 91 | " merged.write(line+'\\n')\n", 92 | " t.close()\n", 93 | "\n", 94 | "merged.close()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "collapsed": true 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "corpus = open(\"./merged.txt\", \"r\")\n", 106 | "new_corpus = open(\"./GloVe-1.2/new_merged\", \"w\")" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "collapsed": true 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "lines = corpus.readlines()\n", 118 | "text = ''\n", 119 | "\n", 120 | "# put everithing in one line\n", 121 | "for line in lines:\n", 122 | " text = text + line\n", 123 | "corpus.close()\n", 124 | "\n", 125 | "# perform TOKENIZATION, returns a vector of words\n", 126 | "words = tokenizer.tokenize(text.lower())\n", 127 | "\n", 128 | "# rewrite this vector in a file in which words are separated by blank spaces\n", 129 | "for w in words:\n", 130 | " new_corpus.seek(0, 2)\n", 131 | " new_corpus.write(w + ' ')\n", 132 | "new_corpus.close()" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": true 140 | }, 141 | "outputs": [], 142 | "source": [] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "Python 3", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.6.2" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 2 166 | } 167 | -------------------------------------------------------------------------------- /dataset_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "from os import listdir\n", 13 | "from os.path import isfile, join\n", 14 | "from nltk import word_tokenize\n", 15 | "from nltk import RegexpTokenizer\n", 16 | "tokenizer = RegexpTokenizer(r'\\w+')" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "dictionary = np.load(\"./dataset/imdb/nltk_dictionary.npy\").item()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "collapsed": true 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "pos_files = [f for f in listdir(\"./dataset/imdb/train/pos\") if ( isfile(join(\"./dataset/imdb/train/pos\", f)) and (f[0] != '.') ) ]\n", 39 | "neg_files = [f for f in listdir(\"./dataset/imdb/train/neg\") if ( isfile(join(\"./dataset/imdb/train/neg\", f)) and (f[0] != '.') ) ]\n", 40 | "unl_files = [f for f in listdir(\"./dataset/imdb/train/unsup\") if ( isfile(join(\"./dataset/imdb/train/unsup\", f)) and (f[0] != '.') ) ]" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "print( \"Labeled...\" )\n", 50 | "xtrain = list()\n", 51 | "ytrain = list()\n", 52 | "for p in pos_files:\n", 53 | " pos = open(\"./dataset/imdb/train/pos/\"+p, 'r')\n", 54 | " for line in pos:\n", 55 | " if line:\n", 56 | " line = line.replace('

', ' ')\n", 57 | " tokens = tokenizer.tokenize(line.lower())\n", 58 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n", 59 | " \n", 60 | " xtrain.append(tokens)\n", 61 | " ytrain.append(1) \n", 62 | " else:\n", 63 | " print(\"empty line: {}.\".format(line))\n", 64 | " pos.close()\n", 65 | " \n", 66 | "for n in neg_files:\n", 67 | " neg = open(\"./dataset/imdb/train/neg/\"+n, 'r')\n", 68 | " for line in neg:\n", 69 | " if line:\n", 70 | " line = line.replace('

', ' ')\n", 71 | " tokens = tokenizer.tokenize(line.lower())\n", 72 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n", 73 | " \n", 74 | " xtrain.append(tokens)\n", 75 | " ytrain.append(0) \n", 76 | " else:\n", 77 | " print(\"empty line: {}.\".format(line))\n", 78 | " neg.close()\n", 79 | "\n", 80 | "print( \"Unlabeled...\" )\n", 81 | "unlab = list()\n", 82 | "for u in unl_files:\n", 83 | " unl = open(\"./dataset/imdb/train/unsup/\"+u, 'r')\n", 84 | " for line in unl:\n", 85 | " if line:\n", 86 | " line = line.replace('

', ' ')\n", 87 | " tokens = tokenizer.tokenize(line.lower())\n", 88 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n", 89 | " \n", 90 | " unlab.append(tokens) \n", 91 | " else:\n", 92 | " print(\"empty line: {}.\".format(line))\n", 93 | " unl.close()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "collapsed": true 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "xtrain = np.asarray(xtrain)\n", 105 | "ytrain = np.asarray(ytrain)\n", 106 | "unlab = np.asarray(unlab)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "collapsed": true 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "np.save(\"./dataset/imdb/nltk_xtrain.npy\", xtrain)\n", 118 | "np.save(\"./dataset/imdb/nltk_ytrain.npy\", ytrain)\n", 119 | "np.save(\"./dataset/imdb/nltk_ultrain.npy\", unlab)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "collapsed": true 127 | }, 128 | "outputs": [], 129 | "source": [] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.6.2" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /paper_network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf\n", 12 | "from tensorflow import keras as K\n", 13 | "import numpy as np\n", 14 | "import argparse\n", 15 | "from progressbar import ProgressBar\n", 16 | "import os\n", 17 | "import matplotlib\n", 18 | "matplotlib.use('Agg')\n", 19 | "import matplotlib.pyplot as plt" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "collapsed": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "class Network:\n", 31 | " def __init__(self, session, dict_weight, dropout=0.2, lstm_units=1024, dense_units=30):\n", 32 | " self.sess = session\n", 33 | " K.backend.set_session(self.sess)\n", 34 | " #defining layers\n", 35 | " dict_shape = dict_weight.shape\n", 36 | " self.emb = K.layers.Embedding(dict_shape[0], dict_shape[1], weights=[dict_weight], trainable=False, name='embedding')\n", 37 | " self.drop = K.layers.Dropout(rate=dropout, seed=91, name='dropout')\n", 38 | " self.lstm = K.layers.LSTM(lstm_units, stateful=False, return_sequences=False, name='lstm')\n", 39 | " self.dense = K.layers.Dense(dense_units, activation='relu', name='dense')\n", 40 | " self.p = K.layers.Dense(1, activation='sigmoid', name='p')\n", 41 | " #defining optimizer\n", 42 | " self.optimizer = tf.train.AdamOptimizer(learning_rate=0.0005)\n", 43 | " # self.optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001)\n", 44 | "\n", 45 | " def __call__(self, batch, perturbation=None):\n", 46 | " embedding = self.emb(batch) \n", 47 | " drop = self.drop(embedding)\n", 48 | " if (perturbation is not None):\n", 49 | " drop += perturbation\n", 50 | " lstm = self.lstm(drop)\n", 51 | " dense = self.dense(lstm)\n", 52 | " return self.p(dense), embedding\n", 53 | " \n", 54 | " def get_minibatch(self, x, y, ul, batch_shape=(64, 400)):\n", 55 | " x = K.preprocessing.sequence.pad_sequences(x, maxlen=batch_shape[1])\n", 56 | " permutations = np.random.permutation( len(y) )\n", 57 | " ul_permutations = None\n", 58 | " len_ratio = None\n", 59 | " if (ul is not None):\n", 60 | " ul = K.preprocessing.sequence.pad_sequences(ul, maxlen=batch_shape[1])\n", 61 | " ul_permutations = np.random.permutation( len(ul) )\n", 62 | " len_ratio = len(ul)/len(y)\n", 63 | " for s in range(0, len(y), batch_shape[0]):\n", 64 | " perm = permutations[s:s+batch_shape[0]]\n", 65 | " minibatch = {'x': x[perm], 'y': y[perm]}\n", 66 | " if (ul is not None):\n", 67 | " ul_perm = ul_permutations[int(np.floor(len_ratio*s)):int(np.floor(len_ratio*(s+batch_shape[0])))]\n", 68 | " minibatch.update( {'ul': np.concatenate((ul[ul_perm], x[perm]), axis=0)} )\n", 69 | " yield minibatch\n", 70 | " \n", 71 | " def get_loss(self, batch, labels):\n", 72 | " pred, emb = self(batch)\n", 73 | " loss = K.losses.binary_crossentropy(labels, pred)\n", 74 | " return tf.reduce_mean( loss ), emb\n", 75 | " \n", 76 | " def get_adv_loss(self, batch, labels, loss, emb, p_mult):\n", 77 | " gradient = tf.gradients(loss, emb, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n", 78 | " p_adv = p_mult * tf.nn.l2_normalize(tf.stop_gradient(gradient), dim=1)\n", 79 | " adv_loss = K.losses.binary_crossentropy(labels, self(batch, p_adv)[0])\n", 80 | " return tf.reduce_mean( adv_loss )\n", 81 | " \n", 82 | " def get_v_adv_loss(self, ul_batch, p_mult, power_iterations=1):\n", 83 | " bernoulli = tf.distributions.Bernoulli\n", 84 | " prob, emb = self(ul_batch)\n", 85 | " prob = tf.clip_by_value(prob, 1e-7, 1.-1e-7)\n", 86 | " prob_dist = bernoulli(probs=prob)\n", 87 | " #generate virtual adversarial perturbation\n", 88 | " d = tf.random_uniform(shape=tf.shape(emb), dtype=tf.float32)\n", 89 | " for _ in range( power_iterations ):\n", 90 | " d = (0.02) * tf.nn.l2_normalize(d, dim=1)\n", 91 | " p_prob = tf.clip_by_value(self(ul_batch, d)[0], 1e-7, 1.-1e-7)\n", 92 | " kl = tf.distributions.kl_divergence(prob_dist, bernoulli(probs=p_prob), allow_nan_stats=False)\n", 93 | " gradient = tf.gradients(kl, [d], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n", 94 | " d = tf.stop_gradient(gradient)\n", 95 | " d = p_mult * tf.nn.l2_normalize(d, dim=1)\n", 96 | " tf.stop_gradient(prob)\n", 97 | " #virtual adversarial loss\n", 98 | " p_prob = tf.clip_by_value(self(ul_batch, d)[0], 1e-7, 1.-1e-7)\n", 99 | " v_adv_loss = tf.distributions.kl_divergence(prob_dist, bernoulli(probs=p_prob), allow_nan_stats=False)\n", 100 | " return tf.reduce_mean( v_adv_loss )\n", 101 | "\n", 102 | " def validation(self, x, y, batch_shape=(64, 400)):\n", 103 | " print( 'Validation...' )\n", 104 | " \n", 105 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='validation_labels')\n", 106 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='validation_batch')\n", 107 | "\n", 108 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n", 109 | " \n", 110 | " accuracies = list()\n", 111 | " minibatch = self.get_minibatch(x, y, ul=None, batch_shape=batch_shape)\n", 112 | " for val_batch in minibatch:\n", 113 | " fd = {batch: val_batch['x'], labels: val_batch['y'], K.backend.learning_phase(): 0} #test mode\n", 114 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n", 115 | " \n", 116 | " print( \"Average accuracy on validation is {:.3f}\".format(np.asarray(accuracies).mean()) )\n", 117 | " \n", 118 | " def train(self, dataset, batch_shape=(64, 400), epochs=10, loss_type='none', p_mult=0.02, init=None, save=None):\n", 119 | " print( 'Training...' )\n", 120 | " xtrain = np.load( \"{}nltk_xtrain.npy\".format(dataset) )\n", 121 | " ytrain = np.load( \"{}nltk_ytrain.npy\".format(dataset) )\n", 122 | " ultrain = np.load( \"{}nltk_ultrain.npy\".format(dataset) ) if (loss_type == 'v_adv') else None\n", 123 | " \n", 124 | " # defining validation set\n", 125 | " xval = list()\n", 126 | " yval = list()\n", 127 | " for _ in range( int(len(ytrain)*0.025) ):\n", 128 | " xval.append( xtrain[0] ); xval.append( xtrain[-1] )\n", 129 | " yval.append( ytrain[0] ); yval.append( ytrain[-1] )\n", 130 | " xtrain = np.delete(xtrain, 0); xtrain = np.delete(xtrain, -1)\n", 131 | " ytrain = np.delete(ytrain, 0); ytrain = np.delete(ytrain, -1)\n", 132 | " xval = np.asarray(xval)\n", 133 | " yval = np.asarray(yval)\n", 134 | " print( '{} elements in validation set'.format(len(yval)) )\n", 135 | " # ---\n", 136 | " yval = np.reshape(yval, newshape=(yval.shape[0], 1))\n", 137 | " ytrain = np.reshape(ytrain, newshape=(ytrain.shape[0], 1))\n", 138 | " \n", 139 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='train_labels')\n", 140 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='train_batch')\n", 141 | " ul_batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='ul_batch')\n", 142 | " \n", 143 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n", 144 | " loss, emb = self.get_loss(batch, labels)\n", 145 | " if (loss_type == 'adv'):\n", 146 | " loss += self.get_adv_loss(batch, labels, loss, emb, p_mult)\n", 147 | " elif (loss_type == 'v_adv'):\n", 148 | " loss += self.get_v_adv_loss(ul_batch, p_mult)\n", 149 | "\n", 150 | " opt = self.optimizer.minimize( loss )\n", 151 | " #initializing parameters\n", 152 | " if (init is None):\n", 153 | " self.sess.run( [var.initializer for var in tf.global_variables() if not('embedding' in var.name)] )\n", 154 | " print( 'Random initialization' )\n", 155 | " else:\n", 156 | " saver = tf.train.Saver()\n", 157 | " saver.restore(self.sess, init)\n", 158 | " print( 'Restored value' )\n", 159 | " \n", 160 | " _losses = list()\n", 161 | " _accuracies = list()\n", 162 | " list_ratio = (len(ultrain)/len(ytrain)) if (ultrain is not None) else None\n", 163 | " for epoch in range(epochs):\n", 164 | " losses = list()\n", 165 | " accuracies = list()\n", 166 | " validation = list()\n", 167 | " \n", 168 | " bar = ProgressBar(max_value=np.floor(len(ytrain)/batch_shape[0]).astype('i'))\n", 169 | " minibatch = enumerate(self.get_minibatch(xtrain, ytrain, ultrain, batch_shape=batch_shape))\n", 170 | " for i, train_batch in minibatch:\n", 171 | " fd = {batch: train_batch['x'], labels: train_batch['y'], K.backend.learning_phase(): 1} #training mode\n", 172 | " if (loss_type == 'v_adv'):\n", 173 | " fd.update( {ul_batch: train_batch['ul']} )\n", 174 | " \n", 175 | " _, acc_val, loss_val = self.sess.run([opt, accuracy, loss], feed_dict=fd)\n", 176 | " \n", 177 | " accuracies.append( acc_val )\n", 178 | " losses.append( loss_val )\n", 179 | " bar.update(i)\n", 180 | " \n", 181 | " #saving accuracies and losses\n", 182 | " _accuracies.append( accuracies )\n", 183 | " _losses.append(losses)\n", 184 | " \n", 185 | " log_msg = \"\\nEpoch {} of {} -- average accuracy is {:.3f} (train) -- average loss is {:.3f}\"\n", 186 | " print( log_msg.format(epoch+1, epochs, np.asarray(accuracies).mean(), np.asarray(losses).mean()) )\n", 187 | " \n", 188 | " # validation log\n", 189 | " self.validation(xval, yval, batch_shape=batch_shape)\n", 190 | " \n", 191 | " #saving model\n", 192 | " if (save is not None) and (epoch == (epochs-1)):\n", 193 | " saver = tf.train.Saver()\n", 194 | " saver.save(self.sess, save)\n", 195 | " print( 'model saved' )\n", 196 | " \n", 197 | " #plotting value\n", 198 | " #plt.plot([l for loss in _losses for l in loss], color='magenta', linestyle='dashed', marker='s', linewidth=1)\n", 199 | " plt.plot([np.asarray(l).mean() for l in _losses], color='red', linestyle='solid', marker='o', linewidth=2)\n", 200 | " #plt.plot([a for acc in _accuracies for a in acc], color='cyan', linestyle='dashed', marker='s', linewidth=1)\n", 201 | " plt.plot([np.asarray(a).mean() for a in _accuracies], color='blue', linestyle='solid', marker='o', linewidth=2)\n", 202 | " plt.savefig('./train_{}_e{}_m{}_l{}.png'.format(loss_type, epochs, batch_shape[0], batch_shape[1]))\n", 203 | " \n", 204 | " def test(self, dataset, batch_shape=(64, 400)):\n", 205 | " print( 'Test...' )\n", 206 | " xtest = np.load( \"{}nltk_xtest.npy\".format(dataset) )\n", 207 | " ytest = np.load( \"{}nltk_ytest.npy\".format(dataset) )\n", 208 | " ytest = np.reshape(ytest, newshape=(ytest.shape[0], 1))\n", 209 | " \n", 210 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='test_labels')\n", 211 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='test_batch')\n", 212 | "\n", 213 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n", 214 | " \n", 215 | " accuracies = list()\n", 216 | " bar = ProgressBar(max_value=np.floor(len(ytest)/batch_shape[0]).astype('i'))\n", 217 | " minibatch = enumerate(self.get_minibatch(xtest, ytest, ul=None, batch_shape=batch_shape))\n", 218 | " for i, test_batch in minibatch:\n", 219 | " fd = {batch: test_batch['x'], labels: test_batch['y'], K.backend.learning_phase(): 0} #test mode\n", 220 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n", 221 | " \n", 222 | " bar.update(i)\n", 223 | " \n", 224 | " print( \"\\nAverage accuracy is {:.3f}\".format(np.asarray(accuracies).mean()) )\n" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": { 231 | "collapsed": true 232 | }, 233 | "outputs": [], 234 | "source": [ 235 | "def main(data, n_epochs, n_ex, ex_len, lt, pm):\n", 236 | " os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 237 | " config = tf.ConfigProto(log_device_placement=True)\n", 238 | " config.gpu_options.allow_growth = True\n", 239 | " session = tf.Session(config=config)\n", 240 | "\n", 241 | " embedding_weights = np.load( \"{}nltk_embedding_matrix.npy\".format(data) )\n", 242 | "\n", 243 | " net = Network(session, embedding_weights)\n", 244 | " net.train(data, batch_shape=(n_ex, ex_len), epochs=n_epochs, loss_type=lt, p_mult=pm, init=None, save=None)\n", 245 | " net.test(data, batch_shape=(n_ex, ex_len))\n", 246 | " \n", 247 | " K.backend.clear_session()" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": { 254 | "scrolled": false 255 | }, 256 | "outputs": [], 257 | "source": [ 258 | " main(data='../dataset/imdb/', n_epochs=10, n_ex=64, ex_len=400, lt='none', pm=0.02)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": { 265 | "collapsed": true 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "'''\n", 270 | "if (__name__ == '__main__'):\n", 271 | " parser = argparse.ArgumentParser()\n", 272 | " parser.add_argument('path', help='path of the folder that contains the data (train, test, emb. matrix)', type=str)\n", 273 | " parser.add_argument('epochs', help='number of training epochs', type=int)\n", 274 | " parser.add_argument('--nex', help='Number of EXamples per batch', default=64, type=int)\n", 275 | " parser.add_argument('--exlen', help='LENght of each EXample', default=400, type=int)\n", 276 | " parser.add_argument('--loss', help='define the loss type', choices=['none', 'adv', 'v_adv'], default='none', type=str)\n", 277 | " parser.add_argument('--p_mult', help='Perturbation MULTiplier (used with adv and v_adv)', default=0.02, type=float)\n", 278 | " args = parser.parse_args()\n", 279 | " \n", 280 | " main(args.path, args.epochs, args.nex, args.exlen, args.loss, args.p_mult)\n", 281 | "'''" 282 | ] 283 | } 284 | ], 285 | "metadata": { 286 | "kernelspec": { 287 | "display_name": "Python 3", 288 | "language": "python", 289 | "name": "python3" 290 | }, 291 | "language_info": { 292 | "codemirror_mode": { 293 | "name": "ipython", 294 | "version": 3 295 | }, 296 | "file_extension": ".py", 297 | "mimetype": "text/x-python", 298 | "name": "python", 299 | "nbconvert_exporter": "python", 300 | "pygments_lexer": "ipython3", 301 | "version": "3.6.4" 302 | } 303 | }, 304 | "nbformat": 4, 305 | "nbformat_minor": 2 306 | } 307 | -------------------------------------------------------------------------------- /pyramidal_network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf\n", 12 | "from tensorflow import keras as K\n", 13 | "import numpy as np\n", 14 | "import argparse\n", 15 | "from progressbar import ProgressBar\n", 16 | "import os\n", 17 | "\n", 18 | "'''\n", 19 | "import matplotlib\n", 20 | "matplotlib.use('Agg')\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "'''\n", 23 | "\n", 24 | "from time import sleep" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "class Network:\n", 36 | " def __init__(self, session, dict_weight, dropout=0.3, lstm_units=1024, dense_units=60):\n", 37 | " self.sess = session\n", 38 | " K.backend.set_session(self.sess)\n", 39 | " #defining layers\n", 40 | " dict_shape = dict_weight.shape\n", 41 | " self.emb = K.layers.Embedding(dict_shape[0], dict_shape[1], weights=[dict_weight], trainable=False, name='embedding')\n", 42 | " self.drop = K.layers.Dropout(rate=dropout, seed=91, name='dropout')\n", 43 | " self.lstm = K.layers.LSTM(lstm_units, stateful=False, return_sequences=False, name='lstm')\n", 44 | " \n", 45 | " self.dense = K.layers.Dense(dense_units, activation='relu', name='dense')\n", 46 | " self.p = K.layers.Dense(1, activation='sigmoid', name='p')\n", 47 | " #defining optimizer\n", 48 | " self.optimizer = tf.train.AdamOptimizer(learning_rate=0.0005)\n", 49 | "\n", 50 | " def __call__(self, batch, perturbation=None): \n", 51 | " batch1 = batch[:,0:400]\n", 52 | " batch2 = batch[:,400:800]\n", 53 | " batch3 = batch[:,800:1200]\n", 54 | " \n", 55 | " embedding1 = self.emb(batch1)\n", 56 | " embedding2 = self.emb(batch2)\n", 57 | " embedding3 = self.emb(batch3)\n", 58 | " \n", 59 | " drop1 = self.drop(embedding1)\n", 60 | " drop2 = self.drop(embedding2)\n", 61 | " drop3 = self.drop(embedding3)\n", 62 | " \n", 63 | " if (perturbation is not None):\n", 64 | " drop1 += perturbation[0]\n", 65 | " drop2 += perturbation[1]\n", 66 | " drop3 += perturbation[2]\n", 67 | " \n", 68 | " lstm1 = self.lstm(drop1)\n", 69 | " lstm2 = self.lstm(drop2)\n", 70 | " lstm3 = self.lstm(drop3)\n", 71 | " lstm = tf.concat([lstm1, lstm2, lstm3], axis=1)\n", 72 | " dense = self.dense(lstm)\n", 73 | " \n", 74 | " return self.p(dense), (embedding1, embedding2, embedding3)\n", 75 | " \n", 76 | " def get_minibatch(self, x, y, ul, batch_shape=(16, 1200)):\n", 77 | " x = K.preprocessing.sequence.pad_sequences(x, maxlen=batch_shape[1])\n", 78 | " permutations = np.random.permutation( len(y) )\n", 79 | " ul_permutations = None\n", 80 | " len_ratio = None\n", 81 | " if (ul is not None):\n", 82 | " ul = K.preprocessing.sequence.pad_sequences(ul, maxlen=batch_shape[1])\n", 83 | " ul_permutations = np.random.permutation(len(ul))\n", 84 | " len_ratio = len(ul)/len(y)\n", 85 | " for s in range(0, len(y), batch_shape[0]):\n", 86 | " perm = permutations[s:s+batch_shape[0]]\n", 87 | " minibatch = {'x': x[perm], 'y': y[perm]}\n", 88 | " if (ul is not None):\n", 89 | " ul_perm = ul_permutations[int(np.floor(len_ratio*s)):int(np.floor(len_ratio*(s+batch_shape[0])))]\n", 90 | " minibatch.update( {'ul': np.concatenate((ul[ul_perm], x[perm]), axis=0)} ) \n", 91 | " yield minibatch\n", 92 | " \n", 93 | " def get_loss(self, batch, labels):\n", 94 | " pred, emb = self(batch)\n", 95 | " loss = K.losses.binary_crossentropy(labels, pred)\n", 96 | " return tf.reduce_mean( loss ), emb\n", 97 | " \n", 98 | " def get_adv_loss(self, batch, labels, loss, emb, p_mult):\n", 99 | " g1 = tf.gradients(loss, emb[0], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n", 100 | " g2 = tf.gradients(loss, emb[1], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n", 101 | " g3 = tf.gradients(loss, emb[2], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n", 102 | " p_adv = list()\n", 103 | " p_adv.append( p_mult * tf.nn.l2_normalize(tf.stop_gradient(g1), dim=1) )\n", 104 | " p_adv.append( p_mult * tf.nn.l2_normalize(tf.stop_gradient(g2), dim=1) )\n", 105 | " p_adv.append( p_mult * tf.nn.l2_normalize(tf.stop_gradient(g3), dim=1) )\n", 106 | " adv_loss = K.losses.binary_crossentropy(labels, self(batch, p_adv)[0])\n", 107 | " return tf.reduce_mean( adv_loss )\n", 108 | " \n", 109 | " def validate(self, xval, yval, batch_shape=(64, 1200), log_path=None):\n", 110 | " print( 'Validation...' )\n", 111 | " \n", 112 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='labels')\n", 113 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='batch')\n", 114 | "\n", 115 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n", 116 | " accuracies = list()\n", 117 | " bar = ProgressBar(max_value=np.floor(len(yval)/batch_shape[0]).astype('i'))\n", 118 | " minibatch = enumerate(self.get_minibatch(xval, yval, None, batch_shape=batch_shape))\n", 119 | " for i, val_batch in minibatch:\n", 120 | " fd = {batch: val_batch['x'], labels: val_batch['y'], K.backend.learning_phase(): 0} #test mode\n", 121 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n", 122 | " bar.update(i)\n", 123 | " \n", 124 | " log_msg = \"\\nValidation Average accuracy is {:.3f} -- batch shape {}\"\n", 125 | " print( log_msg.format(np.asarray(accuracies).mean(), batch_shape) )\n", 126 | " log = None\n", 127 | " if(log_path is not None):\n", 128 | " log = open(log_path, 'a')\n", 129 | " log.write(log_msg.format(np.asarray(accuracies).mean(), batch_shape)+'\\n')\n", 130 | " log.close()\n", 131 | " \n", 132 | " def train(self, dataset, batch_shape=(64, 1200), epochs=10, loss_type='none', p_mult=0.02, save=None, log_path=None):\n", 133 | " print( 'Training...' )\n", 134 | " xtrain = np.load( \"{}nltk_xtrain.npy\".format(dataset) )\n", 135 | " ytrain = np.load( \"{}nltk_ytrain.npy\".format(dataset) )\n", 136 | " \n", 137 | " xval = list()\n", 138 | " yval = list()\n", 139 | " for i in range(int(len(ytrain)*0.025)):\n", 140 | " xval.append( xtrain[0] ); xval.append( xtrain[-1] )\n", 141 | " yval.append( ytrain[0] ); yval.append( ytrain[-1] )\n", 142 | " xtrain = np.delete(xtrain, 0); xtrain = np.delete(xtrain, -1)\n", 143 | " ytrain = np.delete(ytrain, 0); ytrain = np.delete(ytrain, -1)\n", 144 | " xval = np.asarray(xval)\n", 145 | " yval = np.asarray(yval)\n", 146 | " \n", 147 | " yval = np.reshape(yval, newshape=(yval.shape[0],1))\n", 148 | " ytrain = np.reshape(ytrain, newshape=(ytrain.shape[0], 1))\n", 149 | " \n", 150 | " ultrain = np.load( \"{}nltk_ultrain.npy\".format(dataset) ) if (loss_type == 'v_adv') else None \n", 151 | " \n", 152 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='labels')\n", 153 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='batch')\n", 154 | " ul_batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='ul_batch')\n", 155 | " \n", 156 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n", 157 | " loss, emb = self.get_loss(batch, labels)\n", 158 | " if (loss_type == 'adv'):\n", 159 | " loss += self.get_adv_loss(batch, labels, loss, emb, p_mult)\n", 160 | " elif (loss_type == 'v_adv'):\n", 161 | " loss += self.get_v_adv_loss(ul_batch, p_mult)\n", 162 | "\n", 163 | " opt = self.optimizer.minimize( loss )\n", 164 | " #initializing parameters\n", 165 | " self.sess.run( [var.initializer for var in tf.global_variables() if not('embedding' in var.name)] )\n", 166 | " \n", 167 | " _losses = list()\n", 168 | " _accuracies = list()\n", 169 | " log = None\n", 170 | " \n", 171 | " list_ratio = (len(ultrain)/len(ytrain)) if (ultrain is not None) else None\n", 172 | " for epoch in range(epochs):\n", 173 | " losses = list()\n", 174 | " accuracies = list()\n", 175 | " \n", 176 | " bar = ProgressBar(max_value=np.floor(len(ytrain)/batch_shape[0]).astype('i'))\n", 177 | " minibatch = enumerate(self.get_minibatch(xtrain, ytrain, ultrain, batch_shape=batch_shape))\n", 178 | " for i, train_batch in minibatch:\n", 179 | " fd = {batch: train_batch['x'], labels: train_batch['y'], K.backend.learning_phase(): 1} #training mode\n", 180 | " if (loss_type == 'v_adv'):\n", 181 | " fd.update( {ul_batch: train_batch['ul']} )\n", 182 | " \n", 183 | " _, acc_val, loss_val = self.sess.run([opt, accuracy, loss], feed_dict=fd)\n", 184 | " \n", 185 | " accuracies.append( acc_val )\n", 186 | " losses.append( loss_val )\n", 187 | " bar.update(i)\n", 188 | " \n", 189 | " _losses.append(losses)\n", 190 | " _accuracies.append(accuracies)\n", 191 | " \n", 192 | " log_msg = \"\\nEpoch {} of {} -- average accuracy is {:.3f} -- average loss is {:.3f}\"\n", 193 | " print( log_msg.format(epoch+1, epochs, np.asarray(accuracies).mean(), np.asarray(losses).mean()) )\n", 194 | " if(log_path is not None):\n", 195 | " log = open(log_path, 'a')\n", 196 | " log.write(log_msg.format(epoch+1, epochs, np.asarray(accuracies).mean(), np.asarray(losses).mean())+'\\n')\n", 197 | " log.close()\n", 198 | " \n", 199 | " #validation\n", 200 | " self.validate(xval, yval, batch_shape=batch_shape, log_path=log_path)\n", 201 | " \n", 202 | " def test(self, dataset, batch_shape=(64, 1200), log_path=None):\n", 203 | " print( 'Test...' )\n", 204 | " xtest = np.load( \"{}nltk_xtest.npy\".format(dataset) )\n", 205 | " ytest = np.load( \"{}nltk_ytest.npy\".format(dataset) )\n", 206 | " ytest = np.reshape(ytest, newshape=(ytest.shape[0], 1))\n", 207 | " \n", 208 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='labels')\n", 209 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='batch')\n", 210 | "\n", 211 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n", 212 | " \n", 213 | " accuracies = list()\n", 214 | " bar = ProgressBar(max_value=np.floor(len(ytest)/batch_shape[0]).astype('i'))\n", 215 | " minibatch = enumerate(self.get_minibatch(xtest, ytest, None, batch_shape=batch_shape))\n", 216 | " for i, test_batch in minibatch:\n", 217 | " fd = {batch: test_batch['x'], labels: test_batch['y'], K.backend.learning_phase(): 0} #test mode\n", 218 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n", 219 | " bar.update(i)\n", 220 | " \n", 221 | " log_msg = \"\\nTest Average accuracy is {:.3f} -- batch shape {}\"\n", 222 | " print( log_msg.format(np.asarray(accuracies).mean(), batch_shape) )\n", 223 | " \n", 224 | " log = None\n", 225 | " if(log_path is not None):\n", 226 | " log = open(log_path, 'a')\n", 227 | " log.write(log_msg.format(np.asarray(accuracies).mean(), batch_shape)+'\\n')\n", 228 | " log.close()" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "collapsed": true 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "def main(data, n_epochs, n_ex, ex_len, lt, pm):\n", 240 | " os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 241 | " config = tf.ConfigProto(log_device_placement=True)\n", 242 | " config.gpu_options.allow_growth = True\n", 243 | " session = tf.Session(config=config)\n", 244 | "\n", 245 | " embedding_weights = np.load( \"{}nltk_embedding_matrix.npy\".format(data) )\n", 246 | "\n", 247 | " # log text file\n", 248 | " net_tp = 'baseline' if (lt == 'none') else lt\n", 249 | " log_path = './pyramidal_{}_bs_{}_ep_{}_sl_{}.txt'.format(net_tp, n_ex, n_epochs, ex_len)\n", 250 | " log = open(log_path, 'w')\n", 251 | " log.write('Fancy Network with {} loss, {} epochs, {} batch size, {} maximum string length \\n'.format(net_tp, n_epochs, n_ex, ex_len))\n", 252 | " log.close()\n", 253 | " \n", 254 | " net = Network(session, embedding_weights)\n", 255 | " net.train(data, batch_shape=(n_ex, ex_len), epochs=n_epochs, loss_type=lt, p_mult=pm, log_path=log_path)\n", 256 | " net.test(data, batch_shape=(n_ex, ex_len), log_path=log_path)\n", 257 | " \n", 258 | " K.backend.clear_session()" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": { 265 | "collapsed": true 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "main(data='../dataset/imdb/', n_epochs=10, n_ex=16, ex_len=1200, lt='none', pm=0.02) # works only with ex_len=1200" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": { 276 | "collapsed": true 277 | }, 278 | "outputs": [], 279 | "source": [ 280 | "'''\n", 281 | "if (__name__ == '__main__'):\n", 282 | " parser = argparse.ArgumentParser()\n", 283 | " parser.add_argument('path', help='path of the folder that contains the data (train, test, emb. matrix)', type=str)\n", 284 | " parser.add_argument('epochs', help='number of training epochs', type=int)\n", 285 | " parser.add_argument('--nex', help='Number of EXamples per batch', default=64, type=int)\n", 286 | " parser.add_argument('--exlen', help='LENght of each EXample', default=1200, type=int)\n", 287 | " parser.add_argument('--loss', help='define the loss type', choices=['none', 'adv'], default='none', type=str)\n", 288 | " parser.add_argument('--p_mult', help='Perturbation MULTiplier (used with adv)', default=0.02, type=float)\n", 289 | " args = parser.parse_args()\n", 290 | " \n", 291 | " main(args.path, args.epochs, args.nex, args.exlen, args.loss, args.p_mult)\n", 292 | "'''" 293 | ] 294 | } 295 | ], 296 | "metadata": { 297 | "kernelspec": { 298 | "display_name": "Python 3", 299 | "language": "python", 300 | "name": "python3" 301 | }, 302 | "language_info": { 303 | "codemirror_mode": { 304 | "name": "ipython", 305 | "version": 3 306 | }, 307 | "file_extension": ".py", 308 | "mimetype": "text/x-python", 309 | "name": "python", 310 | "nbconvert_exporter": "python", 311 | "pygments_lexer": "ipython3", 312 | "version": "3.6.4" 313 | } 314 | }, 315 | "nbformat": 4, 316 | "nbformat_minor": 2 317 | } 318 | --------------------------------------------------------------------------------