├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── data └── tweets.csv ├── img └── cnn-text-classification.jpg ├── main.py └── src ├── __init__.py ├── model ├── __init__.py ├── model.py └── run.py ├── parameters ├── __init__.py └── parameters.py └── preprocessing ├── __init__.py └── preprocessing.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Fernando 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | numpy = "*" 10 | torch = "*" 11 | sklearn = "*" 12 | pandas = "*" 13 | nltk = "*" 14 | 15 | [requires] 16 | python_version = "3.7" 17 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "fb1c598019717c48cd36fc80879942cf73024f9358e44c0b6a328be357811549" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.7" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "click": { 20 | "hashes": [ 21 | "sha256:d2b5255c7c6349bc1bd1e59e08cd12acbbd63ce649f2588755783aa94dfb6b1a", 22 | "sha256:dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc" 23 | ], 24 | "version": "==7.1.2" 25 | }, 26 | "future": { 27 | "hashes": [ 28 | "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d" 29 | ], 30 | "version": "==0.18.2" 31 | }, 32 | "joblib": { 33 | "hashes": [ 34 | "sha256:8f52bf24c64b608bf0b2563e0e47d6fcf516abc8cfafe10cfd98ad66d94f92d6", 35 | "sha256:d348c5d4ae31496b2aa060d6d9b787864dd204f9480baaa52d18850cb43e9f49" 36 | ], 37 | "version": "==0.16.0" 38 | }, 39 | "nltk": { 40 | "hashes": [ 41 | "sha256:845365449cd8c5f9731f7cb9f8bd6fd0767553b9d53af9eb1b3abf7700936b35" 42 | ], 43 | "index": "pypi", 44 | "version": "==3.5" 45 | }, 46 | "numpy": { 47 | "hashes": [ 48 | "sha256:082f8d4dd69b6b688f64f509b91d482362124986d98dc7dc5f5e9f9b9c3bb983", 49 | "sha256:1bc0145999e8cb8aed9d4e65dd8b139adf1919e521177f198529687dbf613065", 50 | "sha256:309cbcfaa103fc9a33ec16d2d62569d541b79f828c382556ff072442226d1968", 51 | "sha256:3673c8b2b29077f1b7b3a848794f8e11f401ba0b71c49fbd26fb40b71788b132", 52 | "sha256:480fdd4dbda4dd6b638d3863da3be82873bba6d32d1fc12ea1b8486ac7b8d129", 53 | "sha256:56ef7f56470c24bb67fb43dae442e946a6ce172f97c69f8d067ff8550cf782ff", 54 | "sha256:5a936fd51049541d86ccdeef2833cc89a18e4d3808fe58a8abeb802665c5af93", 55 | "sha256:5b6885c12784a27e957294b60f97e8b5b4174c7504665333c5e94fbf41ae5d6a", 56 | "sha256:667c07063940e934287993366ad5f56766bc009017b4a0fe91dbd07960d0aba7", 57 | "sha256:7ed448ff4eaffeb01094959b19cbaf998ecdee9ef9932381420d514e446601cd", 58 | "sha256:8343bf67c72e09cfabfab55ad4a43ce3f6bf6e6ced7acf70f45ded9ebb425055", 59 | "sha256:92feb989b47f83ebef246adabc7ff3b9a59ac30601c3f6819f8913458610bdcc", 60 | "sha256:935c27ae2760c21cd7354402546f6be21d3d0c806fffe967f745d5f2de5005a7", 61 | "sha256:aaf42a04b472d12515debc621c31cf16c215e332242e7a9f56403d814c744624", 62 | "sha256:b12e639378c741add21fbffd16ba5ad25c0a1a17cf2b6fe4288feeb65144f35b", 63 | "sha256:b1cca51512299841bf69add3b75361779962f9cee7d9ee3bb446d5982e925b69", 64 | "sha256:b8456987b637232602ceb4d663cb34106f7eb780e247d51a260b84760fd8f491", 65 | "sha256:b9792b0ac0130b277536ab8944e7b754c69560dac0415dd4b2dbd16b902c8954", 66 | "sha256:c9591886fc9cbe5532d5df85cb8e0cc3b44ba8ce4367bd4cf1b93dc19713da72", 67 | "sha256:cf1347450c0b7644ea142712619533553f02ef23f92f781312f6a3553d031fc7", 68 | "sha256:de8b4a9b56255797cbddb93281ed92acbc510fb7b15df3f01bd28f46ebc4edae", 69 | "sha256:e1b1dc0372f530f26a03578ac75d5e51b3868b9b76cd2facba4c9ee0eb252ab1", 70 | "sha256:e45f8e981a0ab47103181773cc0a54e650b2aef8c7b6cd07405d0fa8d869444a", 71 | "sha256:e4f6d3c53911a9d103d8ec9518190e52a8b945bab021745af4939cfc7c0d4a9e", 72 | "sha256:ed8a311493cf5480a2ebc597d1e177231984c818a86875126cfd004241a73c3e", 73 | "sha256:ef71a1d4fd4858596ae80ad1ec76404ad29701f8ca7cdcebc50300178db14dfc" 74 | ], 75 | "index": "pypi", 76 | "version": "==1.19.1" 77 | }, 78 | "pandas": { 79 | "hashes": [ 80 | "sha256:01b1e536eb960822c5e6b58357cad8c4b492a336f4a5630bf0b598566462a578", 81 | "sha256:0246c67cbaaaac8d25fed8d4cf2d8897bd858f0e540e8528a75281cee9ac516d", 82 | "sha256:0366150fe8ee37ef89a45d3093e05026b5f895e42bbce3902ce3b6427f1b8471", 83 | "sha256:16ae070c47474008769fc443ac765ffd88c3506b4a82966e7a605592978896f9", 84 | "sha256:1acc2bd7fc95e5408a4456897c2c2a1ae7c6acefe108d90479ab6d98d34fcc3d", 85 | "sha256:391db82ebeb886143b96b9c6c6166686c9a272d00020e4e39ad63b792542d9e2", 86 | "sha256:41675323d4fcdd15abde068607cad150dfe17f7d32290ee128e5fea98442bd09", 87 | "sha256:53328284a7bb046e2e885fd1b8c078bd896d7fc4575b915d4936f54984a2ba67", 88 | "sha256:57c5f6be49259cde8e6f71c2bf240a26b071569cabc04c751358495d09419e56", 89 | "sha256:84c101d0f7bbf0d9f1be9a2f29f6fcc12415442558d067164e50a56edfb732b4", 90 | "sha256:88930c74f69e97b17703600233c0eaf1f4f4dd10c14633d522724c5c1b963ec4", 91 | "sha256:8c9ec12c480c4d915e23ee9c8a2d8eba8509986f35f307771045c1294a2e5b73", 92 | "sha256:a81c4bf9c59010aa3efddbb6b9fc84a9b76dc0b4da2c2c2d50f06a9ef6ac0004", 93 | "sha256:d9644ac996149b2a51325d48d77e25c911e01aa6d39dc1b64be679cd71f683ec", 94 | "sha256:e4b6c98f45695799990da328e6fd7d6187be32752ed64c2f22326ad66762d179", 95 | "sha256:fe6f1623376b616e03d51f0dd95afd862cf9a33c18cf55ce0ed4bbe1c4444391" 96 | ], 97 | "index": "pypi", 98 | "version": "==1.1.1" 99 | }, 100 | "python-dateutil": { 101 | "hashes": [ 102 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", 103 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" 104 | ], 105 | "version": "==2.8.1" 106 | }, 107 | "pytz": { 108 | "hashes": [ 109 | "sha256:a494d53b6d39c3c6e44c3bec237336e14305e4f29bbf800b599253057fbb79ed", 110 | "sha256:c35965d010ce31b23eeb663ed3cc8c906275d6be1a34393a1d73a41febf4a048" 111 | ], 112 | "version": "==2020.1" 113 | }, 114 | "regex": { 115 | "hashes": [ 116 | "sha256:0dc64ee3f33cd7899f79a8d788abfbec168410be356ed9bd30bbd3f0a23a7204", 117 | "sha256:1269fef3167bb52631ad4fa7dd27bf635d5a0790b8e6222065d42e91bede4162", 118 | "sha256:14a53646369157baa0499513f96091eb70382eb50b2c82393d17d7ec81b7b85f", 119 | "sha256:3a3af27a8d23143c49a3420efe5b3f8cf1a48c6fc8bc6856b03f638abc1833bb", 120 | "sha256:46bac5ca10fb748d6c55843a931855e2727a7a22584f302dd9bb1506e69f83f6", 121 | "sha256:4c037fd14c5f4e308b8370b447b469ca10e69427966527edcab07f52d88388f7", 122 | "sha256:51178c738d559a2d1071ce0b0f56e57eb315bcf8f7d4cf127674b533e3101f88", 123 | "sha256:5ea81ea3dbd6767873c611687141ec7b06ed8bab43f68fad5b7be184a920dc99", 124 | "sha256:6961548bba529cac7c07af2fd4d527c5b91bb8fe18995fed6044ac22b3d14644", 125 | "sha256:75aaa27aa521a182824d89e5ab0a1d16ca207318a6b65042b046053cfc8ed07a", 126 | "sha256:7a2dd66d2d4df34fa82c9dc85657c5e019b87932019947faece7983f2089a840", 127 | "sha256:8a51f2c6d1f884e98846a0a9021ff6861bdb98457879f412fdc2b42d14494067", 128 | "sha256:9c568495e35599625f7b999774e29e8d6b01a6fb684d77dee1f56d41b11b40cd", 129 | "sha256:9eddaafb3c48e0900690c1727fba226c4804b8e6127ea409689c3bb492d06de4", 130 | "sha256:bbb332d45b32df41200380fff14712cb6093b61bd142272a10b16778c418e98e", 131 | "sha256:bc3d98f621898b4a9bc7fecc00513eec8f40b5b83913d74ccb445f037d58cd89", 132 | "sha256:c11d6033115dc4887c456565303f540c44197f4fc1a2bfb192224a301534888e", 133 | "sha256:c50a724d136ec10d920661f1442e4a8b010a4fe5aebd65e0c2241ea41dbe93dc", 134 | "sha256:d0a5095d52b90ff38592bbdc2644f17c6d495762edf47d876049cfd2968fbccf", 135 | "sha256:d6cff2276e502b86a25fd10c2a96973fdb45c7a977dca2138d661417f3728341", 136 | "sha256:e46d13f38cfcbb79bfdb2964b0fe12561fe633caf964a77a5f8d4e45fe5d2ef7" 137 | ], 138 | "version": "==2020.7.14" 139 | }, 140 | "scikit-learn": { 141 | "hashes": [ 142 | "sha256:0a127cc70990d4c15b1019680bfedc7fec6c23d14d3719fdf9b64b22d37cdeca", 143 | "sha256:0d39748e7c9669ba648acf40fb3ce96b8a07b240db6888563a7cb76e05e0d9cc", 144 | "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea", 145 | "sha256:20766f515e6cd6f954554387dfae705d93c7b544ec0e6c6a5d8e006f6f7ef480", 146 | "sha256:2aa95c2f17d2f80534156215c87bee72b6aa314a7f8b8fe92a2d71f47280570d", 147 | "sha256:5ce7a8021c9defc2b75620571b350acc4a7d9763c25b7593621ef50f3bd019a2", 148 | "sha256:6c28a1d00aae7c3c9568f61aafeaad813f0f01c729bee4fd9479e2132b215c1d", 149 | "sha256:7671bbeddd7f4f9a6968f3b5442dac5f22bf1ba06709ef888cc9132ad354a9ab", 150 | "sha256:914ac2b45a058d3f1338d7736200f7f3b094857758895f8667be8a81ff443b5b", 151 | "sha256:98508723f44c61896a4e15894b2016762a55555fbf09365a0bb1870ecbd442de", 152 | "sha256:a64817b050efd50f9abcfd311870073e500ae11b299683a519fbb52d85e08d25", 153 | "sha256:cb3e76380312e1f86abd20340ab1d5b3cc46a26f6593d3c33c9ea3e4c7134028", 154 | "sha256:d0dcaa54263307075cb93d0bee3ceb02821093b1b3d25f66021987d305d01dce", 155 | "sha256:d9a1ce5f099f29c7c33181cc4386660e0ba891b21a60dc036bf369e3a3ee3aec", 156 | "sha256:da8e7c302003dd765d92a5616678e591f347460ac7b53e53d667be7dfe6d1b10", 157 | "sha256:daf276c465c38ef736a79bd79fc80a249f746bcbcae50c40945428f7ece074f8" 158 | ], 159 | "version": "==0.23.2" 160 | }, 161 | "scipy": { 162 | "hashes": [ 163 | "sha256:066c513d90eb3fd7567a9e150828d39111ebd88d3e924cdfc9f8ce19ab6f90c9", 164 | "sha256:07e52b316b40a4f001667d1ad4eb5f2318738de34597bd91537851365b6c61f1", 165 | "sha256:0a0e9a4e58a4734c2eba917f834b25b7e3b6dc333901ce7784fd31aefbd37b2f", 166 | "sha256:1c7564a4810c1cd77fcdee7fa726d7d39d4e2695ad252d7c86c3ea9d85b7fb8f", 167 | "sha256:315aa2165aca31375f4e26c230188db192ed901761390be908c9b21d8b07df62", 168 | "sha256:6e86c873fe1335d88b7a4bfa09d021f27a9e753758fd75f3f92d714aa4093768", 169 | "sha256:8e28e74b97fc8d6aa0454989db3b5d36fc27e69cef39a7ee5eaf8174ca1123cb", 170 | "sha256:92eb04041d371fea828858e4fff182453c25ae3eaa8782d9b6c32b25857d23bc", 171 | "sha256:a0afbb967fd2c98efad5f4c24439a640d39463282040a88e8e928db647d8ac3d", 172 | "sha256:a785409c0fa51764766840185a34f96a0a93527a0ff0230484d33a8ed085c8f8", 173 | "sha256:cca9fce15109a36a0a9f9cfc64f870f1c140cb235ddf27fe0328e6afb44dfed0", 174 | "sha256:d56b10d8ed72ec1be76bf10508446df60954f08a41c2d40778bc29a3a9ad9bce", 175 | "sha256:dac09281a0eacd59974e24525a3bc90fa39b4e95177e638a31b14db60d3fa806", 176 | "sha256:ec5fe57e46828d034775b00cd625c4a7b5c7d2e354c3b258d820c6c72212a6ec", 177 | "sha256:eecf40fa87eeda53e8e11d265ff2254729d04000cd40bae648e76ff268885d66", 178 | "sha256:fc98f3eac993b9bfdd392e675dfe19850cc8c7246a8fd2b42443e506344be7d9" 179 | ], 180 | "version": "==1.5.2" 181 | }, 182 | "six": { 183 | "hashes": [ 184 | "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259", 185 | "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced" 186 | ], 187 | "version": "==1.15.0" 188 | }, 189 | "sklearn": { 190 | "hashes": [ 191 | "sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31" 192 | ], 193 | "index": "pypi", 194 | "version": "==0.0" 195 | }, 196 | "threadpoolctl": { 197 | "hashes": [ 198 | "sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725", 199 | "sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b" 200 | ], 201 | "version": "==2.1.0" 202 | }, 203 | "torch": { 204 | "hashes": [ 205 | "sha256:3838bd01af7dfb1f78573973f6842ce75b17e8e4f22be99c891dcb7c94bc13f5", 206 | "sha256:4f9a4ad7947cef566afb0a323d99009fe8524f0b0f2ca1fb7ad5de0400381a5b", 207 | "sha256:728facb972a5952323c6d790c2c5922b2b35c44b0bc7bdfa02f8639727671a0c", 208 | "sha256:7669f4d923b5758e28b521ea749c795ed67ff24b45ba20296bc8cff706d08df8", 209 | "sha256:87d65c01d1b70bb46070824f28bfd93c86d3c5c56b90cbbe836a3f2491d91c76", 210 | "sha256:8d66053ed53574b15b4ecddd4f9ca0f3868fe792088be3e728bac36fa998f7ce" 211 | ], 212 | "index": "pypi", 213 | "version": "==1.6.0" 214 | }, 215 | "tqdm": { 216 | "hashes": [ 217 | "sha256:1a336d2b829be50e46b84668691e0a2719f26c97c62846298dd5ae2937e4d5cf", 218 | "sha256:564d632ea2b9cb52979f7956e093e831c28d441c11751682f84c86fc46e4fd21" 219 | ], 220 | "version": "==4.48.2" 221 | } 222 | }, 223 | "develop": {} 224 | } 225 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 9 | [![Medium][medium-shield]][medium-url] 10 | [![Twitter][twitter-shield]][twitter-url] 11 | [![Linkedin][linkedin-shield]][linkedin-url] 12 | 13 | # Text Classification with CNNs in PyTorch 14 | The aim of this repository is to show a baseline model for text classification through convolutional neural networks in the PyTorch framework. The architecture implemented in this model was inspired by the one proposed in the paper: Convolutional Neural Networks for Sentence Classification. 15 | 16 | If you want to understand the details about how this model was created, take a look at this very clear and detailed explanation: Text Classification with CNNs in PyTorch 17 | 18 | 19 | ## Table of Contents 20 | 21 | * [The model](#the-model) 22 | * [Files](#files) 23 | * [How to use](#how-to-use) 24 | * [Contributing](#contributing) 25 | * [Contact](#contact) 26 | * [License](#license) 27 | 28 | 29 | ## 1. The model 30 | The architecture of the model is composed of 4 convolutional layers which generate 32 filters each, then each one of these filters is passed through the ``max pooling`` function whose outputs are subsequently cocatenated. Finally, the concatenation is passed through a fully connected layer. The following image describes the model architecture: 31 | 32 |

33 | 34 |

35 | 36 | 37 | ## 2. Files 38 | * **Pipfile**: Here you will find the dependencies that the model needs to be run. 39 | 40 | * **main.py**: It contains the controller of pipelines (preprocessing and trainig) 41 | 42 | * **src**: It contains three directories, which are: ``model``, ``parameters`` and ``preprocessing``. 43 | 44 | * **src/model**: It contains two files, ``model.py`` and ``run.py`` which handles the model definition as well as the training/evaluation phase respectively. 45 | 46 | * **src/parameters**: It contains a ``dataclass`` which stores the parameters used to preprocess the text, define and train the model. 47 | 48 | * **src/preprocessing**: It contains the functions implemented to load, clean and tokenize the text. 49 | 50 | * **data**: It contains the data used to train the depicted model. 51 | 52 | 53 | ## 3. How to use 54 | First you will need to install the dependencies and right after you will need to launch the ``pipenv`` virutal environment. So in order to install the dependices, you have to type: 55 | 56 | ```SH 57 | pipenv install 58 | ``` 59 | 60 | right after you will need to launch the virtual environment such as: 61 | 62 | ```SH 63 | pipenv shell 64 | ``` 65 | 66 | Then, you can execute the prepropcessing and trainig/evaluation pipelines easily, just typing: 67 | 68 | ```SH 69 | python main.py 70 | ``` 71 | 72 | If you want to modify some of the parameters, you can modify the ``dataclass`` located at ``src/parameters/parameters.py`` which has the following form: 73 | 74 | ```PY 75 | @dataclass 76 | class Parameters: 77 | 78 | # Preprocessing parameeters 79 | seq_len: int = 35 80 | num_words: int = 2000 81 | 82 | # Model parameters 83 | embedding_size: int = 64 84 | out_size: int = 32 85 | stride: int = 2 86 | 87 | # Training parameters 88 | epochs: int = 10 89 | batch_size: int = 12 90 | learning_rate: float = 0.001 91 | ``` 92 | 93 | ## 4. Contributing 94 | Feel free to fork the model and add your own suggestiongs. 95 | 96 | 1. Fork the Project 97 | 2. Create your Feature Branch (`git checkout -b feature/YourGreatFeature`) 98 | 3. Commit your Changes (`git commit -m 'Add some YourGreatFeature'`) 99 | 4. Push to the Branch (`git push origin feature/YourGreatFeature`) 100 | 5. Open a Pull Request 101 | 102 | 103 | ## 5. Contact 104 | If you have any question, feel free to reach me out at: 105 | * Twitter 106 | * Medium 107 | * Linkedin 108 | * Email: fer.neutron@gmail.com 109 | 110 | 111 | ## 6. License 112 | Distributed under the MIT License. See ``LICENSE.md`` for more information. 113 | 114 | 115 | 116 | 117 | [medium-shield]: https://img.shields.io/badge/medium-%2312100E.svg?&style=for-the-badge&logo=medium&logoColor=white 118 | [medium-url]: https://medium.com/@fer.neutron 119 | [twitter-shield]: https://img.shields.io/badge/twitter-%231DA1F2.svg?&style=for-the-badge&logo=twitter&logoColor=white 120 | [twitter-url]: https://twitter.com/Fernando_LpzV 121 | [linkedin-shield]: https://img.shields.io/badge/linkedin-%230077B5.svg?&style=for-the-badge&logo=linkedin&logoColor=white 122 | [linkedin-url]: https://www.linkedin.com/in/fernando-lopezvelasco/ -------------------------------------------------------------------------------- /img/cnn-text-classification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FernandoLpz/Text-Classification-CNN-PyTorch/eab9d9eaa0cd986047079a24a3a91fa55c42848d/img/cnn-text-classification.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from src import Parameters 2 | from src import Preprocessing 3 | from src import TextClassifier 4 | from src import Run 5 | 6 | 7 | class Controller(Parameters): 8 | 9 | def __init__(self): 10 | # Preprocessing pipeline 11 | self.data = self.prepare_data(Parameters.num_words, Parameters.seq_len) 12 | 13 | # Initialize the model 14 | self.model = TextClassifier(Parameters) 15 | 16 | # Training - Evaluation pipeline 17 | Run().train(self.model, self.data, Parameters) 18 | 19 | 20 | @staticmethod 21 | def prepare_data(num_words, seq_len): 22 | # Preprocessing pipeline 23 | pr = Preprocessing(num_words, seq_len) 24 | pr.load_data() 25 | pr.clean_text() 26 | pr.text_tokenization() 27 | pr.build_vocabulary() 28 | pr.word_to_idx() 29 | pr.padding_sentences() 30 | pr.split_data() 31 | 32 | return {'x_train': pr.x_train, 'y_train': pr.y_train, 'x_test': pr.x_test, 'y_test': pr.y_test} 33 | 34 | if __name__ == '__main__': 35 | controller = Controller() -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .parameters import Parameters 2 | from .preprocessing import Preprocessing 3 | from .model import TextClassifier 4 | from .model import Run -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import TextClassifier 2 | from .run import Run -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class TextClassifier(nn.ModuleList): 7 | 8 | def __init__(self, params): 9 | super(TextClassifier, self).__init__() 10 | 11 | # Parameters regarding text preprocessing 12 | self.seq_len = params.seq_len 13 | self.num_words = params.num_words 14 | self.embedding_size = params.embedding_size 15 | 16 | # Dropout definition 17 | self.dropout = nn.Dropout(0.25) 18 | 19 | # CNN parameters definition 20 | # Kernel sizes 21 | self.kernel_1 = 2 22 | self.kernel_2 = 3 23 | self.kernel_3 = 4 24 | self.kernel_4 = 5 25 | 26 | # Output size for each convolution 27 | self.out_size = params.out_size 28 | # Number of strides for each convolution 29 | self.stride = params.stride 30 | 31 | # Embedding layer definition 32 | self.embedding = nn.Embedding(self.num_words + 1, self.embedding_size, padding_idx=0) 33 | 34 | # Convolution layers definition 35 | self.conv_1 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_1, self.stride) 36 | self.conv_2 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_2, self.stride) 37 | self.conv_3 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_3, self.stride) 38 | self.conv_4 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_4, self.stride) 39 | 40 | # Max pooling layers definition 41 | self.pool_1 = nn.MaxPool1d(self.kernel_1, self.stride) 42 | self.pool_2 = nn.MaxPool1d(self.kernel_2, self.stride) 43 | self.pool_3 = nn.MaxPool1d(self.kernel_3, self.stride) 44 | self.pool_4 = nn.MaxPool1d(self.kernel_4, self.stride) 45 | 46 | # Fully connected layer definition 47 | self.fc = nn.Linear(self.in_features_fc(), 1) 48 | 49 | 50 | def in_features_fc(self): 51 | '''Calculates the number of output features after Convolution + Max pooling 52 | 53 | Convolved_Features = ((embedding_size + (2 * padding) - dilation * (kernel - 1) - 1) / stride) + 1 54 | Pooled_Features = ((embedding_size + (2 * padding) - dilation * (kernel - 1) - 1) / stride) + 1 55 | 56 | source: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 57 | ''' 58 | # Calcualte size of convolved/pooled features for convolution_1/max_pooling_1 features 59 | out_conv_1 = ((self.embedding_size - 1 * (self.kernel_1 - 1) - 1) / self.stride) + 1 60 | out_conv_1 = math.floor(out_conv_1) 61 | out_pool_1 = ((out_conv_1 - 1 * (self.kernel_1 - 1) - 1) / self.stride) + 1 62 | out_pool_1 = math.floor(out_pool_1) 63 | 64 | # Calcualte size of convolved/pooled features for convolution_2/max_pooling_2 features 65 | out_conv_2 = ((self.embedding_size - 1 * (self.kernel_2 - 1) - 1) / self.stride) + 1 66 | out_conv_2 = math.floor(out_conv_2) 67 | out_pool_2 = ((out_conv_2 - 1 * (self.kernel_2 - 1) - 1) / self.stride) + 1 68 | out_pool_2 = math.floor(out_pool_2) 69 | 70 | # Calcualte size of convolved/pooled features for convolution_3/max_pooling_3 features 71 | out_conv_3 = ((self.embedding_size - 1 * (self.kernel_3 - 1) - 1) / self.stride) + 1 72 | out_conv_3 = math.floor(out_conv_3) 73 | out_pool_3 = ((out_conv_3 - 1 * (self.kernel_3 - 1) - 1) / self.stride) + 1 74 | out_pool_3 = math.floor(out_pool_3) 75 | 76 | # Calcualte size of convolved/pooled features for convolution_4/max_pooling_4 features 77 | out_conv_4 = ((self.embedding_size - 1 * (self.kernel_4 - 1) - 1) / self.stride) + 1 78 | out_conv_4 = math.floor(out_conv_4) 79 | out_pool_4 = ((out_conv_4 - 1 * (self.kernel_4 - 1) - 1) / self.stride) + 1 80 | out_pool_4 = math.floor(out_pool_4) 81 | 82 | # Returns "flattened" vector (input for fully connected layer) 83 | return (out_pool_1 + out_pool_2 + out_pool_3 + out_pool_4) * self.out_size 84 | 85 | 86 | 87 | def forward(self, x): 88 | 89 | # Sequence of tokes is filterd through an embedding layer 90 | x = self.embedding(x) 91 | 92 | # Convolution layer 1 is applied 93 | x1 = self.conv_1(x) 94 | x1 = torch.relu(x1) 95 | x1 = self.pool_1(x1) 96 | 97 | # Convolution layer 2 is applied 98 | x2 = self.conv_2(x) 99 | x2 = torch.relu((x2)) 100 | x2 = self.pool_2(x2) 101 | 102 | # Convolution layer 3 is applied 103 | x3 = self.conv_3(x) 104 | x3 = torch.relu(x3) 105 | x3 = self.pool_3(x3) 106 | 107 | # Convolution layer 4 is applied 108 | x4 = self.conv_4(x) 109 | x4 = torch.relu(x4) 110 | x4 = self.pool_4(x4) 111 | 112 | # The output of each convolutional layer is concatenated into a unique vector 113 | union = torch.cat((x1, x2, x3, x4), 2) 114 | union = union.reshape(union.size(0), -1) 115 | 116 | # The "flattened" vector is passed through a fully connected layer 117 | out = self.fc(union) 118 | # Dropout is applied 119 | out = self.dropout(out) 120 | # Activation function is applied 121 | out = torch.sigmoid(out) 122 | 123 | return out.squeeze() 124 | -------------------------------------------------------------------------------- /src/model/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | class DatasetMaper(Dataset): 8 | 9 | def __init__(self, x, y): 10 | self.x = x 11 | self.y = y 12 | 13 | def __len__(self): 14 | return len(self.x) 15 | 16 | def __getitem__(self, idx): 17 | return self.x[idx], self.y[idx] 18 | 19 | class Run: 20 | '''Training, evaluation and metrics calculation''' 21 | 22 | @staticmethod 23 | def train(model, data, params): 24 | 25 | # Initialize dataset maper 26 | train = DatasetMaper(data['x_train'], data['y_train']) 27 | test = DatasetMaper(data['x_test'], data['y_test']) 28 | 29 | # Initialize loaders 30 | loader_train = DataLoader(train, batch_size=params.batch_size) 31 | loader_test = DataLoader(test, batch_size=params.batch_size) 32 | 33 | # Define optimizer 34 | optimizer = optim.RMSprop(model.parameters(), lr=params.learning_rate) 35 | 36 | # Starts training phase 37 | for epoch in range(params.epochs): 38 | # Set model in training model 39 | model.train() 40 | predictions = [] 41 | # Starts batch training 42 | for x_batch, y_batch in loader_train: 43 | 44 | y_batch = y_batch.type(torch.FloatTensor) 45 | 46 | # Feed the model 47 | y_pred = model(x_batch) 48 | 49 | # Loss calculation 50 | loss = F.binary_cross_entropy(y_pred, y_batch) 51 | 52 | # Clean gradientes 53 | optimizer.zero_grad() 54 | 55 | # Gradients calculation 56 | loss.backward() 57 | 58 | # Gradients update 59 | optimizer.step() 60 | 61 | # Save predictions 62 | predictions += list(y_pred.detach().numpy()) 63 | 64 | # Evaluation phase 65 | test_predictions = Run.evaluation(model, loader_test) 66 | 67 | # Metrics calculation 68 | train_accuary = Run.calculate_accuray(data['y_train'], predictions) 69 | test_accuracy = Run.calculate_accuray(data['y_test'], test_predictions) 70 | print("Epoch: %d, loss: %.5f, Train accuracy: %.5f, Test accuracy: %.5f" % (epoch+1, loss.item(), train_accuary, test_accuracy)) 71 | 72 | @staticmethod 73 | def evaluation(model, loader_test): 74 | 75 | # Set the model in evaluation mode 76 | model.eval() 77 | predictions = [] 78 | 79 | # Starst evaluation phase 80 | with torch.no_grad(): 81 | for x_batch, y_batch in loader_test: 82 | y_pred = model(x_batch) 83 | predictions += list(y_pred.detach().numpy()) 84 | return predictions 85 | 86 | @staticmethod 87 | def calculate_accuray(grand_truth, predictions): 88 | # Metrics calculation 89 | true_positives = 0 90 | true_negatives = 0 91 | for true, pred in zip(grand_truth, predictions): 92 | if (pred >= 0.5) and (true == 1): 93 | true_positives += 1 94 | elif (pred < 0.5) and (true == 0): 95 | true_negatives += 1 96 | else: 97 | pass 98 | # Return accuracy 99 | return (true_positives+true_negatives) / len(grand_truth) 100 | -------------------------------------------------------------------------------- /src/parameters/__init__.py: -------------------------------------------------------------------------------- 1 | from .parameters import Parameters -------------------------------------------------------------------------------- /src/parameters/parameters.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class Parameters: 5 | # Preprocessing parameeters 6 | seq_len: int = 35 7 | num_words: int = 2000 8 | 9 | # Model parameters 10 | embedding_size: int = 64 11 | out_size: int = 32 12 | stride: int = 2 13 | 14 | # Training parameters 15 | epochs: int = 10 16 | batch_size: int = 12 17 | learning_rate: float = 0.001 -------------------------------------------------------------------------------- /src/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessing import Preprocessing -------------------------------------------------------------------------------- /src/preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | import re 2 | import nltk 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.model_selection import train_test_split 6 | from nltk.tokenize import word_tokenize 7 | 8 | class Preprocessing: 9 | 10 | def __init__(self, num_words, seq_len): 11 | self.data = '/Users/Fer/Documents/OwnRepo/TextClassification-CNN-PyTorch/data/tweets.csv' 12 | self.num_words = num_words 13 | self.seq_len = seq_len 14 | self.vocabulary = None 15 | self.x_tokenized = None 16 | self.x_padded = None 17 | self.x_raw = None 18 | self.y = None 19 | 20 | self.x_train = None 21 | self.x_test = None 22 | self.y_train = None 23 | self.y_test = None 24 | 25 | def load_data(self): 26 | # Reads the raw csv file and split into 27 | # sentences (x) and target (y) 28 | 29 | df = pd.read_csv(self.data) 30 | df.drop(['id','keyword','location'], axis=1, inplace=True) 31 | 32 | self.x_raw = df['text'].values 33 | self.y = df['target'].values 34 | 35 | def clean_text(self): 36 | # Removes special symbols and just keep 37 | # words in lower or upper form 38 | 39 | self.x_raw = [x.lower() for x in self.x_raw] 40 | self.x_raw = [re.sub(r'[^A-Za-z]+', ' ', x) for x in self.x_raw] 41 | 42 | def text_tokenization(self): 43 | # Tokenizes each sentence by implementing the nltk tool 44 | self.x_raw = [word_tokenize(x) for x in self.x_raw] 45 | 46 | def build_vocabulary(self): 47 | # Builds the vocabulary and keeps the "x" most frequent words 48 | self.vocabulary = dict() 49 | fdist = nltk.FreqDist() 50 | 51 | for sentence in self.x_raw: 52 | for word in sentence: 53 | fdist[word] += 1 54 | 55 | common_words = fdist.most_common(self.num_words) 56 | 57 | for idx, word in enumerate(common_words): 58 | self.vocabulary[word[0]] = (idx+1) 59 | 60 | def word_to_idx(self): 61 | # By using the dictionary (vocabulary), it is transformed 62 | # each token into its index based representation 63 | 64 | self.x_tokenized = list() 65 | 66 | for sentence in self.x_raw: 67 | temp_sentence = list() 68 | for word in sentence: 69 | if word in self.vocabulary.keys(): 70 | temp_sentence.append(self.vocabulary[word]) 71 | self.x_tokenized.append(temp_sentence) 72 | 73 | def padding_sentences(self): 74 | # Each sentence which does not fulfill the required len 75 | # it's padded with the index 0 76 | 77 | pad_idx = 0 78 | self.x_padded = list() 79 | 80 | for sentence in self.x_tokenized: 81 | while len(sentence) < self.seq_len: 82 | sentence.insert(len(sentence), pad_idx) 83 | self.x_padded.append(sentence) 84 | 85 | self.x_padded = np.array(self.x_padded) 86 | 87 | def split_data(self): 88 | self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(self.x_padded, self.y, test_size=0.25, random_state=42) --------------------------------------------------------------------------------