├── README.md ├── data └── .gitkeep ├── data_utils.py ├── datasets.py ├── epochs └── .gitkeep ├── main.py ├── model.py ├── results ├── .gitkeep ├── agnews.png ├── amazon.png ├── amazon_fine_grained.png ├── dbpedia.png ├── sogou.png ├── yahoo.png ├── yelp.png └── yelp_fine_grained.png ├── statistics └── .gitkeep ├── utils.py └── vis.py /README.md: -------------------------------------------------------------------------------- 1 | # CCCapsNet 2 | A PyTorch implementation of Compositional Coding Capsule Network based on PRL 2022 paper [Compositional Coding Capsule Network with K-Means Routing for Text Classification](https://www.sciencedirect.com/science/article/pii/S016786552200188X). 3 | 4 | ## Requirements 5 | * [Anaconda](https://www.anaconda.com/download/) 6 | * PyTorch 7 | ``` 8 | conda install pytorch torchvision -c pytorch 9 | ``` 10 | * PyTorchNet 11 | ``` 12 | pip install git+https://github.com/pytorch/tnt.git@master 13 | ``` 14 | * PyTorch-NLP 15 | ``` 16 | pip install pytorch-nlp 17 | ``` 18 | * capsule-layer 19 | ``` 20 | pip install git+https://github.com/leftthomas/CapsuleLayer.git@master 21 | ``` 22 | 23 | ## Datasets 24 | The original `AGNews`, `AmazonReview`, `DBPedia`, `YahooAnswers`, `SogouNews` and `YelpReview` datasets are coming from [here](http://goo.gl/JyCnZq). 25 | 26 | The original `Newsgroups`, `Reuters`, `Cade` and `WebKB` datasets can be found [here](http://ana.cachopo.org/datasets-for-single-label-text-categorization). 27 | 28 | The original `IMDB` dataset is downloaded by `PyTorch-NLP` automatically. 29 | 30 | We have uploaded all the original datasets into [BaiduYun](https://pan.baidu.com/s/16wBuNJiD0acgTHDeld9eDA)(access code:kddr) and 31 | [GoogleDrive](https://drive.google.com/open?id=10n_eZ2ZyRjhRWFjxky7_PhcGHecDjKJ2). 32 | The preprocessed datasets have been uploaded to [BaiduYun](https://pan.baidu.com/s/1hsIJAw54YZbVAqFiehEH6w)(access code:2kyd) and 33 | [GoogleDrive](https://drive.google.com/open?id=1KDE5NJKfgOwc6RNEf9_F0ZhLQZ3Udjx5). 34 | 35 | You needn't download the datasets by yourself, the code will download them automatically. 36 | If you encounter network issues, you can download all the datasets from the aforementioned cloud storage webs, 37 | and extract them into `data` directory. 38 | 39 | ## Usage 40 | 41 | ### Generate Preprocessed Data 42 | ``` 43 | python utils.py --data_type yelp --fine_grained 44 | optional arguments: 45 | --data_type dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb', 46 | 'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon']) 47 | --fine_grained use fine grained class or not, it only works for reuters, yelp and amazon [default value is False] 48 | ``` 49 | This step is not required, and it takes a long time to execute. So I have generated the preprocessed data before, and 50 | uploaded them to the aforementioned cloud storage webs. You could skip this step, and just do the next step, the code will 51 | download the data automatically. 52 | 53 | ### Train Text Classification 54 | ``` 55 | visdom -logging_level WARNING & python main.py --data_type newsgroups --num_epochs 70 56 | optional arguments: 57 | --data_type dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb', 58 | 'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon']) 59 | --fine_grained use fine grained class or not, it only works for reuters, yelp and amazon [default value is False] 60 | --text_length the number of words about the text to load [default value is 5000] 61 | --routing_type routing type, it only works for capsule classifier [default value is 'k_means'](choices:['k_means', 'dynamic']) 62 | --loss_type loss type [default value is 'mf'](choices:['margin', 'focal', 'cross', 'mf', 'mc', 'fc', 'mfc']) 63 | --embedding_type embedding type [default value is 'cwc'](choices:['cwc', 'cc', 'normal']) 64 | --classifier_type classifier type [default value is 'capsule'](choices:['capsule', 'linear']) 65 | --embedding_size embedding size [default value is 64] 66 | --num_codebook codebook number, it only works for cwc and cc embedding [default value is 8] 67 | --num_codeword codeword number, it only works for cwc and cc embedding [default value is None] 68 | --hidden_size hidden size [default value is 128] 69 | --in_length in capsule length, it only works for capsule classifier [default value is 8] 70 | --out_length out capsule length, it only works for capsule classifier [default value is 16] 71 | --num_iterations routing iterations number, it only works for capsule classifier [default value is 3] 72 | --num_repeat gumbel softmax repeat number, it only works for cc embedding [default value is 10] 73 | --drop_out drop_out rate of GRU layer [default value is 0.5] 74 | --batch_size train batch size [default value is 32] 75 | --num_epochs train epochs number [default value is 10] 76 | --num_steps test steps number [default value is 100] 77 | --pre_model pre-trained model weight, it only works for routing_type experiment [default value is None] 78 | ``` 79 | Visdom now can be accessed by going to `127.0.0.1:8097/env/$data_type` in your browser, `$data_type` means the dataset 80 | type which you are training. 81 | 82 | ## Benchmarks 83 | Adam optimizer is used with learning rate scheduling. The models are trained with 10 epochs and batch size of 32 on one 84 | NVIDIA Tesla V100 (32G) GPU. 85 | 86 | The texts are preprocessed as only number and English words, max length is 5000. 87 | 88 | Here is the dataset details: 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 |
Datasetagnewsdbpediayahoosogouyelpyelp fine grainedamazonamazon fine grained
Num. of Train Texts120,000560,0001,400,000450,000560,000650,0003,600,0003,000,000
Num. of Test Texts7,60070,00060,00060,00038,00050,000400,000650,000
Num. of Vocabulary62,535548,338771,820106,385200,790216,985931,271835,818
Num. of Classes4141052525
151 | 152 | Here is the model parameter details, the model name are formalized as `embedding_type-classifier_type`: 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 |
Datasetagnewsdbpediayahoosogouyelpyelp fine grainedamazonamazon fine grained
Normal-Linear4,448,19235,540,86449,843,2007,254,72013,296,25614,333,12060,047,04053,938,432
CC-Linear2,449,12026,770,52837,497,1524,704,0408,479,8569,128,04045,149,77640,568,416
CWC-Linear2,449,12026,770,52837,497,1524,704,0408,479,8569,128,04045,149,77640,568,416
Normal-Capsule4,455,87235,567,74449,862,4007,264,32013,300,09614,342,72060,050,88053,948,032
CC-Capsule2,456,80026,797,40837,516,3524,713,6408,483,6969,137,64045,153,61640,578,016
CWC-Capsule2,456,80026,797,40837,516,3524,713,6408,483,6969,137,64045,153,61640,578,016
237 | 238 | Here is the loss function details, we use `AGNews` dataset and `Normal-Linear` model to test different loss functions: 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 |
Loss Functionmarginfocalcrossmargin+focalmargin+crossfocal+crossmargin+focal+cross
Accuracy92.37%92.13%92.05%92.64%91.95%92.09%92.38%
266 | 267 | Here is the accuracy details, we use `margin+focal` as our loss function, for `capsule` model, `3 iters` is used, 268 | if `embedding_type` is `CC`, then plus `num_repeat`: 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 |
Datasetagnewsdbpediayahoosogouyelpyelp fine grainedamazonamazon fine grained
Normal-Linear92.64%98.84%74.13%97.37%96.69%66.23%95.09%60.78%
CC-Linear-1073.11%92.66%48.01%93.50%87.81%50.33%83.20%45.77%
CC-Linear-3081.05%95.29%53.50%94.65%91.33%55.22%87.37%50.00%
CC-Linear-5083.13%96.06%57.87%95.20%92.37%56.66%89.04%51.30%
CWC-Linear91.93%98.83%73.58%97.37%96.35%65.11%94.90%60.29%
Normal-Capsule92.18%98.86%74.12%97.52%96.56%66.23%95.18%61.36%
CC-Capsule-1073.53%93.04%50.52%94.44%87.98%54.14%83.64%47.44%
CC-Capsule-3081.71%95.72%60.48%95.96%91.90%58.27%87.88%51.63%
CC-Capsule-5084.05%96.27%60.31%96.00%92.82%59.48%89.07%52.06%
CWC-Capsule92.12%98.81%73.78%97.42%96.28%65.38%94.98%60.94%
397 | 398 | Here is the model parameter details, we use `CWC-Capsule` as our model, the model name are formalized as `num_codewords` 399 | for each dataset: 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 |
Datasetagnewsdbpediayahoosogouyelpyelp fine grainedamazonamazon fine grained
577666772,957,59231,184,62443,691,4245,565,23210,090,52810,874,03252,604,29647,265,072
688777883,458,38435,571,84049,866,4966,416,82411,697,36012,610,42460,054,97653,952,128
440 | 441 | Here is the accuracy details: 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 |
Datasetagnewsdbpediayahoosogouyelpyelp fine grainedamazonamazon fine grained
5776667792.54%98.85%73.96%97.41%96.38%65.86%94.98%60.98%
6887778892.05%98.82%73.93%97.52%96.44%65.63%95.05%61.02%
482 | 483 | Here is the accuracy details, we use `57766677` config, the model name are formalized as `num_iterations`: 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 |
Datasetagnewsdbpediayahoosogouyelpyelp fine grainedamazonamazon fine grained
192.28%98.82%73.93%97.25%96.58%65.60%95.00%61.08%
392.54%98.85%73.96%97.41%96.38%65.86%94.98%60.98%
592.21%98.88%73.85%97.38%96.38%65.36%95.05%61.23%
535 | 536 | ## Results 537 | The train/test loss、accuracy and confusion matrix are showed with visdom. The pretrained models and more results can be 538 | found in [BaiduYun](https://pan.baidu.com/s/1mpIXTfuECiSFVxJcLR1j3A) (access code:xer4) and 539 | [GoogleDrive](https://drive.google.com/drive/folders/1hu8sA517kA5bowzE6TYK_xSLGBUCoanK?usp=sharing). 540 | 541 | **agnews** 542 | 543 | ![result](results/agnews.png) 544 | 545 | **dbpedia** 546 | 547 | ![result](results/dbpedia.png) 548 | 549 | **yahoo** 550 | 551 | ![result](results/yahoo.png) 552 | 553 | **sogou** 554 | 555 | ![result](results/sogou.png) 556 | 557 | **yelp** 558 | 559 | ![result](results/yelp.png) 560 | 561 | **yelp fine grained** 562 | 563 | ![result](results/yelp_fine_grained.png) 564 | 565 | **amazon** 566 | 567 | ![result](results/amazon.png) 568 | 569 | **amazon fine grained** 570 | 571 | ![result](results/amazon_fine_grained.png) 572 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/data/.gitkeep -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import warnings 5 | import zipfile 6 | from os import makedirs 7 | from os.path import exists 8 | from sys import stdout 9 | 10 | import requests 11 | 12 | re_letter_number = re.compile(r'[^a-zA-Z0-9 ]') 13 | re_number_letter = re.compile(r'^(\d+)([a-z]\w*)$') 14 | 15 | 16 | def text_preprocess(text, data_type): 17 | if data_type == 'sogou' or data_type == 'yahoo' or data_type == 'yelp': 18 | # Remove \\n character 19 | text = text.replace('\\n', ' ') 20 | if data_type == 'imdb' or data_type == 'yahoo': 21 | # Remove
character 22 | text = text.replace('
', ' ') 23 | if data_type not in ['newsgroups', 'reuters', 'webkb', 'cade']: 24 | # Turn punctuation, foreign word, etc. into SPACES word SPACES). 25 | text = re_letter_number.sub(lambda m: ' ' + m.group(0) + ' ', text) 26 | # Turn all letters to lowercase. 27 | text = text.lower() 28 | # Turn the number-letter word to single number and word (such as turn 2008year into 2008 year). 29 | text = ' '.join(' '.join(w for w in re_number_letter.match(word).groups()) 30 | if re_number_letter.match(word) else word for word in text.split()) 31 | # Turn all numbers to single number (such as turn 789 into 7 8 9). 32 | text = ' '.join(' '.join(w for w in word) if word.isdigit() else word for word in text.split()) 33 | # Substitute multiple SPACES by a single SPACE. 34 | text = ' '.join(text.split()) 35 | return text 36 | 37 | 38 | class GoogleDriveDownloader: 39 | """ 40 | Minimal class to download shared files from Google Drive. 41 | """ 42 | 43 | CHUNK_SIZE = 32768 44 | DOWNLOAD_URL = "https://docs.google.com/uc?export=download" 45 | 46 | @staticmethod 47 | def download_file_from_google_drive(file_id, file_name, dest_path, overwrite=False): 48 | """ 49 | Downloads a shared file from google drive into a given folder. 50 | Optionally unzips it. 51 | 52 | Args: 53 | file_id (str): the file identifier. You can obtain it from the sherable link. 54 | file_name (str): the file name. 55 | dest_path (str): the destination where to save the downloaded file. 56 | overwrite (bool): optional, if True forces re-download and overwrite. 57 | 58 | Returns: 59 | None 60 | """ 61 | 62 | if not exists(dest_path): 63 | makedirs(dest_path) 64 | 65 | if not exists(os.path.join(dest_path, file_name)) or overwrite: 66 | 67 | session = requests.Session() 68 | 69 | print('Downloading {} into {}... '.format(file_name, dest_path), end='') 70 | stdout.flush() 71 | 72 | response = session.get(GoogleDriveDownloader.DOWNLOAD_URL, params={'id': file_id}, stream=True) 73 | 74 | token = GoogleDriveDownloader._get_confirm_token(response) 75 | if token: 76 | params = {'id': file_id, 'confirm': token} 77 | response = session.get(GoogleDriveDownloader.DOWNLOAD_URL, params=params, stream=True) 78 | 79 | GoogleDriveDownloader._save_response_content(response, os.path.join(dest_path, file_name)) 80 | print('Done.') 81 | 82 | try: 83 | print('Unzipping... ', end='') 84 | stdout.flush() 85 | with zipfile.ZipFile(os.path.join(dest_path, file_name), 'r') as zip_file: 86 | for member in zip_file.namelist(): 87 | filename = os.path.basename(member) 88 | # skip directories 89 | if not filename: 90 | continue 91 | # copy file (taken from zipfile's extract) 92 | source = zip_file.open(member) 93 | target = open(os.path.join(dest_path, filename), 'wb') 94 | with source, target: 95 | shutil.copyfileobj(source, target) 96 | print('Done.') 97 | except zipfile.BadZipfile: 98 | warnings.warn('Ignoring `unzip` since "{}" does not look like a valid zip file'.format(file_name)) 99 | 100 | @staticmethod 101 | def _get_confirm_token(response): 102 | for key, value in response.cookies.items(): 103 | if key.startswith('download_warning'): 104 | return value 105 | return None 106 | 107 | @staticmethod 108 | def _save_response_content(response, destination): 109 | with open(destination, 'wb') as f: 110 | for chunk in response.iter_content(GoogleDriveDownloader.CHUNK_SIZE): 111 | if chunk: # filter out keep-alive new chunks 112 | f.write(chunk) 113 | 114 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from torchnlp.datasets.dataset import Dataset 7 | 8 | from data_utils import GoogleDriveDownloader as gdd 9 | from data_utils import text_preprocess 10 | 11 | 12 | def imdb_dataset(directory='data/', data_type='imdb', preprocessing=False, fine_grained=False, 13 | verbose=False, text_length=5000, share_id='1X7YI7nDpKEPio2J-eH7uWCiWoiw2jbSP'): 14 | """ 15 | Load the IMDB dataset (Large Movie Review Dataset v1.0). 16 | 17 | This is a dataset for binary sentiment classification containing substantially more data than 18 | previous benchmark datasets. Provided a set of 25,000 highly polar movie reviews for training, 19 | and 25,000 for testing. 20 | After preprocessing, the total number of training samples is 25,000 and testing samples 25,000. 21 | The min length of text about train data is 11, max length is 2,803, average length is 281; the 22 | min length of text about test data is 8, max length is 2,709, average length is 275. 23 | 24 | **Reference:** http://ai.stanford.edu/~amaas/data/sentiment/ 25 | 26 | Args: 27 | directory (str, optional): Directory to cache the dataset. 28 | data_type (str, optional): Which dataset to use. 29 | preprocessing (bool, optional): Whether to preprocess the original dataset. If preprocessing 30 | equals None, it will not download the preprocessed dataset, it will generate preprocessed 31 | dataset from the original dataset. 32 | fine_grained (bool, optional): Whether to use fine_grained dataset instead of polarity dataset. 33 | verbose (bool, optional): Whether to print the dataset details. 34 | text_length (int, optional): Only load the first text_length words, it only works when 35 | preprocessing is True. 36 | share_id (str, optional): Google Drive share ID about the original dataset to download. 37 | 38 | Returns: 39 | :class:`tuple` of :class:`torchnlp.datasets.Dataset`: Tuple with the training dataset and test 40 | dataset. 41 | 42 | Example: 43 | >>> train, test = imdb_dataset(preprocessing=True) 44 | >>> train[0:2] 45 | [{ 46 | 'label': 'pos', 47 | 'text': 'for a movie that gets no respect there sure are a lot of memorable quotes...'}, 48 | { 49 | 'label': 'pos', 50 | 'text': 'bizarre horror movie filled with famous faces but stolen by cristina raines...'}] 51 | >>> test[0:2] 52 | [{ 53 | 'label': 'pos', 54 | 'text': 'based on an actual story , john boorman shows the struggle of an american...'}, 55 | { 56 | 'label': 'pos', 57 | 'text': 'this is a gem as a film four production the anticipated quality was indeed...'}] 58 | """ 59 | 60 | # other dataset have been set before, only imdb should be set here 61 | if preprocessing and data_type == 'imdb': 62 | share_id = '1naVVErkRQNNJXTA6X_X6YrJY0jPOeuPh' 63 | 64 | if preprocessing: 65 | gdd.download_file_from_google_drive(share_id, data_type + '_preprocessed.zip', directory + data_type) 66 | if fine_grained: 67 | train_file, test_file = 'preprocessed_fine_grained_train.csv', 'preprocessed_fine_grained_test.csv' 68 | else: 69 | train_file, test_file = 'preprocessed_train.csv', 'preprocessed_test.csv' 70 | else: 71 | gdd.download_file_from_google_drive(share_id, data_type + '_original.zip', directory + data_type) 72 | if fine_grained: 73 | train_file, test_file = 'original_fine_grained_train.csv', 'original_fine_grained_test.csv' 74 | else: 75 | train_file, test_file = 'original_train.csv', 'original_test.csv' 76 | 77 | if verbose: 78 | min_train_length, avg_train_length, max_train_length = sys.maxsize, 0, 0 79 | min_test_length, avg_test_length, max_test_length = sys.maxsize, 0, 0 80 | 81 | ret = [] 82 | for file_name in [train_file, test_file]: 83 | csv_file = np.array(pd.read_csv(os.path.join(directory, data_type, file_name), header=None)).tolist() 84 | examples = [] 85 | for label, text in csv_file: 86 | label, text = str(label), str(text) 87 | if preprocessing: 88 | if len(text.split()) > text_length: 89 | text = ' '.join(text.split()[:text_length]) 90 | elif preprocessing is None: 91 | text = text_preprocess(text, data_type) 92 | if len(text.split()) == 0: 93 | continue 94 | if verbose: 95 | if file_name == train_file: 96 | avg_train_length += len(text.split()) 97 | if len(text.split()) > max_train_length: 98 | max_train_length = len(text.split()) 99 | if len(text.split()) < min_train_length: 100 | min_train_length = len(text.split()) 101 | if file_name == test_file: 102 | avg_test_length += len(text.split()) 103 | if len(text.split()) > max_test_length: 104 | max_test_length = len(text.split()) 105 | if len(text.split()) < min_test_length: 106 | min_test_length = len(text.split()) 107 | examples.append({'label': label, 'text': text}) 108 | ret.append(Dataset(examples)) 109 | 110 | if verbose: 111 | print('[!] train samples: {} length--(min: {}, avg: {}, max: {})'. 112 | format(len(ret[0]), min_train_length, round(avg_train_length / len(ret[0])), max_train_length)) 113 | print('[!] test samples: {} length--(min: {}, avg: {}, max: {})'. 114 | format(len(ret[1]), min_test_length, round(avg_test_length / len(ret[1])), max_test_length)) 115 | return tuple(ret) 116 | 117 | 118 | def newsgroups_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000): 119 | """ 120 | Load the 20 Newsgroups dataset (Version 'bydate'). 121 | 122 | The 20 Newsgroups data set is a collection of approximately 20,000 newsgroup documents, 123 | partitioned (nearly) evenly across 20 different newsgroups. The total number of training 124 | samples is 11,293 and testing 7,527. 125 | After preprocessing, the total number of training samples is 11,293 and testing samples 7,527. 126 | The min length of text about train data is 1, max length is 6,779, average length is 143; the 127 | min length of text about test data is 1, max length is 6,142, average length is 139. 128 | 129 | **Reference:** http://qwone.com/~jason/20Newsgroups/ 130 | 131 | Example: 132 | >>> train, test = newsgroups_dataset(preprocessing=True) 133 | >>> train[0:2] 134 | [{ 135 | 'label': 'alt.atheism', 136 | 'text': 'alt atheism faq atheist resourc archiv name atheism resourc alt atheism...'}, 137 | { 138 | 'label': 'alt.atheism', 139 | 'text': 'alt atheism faq introduct atheism archiv name atheism introduct alt...'}] 140 | >>> test[0:2] 141 | [{ 142 | 'label': 'alt.atheism', 143 | 'text': 'bibl quiz answer articl healta saturn wwc edu healta saturn wwc edu...'}, 144 | { 145 | 'label': 'alt.atheism', 146 | 'text': 'amus atheist and agnost articl timmbak mcl timmbak mcl ucsb edu clam bake...'}] 147 | """ 148 | 149 | share_id = '1y8M5yf0DD21ox3K76xJyoCkGIU1Zc4iq' if preprocessing else '18_p4_RnCd0OO2qNxteApbIQ9abrfMyjC' 150 | return imdb_dataset(directory, 'newsgroups', preprocessing, verbose=verbose, text_length=text_length, 151 | share_id=share_id) 152 | 153 | 154 | def reuters_dataset(directory='data/', preprocessing=False, fine_grained=False, verbose=False, text_length=5000): 155 | """ 156 | Load the Reuters-21578 R8 or Reuters-21578 R52 dataset (Version 'modApté'). 157 | 158 | The Reuters-21578 dataset considers only the documents with a single topic and the classes 159 | which still have at least one train and one test example, we have 8 of the 10 most frequent 160 | classes and 52 of the original 90. In total there are 5,485 trainig samples and 2,189 testing 161 | samples in R8 dataset. The total number of training samples is 6,532 and testing 2,568 in R52 162 | dataset. 163 | After preprocessing, the total number of training samples is 5,485 and testing samples 2,189. 164 | The min length of text about train data is 4, max length is 533, average length is 66; the min 165 | length of text about test data is 5, max length is 484, average length is 60. (R8) 166 | After preprocessing, the total number of training samples is 6,532 and testing samples 2,568. 167 | The min length of text about train data is 4, max length is 595, average length is 70; the min 168 | length of text about test data is 5, max length is 484, average length is 64. (R52) 169 | 170 | **Reference:** http://www.daviddlewis.com/resources/testcollections/reuters21578/ 171 | 172 | Example: 173 | >>> train, test = reuters_dataset(preprocessing=True) 174 | >>> train[0:2] 175 | [{ 176 | 'label': 'earn', 177 | 'text': 'champion product approv stock split champion product inc board director...'} 178 | { 179 | 'label': 'acq', 180 | 'text': 'comput termin system cpml complet sale comput termin system inc complet...'}] 181 | >>> test[0:2] 182 | [{ 183 | 'label': 'trade', 184 | 'text': 'asian export fear damag japan rift mount trade friction and japan rais...'}, 185 | { 186 | 'label': 'grain', 187 | 'text': 'china daili vermin eat pct grain stock survei provinc and citi show...'}] 188 | """ 189 | 190 | share_id = '1CY3W31rdagEJ8Kr5gHPeRgS1GVks-YVv' if preprocessing else '1coe-1WB4H7PBY2IG_CVeCl4-UNZRwy1I' 191 | return imdb_dataset(directory, 'reuters', preprocessing, fine_grained, verbose, text_length, share_id) 192 | 193 | 194 | def webkb_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000): 195 | """ 196 | Load the World Wide Knowledge Base (Web->Kb) dataset (Version 1). 197 | 198 | The World Wide Knowledge Base (Web->Kb) dataset is collected by the World Wide Knowledge Base 199 | (Web->Kb) project of the CMU text learning group. These pages were collected from computer 200 | science departments of various universities in 1997, manually classified into seven different 201 | classes: student, faculty, staff, department, course, project, and other. The classes Department 202 | and Staff is discarded, because there were only a few pages from each university. The class Other 203 | is discarded, because pages were very different among this class. The total number of training 204 | samples is 2,785 and testing 1,383. 205 | After preprocessing, the total number of training samples is 2,785 and testing samples 1,383. 206 | The min length of text about train data is 1, max length is 20,628, average length is 134; the min 207 | length of text about test data is 1, max length is 2,082, average length is 136. 208 | 209 | **Reference:** http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-20/www/data/ 210 | 211 | Example: 212 | >>> train, test = webkb_dataset(preprocessing=True) 213 | >>> train[0:2] 214 | [{ 215 | 'label': 'student', 216 | 'text': 'brian comput scienc depart univers wisconsin dayton street madison offic...'} 217 | { 218 | 'label': 'student', 219 | 'text': 'denni swanson web page mail pop uki offic hour comput lab offic anderson...'}] 220 | >>> test[0:2] 221 | [{ 222 | 'label': 'student', 223 | 'text': 'eric homepag eric wei tsinghua physic fudan genet'}, 224 | { 225 | 'label': 'course', 226 | 'text': 'comput system perform evalu model new sept assign due oct postscript text...'}] 227 | """ 228 | 229 | share_id = '1oqcl2N0kDoBlHo_hFgKc_MaSvs0ny1t7' if preprocessing else '12uR98xYZ44fXX0WUf9RjOM4GpBt3JAV8' 230 | return imdb_dataset(directory, 'webkb', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id) 231 | 232 | 233 | def cade_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000): 234 | """ 235 | Load the Cade12 dataset (Version 1). 236 | 237 | The Cade12 dataset is corresponding to a subset of web pages extracted from the CADÊ Web Directory, 238 | which points to Brazilian web pages classified by human experts. The total number of training 239 | samples is 27,322 and testing 13,661. 240 | After preprocessing, the total number of training samples is 27,322 and testing samples 13,661. 241 | The min length of text about train data is 2, max length is 22,352, average length is 119; the min 242 | length of text about test data is 2, max length is 15,318, average length is 112. 243 | 244 | **Reference:** http://www.cade.com.br/ 245 | 246 | Example: 247 | >>> train, test = cade_dataset(preprocessing=True) 248 | >>> train[0:2] 249 | [{ 250 | 'label': '08_cultura', 251 | 'text': 'br br email arvores arvores http www apoio mascote natureza vida links foram...'} 252 | { 253 | 'label': '02_sociedade', 254 | 'text': 'page frames browser support virtual araraquara shop'}] 255 | >>> test[0:2] 256 | [{ 257 | 'label': '02_sociedade', 258 | 'text': 'dezembro envie mail br manutencao funcionarios funcionarios funcionarios...'}, 259 | { 260 | 'label': '07_internet', 261 | 'text': 'auto sao pagina br br computacao rede internet internet internet internet...'}] 262 | """ 263 | 264 | share_id = '13CwKytxKlvMP6FW9iOCOMvmKlm5YWD-k' if preprocessing else '1cWlJHAt5dhomDxoQQmXu3APaFkdjRg_P' 265 | return imdb_dataset(directory, 'cade', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id) 266 | 267 | 268 | def dbpedia_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000): 269 | """ 270 | Load the DBPedia Ontology Classification dataset (Version 2). 271 | 272 | The DBpedia ontology classification dataset is constructed by picking 14 non-overlapping classes 273 | from DBpedia 2014. They are listed in classes.txt. From each of these 14 ontology classes, we 274 | randomly choose 40,000 training samples and 5,000 testing samples. Therefore, the total size 275 | of the training dataset is 560,000 and testing dataset 70,000. 276 | After preprocessing, the total number of training samples is 560,000 and testing samples 70,000. 277 | The min length of text about train data is 3, max length is 2,780, average length is 64; the min 278 | length of text about test data is 4, max length is 930, average length is 64. 279 | 280 | **Reference:** http://dbpedia.org 281 | 282 | Example: 283 | >>> train, test = dbpedia_dataset(preprocessing=True) 284 | >>> train[0:2] 285 | [{ 286 | 'label': 'Company', 287 | 'text': 'e . d . abbott ltd abbott of farnham e d abbott limited was a british...'}, 288 | { 289 | 'label': 'Company', 290 | 'text': 'schwan - stabilo schwan - stabilo is a german maker of pens for writing...'}] 291 | >>> test[0:2] 292 | [{ 293 | 'label': 'Company', 294 | 'text': 'ty ku ty ku / ta ɪ ku ː / is an american alcoholic beverage company that...'}, 295 | { 296 | 'label': 'Company', 297 | 'text': 'odd lot entertainment oddlot entertainment founded in 2 0 0 1 by longtime...'}] 298 | """ 299 | 300 | share_id = '1egq6UCaaqeZOq7siitXEIfIFwYUjFjnP' if preprocessing else '1YEZP-ajK3fUEMhdgkATRmikeuI7EjI9X' 301 | return imdb_dataset(directory, 'dbpedia', preprocessing, verbose=verbose, text_length=text_length, 302 | share_id=share_id) 303 | 304 | 305 | def agnews_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000): 306 | """ 307 | Load the AG's News Topic Classification dataset (Version 3). 308 | 309 | The AG's news topic classification dataset is constructed by choosing 4 largest classes from 310 | the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. 311 | The total number of training samples is 120,000 and testing 7,600. 312 | After preprocessing, the total number of training samples is 120,000 and testing samples 7,600. 313 | The min length of text about train data is 13, max length is 354, average length is 49; the min 314 | length of text about test data is 15, max length is 250, average length is 48. 315 | 316 | **Reference:** http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html 317 | 318 | Example: 319 | >>> train, test = agnews_dataset(preprocessing=True) 320 | >>> train[0:2] 321 | [{ 322 | 'label': 'Business', 323 | 'text': 'wall st . bears claw back into the black ( reuters ) reuters - short - sellers...'}, 324 | { 325 | 'label': 'Business', 326 | 'text': 'carlyle looks toward commercial aerospace ( reuters ) reuters - private investment...'}] 327 | >>> test[0:2] 328 | [{ 329 | 'label': 'Business', 330 | 'text': 'fears for t n pension after talks unions representing workers at turner newall...'}, 331 | { 332 | 'label': 'Sci/Tech', 333 | 'text': 'the race is on : second private team sets launch date for human spaceflight...'}] 334 | """ 335 | 336 | share_id = '153R49C-JY8NDmRwc7bikZvU3EEEjKRs2' if preprocessing else '1EKKHXTXwGJaitaPn8BYD2bpSGQCB76te' 337 | return imdb_dataset(directory, 'agnews', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id) 338 | 339 | 340 | def yahoo_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000): 341 | """ 342 | Load the Yahoo! Answers Topic Classification dataset (Version 2). 343 | 344 | The Yahoo! Answers topic classification dataset is constructed using 10 largest main categories. 345 | Each class contains 140,000 training samples and 6,000 testing samples. Therefore, the total number 346 | of training samples is 1,400,000 and testing samples 60,000 in this dataset. 347 | After preprocessing, the total number of training samples is 1,400,000 and testing samples 60,000. 348 | The min length of text about train data is 2, max length is 4,044, average length is 118; the min 349 | length of text about test data is 3, max length is 4,017, average length is 119. 350 | 351 | **Reference:** https://webscope.sandbox.yahoo.com/catalog.php?datatype=l 352 | 353 | Example: 354 | >>> train, test = yahoo_dataset(preprocessing=True) 355 | >>> train[0:2] 356 | [{ 357 | 'label': 'Computers & Internet', 358 | 'text': 'why doesn ' t an optical mouse work on a glass table ? or even on some surfaces...'}, 359 | { 360 | 'label': 'Sports', 361 | 'text': 'what is the best off - road motorcycle trail ? long - distance trail throughout...'}] 362 | >>> test[0:2] 363 | [{ 364 | 'label': 'Family & Relationships', 365 | 'text': 'what makes friendship click ? how does the spark keep going ? good communication...'}, 366 | { 367 | 'label': 'Science & Mathematics', 368 | 'text': 'why does zebras have stripes ? what is the purpose or those stripes ? who do they...'}] 369 | """ 370 | 371 | share_id = '1LS7iQM3qMofMCVlm08LfniyqXsdhFdnn' if preprocessing else '15xpGyKaQk2-WDrjzsz57TVKrgQbhM7Ct' 372 | return imdb_dataset(directory, 'yahoo', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id) 373 | 374 | 375 | def sogou_dataset(directory='data/', preprocessing=False, verbose=False, text_length=5000): 376 | """ 377 | Load the Sogou News Topic Classification dataset (Version 3). 378 | 379 | The Sogou news topic classification dataset is constructed by manually labeling each news article 380 | according to its URL, which represents roughly the categorization of news in their websites. We 381 | chose 5 largest categories for the dataset, each having 90,000 samples for training and 12,000 for 382 | testing. The Pinyin texts are converted using pypinyin combined with jieba Chinese segmentation 383 | system. In total there are 450,000 training samples and 60,000 testing samples. 384 | After preprocessing, the total number of training samples is 450,000 and testing samples 60,000. 385 | The min length of text about train data is 2, max length is 42,695, average length is 612; the min 386 | length of text about test data is 3, max length is 64,651, average length is 616. 387 | 388 | **Reference:** http://www.sogou.com/labs/dl/ca.html and http://www.sogou.com/labs/dl/cs.html 389 | 390 | Example: 391 | >>> train, test = sogou_dataset(preprocessing=True) 392 | >>> train[0:2] 393 | [{ 394 | 'label': 'automobile', 395 | 'text': '2 0 0 8 di4 qi1 jie4 qi1ng da3o guo2 ji4 che1 zha3n me3i nv3 mo2 te4 2 0 0 8...'} 396 | { 397 | 'label': 'automobile', 398 | 'text': 'zho1ng hua2 ju4n jie2 frv ya4o shi tu2 we2i zho1ng hua2 ju4n jie2 frv ya4o shi .'}] 399 | >>> test[0:2] 400 | [{ 401 | 'label': 'sports', 402 | 'text': 'ti3 ca1o shi4 jie4 be1i : che2ng fe1i na2 pi2ng he2ng mu4 zi4 yo2u ca1o ji1n...'}, 403 | { 404 | 'label': 'automobile', 405 | 'text': 'da3o ha2ng du2 jia1 ti2 go1ng me3i ri4 ba4o jia4 re4 xia4n : 0 1 0 - 6 4 4 3...'}] 406 | """ 407 | 408 | share_id = '1HbJHzIacbQt7m-IRZzv8nRaSubSrYdip' if preprocessing else '1pvg0e3HSE_IeYdphYyJ8k52JZyGnae_U' 409 | return imdb_dataset(directory, 'sogou', preprocessing, verbose=verbose, text_length=text_length, share_id=share_id) 410 | 411 | 412 | def yelp_dataset(directory='data/', preprocessing=False, fine_grained=False, verbose=False, text_length=5000): 413 | """ 414 | Load the Yelp Review Full Star or Yelp Review Polarity dataset (Version 1). 415 | 416 | The Yelp reviews polarity dataset is constructed by considering stars 1 and 2 negative, and 3 417 | and 4 positive. For each polarity 280,000 training samples and 19,000 testing samples are take 418 | randomly. In total there are 560,000 training samples and 38,000 testing samples. Negative 419 | polarity is class 1, and positive class 2. 420 | The Yelp reviews full star dataset is constructed by randomly taking 130,000 training samples 421 | and 10,000 testing samples for each review star from 1 to 5. In total there are 650,000 training 422 | samples and 50,000 testing samples. 423 | After preprocessing, the total number of training samples is 560,000 and testing samples 38,000. 424 | The min length of text about train data is 1, max length is 1,491, average length is 162; the min 425 | length of text about test data is 1, max length is 1,311, average length is 162. (polarity) 426 | After preprocessing, the total number of training samples is 650,000 and testing samples 50,000. 427 | The min length of text about train data is 1, max length is 1,332, average length is 164; the min 428 | length of text about test data is 1, max length is 1,491, average length is 164. (full) 429 | 430 | **Reference:** http://www.yelp.com/dataset_challenge 431 | 432 | Example: 433 | >>> train, test = yelp_dataset(preprocessing=True) 434 | >>> train[0:2] 435 | [{ 436 | 'label': '1', 437 | 'text': 'unfortunately , the frustration of being dr . goldberg ' s patient is a repeat...'} 438 | { 439 | 'label': '2', 440 | 'text': 'been going to dr . goldberg for over 1 0 years . i think i was one of his 1...'}] 441 | >>> test[0:2] 442 | [{ 443 | 'label': '2', 444 | 'text': 'contrary to other reviews , i have zero complaints about the service or the prices...'}, 445 | { 446 | 'label': '1', 447 | 'text': 'last summer i had an appointment to get new tires and had to wait a super long time...'}] 448 | """ 449 | 450 | share_id = '1ecOuyAhT-MjXQiueRHqS9LnY0CV0HQYd' if preprocessing else '1yEF0Lnd4f8mDZiqeh2QmEKvp_gPdYgtH' 451 | return imdb_dataset(directory, 'yelp', preprocessing, fine_grained, verbose, text_length, share_id) 452 | 453 | 454 | def amazon_dataset(directory='data/', preprocessing=False, fine_grained=False, verbose=False, text_length=5000): 455 | """ 456 | Load the Amazon Review Full Score or Amazon Review Polaridy dataset (Version 3). 457 | 458 | The Amazon reviews polarity dataset is constructed by taking review score 1 and 2 as negative, 459 | and 4 and 5 as positive. For each polarity 1,800,000 training samples and 200,000 testing samples 460 | are take randomly. In total there are 3,600,000 training samples and 400,000 testing samples. 461 | Negative polarity is class 1, and positive class 2. 462 | The Amazon reviews full score dataset is constructed by randomly taking 600,000 training samples 463 | and 130,000 testing samples for each review score from 1 to 5. In total there are 3,000,000 464 | training samples and 650,000 testing samples. 465 | After preprocessing, the total number of training samples is 3,600,000 and testing samples 400,000. 466 | The min length of text about train data is 2, max length is 986, average length is 95; the min 467 | length of text about test data is 14, max length is 914, average length is 95. (polarity) 468 | After preprocessing, the total number of training samples is 3,000,000 and testing samples 650,000. 469 | The min length of text about train data is 2, max length is 781, average length is 97; the min 470 | length of text about test data is 12, max length is 931, average length is 97. (full) 471 | 472 | **Reference:** http://jmcauley.ucsd.edu/data/amazon/ 473 | 474 | Example: 475 | >>> train, test = amazon_dataset(preprocessing=True) 476 | >>> train[0:2] 477 | [{ 478 | 'label': '2', 479 | 'text': 'stuning even for the non - gamer this sound track was beautiful ! it paints...'} 480 | { 481 | 'label': '2', 482 | 'text': 'the best soundtrack ever to anything . i ' m reading a lot of reviews saying...'}] 483 | >>> test[0:2] 484 | [{ 485 | 'label': '2', 486 | 'text': 'great cd my lovely pat has one of the great voices of her generation . i have...'}, 487 | { 488 | 'label': '2', 489 | 'text': 'one of the best game music soundtracks - for a game i didn ' t really play...'}] 490 | """ 491 | 492 | share_id = '1BSqCU6DwIVD1jllbsz9ueudu3tSfomzY' if preprocessing else '11-l1T-_kBdtqrqqSfqNnJCLC6_NRM9Je' 493 | return imdb_dataset(directory, 'amazon', preprocessing, fine_grained, verbose, text_length, share_id) 494 | -------------------------------------------------------------------------------- /epochs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/epochs/.gitkeep -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torchnet as tnt 7 | from torch.nn import CrossEntropyLoss 8 | from torch.optim import Adam 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | from torch.utils.data import DataLoader 11 | from torchnet.logger import VisdomPlotLogger, VisdomLogger 12 | from torchnlp.samplers import BucketBatchSampler 13 | 14 | from model import Model 15 | from utils import load_data, MarginLoss, collate_fn, FocalLoss 16 | 17 | 18 | def reset_meters(): 19 | meter_accuracy.reset() 20 | meter_loss.reset() 21 | meter_confusion.reset() 22 | 23 | 24 | if __name__ == '__main__': 25 | 26 | parser = argparse.ArgumentParser(description='Train Text Classification') 27 | parser.add_argument('--data_type', default='imdb', type=str, 28 | choices=['imdb', 'newsgroups', 'reuters', 'webkb', 'cade', 'dbpedia', 'agnews', 'yahoo', 29 | 'sogou', 'yelp', 'amazon'], help='dataset type') 30 | parser.add_argument('--fine_grained', action='store_true', help='use fine grained class or not, it only works for ' 31 | 'reuters, yelp and amazon') 32 | parser.add_argument('--text_length', default=5000, type=int, help='the number of words about the text to load') 33 | parser.add_argument('--routing_type', default='k_means', type=str, choices=['k_means', 'dynamic'], 34 | help='routing type, it only works for capsule classifier') 35 | parser.add_argument('--loss_type', default='mf', type=str, 36 | choices=['margin', 'focal', 'cross', 'mf', 'mc', 'fc', 'mfc'], help='loss type') 37 | parser.add_argument('--embedding_type', default='cwc', type=str, choices=['cwc', 'cc', 'normal'], 38 | help='embedding type') 39 | parser.add_argument('--classifier_type', default='capsule', type=str, choices=['capsule', 'linear'], 40 | help='classifier type') 41 | parser.add_argument('--embedding_size', default=64, type=int, help='embedding size') 42 | parser.add_argument('--num_codebook', default=8, type=int, 43 | help='codebook number, it only works for cwc and cc embedding') 44 | parser.add_argument('--num_codeword', default=None, type=int, 45 | help='codeword number, it only works for cwc and cc embedding') 46 | parser.add_argument('--hidden_size', default=128, type=int, help='hidden size') 47 | parser.add_argument('--in_length', default=8, type=int, 48 | help='in capsule length, it only works for capsule classifier') 49 | parser.add_argument('--out_length', default=16, type=int, 50 | help='out capsule length, it only works for capsule classifier') 51 | parser.add_argument('--num_iterations', default=3, type=int, 52 | help='routing iterations number, it only works for capsule classifier') 53 | parser.add_argument('--num_repeat', default=10, type=int, 54 | help='gumbel softmax repeat number, it only works for cc embedding') 55 | parser.add_argument('--drop_out', default=0.5, type=float, help='drop_out rate of GRU layer') 56 | parser.add_argument('--batch_size', default=32, type=int, help='train batch size') 57 | parser.add_argument('--num_epochs', default=10, type=int, help='train epochs number') 58 | parser.add_argument('--num_steps', default=100, type=int, help='test steps number') 59 | parser.add_argument('--pre_model', default=None, type=str, 60 | help='pre-trained model weight, it only works for routing_type experiment') 61 | 62 | opt = parser.parse_args() 63 | DATA_TYPE, FINE_GRAINED, TEXT_LENGTH = opt.data_type, opt.fine_grained, opt.text_length 64 | ROUTING_TYPE, LOSS_TYPE, EMBEDDING_TYPE = opt.routing_type, opt.loss_type, opt.embedding_type 65 | CLASSIFIER_TYPE, EMBEDDING_SIZE, NUM_CODEBOOK = opt.classifier_type, opt.embedding_size, opt.num_codebook 66 | NUM_CODEWORD, HIDDEN_SIZE, IN_LENGTH = opt.num_codeword, opt.hidden_size, opt.in_length 67 | OUT_LENGTH, NUM_ITERATIONS, DROP_OUT, BATCH_SIZE = opt.out_length, opt.num_iterations, opt.drop_out, opt.batch_size 68 | NUM_REPEAT, NUM_EPOCHS, NUM_STEPS, PRE_MODEL = opt.num_repeat, opt.num_epochs, opt.num_steps, opt.pre_model 69 | 70 | # prepare dataset 71 | sentence_encoder, label_encoder, train_dataset, test_dataset = load_data(DATA_TYPE, preprocessing=True, 72 | fine_grained=FINE_GRAINED, verbose=True, 73 | text_length=TEXT_LENGTH) 74 | VOCAB_SIZE, NUM_CLASS = sentence_encoder.vocab_size, label_encoder.vocab_size 75 | print("[!] vocab_size: {}, num_class: {}".format(VOCAB_SIZE, NUM_CLASS)) 76 | train_sampler = BucketBatchSampler(train_dataset, BATCH_SIZE, False, sort_key=lambda row: len(row['text'])) 77 | train_iterator = DataLoader(train_dataset, batch_sampler=train_sampler, collate_fn=collate_fn) 78 | test_sampler = BucketBatchSampler(test_dataset, BATCH_SIZE * 2, False, sort_key=lambda row: len(row['text'])) 79 | test_iterator = DataLoader(test_dataset, batch_sampler=test_sampler, collate_fn=collate_fn) 80 | 81 | model = Model(VOCAB_SIZE, EMBEDDING_SIZE, NUM_CODEBOOK, NUM_CODEWORD, HIDDEN_SIZE, IN_LENGTH, OUT_LENGTH, 82 | NUM_CLASS, ROUTING_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, NUM_ITERATIONS, NUM_REPEAT, DROP_OUT) 83 | if PRE_MODEL is not None: 84 | model_weight = torch.load('epochs/{}'.format(PRE_MODEL), map_location='cpu') 85 | model_weight.pop('classifier.weight') 86 | model.load_state_dict(model_weight, strict=False) 87 | 88 | if LOSS_TYPE == 'margin': 89 | loss_criterion = [MarginLoss(NUM_CLASS)] 90 | elif LOSS_TYPE == 'focal': 91 | loss_criterion = [FocalLoss()] 92 | elif LOSS_TYPE == 'cross': 93 | loss_criterion = [CrossEntropyLoss()] 94 | elif LOSS_TYPE == 'mf': 95 | loss_criterion = [MarginLoss(NUM_CLASS), FocalLoss()] 96 | elif LOSS_TYPE == 'mc': 97 | loss_criterion = [MarginLoss(NUM_CLASS), CrossEntropyLoss()] 98 | elif LOSS_TYPE == 'fc': 99 | loss_criterion = [FocalLoss(), CrossEntropyLoss()] 100 | else: 101 | loss_criterion = [MarginLoss(NUM_CLASS), FocalLoss(), CrossEntropyLoss()] 102 | if torch.cuda.is_available(): 103 | model, cudnn.benchmark = model.to('cuda'), True 104 | 105 | if PRE_MODEL is None: 106 | optim_configs = [{'params': model.embedding.parameters(), 'lr': 1e-4 * 10}, 107 | {'params': model.features.parameters(), 'lr': 1e-4 * 10}, 108 | {'params': model.classifier.parameters(), 'lr': 1e-4}] 109 | else: 110 | for param in model.embedding.parameters(): 111 | param.requires_grad = False 112 | for param in model.features.parameters(): 113 | param.requires_grad = False 114 | optim_configs = [{'params': model.classifier.parameters(), 'lr': 1e-4}] 115 | optimizer = Adam(optim_configs, lr=1e-4) 116 | lr_scheduler = MultiStepLR(optimizer, milestones=[int(NUM_EPOCHS * 0.5), int(NUM_EPOCHS * 0.7)], gamma=0.1) 117 | 118 | print("# trainable parameters:", sum(param.numel() if param.requires_grad else 0 for param in model.parameters())) 119 | # record statistics 120 | results = {'train_loss': [], 'train_accuracy': [], 'test_loss': [], 'test_accuracy': []} 121 | # record current best test accuracy 122 | best_acc = 0 123 | meter_loss = tnt.meter.AverageValueMeter() 124 | meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) 125 | meter_confusion = tnt.meter.ConfusionMeter(NUM_CLASS, normalized=True) 126 | 127 | # config the visdom figures 128 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']: 129 | env_name = DATA_TYPE + '_fine_grained' 130 | else: 131 | env_name = DATA_TYPE 132 | loss_logger = VisdomPlotLogger('line', env=env_name, opts={'title': 'Loss'}) 133 | accuracy_logger = VisdomPlotLogger('line', env=env_name, opts={'title': 'Accuracy'}) 134 | train_confusion_logger = VisdomLogger('heatmap', env=env_name, opts={'title': 'Train Confusion Matrix'}) 135 | test_confusion_logger = VisdomLogger('heatmap', env=env_name, opts={'title': 'Test Confusion Matrix'}) 136 | 137 | current_step = 0 138 | for epoch in range(1, NUM_EPOCHS + 1): 139 | for data, target in train_iterator: 140 | current_step += 1 141 | label = target 142 | if torch.cuda.is_available(): 143 | data, label = data.to('cuda'), label.to('cuda') 144 | # train model 145 | model.train() 146 | optimizer.zero_grad() 147 | classes = model(data) 148 | loss = sum([criterion(classes, label) for criterion in loss_criterion]) 149 | loss.backward() 150 | optimizer.step() 151 | # save the metrics 152 | meter_loss.add(loss.detach().cpu().item()) 153 | meter_accuracy.add(classes.detach().cpu(), target) 154 | meter_confusion.add(classes.detach().cpu(), target) 155 | 156 | if current_step % NUM_STEPS == 0: 157 | # print the information about train 158 | loss_logger.log(current_step // NUM_STEPS, meter_loss.value()[0], name='train') 159 | accuracy_logger.log(current_step // NUM_STEPS, meter_accuracy.value()[0], name='train') 160 | train_confusion_logger.log(meter_confusion.value()) 161 | results['train_loss'].append(meter_loss.value()[0]) 162 | results['train_accuracy'].append(meter_accuracy.value()[0]) 163 | print('[Step %d] Training Loss: %.4f Accuracy: %.2f%%' % ( 164 | current_step // NUM_STEPS, meter_loss.value()[0], meter_accuracy.value()[0])) 165 | reset_meters() 166 | 167 | # test model periodically 168 | model.eval() 169 | with torch.no_grad(): 170 | for data, target in test_iterator: 171 | label = target 172 | if torch.cuda.is_available(): 173 | data, label = data.to('cuda'), label.to('cuda') 174 | classes = model(data) 175 | loss = sum([criterion(classes, label) for criterion in loss_criterion]) 176 | # save the metrics 177 | meter_loss.add(loss.detach().cpu().item()) 178 | meter_accuracy.add(classes.detach().cpu(), target) 179 | meter_confusion.add(classes.detach().cpu(), target) 180 | # print the information about test 181 | loss_logger.log(current_step // NUM_STEPS, meter_loss.value()[0], name='test') 182 | accuracy_logger.log(current_step // NUM_STEPS, meter_accuracy.value()[0], name='test') 183 | test_confusion_logger.log(meter_confusion.value()) 184 | results['test_loss'].append(meter_loss.value()[0]) 185 | results['test_accuracy'].append(meter_accuracy.value()[0]) 186 | 187 | # save best model 188 | if meter_accuracy.value()[0] > best_acc: 189 | best_acc = meter_accuracy.value()[0] 190 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']: 191 | torch.save(model.state_dict(), 'epochs/{}_{}_{}_{}.pth' 192 | .format(DATA_TYPE + '_fine-grained', EMBEDDING_TYPE, CLASSIFIER_TYPE, 193 | str(TEXT_LENGTH))) 194 | else: 195 | torch.save(model.state_dict(), 'epochs/{}_{}_{}_{}.pth' 196 | .format(DATA_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, str(TEXT_LENGTH))) 197 | print('[Step %d] Testing Loss: %.4f Accuracy: %.2f%% Best Accuracy: %.2f%%' % ( 198 | current_step // NUM_STEPS, meter_loss.value()[0], meter_accuracy.value()[0], best_acc)) 199 | reset_meters() 200 | 201 | # save statistics 202 | data_frame = pd.DataFrame(data=results, index=range(1, current_step // NUM_STEPS + 1)) 203 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']: 204 | data_frame.to_csv('statistics/{}_{}_{}_results.csv'.format( 205 | DATA_TYPE + '_fine-grained', EMBEDDING_TYPE, CLASSIFIER_TYPE), index_label='step') 206 | else: 207 | data_frame.to_csv('statistics/{}_{}_{}_results.csv'.format( 208 | DATA_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE), index_label='step') 209 | lr_scheduler.step(epoch) 210 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from capsule_layer import CapsuleLinear 6 | from torch import nn 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | class CompositionalEmbedding(nn.Module): 11 | r"""A simple compositional codeword and codebook that store embeddings. 12 | 13 | Args: 14 | num_embeddings (int): size of the dictionary of embeddings 15 | embedding_dim (int): size of each embedding vector 16 | num_codebook (int): size of the codebook of embeddings 17 | num_codeword (int, optional): size of the codeword of embeddings 18 | weighted (bool, optional): weighted version of unweighted version 19 | return_code (bool, optional): return code or not 20 | 21 | Shape: 22 | - Input: (LongTensor): (N, W), W = number of indices to extract per mini-batch 23 | - Output: (Tensor): (N, W, embedding_dim) 24 | 25 | Attributes: 26 | - code (Tensor): the learnable weights of the module of shape 27 | (num_embeddings, num_codebook, num_codeword) 28 | - codebook (Tensor): the learnable weights of the module of shape 29 | (num_codebook, num_codeword, embedding_dim) 30 | 31 | Examples:: 32 | >>> m = CompositionalEmbedding(200, 64, 16, 32, weighted=False) 33 | >>> a = torch.randperm(128).view(16, -1) 34 | >>> output = m(a) 35 | >>> print(output.size()) 36 | torch.Size([16, 8, 64]) 37 | """ 38 | 39 | def __init__(self, num_embeddings, embedding_dim, num_codebook, num_codeword=None, num_repeat=10, weighted=True, 40 | return_code=False): 41 | super(CompositionalEmbedding, self).__init__() 42 | self.num_embeddings = num_embeddings 43 | self.embedding_dim = embedding_dim 44 | self.num_codebook = num_codebook 45 | self.num_repeat = num_repeat 46 | self.weighted = weighted 47 | self.return_code = return_code 48 | 49 | if num_codeword is None: 50 | num_codeword = math.ceil(math.pow(num_embeddings, 1 / num_codebook)) 51 | self.num_codeword = num_codeword 52 | self.code = Parameter(torch.Tensor(num_embeddings, num_codebook, num_codeword)) 53 | self.codebook = Parameter(torch.Tensor(num_codebook, num_codeword, embedding_dim)) 54 | 55 | nn.init.normal_(self.code) 56 | nn.init.normal_(self.codebook) 57 | 58 | def forward(self, input): 59 | batch_size = input.size(0) 60 | index = input.view(-1) 61 | code = self.code.index_select(dim=0, index=index) 62 | if self.weighted: 63 | # reweight, do softmax, make sure the sum of weight about each book to 1 64 | code = F.softmax(code, dim=-1) 65 | out = (code[:, :, None, :] @ self.codebook[None, :, :, :]).squeeze(dim=-2).sum(dim=1) 66 | else: 67 | # because Gumbel SoftMax works in a stochastic manner, needs to run several times to 68 | # get more accurate embedding 69 | code = (torch.sum(torch.stack([F.gumbel_softmax(code) for _ in range(self.num_repeat)]), dim=0)).argmax( 70 | dim=-1) 71 | out = [] 72 | for index in range(self.num_codebook): 73 | out.append(self.codebook[index, :, :].index_select(dim=0, index=code[:, index])) 74 | out = torch.sum(torch.stack(out), dim=0) 75 | code = F.one_hot(code, num_classes=self.num_codeword) 76 | 77 | out = out.view(batch_size, -1, self.embedding_dim) 78 | code = code.view(batch_size, -1, self.num_codebook, self.num_codeword) 79 | if self.return_code: 80 | return out, code 81 | else: 82 | return out 83 | 84 | def __repr__(self): 85 | return self.__class__.__name__ + ' (' + str(self.num_embeddings) + ', ' + str(self.embedding_dim) + ')' 86 | 87 | 88 | class Model(nn.Module): 89 | def __init__(self, vocab_size, embedding_size, num_codebook, num_codeword, hidden_size, in_length, out_length, 90 | num_class, routing_type, embedding_type, classifier_type, num_iterations, num_repeat, dropout): 91 | super().__init__() 92 | 93 | self.in_length, self.out_length = in_length, out_length 94 | self.hidden_size, self.classifier_type = hidden_size, classifier_type 95 | self.embedding_type = embedding_type 96 | 97 | if embedding_type == 'cwc': 98 | self.embedding = CompositionalEmbedding(vocab_size, embedding_size, num_codebook, num_codeword, 99 | weighted=True) 100 | elif embedding_type == 'cc': 101 | self.embedding = CompositionalEmbedding(vocab_size, embedding_size, num_codebook, num_codeword, num_repeat, 102 | weighted=False) 103 | else: 104 | self.embedding = nn.Embedding(vocab_size, embedding_size) 105 | self.features = nn.GRU(embedding_size, self.hidden_size, num_layers=2, dropout=dropout, batch_first=True, 106 | bidirectional=True) 107 | if classifier_type == 'capsule' and routing_type == 'k_means': 108 | self.classifier = CapsuleLinear(out_capsules=num_class, in_length=self.in_length, 109 | out_length=self.out_length, in_capsules=None, share_weight=True, 110 | routing_type='k_means', num_iterations=num_iterations, bias=False) 111 | elif classifier_type == 'capsule' and routing_type == 'dynamic': 112 | self.classifier = CapsuleLinear(out_capsules=num_class, in_length=self.in_length, 113 | out_length=self.out_length, in_capsules=None, share_weight=True, 114 | routing_type='dynamic', num_iterations=num_iterations, bias=False) 115 | else: 116 | self.classifier = nn.Linear(in_features=self.hidden_size, out_features=num_class, bias=False) 117 | 118 | def forward(self, x): 119 | embed = self.embedding(x) 120 | out, _ = self.features(embed) 121 | 122 | out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:] 123 | out = out.mean(dim=1).contiguous() 124 | if self.classifier_type == 'capsule': 125 | out = out.view(out.size(0), -1, self.in_length) 126 | out = self.classifier(out) 127 | classes = out.norm(dim=-1) 128 | else: 129 | out = out.view(out.size(0), -1) 130 | classes = self.classifier(out) 131 | return classes 132 | -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/.gitkeep -------------------------------------------------------------------------------- /results/agnews.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/agnews.png -------------------------------------------------------------------------------- /results/amazon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/amazon.png -------------------------------------------------------------------------------- /results/amazon_fine_grained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/amazon_fine_grained.png -------------------------------------------------------------------------------- /results/dbpedia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/dbpedia.png -------------------------------------------------------------------------------- /results/sogou.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/sogou.png -------------------------------------------------------------------------------- /results/yahoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/yahoo.png -------------------------------------------------------------------------------- /results/yelp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/yelp.png -------------------------------------------------------------------------------- /results/yelp_fine_grained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/results/yelp_fine_grained.png -------------------------------------------------------------------------------- /statistics/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CCCapsNet/76b16f71a344d3ada9fa335f5506c5b74769a4e9/statistics/.gitkeep -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torchnlp.encoders.label_encoder import LabelEncoder 9 | from torchnlp.encoders.text import WhitespaceEncoder 10 | from torchnlp.encoders.text.default_reserved_tokens import DEFAULT_PADDING_TOKEN, DEFAULT_UNKNOWN_TOKEN 11 | from torchnlp.encoders.text.text_encoder import stack_and_pad_tensors 12 | from torchnlp.utils import datasets_iterator 13 | 14 | from datasets import imdb_dataset, agnews_dataset, amazon_dataset, dbpedia_dataset, newsgroups_dataset, reuters_dataset, \ 15 | webkb_dataset, yahoo_dataset, yelp_dataset, cade_dataset, sogou_dataset 16 | 17 | 18 | class MarginLoss(nn.Module): 19 | def __init__(self, num_class, size_average=True): 20 | super(MarginLoss, self).__init__() 21 | self.num_class = num_class 22 | self.size_average = size_average 23 | 24 | def forward(self, classes, labels): 25 | labels = F.one_hot(labels, self.num_class).float() 26 | left = F.relu(0.9 - classes, inplace=True) ** 2 27 | right = F.relu(classes - 0.1, inplace=True) ** 2 28 | loss = labels * left + 0.5 * (1 - labels) * right 29 | loss = loss.sum(dim=-1) 30 | if self.size_average: 31 | return loss.mean() 32 | else: 33 | return loss.sum() 34 | 35 | 36 | class FocalLoss(nn.Module): 37 | def __init__(self, alpha=0.25, gamma=2, size_average=True): 38 | super(FocalLoss, self).__init__() 39 | self.alpha = alpha 40 | self.gamma = gamma 41 | self.size_average = size_average 42 | 43 | def forward(self, classes, labels): 44 | log_pt = F.log_softmax(classes, dim=-1) 45 | log_pt = log_pt.gather(-1, labels.view(-1, 1)).view(-1) 46 | pt = log_pt.exp() 47 | loss = -self.alpha * (1 - pt) ** self.gamma * log_pt 48 | if self.size_average: 49 | return loss.mean() 50 | else: 51 | return loss.sum() 52 | 53 | 54 | def load_data(data_type, preprocessing=False, fine_grained=False, verbose=False, text_length=5000, encode=True): 55 | if data_type == 'imdb': 56 | train_data, test_data = imdb_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length) 57 | elif data_type == 'newsgroups': 58 | train_data, test_data = newsgroups_dataset(preprocessing=preprocessing, verbose=verbose, 59 | text_length=text_length) 60 | elif data_type == 'reuters': 61 | train_data, test_data = reuters_dataset(preprocessing=preprocessing, fine_grained=fine_grained, verbose=verbose, 62 | text_length=text_length) 63 | elif data_type == 'webkb': 64 | train_data, test_data = webkb_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length) 65 | elif data_type == 'cade': 66 | train_data, test_data = cade_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length) 67 | elif data_type == 'dbpedia': 68 | train_data, test_data = dbpedia_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length) 69 | elif data_type == 'agnews': 70 | train_data, test_data = agnews_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length) 71 | elif data_type == 'yahoo': 72 | train_data, test_data = yahoo_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length) 73 | elif data_type == 'sogou': 74 | train_data, test_data = sogou_dataset(preprocessing=preprocessing, verbose=verbose, text_length=text_length) 75 | elif data_type == 'yelp': 76 | train_data, test_data = yelp_dataset(preprocessing=preprocessing, fine_grained=fine_grained, verbose=verbose, 77 | text_length=text_length) 78 | elif data_type == 'amazon': 79 | train_data, test_data = amazon_dataset(preprocessing=preprocessing, fine_grained=fine_grained, verbose=verbose, 80 | text_length=text_length) 81 | else: 82 | raise ValueError('{} data type not supported.'.format(data_type)) 83 | 84 | if encode: 85 | sentence_corpus = [row['text'] for row in datasets_iterator(train_data, )] 86 | sentence_encoder = WhitespaceEncoder(sentence_corpus, 87 | reserved_tokens=[DEFAULT_PADDING_TOKEN, DEFAULT_UNKNOWN_TOKEN]) 88 | label_corpus = [row['label'] for row in datasets_iterator(train_data, )] 89 | label_encoder = LabelEncoder(label_corpus, reserved_labels=[]) 90 | 91 | # Encode 92 | for row in datasets_iterator(train_data, test_data): 93 | row['text'] = sentence_encoder.encode(row['text']) 94 | row['label'] = label_encoder.encode(row['label']) 95 | return sentence_encoder, label_encoder, train_data, test_data 96 | else: 97 | return train_data, test_data 98 | 99 | 100 | def collate_fn(batch): 101 | """ list of tensors to a batch tensors """ 102 | text_batch, _ = stack_and_pad_tensors([row['text'] for row in batch]) 103 | label_batch = [row['label'].unsqueeze(0) for row in batch] 104 | return [text_batch, torch.cat(label_batch)] 105 | 106 | 107 | if __name__ == '__main__': 108 | parser = argparse.ArgumentParser(description='Generate Preprocessed Data') 109 | parser.add_argument('--data_type', default='imdb', type=str, 110 | choices=['imdb', 'newsgroups', 'reuters', 'webkb', 'cade', 'dbpedia', 'agnews', 'yahoo', 111 | 'sogou', 'yelp', 'amazon'], help='dataset type') 112 | parser.add_argument('--fine_grained', action='store_true', help='use fine grained class or not, it only works for ' 113 | 'reuters, yelp and amazon') 114 | opt = parser.parse_args() 115 | DATA_TYPE, FINE_GRAINED = opt.data_type, opt.fine_grained 116 | train_dataset, test_dataset = load_data(DATA_TYPE, preprocessing=None, fine_grained=FINE_GRAINED, encode=False) 117 | 118 | if FINE_GRAINED and DATA_TYPE in ['reuters', 'yelp', 'amazon']: 119 | train_file = os.path.join('data', DATA_TYPE, 'preprocessed_fine_grained_train.csv') 120 | test_file = os.path.join('data', DATA_TYPE, 'preprocessed_fine_grained_test.csv') 121 | else: 122 | train_file = os.path.join('data', DATA_TYPE, 'preprocessed_train.csv') 123 | test_file = os.path.join('data', DATA_TYPE, 'preprocessed_test.csv') 124 | 125 | # save files 126 | print('Saving preprocessed {} dataset into {}... '.format(DATA_TYPE, os.path.join('data', DATA_TYPE)), end='') 127 | train_label, train_text, test_label, test_text = [], [], [], [] 128 | for data in train_dataset: 129 | train_label.append(data['label']) 130 | train_text.append(data['text']) 131 | for data in test_dataset: 132 | test_label.append(data['label']) 133 | test_text.append(data['text']) 134 | train_data_frame = pd.DataFrame({'label': train_label, 'text': train_text}) 135 | test_data_frame = pd.DataFrame({'label': test_label, 'text': test_text}) 136 | train_data_frame.to_csv(train_file, header=False, index=False) 137 | test_data_frame.to_csv(test_file, header=False, index=False) 138 | print('Done.') 139 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | 8 | from model import Model 9 | from utils import load_data 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser(description='Vis Embedding and Code') 13 | parser.add_argument('--model_weight', default=None, type=str, help='saved model weight to load') 14 | parser.add_argument('--routing_type', default='k_means', type=str, choices=['k_means', 'dynamic'], 15 | help='routing type, it only works for capsule classifier') 16 | parser.add_argument('--embedding_size', default=64, type=int, help='embedding size') 17 | parser.add_argument('--num_codebook', default=8, type=int, 18 | help='codebook number, it only works for cwc and cc embedding') 19 | parser.add_argument('--num_codeword', default=None, type=int, 20 | help='codeword number, it only works for cwc and cc embedding') 21 | parser.add_argument('--hidden_size', default=128, type=int, help='hidden size') 22 | parser.add_argument('--in_length', default=8, type=int, 23 | help='in capsule length, it only works for capsule classifier') 24 | parser.add_argument('--out_length', default=16, type=int, 25 | help='out capsule length, it only works for capsule classifier') 26 | parser.add_argument('--num_iterations', default=3, type=int, 27 | help='routing iterations number, it only works for capsule classifier') 28 | parser.add_argument('--num_repeat', default=10, type=int, 29 | help='gumbel softmax repeat number, it only works for cc embedding') 30 | parser.add_argument('--drop_out', default=0.5, type=float, help='drop_out rate of GRU layer') 31 | 32 | opt = parser.parse_args() 33 | MODEL_WEIGHT, ROUTING_TYPE, EMBEDDING_SIZE = opt.model_weight, opt.routing_type, opt.embedding_size 34 | NUM_CODEBOOK, NUM_CODEWORD, HIDDEN_SIZE = opt.num_codebook, opt.num_codeword, opt.hidden_size 35 | IN_LENGTH, OUT_LENGTH, NUM_ITERATIONS, DROP_OUT = opt.in_length, opt.out_length, opt.num_iterations, opt.drop_out 36 | NUM_REPEAT = opt.num_repeat 37 | configs = MODEL_WEIGHT.split('_') 38 | if len(configs) == 4: 39 | DATA_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, TEXT_LENGTH = configs 40 | FINE_GRAINED, TEXT_LENGTH = False, int(TEXT_LENGTH.split('.')[0]) 41 | else: 42 | DATA_TYPE, _, EMBEDDING_TYPE, CLASSIFIER_TYPE, TEXT_LENGTH = configs 43 | FINE_GRAINED, TEXT_LENGTH = True, int(TEXT_LENGTH.split('.')[0]) 44 | 45 | data_name = '{}_fine-grained'.format(DATA_TYPE) if FINE_GRAINED else DATA_TYPE 46 | 47 | print('Loading {} dataset'.format(data_name)) 48 | # get sentence encoder 49 | sentence_encoder, label_encoder, _, _ = load_data(DATA_TYPE, preprocessing=True, fine_grained=FINE_GRAINED, 50 | verbose=True, text_length=TEXT_LENGTH) 51 | VOCAB_SIZE, NUM_CLASS = sentence_encoder.vocab_size, label_encoder.vocab_size 52 | 53 | model = Model(VOCAB_SIZE, EMBEDDING_SIZE, NUM_CODEBOOK, NUM_CODEWORD, HIDDEN_SIZE, IN_LENGTH, OUT_LENGTH, 54 | NUM_CLASS, ROUTING_TYPE, EMBEDDING_TYPE, CLASSIFIER_TYPE, NUM_ITERATIONS, NUM_REPEAT, DROP_OUT) 55 | model.load_state_dict(torch.load('epochs/{}'.format(MODEL_WEIGHT), map_location='cpu')) 56 | if torch.cuda.is_available(): 57 | model, cudnn.benchmark = model.to('cuda'), True 58 | 59 | model.eval() 60 | print('Generating embedding and code for {} dataset'.format(data_name)) 61 | with torch.no_grad(): 62 | if EMBEDDING_TYPE == 'normal': 63 | vocabs = model.embedding.weight.detach().cpu().numpy() 64 | codes = torch.ones(1, 1, sentence_encoder.vocab_size) 65 | else: 66 | embedding = model.embedding 67 | embedding.return_code = True 68 | data = torch.arange(sentence_encoder.vocab_size).view(1, -1) 69 | if torch.cuda.is_available(): 70 | data = data.to('cuda') 71 | out, code = embedding(data) 72 | # [num_embeddings, embedding_dim], ([num_embeddings, num_codebook, num_codeword], [1, 1, num_embeddings]) 73 | vocabs, codes = out.squeeze(dim=0).detach().cpu().numpy(), code.squeeze(dim=0).detach().cpu() 74 | 75 | print('Plotting code usage for {} dataset'.format(data_name)) 76 | reduced_codes = codes.sum(dim=0).float() 77 | c_max, c_min = reduced_codes.max().item(), reduced_codes.min().item() 78 | f, ax = plt.subplots(figsize=(10, 5)) 79 | heat_map = sns.heatmap(reduced_codes.numpy(), vmin=c_min, vmax=c_max, annot=True, fmt='.2f', ax=ax) 80 | ax.set_title('Code usage of {} embedding for {} dataset'.format(EMBEDDING_TYPE, data_name)) 81 | ax.set_xlabel('codeword') 82 | ax.set_ylabel('codebook') 83 | f.savefig('results/{}_{}_code.jpg'.format(data_name, EMBEDDING_TYPE)) 84 | --------------------------------------------------------------------------------