├── .gitignore ├── LICENSE ├── README.md ├── data_utils.py ├── email_utils.py ├── main.py ├── spam_detect_char ├── spam_detection.ipynb └── spam_email.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Dhoomil B Sheta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spam-detection-using-deep-learning 2 | Detecting Spam Emails using CNN 3 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from keras.preprocessing import text 3 | import pickle 4 | import os 5 | from bs4 import BeautifulSoup 6 | from email.parser import Parser 7 | 8 | parser = Parser() 9 | 10 | 11 | # Load Data 12 | def process_dataset(): 13 | data = pd.read_csv("data/enron.csv") 14 | print(f"Total emails: {len(data)}") 15 | emails = data['msg'].values 16 | labels = [1 if x == "spam" else 0 for x in data['label'].values] 17 | 18 | # Pre-process Data 19 | # tokenizer = text.Tokenizer(char_level=True) 20 | # tokenizer.fit_on_texts(emails) 21 | # sequences = tokenizer.texts_to_sequences(emails) 22 | # word2index = tokenizer.word_index 23 | # num_words = len(word2index) 24 | # print(f"Found {num_words} unique tokens") 25 | alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:" 26 | char2index = {} 27 | for i, c in enumerate(alphabet): 28 | char2index[c] = i + 1 29 | 30 | sequences = [] 31 | for email in emails: 32 | seq = [] 33 | for c in email: 34 | if c in char2index: 35 | seq.append(char2index[c]) 36 | sequences.append(seq) 37 | 38 | with open("data/dataset.pkl", 'wb') as f: 39 | pickle.dump([sequences, labels, char2index], f) 40 | 41 | 42 | process_dataset() 43 | 44 | 45 | def process_email(filename): 46 | with open(filename) as f: 47 | email = parser.parse(f) 48 | cleantext = "" 49 | if email.is_multipart(): 50 | for part in email.get_payload(): 51 | soup = BeautifulSoup(part.as_string(maxheaderlen=1)) 52 | txt = soup.get_text() 53 | txt = ' '.join(txt.split()) 54 | i = txt.find("Content-Transfer-Encoding") 55 | txt = txt[i + len("Content-Transfer-Encoding"):].split(maxsplit=2)[2] 56 | cleantext += txt 57 | 58 | else: 59 | soup = BeautifulSoup(email.get_payload()) 60 | txt = soup.get_text() 61 | txt = ' '.join(txt.split()) 62 | i = txt.find("Content-Transfer-Encoding") 63 | txt = txt[i + len("Content-Transfer-Encoding"):].split(maxsplit=2)[2] 64 | cleantext += txt 65 | print(cleantext) 66 | 67 | # for filename in os.listdir("data/infy_spam_emails"): 68 | # process_email(f"data/infy_spam_emails/{filename}") 69 | -------------------------------------------------------------------------------- /email_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # FileName: Subsampling.py 3 | # Version 1.0 by Tao Ban, 2010.5.26 4 | # This function extract all the contents, ie subject and first part from the .eml file 5 | # and store it in a new file with the same name in the dst dir. 6 | 7 | import email.parser 8 | import os 9 | import stat 10 | import sys 11 | 12 | 13 | def ExtractSubPayload(filename): 14 | ''' Extract the subject and payload from the .eml file. 15 | 16 | ''' 17 | if not os.path.exists(filename): # dest path doesnot exist 18 | print("ERROR: input file does not exist:", filename) 19 | os.exit(1) 20 | fp = open(filename) 21 | msg = email.message_from_file(fp) 22 | payload = msg.get_payload() 23 | if type(payload) == type(list()): 24 | payload = payload[0] # only use the first part of payload 25 | sub = msg.get('subject') 26 | sub = str(sub) 27 | if type(payload) != type(''): 28 | payload = str(payload) 29 | 30 | return sub + payload 31 | 32 | 33 | def ExtractBodyFromDir(srcdir, dstdir): 34 | '''Extract the body information from all .eml files in the srcdir and 35 | 36 | save the file to the dstdir with the same name.''' 37 | if not os.path.exists(dstdir): # dest path doesnot exist 38 | os.makedirs(dstdir) 39 | files = os.listdir(srcdir) 40 | for file in files: 41 | srcpath = os.path.join(srcdir, file) 42 | dstpath = os.path.join(dstdir, file) 43 | src_info = os.stat(srcpath) 44 | if stat.S_ISDIR(src_info.st_mode): # for subfolders, recurse 45 | ExtractBodyFromDir(srcpath, dstpath) 46 | else: # copy the file 47 | body = ExtractSubPayload(srcpath) 48 | dstfile = open(dstpath, 'w') 49 | dstfile.write(body) 50 | dstfile.close() 51 | 52 | 53 | ################################################################### 54 | # main function start here 55 | # srcdir is the directory where the .eml are stored 56 | print('Input source directory: ') # ask for source and dest dirs 57 | srcdir = input() 58 | if not os.path.exists(srcdir): 59 | print('The source directory %s does not exist, exit...' % (srcdir)) 60 | sys.exit() 61 | # dstdir is the directory where the content .eml are stored 62 | print('Input destination directory: ') # ask for source and dest dirs 63 | dstdir = input() 64 | if not os.path.exists(dstdir): 65 | print('The destination directory is newly created.') 66 | os.makedirs(dstdir) 67 | 68 | ################################################################### 69 | ExtractBodyFromDir(srcdir, dstdir) 70 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from keras.utils import to_categorical 2 | from keras.preprocessing import sequence 3 | from mxnet import gluon 4 | from sklearn.model_selection import train_test_split 5 | from keras.models import Model, load_model 6 | from keras.layers import Conv1D, GlobalMaxPooling1D, Dropout, Dense, Input, Embedding, MaxPooling1D, Flatten 7 | from keras.callbacks import ModelCheckpoint 8 | import numpy as np 9 | import pickle 10 | 11 | MAX_WORDS_IN_SEQ = 1000 12 | EMBED_DIM = 100 13 | MODEL_PATH = "models/spam_detect" 14 | 15 | # Load Data 16 | with open("data/dataset.pkl", 'rb') as f: 17 | sequences, labels, word2index = pickle.load(f) 18 | 19 | num_words = len(word2index) 20 | print(f"Found {num_words} unique tokens") 21 | 22 | data = sequence.pad_sequences(sequences, maxlen=MAX_WORDS_IN_SEQ, padding='post', truncating='post') 23 | print(labels[:10]) 24 | labels = to_categorical(labels) 25 | print(labels[:10]) 26 | 27 | print('Shape of data tensor:', data.shape) 28 | print('Shape of label tensor:', labels.shape) 29 | 30 | x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2) 31 | 32 | # Building the model 33 | input_seq = Input(shape=[MAX_WORDS_IN_SEQ, ], dtype='int32') 34 | embed_seq = Embedding(num_words + 1, EMBED_DIM, embeddings_initializer='glorot_normal', input_length=MAX_WORDS_IN_SEQ)( 35 | input_seq) 36 | conv_1 = Conv1D(128, 5, activation='relu')(embed_seq) 37 | conv_1 = MaxPooling1D(pool_size=5)(conv_1) 38 | conv_2 = Conv1D(128, 5, activation='relu')(conv_1) 39 | conv_2 = MaxPooling1D(pool_size=5)(conv_2) 40 | conv_3 = Conv1D(128, 5, activation='relu')(conv_2) 41 | conv_3 = MaxPooling1D(pool_size=35)(conv_3) 42 | flat = Flatten()(conv_3) 43 | flat = Dropout(0.25)(flat) 44 | fc1 = Dense(128, activation='relu')(flat) 45 | dense_1 = Dropout(0.25)(flat) 46 | fc2 = Dense(2, activation='softmax')(fc1) 47 | 48 | model = Model(input_seq, fc2) 49 | model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc']) 50 | 51 | # Train the model 52 | model.fit( 53 | x_train, 54 | y_train, 55 | batch_size=128, 56 | epochs=2, 57 | callbacks=[ModelCheckpoint(MODEL_PATH, save_best_only=True)], 58 | validation_data=[x_test, y_test] 59 | ) 60 | 61 | model.save(MODEL_PATH) 62 | 63 | 64 | class CnnClassifierModel(gluon.HybridBlock): 65 | def __init__(self, **kwargs): 66 | super(CnnClassifierModel, self).__init__(**kwargs) 67 | with self.name_scope(): 68 | self.conv1 = gluon.nn.Conv1D() 69 | 70 | def hybrid_forward(self, F, x, *args, **kwargs): 71 | pass 72 | -------------------------------------------------------------------------------- /spam_detect_char: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbsheta/spam-detection-using-deep-learning/872b5086bfa72f2e6a901924c875cf76fe5b7cdb/spam_detect_char -------------------------------------------------------------------------------- /spam_detection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Spam Detection Using CNN" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "Using TensorFlow backend.\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "from keras.utils import to_categorical\n", 25 | "from keras.preprocessing import sequence, text\n", 26 | "from sklearn.model_selection import train_test_split\n", 27 | "from keras.models import Model, load_model\n", 28 | "from keras.layers import Conv1D, GlobalMaxPooling1D, Dropout, Dense, Input, Embedding, MaxPooling1D, Flatten\n", 29 | "from keras.callbacks import ModelCheckpoint\n", 30 | "import numpy as np\n", 31 | "import pandas as pd\n", 32 | "\n", 33 | "MAX_WORDS_IN_SEQ = 1000\n", 34 | "EMBED_DIM = 100\n", 35 | "MODEL_NAME = \"/model/spam_detect\"" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Load Data" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "Total emails: 33716\n" 55 | ] 56 | }, 57 | { 58 | "data": { 59 | "text/html": [ 60 | "
\n", 61 | "\n", 74 | "\n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | "
labelindexmsgdatasetfile
0spam0Subject: dobmeos with hgh my energy level has ...1enron1/spam/0006.2003-12-18.GP.spam.txt
1spam1Subject: your prescription is ready . . oxwq s...1enron1/spam/0008.2003-12-18.GP.spam.txt
2ham2Subject: christmas tree farm pictures1enron1/ham/0001.1999-12-10.farmer.ham.txt
3ham3Subject: vastar resources , inc .gary , produc...1enron1/ham/0002.1999-12-13.farmer.ham.txt
4ham4Subject: calpine daily gas nomination- calpine...1enron1/ham/0003.1999-12-14.farmer.ham.txt
\n", 128 | "
" 129 | ], 130 | "text/plain": [ 131 | " label index msg dataset \\\n", 132 | "0 spam 0 Subject: dobmeos with hgh my energy level has ... 1 \n", 133 | "1 spam 1 Subject: your prescription is ready . . oxwq s... 1 \n", 134 | "2 ham 2 Subject: christmas tree farm pictures 1 \n", 135 | "3 ham 3 Subject: vastar resources , inc .gary , produc... 1 \n", 136 | "4 ham 4 Subject: calpine daily gas nomination- calpine... 1 \n", 137 | "\n", 138 | " file \n", 139 | "0 enron1/spam/0006.2003-12-18.GP.spam.txt \n", 140 | "1 enron1/spam/0008.2003-12-18.GP.spam.txt \n", 141 | "2 enron1/ham/0001.1999-12-10.farmer.ham.txt \n", 142 | "3 enron1/ham/0002.1999-12-13.farmer.ham.txt \n", 143 | "4 enron1/ham/0003.1999-12-14.farmer.ham.txt " 144 | ] 145 | }, 146 | "execution_count": 3, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "data = pd.read_csv(\"~/Development/datasets/enron.csv\")\n", 153 | "print(f\"Total emails: {len(data)}\")\n", 154 | "data.head()" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 4, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "emails = data['msg'].values\n", 164 | "labels = [1 if x == \"spam\" else 0 for x in data['label'].values]" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 5, 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "data": { 174 | "text/plain": [ 175 | "226609" 176 | ] 177 | }, 178 | "execution_count": 5, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "max_len = max(map(lambda x: len(x), emails))\n", 185 | "max_len" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "## Pre-Process Data" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 89, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "Found 309362 unique tokens\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "tokenizer = text.Tokenizer()\n", 210 | "tokenizer.fit_on_texts(emails)\n", 211 | "sequences = tokenizer.texts_to_sequences(emails)\n", 212 | "word2index = tokenizer.word_index\n", 213 | "num_words = len(word2index)\n", 214 | "print(f\"Found {num_words} unique tokens\")\n" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 90, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "[1, 1, 0, 0, 0, 0, 0, 1, 0, 1]\n", 227 | "[[ 0. 1.]\n", 228 | " [ 0. 1.]\n", 229 | " [ 1. 0.]\n", 230 | " [ 1. 0.]\n", 231 | " [ 1. 0.]\n", 232 | " [ 1. 0.]\n", 233 | " [ 1. 0.]\n", 234 | " [ 0. 1.]\n", 235 | " [ 1. 0.]\n", 236 | " [ 0. 1.]]\n", 237 | "Shape of data tensor: (33716, 1000)\n", 238 | "Shape of label tensor: (33716, 2)\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "data = sequence.pad_sequences(sequences, maxlen=MAX_WORDS_IN_SEQ, padding='post', truncating='post')\n", 244 | "print(labels[:10])\n", 245 | "labels = to_categorical(labels)\n", 246 | "print(labels[:10])\n", 247 | "\n", 248 | "print('Shape of data tensor:', data.shape)\n", 249 | "print('Shape of label tensor:', labels.shape)\n", 250 | "\n", 251 | "x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "## Building the Model: Basic CNN" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 91, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "input_seq = Input(shape=[MAX_WORDS_IN_SEQ, ], dtype='int32')\n", 268 | "embed_seq = Embedding(num_words, EMBED_DIM, embeddings_initializer='glorot_uniform', input_length=MAX_WORDS_IN_SEQ)(\n", 269 | " input_seq)\n", 270 | "conv_1 = Conv1D(128, 5, activation='relu')(embed_seq)\n", 271 | "conv_1 = MaxPooling1D(pool_size=5)(conv_1)\n", 272 | "conv_2 = Conv1D(128, 5, activation='relu')(conv_1)\n", 273 | "conv_2 = MaxPooling1D(pool_size=5)(conv_2)\n", 274 | "conv_3 = Conv1D(128, 5, activation='relu')(conv_2)\n", 275 | "conv_3 = MaxPooling1D(pool_size=35)(conv_3)\n", 276 | "flat = Flatten()(conv_3)\n", 277 | "# flat = Dropout(0.25)(flat)\n", 278 | "fc1 = Dense(128, activation='relu')(flat)\n", 279 | "# dense_1 = Dropout(0.25)(flat)\n", 280 | "fc2 = Dense(2, activation='softmax')(fc1)\n", 281 | "\n", 282 | "model = Model(input_seq, fc2)\n", 283 | "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 92, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "_________________________________________________________________\n", 296 | "Layer (type) Output Shape Param # \n", 297 | "=================================================================\n", 298 | "input_8 (InputLayer) (None, 1000) 0 \n", 299 | "_________________________________________________________________\n", 300 | "embedding_7 (Embedding) (None, 1000, 100) 30936200 \n", 301 | "_________________________________________________________________\n", 302 | "conv1d_19 (Conv1D) (None, 996, 128) 64128 \n", 303 | "_________________________________________________________________\n", 304 | "max_pooling1d_19 (MaxPooling (None, 199, 128) 0 \n", 305 | "_________________________________________________________________\n", 306 | "conv1d_20 (Conv1D) (None, 195, 128) 82048 \n", 307 | "_________________________________________________________________\n", 308 | "max_pooling1d_20 (MaxPooling (None, 39, 128) 0 \n", 309 | "_________________________________________________________________\n", 310 | "conv1d_21 (Conv1D) (None, 35, 128) 82048 \n", 311 | "_________________________________________________________________\n", 312 | "max_pooling1d_21 (MaxPooling (None, 1, 128) 0 \n", 313 | "_________________________________________________________________\n", 314 | "flatten_7 (Flatten) (None, 128) 0 \n", 315 | "_________________________________________________________________\n", 316 | "dense_13 (Dense) (None, 128) 16512 \n", 317 | "_________________________________________________________________\n", 318 | "dense_14 (Dense) (None, 2) 258 \n", 319 | "=================================================================\n", 320 | "Total params: 31,181,194\n", 321 | "Trainable params: 31,181,194\n", 322 | "Non-trainable params: 0\n", 323 | "_________________________________________________________________\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "# Testing ---------------------------------------\n", 329 | "model.summary()" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "model.fit()" 339 | ] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "Python 3", 345 | "language": "python", 346 | "name": "python3" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.6.1" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 1 363 | } 364 | -------------------------------------------------------------------------------- /spam_email.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using TensorFlow backend.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "from keras.preprocessing import sequence\n", 18 | "from keras.utils import to_categorical\n", 19 | "from keras.models import Model, load_model\n", 20 | "from keras.layers import Conv1D, Dropout, Dense, Input, Embedding, MaxPooling1D, Flatten, BatchNormalization, Activation\n", 21 | "from keras.callbacks import ModelCheckpoint\n", 22 | "from sklearn.model_selection import train_test_split\n", 23 | "\n", 24 | "import mxnet as mx\n", 25 | "from mxnet import gluon\n", 26 | "from mxnet import autograd\n", 27 | "\n", 28 | "import pickle\n", 29 | "import numpy as np\n", 30 | "import time\n", 31 | "import math" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "def time_since(start):\n", 41 | " now = time.time()\n", 42 | " s = now - start\n", 43 | " m = math.floor(s / 60)\n", 44 | " s -= m * 60\n", 45 | " return '%dm %ds' % (m, s)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 20, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "MAX_WORDS_IN_SEQ = 3000\n", 55 | "EMBED_DIM = 32\n", 56 | "MODEL_PATH = \"model/spam_detect_char\"\n", 57 | "ctx = mx.cpu()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Found 43 unique tokens\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "with open(\"data/dataset.pkl\", 'rb') as f:\n", 75 | " sequences, labels, word2index = pickle.load(f)\n", 76 | " \n", 77 | "num_words = len(word2index)\n", 78 | "print(f\"Found {num_words} unique tokens\")" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "data = sequence.pad_sequences(sequences, maxlen=MAX_WORDS_IN_SEQ, padding='post', truncating='post')\n", 88 | "targets = to_categorical(labels)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Shape of data tensor: (33716, 3000)\n", 101 | "Shape of label tensor: (33716, 2)\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "print('Shape of data tensor:', data.shape)\n", 107 | "print('Shape of label tensor:', targets.shape)\n", 108 | "x_train, x_test, y_train, y_test = train_test_split(data, targets, test_size=0.25)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 15, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stderr", 118 | "output_type": "stream", 119 | "text": [ 120 | "/Users/dhoomilbsheta/deepl/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n", 121 | " if d.decorator_argspec is not None), _inspect.getargspec(target))\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "input_seq = Input(shape=[MAX_WORDS_IN_SEQ, ], dtype='int32')\n", 127 | "embed_seq = Embedding(num_words + 1, EMBED_DIM, input_length=MAX_WORDS_IN_SEQ)(\n", 128 | " input_seq)\n", 129 | "conv_1 = Conv1D(128, 5)(embed_seq)\n", 130 | "conv_1 = BatchNormalization()(conv_1)\n", 131 | "conv_1 = Activation(activation='relu')(conv_1)\n", 132 | "conv_1 = MaxPooling1D(pool_size=5)(conv_1)\n", 133 | "\n", 134 | "conv_2 = Conv1D(128, 5)(conv_1)\n", 135 | "conv_2 = BatchNormalization()(conv_2)\n", 136 | "conv_2 = Activation(activation='relu')(conv_2)\n", 137 | "conv_2 = MaxPooling1D(pool_size=5)(conv_2)\n", 138 | "\n", 139 | "conv_3 = Conv1D(128, 5)(conv_2)\n", 140 | "conv_3 = BatchNormalization()(conv_3)\n", 141 | "conv_3 = Activation(activation='relu')(conv_3)\n", 142 | "conv_3 = MaxPooling1D(pool_size=35)(conv_3)\n", 143 | "\n", 144 | "flat = Flatten()(conv_3)\n", 145 | "flat = Dropout(0.25)(flat)\n", 146 | "fc1 = Dense(128, activation='relu')(flat)\n", 147 | "dense_1 = Dropout(0.25)(flat)\n", 148 | "fc2 = Dense(2, activation='softmax')(fc1)\n", 149 | "\n", 150 | "model = Model(input_seq, fc2)\n", 151 | "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stderr", 161 | "output_type": "stream", 162 | "text": [ 163 | "/Users/dhoomilbsheta/deepl/lib/python3.6/site-packages/tensorflow/python/util/tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()\n", 164 | " if d.decorator_argspec is not None), _inspect.getargspec(target))\n" 165 | ] 166 | }, 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "Train on 25287 samples, validate on 8429 samples\n", 172 | "Epoch 1/5\n", 173 | "25287/25287 [==============================] - 1085s - loss: 0.1626 - acc: 0.9380 - val_loss: 0.9829 - val_acc: 0.5105\n", 174 | "Epoch 2/5\n", 175 | "25287/25287 [==============================] - 1086s - loss: 0.1038 - acc: 0.9605 - val_loss: 0.2687 - val_acc: 0.8602\n", 176 | "Epoch 3/5\n", 177 | "25287/25287 [==============================] - 1083s - loss: 0.0777 - acc: 0.9713 - val_loss: 0.1434 - val_acc: 0.9470\n", 178 | "Epoch 4/5\n", 179 | "25287/25287 [==============================] - 1054s - loss: 0.0598 - acc: 0.9786 - val_loss: 0.9278 - val_acc: 0.7392\n", 180 | "Epoch 5/5\n", 181 | "25216/25287 [============================>.] - ETA: 2s - loss: 0.0395 - acc: 0.9864" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "model = load_model(MODEL_PATH)\n", 187 | "model.fit(\n", 188 | " x_train,\n", 189 | " y_train,\n", 190 | " batch_size=128,\n", 191 | " epochs=5,\n", 192 | " callbacks=[ModelCheckpoint(MODEL_PATH, save_best_only=True)],\n", 193 | " validation_data=[x_test, y_test]\n", 194 | ")\n", 195 | "\n", 196 | "model.save(MODEL_PATH)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "## MXNET Implementation" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 7, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "class MxModel(gluon.HybridBlock):\n", 213 | " def __init__(self, **kwargs):\n", 214 | " super(MxModel, self).__init__(**kwargs)\n", 215 | " with self.name_scope():\n", 216 | " self.embed = gluon.nn.Embedding(input_dim=num_words + 1, output_dim=EMBED_DIM)\n", 217 | " \n", 218 | " self.conv1 = gluon.nn.Conv1D(channels=128, kernel_size=5)\n", 219 | " self.conv2 = gluon.nn.Conv1D(channels=128, kernel_size=5)\n", 220 | " self.conv3 = gluon.nn.Conv1D(channels=128, kernel_size=5)\n", 221 | " \n", 222 | " self.bnorm1 = gluon.nn.BatchNorm()\n", 223 | " self.bnorm2 = gluon.nn.BatchNorm()\n", 224 | " self.bnorm3 = gluon.nn.BatchNorm()\n", 225 | " \n", 226 | " self.fc1 = gluon.nn.Dense(units=128)\n", 227 | " self.fc2 = gluon.nn.Dense(units=2)\n", 228 | " \n", 229 | " self.dropout = gluon.nn.Dropout(rate=0.25)\n", 230 | " def hybrid_forward(self, F, x, *args, **kwargs):\n", 231 | " x = self.embed(x)\n", 232 | " x = F.relu(self.bnorm1(self.conv1(x)))\n", 233 | " x = F.relu(self.bnorm2(self.conv2(x)))\n", 234 | " x = F.relu(self.bnorm3(self.conv3(x)))\n", 235 | " x = F.relu(self.dropout(self.fc1(x)))\n", 236 | " x = self.dropout(self.fc2(x))\n", 237 | " return x\n", 238 | " " 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 8, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "mx_model = MxModel()\n", 248 | "mx_model.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 9, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)\n", 258 | "trainer = gluon.Trainer(mx_model.collect_params(), 'adam', {'learning_rate': 0.001})\n", 259 | "acc = mx.metric.Accuracy()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 10, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "train_data = mx.io.NDArrayIter(data=x_train, label=y_train, batch_size=128, shuffle=True)\n", 269 | "test_data = mx.io.NDArrayIter(data=x_test, label=y_test, batch_size=128, shuffle=False)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 14, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "def evaluate_accuracy(data_iterator, net):\n", 279 | " data_iterator.reset()\n", 280 | " acc_test = mx.metric.Accuracy()\n", 281 | " for batch in data_iterator:\n", 282 | " data = batch.data[0].as_in_context(ctx)\n", 283 | " label = batch.label[0].as_in_context(ctx)\n", 284 | " output = net(data)\n", 285 | " acc_test.update(preds=output, labels=label)\n", 286 | " return acc_test.get()[1]" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 13, 292 | "metadata": { 293 | "scrolled": true 294 | }, 295 | "outputs": [ 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "Epoch 1--------------\n", 301 | "loss: 0.7064136266708374 acc:0.5\n", 302 | "loss: 0.6858754754066467 acc:0.5\n", 303 | "loss: 0.6862083077430725 acc:0.5000386757425742\n", 304 | "loss: 0.6491576433181763 acc:0.5009054221854304\n", 305 | "val acc: 0.5065104166666666\n", 306 | "14m 19s\n", 307 | "Epoch 2--------------\n", 308 | "loss: 0.6357454061508179 acc:0.50390625\n", 309 | "loss: 0.6397137641906738 acc:0.5003063725490197\n", 310 | "loss: 0.6163472533226013 acc:0.49176206683168316\n", 311 | "loss: 0.5043906569480896 acc:0.47775248344370863\n", 312 | "val acc: nan\n", 313 | "27m 20s\n", 314 | "27m 20s\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "epochs = 2\n", 320 | "smoothing_constant = .01\n", 321 | "mx_model.hybridize()\n", 322 | "\n", 323 | "start = time.time()\n", 324 | "\n", 325 | "for e in range(epochs):\n", 326 | " print(f\"Epoch {e+1}--------------\")\n", 327 | " i = 0\n", 328 | " train_data.reset()\n", 329 | " for batch in train_data:\n", 330 | " data = batch.data[0].as_in_context(ctx)\n", 331 | " label = batch.label[0].as_in_context(ctx)\n", 332 | " with autograd.record():\n", 333 | " output = mx_model(data)\n", 334 | " loss = softmax_cross_entropy(output, label)\n", 335 | " loss.backward()\n", 336 | " trainer.step(data.shape[0])\n", 337 | "\n", 338 | " ##########################\n", 339 | " # Keep a moving average of the losses\n", 340 | " ##########################\n", 341 | " curr_loss = mx.nd.mean(loss).asscalar()\n", 342 | " acc.update(preds=output, labels=label)\n", 343 | " if i % 50 == 0:\n", 344 | " print(f\"loss: {curr_loss} acc:{acc.get()[1]}\")\n", 345 | " i += 1\n", 346 | " print(f\"val acc: {evaluate_accuracy(test_data, mx_model)}\")\n", 347 | " print(time_since(start))\n", 348 | " acc.reset()\n", 349 | " \n", 350 | "print(time_since(start))\n", 351 | "mx_model.save_params(\"data/mx_model\")" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 21, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [] 374 | } 375 | ], 376 | "metadata": { 377 | "kernelspec": { 378 | "display_name": "Python 3", 379 | "language": "python", 380 | "name": "python3" 381 | }, 382 | "language_info": { 383 | "codemirror_mode": { 384 | "name": "ipython", 385 | "version": 3 386 | }, 387 | "file_extension": ".py", 388 | "mimetype": "text/x-python", 389 | "name": "python", 390 | "nbconvert_exporter": "python", 391 | "pygments_lexer": "ipython3", 392 | "version": "3.6.1" 393 | } 394 | }, 395 | "nbformat": 4, 396 | "nbformat_minor": 2 397 | } 398 | --------------------------------------------------------------------------------