├── .gitignore ├── README.md ├── notebook.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | data/* 104 | cooking.* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text classification experiment using fastText 2 | 3 | ## Goal 4 | 5 | The goal of text classification is to assign documents (such as emails, posts, text messages, etc) to one or multiple categories (review scores, spam vs non-spam, topics, etc). The dominant approach to build such classificers is ML, that is learning classification rules from examples. 6 | 7 | ## Data 8 | 9 | In order to build such classifiers, we need labeled data, wich consist of documents and their corresponding categories. In this example, we build a classifier which automatically classifies stackexchange questions about cooking into one of several possible tags, such as `pot`, `bowl` or `baking`. 10 | 11 | ## Classification engine 12 | 13 | Facebook AI Research (FAIR) lab [open-sourced](https://github.com/facebookresearch/fastText) fastText on [August 2016](https://code.facebook.com/posts/1438652669495149/fair-open-sources-fasttext/), a library designed to help build scalable solutions for text representation and classification. FastText combines some of the most successful concepts introduced by the natural language processing and machine learning communities in the last few decades. 14 | 15 | ## Setup 16 | 17 | To run the experiment, a python environment and fastText are needed: 18 | 19 | Install fastText ([detailed instructions here](https://github.com/facebookresearch/fastText/blob/master/README.md#requirements)): 20 | 21 | ``` 22 | $ git clone https://github.com/facebookresearch/fastText.git 23 | $ cd fastText 24 | $ make 25 | ``` 26 | 27 | Clone the project: 28 | 29 | ``` 30 | $ git clone https://github.com/mpuig/textclassification 31 | $ cd textclassification 32 | ``` 33 | 34 | Create a virtual environment, activate and install python packages: 35 | 36 | ``` 37 | $ python3.6 -m venv venv 38 | $ source venv/bin/activate 39 | $ pip install Cython 40 | $ pip install -r requirements.txt 41 | ``` 42 | 43 | Create the output directory for classification models: 44 | 45 | ``` 46 | $ mkdir models 47 | ``` 48 | 49 | ## Getting and preparing the data 50 | 51 | As mentioned in the introduction, we need labeled data to train our supervised classifier. In this tutorial, we are interested in building a classifier to automatically recognize the topic of a stackexchange question about cooking. Let's download examples of questions from [the cooking section of Stackexchange](http://cooking.stackexchange.com/), and their associated tags: 52 | 53 | ``` 54 | $ mkdir data 55 | $ wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/cooking.stackexchange.tar.gz 56 | $ tar xvzf cooking.stackexchange.tar.gz -C data 57 | $ head data/cooking.stackexchange.txt 58 | ``` 59 | 60 | Each line of the text file contains a list of labels, followed by the corresponding document. All the labels start by the `__label__` prefix, which is how fastText recognize what is a label or what is a word. The model is then trained to predict the labels given the word in the document. 61 | 62 | Before training our first classifier, we need to split the data into train and validation. We will use the validation set to evaluate how good the learned classifier is on new data. 63 | 64 | ``` 65 | $ wc data/cooking.stackexchange.txt 66 | 15404 169582 1401900 data/cooking.stackexchange.txt 67 | ``` 68 | 69 | Our full dataset contains 15404 examples. Let's split it into a training set of 12404 examples and a validation set of 3000 examples: 70 | 71 | ``` 72 | $ head -n 12404 data/cooking.stackexchange.txt > data/cooking.train 73 | $ tail -n 3000 data/cooking.stackexchange.txt > data/cooking.test 74 | ``` 75 | 76 | ## Run the notebook: 77 | 78 | ``` 79 | $ jupyter notebook notebook.ipynb 80 | ``` 81 | [Open your browser](http://localhost:8888/notebooks/notebook.ipynb) 82 | 83 | ## Improvements to be done: 84 | 85 | - Use [Gensim Phrases](https://radimrehurek.com/gensim/models/phrases.html#module-gensim.models.phrases) 86 | - Use bigrams 87 | - Apply ideas from [this blog](https://blog.lateral.io/2016/09/fasttext-based-hybrid-recommender/) 88 | - Apply ideas from [this blog](https://bbengfort.github.io/tutorials/2016/05/19/text-classification-nltk-sckit-learn.html) 89 | 90 | ## Thanks 91 | 92 | Some ideas behing the code: 93 | - read and write functions, using python generators: [Francesco Bruni code](https://github.com/brunifrancesco/nltk_base/blob/master/2nd.ipynb) 94 | - nltk: [Gensim examples](https://github.com/RaRe-Technologies/gensim/blob/master/docs/notebooks/doc2vec-IMDB.ipynb) -------------------------------------------------------------------------------- /notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Fasttext Supervised learning example \n", 10 | "\n", 11 | "This notebook is inspired by the [Supervised Learning fastText tutorial](https://github.com/facebookresearch/fastText/blob/master/tutorials/supervised-learning.md)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": { 18 | "collapsed": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "def read_data(filename):\n", 23 | " \"\"\"\n", 24 | " Read data 'line by line', using generators.\n", 25 | " Generators make it easier to process BIG text files.\n", 26 | " \"\"\"\n", 27 | " with open(filename, 'r') as input:\n", 28 | " for line in input:\n", 29 | " yield line" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": { 36 | "collapsed": true 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "def write_data(filename, data):\n", 41 | " \"\"\"\n", 42 | " Write result to a file.\n", 43 | " \n", 44 | " :param result: the list to be written to the file\n", 45 | " \"\"\"\n", 46 | " with open(filename, \"a\") as output:\n", 47 | " output.write('{}\\n'.format(data))" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": { 54 | "collapsed": true 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "from string import punctuation\n", 59 | "from nltk.corpus import stopwords\n", 60 | "\n", 61 | "def preprocess(data):\n", 62 | " \"\"\"\n", 63 | " Preprocess data, filtering out stopwords, punctuation and lowering \n", 64 | " all splitted tokens.\n", 65 | " \n", 66 | " :param data: the string data to be processed\n", 67 | " \"\"\" \n", 68 | " # Pad punctuation with spaces on both sides\n", 69 | " for char in ['.', '\"', ',', '(', ')', '!', '?', ';', ':']:\n", 70 | " data = data.replace(char, ' ' + char + ' ')\n", 71 | " sw = stopwords.words('english')\n", 72 | " splitted_chunks = data.split()\n", 73 | " lowered_chunks = (item.lower() for item in splitted_chunks)\n", 74 | " chunks_without_punctuation = (chunk for chunk in lowered_chunks if chunk not in punctuation)\n", 75 | " chunks_without_stopwords = (chunk for chunk in chunks_without_punctuation if chunk not in sw)\n", 76 | " return \" \".join(list(chunks_without_stopwords))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": { 83 | "collapsed": true 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "from itertools import islice\n", 88 | "\n", 89 | "def pipeline(input_filename, output_filename, limit=None):\n", 90 | " \"\"\"\n", 91 | " Iterate over the rows and apply the text preprocessing.\n", 92 | "\n", 93 | " :param input_filename: name of the input filename\n", 94 | " :param output_filename: name of the output filename\n", 95 | " :param limit: get the first N rows\n", 96 | " \"\"\" \n", 97 | " open(output_filename, 'w').close() # Hack to \"reset\" the output file\n", 98 | " for row in islice(read_data(input_filename), 0, limit):\n", 99 | " data = preprocess(row)\n", 100 | " if data:\n", 101 | " write_data(output_filename, data)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 5, 107 | "metadata": { 108 | "collapsed": true 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "def test_model(model, test_data):\n", 113 | " result = model.test(test_data)\n", 114 | " print('Precision@1:', result.precision)\n", 115 | " print('Recall@1:', result.recall)\n", 116 | " print('Number of examples:', result.nexamples)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "metadata": { 123 | "collapsed": true 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "from os import path\n", 128 | "\n", 129 | "data_dir = path.join(path.dirname(\"__file__\"), 'data')\n", 130 | "cooking_input = path.join(data_dir, 'cooking.train')\n", 131 | "cooking_input_norm = path.join(data_dir, 'cooking.train_norm')\n", 132 | "cooking_test = path.join(data_dir, 'cooking.test')\n", 133 | "cooking_test_norm = path.join(data_dir, 'cooking.test_norm')" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 7, 139 | "metadata": { 140 | "collapsed": true 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "pipeline(cooking_input, cooking_input_norm)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "### Using fasttext" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 8, 157 | "metadata": { 158 | "collapsed": true 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "import fasttext as ft\n", 163 | "\n", 164 | "# Info to save the model\n", 165 | "model_dir = path.join(path.dirname(\"__file__\"), 'models')\n", 166 | "cooking_output = path.join(model_dir, 'cooking')" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "### Not normalized input" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 9, 179 | "metadata": { 180 | "collapsed": true 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "cooking_model = ft.supervised(cooking_input, cooking_output, lr=1.0, epoch=10, silent=0)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 10, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "Precision@1: 0.5473333333333333\n", 197 | "Recall@1: 0.2367017442698573\n", 198 | "Number of examples: 3000\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "test_model(cooking_model, cooking_test)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "### Normalized input" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 11, 216 | "metadata": { 217 | "collapsed": true 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "cooking_norm_model = ft.supervised(cooking_input_norm, cooking_output, lr=1.0, epoch=10, silent=0)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 12, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "Precision@1: 0.5926666666666667\n", 234 | "Recall@1: 0.25630676084762866\n", 235 | "Number of examples: 3000\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "pipeline(cooking_test, cooking_test_norm)\n", 241 | "test_model(cooking_norm_model, cooking_test_norm)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "### Load existing test" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 13, 254 | "metadata": { 255 | "collapsed": true 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "# cooking_output_filename = path.join(current_dir, 'test', 'model_cooking.bin')\n", 260 | "# model = ft.load_model(cooking_output_filename)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "### Predictions" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 14, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | "[['knives']]\n" 280 | ] 281 | } 282 | ], 283 | "source": [ 284 | "texts = ['Why not put knives in the dishwasher?']\n", 285 | "\n", 286 | "labels = cooking_model.predict(texts)\n", 287 | "print(labels)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 15, 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "[['knives']]\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "labels = cooking_norm_model.predict(texts)\n", 305 | "print(labels)" 306 | ] 307 | } 308 | ], 309 | "metadata": { 310 | "kernelspec": { 311 | "display_name": "Python 3", 312 | "language": "python", 313 | "name": "python3" 314 | }, 315 | "language_info": { 316 | "codemirror_mode": { 317 | "name": "ipython", 318 | "version": 3 319 | }, 320 | "file_extension": ".py", 321 | "mimetype": "text/x-python", 322 | "name": "python", 323 | "nbconvert_exporter": "python", 324 | "pygments_lexer": "ipython3", 325 | "version": "3.6.1" 326 | } 327 | }, 328 | "nbformat": 4, 329 | "nbformat_minor": 2 330 | } 331 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter==1.0.0 2 | fasttext==0.8.3 3 | nltk==3.2.4 --------------------------------------------------------------------------------