├── LICENSE
├── Pipfile
├── Pipfile.lock
├── README.md
├── data
└── tweets.csv
├── img
└── cnn-text-classification.jpg
├── main.py
└── src
├── __init__.py
├── model
├── __init__.py
├── model.py
└── run.py
├── parameters
├── __init__.py
└── parameters.py
└── preprocessing
├── __init__.py
└── preprocessing.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Fernando
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | name = "pypi"
3 | url = "https://pypi.org/simple"
4 | verify_ssl = true
5 |
6 | [dev-packages]
7 |
8 | [packages]
9 | numpy = "*"
10 | torch = "*"
11 | sklearn = "*"
12 | pandas = "*"
13 | nltk = "*"
14 |
15 | [requires]
16 | python_version = "3.7"
17 |
--------------------------------------------------------------------------------
/Pipfile.lock:
--------------------------------------------------------------------------------
1 | {
2 | "_meta": {
3 | "hash": {
4 | "sha256": "fb1c598019717c48cd36fc80879942cf73024f9358e44c0b6a328be357811549"
5 | },
6 | "pipfile-spec": 6,
7 | "requires": {
8 | "python_version": "3.7"
9 | },
10 | "sources": [
11 | {
12 | "name": "pypi",
13 | "url": "https://pypi.org/simple",
14 | "verify_ssl": true
15 | }
16 | ]
17 | },
18 | "default": {
19 | "click": {
20 | "hashes": [
21 | "sha256:d2b5255c7c6349bc1bd1e59e08cd12acbbd63ce649f2588755783aa94dfb6b1a",
22 | "sha256:dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc"
23 | ],
24 | "version": "==7.1.2"
25 | },
26 | "future": {
27 | "hashes": [
28 | "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"
29 | ],
30 | "version": "==0.18.2"
31 | },
32 | "joblib": {
33 | "hashes": [
34 | "sha256:8f52bf24c64b608bf0b2563e0e47d6fcf516abc8cfafe10cfd98ad66d94f92d6",
35 | "sha256:d348c5d4ae31496b2aa060d6d9b787864dd204f9480baaa52d18850cb43e9f49"
36 | ],
37 | "version": "==0.16.0"
38 | },
39 | "nltk": {
40 | "hashes": [
41 | "sha256:845365449cd8c5f9731f7cb9f8bd6fd0767553b9d53af9eb1b3abf7700936b35"
42 | ],
43 | "index": "pypi",
44 | "version": "==3.5"
45 | },
46 | "numpy": {
47 | "hashes": [
48 | "sha256:082f8d4dd69b6b688f64f509b91d482362124986d98dc7dc5f5e9f9b9c3bb983",
49 | "sha256:1bc0145999e8cb8aed9d4e65dd8b139adf1919e521177f198529687dbf613065",
50 | "sha256:309cbcfaa103fc9a33ec16d2d62569d541b79f828c382556ff072442226d1968",
51 | "sha256:3673c8b2b29077f1b7b3a848794f8e11f401ba0b71c49fbd26fb40b71788b132",
52 | "sha256:480fdd4dbda4dd6b638d3863da3be82873bba6d32d1fc12ea1b8486ac7b8d129",
53 | "sha256:56ef7f56470c24bb67fb43dae442e946a6ce172f97c69f8d067ff8550cf782ff",
54 | "sha256:5a936fd51049541d86ccdeef2833cc89a18e4d3808fe58a8abeb802665c5af93",
55 | "sha256:5b6885c12784a27e957294b60f97e8b5b4174c7504665333c5e94fbf41ae5d6a",
56 | "sha256:667c07063940e934287993366ad5f56766bc009017b4a0fe91dbd07960d0aba7",
57 | "sha256:7ed448ff4eaffeb01094959b19cbaf998ecdee9ef9932381420d514e446601cd",
58 | "sha256:8343bf67c72e09cfabfab55ad4a43ce3f6bf6e6ced7acf70f45ded9ebb425055",
59 | "sha256:92feb989b47f83ebef246adabc7ff3b9a59ac30601c3f6819f8913458610bdcc",
60 | "sha256:935c27ae2760c21cd7354402546f6be21d3d0c806fffe967f745d5f2de5005a7",
61 | "sha256:aaf42a04b472d12515debc621c31cf16c215e332242e7a9f56403d814c744624",
62 | "sha256:b12e639378c741add21fbffd16ba5ad25c0a1a17cf2b6fe4288feeb65144f35b",
63 | "sha256:b1cca51512299841bf69add3b75361779962f9cee7d9ee3bb446d5982e925b69",
64 | "sha256:b8456987b637232602ceb4d663cb34106f7eb780e247d51a260b84760fd8f491",
65 | "sha256:b9792b0ac0130b277536ab8944e7b754c69560dac0415dd4b2dbd16b902c8954",
66 | "sha256:c9591886fc9cbe5532d5df85cb8e0cc3b44ba8ce4367bd4cf1b93dc19713da72",
67 | "sha256:cf1347450c0b7644ea142712619533553f02ef23f92f781312f6a3553d031fc7",
68 | "sha256:de8b4a9b56255797cbddb93281ed92acbc510fb7b15df3f01bd28f46ebc4edae",
69 | "sha256:e1b1dc0372f530f26a03578ac75d5e51b3868b9b76cd2facba4c9ee0eb252ab1",
70 | "sha256:e45f8e981a0ab47103181773cc0a54e650b2aef8c7b6cd07405d0fa8d869444a",
71 | "sha256:e4f6d3c53911a9d103d8ec9518190e52a8b945bab021745af4939cfc7c0d4a9e",
72 | "sha256:ed8a311493cf5480a2ebc597d1e177231984c818a86875126cfd004241a73c3e",
73 | "sha256:ef71a1d4fd4858596ae80ad1ec76404ad29701f8ca7cdcebc50300178db14dfc"
74 | ],
75 | "index": "pypi",
76 | "version": "==1.19.1"
77 | },
78 | "pandas": {
79 | "hashes": [
80 | "sha256:01b1e536eb960822c5e6b58357cad8c4b492a336f4a5630bf0b598566462a578",
81 | "sha256:0246c67cbaaaac8d25fed8d4cf2d8897bd858f0e540e8528a75281cee9ac516d",
82 | "sha256:0366150fe8ee37ef89a45d3093e05026b5f895e42bbce3902ce3b6427f1b8471",
83 | "sha256:16ae070c47474008769fc443ac765ffd88c3506b4a82966e7a605592978896f9",
84 | "sha256:1acc2bd7fc95e5408a4456897c2c2a1ae7c6acefe108d90479ab6d98d34fcc3d",
85 | "sha256:391db82ebeb886143b96b9c6c6166686c9a272d00020e4e39ad63b792542d9e2",
86 | "sha256:41675323d4fcdd15abde068607cad150dfe17f7d32290ee128e5fea98442bd09",
87 | "sha256:53328284a7bb046e2e885fd1b8c078bd896d7fc4575b915d4936f54984a2ba67",
88 | "sha256:57c5f6be49259cde8e6f71c2bf240a26b071569cabc04c751358495d09419e56",
89 | "sha256:84c101d0f7bbf0d9f1be9a2f29f6fcc12415442558d067164e50a56edfb732b4",
90 | "sha256:88930c74f69e97b17703600233c0eaf1f4f4dd10c14633d522724c5c1b963ec4",
91 | "sha256:8c9ec12c480c4d915e23ee9c8a2d8eba8509986f35f307771045c1294a2e5b73",
92 | "sha256:a81c4bf9c59010aa3efddbb6b9fc84a9b76dc0b4da2c2c2d50f06a9ef6ac0004",
93 | "sha256:d9644ac996149b2a51325d48d77e25c911e01aa6d39dc1b64be679cd71f683ec",
94 | "sha256:e4b6c98f45695799990da328e6fd7d6187be32752ed64c2f22326ad66762d179",
95 | "sha256:fe6f1623376b616e03d51f0dd95afd862cf9a33c18cf55ce0ed4bbe1c4444391"
96 | ],
97 | "index": "pypi",
98 | "version": "==1.1.1"
99 | },
100 | "python-dateutil": {
101 | "hashes": [
102 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c",
103 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a"
104 | ],
105 | "version": "==2.8.1"
106 | },
107 | "pytz": {
108 | "hashes": [
109 | "sha256:a494d53b6d39c3c6e44c3bec237336e14305e4f29bbf800b599253057fbb79ed",
110 | "sha256:c35965d010ce31b23eeb663ed3cc8c906275d6be1a34393a1d73a41febf4a048"
111 | ],
112 | "version": "==2020.1"
113 | },
114 | "regex": {
115 | "hashes": [
116 | "sha256:0dc64ee3f33cd7899f79a8d788abfbec168410be356ed9bd30bbd3f0a23a7204",
117 | "sha256:1269fef3167bb52631ad4fa7dd27bf635d5a0790b8e6222065d42e91bede4162",
118 | "sha256:14a53646369157baa0499513f96091eb70382eb50b2c82393d17d7ec81b7b85f",
119 | "sha256:3a3af27a8d23143c49a3420efe5b3f8cf1a48c6fc8bc6856b03f638abc1833bb",
120 | "sha256:46bac5ca10fb748d6c55843a931855e2727a7a22584f302dd9bb1506e69f83f6",
121 | "sha256:4c037fd14c5f4e308b8370b447b469ca10e69427966527edcab07f52d88388f7",
122 | "sha256:51178c738d559a2d1071ce0b0f56e57eb315bcf8f7d4cf127674b533e3101f88",
123 | "sha256:5ea81ea3dbd6767873c611687141ec7b06ed8bab43f68fad5b7be184a920dc99",
124 | "sha256:6961548bba529cac7c07af2fd4d527c5b91bb8fe18995fed6044ac22b3d14644",
125 | "sha256:75aaa27aa521a182824d89e5ab0a1d16ca207318a6b65042b046053cfc8ed07a",
126 | "sha256:7a2dd66d2d4df34fa82c9dc85657c5e019b87932019947faece7983f2089a840",
127 | "sha256:8a51f2c6d1f884e98846a0a9021ff6861bdb98457879f412fdc2b42d14494067",
128 | "sha256:9c568495e35599625f7b999774e29e8d6b01a6fb684d77dee1f56d41b11b40cd",
129 | "sha256:9eddaafb3c48e0900690c1727fba226c4804b8e6127ea409689c3bb492d06de4",
130 | "sha256:bbb332d45b32df41200380fff14712cb6093b61bd142272a10b16778c418e98e",
131 | "sha256:bc3d98f621898b4a9bc7fecc00513eec8f40b5b83913d74ccb445f037d58cd89",
132 | "sha256:c11d6033115dc4887c456565303f540c44197f4fc1a2bfb192224a301534888e",
133 | "sha256:c50a724d136ec10d920661f1442e4a8b010a4fe5aebd65e0c2241ea41dbe93dc",
134 | "sha256:d0a5095d52b90ff38592bbdc2644f17c6d495762edf47d876049cfd2968fbccf",
135 | "sha256:d6cff2276e502b86a25fd10c2a96973fdb45c7a977dca2138d661417f3728341",
136 | "sha256:e46d13f38cfcbb79bfdb2964b0fe12561fe633caf964a77a5f8d4e45fe5d2ef7"
137 | ],
138 | "version": "==2020.7.14"
139 | },
140 | "scikit-learn": {
141 | "hashes": [
142 | "sha256:0a127cc70990d4c15b1019680bfedc7fec6c23d14d3719fdf9b64b22d37cdeca",
143 | "sha256:0d39748e7c9669ba648acf40fb3ce96b8a07b240db6888563a7cb76e05e0d9cc",
144 | "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea",
145 | "sha256:20766f515e6cd6f954554387dfae705d93c7b544ec0e6c6a5d8e006f6f7ef480",
146 | "sha256:2aa95c2f17d2f80534156215c87bee72b6aa314a7f8b8fe92a2d71f47280570d",
147 | "sha256:5ce7a8021c9defc2b75620571b350acc4a7d9763c25b7593621ef50f3bd019a2",
148 | "sha256:6c28a1d00aae7c3c9568f61aafeaad813f0f01c729bee4fd9479e2132b215c1d",
149 | "sha256:7671bbeddd7f4f9a6968f3b5442dac5f22bf1ba06709ef888cc9132ad354a9ab",
150 | "sha256:914ac2b45a058d3f1338d7736200f7f3b094857758895f8667be8a81ff443b5b",
151 | "sha256:98508723f44c61896a4e15894b2016762a55555fbf09365a0bb1870ecbd442de",
152 | "sha256:a64817b050efd50f9abcfd311870073e500ae11b299683a519fbb52d85e08d25",
153 | "sha256:cb3e76380312e1f86abd20340ab1d5b3cc46a26f6593d3c33c9ea3e4c7134028",
154 | "sha256:d0dcaa54263307075cb93d0bee3ceb02821093b1b3d25f66021987d305d01dce",
155 | "sha256:d9a1ce5f099f29c7c33181cc4386660e0ba891b21a60dc036bf369e3a3ee3aec",
156 | "sha256:da8e7c302003dd765d92a5616678e591f347460ac7b53e53d667be7dfe6d1b10",
157 | "sha256:daf276c465c38ef736a79bd79fc80a249f746bcbcae50c40945428f7ece074f8"
158 | ],
159 | "version": "==0.23.2"
160 | },
161 | "scipy": {
162 | "hashes": [
163 | "sha256:066c513d90eb3fd7567a9e150828d39111ebd88d3e924cdfc9f8ce19ab6f90c9",
164 | "sha256:07e52b316b40a4f001667d1ad4eb5f2318738de34597bd91537851365b6c61f1",
165 | "sha256:0a0e9a4e58a4734c2eba917f834b25b7e3b6dc333901ce7784fd31aefbd37b2f",
166 | "sha256:1c7564a4810c1cd77fcdee7fa726d7d39d4e2695ad252d7c86c3ea9d85b7fb8f",
167 | "sha256:315aa2165aca31375f4e26c230188db192ed901761390be908c9b21d8b07df62",
168 | "sha256:6e86c873fe1335d88b7a4bfa09d021f27a9e753758fd75f3f92d714aa4093768",
169 | "sha256:8e28e74b97fc8d6aa0454989db3b5d36fc27e69cef39a7ee5eaf8174ca1123cb",
170 | "sha256:92eb04041d371fea828858e4fff182453c25ae3eaa8782d9b6c32b25857d23bc",
171 | "sha256:a0afbb967fd2c98efad5f4c24439a640d39463282040a88e8e928db647d8ac3d",
172 | "sha256:a785409c0fa51764766840185a34f96a0a93527a0ff0230484d33a8ed085c8f8",
173 | "sha256:cca9fce15109a36a0a9f9cfc64f870f1c140cb235ddf27fe0328e6afb44dfed0",
174 | "sha256:d56b10d8ed72ec1be76bf10508446df60954f08a41c2d40778bc29a3a9ad9bce",
175 | "sha256:dac09281a0eacd59974e24525a3bc90fa39b4e95177e638a31b14db60d3fa806",
176 | "sha256:ec5fe57e46828d034775b00cd625c4a7b5c7d2e354c3b258d820c6c72212a6ec",
177 | "sha256:eecf40fa87eeda53e8e11d265ff2254729d04000cd40bae648e76ff268885d66",
178 | "sha256:fc98f3eac993b9bfdd392e675dfe19850cc8c7246a8fd2b42443e506344be7d9"
179 | ],
180 | "version": "==1.5.2"
181 | },
182 | "six": {
183 | "hashes": [
184 | "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259",
185 | "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"
186 | ],
187 | "version": "==1.15.0"
188 | },
189 | "sklearn": {
190 | "hashes": [
191 | "sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31"
192 | ],
193 | "index": "pypi",
194 | "version": "==0.0"
195 | },
196 | "threadpoolctl": {
197 | "hashes": [
198 | "sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725",
199 | "sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b"
200 | ],
201 | "version": "==2.1.0"
202 | },
203 | "torch": {
204 | "hashes": [
205 | "sha256:3838bd01af7dfb1f78573973f6842ce75b17e8e4f22be99c891dcb7c94bc13f5",
206 | "sha256:4f9a4ad7947cef566afb0a323d99009fe8524f0b0f2ca1fb7ad5de0400381a5b",
207 | "sha256:728facb972a5952323c6d790c2c5922b2b35c44b0bc7bdfa02f8639727671a0c",
208 | "sha256:7669f4d923b5758e28b521ea749c795ed67ff24b45ba20296bc8cff706d08df8",
209 | "sha256:87d65c01d1b70bb46070824f28bfd93c86d3c5c56b90cbbe836a3f2491d91c76",
210 | "sha256:8d66053ed53574b15b4ecddd4f9ca0f3868fe792088be3e728bac36fa998f7ce"
211 | ],
212 | "index": "pypi",
213 | "version": "==1.6.0"
214 | },
215 | "tqdm": {
216 | "hashes": [
217 | "sha256:1a336d2b829be50e46b84668691e0a2719f26c97c62846298dd5ae2937e4d5cf",
218 | "sha256:564d632ea2b9cb52979f7956e093e831c28d441c11751682f84c86fc46e4fd21"
219 | ],
220 | "version": "==4.48.2"
221 | }
222 | },
223 | "develop": {}
224 | }
225 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
9 | [![Medium][medium-shield]][medium-url]
10 | [![Twitter][twitter-shield]][twitter-url]
11 | [![Linkedin][linkedin-shield]][linkedin-url]
12 |
13 | # Text Classification with CNNs in PyTorch
14 | The aim of this repository is to show a baseline model for text classification through convolutional neural networks in the PyTorch framework. The architecture implemented in this model was inspired by the one proposed in the paper: Convolutional Neural Networks for Sentence Classification.
15 |
16 | If you want to understand the details about how this model was created, take a look at this very clear and detailed explanation: Text Classification with CNNs in PyTorch
17 |
18 |
19 | ## Table of Contents
20 |
21 | * [The model](#the-model)
22 | * [Files](#files)
23 | * [How to use](#how-to-use)
24 | * [Contributing](#contributing)
25 | * [Contact](#contact)
26 | * [License](#license)
27 |
28 |
29 | ## 1. The model
30 | The architecture of the model is composed of 4 convolutional layers which generate 32 filters each, then each one of these filters is passed through the ``max pooling`` function whose outputs are subsequently cocatenated. Finally, the concatenation is passed through a fully connected layer. The following image describes the model architecture:
31 |
32 |
33 |
34 |
35 |
36 |
37 | ## 2. Files
38 | * **Pipfile**: Here you will find the dependencies that the model needs to be run.
39 |
40 | * **main.py**: It contains the controller of pipelines (preprocessing and trainig)
41 |
42 | * **src**: It contains three directories, which are: ``model``, ``parameters`` and ``preprocessing``.
43 |
44 | * **src/model**: It contains two files, ``model.py`` and ``run.py`` which handles the model definition as well as the training/evaluation phase respectively.
45 |
46 | * **src/parameters**: It contains a ``dataclass`` which stores the parameters used to preprocess the text, define and train the model.
47 |
48 | * **src/preprocessing**: It contains the functions implemented to load, clean and tokenize the text.
49 |
50 | * **data**: It contains the data used to train the depicted model.
51 |
52 |
53 | ## 3. How to use
54 | First you will need to install the dependencies and right after you will need to launch the ``pipenv`` virutal environment. So in order to install the dependices, you have to type:
55 |
56 | ```SH
57 | pipenv install
58 | ```
59 |
60 | right after you will need to launch the virtual environment such as:
61 |
62 | ```SH
63 | pipenv shell
64 | ```
65 |
66 | Then, you can execute the prepropcessing and trainig/evaluation pipelines easily, just typing:
67 |
68 | ```SH
69 | python main.py
70 | ```
71 |
72 | If you want to modify some of the parameters, you can modify the ``dataclass`` located at ``src/parameters/parameters.py`` which has the following form:
73 |
74 | ```PY
75 | @dataclass
76 | class Parameters:
77 |
78 | # Preprocessing parameeters
79 | seq_len: int = 35
80 | num_words: int = 2000
81 |
82 | # Model parameters
83 | embedding_size: int = 64
84 | out_size: int = 32
85 | stride: int = 2
86 |
87 | # Training parameters
88 | epochs: int = 10
89 | batch_size: int = 12
90 | learning_rate: float = 0.001
91 | ```
92 |
93 | ## 4. Contributing
94 | Feel free to fork the model and add your own suggestiongs.
95 |
96 | 1. Fork the Project
97 | 2. Create your Feature Branch (`git checkout -b feature/YourGreatFeature`)
98 | 3. Commit your Changes (`git commit -m 'Add some YourGreatFeature'`)
99 | 4. Push to the Branch (`git push origin feature/YourGreatFeature`)
100 | 5. Open a Pull Request
101 |
102 |
103 | ## 5. Contact
104 | If you have any question, feel free to reach me out at:
105 | * Twitter
106 | * Medium
107 | * Linkedin
108 | * Email: fer.neutron@gmail.com
109 |
110 |
111 | ## 6. License
112 | Distributed under the MIT License. See ``LICENSE.md`` for more information.
113 |
114 |
115 |
116 |
117 | [medium-shield]: https://img.shields.io/badge/medium-%2312100E.svg?&style=for-the-badge&logo=medium&logoColor=white
118 | [medium-url]: https://medium.com/@fer.neutron
119 | [twitter-shield]: https://img.shields.io/badge/twitter-%231DA1F2.svg?&style=for-the-badge&logo=twitter&logoColor=white
120 | [twitter-url]: https://twitter.com/Fernando_LpzV
121 | [linkedin-shield]: https://img.shields.io/badge/linkedin-%230077B5.svg?&style=for-the-badge&logo=linkedin&logoColor=white
122 | [linkedin-url]: https://www.linkedin.com/in/fernando-lopezvelasco/
--------------------------------------------------------------------------------
/img/cnn-text-classification.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FernandoLpz/Text-Classification-CNN-PyTorch/eab9d9eaa0cd986047079a24a3a91fa55c42848d/img/cnn-text-classification.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from src import Parameters
2 | from src import Preprocessing
3 | from src import TextClassifier
4 | from src import Run
5 |
6 |
7 | class Controller(Parameters):
8 |
9 | def __init__(self):
10 | # Preprocessing pipeline
11 | self.data = self.prepare_data(Parameters.num_words, Parameters.seq_len)
12 |
13 | # Initialize the model
14 | self.model = TextClassifier(Parameters)
15 |
16 | # Training - Evaluation pipeline
17 | Run().train(self.model, self.data, Parameters)
18 |
19 |
20 | @staticmethod
21 | def prepare_data(num_words, seq_len):
22 | # Preprocessing pipeline
23 | pr = Preprocessing(num_words, seq_len)
24 | pr.load_data()
25 | pr.clean_text()
26 | pr.text_tokenization()
27 | pr.build_vocabulary()
28 | pr.word_to_idx()
29 | pr.padding_sentences()
30 | pr.split_data()
31 |
32 | return {'x_train': pr.x_train, 'y_train': pr.y_train, 'x_test': pr.x_test, 'y_test': pr.y_test}
33 |
34 | if __name__ == '__main__':
35 | controller = Controller()
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .parameters import Parameters
2 | from .preprocessing import Preprocessing
3 | from .model import TextClassifier
4 | from .model import Run
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import TextClassifier
2 | from .run import Run
--------------------------------------------------------------------------------
/src/model/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class TextClassifier(nn.ModuleList):
7 |
8 | def __init__(self, params):
9 | super(TextClassifier, self).__init__()
10 |
11 | # Parameters regarding text preprocessing
12 | self.seq_len = params.seq_len
13 | self.num_words = params.num_words
14 | self.embedding_size = params.embedding_size
15 |
16 | # Dropout definition
17 | self.dropout = nn.Dropout(0.25)
18 |
19 | # CNN parameters definition
20 | # Kernel sizes
21 | self.kernel_1 = 2
22 | self.kernel_2 = 3
23 | self.kernel_3 = 4
24 | self.kernel_4 = 5
25 |
26 | # Output size for each convolution
27 | self.out_size = params.out_size
28 | # Number of strides for each convolution
29 | self.stride = params.stride
30 |
31 | # Embedding layer definition
32 | self.embedding = nn.Embedding(self.num_words + 1, self.embedding_size, padding_idx=0)
33 |
34 | # Convolution layers definition
35 | self.conv_1 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_1, self.stride)
36 | self.conv_2 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_2, self.stride)
37 | self.conv_3 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_3, self.stride)
38 | self.conv_4 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_4, self.stride)
39 |
40 | # Max pooling layers definition
41 | self.pool_1 = nn.MaxPool1d(self.kernel_1, self.stride)
42 | self.pool_2 = nn.MaxPool1d(self.kernel_2, self.stride)
43 | self.pool_3 = nn.MaxPool1d(self.kernel_3, self.stride)
44 | self.pool_4 = nn.MaxPool1d(self.kernel_4, self.stride)
45 |
46 | # Fully connected layer definition
47 | self.fc = nn.Linear(self.in_features_fc(), 1)
48 |
49 |
50 | def in_features_fc(self):
51 | '''Calculates the number of output features after Convolution + Max pooling
52 |
53 | Convolved_Features = ((embedding_size + (2 * padding) - dilation * (kernel - 1) - 1) / stride) + 1
54 | Pooled_Features = ((embedding_size + (2 * padding) - dilation * (kernel - 1) - 1) / stride) + 1
55 |
56 | source: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
57 | '''
58 | # Calcualte size of convolved/pooled features for convolution_1/max_pooling_1 features
59 | out_conv_1 = ((self.embedding_size - 1 * (self.kernel_1 - 1) - 1) / self.stride) + 1
60 | out_conv_1 = math.floor(out_conv_1)
61 | out_pool_1 = ((out_conv_1 - 1 * (self.kernel_1 - 1) - 1) / self.stride) + 1
62 | out_pool_1 = math.floor(out_pool_1)
63 |
64 | # Calcualte size of convolved/pooled features for convolution_2/max_pooling_2 features
65 | out_conv_2 = ((self.embedding_size - 1 * (self.kernel_2 - 1) - 1) / self.stride) + 1
66 | out_conv_2 = math.floor(out_conv_2)
67 | out_pool_2 = ((out_conv_2 - 1 * (self.kernel_2 - 1) - 1) / self.stride) + 1
68 | out_pool_2 = math.floor(out_pool_2)
69 |
70 | # Calcualte size of convolved/pooled features for convolution_3/max_pooling_3 features
71 | out_conv_3 = ((self.embedding_size - 1 * (self.kernel_3 - 1) - 1) / self.stride) + 1
72 | out_conv_3 = math.floor(out_conv_3)
73 | out_pool_3 = ((out_conv_3 - 1 * (self.kernel_3 - 1) - 1) / self.stride) + 1
74 | out_pool_3 = math.floor(out_pool_3)
75 |
76 | # Calcualte size of convolved/pooled features for convolution_4/max_pooling_4 features
77 | out_conv_4 = ((self.embedding_size - 1 * (self.kernel_4 - 1) - 1) / self.stride) + 1
78 | out_conv_4 = math.floor(out_conv_4)
79 | out_pool_4 = ((out_conv_4 - 1 * (self.kernel_4 - 1) - 1) / self.stride) + 1
80 | out_pool_4 = math.floor(out_pool_4)
81 |
82 | # Returns "flattened" vector (input for fully connected layer)
83 | return (out_pool_1 + out_pool_2 + out_pool_3 + out_pool_4) * self.out_size
84 |
85 |
86 |
87 | def forward(self, x):
88 |
89 | # Sequence of tokes is filterd through an embedding layer
90 | x = self.embedding(x)
91 |
92 | # Convolution layer 1 is applied
93 | x1 = self.conv_1(x)
94 | x1 = torch.relu(x1)
95 | x1 = self.pool_1(x1)
96 |
97 | # Convolution layer 2 is applied
98 | x2 = self.conv_2(x)
99 | x2 = torch.relu((x2))
100 | x2 = self.pool_2(x2)
101 |
102 | # Convolution layer 3 is applied
103 | x3 = self.conv_3(x)
104 | x3 = torch.relu(x3)
105 | x3 = self.pool_3(x3)
106 |
107 | # Convolution layer 4 is applied
108 | x4 = self.conv_4(x)
109 | x4 = torch.relu(x4)
110 | x4 = self.pool_4(x4)
111 |
112 | # The output of each convolutional layer is concatenated into a unique vector
113 | union = torch.cat((x1, x2, x3, x4), 2)
114 | union = union.reshape(union.size(0), -1)
115 |
116 | # The "flattened" vector is passed through a fully connected layer
117 | out = self.fc(union)
118 | # Dropout is applied
119 | out = self.dropout(out)
120 | # Activation function is applied
121 | out = torch.sigmoid(out)
122 |
123 | return out.squeeze()
124 |
--------------------------------------------------------------------------------
/src/model/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import torch.nn.functional as F
4 |
5 | from torch.utils.data import Dataset, DataLoader
6 |
7 | class DatasetMaper(Dataset):
8 |
9 | def __init__(self, x, y):
10 | self.x = x
11 | self.y = y
12 |
13 | def __len__(self):
14 | return len(self.x)
15 |
16 | def __getitem__(self, idx):
17 | return self.x[idx], self.y[idx]
18 |
19 | class Run:
20 | '''Training, evaluation and metrics calculation'''
21 |
22 | @staticmethod
23 | def train(model, data, params):
24 |
25 | # Initialize dataset maper
26 | train = DatasetMaper(data['x_train'], data['y_train'])
27 | test = DatasetMaper(data['x_test'], data['y_test'])
28 |
29 | # Initialize loaders
30 | loader_train = DataLoader(train, batch_size=params.batch_size)
31 | loader_test = DataLoader(test, batch_size=params.batch_size)
32 |
33 | # Define optimizer
34 | optimizer = optim.RMSprop(model.parameters(), lr=params.learning_rate)
35 |
36 | # Starts training phase
37 | for epoch in range(params.epochs):
38 | # Set model in training model
39 | model.train()
40 | predictions = []
41 | # Starts batch training
42 | for x_batch, y_batch in loader_train:
43 |
44 | y_batch = y_batch.type(torch.FloatTensor)
45 |
46 | # Feed the model
47 | y_pred = model(x_batch)
48 |
49 | # Loss calculation
50 | loss = F.binary_cross_entropy(y_pred, y_batch)
51 |
52 | # Clean gradientes
53 | optimizer.zero_grad()
54 |
55 | # Gradients calculation
56 | loss.backward()
57 |
58 | # Gradients update
59 | optimizer.step()
60 |
61 | # Save predictions
62 | predictions += list(y_pred.detach().numpy())
63 |
64 | # Evaluation phase
65 | test_predictions = Run.evaluation(model, loader_test)
66 |
67 | # Metrics calculation
68 | train_accuary = Run.calculate_accuray(data['y_train'], predictions)
69 | test_accuracy = Run.calculate_accuray(data['y_test'], test_predictions)
70 | print("Epoch: %d, loss: %.5f, Train accuracy: %.5f, Test accuracy: %.5f" % (epoch+1, loss.item(), train_accuary, test_accuracy))
71 |
72 | @staticmethod
73 | def evaluation(model, loader_test):
74 |
75 | # Set the model in evaluation mode
76 | model.eval()
77 | predictions = []
78 |
79 | # Starst evaluation phase
80 | with torch.no_grad():
81 | for x_batch, y_batch in loader_test:
82 | y_pred = model(x_batch)
83 | predictions += list(y_pred.detach().numpy())
84 | return predictions
85 |
86 | @staticmethod
87 | def calculate_accuray(grand_truth, predictions):
88 | # Metrics calculation
89 | true_positives = 0
90 | true_negatives = 0
91 | for true, pred in zip(grand_truth, predictions):
92 | if (pred >= 0.5) and (true == 1):
93 | true_positives += 1
94 | elif (pred < 0.5) and (true == 0):
95 | true_negatives += 1
96 | else:
97 | pass
98 | # Return accuracy
99 | return (true_positives+true_negatives) / len(grand_truth)
100 |
--------------------------------------------------------------------------------
/src/parameters/__init__.py:
--------------------------------------------------------------------------------
1 | from .parameters import Parameters
--------------------------------------------------------------------------------
/src/parameters/parameters.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | @dataclass
4 | class Parameters:
5 | # Preprocessing parameeters
6 | seq_len: int = 35
7 | num_words: int = 2000
8 |
9 | # Model parameters
10 | embedding_size: int = 64
11 | out_size: int = 32
12 | stride: int = 2
13 |
14 | # Training parameters
15 | epochs: int = 10
16 | batch_size: int = 12
17 | learning_rate: float = 0.001
--------------------------------------------------------------------------------
/src/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocessing import Preprocessing
--------------------------------------------------------------------------------
/src/preprocessing/preprocessing.py:
--------------------------------------------------------------------------------
1 | import re
2 | import nltk
3 | import numpy as np
4 | import pandas as pd
5 | from sklearn.model_selection import train_test_split
6 | from nltk.tokenize import word_tokenize
7 |
8 | class Preprocessing:
9 |
10 | def __init__(self, num_words, seq_len):
11 | self.data = '/Users/Fer/Documents/OwnRepo/TextClassification-CNN-PyTorch/data/tweets.csv'
12 | self.num_words = num_words
13 | self.seq_len = seq_len
14 | self.vocabulary = None
15 | self.x_tokenized = None
16 | self.x_padded = None
17 | self.x_raw = None
18 | self.y = None
19 |
20 | self.x_train = None
21 | self.x_test = None
22 | self.y_train = None
23 | self.y_test = None
24 |
25 | def load_data(self):
26 | # Reads the raw csv file and split into
27 | # sentences (x) and target (y)
28 |
29 | df = pd.read_csv(self.data)
30 | df.drop(['id','keyword','location'], axis=1, inplace=True)
31 |
32 | self.x_raw = df['text'].values
33 | self.y = df['target'].values
34 |
35 | def clean_text(self):
36 | # Removes special symbols and just keep
37 | # words in lower or upper form
38 |
39 | self.x_raw = [x.lower() for x in self.x_raw]
40 | self.x_raw = [re.sub(r'[^A-Za-z]+', ' ', x) for x in self.x_raw]
41 |
42 | def text_tokenization(self):
43 | # Tokenizes each sentence by implementing the nltk tool
44 | self.x_raw = [word_tokenize(x) for x in self.x_raw]
45 |
46 | def build_vocabulary(self):
47 | # Builds the vocabulary and keeps the "x" most frequent words
48 | self.vocabulary = dict()
49 | fdist = nltk.FreqDist()
50 |
51 | for sentence in self.x_raw:
52 | for word in sentence:
53 | fdist[word] += 1
54 |
55 | common_words = fdist.most_common(self.num_words)
56 |
57 | for idx, word in enumerate(common_words):
58 | self.vocabulary[word[0]] = (idx+1)
59 |
60 | def word_to_idx(self):
61 | # By using the dictionary (vocabulary), it is transformed
62 | # each token into its index based representation
63 |
64 | self.x_tokenized = list()
65 |
66 | for sentence in self.x_raw:
67 | temp_sentence = list()
68 | for word in sentence:
69 | if word in self.vocabulary.keys():
70 | temp_sentence.append(self.vocabulary[word])
71 | self.x_tokenized.append(temp_sentence)
72 |
73 | def padding_sentences(self):
74 | # Each sentence which does not fulfill the required len
75 | # it's padded with the index 0
76 |
77 | pad_idx = 0
78 | self.x_padded = list()
79 |
80 | for sentence in self.x_tokenized:
81 | while len(sentence) < self.seq_len:
82 | sentence.insert(len(sentence), pad_idx)
83 | self.x_padded.append(sentence)
84 |
85 | self.x_padded = np.array(self.x_padded)
86 |
87 | def split_data(self):
88 | self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(self.x_padded, self.y, test_size=0.25, random_state=42)
--------------------------------------------------------------------------------