├── images_readme
├── paper_model.png
└── pyramidal_model.png
├── LICENSE
├── README.md
├── pretrained_glove_embedding_script.ipynb
├── test_dataset_script.ipynb
├── .gitignore
├── glove_training_preparation.ipynb
├── dataset_script.ipynb
├── paper_network.ipynb
└── pyramidal_network.ipynb
/images_readme/paper_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/enricivi/adversarial_training_methods/HEAD/images_readme/paper_model.png
--------------------------------------------------------------------------------
/images_readme/pyramidal_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/enricivi/adversarial_training_methods/HEAD/images_readme/pyramidal_model.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Enrico Civitelli
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adversarial Training Methods
2 | Adversarial and virtual adversarial training methods for semi-supervised text classification.
3 |
4 | Based on the paper:
5 | [Adversarial training methods for semi-supervised text classification, ICLR 2017, Miyato T., Dai A., Goodfellow I.
6 | ](https://arxiv.org/abs/1605.07725)
7 |
8 | **Without the pre-training phase**
9 |
10 | ## Requirements
11 |
12 | Package | Version
13 | :-------: | :-------:
14 | [Python](https://www.python.org/downloads/) | 3.5.4
15 | [Jupyter](http://jupyter.org/install) | 1.0.0
16 | [Tensorflow](https://www.tensorflow.org/versions/r1.5/) | r1.5
17 | [GloVe](https://nlp.stanford.edu/projects/glove/) | 1.2
18 | [NLTK](http://www.nltk.org/install.html) | 3.2.5
19 | [ProgressBar2](https://pypi.python.org/pypi/progressbar2) | 3.34.3
20 | [Matplotlib](https://matplotlib.org/2.1.2/index.html) | 2.1.2
21 | [Argparse](https://pypi.python.org/pypi/argparse/1.1) | 1.1
22 |
23 |
24 | ## Dataset creation
25 |
26 | Download the [IMDB](http://ai.stanford.edu/~amaas/data/sentiment/) dataset.
27 |
28 | Then we have to do the following steps:
29 | * prepare dataset in order to train GloVe using _glove_training_preparation.ipynb_
30 | * set the vectors length at 256 and remove words appearing in less than 3 reviews in the GloVe training script
31 | * run GloVe training script
32 | * create the embedding matrix (used by the network) and the dictionary (used to convert reviews in sequences of index) using _pretrained_glove_embedding_script.ipynb_
33 | * convert reviews in sequences of index using _dataset_script.pynb_ and _test_dataset_script.ipynb_
34 |
35 | ## Training
36 |
37 | 5% of the training set is used for validation
38 |
39 | use _paper_network.ipynb_ or _pyramidal_network.ipynb_ according to the network architecture you want to use.
40 |
41 | Paper model | Pyramidal model
42 | :---------: | :-------------:
43 |  | 
44 |
45 | ## Results
46 |
47 | Sequences are truncated at 1200 (pyramidal model), 600 and 400 to test the sensitivity of the model to reviews lengths. In particular we cut off or add zero-padding at the initial part of the review.
48 |
49 | Method | Seq. Length | Epochs | Accuracy
50 | :------: | :-----------: | :------: | :--------:
51 | baseline | 400 | 10 | 0.906
52 | adversarial | 400 | 10 | 0.914
53 | virtual adv. | 400 | 10 | **0.921**
54 | baseline | 600 | 10 | 0.904
55 | adversarial | 600 | 10 | 0.912
56 | pyramidal baseline | 1200 | 10 | 0.910
57 | pyramidal adversarial | 1200 | 10 | 0.916
58 |
--------------------------------------------------------------------------------
/pretrained_glove_embedding_script.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {
18 | "collapsed": true
19 | },
20 | "outputs": [],
21 | "source": [
22 | "# building dictionary word:word_embedding\n",
23 | "\n",
24 | "word_embedding = dict()\n",
25 | "dictionary = dict()\n",
26 | "vectors = open(\"./GloVe-1.2/vectors.txt\", \"r\")\n",
27 | "for i, line in enumerate(vectors):\n",
28 | " sline = line.split(' ')\n",
29 | " try:\n",
30 | " word = sline[0]\n",
31 | " word_embedding[word] = np.asarray(sline[1:], dtype='float32')\n",
32 | " dictionary[word] = i+1\n",
33 | " except:\n",
34 | " print(\"error at index {}\".format(i))"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "collapsed": true,
42 | "scrolled": true
43 | },
44 | "outputs": [],
45 | "source": [
46 | "# building embedding matrix (used by the network)\n",
47 | "embedding = np.zeros( shape=(len(dictionary.keys())+1, len(list(word_embedding.values())[0])))\n",
48 | "for word in dictionary.keys():\n",
49 | " idx = dictionary[word]\n",
50 | " embedding[idx] = word_embedding[word]"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "collapsed": true
58 | },
59 | "outputs": [],
60 | "source": [
61 | "np.save(\"./dataset/imdb/nltk_embedding_matrix.npy\", embedding)\n",
62 | "np.save(\"./dataset/imdb/nltk_dictionary.npy\", dictionary)\n",
63 | "np.save(\"./dataset/imdb/nltk_word_embedding.npy\", word_embedding)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "collapsed": true
71 | },
72 | "outputs": [],
73 | "source": []
74 | }
75 | ],
76 | "metadata": {
77 | "kernelspec": {
78 | "display_name": "Python 3",
79 | "language": "python",
80 | "name": "python3"
81 | },
82 | "language_info": {
83 | "codemirror_mode": {
84 | "name": "ipython",
85 | "version": 3
86 | },
87 | "file_extension": ".py",
88 | "mimetype": "text/x-python",
89 | "name": "python",
90 | "nbconvert_exporter": "python",
91 | "pygments_lexer": "ipython3",
92 | "version": "3.6.2"
93 | }
94 | },
95 | "nbformat": 4,
96 | "nbformat_minor": 2
97 | }
98 |
--------------------------------------------------------------------------------
/test_dataset_script.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 10,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "from os import listdir\n",
13 | "from os.path import isfile, join\n",
14 | "from nltk import word_tokenize\n",
15 | "from nltk import RegexpTokenizer\n",
16 | "tokenizer = RegexpTokenizer(r'\\w+')"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 11,
22 | "metadata": {
23 | "collapsed": true
24 | },
25 | "outputs": [],
26 | "source": [
27 | "dictionary = np.load(\"./dataset/imdb/nltk_dictionary.npy\").item()"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 12,
33 | "metadata": {
34 | "collapsed": true
35 | },
36 | "outputs": [],
37 | "source": [
38 | "pos_files = [f for f in listdir(\"./dataset/imdb/test/pos\") if ( isfile(join(\"./dataset/imdb/test/pos\", f)) and (f[0] != '.') ) ]\n",
39 | "neg_files = [f for f in listdir(\"./dataset/imdb/test/neg\") if ( isfile(join(\"./dataset/imdb/test/neg\", f)) and (f[0] != '.') ) ]"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 13,
45 | "metadata": {
46 | "collapsed": true
47 | },
48 | "outputs": [],
49 | "source": [
50 | "xtest = list()\n",
51 | "ytest = list()\n",
52 | "for p in pos_files:\n",
53 | " pos = open(\"./dataset/imdb/test/pos/\"+p, 'r')\n",
54 | " for line in pos:\n",
55 | " if line:\n",
56 | " line = line.replace('
', ' ')\n",
57 | " tokens = tokenizer.tokenize(line.lower())\n",
58 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n",
59 | " \n",
60 | " xtest.append(tokens)\n",
61 | " ytest.append(1)\n",
62 | " \n",
63 | " else:\n",
64 | " print(\"empty line: {}.\".format(line))\n",
65 | " pos.close()\n",
66 | " \n",
67 | "for n in neg_files:\n",
68 | " neg = open(\"./dataset/imdb/test/neg/\"+n, 'r')\n",
69 | " for line in neg:\n",
70 | " if line:\n",
71 | " line = line.replace('
', ' ')\n",
72 | " tokens = tokenizer.tokenize(line.lower())\n",
73 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n",
74 | " \n",
75 | " xtest.append(tokens)\n",
76 | " ytest.append(0)\n",
77 | " \n",
78 | " else:\n",
79 | " print(\"empty line: {}.\".format(line))\n",
80 | " neg.close()"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 14,
86 | "metadata": {
87 | "collapsed": true
88 | },
89 | "outputs": [],
90 | "source": [
91 | "xtest = np.asarray(xtest)\n",
92 | "ytest = np.asarray(ytest)"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 15,
98 | "metadata": {
99 | "collapsed": true
100 | },
101 | "outputs": [],
102 | "source": [
103 | "np.save(\"./dataset/imdb/nltk_xtest.npy\", xtest)\n",
104 | "np.save(\"./dataset/imdb/nltk_ytest.npy\", ytest)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "metadata": {
111 | "collapsed": true
112 | },
113 | "outputs": [],
114 | "source": [
115 | "max( [len(s) for s in x] )"
116 | ]
117 | }
118 | ],
119 | "metadata": {
120 | "kernelspec": {
121 | "display_name": "Python 3",
122 | "language": "python",
123 | "name": "python3"
124 | },
125 | "language_info": {
126 | "codemirror_mode": {
127 | "name": "ipython",
128 | "version": 3
129 | },
130 | "file_extension": ".py",
131 | "mimetype": "text/x-python",
132 | "name": "python",
133 | "nbconvert_exporter": "python",
134 | "pygments_lexer": "ipython3",
135 | "version": "3.6.2"
136 | }
137 | },
138 | "nbformat": 4,
139 | "nbformat_minor": 2
140 | }
141 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.gitignore.io/api/linux,macos,python,windows,pycharm
2 |
3 | ### Linux ###
4 | *~
5 |
6 | # temporary files which can be created if a process still has a handle open of a deleted file
7 | .fuse_hidden*
8 |
9 | # KDE directory preferences
10 | .directory
11 |
12 | # Linux trash folder which might appear on any partition or disk
13 | .Trash-*
14 |
15 | # .nfs files are created when an open file is removed but is still being accessed
16 | .nfs*
17 |
18 | ### macOS ###
19 | *.DS_Store
20 | .AppleDouble
21 | .LSOverride
22 |
23 | # Icon must end with two \r
24 | Icon
25 |
26 | # Thumbnails
27 | ._*
28 |
29 | # Files that might appear in the root of a volume
30 | .DocumentRevisions-V100
31 | .fseventsd
32 | .Spotlight-V100
33 | .TemporaryItems
34 | .Trashes
35 | .VolumeIcon.icns
36 | .com.apple.timemachine.donotpresent
37 |
38 | # Directories potentially created on remote AFP share
39 | .AppleDB
40 | .AppleDesktop
41 | Network Trash Folder
42 | Temporary Items
43 | .apdisk
44 |
45 | ### PyCharm ###
46 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
47 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
48 |
49 | # User-specific stuff:
50 | .idea/**/workspace.xml
51 | .idea/**/tasks.xml
52 | .idea/dictionaries
53 |
54 | # Sensitive or high-churn files:
55 | .idea/**/dataSources/
56 | .idea/**/dataSources.ids
57 | .idea/**/dataSources.xml
58 | .idea/**/dataSources.local.xml
59 | .idea/**/sqlDataSources.xml
60 | .idea/**/dynamic.xml
61 | .idea/**/uiDesigner.xml
62 |
63 | # Gradle:
64 | .idea/**/gradle.xml
65 | .idea/**/libraries
66 |
67 | # CMake
68 | cmake-build-debug/
69 |
70 | # Mongo Explorer plugin:
71 | .idea/**/mongoSettings.xml
72 |
73 | ## File-based project format:
74 | *.iws
75 |
76 | ## Plugin-specific files:
77 |
78 | # IntelliJ
79 | /out/
80 |
81 | # mpeltonen/sbt-idea plugin
82 | .idea_modules/
83 |
84 | # JIRA plugin
85 | atlassian-ide-plugin.xml
86 |
87 | # Cursive Clojure plugin
88 | .idea/replstate.xml
89 |
90 | # Ruby plugin and RubyMine
91 | /.rakeTasks
92 |
93 | # Crashlytics plugin (for Android Studio and IntelliJ)
94 | com_crashlytics_export_strings.xml
95 | crashlytics.properties
96 | crashlytics-build.properties
97 | fabric.properties
98 |
99 | ### PyCharm Patch ###
100 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
101 |
102 | # *.iml
103 | # modules.xml
104 | # .idea/misc.xml
105 | # *.ipr
106 |
107 | # Sonarlint plugin
108 | .idea/sonarlint
109 |
110 | ### Python ###
111 | # Byte-compiled / optimized / DLL files
112 | __pycache__/
113 | *.py[cod]
114 | *$py.class
115 |
116 | # C extensions
117 | *.so
118 |
119 | # Distribution / packaging
120 | .Python
121 | build/
122 | develop-eggs/
123 | dist/
124 | downloads/
125 | eggs/
126 | .eggs/
127 | lib/
128 | lib64/
129 | parts/
130 | sdist/
131 | var/
132 | wheels/
133 | *.egg-info/
134 | .installed.cfg
135 | *.egg
136 |
137 | # PyInstaller
138 | # Usually these files are written by a python script from a template
139 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
140 | *.manifest
141 | *.spec
142 |
143 | # Installer logs
144 | pip-log.txt
145 | pip-delete-this-directory.txt
146 |
147 | # Unit test / coverage reports
148 | htmlcov/
149 | .tox/
150 | .coverage
151 | .coverage.*
152 | .cache
153 | .pytest_cache/
154 | nosetests.xml
155 | coverage.xml
156 | *.cover
157 | .hypothesis/
158 |
159 | # Translations
160 | *.mo
161 | *.pot
162 |
163 | # Flask stuff:
164 | instance/
165 | .webassets-cache
166 |
167 | # Scrapy stuff:
168 | .scrapy
169 |
170 | # Sphinx documentation
171 | docs/_build/
172 |
173 | # PyBuilder
174 | target/
175 |
176 | # Jupyter Notebook
177 | .ipynb_checkpoints
178 |
179 | # pyenv
180 | .python-version
181 |
182 | # celery beat schedule file
183 | celerybeat-schedule.*
184 |
185 | # SageMath parsed files
186 | *.sage.py
187 |
188 | # Environments
189 | .env
190 | .venv
191 | env/
192 | venv/
193 | ENV/
194 | env.bak/
195 | venv.bak/
196 |
197 | # Spyder project settings
198 | .spyderproject
199 | .spyproject
200 |
201 | # Rope project settings
202 | .ropeproject
203 |
204 | # mkdocs documentation
205 | /site
206 |
207 | # mypy
208 | .mypy_cache/
209 |
210 | ### Windows ###
211 | # Windows thumbnail cache files
212 | Thumbs.db
213 | ehthumbs.db
214 | ehthumbs_vista.db
215 |
216 | # Folder config file
217 | Desktop.ini
218 |
219 | # Recycle Bin used on file shares
220 | $RECYCLE.BIN/
221 |
222 | # Windows Installer files
223 | *.cab
224 | *.msi
225 | *.msm
226 | *.msp
227 |
228 | # Windows shortcuts
229 | *.lnk
230 |
231 |
232 | # End of https://www.gitignore.io/api/linux,macos,python,windows,pycharm
233 |
--------------------------------------------------------------------------------
/glove_training_preparation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# TAKE ALL REVIEWS(POS NEG AND UNSUP) AND MERGE THEM IN A SINGLE TEXT FILE WITH ONE REVIEW PER LINE (merged.txt)\n",
12 | "# PERFORM TEXT TOKENIZATION\n",
13 | "\n",
14 | "# REWRITE ALL THE TOKENIZED CORPUS AS A SINGLE WORDS SEQUENCE SEPARATED BY A BLANK SPACE (new_merged).\n",
15 | "# This file is used as input for GloVe training."
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {
22 | "collapsed": true
23 | },
24 | "outputs": [],
25 | "source": [
26 | "from os import listdir\n",
27 | "from os.path import isfile, join\n",
28 | "from nltk import RegexpTokenizer\n",
29 | "tokenizer = RegexpTokenizer(r'\\w+')"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {
36 | "collapsed": true
37 | },
38 | "outputs": [],
39 | "source": [
40 | "# new_path = \"./\"\n",
41 | "old_path = \"./dataset/imdb/train/\""
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {
48 | "collapsed": true
49 | },
50 | "outputs": [],
51 | "source": [
52 | "# support file \n",
53 | "merged = open(\"./merged.txt\", \"w\")"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "collapsed": true
61 | },
62 | "outputs": [],
63 | "source": [
64 | "current_directory = old_path + 'pos'\n",
65 | "files = [f for f in listdir(current_directory) if ( isfile(join(current_directory, f)) and (f[0] != '.') ) ]\n",
66 | "for f in files:\n",
67 | " t = open(current_directory+'/'+f, 'r')\n",
68 | " for line in t:\n",
69 | " if line:\n",
70 | " line = line.replace('
', ' ')\n",
71 | " merged.write(line+'\\n')\n",
72 | " t.close()\n",
73 | "\n",
74 | "current_directory = old_path + 'neg'\n",
75 | "files = [f for f in listdir(current_directory) if ( isfile(join(current_directory, f)) and (f[0] != '.') ) ]\n",
76 | "for f in files:\n",
77 | " t = open(current_directory+'/'+f, 'r')\n",
78 | " for line in t:\n",
79 | " if line:\n",
80 | " line = line.replace('
', ' ')\n",
81 | " merged.write(line+'\\n')\n",
82 | " t.close()\n",
83 | "\n",
84 | "current_directory = old_path + 'unsup'\n",
85 | "files = [f for f in listdir(current_directory) if ( isfile(join(current_directory, f)) and (f[0] != '.') ) ]\n",
86 | "for f in files:\n",
87 | " t = open(current_directory+'/'+f, 'r')\n",
88 | " for line in t:\n",
89 | " if line:\n",
90 | " line = line.replace('
', ' ')\n",
91 | " merged.write(line+'\\n')\n",
92 | " t.close()\n",
93 | "\n",
94 | "merged.close()"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {
101 | "collapsed": true
102 | },
103 | "outputs": [],
104 | "source": [
105 | "corpus = open(\"./merged.txt\", \"r\")\n",
106 | "new_corpus = open(\"./GloVe-1.2/new_merged\", \"w\")"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {
113 | "collapsed": true
114 | },
115 | "outputs": [],
116 | "source": [
117 | "lines = corpus.readlines()\n",
118 | "text = ''\n",
119 | "\n",
120 | "# put everithing in one line\n",
121 | "for line in lines:\n",
122 | " text = text + line\n",
123 | "corpus.close()\n",
124 | "\n",
125 | "# perform TOKENIZATION, returns a vector of words\n",
126 | "words = tokenizer.tokenize(text.lower())\n",
127 | "\n",
128 | "# rewrite this vector in a file in which words are separated by blank spaces\n",
129 | "for w in words:\n",
130 | " new_corpus.seek(0, 2)\n",
131 | " new_corpus.write(w + ' ')\n",
132 | "new_corpus.close()"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {
139 | "collapsed": true
140 | },
141 | "outputs": [],
142 | "source": []
143 | }
144 | ],
145 | "metadata": {
146 | "kernelspec": {
147 | "display_name": "Python 3",
148 | "language": "python",
149 | "name": "python3"
150 | },
151 | "language_info": {
152 | "codemirror_mode": {
153 | "name": "ipython",
154 | "version": 3
155 | },
156 | "file_extension": ".py",
157 | "mimetype": "text/x-python",
158 | "name": "python",
159 | "nbconvert_exporter": "python",
160 | "pygments_lexer": "ipython3",
161 | "version": "3.6.2"
162 | }
163 | },
164 | "nbformat": 4,
165 | "nbformat_minor": 2
166 | }
167 |
--------------------------------------------------------------------------------
/dataset_script.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "from os import listdir\n",
13 | "from os.path import isfile, join\n",
14 | "from nltk import word_tokenize\n",
15 | "from nltk import RegexpTokenizer\n",
16 | "tokenizer = RegexpTokenizer(r'\\w+')"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {
23 | "collapsed": true
24 | },
25 | "outputs": [],
26 | "source": [
27 | "dictionary = np.load(\"./dataset/imdb/nltk_dictionary.npy\").item()"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "collapsed": true
35 | },
36 | "outputs": [],
37 | "source": [
38 | "pos_files = [f for f in listdir(\"./dataset/imdb/train/pos\") if ( isfile(join(\"./dataset/imdb/train/pos\", f)) and (f[0] != '.') ) ]\n",
39 | "neg_files = [f for f in listdir(\"./dataset/imdb/train/neg\") if ( isfile(join(\"./dataset/imdb/train/neg\", f)) and (f[0] != '.') ) ]\n",
40 | "unl_files = [f for f in listdir(\"./dataset/imdb/train/unsup\") if ( isfile(join(\"./dataset/imdb/train/unsup\", f)) and (f[0] != '.') ) ]"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "print( \"Labeled...\" )\n",
50 | "xtrain = list()\n",
51 | "ytrain = list()\n",
52 | "for p in pos_files:\n",
53 | " pos = open(\"./dataset/imdb/train/pos/\"+p, 'r')\n",
54 | " for line in pos:\n",
55 | " if line:\n",
56 | " line = line.replace('
', ' ')\n",
57 | " tokens = tokenizer.tokenize(line.lower())\n",
58 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n",
59 | " \n",
60 | " xtrain.append(tokens)\n",
61 | " ytrain.append(1) \n",
62 | " else:\n",
63 | " print(\"empty line: {}.\".format(line))\n",
64 | " pos.close()\n",
65 | " \n",
66 | "for n in neg_files:\n",
67 | " neg = open(\"./dataset/imdb/train/neg/\"+n, 'r')\n",
68 | " for line in neg:\n",
69 | " if line:\n",
70 | " line = line.replace('
', ' ')\n",
71 | " tokens = tokenizer.tokenize(line.lower())\n",
72 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n",
73 | " \n",
74 | " xtrain.append(tokens)\n",
75 | " ytrain.append(0) \n",
76 | " else:\n",
77 | " print(\"empty line: {}.\".format(line))\n",
78 | " neg.close()\n",
79 | "\n",
80 | "print( \"Unlabeled...\" )\n",
81 | "unlab = list()\n",
82 | "for u in unl_files:\n",
83 | " unl = open(\"./dataset/imdb/train/unsup/\"+u, 'r')\n",
84 | " for line in unl:\n",
85 | " if line:\n",
86 | " line = line.replace('
', ' ')\n",
87 | " tokens = tokenizer.tokenize(line.lower())\n",
88 | " tokens = [ (dictionary[t] if (t in dictionary) else dictionary['']) for t in tokens ] \n",
89 | " \n",
90 | " unlab.append(tokens) \n",
91 | " else:\n",
92 | " print(\"empty line: {}.\".format(line))\n",
93 | " unl.close()"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {
100 | "collapsed": true
101 | },
102 | "outputs": [],
103 | "source": [
104 | "xtrain = np.asarray(xtrain)\n",
105 | "ytrain = np.asarray(ytrain)\n",
106 | "unlab = np.asarray(unlab)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {
113 | "collapsed": true
114 | },
115 | "outputs": [],
116 | "source": [
117 | "np.save(\"./dataset/imdb/nltk_xtrain.npy\", xtrain)\n",
118 | "np.save(\"./dataset/imdb/nltk_ytrain.npy\", ytrain)\n",
119 | "np.save(\"./dataset/imdb/nltk_ultrain.npy\", unlab)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {
126 | "collapsed": true
127 | },
128 | "outputs": [],
129 | "source": []
130 | }
131 | ],
132 | "metadata": {
133 | "kernelspec": {
134 | "display_name": "Python 3",
135 | "language": "python",
136 | "name": "python3"
137 | },
138 | "language_info": {
139 | "codemirror_mode": {
140 | "name": "ipython",
141 | "version": 3
142 | },
143 | "file_extension": ".py",
144 | "mimetype": "text/x-python",
145 | "name": "python",
146 | "nbconvert_exporter": "python",
147 | "pygments_lexer": "ipython3",
148 | "version": "3.6.2"
149 | }
150 | },
151 | "nbformat": 4,
152 | "nbformat_minor": 2
153 | }
154 |
--------------------------------------------------------------------------------
/paper_network.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import tensorflow as tf\n",
12 | "from tensorflow import keras as K\n",
13 | "import numpy as np\n",
14 | "import argparse\n",
15 | "from progressbar import ProgressBar\n",
16 | "import os\n",
17 | "import matplotlib\n",
18 | "matplotlib.use('Agg')\n",
19 | "import matplotlib.pyplot as plt"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {
26 | "collapsed": true
27 | },
28 | "outputs": [],
29 | "source": [
30 | "class Network:\n",
31 | " def __init__(self, session, dict_weight, dropout=0.2, lstm_units=1024, dense_units=30):\n",
32 | " self.sess = session\n",
33 | " K.backend.set_session(self.sess)\n",
34 | " #defining layers\n",
35 | " dict_shape = dict_weight.shape\n",
36 | " self.emb = K.layers.Embedding(dict_shape[0], dict_shape[1], weights=[dict_weight], trainable=False, name='embedding')\n",
37 | " self.drop = K.layers.Dropout(rate=dropout, seed=91, name='dropout')\n",
38 | " self.lstm = K.layers.LSTM(lstm_units, stateful=False, return_sequences=False, name='lstm')\n",
39 | " self.dense = K.layers.Dense(dense_units, activation='relu', name='dense')\n",
40 | " self.p = K.layers.Dense(1, activation='sigmoid', name='p')\n",
41 | " #defining optimizer\n",
42 | " self.optimizer = tf.train.AdamOptimizer(learning_rate=0.0005)\n",
43 | " # self.optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001)\n",
44 | "\n",
45 | " def __call__(self, batch, perturbation=None):\n",
46 | " embedding = self.emb(batch) \n",
47 | " drop = self.drop(embedding)\n",
48 | " if (perturbation is not None):\n",
49 | " drop += perturbation\n",
50 | " lstm = self.lstm(drop)\n",
51 | " dense = self.dense(lstm)\n",
52 | " return self.p(dense), embedding\n",
53 | " \n",
54 | " def get_minibatch(self, x, y, ul, batch_shape=(64, 400)):\n",
55 | " x = K.preprocessing.sequence.pad_sequences(x, maxlen=batch_shape[1])\n",
56 | " permutations = np.random.permutation( len(y) )\n",
57 | " ul_permutations = None\n",
58 | " len_ratio = None\n",
59 | " if (ul is not None):\n",
60 | " ul = K.preprocessing.sequence.pad_sequences(ul, maxlen=batch_shape[1])\n",
61 | " ul_permutations = np.random.permutation( len(ul) )\n",
62 | " len_ratio = len(ul)/len(y)\n",
63 | " for s in range(0, len(y), batch_shape[0]):\n",
64 | " perm = permutations[s:s+batch_shape[0]]\n",
65 | " minibatch = {'x': x[perm], 'y': y[perm]}\n",
66 | " if (ul is not None):\n",
67 | " ul_perm = ul_permutations[int(np.floor(len_ratio*s)):int(np.floor(len_ratio*(s+batch_shape[0])))]\n",
68 | " minibatch.update( {'ul': np.concatenate((ul[ul_perm], x[perm]), axis=0)} )\n",
69 | " yield minibatch\n",
70 | " \n",
71 | " def get_loss(self, batch, labels):\n",
72 | " pred, emb = self(batch)\n",
73 | " loss = K.losses.binary_crossentropy(labels, pred)\n",
74 | " return tf.reduce_mean( loss ), emb\n",
75 | " \n",
76 | " def get_adv_loss(self, batch, labels, loss, emb, p_mult):\n",
77 | " gradient = tf.gradients(loss, emb, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n",
78 | " p_adv = p_mult * tf.nn.l2_normalize(tf.stop_gradient(gradient), dim=1)\n",
79 | " adv_loss = K.losses.binary_crossentropy(labels, self(batch, p_adv)[0])\n",
80 | " return tf.reduce_mean( adv_loss )\n",
81 | " \n",
82 | " def get_v_adv_loss(self, ul_batch, p_mult, power_iterations=1):\n",
83 | " bernoulli = tf.distributions.Bernoulli\n",
84 | " prob, emb = self(ul_batch)\n",
85 | " prob = tf.clip_by_value(prob, 1e-7, 1.-1e-7)\n",
86 | " prob_dist = bernoulli(probs=prob)\n",
87 | " #generate virtual adversarial perturbation\n",
88 | " d = tf.random_uniform(shape=tf.shape(emb), dtype=tf.float32)\n",
89 | " for _ in range( power_iterations ):\n",
90 | " d = (0.02) * tf.nn.l2_normalize(d, dim=1)\n",
91 | " p_prob = tf.clip_by_value(self(ul_batch, d)[0], 1e-7, 1.-1e-7)\n",
92 | " kl = tf.distributions.kl_divergence(prob_dist, bernoulli(probs=p_prob), allow_nan_stats=False)\n",
93 | " gradient = tf.gradients(kl, [d], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n",
94 | " d = tf.stop_gradient(gradient)\n",
95 | " d = p_mult * tf.nn.l2_normalize(d, dim=1)\n",
96 | " tf.stop_gradient(prob)\n",
97 | " #virtual adversarial loss\n",
98 | " p_prob = tf.clip_by_value(self(ul_batch, d)[0], 1e-7, 1.-1e-7)\n",
99 | " v_adv_loss = tf.distributions.kl_divergence(prob_dist, bernoulli(probs=p_prob), allow_nan_stats=False)\n",
100 | " return tf.reduce_mean( v_adv_loss )\n",
101 | "\n",
102 | " def validation(self, x, y, batch_shape=(64, 400)):\n",
103 | " print( 'Validation...' )\n",
104 | " \n",
105 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='validation_labels')\n",
106 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='validation_batch')\n",
107 | "\n",
108 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n",
109 | " \n",
110 | " accuracies = list()\n",
111 | " minibatch = self.get_minibatch(x, y, ul=None, batch_shape=batch_shape)\n",
112 | " for val_batch in minibatch:\n",
113 | " fd = {batch: val_batch['x'], labels: val_batch['y'], K.backend.learning_phase(): 0} #test mode\n",
114 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n",
115 | " \n",
116 | " print( \"Average accuracy on validation is {:.3f}\".format(np.asarray(accuracies).mean()) )\n",
117 | " \n",
118 | " def train(self, dataset, batch_shape=(64, 400), epochs=10, loss_type='none', p_mult=0.02, init=None, save=None):\n",
119 | " print( 'Training...' )\n",
120 | " xtrain = np.load( \"{}nltk_xtrain.npy\".format(dataset) )\n",
121 | " ytrain = np.load( \"{}nltk_ytrain.npy\".format(dataset) )\n",
122 | " ultrain = np.load( \"{}nltk_ultrain.npy\".format(dataset) ) if (loss_type == 'v_adv') else None\n",
123 | " \n",
124 | " # defining validation set\n",
125 | " xval = list()\n",
126 | " yval = list()\n",
127 | " for _ in range( int(len(ytrain)*0.025) ):\n",
128 | " xval.append( xtrain[0] ); xval.append( xtrain[-1] )\n",
129 | " yval.append( ytrain[0] ); yval.append( ytrain[-1] )\n",
130 | " xtrain = np.delete(xtrain, 0); xtrain = np.delete(xtrain, -1)\n",
131 | " ytrain = np.delete(ytrain, 0); ytrain = np.delete(ytrain, -1)\n",
132 | " xval = np.asarray(xval)\n",
133 | " yval = np.asarray(yval)\n",
134 | " print( '{} elements in validation set'.format(len(yval)) )\n",
135 | " # ---\n",
136 | " yval = np.reshape(yval, newshape=(yval.shape[0], 1))\n",
137 | " ytrain = np.reshape(ytrain, newshape=(ytrain.shape[0], 1))\n",
138 | " \n",
139 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='train_labels')\n",
140 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='train_batch')\n",
141 | " ul_batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='ul_batch')\n",
142 | " \n",
143 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n",
144 | " loss, emb = self.get_loss(batch, labels)\n",
145 | " if (loss_type == 'adv'):\n",
146 | " loss += self.get_adv_loss(batch, labels, loss, emb, p_mult)\n",
147 | " elif (loss_type == 'v_adv'):\n",
148 | " loss += self.get_v_adv_loss(ul_batch, p_mult)\n",
149 | "\n",
150 | " opt = self.optimizer.minimize( loss )\n",
151 | " #initializing parameters\n",
152 | " if (init is None):\n",
153 | " self.sess.run( [var.initializer for var in tf.global_variables() if not('embedding' in var.name)] )\n",
154 | " print( 'Random initialization' )\n",
155 | " else:\n",
156 | " saver = tf.train.Saver()\n",
157 | " saver.restore(self.sess, init)\n",
158 | " print( 'Restored value' )\n",
159 | " \n",
160 | " _losses = list()\n",
161 | " _accuracies = list()\n",
162 | " list_ratio = (len(ultrain)/len(ytrain)) if (ultrain is not None) else None\n",
163 | " for epoch in range(epochs):\n",
164 | " losses = list()\n",
165 | " accuracies = list()\n",
166 | " validation = list()\n",
167 | " \n",
168 | " bar = ProgressBar(max_value=np.floor(len(ytrain)/batch_shape[0]).astype('i'))\n",
169 | " minibatch = enumerate(self.get_minibatch(xtrain, ytrain, ultrain, batch_shape=batch_shape))\n",
170 | " for i, train_batch in minibatch:\n",
171 | " fd = {batch: train_batch['x'], labels: train_batch['y'], K.backend.learning_phase(): 1} #training mode\n",
172 | " if (loss_type == 'v_adv'):\n",
173 | " fd.update( {ul_batch: train_batch['ul']} )\n",
174 | " \n",
175 | " _, acc_val, loss_val = self.sess.run([opt, accuracy, loss], feed_dict=fd)\n",
176 | " \n",
177 | " accuracies.append( acc_val )\n",
178 | " losses.append( loss_val )\n",
179 | " bar.update(i)\n",
180 | " \n",
181 | " #saving accuracies and losses\n",
182 | " _accuracies.append( accuracies )\n",
183 | " _losses.append(losses)\n",
184 | " \n",
185 | " log_msg = \"\\nEpoch {} of {} -- average accuracy is {:.3f} (train) -- average loss is {:.3f}\"\n",
186 | " print( log_msg.format(epoch+1, epochs, np.asarray(accuracies).mean(), np.asarray(losses).mean()) )\n",
187 | " \n",
188 | " # validation log\n",
189 | " self.validation(xval, yval, batch_shape=batch_shape)\n",
190 | " \n",
191 | " #saving model\n",
192 | " if (save is not None) and (epoch == (epochs-1)):\n",
193 | " saver = tf.train.Saver()\n",
194 | " saver.save(self.sess, save)\n",
195 | " print( 'model saved' )\n",
196 | " \n",
197 | " #plotting value\n",
198 | " #plt.plot([l for loss in _losses for l in loss], color='magenta', linestyle='dashed', marker='s', linewidth=1)\n",
199 | " plt.plot([np.asarray(l).mean() for l in _losses], color='red', linestyle='solid', marker='o', linewidth=2)\n",
200 | " #plt.plot([a for acc in _accuracies for a in acc], color='cyan', linestyle='dashed', marker='s', linewidth=1)\n",
201 | " plt.plot([np.asarray(a).mean() for a in _accuracies], color='blue', linestyle='solid', marker='o', linewidth=2)\n",
202 | " plt.savefig('./train_{}_e{}_m{}_l{}.png'.format(loss_type, epochs, batch_shape[0], batch_shape[1]))\n",
203 | " \n",
204 | " def test(self, dataset, batch_shape=(64, 400)):\n",
205 | " print( 'Test...' )\n",
206 | " xtest = np.load( \"{}nltk_xtest.npy\".format(dataset) )\n",
207 | " ytest = np.load( \"{}nltk_ytest.npy\".format(dataset) )\n",
208 | " ytest = np.reshape(ytest, newshape=(ytest.shape[0], 1))\n",
209 | " \n",
210 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='test_labels')\n",
211 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='test_batch')\n",
212 | "\n",
213 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n",
214 | " \n",
215 | " accuracies = list()\n",
216 | " bar = ProgressBar(max_value=np.floor(len(ytest)/batch_shape[0]).astype('i'))\n",
217 | " minibatch = enumerate(self.get_minibatch(xtest, ytest, ul=None, batch_shape=batch_shape))\n",
218 | " for i, test_batch in minibatch:\n",
219 | " fd = {batch: test_batch['x'], labels: test_batch['y'], K.backend.learning_phase(): 0} #test mode\n",
220 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n",
221 | " \n",
222 | " bar.update(i)\n",
223 | " \n",
224 | " print( \"\\nAverage accuracy is {:.3f}\".format(np.asarray(accuracies).mean()) )\n"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "metadata": {
231 | "collapsed": true
232 | },
233 | "outputs": [],
234 | "source": [
235 | "def main(data, n_epochs, n_ex, ex_len, lt, pm):\n",
236 | " os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
237 | " config = tf.ConfigProto(log_device_placement=True)\n",
238 | " config.gpu_options.allow_growth = True\n",
239 | " session = tf.Session(config=config)\n",
240 | "\n",
241 | " embedding_weights = np.load( \"{}nltk_embedding_matrix.npy\".format(data) )\n",
242 | "\n",
243 | " net = Network(session, embedding_weights)\n",
244 | " net.train(data, batch_shape=(n_ex, ex_len), epochs=n_epochs, loss_type=lt, p_mult=pm, init=None, save=None)\n",
245 | " net.test(data, batch_shape=(n_ex, ex_len))\n",
246 | " \n",
247 | " K.backend.clear_session()"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {
254 | "scrolled": false
255 | },
256 | "outputs": [],
257 | "source": [
258 | " main(data='../dataset/imdb/', n_epochs=10, n_ex=64, ex_len=400, lt='none', pm=0.02)"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {
265 | "collapsed": true
266 | },
267 | "outputs": [],
268 | "source": [
269 | "'''\n",
270 | "if (__name__ == '__main__'):\n",
271 | " parser = argparse.ArgumentParser()\n",
272 | " parser.add_argument('path', help='path of the folder that contains the data (train, test, emb. matrix)', type=str)\n",
273 | " parser.add_argument('epochs', help='number of training epochs', type=int)\n",
274 | " parser.add_argument('--nex', help='Number of EXamples per batch', default=64, type=int)\n",
275 | " parser.add_argument('--exlen', help='LENght of each EXample', default=400, type=int)\n",
276 | " parser.add_argument('--loss', help='define the loss type', choices=['none', 'adv', 'v_adv'], default='none', type=str)\n",
277 | " parser.add_argument('--p_mult', help='Perturbation MULTiplier (used with adv and v_adv)', default=0.02, type=float)\n",
278 | " args = parser.parse_args()\n",
279 | " \n",
280 | " main(args.path, args.epochs, args.nex, args.exlen, args.loss, args.p_mult)\n",
281 | "'''"
282 | ]
283 | }
284 | ],
285 | "metadata": {
286 | "kernelspec": {
287 | "display_name": "Python 3",
288 | "language": "python",
289 | "name": "python3"
290 | },
291 | "language_info": {
292 | "codemirror_mode": {
293 | "name": "ipython",
294 | "version": 3
295 | },
296 | "file_extension": ".py",
297 | "mimetype": "text/x-python",
298 | "name": "python",
299 | "nbconvert_exporter": "python",
300 | "pygments_lexer": "ipython3",
301 | "version": "3.6.4"
302 | }
303 | },
304 | "nbformat": 4,
305 | "nbformat_minor": 2
306 | }
307 |
--------------------------------------------------------------------------------
/pyramidal_network.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import tensorflow as tf\n",
12 | "from tensorflow import keras as K\n",
13 | "import numpy as np\n",
14 | "import argparse\n",
15 | "from progressbar import ProgressBar\n",
16 | "import os\n",
17 | "\n",
18 | "'''\n",
19 | "import matplotlib\n",
20 | "matplotlib.use('Agg')\n",
21 | "import matplotlib.pyplot as plt\n",
22 | "'''\n",
23 | "\n",
24 | "from time import sleep"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "class Network:\n",
36 | " def __init__(self, session, dict_weight, dropout=0.3, lstm_units=1024, dense_units=60):\n",
37 | " self.sess = session\n",
38 | " K.backend.set_session(self.sess)\n",
39 | " #defining layers\n",
40 | " dict_shape = dict_weight.shape\n",
41 | " self.emb = K.layers.Embedding(dict_shape[0], dict_shape[1], weights=[dict_weight], trainable=False, name='embedding')\n",
42 | " self.drop = K.layers.Dropout(rate=dropout, seed=91, name='dropout')\n",
43 | " self.lstm = K.layers.LSTM(lstm_units, stateful=False, return_sequences=False, name='lstm')\n",
44 | " \n",
45 | " self.dense = K.layers.Dense(dense_units, activation='relu', name='dense')\n",
46 | " self.p = K.layers.Dense(1, activation='sigmoid', name='p')\n",
47 | " #defining optimizer\n",
48 | " self.optimizer = tf.train.AdamOptimizer(learning_rate=0.0005)\n",
49 | "\n",
50 | " def __call__(self, batch, perturbation=None): \n",
51 | " batch1 = batch[:,0:400]\n",
52 | " batch2 = batch[:,400:800]\n",
53 | " batch3 = batch[:,800:1200]\n",
54 | " \n",
55 | " embedding1 = self.emb(batch1)\n",
56 | " embedding2 = self.emb(batch2)\n",
57 | " embedding3 = self.emb(batch3)\n",
58 | " \n",
59 | " drop1 = self.drop(embedding1)\n",
60 | " drop2 = self.drop(embedding2)\n",
61 | " drop3 = self.drop(embedding3)\n",
62 | " \n",
63 | " if (perturbation is not None):\n",
64 | " drop1 += perturbation[0]\n",
65 | " drop2 += perturbation[1]\n",
66 | " drop3 += perturbation[2]\n",
67 | " \n",
68 | " lstm1 = self.lstm(drop1)\n",
69 | " lstm2 = self.lstm(drop2)\n",
70 | " lstm3 = self.lstm(drop3)\n",
71 | " lstm = tf.concat([lstm1, lstm2, lstm3], axis=1)\n",
72 | " dense = self.dense(lstm)\n",
73 | " \n",
74 | " return self.p(dense), (embedding1, embedding2, embedding3)\n",
75 | " \n",
76 | " def get_minibatch(self, x, y, ul, batch_shape=(16, 1200)):\n",
77 | " x = K.preprocessing.sequence.pad_sequences(x, maxlen=batch_shape[1])\n",
78 | " permutations = np.random.permutation( len(y) )\n",
79 | " ul_permutations = None\n",
80 | " len_ratio = None\n",
81 | " if (ul is not None):\n",
82 | " ul = K.preprocessing.sequence.pad_sequences(ul, maxlen=batch_shape[1])\n",
83 | " ul_permutations = np.random.permutation(len(ul))\n",
84 | " len_ratio = len(ul)/len(y)\n",
85 | " for s in range(0, len(y), batch_shape[0]):\n",
86 | " perm = permutations[s:s+batch_shape[0]]\n",
87 | " minibatch = {'x': x[perm], 'y': y[perm]}\n",
88 | " if (ul is not None):\n",
89 | " ul_perm = ul_permutations[int(np.floor(len_ratio*s)):int(np.floor(len_ratio*(s+batch_shape[0])))]\n",
90 | " minibatch.update( {'ul': np.concatenate((ul[ul_perm], x[perm]), axis=0)} ) \n",
91 | " yield minibatch\n",
92 | " \n",
93 | " def get_loss(self, batch, labels):\n",
94 | " pred, emb = self(batch)\n",
95 | " loss = K.losses.binary_crossentropy(labels, pred)\n",
96 | " return tf.reduce_mean( loss ), emb\n",
97 | " \n",
98 | " def get_adv_loss(self, batch, labels, loss, emb, p_mult):\n",
99 | " g1 = tf.gradients(loss, emb[0], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n",
100 | " g2 = tf.gradients(loss, emb[1], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n",
101 | " g3 = tf.gradients(loss, emb[2], aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)[0]\n",
102 | " p_adv = list()\n",
103 | " p_adv.append( p_mult * tf.nn.l2_normalize(tf.stop_gradient(g1), dim=1) )\n",
104 | " p_adv.append( p_mult * tf.nn.l2_normalize(tf.stop_gradient(g2), dim=1) )\n",
105 | " p_adv.append( p_mult * tf.nn.l2_normalize(tf.stop_gradient(g3), dim=1) )\n",
106 | " adv_loss = K.losses.binary_crossentropy(labels, self(batch, p_adv)[0])\n",
107 | " return tf.reduce_mean( adv_loss )\n",
108 | " \n",
109 | " def validate(self, xval, yval, batch_shape=(64, 1200), log_path=None):\n",
110 | " print( 'Validation...' )\n",
111 | " \n",
112 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='labels')\n",
113 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='batch')\n",
114 | "\n",
115 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n",
116 | " accuracies = list()\n",
117 | " bar = ProgressBar(max_value=np.floor(len(yval)/batch_shape[0]).astype('i'))\n",
118 | " minibatch = enumerate(self.get_minibatch(xval, yval, None, batch_shape=batch_shape))\n",
119 | " for i, val_batch in minibatch:\n",
120 | " fd = {batch: val_batch['x'], labels: val_batch['y'], K.backend.learning_phase(): 0} #test mode\n",
121 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n",
122 | " bar.update(i)\n",
123 | " \n",
124 | " log_msg = \"\\nValidation Average accuracy is {:.3f} -- batch shape {}\"\n",
125 | " print( log_msg.format(np.asarray(accuracies).mean(), batch_shape) )\n",
126 | " log = None\n",
127 | " if(log_path is not None):\n",
128 | " log = open(log_path, 'a')\n",
129 | " log.write(log_msg.format(np.asarray(accuracies).mean(), batch_shape)+'\\n')\n",
130 | " log.close()\n",
131 | " \n",
132 | " def train(self, dataset, batch_shape=(64, 1200), epochs=10, loss_type='none', p_mult=0.02, save=None, log_path=None):\n",
133 | " print( 'Training...' )\n",
134 | " xtrain = np.load( \"{}nltk_xtrain.npy\".format(dataset) )\n",
135 | " ytrain = np.load( \"{}nltk_ytrain.npy\".format(dataset) )\n",
136 | " \n",
137 | " xval = list()\n",
138 | " yval = list()\n",
139 | " for i in range(int(len(ytrain)*0.025)):\n",
140 | " xval.append( xtrain[0] ); xval.append( xtrain[-1] )\n",
141 | " yval.append( ytrain[0] ); yval.append( ytrain[-1] )\n",
142 | " xtrain = np.delete(xtrain, 0); xtrain = np.delete(xtrain, -1)\n",
143 | " ytrain = np.delete(ytrain, 0); ytrain = np.delete(ytrain, -1)\n",
144 | " xval = np.asarray(xval)\n",
145 | " yval = np.asarray(yval)\n",
146 | " \n",
147 | " yval = np.reshape(yval, newshape=(yval.shape[0],1))\n",
148 | " ytrain = np.reshape(ytrain, newshape=(ytrain.shape[0], 1))\n",
149 | " \n",
150 | " ultrain = np.load( \"{}nltk_ultrain.npy\".format(dataset) ) if (loss_type == 'v_adv') else None \n",
151 | " \n",
152 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='labels')\n",
153 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='batch')\n",
154 | " ul_batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='ul_batch')\n",
155 | " \n",
156 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n",
157 | " loss, emb = self.get_loss(batch, labels)\n",
158 | " if (loss_type == 'adv'):\n",
159 | " loss += self.get_adv_loss(batch, labels, loss, emb, p_mult)\n",
160 | " elif (loss_type == 'v_adv'):\n",
161 | " loss += self.get_v_adv_loss(ul_batch, p_mult)\n",
162 | "\n",
163 | " opt = self.optimizer.minimize( loss )\n",
164 | " #initializing parameters\n",
165 | " self.sess.run( [var.initializer for var in tf.global_variables() if not('embedding' in var.name)] )\n",
166 | " \n",
167 | " _losses = list()\n",
168 | " _accuracies = list()\n",
169 | " log = None\n",
170 | " \n",
171 | " list_ratio = (len(ultrain)/len(ytrain)) if (ultrain is not None) else None\n",
172 | " for epoch in range(epochs):\n",
173 | " losses = list()\n",
174 | " accuracies = list()\n",
175 | " \n",
176 | " bar = ProgressBar(max_value=np.floor(len(ytrain)/batch_shape[0]).astype('i'))\n",
177 | " minibatch = enumerate(self.get_minibatch(xtrain, ytrain, ultrain, batch_shape=batch_shape))\n",
178 | " for i, train_batch in minibatch:\n",
179 | " fd = {batch: train_batch['x'], labels: train_batch['y'], K.backend.learning_phase(): 1} #training mode\n",
180 | " if (loss_type == 'v_adv'):\n",
181 | " fd.update( {ul_batch: train_batch['ul']} )\n",
182 | " \n",
183 | " _, acc_val, loss_val = self.sess.run([opt, accuracy, loss], feed_dict=fd)\n",
184 | " \n",
185 | " accuracies.append( acc_val )\n",
186 | " losses.append( loss_val )\n",
187 | " bar.update(i)\n",
188 | " \n",
189 | " _losses.append(losses)\n",
190 | " _accuracies.append(accuracies)\n",
191 | " \n",
192 | " log_msg = \"\\nEpoch {} of {} -- average accuracy is {:.3f} -- average loss is {:.3f}\"\n",
193 | " print( log_msg.format(epoch+1, epochs, np.asarray(accuracies).mean(), np.asarray(losses).mean()) )\n",
194 | " if(log_path is not None):\n",
195 | " log = open(log_path, 'a')\n",
196 | " log.write(log_msg.format(epoch+1, epochs, np.asarray(accuracies).mean(), np.asarray(losses).mean())+'\\n')\n",
197 | " log.close()\n",
198 | " \n",
199 | " #validation\n",
200 | " self.validate(xval, yval, batch_shape=batch_shape, log_path=log_path)\n",
201 | " \n",
202 | " def test(self, dataset, batch_shape=(64, 1200), log_path=None):\n",
203 | " print( 'Test...' )\n",
204 | " xtest = np.load( \"{}nltk_xtest.npy\".format(dataset) )\n",
205 | " ytest = np.load( \"{}nltk_ytest.npy\".format(dataset) )\n",
206 | " ytest = np.reshape(ytest, newshape=(ytest.shape[0], 1))\n",
207 | " \n",
208 | " labels = tf.placeholder(tf.float32, shape=(None, 1), name='labels')\n",
209 | " batch = tf.placeholder(tf.float32, shape=(None, batch_shape[1]), name='batch')\n",
210 | "\n",
211 | " accuracy = tf.reduce_mean( K.metrics.binary_accuracy(labels, self(batch)[0]) )\n",
212 | " \n",
213 | " accuracies = list()\n",
214 | " bar = ProgressBar(max_value=np.floor(len(ytest)/batch_shape[0]).astype('i'))\n",
215 | " minibatch = enumerate(self.get_minibatch(xtest, ytest, None, batch_shape=batch_shape))\n",
216 | " for i, test_batch in minibatch:\n",
217 | " fd = {batch: test_batch['x'], labels: test_batch['y'], K.backend.learning_phase(): 0} #test mode\n",
218 | " accuracies.append( self.sess.run(accuracy, feed_dict=fd) )\n",
219 | " bar.update(i)\n",
220 | " \n",
221 | " log_msg = \"\\nTest Average accuracy is {:.3f} -- batch shape {}\"\n",
222 | " print( log_msg.format(np.asarray(accuracies).mean(), batch_shape) )\n",
223 | " \n",
224 | " log = None\n",
225 | " if(log_path is not None):\n",
226 | " log = open(log_path, 'a')\n",
227 | " log.write(log_msg.format(np.asarray(accuracies).mean(), batch_shape)+'\\n')\n",
228 | " log.close()"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "metadata": {
235 | "collapsed": true
236 | },
237 | "outputs": [],
238 | "source": [
239 | "def main(data, n_epochs, n_ex, ex_len, lt, pm):\n",
240 | " os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
241 | " config = tf.ConfigProto(log_device_placement=True)\n",
242 | " config.gpu_options.allow_growth = True\n",
243 | " session = tf.Session(config=config)\n",
244 | "\n",
245 | " embedding_weights = np.load( \"{}nltk_embedding_matrix.npy\".format(data) )\n",
246 | "\n",
247 | " # log text file\n",
248 | " net_tp = 'baseline' if (lt == 'none') else lt\n",
249 | " log_path = './pyramidal_{}_bs_{}_ep_{}_sl_{}.txt'.format(net_tp, n_ex, n_epochs, ex_len)\n",
250 | " log = open(log_path, 'w')\n",
251 | " log.write('Fancy Network with {} loss, {} epochs, {} batch size, {} maximum string length \\n'.format(net_tp, n_epochs, n_ex, ex_len))\n",
252 | " log.close()\n",
253 | " \n",
254 | " net = Network(session, embedding_weights)\n",
255 | " net.train(data, batch_shape=(n_ex, ex_len), epochs=n_epochs, loss_type=lt, p_mult=pm, log_path=log_path)\n",
256 | " net.test(data, batch_shape=(n_ex, ex_len), log_path=log_path)\n",
257 | " \n",
258 | " K.backend.clear_session()"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {
265 | "collapsed": true
266 | },
267 | "outputs": [],
268 | "source": [
269 | "main(data='../dataset/imdb/', n_epochs=10, n_ex=16, ex_len=1200, lt='none', pm=0.02) # works only with ex_len=1200"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": null,
275 | "metadata": {
276 | "collapsed": true
277 | },
278 | "outputs": [],
279 | "source": [
280 | "'''\n",
281 | "if (__name__ == '__main__'):\n",
282 | " parser = argparse.ArgumentParser()\n",
283 | " parser.add_argument('path', help='path of the folder that contains the data (train, test, emb. matrix)', type=str)\n",
284 | " parser.add_argument('epochs', help='number of training epochs', type=int)\n",
285 | " parser.add_argument('--nex', help='Number of EXamples per batch', default=64, type=int)\n",
286 | " parser.add_argument('--exlen', help='LENght of each EXample', default=1200, type=int)\n",
287 | " parser.add_argument('--loss', help='define the loss type', choices=['none', 'adv'], default='none', type=str)\n",
288 | " parser.add_argument('--p_mult', help='Perturbation MULTiplier (used with adv)', default=0.02, type=float)\n",
289 | " args = parser.parse_args()\n",
290 | " \n",
291 | " main(args.path, args.epochs, args.nex, args.exlen, args.loss, args.p_mult)\n",
292 | "'''"
293 | ]
294 | }
295 | ],
296 | "metadata": {
297 | "kernelspec": {
298 | "display_name": "Python 3",
299 | "language": "python",
300 | "name": "python3"
301 | },
302 | "language_info": {
303 | "codemirror_mode": {
304 | "name": "ipython",
305 | "version": 3
306 | },
307 | "file_extension": ".py",
308 | "mimetype": "text/x-python",
309 | "name": "python",
310 | "nbconvert_exporter": "python",
311 | "pygments_lexer": "ipython3",
312 | "version": "3.6.4"
313 | }
314 | },
315 | "nbformat": 4,
316 | "nbformat_minor": 2
317 | }
318 |
--------------------------------------------------------------------------------