├── .gitignore
├── LICENSE
├── README.md
├── data_utils.py
├── email_utils.py
├── main.py
├── spam_detect_char
├── spam_detection.ipynb
└── spam_email.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Dhoomil B Sheta
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # spam-detection-using-deep-learning
2 | Detecting Spam Emails using CNN
3 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from keras.preprocessing import text
3 | import pickle
4 | import os
5 | from bs4 import BeautifulSoup
6 | from email.parser import Parser
7 |
8 | parser = Parser()
9 |
10 |
11 | # Load Data
12 | def process_dataset():
13 | data = pd.read_csv("data/enron.csv")
14 | print(f"Total emails: {len(data)}")
15 | emails = data['msg'].values
16 | labels = [1 if x == "spam" else 0 for x in data['label'].values]
17 |
18 | # Pre-process Data
19 | # tokenizer = text.Tokenizer(char_level=True)
20 | # tokenizer.fit_on_texts(emails)
21 | # sequences = tokenizer.texts_to_sequences(emails)
22 | # word2index = tokenizer.word_index
23 | # num_words = len(word2index)
24 | # print(f"Found {num_words} unique tokens")
25 | alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:"
26 | char2index = {}
27 | for i, c in enumerate(alphabet):
28 | char2index[c] = i + 1
29 |
30 | sequences = []
31 | for email in emails:
32 | seq = []
33 | for c in email:
34 | if c in char2index:
35 | seq.append(char2index[c])
36 | sequences.append(seq)
37 |
38 | with open("data/dataset.pkl", 'wb') as f:
39 | pickle.dump([sequences, labels, char2index], f)
40 |
41 |
42 | process_dataset()
43 |
44 |
45 | def process_email(filename):
46 | with open(filename) as f:
47 | email = parser.parse(f)
48 | cleantext = ""
49 | if email.is_multipart():
50 | for part in email.get_payload():
51 | soup = BeautifulSoup(part.as_string(maxheaderlen=1))
52 | txt = soup.get_text()
53 | txt = ' '.join(txt.split())
54 | i = txt.find("Content-Transfer-Encoding")
55 | txt = txt[i + len("Content-Transfer-Encoding"):].split(maxsplit=2)[2]
56 | cleantext += txt
57 |
58 | else:
59 | soup = BeautifulSoup(email.get_payload())
60 | txt = soup.get_text()
61 | txt = ' '.join(txt.split())
62 | i = txt.find("Content-Transfer-Encoding")
63 | txt = txt[i + len("Content-Transfer-Encoding"):].split(maxsplit=2)[2]
64 | cleantext += txt
65 | print(cleantext)
66 |
67 | # for filename in os.listdir("data/infy_spam_emails"):
68 | # process_email(f"data/infy_spam_emails/{filename}")
69 |
--------------------------------------------------------------------------------
/email_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # FileName: Subsampling.py
3 | # Version 1.0 by Tao Ban, 2010.5.26
4 | # This function extract all the contents, ie subject and first part from the .eml file
5 | # and store it in a new file with the same name in the dst dir.
6 |
7 | import email.parser
8 | import os
9 | import stat
10 | import sys
11 |
12 |
13 | def ExtractSubPayload(filename):
14 | ''' Extract the subject and payload from the .eml file.
15 |
16 | '''
17 | if not os.path.exists(filename): # dest path doesnot exist
18 | print("ERROR: input file does not exist:", filename)
19 | os.exit(1)
20 | fp = open(filename)
21 | msg = email.message_from_file(fp)
22 | payload = msg.get_payload()
23 | if type(payload) == type(list()):
24 | payload = payload[0] # only use the first part of payload
25 | sub = msg.get('subject')
26 | sub = str(sub)
27 | if type(payload) != type(''):
28 | payload = str(payload)
29 |
30 | return sub + payload
31 |
32 |
33 | def ExtractBodyFromDir(srcdir, dstdir):
34 | '''Extract the body information from all .eml files in the srcdir and
35 |
36 | save the file to the dstdir with the same name.'''
37 | if not os.path.exists(dstdir): # dest path doesnot exist
38 | os.makedirs(dstdir)
39 | files = os.listdir(srcdir)
40 | for file in files:
41 | srcpath = os.path.join(srcdir, file)
42 | dstpath = os.path.join(dstdir, file)
43 | src_info = os.stat(srcpath)
44 | if stat.S_ISDIR(src_info.st_mode): # for subfolders, recurse
45 | ExtractBodyFromDir(srcpath, dstpath)
46 | else: # copy the file
47 | body = ExtractSubPayload(srcpath)
48 | dstfile = open(dstpath, 'w')
49 | dstfile.write(body)
50 | dstfile.close()
51 |
52 |
53 | ###################################################################
54 | # main function start here
55 | # srcdir is the directory where the .eml are stored
56 | print('Input source directory: ') # ask for source and dest dirs
57 | srcdir = input()
58 | if not os.path.exists(srcdir):
59 | print('The source directory %s does not exist, exit...' % (srcdir))
60 | sys.exit()
61 | # dstdir is the directory where the content .eml are stored
62 | print('Input destination directory: ') # ask for source and dest dirs
63 | dstdir = input()
64 | if not os.path.exists(dstdir):
65 | print('The destination directory is newly created.')
66 | os.makedirs(dstdir)
67 |
68 | ###################################################################
69 | ExtractBodyFromDir(srcdir, dstdir)
70 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from keras.utils import to_categorical
2 | from keras.preprocessing import sequence
3 | from mxnet import gluon
4 | from sklearn.model_selection import train_test_split
5 | from keras.models import Model, load_model
6 | from keras.layers import Conv1D, GlobalMaxPooling1D, Dropout, Dense, Input, Embedding, MaxPooling1D, Flatten
7 | from keras.callbacks import ModelCheckpoint
8 | import numpy as np
9 | import pickle
10 |
11 | MAX_WORDS_IN_SEQ = 1000
12 | EMBED_DIM = 100
13 | MODEL_PATH = "models/spam_detect"
14 |
15 | # Load Data
16 | with open("data/dataset.pkl", 'rb') as f:
17 | sequences, labels, word2index = pickle.load(f)
18 |
19 | num_words = len(word2index)
20 | print(f"Found {num_words} unique tokens")
21 |
22 | data = sequence.pad_sequences(sequences, maxlen=MAX_WORDS_IN_SEQ, padding='post', truncating='post')
23 | print(labels[:10])
24 | labels = to_categorical(labels)
25 | print(labels[:10])
26 |
27 | print('Shape of data tensor:', data.shape)
28 | print('Shape of label tensor:', labels.shape)
29 |
30 | x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2)
31 |
32 | # Building the model
33 | input_seq = Input(shape=[MAX_WORDS_IN_SEQ, ], dtype='int32')
34 | embed_seq = Embedding(num_words + 1, EMBED_DIM, embeddings_initializer='glorot_normal', input_length=MAX_WORDS_IN_SEQ)(
35 | input_seq)
36 | conv_1 = Conv1D(128, 5, activation='relu')(embed_seq)
37 | conv_1 = MaxPooling1D(pool_size=5)(conv_1)
38 | conv_2 = Conv1D(128, 5, activation='relu')(conv_1)
39 | conv_2 = MaxPooling1D(pool_size=5)(conv_2)
40 | conv_3 = Conv1D(128, 5, activation='relu')(conv_2)
41 | conv_3 = MaxPooling1D(pool_size=35)(conv_3)
42 | flat = Flatten()(conv_3)
43 | flat = Dropout(0.25)(flat)
44 | fc1 = Dense(128, activation='relu')(flat)
45 | dense_1 = Dropout(0.25)(flat)
46 | fc2 = Dense(2, activation='softmax')(fc1)
47 |
48 | model = Model(input_seq, fc2)
49 | model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
50 |
51 | # Train the model
52 | model.fit(
53 | x_train,
54 | y_train,
55 | batch_size=128,
56 | epochs=2,
57 | callbacks=[ModelCheckpoint(MODEL_PATH, save_best_only=True)],
58 | validation_data=[x_test, y_test]
59 | )
60 |
61 | model.save(MODEL_PATH)
62 |
63 |
64 | class CnnClassifierModel(gluon.HybridBlock):
65 | def __init__(self, **kwargs):
66 | super(CnnClassifierModel, self).__init__(**kwargs)
67 | with self.name_scope():
68 | self.conv1 = gluon.nn.Conv1D()
69 |
70 | def hybrid_forward(self, F, x, *args, **kwargs):
71 | pass
72 |
--------------------------------------------------------------------------------
/spam_detect_char:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dbsheta/spam-detection-using-deep-learning/872b5086bfa72f2e6a901924c875cf76fe5b7cdb/spam_detect_char
--------------------------------------------------------------------------------
/spam_detection.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Spam Detection Using CNN"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [
15 | {
16 | "name": "stderr",
17 | "output_type": "stream",
18 | "text": [
19 | "Using TensorFlow backend.\n"
20 | ]
21 | }
22 | ],
23 | "source": [
24 | "from keras.utils import to_categorical\n",
25 | "from keras.preprocessing import sequence, text\n",
26 | "from sklearn.model_selection import train_test_split\n",
27 | "from keras.models import Model, load_model\n",
28 | "from keras.layers import Conv1D, GlobalMaxPooling1D, Dropout, Dense, Input, Embedding, MaxPooling1D, Flatten\n",
29 | "from keras.callbacks import ModelCheckpoint\n",
30 | "import numpy as np\n",
31 | "import pandas as pd\n",
32 | "\n",
33 | "MAX_WORDS_IN_SEQ = 1000\n",
34 | "EMBED_DIM = 100\n",
35 | "MODEL_NAME = \"/model/spam_detect\""
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## Load Data"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "Total emails: 33716\n"
55 | ]
56 | },
57 | {
58 | "data": {
59 | "text/html": [
60 | "
\n",
61 | "\n",
74 | "
\n",
75 | " \n",
76 | " \n",
77 | " | \n",
78 | " label | \n",
79 | " index | \n",
80 | " msg | \n",
81 | " dataset | \n",
82 | " file | \n",
83 | "
\n",
84 | " \n",
85 | " \n",
86 | " \n",
87 | " 0 | \n",
88 | " spam | \n",
89 | " 0 | \n",
90 | " Subject: dobmeos with hgh my energy level has ... | \n",
91 | " 1 | \n",
92 | " enron1/spam/0006.2003-12-18.GP.spam.txt | \n",
93 | "
\n",
94 | " \n",
95 | " 1 | \n",
96 | " spam | \n",
97 | " 1 | \n",
98 | " Subject: your prescription is ready . . oxwq s... | \n",
99 | " 1 | \n",
100 | " enron1/spam/0008.2003-12-18.GP.spam.txt | \n",
101 | "
\n",
102 | " \n",
103 | " 2 | \n",
104 | " ham | \n",
105 | " 2 | \n",
106 | " Subject: christmas tree farm pictures | \n",
107 | " 1 | \n",
108 | " enron1/ham/0001.1999-12-10.farmer.ham.txt | \n",
109 | "
\n",
110 | " \n",
111 | " 3 | \n",
112 | " ham | \n",
113 | " 3 | \n",
114 | " Subject: vastar resources , inc .gary , produc... | \n",
115 | " 1 | \n",
116 | " enron1/ham/0002.1999-12-13.farmer.ham.txt | \n",
117 | "
\n",
118 | " \n",
119 | " 4 | \n",
120 | " ham | \n",
121 | " 4 | \n",
122 | " Subject: calpine daily gas nomination- calpine... | \n",
123 | " 1 | \n",
124 | " enron1/ham/0003.1999-12-14.farmer.ham.txt | \n",
125 | "
\n",
126 | " \n",
127 | "
\n",
128 | "
"
129 | ],
130 | "text/plain": [
131 | " label index msg dataset \\\n",
132 | "0 spam 0 Subject: dobmeos with hgh my energy level has ... 1 \n",
133 | "1 spam 1 Subject: your prescription is ready . . oxwq s... 1 \n",
134 | "2 ham 2 Subject: christmas tree farm pictures 1 \n",
135 | "3 ham 3 Subject: vastar resources , inc .gary , produc... 1 \n",
136 | "4 ham 4 Subject: calpine daily gas nomination- calpine... 1 \n",
137 | "\n",
138 | " file \n",
139 | "0 enron1/spam/0006.2003-12-18.GP.spam.txt \n",
140 | "1 enron1/spam/0008.2003-12-18.GP.spam.txt \n",
141 | "2 enron1/ham/0001.1999-12-10.farmer.ham.txt \n",
142 | "3 enron1/ham/0002.1999-12-13.farmer.ham.txt \n",
143 | "4 enron1/ham/0003.1999-12-14.farmer.ham.txt "
144 | ]
145 | },
146 | "execution_count": 3,
147 | "metadata": {},
148 | "output_type": "execute_result"
149 | }
150 | ],
151 | "source": [
152 | "data = pd.read_csv(\"~/Development/datasets/enron.csv\")\n",
153 | "print(f\"Total emails: {len(data)}\")\n",
154 | "data.head()"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 4,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "emails = data['msg'].values\n",
164 | "labels = [1 if x == \"spam\" else 0 for x in data['label'].values]"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 5,
170 | "metadata": {},
171 | "outputs": [
172 | {
173 | "data": {
174 | "text/plain": [
175 | "226609"
176 | ]
177 | },
178 | "execution_count": 5,
179 | "metadata": {},
180 | "output_type": "execute_result"
181 | }
182 | ],
183 | "source": [
184 | "max_len = max(map(lambda x: len(x), emails))\n",
185 | "max_len"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "## Pre-Process Data"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 89,
198 | "metadata": {},
199 | "outputs": [
200 | {
201 | "name": "stdout",
202 | "output_type": "stream",
203 | "text": [
204 | "Found 309362 unique tokens\n"
205 | ]
206 | }
207 | ],
208 | "source": [
209 | "tokenizer = text.Tokenizer()\n",
210 | "tokenizer.fit_on_texts(emails)\n",
211 | "sequences = tokenizer.texts_to_sequences(emails)\n",
212 | "word2index = tokenizer.word_index\n",
213 | "num_words = len(word2index)\n",
214 | "print(f\"Found {num_words} unique tokens\")\n"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 90,
220 | "metadata": {},
221 | "outputs": [
222 | {
223 | "name": "stdout",
224 | "output_type": "stream",
225 | "text": [
226 | "[1, 1, 0, 0, 0, 0, 0, 1, 0, 1]\n",
227 | "[[ 0. 1.]\n",
228 | " [ 0. 1.]\n",
229 | " [ 1. 0.]\n",
230 | " [ 1. 0.]\n",
231 | " [ 1. 0.]\n",
232 | " [ 1. 0.]\n",
233 | " [ 1. 0.]\n",
234 | " [ 0. 1.]\n",
235 | " [ 1. 0.]\n",
236 | " [ 0. 1.]]\n",
237 | "Shape of data tensor: (33716, 1000)\n",
238 | "Shape of label tensor: (33716, 2)\n"
239 | ]
240 | }
241 | ],
242 | "source": [
243 | "data = sequence.pad_sequences(sequences, maxlen=MAX_WORDS_IN_SEQ, padding='post', truncating='post')\n",
244 | "print(labels[:10])\n",
245 | "labels = to_categorical(labels)\n",
246 | "print(labels[:10])\n",
247 | "\n",
248 | "print('Shape of data tensor:', data.shape)\n",
249 | "print('Shape of label tensor:', labels.shape)\n",
250 | "\n",
251 | "x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2)"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {},
257 | "source": [
258 | "## Building the Model: Basic CNN"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 91,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "input_seq = Input(shape=[MAX_WORDS_IN_SEQ, ], dtype='int32')\n",
268 | "embed_seq = Embedding(num_words, EMBED_DIM, embeddings_initializer='glorot_uniform', input_length=MAX_WORDS_IN_SEQ)(\n",
269 | " input_seq)\n",
270 | "conv_1 = Conv1D(128, 5, activation='relu')(embed_seq)\n",
271 | "conv_1 = MaxPooling1D(pool_size=5)(conv_1)\n",
272 | "conv_2 = Conv1D(128, 5, activation='relu')(conv_1)\n",
273 | "conv_2 = MaxPooling1D(pool_size=5)(conv_2)\n",
274 | "conv_3 = Conv1D(128, 5, activation='relu')(conv_2)\n",
275 | "conv_3 = MaxPooling1D(pool_size=35)(conv_3)\n",
276 | "flat = Flatten()(conv_3)\n",
277 | "# flat = Dropout(0.25)(flat)\n",
278 | "fc1 = Dense(128, activation='relu')(flat)\n",
279 | "# dense_1 = Dropout(0.25)(flat)\n",
280 | "fc2 = Dense(2, activation='softmax')(fc1)\n",
281 | "\n",
282 | "model = Model(input_seq, fc2)\n",
283 | "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": 92,
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "_________________________________________________________________\n",
296 | "Layer (type) Output Shape Param # \n",
297 | "=================================================================\n",
298 | "input_8 (InputLayer) (None, 1000) 0 \n",
299 | "_________________________________________________________________\n",
300 | "embedding_7 (Embedding) (None, 1000, 100) 30936200 \n",
301 | "_________________________________________________________________\n",
302 | "conv1d_19 (Conv1D) (None, 996, 128) 64128 \n",
303 | "_________________________________________________________________\n",
304 | "max_pooling1d_19 (MaxPooling (None, 199, 128) 0 \n",
305 | "_________________________________________________________________\n",
306 | "conv1d_20 (Conv1D) (None, 195, 128) 82048 \n",
307 | "_________________________________________________________________\n",
308 | "max_pooling1d_20 (MaxPooling (None, 39, 128) 0 \n",
309 | "_________________________________________________________________\n",
310 | "conv1d_21 (Conv1D) (None, 35, 128) 82048 \n",
311 | "_________________________________________________________________\n",
312 | "max_pooling1d_21 (MaxPooling (None, 1, 128) 0 \n",
313 | "_________________________________________________________________\n",
314 | "flatten_7 (Flatten) (None, 128) 0 \n",
315 | "_________________________________________________________________\n",
316 | "dense_13 (Dense) (None, 128) 16512 \n",
317 | "_________________________________________________________________\n",
318 | "dense_14 (Dense) (None, 2) 258 \n",
319 | "=================================================================\n",
320 | "Total params: 31,181,194\n",
321 | "Trainable params: 31,181,194\n",
322 | "Non-trainable params: 0\n",
323 | "_________________________________________________________________\n"
324 | ]
325 | }
326 | ],
327 | "source": [
328 | "# Testing ---------------------------------------\n",
329 | "model.summary()"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": null,
335 | "metadata": {},
336 | "outputs": [],
337 | "source": [
338 | "model.fit()"
339 | ]
340 | }
341 | ],
342 | "metadata": {
343 | "kernelspec": {
344 | "display_name": "Python 3",
345 | "language": "python",
346 | "name": "python3"
347 | },
348 | "language_info": {
349 | "codemirror_mode": {
350 | "name": "ipython",
351 | "version": 3
352 | },
353 | "file_extension": ".py",
354 | "mimetype": "text/x-python",
355 | "name": "python",
356 | "nbconvert_exporter": "python",
357 | "pygments_lexer": "ipython3",
358 | "version": "3.6.1"
359 | }
360 | },
361 | "nbformat": 4,
362 | "nbformat_minor": 1
363 | }
364 |
--------------------------------------------------------------------------------
/spam_email.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "Using TensorFlow backend.\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "from keras.preprocessing import sequence\n",
18 | "from keras.utils import to_categorical\n",
19 | "from keras.models import Model, load_model\n",
20 | "from keras.layers import Conv1D, Dropout, Dense, Input, Embedding, MaxPooling1D, Flatten, BatchNormalization, Activation\n",
21 | "from keras.callbacks import ModelCheckpoint\n",
22 | "from sklearn.model_selection import train_test_split\n",
23 | "\n",
24 | "import mxnet as mx\n",
25 | "from mxnet import gluon\n",
26 | "from mxnet import autograd\n",
27 | "\n",
28 | "import pickle\n",
29 | "import numpy as np\n",
30 | "import time\n",
31 | "import math"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 2,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "def time_since(start):\n",
41 | " now = time.time()\n",
42 | " s = now - start\n",
43 | " m = math.floor(s / 60)\n",
44 | " s -= m * 60\n",
45 | " return '%dm %ds' % (m, s)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 20,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "MAX_WORDS_IN_SEQ = 3000\n",
55 | "EMBED_DIM = 32\n",
56 | "MODEL_PATH = \"model/spam_detect_char\"\n",
57 | "ctx = mx.cpu()"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 4,
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "name": "stdout",
67 | "output_type": "stream",
68 | "text": [
69 | "Found 43 unique tokens\n"
70 | ]
71 | }
72 | ],
73 | "source": [
74 | "with open(\"data/dataset.pkl\", 'rb') as f:\n",
75 | " sequences, labels, word2index = pickle.load(f)\n",
76 | " \n",
77 | "num_words = len(word2index)\n",
78 | "print(f\"Found {num_words} unique tokens\")"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 5,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "data = sequence.pad_sequences(sequences, maxlen=MAX_WORDS_IN_SEQ, padding='post', truncating='post')\n",
88 | "targets = to_categorical(labels)"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 6,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "Shape of data tensor: (33716, 3000)\n",
101 | "Shape of label tensor: (33716, 2)\n"
102 | ]
103 | }
104 | ],
105 | "source": [
106 | "print('Shape of data tensor:', data.shape)\n",
107 | "print('Shape of label tensor:', targets.shape)\n",
108 | "x_train, x_test, y_train, y_test = train_test_split(data, targets, test_size=0.25)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 15,
114 | "metadata": {},
115 | "outputs": [
116 | {
117 | "name": "stderr",
118 | "output_type": "stream",
119 | "text": [
120 | "/Users/dhoomilbsheta/deepl/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
121 | " if d.decorator_argspec is not None), _inspect.getargspec(target))\n"
122 | ]
123 | }
124 | ],
125 | "source": [
126 | "input_seq = Input(shape=[MAX_WORDS_IN_SEQ, ], dtype='int32')\n",
127 | "embed_seq = Embedding(num_words + 1, EMBED_DIM, input_length=MAX_WORDS_IN_SEQ)(\n",
128 | " input_seq)\n",
129 | "conv_1 = Conv1D(128, 5)(embed_seq)\n",
130 | "conv_1 = BatchNormalization()(conv_1)\n",
131 | "conv_1 = Activation(activation='relu')(conv_1)\n",
132 | "conv_1 = MaxPooling1D(pool_size=5)(conv_1)\n",
133 | "\n",
134 | "conv_2 = Conv1D(128, 5)(conv_1)\n",
135 | "conv_2 = BatchNormalization()(conv_2)\n",
136 | "conv_2 = Activation(activation='relu')(conv_2)\n",
137 | "conv_2 = MaxPooling1D(pool_size=5)(conv_2)\n",
138 | "\n",
139 | "conv_3 = Conv1D(128, 5)(conv_2)\n",
140 | "conv_3 = BatchNormalization()(conv_3)\n",
141 | "conv_3 = Activation(activation='relu')(conv_3)\n",
142 | "conv_3 = MaxPooling1D(pool_size=35)(conv_3)\n",
143 | "\n",
144 | "flat = Flatten()(conv_3)\n",
145 | "flat = Dropout(0.25)(flat)\n",
146 | "fc1 = Dense(128, activation='relu')(flat)\n",
147 | "dense_1 = Dropout(0.25)(flat)\n",
148 | "fc2 = Dense(2, activation='softmax')(fc1)\n",
149 | "\n",
150 | "model = Model(input_seq, fc2)\n",
151 | "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "name": "stderr",
161 | "output_type": "stream",
162 | "text": [
163 | "/Users/dhoomilbsheta/deepl/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n",
164 | " if d.decorator_argspec is not None), _inspect.getargspec(target))\n"
165 | ]
166 | },
167 | {
168 | "name": "stdout",
169 | "output_type": "stream",
170 | "text": [
171 | "Train on 25287 samples, validate on 8429 samples\n",
172 | "Epoch 1/5\n",
173 | "25287/25287 [==============================] - 1085s - loss: 0.1626 - acc: 0.9380 - val_loss: 0.9829 - val_acc: 0.5105\n",
174 | "Epoch 2/5\n",
175 | "25287/25287 [==============================] - 1086s - loss: 0.1038 - acc: 0.9605 - val_loss: 0.2687 - val_acc: 0.8602\n",
176 | "Epoch 3/5\n",
177 | "25287/25287 [==============================] - 1083s - loss: 0.0777 - acc: 0.9713 - val_loss: 0.1434 - val_acc: 0.9470\n",
178 | "Epoch 4/5\n",
179 | "25287/25287 [==============================] - 1054s - loss: 0.0598 - acc: 0.9786 - val_loss: 0.9278 - val_acc: 0.7392\n",
180 | "Epoch 5/5\n",
181 | "25216/25287 [============================>.] - ETA: 2s - loss: 0.0395 - acc: 0.9864"
182 | ]
183 | }
184 | ],
185 | "source": [
186 | "model = load_model(MODEL_PATH)\n",
187 | "model.fit(\n",
188 | " x_train,\n",
189 | " y_train,\n",
190 | " batch_size=128,\n",
191 | " epochs=5,\n",
192 | " callbacks=[ModelCheckpoint(MODEL_PATH, save_best_only=True)],\n",
193 | " validation_data=[x_test, y_test]\n",
194 | ")\n",
195 | "\n",
196 | "model.save(MODEL_PATH)"
197 | ]
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "metadata": {},
202 | "source": [
203 | "## MXNET Implementation"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 7,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "class MxModel(gluon.HybridBlock):\n",
213 | " def __init__(self, **kwargs):\n",
214 | " super(MxModel, self).__init__(**kwargs)\n",
215 | " with self.name_scope():\n",
216 | " self.embed = gluon.nn.Embedding(input_dim=num_words + 1, output_dim=EMBED_DIM)\n",
217 | " \n",
218 | " self.conv1 = gluon.nn.Conv1D(channels=128, kernel_size=5)\n",
219 | " self.conv2 = gluon.nn.Conv1D(channels=128, kernel_size=5)\n",
220 | " self.conv3 = gluon.nn.Conv1D(channels=128, kernel_size=5)\n",
221 | " \n",
222 | " self.bnorm1 = gluon.nn.BatchNorm()\n",
223 | " self.bnorm2 = gluon.nn.BatchNorm()\n",
224 | " self.bnorm3 = gluon.nn.BatchNorm()\n",
225 | " \n",
226 | " self.fc1 = gluon.nn.Dense(units=128)\n",
227 | " self.fc2 = gluon.nn.Dense(units=2)\n",
228 | " \n",
229 | " self.dropout = gluon.nn.Dropout(rate=0.25)\n",
230 | " def hybrid_forward(self, F, x, *args, **kwargs):\n",
231 | " x = self.embed(x)\n",
232 | " x = F.relu(self.bnorm1(self.conv1(x)))\n",
233 | " x = F.relu(self.bnorm2(self.conv2(x)))\n",
234 | " x = F.relu(self.bnorm3(self.conv3(x)))\n",
235 | " x = F.relu(self.dropout(self.fc1(x)))\n",
236 | " x = self.dropout(self.fc2(x))\n",
237 | " return x\n",
238 | " "
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 8,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "mx_model = MxModel()\n",
248 | "mx_model.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 9,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)\n",
258 | "trainer = gluon.Trainer(mx_model.collect_params(), 'adam', {'learning_rate': 0.001})\n",
259 | "acc = mx.metric.Accuracy()"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 10,
265 | "metadata": {},
266 | "outputs": [],
267 | "source": [
268 | "train_data = mx.io.NDArrayIter(data=x_train, label=y_train, batch_size=128, shuffle=True)\n",
269 | "test_data = mx.io.NDArrayIter(data=x_test, label=y_test, batch_size=128, shuffle=False)"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 14,
275 | "metadata": {},
276 | "outputs": [],
277 | "source": [
278 | "def evaluate_accuracy(data_iterator, net):\n",
279 | " data_iterator.reset()\n",
280 | " acc_test = mx.metric.Accuracy()\n",
281 | " for batch in data_iterator:\n",
282 | " data = batch.data[0].as_in_context(ctx)\n",
283 | " label = batch.label[0].as_in_context(ctx)\n",
284 | " output = net(data)\n",
285 | " acc_test.update(preds=output, labels=label)\n",
286 | " return acc_test.get()[1]"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": 13,
292 | "metadata": {
293 | "scrolled": true
294 | },
295 | "outputs": [
296 | {
297 | "name": "stdout",
298 | "output_type": "stream",
299 | "text": [
300 | "Epoch 1--------------\n",
301 | "loss: 0.7064136266708374 acc:0.5\n",
302 | "loss: 0.6858754754066467 acc:0.5\n",
303 | "loss: 0.6862083077430725 acc:0.5000386757425742\n",
304 | "loss: 0.6491576433181763 acc:0.5009054221854304\n",
305 | "val acc: 0.5065104166666666\n",
306 | "14m 19s\n",
307 | "Epoch 2--------------\n",
308 | "loss: 0.6357454061508179 acc:0.50390625\n",
309 | "loss: 0.6397137641906738 acc:0.5003063725490197\n",
310 | "loss: 0.6163472533226013 acc:0.49176206683168316\n",
311 | "loss: 0.5043906569480896 acc:0.47775248344370863\n",
312 | "val acc: nan\n",
313 | "27m 20s\n",
314 | "27m 20s\n"
315 | ]
316 | }
317 | ],
318 | "source": [
319 | "epochs = 2\n",
320 | "smoothing_constant = .01\n",
321 | "mx_model.hybridize()\n",
322 | "\n",
323 | "start = time.time()\n",
324 | "\n",
325 | "for e in range(epochs):\n",
326 | " print(f\"Epoch {e+1}--------------\")\n",
327 | " i = 0\n",
328 | " train_data.reset()\n",
329 | " for batch in train_data:\n",
330 | " data = batch.data[0].as_in_context(ctx)\n",
331 | " label = batch.label[0].as_in_context(ctx)\n",
332 | " with autograd.record():\n",
333 | " output = mx_model(data)\n",
334 | " loss = softmax_cross_entropy(output, label)\n",
335 | " loss.backward()\n",
336 | " trainer.step(data.shape[0])\n",
337 | "\n",
338 | " ##########################\n",
339 | " # Keep a moving average of the losses\n",
340 | " ##########################\n",
341 | " curr_loss = mx.nd.mean(loss).asscalar()\n",
342 | " acc.update(preds=output, labels=label)\n",
343 | " if i % 50 == 0:\n",
344 | " print(f\"loss: {curr_loss} acc:{acc.get()[1]}\")\n",
345 | " i += 1\n",
346 | " print(f\"val acc: {evaluate_accuracy(test_data, mx_model)}\")\n",
347 | " print(time_since(start))\n",
348 | " acc.reset()\n",
349 | " \n",
350 | "print(time_since(start))\n",
351 | "mx_model.save_params(\"data/mx_model\")"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": null,
357 | "metadata": {},
358 | "outputs": [],
359 | "source": []
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 21,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": []
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": null,
371 | "metadata": {},
372 | "outputs": [],
373 | "source": []
374 | }
375 | ],
376 | "metadata": {
377 | "kernelspec": {
378 | "display_name": "Python 3",
379 | "language": "python",
380 | "name": "python3"
381 | },
382 | "language_info": {
383 | "codemirror_mode": {
384 | "name": "ipython",
385 | "version": 3
386 | },
387 | "file_extension": ".py",
388 | "mimetype": "text/x-python",
389 | "name": "python",
390 | "nbconvert_exporter": "python",
391 | "pygments_lexer": "ipython3",
392 | "version": "3.6.1"
393 | }
394 | },
395 | "nbformat": 4,
396 | "nbformat_minor": 2
397 | }
398 |
--------------------------------------------------------------------------------