├── README.md
├── data
└── .gitkeep
├── data_utils.py
├── datasets.py
├── epochs
└── .gitkeep
├── main.py
├── model.py
├── results
├── .gitkeep
├── agnews.png
├── amazon.png
├── amazon_fine_grained.png
├── dbpedia.png
├── sogou.png
├── yahoo.png
├── yelp.png
└── yelp_fine_grained.png
├── statistics
└── .gitkeep
├── utils.py
└── vis.py
/README.md:
--------------------------------------------------------------------------------
1 | # CCCapsNet
2 | A PyTorch implementation of Compositional Coding Capsule Network based on PRL 2022 paper [Compositional Coding Capsule Network with K-Means Routing for Text Classification](https://www.sciencedirect.com/science/article/pii/S016786552200188X).
3 |
4 | ## Requirements
5 | * [Anaconda](https://www.anaconda.com/download/)
6 | * PyTorch
7 | ```
8 | conda install pytorch torchvision -c pytorch
9 | ```
10 | * PyTorchNet
11 | ```
12 | pip install git+https://github.com/pytorch/tnt.git@master
13 | ```
14 | * PyTorch-NLP
15 | ```
16 | pip install pytorch-nlp
17 | ```
18 | * capsule-layer
19 | ```
20 | pip install git+https://github.com/leftthomas/CapsuleLayer.git@master
21 | ```
22 |
23 | ## Datasets
24 | The original `AGNews`, `AmazonReview`, `DBPedia`, `YahooAnswers`, `SogouNews` and `YelpReview` datasets are coming from [here](http://goo.gl/JyCnZq).
25 |
26 | The original `Newsgroups`, `Reuters`, `Cade` and `WebKB` datasets can be found [here](http://ana.cachopo.org/datasets-for-single-label-text-categorization).
27 |
28 | The original `IMDB` dataset is downloaded by `PyTorch-NLP` automatically.
29 |
30 | We have uploaded all the original datasets into [BaiduYun](https://pan.baidu.com/s/16wBuNJiD0acgTHDeld9eDA)(access code:kddr) and
31 | [GoogleDrive](https://drive.google.com/open?id=10n_eZ2ZyRjhRWFjxky7_PhcGHecDjKJ2).
32 | The preprocessed datasets have been uploaded to [BaiduYun](https://pan.baidu.com/s/1hsIJAw54YZbVAqFiehEH6w)(access code:2kyd) and
33 | [GoogleDrive](https://drive.google.com/open?id=1KDE5NJKfgOwc6RNEf9_F0ZhLQZ3Udjx5).
34 |
35 | You needn't download the datasets by yourself, the code will download them automatically.
36 | If you encounter network issues, you can download all the datasets from the aforementioned cloud storage webs,
37 | and extract them into `data` directory.
38 |
39 | ## Usage
40 |
41 | ### Generate Preprocessed Data
42 | ```
43 | python utils.py --data_type yelp --fine_grained
44 | optional arguments:
45 | --data_type dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb',
46 | 'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon'])
47 | --fine_grained use fine grained class or not, it only works for reuters, yelp and amazon [default value is False]
48 | ```
49 | This step is not required, and it takes a long time to execute. So I have generated the preprocessed data before, and
50 | uploaded them to the aforementioned cloud storage webs. You could skip this step, and just do the next step, the code will
51 | download the data automatically.
52 |
53 | ### Train Text Classification
54 | ```
55 | visdom -logging_level WARNING & python main.py --data_type newsgroups --num_epochs 70
56 | optional arguments:
57 | --data_type dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb',
58 | 'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon'])
59 | --fine_grained use fine grained class or not, it only works for reuters, yelp and amazon [default value is False]
60 | --text_length the number of words about the text to load [default value is 5000]
61 | --routing_type routing type, it only works for capsule classifier [default value is 'k_means'](choices:['k_means', 'dynamic'])
62 | --loss_type loss type [default value is 'mf'](choices:['margin', 'focal', 'cross', 'mf', 'mc', 'fc', 'mfc'])
63 | --embedding_type embedding type [default value is 'cwc'](choices:['cwc', 'cc', 'normal'])
64 | --classifier_type classifier type [default value is 'capsule'](choices:['capsule', 'linear'])
65 | --embedding_size embedding size [default value is 64]
66 | --num_codebook codebook number, it only works for cwc and cc embedding [default value is 8]
67 | --num_codeword codeword number, it only works for cwc and cc embedding [default value is None]
68 | --hidden_size hidden size [default value is 128]
69 | --in_length in capsule length, it only works for capsule classifier [default value is 8]
70 | --out_length out capsule length, it only works for capsule classifier [default value is 16]
71 | --num_iterations routing iterations number, it only works for capsule classifier [default value is 3]
72 | --num_repeat gumbel softmax repeat number, it only works for cc embedding [default value is 10]
73 | --drop_out drop_out rate of GRU layer [default value is 0.5]
74 | --batch_size train batch size [default value is 32]
75 | --num_epochs train epochs number [default value is 10]
76 | --num_steps test steps number [default value is 100]
77 | --pre_model pre-trained model weight, it only works for routing_type experiment [default value is None]
78 | ```
79 | Visdom now can be accessed by going to `127.0.0.1:8097/env/$data_type` in your browser, `$data_type` means the dataset
80 | type which you are training.
81 |
82 | ## Benchmarks
83 | Adam optimizer is used with learning rate scheduling. The models are trained with 10 epochs and batch size of 32 on one
84 | NVIDIA Tesla V100 (32G) GPU.
85 |
86 | The texts are preprocessed as only number and English words, max length is 5000.
87 |
88 | Here is the dataset details:
89 |
90 |
91 |
92 |
93 | Dataset |
94 | agnews |
95 | dbpedia |
96 | yahoo |
97 | sogou |
98 | yelp |
99 | yelp fine grained |
100 | amazon |
101 | amazon fine grained |
102 |
103 |
104 |
105 |
106 | Num. of Train Texts |
107 | 120,000 |
108 | 560,000 |
109 | 1,400,000 |
110 | 450,000 |
111 | 560,000 |
112 | 650,000 |
113 | 3,600,000 |
114 | 3,000,000 |
115 |
116 |
117 | Num. of Test Texts |
118 | 7,600 |
119 | 70,000 |
120 | 60,000 |
121 | 60,000 |
122 | 38,000 |
123 | 50,000 |
124 | 400,000 |
125 | 650,000 |
126 |
127 |
128 | Num. of Vocabulary |
129 | 62,535 |
130 | 548,338 |
131 | 771,820 |
132 | 106,385 |
133 | 200,790 |
134 | 216,985 |
135 | 931,271 |
136 | 835,818 |
137 |
138 |
139 | Num. of Classes |
140 | 4 |
141 | 14 |
142 | 10 |
143 | 5 |
144 | 2 |
145 | 5 |
146 | 2 |
147 | 5 |
148 |
149 |
150 |
151 |
152 | Here is the model parameter details, the model name are formalized as `embedding_type-classifier_type`:
153 |
154 |
155 |
156 |
157 | Dataset |
158 | agnews |
159 | dbpedia |
160 | yahoo |
161 | sogou |
162 | yelp |
163 | yelp fine grained |
164 | amazon |
165 | amazon fine grained |
166 |
167 |
168 |
169 |
170 | Normal-Linear |
171 | 4,448,192 |
172 | 35,540,864 |
173 | 49,843,200 |
174 | 7,254,720 |
175 | 13,296,256 |
176 | 14,333,120 |
177 | 60,047,040 |
178 | 53,938,432 |
179 |
180 |
181 | CC-Linear |
182 | 2,449,120 |
183 | 26,770,528 |
184 | 37,497,152 |
185 | 4,704,040 |
186 | 8,479,856 |
187 | 9,128,040 |
188 | 45,149,776 |
189 | 40,568,416 |
190 |
191 |
192 | CWC-Linear |
193 | 2,449,120 |
194 | 26,770,528 |
195 | 37,497,152 |
196 | 4,704,040 |
197 | 8,479,856 |
198 | 9,128,040 |
199 | 45,149,776 |
200 | 40,568,416 |
201 |
202 |
203 | Normal-Capsule |
204 | 4,455,872 |
205 | 35,567,744 |
206 | 49,862,400 |
207 | 7,264,320 |
208 | 13,300,096 |
209 | 14,342,720 |
210 | 60,050,880 |
211 | 53,948,032 |
212 |
213 |
214 | CC-Capsule |
215 | 2,456,800 |
216 | 26,797,408 |
217 | 37,516,352 |
218 | 4,713,640 |
219 | 8,483,696 |
220 | 9,137,640 |
221 | 45,153,616 |
222 | 40,578,016 |
223 |
224 |
225 | CWC-Capsule |
226 | 2,456,800 |
227 | 26,797,408 |
228 | 37,516,352 |
229 | 4,713,640 |
230 | 8,483,696 |
231 | 9,137,640 |
232 | 45,153,616 |
233 | 40,578,016 |
234 |
235 |
236 |
237 |
238 | Here is the loss function details, we use `AGNews` dataset and `Normal-Linear` model to test different loss functions:
239 |
240 |
241 |
242 |
243 | Loss Function |
244 | margin |
245 | focal |
246 | cross |
247 | margin+focal |
248 | margin+cross |
249 | focal+cross |
250 | margin+focal+cross |
251 |
252 |
253 |
254 |
255 | Accuracy |
256 | 92.37% |
257 | 92.13% |
258 | 92.05% |
259 | 92.64% |
260 | 91.95% |
261 | 92.09% |
262 | 92.38% |
263 |
264 |
265 |
266 |
267 | Here is the accuracy details, we use `margin+focal` as our loss function, for `capsule` model, `3 iters` is used,
268 | if `embedding_type` is `CC`, then plus `num_repeat`:
269 |
270 |
271 |
272 |
273 | Dataset |
274 | agnews |
275 | dbpedia |
276 | yahoo |
277 | sogou |
278 | yelp |
279 | yelp fine grained |
280 | amazon |
281 | amazon fine grained |
282 |
283 |
284 |
285 |
286 | Normal-Linear |
287 | 92.64% |
288 | 98.84% |
289 | 74.13% |
290 | 97.37% |
291 | 96.69% |
292 | 66.23% |
293 | 95.09% |
294 | 60.78% |
295 |
296 |
297 | CC-Linear-10 |
298 | 73.11% |
299 | 92.66% |
300 | 48.01% |
301 | 93.50% |
302 | 87.81% |
303 | 50.33% |
304 | 83.20% |
305 | 45.77% |
306 |
307 |
308 | CC-Linear-30 |
309 | 81.05% |
310 | 95.29% |
311 | 53.50% |
312 | 94.65% |
313 | 91.33% |
314 | 55.22% |
315 | 87.37% |
316 | 50.00% |
317 |
318 |
319 | CC-Linear-50 |
320 | 83.13% |
321 | 96.06% |
322 | 57.87% |
323 | 95.20% |
324 | 92.37% |
325 | 56.66% |
326 | 89.04% |
327 | 51.30% |
328 |
329 |
330 | CWC-Linear |
331 | 91.93% |
332 | 98.83% |
333 | 73.58% |
334 | 97.37% |
335 | 96.35% |
336 | 65.11% |
337 | 94.90% |
338 | 60.29% |
339 |
340 |
341 | Normal-Capsule |
342 | 92.18% |
343 | 98.86% |
344 | 74.12% |
345 | 97.52% |
346 | 96.56% |
347 | 66.23% |
348 | 95.18% |
349 | 61.36% |
350 |
351 |
352 | CC-Capsule-10 |
353 | 73.53% |
354 | 93.04% |
355 | 50.52% |
356 | 94.44% |
357 | 87.98% |
358 | 54.14% |
359 | 83.64% |
360 | 47.44% |
361 |
362 |
363 | CC-Capsule-30 |
364 | 81.71% |
365 | 95.72% |
366 | 60.48% |
367 | 95.96% |
368 | 91.90% |
369 | 58.27% |
370 | 87.88% |
371 | 51.63% |
372 |
373 |
374 | CC-Capsule-50 |
375 | 84.05% |
376 | 96.27% |
377 | 60.31% |
378 | 96.00% |
379 | 92.82% |
380 | 59.48% |
381 | 89.07% |
382 | 52.06% |
383 |
384 |
385 | CWC-Capsule |
386 | 92.12% |
387 | 98.81% |
388 | 73.78% |
389 | 97.42% |
390 | 96.28% |
391 | 65.38% |
392 | 94.98% |
393 | 60.94% |
394 |
395 |
396 |
397 |
398 | Here is the model parameter details, we use `CWC-Capsule` as our model, the model name are formalized as `num_codewords`
399 | for each dataset:
400 |
401 |
402 |
403 |
404 | Dataset |
405 | agnews |
406 | dbpedia |
407 | yahoo |
408 | sogou |
409 | yelp |
410 | yelp fine grained |
411 | amazon |
412 | amazon fine grained |
413 |
414 |
415 |
416 |
417 | 57766677 |
418 | 2,957,592 |
419 | 31,184,624 |
420 | 43,691,424 |
421 | 5,565,232 |
422 | 10,090,528 |
423 | 10,874,032 |
424 | 52,604,296 |
425 | 47,265,072 |
426 |
427 |
428 | 68877788 |
429 | 3,458,384 |
430 | 35,571,840 |
431 | 49,866,496 |
432 | 6,416,824 |
433 | 11,697,360 |
434 | 12,610,424 |
435 | 60,054,976 |
436 | 53,952,128 |
437 |
438 |
439 |
440 |
441 | Here is the accuracy details:
442 |
443 |
444 |
445 |
446 | Dataset |
447 | agnews |
448 | dbpedia |
449 | yahoo |
450 | sogou |
451 | yelp |
452 | yelp fine grained |
453 | amazon |
454 | amazon fine grained |
455 |
456 |
457 |
458 |
459 | 57766677 |
460 | 92.54% |
461 | 98.85% |
462 | 73.96% |
463 | 97.41% |
464 | 96.38% |
465 | 65.86% |
466 | 94.98% |
467 | 60.98% |
468 |
469 |
470 | 68877788 |
471 | 92.05% |
472 | 98.82% |
473 | 73.93% |
474 | 97.52% |
475 | 96.44% |
476 | 65.63% |
477 | 95.05% |
478 | 61.02% |
479 |
480 |
481 |
482 |
483 | Here is the accuracy details, we use `57766677` config, the model name are formalized as `num_iterations`:
484 |
485 |
486 |
487 |
488 | Dataset |
489 | agnews |
490 | dbpedia |
491 | yahoo |
492 | sogou |
493 | yelp |
494 | yelp fine grained |
495 | amazon |
496 | amazon fine grained |
497 |
498 |
499 |
500 |
501 | 1 |
502 | 92.28% |
503 | 98.82% |
504 | 73.93% |
505 | 97.25% |
506 | 96.58% |
507 | 65.60% |
508 | 95.00% |
509 | 61.08% |
510 |
511 |
512 | 3 |
513 | 92.54% |
514 | 98.85% |
515 | 73.96% |
516 | 97.41% |
517 | 96.38% |
518 | 65.86% |
519 | 94.98% |
520 | 60.98% |
521 |
522 |
523 | 5 |
524 | 92.21% |
525 | 98.88% |
526 | 73.85% |
527 | 97.38% |
528 | 96.38% |
529 | 65.36% |
530 | 95.05% |
531 | 61.23% |
532 |
533 |
534 |
535 |
536 | ## Results
537 | The train/test loss、accuracy and confusion matrix are showed with visdom. The pretrained models and more results can be
538 | found in [BaiduYun](https://pan.baidu.com/s/1mpIXTfuECiSFVxJcLR1j3A) (access code:xer4) and
539 | [GoogleDrive](https://drive.google.com/drive/folders/1hu8sA517kA5bowzE6TYK_xSLGBUCoanK?usp=sharing).
540 |
541 | **agnews**
542 |
543 | 
544 |
545 | **dbpedia**
546 |
547 | 
548 |
549 | **yahoo**
550 |
551 | 
552 |
553 | **sogou**
554 |
555 | 
556 |
557 | **yelp**
558 |
559 | 
560 |
561 | **yelp fine grained**
562 |
563 | 
564 |
565 | **amazon**
566 |
567 | 
568 |
569 | **amazon fine grained**
570 |
571 | 
572 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/data/.gitkeep
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import shutil
4 | import warnings
5 | import zipfile
6 | from os import makedirs
7 | from os.path import exists
8 | from sys import stdout
9 |
10 | import requests
11 |
12 | re_letter_number = re.compile(r'[^a-zA-Z0-9 ]')
13 | re_number_letter = re.compile(r'^(\d+)([a-z]\w*)$')
14 |
15 |
16 | def text_preprocess(text, data_type):
17 | if data_type == 'sogou' or data_type == 'yahoo' or data_type == 'yelp':
18 | # Remove \\n character
19 | text = text.replace('\\n', ' ')
20 | if data_type == 'imdb' or data_type == 'yahoo':
21 | # Remove
character
22 | text = text.replace('
', ' ')
23 | if data_type not in ['newsgroups', 'reuters', 'webkb', 'cade']:
24 | # Turn punctuation, foreign word, etc. into SPACES word SPACES).
25 | text = re_letter_number.sub(lambda m: ' ' + m.group(0) + ' ', text)
26 | # Turn all letters to lowercase.
27 | text = text.lower()
28 | # Turn the number-letter word to single number and word (such as turn 2008year into 2008 year).
29 | text = ' '.join(' '.join(w for w in re_number_letter.match(word).groups())
30 | if re_number_letter.match(word) else word for word in text.split())
31 | # Turn all numbers to single number (such as turn 789 into 7 8 9).
32 | text = ' '.join(' '.join(w for w in word) if word.isdigit() else word for word in text.split())
33 | # Substitute multiple SPACES by a single SPACE.
34 | text = ' '.join(text.split())
35 | return text
36 |
37 |
38 | class GoogleDriveDownloader:
39 | """
40 | Minimal class to download shared files from Google Drive.
41 | """
42 |
43 | CHUNK_SIZE = 32768
44 | DOWNLOAD_URL = "https://docs.google.com/uc?export=download"
45 |
46 | @staticmethod
47 | def download_file_from_google_drive(file_id, file_name, dest_path, overwrite=False):
48 | """
49 | Downloads a shared file from google drive into a given folder.
50 | Optionally unzips it.
51 |
52 | Args:
53 | file_id (str): the file identifier. You can obtain it from the sherable link.
54 | file_name (str): the file name.
55 | dest_path (str): the destination where to save the downloaded file.
56 | overwrite (bool): optional, if True forces re-download and overwrite.
57 |
58 | Returns:
59 | None
60 | """
61 |
62 | if not exists(dest_path):
63 | makedirs(dest_path)
64 |
65 | if not exists(os.path.join(dest_path, file_name)) or overwrite:
66 |
67 | session = requests.Session()
68 |
69 | print('Downloading {} into {}... '.format(file_name, dest_path), end='')
70 | stdout.flush()
71 |
72 | response = session.get(GoogleDriveDownloader.DOWNLOAD_URL, params={'id': file_id}, stream=True)
73 |
74 | token = GoogleDriveDownloader._get_confirm_token(response)
75 | if token:
76 | params = {'id': file_id, 'confirm': token}
77 | response = session.get(GoogleDriveDownloader.DOWNLOAD_URL, params=params, stream=True)
78 |
79 | GoogleDriveDownloader._save_response_content(response, os.path.join(dest_path, file_name))
80 | print('Done.')
81 |
82 | try:
83 | print('Unzipping... ', end='')
84 | stdout.flush()
85 | with zipfile.ZipFile(os.path.join(dest_path, file_name), 'r') as zip_file:
86 | for member in zip_file.namelist():
87 | filename = os.path.basename(member)
88 | # skip directories
89 | if not filename:
90 | continue
91 | # copy file (taken from zipfile's extract)
92 | source = zip_file.open(member)
93 | target = open(os.path.join(dest_path, filename), 'wb')
94 | with source, target:
95 | shutil.copyfileobj(source, target)
96 | print('Done.')
97 | except zipfile.BadZipfile:
98 | warnings.warn('Ignoring `unzip` since "{}" does not look like a valid zip file'.format(file_name))
99 |
100 | @staticmethod
101 | def _get_confirm_token(response):
102 | for key, value in response.cookies.items():
103 | if key.startswith('download_warning'):
104 | return value
105 | return None
106 |
107 | @staticmethod
108 | def _save_response_content(response, destination):
109 | with open(destination, 'wb') as f:
110 | for chunk in response.iter_content(GoogleDriveDownloader.CHUNK_SIZE):
111 | if chunk: # filter out keep-alive new chunks
112 | f.write(chunk)
113 |
114 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from torchnlp.datasets.dataset import Dataset
7 |
8 | from data_utils import GoogleDriveDownloader as gdd
9 | from data_utils import text_preprocess
10 |
11 |
12 | def imdb_dataset(directory='data/', data_type='imdb', preprocessing=False, fine_grained=False,
13 | verbose=False, text_length=5000, share_id='1X7YI7nDpKEPio2J-eH7uWCiWoiw2jbSP'):
14 | """
15 | Load the IMDB dataset (Large Movie Review Dataset v1.0).
16 |
17 | This is a dataset for binary sentiment classification containing substantially more data than
18 | previous benchmark datasets. Provided a set of 25,000 highly polar movie reviews for training,
19 | and 25,000 for testing.
20 | After preprocessing, the total number of training samples is 25,000 and testing samples 25,000.
21 | The min length of text about train data is 11, max length is 2,803, average length is 281; the
22 | min length of text about test data is 8, max length is 2,709, average length is 275.
23 |
24 | **Reference:** http://ai.stanford.edu/~amaas/data/sentiment/
25 |
26 | Args:
27 | directory (str, optional): Directory to cache the dataset.
28 | data_type (str, optional): Which dataset to use.
29 | preprocessing (bool, optional): Whether to preprocess the original dataset. If preprocessing
30 | equals None, it will not download the preprocessed dataset, it will generate preprocessed
31 | dataset from the original dataset.
32 | fine_grained (bool, optional): Whether to use fine_grained dataset instead of polarity dataset.
33 | verbose (bool, optional): Whether to print the dataset details.
34 | text_length (int, optional): Only load the first text_length words, it only works when
35 | preprocessing is True.
36 | share_id (str, optional): Google Drive share ID about the original dataset to download.
37 |
38 | Returns:
39 | :class:`tuple` of :class:`torchnlp.datasets.Dataset`: Tuple with the training dataset and test
40 | dataset.
41 |
42 | Example:
43 | >>> train, test = imdb_dataset(preprocessing=True)
44 | >>> train[0:2]
45 | [{
46 | 'label': 'pos',
47 | 'text': 'for a movie that gets no respect there sure are a lot of memorable quotes...'},
48 | {
49 | 'label': 'pos',
50 | 'text': 'bizarre horror movie filled with famous faces but stolen by cristina raines...'}]
51 | >>> test[0:2]
52 | [{
53 | 'label': 'pos',
54 | 'text': 'based on an actual story , john boorman shows the struggle of an american...'},
55 | {
56 | 'label': 'pos',
57 | 'text': 'this is a gem as a film four production the anticipated quality was indeed...'}]
58 | """
59 |
60 | # other dataset have been set before, only imdb should be set here
61 | if preprocessing and data_type == 'imdb':
62 | share_id = '1naVVErkRQNNJXTA6X_X6YrJY0jPOeuPh'
63 |
64 | if preprocessing:
65 | gdd.download_file_from_google_drive(share_id, data_type + '_preprocessed.zip', directory + data_type)
66 | if fine_grained:
67 | train_file, test_file = 'preprocessed_fine_grained_train.csv', 'preprocessed_fine_grained_test.csv'
68 | else:
69 | train_file, test_file = 'preprocessed_train.csv', 'preprocessed_test.csv'
70 | else:
71 | gdd.download_file_from_google_drive(share_id, data_type + '_original.zip', directory + data_type)
72 | if fine_grained:
73 | train_file, test_file = 'original_fine_grained_train.csv', 'original_fine_grained_test.csv'
74 | else:
75 | train_file, test_file = 'original_train.csv', 'original_test.csv'
76 |
77 | if verbose:
78 | min_train_length, avg_train_length, max_train_length = sys.maxsize, 0, 0
79 | min_test_length, avg_test_length, max_test_length = sys.maxsize, 0, 0
80 |
81 | ret = []
82 | for file_name in [train_file, test_file]:
83 | csv_file = np.array(pd.read_csv(os.path.join(directory, data_type, file_name), header=None)).tolist()
84 | examples = []
85 | for label, text in csv_file:
86 | label, text = str(label), str(text)
87 | if preprocessing:
88 | if len(text.split()) > text_length:
89 | text = ' '.join(text.split()[:text_length])
90 | elif preprocessing is None:
91 | text = text_preprocess(text, data_type)
92 | if len(text.split()) == 0:
93 | continue
94 | if verbose:
95 | if file_name == train_file:
96 | avg_train_length += len(text.split())
97 | if len(text.split()) > max_train_length:
98 | max_train_length = len(text.split())
99 | if len(text.split()) < min_train_length:
100 | min_train_length = len(text.split())
101 | if file_name == test_file:
102 | avg_test_length += len(text.split())
103 | if len(text.split()) > max_test_length:
104 | max_test_length = len(text.split())
105 | if len(text.split()) < min_test_length:
106 | min_test_length = len(text.split())
107 | examples.append({'label': label, 'text': text})
108 | ret.append(Dataset(examples))
109 |
110 | if verbose:
111 | print('[!] train samples: {} length--(min: {}, avg: {}, max: {})'.
112 | format(len(ret[0]), min_train_length, round(avg_train_length / len(ret[0])), max_train_length))
113 | print('[!] test samples: {} length--(min: {}, avg: {}, max: {})'.
114 | format(len(ret[1]), min_test_length, round(avg_test_length / len(ret[1])), max_test_length))
115 | return tuple(ret)
116 |
117 |
118 | def newsgroups_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000):
119 | """
120 | Load the 20 Newsgroups dataset (Version 'bydate').
121 |
122 | The 20 Newsgroups data set is a collection of approximately 20,000 newsgroup documents,
123 | partitioned (nearly) evenly across 20 different newsgroups. The total number of training
124 | samples is 11,293 and testing 7,527.
125 | After preprocessing, the total number of training samples is 11,293 and testing samples 7,527.
126 | The min length of text about train data is 1, max length is 6,779, average length is 143; the
127 | min length of text about test data is 1, max length is 6,142, average length is 139.
128 |
129 | **Reference:** http://qwone.com/~jason/20Newsgroups/
130 |
131 | Example:
132 | >>> train, test = newsgroups_dataset(preprocessing=True)
133 | >>> train[0:2]
134 | [{
135 | 'label': 'alt.atheism',
136 | 'text': 'alt atheism faq atheist resourc archiv name atheism resourc alt atheism...'},
137 | {
138 | 'label': 'alt.atheism',
139 | 'text': 'alt atheism faq introduct atheism archiv name atheism introduct alt...'}]
140 | >>> test[0:2]
141 | [{
142 | 'label': 'alt.atheism',
143 | 'text': 'bibl quiz answer articl healta saturn wwc edu healta saturn wwc edu...'},
144 | {
145 | 'label': 'alt.atheism',
146 | 'text': 'amus atheist and agnost articl timmbak mcl timmbak mcl ucsb edu clam bake...'}]
147 | """
148 |
149 | share_id = '1y8M5yf0DD21ox3K76xJyoCkGIU1Zc4iq' if preprocessing else '18_p4_RnCd0OO2qNxteApbIQ9abrfMyjC'
150 | return imdb_dataset(directory, 'newsgroups', preprocessing, verbose=verbose, text_length=text_length,
151 | share_id=share_id)
152 |
153 |
154 | def reuters_dataset(directory='data/', preprocessing=False, fine_grained=False, verbose=False, text_length=5000):
155 | """
156 | Load the Reuters-21578 R8 or Reuters-21578 R52 dataset (Version 'modApté').
157 |
158 | The Reuters-21578 dataset considers only the documents with a single topic and the classes
159 | which still have at least one train and one test example, we have 8 of the 10 most frequent
160 | classes and 52 of the original 90. In total there are 5,485 trainig samples and 2,189 testing
161 | samples in R8 dataset. The total number of training samples is 6,532 and testing 2,568 in R52
162 | dataset.
163 | After preprocessing, the total number of training samples is 5,485 and testing samples 2,189.
164 | The min length of text about train data is 4, max length is 533, average length is 66; the min
165 | length of text about test data is 5, max length is 484, average length is 60. (R8)
166 | After preprocessing, the total number of training samples is 6,532 and testing samples 2,568.
167 | The min length of text about train data is 4, max length is 595, average length is 70; the min
168 | length of text about test data is 5, max length is 484, average length is 64. (R52)
169 |
170 | **Reference:** http://www.daviddlewis.com/resources/testcollections/reuters21578/
171 |
172 | Example:
173 | >>> train, test = reuters_dataset(preprocessing=True)
174 | >>> train[0:2]
175 | [{
176 | 'label': 'earn',
177 | 'text': 'champion product approv stock split champion product inc board director...'}
178 | {
179 | 'label': 'acq',
180 | 'text': 'comput termin system cpml complet sale comput termin system inc complet...'}]
181 | >>> test[0:2]
182 | [{
183 | 'label': 'trade',
184 | 'text': 'asian export fear damag japan rift mount trade friction and japan rais...'},
185 | {
186 | 'label': 'grain',
187 | 'text': 'china daili vermin eat pct grain stock survei provinc and citi show...'}]
188 | """
189 |
190 | share_id = '1CY3W31rdagEJ8Kr5gHPeRgS1GVks-YVv' if preprocessing else '1coe-1WB4H7PBY2IG_CVeCl4-UNZRwy1I'
191 | return imdb_dataset(directory, 'reuters', preprocessing, fine_grained, verbose, text_length, share_id)
192 |
193 |
194 | def webkb_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000):
195 | """
196 | Load the World Wide Knowledge Base (Web->Kb) dataset (Version 1).
197 |
198 | The World Wide Knowledge Base (Web->Kb) dataset is collected by the World Wide Knowledge Base
199 | (Web->Kb) project of the CMU text learning group. These pages were collected from computer
200 | science departments of various universities in 1997, manually classified into seven different
201 | classes: student, faculty, staff, department, course, project, and other. The classes Department
202 | and Staff is discarded, because there were only a few pages from each university. The class Other
203 | is discarded, because pages were very different among this class. The total number of training
204 | samples is 2,785 and testing 1,383.
205 | After preprocessing, the total number of training samples is 2,785 and testing samples 1,383.
206 | The min length of text about train data is 1, max length is 20,628, average length is 134; the min
207 | length of text about test data is 1, max length is 2,082, average length is 136.
208 |
209 | **Reference:** http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-20/www/data/
210 |
211 | Example:
212 | >>> train, test = webkb_dataset(preprocessing=True)
213 | >>> train[0:2]
214 | [{
215 | 'label': 'student',
216 | 'text': 'brian comput scienc depart univers wisconsin dayton street madison offic...'}
217 | {
218 | 'label': 'student',
219 | 'text': 'denni swanson web page mail pop uki offic hour comput lab offic anderson...'}]
220 | >>> test[0:2]
221 | [{
222 | 'label': 'student',
223 | 'text': 'eric homepag eric wei tsinghua physic fudan genet'},
224 | {
225 | 'label': 'course',
226 | 'text': 'comput system perform evalu model new sept assign due oct postscript text...'}]
227 | """
228 |
229 | share_id = '1oqcl2N0kDoBlHo_hFgKc_MaSvs0ny1t7' if preprocessing else '12uR98xYZ44fXX0WUf9RjOM4GpBt3JAV8'
230 | return imdb_dataset(directory, 'webkb', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id)
231 |
232 |
233 | def cade_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000):
234 | """
235 | Load the Cade12 dataset (Version 1).
236 |
237 | The Cade12 dataset is corresponding to a subset of web pages extracted from the CADÊ Web Directory,
238 | which points to Brazilian web pages classified by human experts. The total number of training
239 | samples is 27,322 and testing 13,661.
240 | After preprocessing, the total number of training samples is 27,322 and testing samples 13,661.
241 | The min length of text about train data is 2, max length is 22,352, average length is 119; the min
242 | length of text about test data is 2, max length is 15,318, average length is 112.
243 |
244 | **Reference:** http://www.cade.com.br/
245 |
246 | Example:
247 | >>> train, test = cade_dataset(preprocessing=True)
248 | >>> train[0:2]
249 | [{
250 | 'label': '08_cultura',
251 | 'text': 'br br email arvores arvores http www apoio mascote natureza vida links foram...'}
252 | {
253 | 'label': '02_sociedade',
254 | 'text': 'page frames browser support virtual araraquara shop'}]
255 | >>> test[0:2]
256 | [{
257 | 'label': '02_sociedade',
258 | 'text': 'dezembro envie mail br manutencao funcionarios funcionarios funcionarios...'},
259 | {
260 | 'label': '07_internet',
261 | 'text': 'auto sao pagina br br computacao rede internet internet internet internet...'}]
262 | """
263 |
264 | share_id = '13CwKytxKlvMP6FW9iOCOMvmKlm5YWD-k' if preprocessing else '1cWlJHAt5dhomDxoQQmXu3APaFkdjRg_P'
265 | return imdb_dataset(directory, 'cade', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id)
266 |
267 |
268 | def dbpedia_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000):
269 | """
270 | Load the DBPedia Ontology Classification dataset (Version 2).
271 |
272 | The DBpedia ontology classification dataset is constructed by picking 14 non-overlapping classes
273 | from DBpedia 2014. They are listed in classes.txt. From each of these 14 ontology classes, we
274 | randomly choose 40,000 training samples and 5,000 testing samples. Therefore, the total size
275 | of the training dataset is 560,000 and testing dataset 70,000.
276 | After preprocessing, the total number of training samples is 560,000 and testing samples 70,000.
277 | The min length of text about train data is 3, max length is 2,780, average length is 64; the min
278 | length of text about test data is 4, max length is 930, average length is 64.
279 |
280 | **Reference:** http://dbpedia.org
281 |
282 | Example:
283 | >>> train, test = dbpedia_dataset(preprocessing=True)
284 | >>> train[0:2]
285 | [{
286 | 'label': 'Company',
287 | 'text': 'e . d . abbott ltd abbott of farnham e d abbott limited was a british...'},
288 | {
289 | 'label': 'Company',
290 | 'text': 'schwan - stabilo schwan - stabilo is a german maker of pens for writing...'}]
291 | >>> test[0:2]
292 | [{
293 | 'label': 'Company',
294 | 'text': 'ty ku ty ku / ta ɪ ku ː / is an american alcoholic beverage company that...'},
295 | {
296 | 'label': 'Company',
297 | 'text': 'odd lot entertainment oddlot entertainment founded in 2 0 0 1 by longtime...'}]
298 | """
299 |
300 | share_id = '1egq6UCaaqeZOq7siitXEIfIFwYUjFjnP' if preprocessing else '1YEZP-ajK3fUEMhdgkATRmikeuI7EjI9X'
301 | return imdb_dataset(directory, 'dbpedia', preprocessing, verbose=verbose, text_length=text_length,
302 | share_id=share_id)
303 |
304 |
305 | def agnews_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000):
306 | """
307 | Load the AG's News Topic Classification dataset (Version 3).
308 |
309 | The AG's news topic classification dataset is constructed by choosing 4 largest classes from
310 | the original corpus. Each class contains 30,000 training samples and 1,900 testing samples.
311 | The total number of training samples is 120,000 and testing 7,600.
312 | After preprocessing, the total number of training samples is 120,000 and testing samples 7,600.
313 | The min length of text about train data is 13, max length is 354, average length is 49; the min
314 | length of text about test data is 15, max length is 250, average length is 48.
315 |
316 | **Reference:** http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html
317 |
318 | Example:
319 | >>> train, test = agnews_dataset(preprocessing=True)
320 | >>> train[0:2]
321 | [{
322 | 'label': 'Business',
323 | 'text': 'wall st . bears claw back into the black ( reuters ) reuters - short - sellers...'},
324 | {
325 | 'label': 'Business',
326 | 'text': 'carlyle looks toward commercial aerospace ( reuters ) reuters - private investment...'}]
327 | >>> test[0:2]
328 | [{
329 | 'label': 'Business',
330 | 'text': 'fears for t n pension after talks unions representing workers at turner newall...'},
331 | {
332 | 'label': 'Sci/Tech',
333 | 'text': 'the race is on : second private team sets launch date for human spaceflight...'}]
334 | """
335 |
336 | share_id = '153R49C-JY8NDmRwc7bikZvU3EEEjKRs2' if preprocessing else '1EKKHXTXwGJaitaPn8BYD2bpSGQCB76te'
337 | return imdb_dataset(directory, 'agnews', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id)
338 |
339 |
340 | def yahoo_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000):
341 | """
342 | Load the Yahoo! Answers Topic Classification dataset (Version 2).
343 |
344 | The Yahoo! Answers topic classification dataset is constructed using 10 largest main categories.
345 | Each class contains 140,000 training samples and 6,000 testing samples. Therefore, the total number
346 | of training samples is 1,400,000 and testing samples 60,000 in this dataset.
347 | After preprocessing, the total number of training samples is 1,400,000 and testing samples 60,000.
348 | The min length of text about train data is 2, max length is 4,044, average length is 118; the min
349 | length of text about test data is 3, max length is 4,017, average length is 119.
350 |
351 | **Reference:** https://webscope.sandbox.yahoo.com/catalog.php?datatype=l
352 |
353 | Example:
354 | >>> train, test = yahoo_dataset(preprocessing=True)
355 | >>> train[0:2]
356 | [{
357 | 'label': 'Computers & Internet',
358 | 'text': 'why doesn ' t an optical mouse work on a glass table ? or even on some surfaces...'},
359 | {
360 | 'label': 'Sports',
361 | 'text': 'what is the best off - road motorcycle trail ? long - distance trail throughout...'}]
362 | >>> test[0:2]
363 | [{
364 | 'label': 'Family & Relationships',
365 | 'text': 'what makes friendship click ? how does the spark keep going ? good communication...'},
366 | {
367 | 'label': 'Science & Mathematics',
368 | 'text': 'why does zebras have stripes ? what is the purpose or those stripes ? who do they...'}]
369 | """
370 |
371 | share_id = '1LS7iQM3qMofMCVlm08LfniyqXsdhFdnn' if preprocessing else '15xpGyKaQk2-WDrjzsz57TVKrgQbhM7Ct'
372 | return imdb_dataset(directory, 'yahoo', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id)
373 |
374 |
375 | def sogou_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000):
376 | """
377 | Load the Sogou News Topic Classification dataset (Version 3).
378 |
379 | The Sogou news topic classification dataset is constructed by manually labeling each news article
380 | according to its URL, which represents roughly the categorization of news in their websites. We
381 | chose 5 largest categories for the dataset, each having 90,000 samples for training and 12,000 for
382 | testing. The Pinyin texts are converted using pypinyin combined with jieba Chinese segmentation
383 | system. In total there are 450,000 training samples and 60,000 testing samples.
384 | After preprocessing, the total number of training samples is 450,000 and testing samples 60,000.
385 | The min length of text about train data is 2, max length is 42,695, average length is 612; the min
386 | length of text about test data is 3, max length is 64,651, average length is 616.
387 |
388 | **Reference:** http://www.sogou.com/labs/dl/ca.html and http://www.sogou.com/labs/dl/cs.html
389 |
390 | Example:
391 | >>> train, test = sogou_dataset(preprocessing=True)
392 | >>> train[0:2]
393 | [{
394 | 'label': 'automobile',
395 | 'text': '2 0 0 8 di4 qi1 jie4 qi1ng da3o guo2 ji4 che1 zha3n me3i nv3 mo2 te4 2 0 0 8...'}
396 | {
397 | 'label': 'automobile',
398 | 'text': 'zho1ng hua2 ju4n jie2 frv ya4o shi tu2 we2i zho1ng hua2 ju4n jie2 frv ya4o shi .'}]
399 | >>> test[0:2]
400 | [{
401 | 'label': 'sports',
402 | 'text': 'ti3 ca1o shi4 jie4 be1i : che2ng fe1i na2 pi2ng he2ng mu4 zi4 yo2u ca1o ji1n...'},
403 | {
404 | 'label': 'automobile',
405 | 'text': 'da3o ha2ng du2 jia1 ti2 go1ng me3i ri4 ba4o jia4 re4 xia4n : 0 1 0 - 6 4 4 3...'}]
406 | """
407 |
408 | share_id = '1HbJHzIacbQt7m-IRZzv8nRaSubSrYdip' if preprocessing else '1pvg0e3HSE_IeYdphYyJ8k52JZyGnae_U'
409 | return imdb_dataset(directory, 'sogou', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id)
410 |
411 |
412 | def yelp_dataset(directory='data/', preprocessing=False, fine_grained=False, verbose=False, text_length=5000):
413 | """
414 | Load the Yelp Review Full Star or Yelp Review Polarity dataset (Version 1).
415 |
416 | The Yelp reviews polarity dataset is constructed by considering stars 1 and 2 negative, and 3
417 | and 4 positive. For each polarity 280,000 training samples and 19,000 testing samples are take
418 | randomly. In total there are 560,000 training samples and 38,000 testing samples. Negative
419 | polarity is class 1, and positive class 2.
420 | The Yelp reviews full star dataset is constructed by randomly taking 130,000 training samples
421 | and 10,000 testing samples for each review star from 1 to 5. In total there are 650,000 training
422 | samples and 50,000 testing samples.
423 | After preprocessing, the total number of training samples is 560,000 and testing samples 38,000.
424 | The min length of text about train data is 1, max length is 1,491, average length is 162; the min
425 | length of text about test data is 1, max length is 1,311, average length is 162. (polarity)
426 | After preprocessing, the total number of training samples is 650,000 and testing samples 50,000.
427 | The min length of text about train data is 1, max length is 1,332, average length is 164; the min
428 | length of text about test data is 1, max length is 1,491, average length is 164. (full)
429 |
430 | **Reference:** http://www.yelp.com/dataset_challenge
431 |
432 | Example:
433 | >>> train, test = yelp_dataset(preprocessing=True)
434 | >>> train[0:2]
435 | [{
436 | 'label': '1',
437 | 'text': 'unfortunately , the frustration of being dr . goldberg ' s patient is a repeat...'}
438 | {
439 | 'label': '2',
440 | 'text': 'been going to dr . goldberg for over 1 0 years . i think i was one of his 1...'}]
441 | >>> test[0:2]
442 | [{
443 | 'label': '2',
444 | 'text': 'contrary to other reviews , i have zero complaints about the service or the prices...'},
445 | {
446 | 'label': '1',
447 | 'text': 'last summer i had an appointment to get new tires and had to wait a super long time...'}]
448 | """
449 |
450 | share_id = '1ecOuyAhT-MjXQiueRHqS9LnY0CV0HQYd' if preprocessing else '1yEF0Lnd4f8mDZiqeh2QmEKvp_gPdYgtH'
451 | return imdb_dataset(directory, 'yelp', preprocessing, fine_grained, verbose, text_length, share_id)
452 |
453 |
454 | def amazon_dataset(directory='data/', preprocessing=False, fine_grained=False, verbose=False, text_length=5000):
455 | """
456 | Load the Amazon Review Full Score or Amazon Review Polaridy dataset (Version 3).
457 |
458 | The Amazon reviews polarity dataset is constructed by taking review score 1 and 2 as negative,
459 | and 4 and 5 as positive. For each polarity 1,800,000 training samples and 200,000 testing samples
460 | are take randomly. In total there are 3,600,000 training samples and 400,000 testing samples.
461 | Negative polarity is class 1, and positive class 2.
462 | The Amazon reviews full score dataset is constructed by randomly taking 600,000 training samples
463 | and 130,000 testing samples for each review score from 1 to 5. In total there are 3,000,000
464 | training samples and 650,000 testing samples.
465 | After preprocessing, the total number of training samples is 3,600,000 and testing samples 400,000.
466 | The min length of text about train data is 2, max length is 986, average length is 95; the min
467 | length of text about test data is 14, max length is 914, average length is 95. (polarity)
468 | After preprocessing, the total number of training samples is 3,000,000 and testing samples 650,000.
469 | The min length of text about train data is 2, max length is 781, average length is 97; the min
470 | length of text about test data is 12, max length is 931, average length is 97. (full)
471 |
472 | **Reference:** http://jmcauley.ucsd.edu/data/amazon/
473 |
474 | Example:
475 | >>> train, test = amazon_dataset(preprocessing=True)
476 | >>> train[0:2]
477 | [{
478 | 'label': '2',
479 | 'text': 'stuning even for the non - gamer this sound track was beautiful ! it paints...'}
480 | {
481 | 'label': '2',
482 | 'text': 'the best soundtrack ever to anything . i ' m reading a lot of reviews saying...'}]
483 | >>> test[0:2]
484 | [{
485 | 'label': '2',
486 | 'text': 'great cd my lovely pat has one of the great voices of her generation . i have...'},
487 | {
488 | 'label': '2',
489 | 'text': 'one of the best game music soundtracks - for a game i didn ' t really play...'}]
490 | """
491 |
492 | share_id = '1BSqCU6DwIVD1jllbsz9ueudu3tSfomzY' if preprocessing else '11-l1T-_kBdtqrqqSfqNnJCLC6_NRM9Je'
493 | return imdb_dataset(directory, 'amazon', preprocessing, fine_grained, verbose, text_length, share_id)
494 |
--------------------------------------------------------------------------------
/epochs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/epochs/.gitkeep
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import pandas as pd
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | import torchnet as tnt
7 | from torch.nn import CrossEntropyLoss
8 | from torch.optim import Adam
9 | from torch.optim.lr_scheduler import MultiStepLR
10 | from torch.utils.data import DataLoader
11 | from torchnet.logger import VisdomPlotLogger, VisdomLogger
12 | from torchnlp.samplers import BucketBatchSampler
13 |
14 | from model import Model
15 | from utils import load_data, MarginLoss, collate_fn, FocalLoss
16 |
17 |
18 | def reset_meters():
19 | meter_accuracy.reset()
20 | meter_loss.reset()
21 | meter_confusion.reset()
22 |
23 |
24 | if __name__ == '__main__':
25 |
26 | parser = argparse.ArgumentParser(description='Train Text Classification')
27 | parser.add_argument('--data_type', default='imdb', type=str,
28 | choices=['imdb', 'newsgroups', 'reuters', 'webkb', 'cade', 'dbpedia', 'agnews', 'yahoo',
29 | 'sogou', 'yelp', 'amazon'], help='dataset type')
30 | parser.add_argument('--fine_grained', action='store_true', help='use fine grained class or not, it only works for '
31 | 'reuters, yelp and amazon')
32 | parser.add_argument('--text_length', default=5000, type=int, help='the number of words about the text to load')
33 | parser.add_argument('--routing_type', default='k_means', type=str, choices=['k_means', 'dynamic'],
34 | help='routing type, it only works for capsule classifier')
35 | parser.add_argument('--loss_type', default='mf', type=str,
36 | choices=['margin', 'focal', 'cross', 'mf', 'mc', 'fc', 'mfc'], help='loss type')
37 | parser.add_argument('--embedding_type', default='cwc', type=str, choices=['cwc', 'cc', 'normal'],
38 | help='embedding type')
39 | parser.add_argument('--classifier_type', default='capsule', type=str, choices=['capsule', 'linear'],
40 | help='classifier type')
41 | parser.add_argument('--embedding_size', default=64, type=int, help='embedding size')
42 | parser.add_argument('--num_codebook', default=8, type=int,
43 | help='codebook number, it only works for cwc and cc embedding')
44 | parser.add_argument('--num_codeword', default=None, type=int,
45 | help='codeword number, it only works for cwc and cc embedding')
46 | parser.add_argument('--hidden_size', default=128, type=int, help='hidden size')
47 | parser.add_argument('--in_length', default=8, type=int,
48 | help='in capsule length, it only works for capsule classifier')
49 | parser.add_argument('--out_length', default=16, type=int,
50 | help='out capsule length, it only works for capsule classifier')
51 | parser.add_argument('--num_iterations', default=3, type=int,
52 | help='routing iterations number, it only works for capsule classifier')
53 | parser.add_argument('--num_repeat', default=10, type=int,
54 | help='gumbel softmax repeat number, it only works for cc embedding')
55 | parser.add_argument('--drop_out', default=0.5, type=float, help='drop_out rate of GRU layer')
56 | parser.add_argument('--batch_size', default=32, type=int, help='train batch size')
57 | parser.add_argument('--num_epochs', default=10, type=int, help='train epochs number')
58 | parser.add_argument('--num_steps', default=100, type=int, help='test steps number')
59 | parser.add_argument('--pre_model', default=None, type=str,
60 | help='pre-trained model weight, it only works for routing_type experiment')
61 |
62 | opt = parser.parse_args()
63 | DATA_TYPE, FINE_GRAINED, TEXT_LENGTH = opt.data_type, opt.fine_grained, opt.text_length
64 | ROUTING_TYPE, LOSS_TYPE, EMBEDDING_TYPE = opt.routing_type, opt.loss_type, opt.embedding_type
65 | CLASSIFIER_TYPE, EMBEDDING_SIZE, NUM_CODEBOOK = opt.classifier_type, opt.embedding_size, opt.num_codebook
66 | NUM_CODEWORD, HIDDEN_SIZE, IN_LENGTH = opt.num_codeword, opt.hidden_size, opt.in_length
67 | OUT_LENGTH, NUM_ITERATIONS, DROP_OUT, BATCH_SIZE = opt.out_length, opt.num_iterations, opt.drop_out, opt.batch_size
68 | NUM_REPEAT, NUM_EPOCHS, NUM_STEPS, PRE_MODEL = opt.num_repeat, opt.num_epochs, opt.num_steps, opt.pre_model
69 |
70 | # prepare dataset
71 | sentence_encoder, label_encoder, train_dataset, test_dataset = load_data(DATA_TYPE, preprocessing=True,
72 | fine_grained=FINE_GRAINED, verbose=True,
73 | text_length=TEXT_LENGTH)
74 | VOCAB_SIZE, NUM_CLASS = sentence_encoder.vocab_size, label_encoder.vocab_size
75 | print("[!] vocab_size: {}, num_class: {}".format(VOCAB_SIZE, NUM_CLASS))
76 | train_sampler = BucketBatchSampler(train_dataset, BATCH_SIZE, False, sort_key=lambda row: len(row['text']))
77 | train_iterator = DataLoader(train_dataset, batch_sampler=train_sampler, collate_fn=collate_fn)
78 | test_sampler = BucketBatchSampler(test_dataset, BATCH_SIZE * 2, False, sort_key=lambda row: len(row['text']))
79 | test_iterator = DataLoader(test_dataset, batch_sampler=test_sampler, collate_fn=collate_fn)
80 |
81 | model = Model(VOCAB_SIZE, EMBEDDING_SIZE, NUM_CODEBOOK, NUM_CODEWORD, HIDDEN_SIZE, IN_LENGTH, OUT_LENGTH,
82 | NUM_CLASS, ROUTING_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, NUM_ITERATIONS, NUM_REPEAT, DROP_OUT)
83 | if PRE_MODEL is not None:
84 | model_weight = torch.load('epochs/{}'.format(PRE_MODEL), map_location='cpu')
85 | model_weight.pop('classifier.weight')
86 | model.load_state_dict(model_weight, strict=False)
87 |
88 | if LOSS_TYPE == 'margin':
89 | loss_criterion = [MarginLoss(NUM_CLASS)]
90 | elif LOSS_TYPE == 'focal':
91 | loss_criterion = [FocalLoss()]
92 | elif LOSS_TYPE == 'cross':
93 | loss_criterion = [CrossEntropyLoss()]
94 | elif LOSS_TYPE == 'mf':
95 | loss_criterion = [MarginLoss(NUM_CLASS), FocalLoss()]
96 | elif LOSS_TYPE == 'mc':
97 | loss_criterion = [MarginLoss(NUM_CLASS), CrossEntropyLoss()]
98 | elif LOSS_TYPE == 'fc':
99 | loss_criterion = [FocalLoss(), CrossEntropyLoss()]
100 | else:
101 | loss_criterion = [MarginLoss(NUM_CLASS), FocalLoss(), CrossEntropyLoss()]
102 | if torch.cuda.is_available():
103 | model, cudnn.benchmark = model.to('cuda'), True
104 |
105 | if PRE_MODEL is None:
106 | optim_configs = [{'params': model.embedding.parameters(), 'lr': 1e-4 * 10},
107 | {'params': model.features.parameters(), 'lr': 1e-4 * 10},
108 | {'params': model.classifier.parameters(), 'lr': 1e-4}]
109 | else:
110 | for param in model.embedding.parameters():
111 | param.requires_grad = False
112 | for param in model.features.parameters():
113 | param.requires_grad = False
114 | optim_configs = [{'params': model.classifier.parameters(), 'lr': 1e-4}]
115 | optimizer = Adam(optim_configs, lr=1e-4)
116 | lr_scheduler = MultiStepLR(optimizer, milestones=[int(NUM_EPOCHS * 0.5), int(NUM_EPOCHS * 0.7)], gamma=0.1)
117 |
118 | print("# trainable parameters:", sum(param.numel() if param.requires_grad else 0 for param in model.parameters()))
119 | # record statistics
120 | results = {'train_loss': [], 'train_accuracy': [], 'test_loss': [], 'test_accuracy': []}
121 | # record current best test accuracy
122 | best_acc = 0
123 | meter_loss = tnt.meter.AverageValueMeter()
124 | meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
125 | meter_confusion = tnt.meter.ConfusionMeter(NUM_CLASS, normalized=True)
126 |
127 | # config the visdom figures
128 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']:
129 | env_name = DATA_TYPE + '_fine_grained'
130 | else:
131 | env_name = DATA_TYPE
132 | loss_logger = VisdomPlotLogger('line', env=env_name, opts={'title': 'Loss'})
133 | accuracy_logger = VisdomPlotLogger('line', env=env_name, opts={'title': 'Accuracy'})
134 | train_confusion_logger = VisdomLogger('heatmap', env=env_name, opts={'title': 'Train Confusion Matrix'})
135 | test_confusion_logger = VisdomLogger('heatmap', env=env_name, opts={'title': 'Test Confusion Matrix'})
136 |
137 | current_step = 0
138 | for epoch in range(1, NUM_EPOCHS + 1):
139 | for data, target in train_iterator:
140 | current_step += 1
141 | label = target
142 | if torch.cuda.is_available():
143 | data, label = data.to('cuda'), label.to('cuda')
144 | # train model
145 | model.train()
146 | optimizer.zero_grad()
147 | classes = model(data)
148 | loss = sum([criterion(classes, label) for criterion in loss_criterion])
149 | loss.backward()
150 | optimizer.step()
151 | # save the metrics
152 | meter_loss.add(loss.detach().cpu().item())
153 | meter_accuracy.add(classes.detach().cpu(), target)
154 | meter_confusion.add(classes.detach().cpu(), target)
155 |
156 | if current_step % NUM_STEPS == 0:
157 | # print the information about train
158 | loss_logger.log(current_step // NUM_STEPS, meter_loss.value()[0], name='train')
159 | accuracy_logger.log(current_step // NUM_STEPS, meter_accuracy.value()[0], name='train')
160 | train_confusion_logger.log(meter_confusion.value())
161 | results['train_loss'].append(meter_loss.value()[0])
162 | results['train_accuracy'].append(meter_accuracy.value()[0])
163 | print('[Step %d] Training Loss: %.4f Accuracy: %.2f%%' % (
164 | current_step // NUM_STEPS, meter_loss.value()[0], meter_accuracy.value()[0]))
165 | reset_meters()
166 |
167 | # test model periodically
168 | model.eval()
169 | with torch.no_grad():
170 | for data, target in test_iterator:
171 | label = target
172 | if torch.cuda.is_available():
173 | data, label = data.to('cuda'), label.to('cuda')
174 | classes = model(data)
175 | loss = sum([criterion(classes, label) for criterion in loss_criterion])
176 | # save the metrics
177 | meter_loss.add(loss.detach().cpu().item())
178 | meter_accuracy.add(classes.detach().cpu(), target)
179 | meter_confusion.add(classes.detach().cpu(), target)
180 | # print the information about test
181 | loss_logger.log(current_step // NUM_STEPS, meter_loss.value()[0], name='test')
182 | accuracy_logger.log(current_step // NUM_STEPS, meter_accuracy.value()[0], name='test')
183 | test_confusion_logger.log(meter_confusion.value())
184 | results['test_loss'].append(meter_loss.value()[0])
185 | results['test_accuracy'].append(meter_accuracy.value()[0])
186 |
187 | # save best model
188 | if meter_accuracy.value()[0] > best_acc:
189 | best_acc = meter_accuracy.value()[0]
190 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']:
191 | torch.save(model.state_dict(), 'epochs/{}_{}_{}_{}.pth'
192 | .format(DATA_TYPE + '_fine-grained', EMBEDDING_TYPE, CLASSIFIER_TYPE,
193 | str(TEXT_LENGTH)))
194 | else:
195 | torch.save(model.state_dict(), 'epochs/{}_{}_{}_{}.pth'
196 | .format(DATA_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, str(TEXT_LENGTH)))
197 | print('[Step %d] Testing Loss: %.4f Accuracy: %.2f%% Best Accuracy: %.2f%%' % (
198 | current_step // NUM_STEPS, meter_loss.value()[0], meter_accuracy.value()[0], best_acc))
199 | reset_meters()
200 |
201 | # save statistics
202 | data_frame = pd.DataFrame(data=results, index=range(1, current_step // NUM_STEPS + 1))
203 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']:
204 | data_frame.to_csv('statistics/{}_{}_{}_results.csv'.format(
205 | DATA_TYPE + '_fine-grained', EMBEDDING_TYPE, CLASSIFIER_TYPE), index_label='step')
206 | else:
207 | data_frame.to_csv('statistics/{}_{}_{}_results.csv'.format(
208 | DATA_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE), index_label='step')
209 | lr_scheduler.step(epoch)
210 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from capsule_layer import CapsuleLinear
6 | from torch import nn
7 | from torch.nn.parameter import Parameter
8 |
9 |
10 | class CompositionalEmbedding(nn.Module):
11 | r"""A simple compositional codeword and codebook that store embeddings.
12 |
13 | Args:
14 | num_embeddings (int): size of the dictionary of embeddings
15 | embedding_dim (int): size of each embedding vector
16 | num_codebook (int): size of the codebook of embeddings
17 | num_codeword (int, optional): size of the codeword of embeddings
18 | weighted (bool, optional): weighted version of unweighted version
19 | return_code (bool, optional): return code or not
20 |
21 | Shape:
22 | - Input: (LongTensor): (N, W), W = number of indices to extract per mini-batch
23 | - Output: (Tensor): (N, W, embedding_dim)
24 |
25 | Attributes:
26 | - code (Tensor): the learnable weights of the module of shape
27 | (num_embeddings, num_codebook, num_codeword)
28 | - codebook (Tensor): the learnable weights of the module of shape
29 | (num_codebook, num_codeword, embedding_dim)
30 |
31 | Examples::
32 | >>> m = CompositionalEmbedding(200, 64, 16, 32, weighted=False)
33 | >>> a = torch.randperm(128).view(16, -1)
34 | >>> output = m(a)
35 | >>> print(output.size())
36 | torch.Size([16, 8, 64])
37 | """
38 |
39 | def __init__(self, num_embeddings, embedding_dim, num_codebook, num_codeword=None, num_repeat=10, weighted=True,
40 | return_code=False):
41 | super(CompositionalEmbedding, self).__init__()
42 | self.num_embeddings = num_embeddings
43 | self.embedding_dim = embedding_dim
44 | self.num_codebook = num_codebook
45 | self.num_repeat = num_repeat
46 | self.weighted = weighted
47 | self.return_code = return_code
48 |
49 | if num_codeword is None:
50 | num_codeword = math.ceil(math.pow(num_embeddings, 1 / num_codebook))
51 | self.num_codeword = num_codeword
52 | self.code = Parameter(torch.Tensor(num_embeddings, num_codebook, num_codeword))
53 | self.codebook = Parameter(torch.Tensor(num_codebook, num_codeword, embedding_dim))
54 |
55 | nn.init.normal_(self.code)
56 | nn.init.normal_(self.codebook)
57 |
58 | def forward(self, input):
59 | batch_size = input.size(0)
60 | index = input.view(-1)
61 | code = self.code.index_select(dim=0, index=index)
62 | if self.weighted:
63 | # reweight, do softmax, make sure the sum of weight about each book to 1
64 | code = F.softmax(code, dim=-1)
65 | out = (code[:, :, None, :] @ self.codebook[None, :, :, :]).squeeze(dim=-2).sum(dim=1)
66 | else:
67 | # because Gumbel SoftMax works in a stochastic manner, needs to run several times to
68 | # get more accurate embedding
69 | code = (torch.sum(torch.stack([F.gumbel_softmax(code) for _ in range(self.num_repeat)]), dim=0)).argmax(
70 | dim=-1)
71 | out = []
72 | for index in range(self.num_codebook):
73 | out.append(self.codebook[index, :, :].index_select(dim=0, index=code[:, index]))
74 | out = torch.sum(torch.stack(out), dim=0)
75 | code = F.one_hot(code, num_classes=self.num_codeword)
76 |
77 | out = out.view(batch_size, -1, self.embedding_dim)
78 | code = code.view(batch_size, -1, self.num_codebook, self.num_codeword)
79 | if self.return_code:
80 | return out, code
81 | else:
82 | return out
83 |
84 | def __repr__(self):
85 | return self.__class__.__name__ + ' (' + str(self.num_embeddings) + ', ' + str(self.embedding_dim) + ')'
86 |
87 |
88 | class Model(nn.Module):
89 | def __init__(self, vocab_size, embedding_size, num_codebook, num_codeword, hidden_size, in_length, out_length,
90 | num_class, routing_type, embedding_type, classifier_type, num_iterations, num_repeat, dropout):
91 | super().__init__()
92 |
93 | self.in_length, self.out_length = in_length, out_length
94 | self.hidden_size, self.classifier_type = hidden_size, classifier_type
95 | self.embedding_type = embedding_type
96 |
97 | if embedding_type == 'cwc':
98 | self.embedding = CompositionalEmbedding(vocab_size, embedding_size, num_codebook, num_codeword,
99 | weighted=True)
100 | elif embedding_type == 'cc':
101 | self.embedding = CompositionalEmbedding(vocab_size, embedding_size, num_codebook, num_codeword, num_repeat,
102 | weighted=False)
103 | else:
104 | self.embedding = nn.Embedding(vocab_size, embedding_size)
105 | self.features = nn.GRU(embedding_size, self.hidden_size, num_layers=2, dropout=dropout, batch_first=True,
106 | bidirectional=True)
107 | if classifier_type == 'capsule' and routing_type == 'k_means':
108 | self.classifier = CapsuleLinear(out_capsules=num_class, in_length=self.in_length,
109 | out_length=self.out_length, in_capsules=None, share_weight=True,
110 | routing_type='k_means', num_iterations=num_iterations, bias=False)
111 | elif classifier_type == 'capsule' and routing_type == 'dynamic':
112 | self.classifier = CapsuleLinear(out_capsules=num_class, in_length=self.in_length,
113 | out_length=self.out_length, in_capsules=None, share_weight=True,
114 | routing_type='dynamic', num_iterations=num_iterations, bias=False)
115 | else:
116 | self.classifier = nn.Linear(in_features=self.hidden_size, out_features=num_class, bias=False)
117 |
118 | def forward(self, x):
119 | embed = self.embedding(x)
120 | out, _ = self.features(embed)
121 |
122 | out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:]
123 | out = out.mean(dim=1).contiguous()
124 | if self.classifier_type == 'capsule':
125 | out = out.view(out.size(0), -1, self.in_length)
126 | out = self.classifier(out)
127 | classes = out.norm(dim=-1)
128 | else:
129 | out = out.view(out.size(0), -1)
130 | classes = self.classifier(out)
131 | return classes
132 |
--------------------------------------------------------------------------------
/results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/.gitkeep
--------------------------------------------------------------------------------
/results/agnews.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/agnews.png
--------------------------------------------------------------------------------
/results/amazon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/amazon.png
--------------------------------------------------------------------------------
/results/amazon_fine_grained.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/amazon_fine_grained.png
--------------------------------------------------------------------------------
/results/dbpedia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/dbpedia.png
--------------------------------------------------------------------------------
/results/sogou.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/sogou.png
--------------------------------------------------------------------------------
/results/yahoo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/yahoo.png
--------------------------------------------------------------------------------
/results/yelp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/yelp.png
--------------------------------------------------------------------------------
/results/yelp_fine_grained.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/yelp_fine_grained.png
--------------------------------------------------------------------------------
/statistics/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/statistics/.gitkeep
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import pandas as pd
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 | from torchnlp.encoders.label_encoder import LabelEncoder
9 | from torchnlp.encoders.text import WhitespaceEncoder
10 | from torchnlp.encoders.text.default_reserved_tokens import DEFAULT_PADDING_TOKEN, DEFAULT_UNKNOWN_TOKEN
11 | from torchnlp.encoders.text.text_encoder import stack_and_pad_tensors
12 | from torchnlp.utils import datasets_iterator
13 |
14 | from datasets import imdb_dataset, agnews_dataset, amazon_dataset, dbpedia_dataset, newsgroups_dataset, reuters_dataset, \
15 | webkb_dataset, yahoo_dataset, yelp_dataset, cade_dataset, sogou_dataset
16 |
17 |
18 | class MarginLoss(nn.Module):
19 | def __init__(self, num_class, size_average=True):
20 | super(MarginLoss, self).__init__()
21 | self.num_class = num_class
22 | self.size_average = size_average
23 |
24 | def forward(self, classes, labels):
25 | labels = F.one_hot(labels, self.num_class).float()
26 | left = F.relu(0.9 - classes, inplace=True) ** 2
27 | right = F.relu(classes - 0.1, inplace=True) ** 2
28 | loss = labels * left + 0.5 * (1 - labels) * right
29 | loss = loss.sum(dim=-1)
30 | if self.size_average:
31 | return loss.mean()
32 | else:
33 | return loss.sum()
34 |
35 |
36 | class FocalLoss(nn.Module):
37 | def __init__(self, alpha=0.25, gamma=2, size_average=True):
38 | super(FocalLoss, self).__init__()
39 | self.alpha = alpha
40 | self.gamma = gamma
41 | self.size_average = size_average
42 |
43 | def forward(self, classes, labels):
44 | log_pt = F.log_softmax(classes, dim=-1)
45 | log_pt = log_pt.gather(-1, labels.view(-1, 1)).view(-1)
46 | pt = log_pt.exp()
47 | loss = -self.alpha * (1 - pt) ** self.gamma * log_pt
48 | if self.size_average:
49 | return loss.mean()
50 | else:
51 | return loss.sum()
52 |
53 |
54 | def load_data(data_type, preprocessing=False, fine_grained=False, verbose=False, text_length=5000, encode=True):
55 | if data_type == 'imdb':
56 | train_data, test_data = imdb_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length)
57 | elif data_type == 'newsgroups':
58 | train_data, test_data = newsgroups_dataset(preprocessing=preprocessing, verbose=verbose,
59 | text_length=text_length)
60 | elif data_type == 'reuters':
61 | train_data, test_data = reuters_dataset(preprocessing=preprocessing, fine_grained=fine_grained, verbose=verbose,
62 | text_length=text_length)
63 | elif data_type == 'webkb':
64 | train_data, test_data = webkb_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length)
65 | elif data_type == 'cade':
66 | train_data, test_data = cade_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length)
67 | elif data_type == 'dbpedia':
68 | train_data, test_data = dbpedia_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length)
69 | elif data_type == 'agnews':
70 | train_data, test_data = agnews_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length)
71 | elif data_type == 'yahoo':
72 | train_data, test_data = yahoo_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length)
73 | elif data_type == 'sogou':
74 | train_data, test_data = sogou_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length)
75 | elif data_type == 'yelp':
76 | train_data, test_data = yelp_dataset(preprocessing=preprocessing, fine_grained=fine_grained, verbose=verbose,
77 | text_length=text_length)
78 | elif data_type == 'amazon':
79 | train_data, test_data = amazon_dataset(preprocessing=preprocessing, fine_grained=fine_grained, verbose=verbose,
80 | text_length=text_length)
81 | else:
82 | raise ValueError('{} data type not supported.'.format(data_type))
83 |
84 | if encode:
85 | sentence_corpus = [row['text'] for row in datasets_iterator(train_data, )]
86 | sentence_encoder = WhitespaceEncoder(sentence_corpus,
87 | reserved_tokens=[DEFAULT_PADDING_TOKEN, DEFAULT_UNKNOWN_TOKEN])
88 | label_corpus = [row['label'] for row in datasets_iterator(train_data, )]
89 | label_encoder = LabelEncoder(label_corpus, reserved_labels=[])
90 |
91 | # Encode
92 | for row in datasets_iterator(train_data, test_data):
93 | row['text'] = sentence_encoder.encode(row['text'])
94 | row['label'] = label_encoder.encode(row['label'])
95 | return sentence_encoder, label_encoder, train_data, test_data
96 | else:
97 | return train_data, test_data
98 |
99 |
100 | def collate_fn(batch):
101 | """ list of tensors to a batch tensors """
102 | text_batch, _ = stack_and_pad_tensors([row['text'] for row in batch])
103 | label_batch = [row['label'].unsqueeze(0) for row in batch]
104 | return [text_batch, torch.cat(label_batch)]
105 |
106 |
107 | if __name__ == '__main__':
108 | parser = argparse.ArgumentParser(description='Generate Preprocessed Data')
109 | parser.add_argument('--data_type', default='imdb', type=str,
110 | choices=['imdb', 'newsgroups', 'reuters', 'webkb', 'cade', 'dbpedia', 'agnews', 'yahoo',
111 | 'sogou', 'yelp', 'amazon'], help='dataset type')
112 | parser.add_argument('--fine_grained', action='store_true', help='use fine grained class or not, it only works for '
113 | 'reuters, yelp and amazon')
114 | opt = parser.parse_args()
115 | DATA_TYPE, FINE_GRAINED = opt.data_type, opt.fine_grained
116 | train_dataset, test_dataset = load_data(DATA_TYPE, preprocessing=None, fine_grained=FINE_GRAINED, encode=False)
117 |
118 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']:
119 | train_file = os.path.join('data', DATA_TYPE, 'preprocessed_fine_grained_train.csv')
120 | test_file = os.path.join('data', DATA_TYPE, 'preprocessed_fine_grained_test.csv')
121 | else:
122 | train_file = os.path.join('data', DATA_TYPE, 'preprocessed_train.csv')
123 | test_file = os.path.join('data', DATA_TYPE, 'preprocessed_test.csv')
124 |
125 | # save files
126 | print('Saving preprocessed {} dataset into {}... '.format(DATA_TYPE, os.path.join('data', DATA_TYPE)), end='')
127 | train_label, train_text, test_label, test_text = [], [], [], []
128 | for data in train_dataset:
129 | train_label.append(data['label'])
130 | train_text.append(data['text'])
131 | for data in test_dataset:
132 | test_label.append(data['label'])
133 | test_text.append(data['text'])
134 | train_data_frame = pd.DataFrame({'label': train_label, 'text': train_text})
135 | test_data_frame = pd.DataFrame({'label': test_label, 'text': test_text})
136 | train_data_frame.to_csv(train_file, header=False, index=False)
137 | test_data_frame.to_csv(test_file, header=False, index=False)
138 | print('Done.')
139 |
--------------------------------------------------------------------------------
/vis.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 |
8 | from model import Model
9 | from utils import load_data
10 |
11 | if __name__ == '__main__':
12 | parser = argparse.ArgumentParser(description='Vis Embedding and Code')
13 | parser.add_argument('--model_weight', default=None, type=str, help='saved model weight to load')
14 | parser.add_argument('--routing_type', default='k_means', type=str, choices=['k_means', 'dynamic'],
15 | help='routing type, it only works for capsule classifier')
16 | parser.add_argument('--embedding_size', default=64, type=int, help='embedding size')
17 | parser.add_argument('--num_codebook', default=8, type=int,
18 | help='codebook number, it only works for cwc and cc embedding')
19 | parser.add_argument('--num_codeword', default=None, type=int,
20 | help='codeword number, it only works for cwc and cc embedding')
21 | parser.add_argument('--hidden_size', default=128, type=int, help='hidden size')
22 | parser.add_argument('--in_length', default=8, type=int,
23 | help='in capsule length, it only works for capsule classifier')
24 | parser.add_argument('--out_length', default=16, type=int,
25 | help='out capsule length, it only works for capsule classifier')
26 | parser.add_argument('--num_iterations', default=3, type=int,
27 | help='routing iterations number, it only works for capsule classifier')
28 | parser.add_argument('--num_repeat', default=10, type=int,
29 | help='gumbel softmax repeat number, it only works for cc embedding')
30 | parser.add_argument('--drop_out', default=0.5, type=float, help='drop_out rate of GRU layer')
31 |
32 | opt = parser.parse_args()
33 | MODEL_WEIGHT, ROUTING_TYPE, EMBEDDING_SIZE = opt.model_weight, opt.routing_type, opt.embedding_size
34 | NUM_CODEBOOK, NUM_CODEWORD, HIDDEN_SIZE = opt.num_codebook, opt.num_codeword, opt.hidden_size
35 | IN_LENGTH, OUT_LENGTH, NUM_ITERATIONS, DROP_OUT = opt.in_length, opt.out_length, opt.num_iterations, opt.drop_out
36 | NUM_REPEAT = opt.num_repeat
37 | configs = MODEL_WEIGHT.split('_')
38 | if len(configs) == 4:
39 | DATA_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, TEXT_LENGTH = configs
40 | FINE_GRAINED, TEXT_LENGTH = False, int(TEXT_LENGTH.split('.')[0])
41 | else:
42 | DATA_TYPE, _, EMBEDDING_TYPE, CLASSIFIER_TYPE, TEXT_LENGTH = configs
43 | FINE_GRAINED, TEXT_LENGTH = True, int(TEXT_LENGTH.split('.')[0])
44 |
45 | data_name = '{}_fine-grained'.format(DATA_TYPE) if FINE_GRAINED else DATA_TYPE
46 |
47 | print('Loading {} dataset'.format(data_name))
48 | # get sentence encoder
49 | sentence_encoder, label_encoder, _, _ = load_data(DATA_TYPE, preprocessing=True, fine_grained=FINE_GRAINED,
50 | verbose=True, text_length=TEXT_LENGTH)
51 | VOCAB_SIZE, NUM_CLASS = sentence_encoder.vocab_size, label_encoder.vocab_size
52 |
53 | model = Model(VOCAB_SIZE, EMBEDDING_SIZE, NUM_CODEBOOK, NUM_CODEWORD, HIDDEN_SIZE, IN_LENGTH, OUT_LENGTH,
54 | NUM_CLASS, ROUTING_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, NUM_ITERATIONS, NUM_REPEAT, DROP_OUT)
55 | model.load_state_dict(torch.load('epochs/{}'.format(MODEL_WEIGHT), map_location='cpu'))
56 | if torch.cuda.is_available():
57 | model, cudnn.benchmark = model.to('cuda'), True
58 |
59 | model.eval()
60 | print('Generating embedding and code for {} dataset'.format(data_name))
61 | with torch.no_grad():
62 | if EMBEDDING_TYPE == 'normal':
63 | vocabs = model.embedding.weight.detach().cpu().numpy()
64 | codes = torch.ones(1, 1, sentence_encoder.vocab_size)
65 | else:
66 | embedding = model.embedding
67 | embedding.return_code = True
68 | data = torch.arange(sentence_encoder.vocab_size).view(1, -1)
69 | if torch.cuda.is_available():
70 | data = data.to('cuda')
71 | out, code = embedding(data)
72 | # [num_embeddings, embedding_dim], ([num_embeddings, num_codebook, num_codeword], [1, 1, num_embeddings])
73 | vocabs, codes = out.squeeze(dim=0).detach().cpu().numpy(), code.squeeze(dim=0).detach().cpu()
74 |
75 | print('Plotting code usage for {} dataset'.format(data_name))
76 | reduced_codes = codes.sum(dim=0).float()
77 | c_max, c_min = reduced_codes.max().item(), reduced_codes.min().item()
78 | f, ax = plt.subplots(figsize=(10, 5))
79 | heat_map = sns.heatmap(reduced_codes.numpy(), vmin=c_min, vmax=c_max, annot=True, fmt='.2f', ax=ax)
80 | ax.set_title('Code usage of {} embedding for {} dataset'.format(EMBEDDING_TYPE, data_name))
81 | ax.set_xlabel('codeword')
82 | ax.set_ylabel('codebook')
83 | f.savefig('results/{}_{}_code.jpg'.format(data_name, EMBEDDING_TYPE))
84 |
--------------------------------------------------------------------------------