├── .gitignore ├── LICENSE ├── README.md ├── datasets.py ├── main_classification.py ├── main_reconstruction.py ├── model.py ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WIP]textcnn-conv-deconv-pytorch 2 | Text convolution-deconvolution auto-encoder and classification model in PyTorch. 3 | PyTorch implementation of [Deconvolutional Paragraph Representation Learning](https://arxiv.org/abs/1708.04729v3) described in NIPS 2017. 4 | **This repository is still developing.** 5 | 6 | ## Requirement 7 | - Python 3 8 | - PyTorch >= 0.3 9 | - numpy 10 | 11 | ## Usage 12 | ### Train 13 | #### Paragraph reconstruction 14 | Download data. [Hotel reviews](https://drive.google.com/file/d/0B52eYWrYWqIpQzhBNkVxaV9mMjQ/view) 15 | Then, run following command. 16 | 17 | ```shell 18 | $ python main_reconstruction.py -data_path=path/to/hotel_reviews.p 19 | ``` 20 | 21 | Specify download data path by `-data_path`. 22 | 23 | About other parameters. 24 | 25 | ``` 26 | usage: main_reconstruction.py [-h] [-lr LR] [-epochs EPOCHS] 27 | [-batch_size BATCH_SIZE] 28 | [-lr_decay_interval LR_DECAY_INTERVAL] 29 | [-log_interval LOG_INTERVAL] 30 | [-test_interval TEST_INTERVAL] 31 | [-save_interval SAVE_INTERVAL] 32 | [-save_dir SAVE_DIR] [-data_path DATA_PATH] 33 | [-shuffle SHUFFLE] [-sentence_len SENTENCE_LEN] 34 | [-embed_dim EMBED_DIM] 35 | [-kernel_sizes KERNEL_SIZES] [-tau TAU] 36 | [-use_cuda] [-enc_snapshot ENC_SNAPSHOT] 37 | [-dec_snapshot DEC_SNAPSHOT] 38 | 39 | text convolution-deconvolution auto-encoder model 40 | 41 | optional arguments: 42 | -h, --help show this help message and exit 43 | -lr LR initial learning rate 44 | -epochs EPOCHS number of epochs for train 45 | -batch_size BATCH_SIZE 46 | batch size for training 47 | -lr_decay_interval LR_DECAY_INTERVAL 48 | how many epochs to wait before decrease learning rate 49 | -log_interval LOG_INTERVAL 50 | how many steps to wait before logging training status 51 | -test_interval TEST_INTERVAL 52 | how many epochs to wait before testing 53 | -save_interval SAVE_INTERVAL 54 | how many epochs to wait before saving 55 | -save_dir SAVE_DIR where to save the snapshot 56 | -data_path DATA_PATH data path 57 | -shuffle SHUFFLE shuffle data every epoch 58 | -sentence_len SENTENCE_LEN 59 | how many tokens in a sentence 60 | -embed_dim EMBED_DIM number of embedding dimension 61 | -kernel_sizes KERNEL_SIZES 62 | kernel size to use for convolution 63 | -tau TAU temperature parameter 64 | -use_cuda whether using cuda 65 | -enc_snapshot ENC_SNAPSHOT 66 | filename of encoder snapshot 67 | -dec_snapshot DEC_SNAPSHOT 68 | filename of decoder snapshot 69 | ``` 70 | 71 | #### Semi-supervised sequence classification 72 | Run follow command. 73 | 74 | ```shell 75 | $ python main.py -data_path=path/to/trainingdata -label_path=path/to/labeldata 76 | ``` 77 | 78 | Specify training data and label data by `-data_path` and `-label_data` arguments. 79 | Both data must have same lines and training data must be separated by blank. 80 | 81 | About other parameters. 82 | 83 | ``` 84 | usage: main_classification.py [-h] [-lr LR] [-epochs EPOCHS] 85 | [-batch_size BATCH_SIZE] 86 | [-lr_decay_interval LR_DECAY_INTERVAL] 87 | [-log_interval LOG_INTERVAL] 88 | [-test_interval TEST_INTERVAL] 89 | [-save_interval SAVE_INTERVAL] 90 | [-save_dir SAVE_DIR] [-data_path DATA_PATH] 91 | [-label_path LABEL_PATH] [-separated SEPARATED] 92 | [-shuffle SHUFFLE] [-sentence_len SENTENCE_LEN] 93 | [-mlp_out MLP_OUT] [-dropout DROPOUT] 94 | [-embed_dim EMBED_DIM] 95 | [-kernel_sizes KERNEL_SIZES] [-tau TAU] 96 | [-use_cuda] [-enc_snapshot ENC_SNAPSHOT] 97 | [-dec_snapshot DEC_SNAPSHOT] 98 | [-mlp_snapshot MLP_SNAPSHOT] 99 | 100 | text convolution-deconvolution auto-encoder model 101 | 102 | optional arguments: 103 | -h, --help show this help message and exit 104 | -lr LR initial learning rate 105 | -epochs EPOCHS number of epochs for train 106 | -batch_size BATCH_SIZE 107 | batch size for training 108 | -lr_decay_interval LR_DECAY_INTERVAL 109 | how many epochs to wait before decrease learning rate 110 | -log_interval LOG_INTERVAL 111 | how many steps to wait before logging training status 112 | -test_interval TEST_INTERVAL 113 | how many steps to wait before testing 114 | -save_interval SAVE_INTERVAL 115 | how many epochs to wait before saving 116 | -save_dir SAVE_DIR where to save the snapshot 117 | -data_path DATA_PATH data path 118 | -label_path LABEL_PATH 119 | label path 120 | -separated SEPARATED how separated text data is 121 | -shuffle SHUFFLE shuffle the data every epoch 122 | -sentence_len SENTENCE_LEN 123 | how many tokens in a sentence 124 | -mlp_out MLP_OUT number of classes 125 | -dropout DROPOUT the probability for dropout 126 | -embed_dim EMBED_DIM number of embedding dimension 127 | -kernel_sizes KERNEL_SIZES 128 | kernel size to use for convolution 129 | -tau TAU temperature parameter 130 | -use_cuda whether using cuda 131 | -enc_snapshot ENC_SNAPSHOT 132 | filename of encoder snapshot 133 | -dec_snapshot DEC_SNAPSHOT 134 | filename of decoder snapshot 135 | -mlp_snapshot MLP_SNAPSHOT 136 | filename of mlp classifier snapshot 137 | ``` 138 | 139 | ## Reference 140 | [Deconvolutional Paragraph Representation Learning](https://arxiv.org/abs/1708.04729v3) 141 | Yizhe Zhang, Dinghan Shen, Guoyin Wang, Zhe Gan, Ricardo Henao, Lawrence Carin 142 | arXiv:1708.04729 [cs.CL] 143 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from collections import Counter 7 | from copy import deepcopy 8 | 9 | 10 | def load_hotel_review_data(path, sentence_len): 11 | """ 12 | Load Hotel Reviews data from pickle distributed in https://drive.google.com/file/d/0B52eYWrYWqIpQzhBNkVxaV9mMjQ/view 13 | This file is published in https://github.com/dreasysnail/textCNN_public 14 | 15 | :param path: pickle path 16 | :return: 17 | """ 18 | import _pickle as cPickle 19 | with open(path, "rb") as f: 20 | data = cPickle.load(f, encoding="latin1") 21 | 22 | train_data, test_data = HotelReviewsDataset(data[0], deepcopy(data[2]), deepcopy(data[3]), sentence_len, transform=ToTensor()), \ 23 | HotelReviewsDataset(data[1], deepcopy(data[2]), deepcopy(data[3]), sentence_len, transform=ToTensor()) 24 | return train_data, test_data 25 | 26 | 27 | class HotelReviewsDataset(Dataset): 28 | """ 29 | Hotel Reviews Dataset 30 | """ 31 | def __init__(self, data_list, word2index, index2word, sentence_len, transform=None): 32 | self.word2index = word2index 33 | self.index2word = index2word 34 | self.n_words = len(self.word2index) 35 | self.data = data_list 36 | self.sentence_len = sentence_len 37 | self.transform = transform 38 | self.word2index[""] = self.n_words 39 | self.index2word[self.n_words] = "" 40 | self.n_words += 1 41 | temp_list = [] 42 | for sentence in tqdm(self.data): 43 | if len(sentence) > self.sentence_len: 44 | # truncate sentence if sentence length is longer than `sentence_len` 45 | temp_list.append(np.array(sentence[:self.sentence_len])) 46 | else: 47 | # pad sentence with '' token if sentence length is shorter than `sentence_len` 48 | sent_array = np.lib.pad(np.array(sentence), 49 | (0, self.sentence_len - len(sentence)), 50 | "constant", 51 | constant_values=(self.n_words-1, self.n_words-1)) 52 | temp_list.append(sent_array) 53 | self.data = np.array(temp_list, dtype=np.int32) 54 | 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | def __getitem__(self, idx): 60 | data = self.data[idx] 61 | if self.transform: 62 | data = self.transform(data) 63 | return data 64 | 65 | def vocab_lennght(self): 66 | return len(self.word2index) 67 | 68 | 69 | class TextClassificationDataset(Dataset): 70 | def __init__(self, data_path, label_path, tokenized, sentence_len=60, transoform=None): 71 | self.word2index = {"": 0, "": 1} 72 | self.index2word = {0: "", 1: ""} 73 | self.n_words = 2 74 | self.sentence_len = sentence_len 75 | # Data load 76 | with open(data_path, encoding="utf-8") as f: 77 | data = [line.split() for line in f] 78 | 79 | if tokenized == "mecab": 80 | # replace low frequency word to UNK token 81 | word_bucket = [] 82 | for sentence in data: 83 | word_bucket.extend(sentence) 84 | cnt = Counter(word_bucket) 85 | rare_word = [] 86 | for common in cnt.most_common(): 87 | if common[1] <= 2: 88 | rare_word.append(common[0]) 89 | print("Rare word") 90 | rare_word = set(rare_word) 91 | print(len(rare_word)) 92 | 93 | for sentence in data: 94 | for word in sentence: 95 | if word in rare_word: 96 | continue 97 | elif word not in self.word2index: 98 | self.word2index[word] = self.n_words 99 | self.index2word[self.n_words] = word 100 | self.n_words += 1 101 | # Transform to idx 102 | self.data = np.array([[self.word2index[word] 103 | if word not in rare_word 104 | else self.word2index[""] for word in sentence] 105 | for sentence in tqdm(data)]) 106 | elif tokenized == "sentencepiece": 107 | for sentence in data: 108 | # remove meta symbol 109 | # TODO:this process remove blank which in sentene. Are there other method? 110 | for word in map(lambda word: word.replace("▁", ""), sentence): 111 | if word not in self.word2index: 112 | self.word2index[word] = self.n_words 113 | self.index2word[self.n_words] = word 114 | self.n_words += 1 115 | self.data = np.array([[self.word2index[word] for word in map(lambda word: word.replace("▁", ""), sentence)] 116 | for sentence in tqdm(data)]) 117 | 118 | temp_list = [] 119 | for sentence in self.data: 120 | if len(sentence) > self.sentence_len: 121 | # truncate sentence if sentence length is longer than `sentence_len` 122 | temp_list.append(np.array(sentence[:self.sentence_len])) 123 | else: 124 | # pad sentence with '' token if sentence length is shorter than `sentence_len` 125 | sent_array = np.lib.pad(np.array(sentence), 126 | (0, self.sentence_len - len(sentence)), 127 | "constant", 128 | constant_values=(0, 0)) 129 | temp_list.append(sent_array) 130 | self.data = np.array(temp_list, dtype=np.int32) 131 | with open(label_path, encoding="utf-8") as f: 132 | self.labels = np.array([np.array([int(label)]) for label in f], dtype=np.int32) 133 | self.transform = transoform 134 | 135 | def __len__(self): 136 | return len(self.data) 137 | 138 | def __getitem__(self, idx): 139 | sentence = self.data[idx] 140 | label = self.labels[idx] 141 | sample = {"sentence": sentence, "label": label} 142 | 143 | if self.transform: 144 | sample = {"sentence": self.transform(sample["sentence"]), 145 | "label": self.transform(sample["label"])} 146 | 147 | return sample 148 | 149 | def vocab_length(self): 150 | return self.n_words 151 | 152 | 153 | class ToTensor(object): 154 | """Convert ndarrays in sample to Tensors.""" 155 | def __call__(self, data): 156 | return torch.from_numpy(data).type(torch.LongTensor) 157 | -------------------------------------------------------------------------------- /main_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | 5 | import model 6 | from datasets import TextClassificationDataset, ToTensor 7 | from train import train_classification 8 | 9 | import argparse 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description='text convolution-deconvolution auto-encoder model') 14 | # learning 15 | parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate') 16 | parser.add_argument('-epochs', type=int, default=60, help='number of epochs for train') 17 | parser.add_argument('-batch_size', type=int, default=64, help='batch size for training') 18 | parser.add_argument('-lr_decay_interval', type=int, default=20, 19 | help='how many epochs to wait before decrease learning rate') 20 | parser.add_argument('-log_interval', type=int, default=16, 21 | help='how many steps to wait before logging training status') 22 | parser.add_argument('-test_interval', type=int, default=100, 23 | help='how many steps to wait before testing') 24 | parser.add_argument('-save_interval', type=int, default=5, 25 | help='how many epochs to wait before saving') 26 | parser.add_argument('-save_dir', type=str, default='snapshot', help='where to save the snapshot') 27 | # data 28 | parser.add_argument('-data_path', type=str, help='data path') 29 | parser.add_argument('-label_path', type=str, help='label path') 30 | parser.add_argument('-separated', type=str, default='sentencepiece', help='how separated text data is') 31 | parser.add_argument('-shuffle', default=False, help='shuffle the data every epoch') 32 | parser.add_argument('-sentence_len', type=int, default=60, help='how many tokens in a sentence') 33 | # model 34 | parser.add_argument('-mlp_out', type=int, default=7, help='number of classes') 35 | parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout') 36 | parser.add_argument('-embed_dim', type=int, default=300, help='number of embedding dimension') 37 | parser.add_argument('-kernel_sizes', type=int, default=2, 38 | help='kernel size to use for convolution') 39 | parser.add_argument('-tau', type=float, default=0.01, help='temperature parameter') 40 | parser.add_argument('-use_cuda', action='store_true', default=True, help='whether using cuda') 41 | # option 42 | parser.add_argument('-enc_snapshot', type=str, default=None, help='filename of encoder snapshot ') 43 | parser.add_argument('-dec_snapshot', type=str, default=None, help='filename of decoder snapshot ') 44 | parser.add_argument('-mlp_snapshot', type=str, default=None, help='filename of mlp classifier snapshot ') 45 | args = parser.parse_args() 46 | 47 | dataset = TextClassificationDataset(args.data_path, 48 | args.label_path, 49 | args.separated, 50 | sentence_len=args.sentence_len, 51 | transoform=ToTensor()) 52 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=args.shuffle) 53 | print("Vocab number") 54 | print(dataset.vocab_length()) 55 | 56 | k = args.embed_dim 57 | v = dataset.vocab_length() 58 | if args.enc_snapshot is None or args.dec_snapshot is None or args.mlp_snapshot is None: 59 | print("Start from initial") 60 | embedding = nn.Embedding(v, k, max_norm=1.0, norm_type=2.0) 61 | 62 | encoder = model.ConvolutionEncoder(embedding) 63 | decoder = model.DeconvolutionDecoder(embedding, args.tau) 64 | mlp = model.MLPClassifier(args.mlp_out, args.dropout) 65 | else: 66 | print("Restart from snapshot") 67 | encoder = torch.load(args.enc_snapshot) 68 | decoder = torch.load(args.dec_snapshot) 69 | mlp = torch.load(args.mlp_snapshot) 70 | 71 | train_classification(data_loader, data_loader, encoder, decoder, mlp, args) 72 | 73 | if __name__ == '__main__': 74 | main() -------------------------------------------------------------------------------- /main_reconstruction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | 6 | import model 7 | from datasets import TextClassificationDataset, ToTensor, load_hotel_review_data 8 | from train import train_reconstruction 9 | 10 | import argparse 11 | import math 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description='text convolution-deconvolution auto-encoder model') 16 | # learning 17 | parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate') 18 | parser.add_argument('-epochs', type=int, default=10, help='number of epochs for train') 19 | parser.add_argument('-batch_size', type=int, default=16, help='batch size for training') 20 | parser.add_argument('-lr_decay_interval', type=int, default=4, 21 | help='how many epochs to wait before decrease learning rate') 22 | parser.add_argument('-log_interval', type=int, default=256, 23 | help='how many steps to wait before logging training status') 24 | parser.add_argument('-test_interval', type=int, default=2, 25 | help='how many epochs to wait before testing') 26 | parser.add_argument('-save_interval', type=int, default=2, 27 | help='how many epochs to wait before saving') 28 | parser.add_argument('-save_dir', type=str, default='rec_snapshot', help='where to save the snapshot') 29 | # data 30 | parser.add_argument('-data_path', type=str, help='data path') 31 | parser.add_argument('-shuffle', default=False, help='shuffle data every epoch') 32 | parser.add_argument('-sentence_len', type=int, default=253, help='how many tokens in a sentence') 33 | # model 34 | parser.add_argument('-embed_dim', type=int, default=300, help='number of embedding dimension') 35 | parser.add_argument('-filter_size', type=int, default=300, help='filter size of convolution') 36 | parser.add_argument('-filter_shape', type=int, default=5, 37 | help='filter shape to use for convolution') 38 | parser.add_argument('-latent_size', type=int, default=900, help='size of latent variable') 39 | parser.add_argument('-tau', type=float, default=0.01, help='temperature parameter') 40 | parser.add_argument('-use_cuda', action='store_true', default=True, help='whether using cuda') 41 | # option 42 | parser.add_argument('-enc_snapshot', type=str, default=None, help='filename of encoder snapshot ') 43 | parser.add_argument('-dec_snapshot', type=str, default=None, help='filename of decoder snapshot ') 44 | args = parser.parse_args() 45 | 46 | train_data, test_data = load_hotel_review_data(args.data_path, args.sentence_len) 47 | train_loader, test_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=args.shuffle),\ 48 | DataLoader(test_data, batch_size=args.batch_size, shuffle=args.shuffle) 49 | 50 | k = args.embed_dim 51 | v = train_data.vocab_lennght() 52 | t1 = args.sentence_len + 2 * (args.filter_shape - 1) 53 | t2 = int(math.floor((t1 - args.filter_shape) / 2) + 1) # "2" means stride size 54 | t3 = int(math.floor((t2 - args.filter_shape) / 2) + 1) - 2 55 | if args.enc_snapshot is None or args.dec_snapshot is None: 56 | print("Start from initial") 57 | embedding = nn.Embedding(v, k, max_norm=1.0, norm_type=2.0) 58 | 59 | encoder = model.ConvolutionEncoder(embedding, t3, args.filter_size, args.filter_shape, args.latent_size) 60 | decoder = model.DeconvolutionDecoder(embedding, args.tau, t3, args.filter_size, args.filter_shape, args.latent_size) 61 | else: 62 | print("Restart from snapshot") 63 | encoder = torch.load(args.enc_snapshot) 64 | decoder = torch.load(args.dec_snapshot) 65 | 66 | train_reconstruction(train_loader, test_loader, encoder, decoder, args) 67 | 68 | if __name__ == '__main__': 69 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | 7 | class ConvolutionEncoder(nn.Module): 8 | def __init__(self, embedding, sentence_len, filter_size, filter_shape, latent_size): 9 | super(ConvolutionEncoder, self).__init__() 10 | self.embed = embedding 11 | self.convs1 = nn.Conv2d(1, filter_size, (filter_shape, self.embed.weight.size()[1]), stride=2) 12 | self.bn1 = nn.BatchNorm2d(filter_size) 13 | self.convs2 = nn.Conv2d(filter_size, filter_size * 2, (filter_shape, 1), stride=2) 14 | self.bn2 = nn.BatchNorm2d(filter_size * 2) 15 | self.convs3 = nn.Conv2d(filter_size * 2, latent_size, (sentence_len, 1), stride=2) 16 | 17 | # weight initialize for conv layer 18 | for m in self.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 21 | m.weight.data.normal_(0, math.sqrt(2. / n)) 22 | 23 | def __call__(self, x): 24 | x = self.embed(x) 25 | 26 | # x.size() is (L, emb_dim) if batch_size is 1. 27 | # So interpolate x's dimension if batch_size is 1. 28 | if len(x.size()) < 3: 29 | x = x.view(1, *x.size()) 30 | # reshape for convolution layer 31 | x = x.view(x.size()[0], 1, x.size()[1], x.size()[2]) 32 | 33 | h1 = F.relu(self.bn1(self.convs1(x))) 34 | h2 = F.relu(self.bn2(self.convs2(h1))) 35 | h3 = F.relu(self.convs3(h2)) 36 | 37 | return h3 38 | 39 | 40 | class DeconvolutionDecoder(nn.Module): 41 | def __init__(self, embedding, tau, sentence_len, filter_size, filter_shape, latent_size): 42 | super(DeconvolutionDecoder, self).__init__() 43 | self.tau = tau 44 | self.embed = embedding 45 | self.deconvs1 = nn.ConvTranspose2d(latent_size, filter_size * 2, (sentence_len, 1), stride=2) 46 | self.bn1 = nn.BatchNorm2d(filter_size * 2) 47 | self.deconvs2 = nn.ConvTranspose2d(filter_size * 2, filter_size, (filter_shape, 1), stride=2) 48 | self.bn2 = nn.BatchNorm2d(filter_size) 49 | self.deconvs3 = nn.ConvTranspose2d(filter_size, 1, (filter_shape, self.embed.weight.size()[1]), stride=2) 50 | 51 | # weight initialize for conv_transpose layer 52 | for m in self.modules(): 53 | if isinstance(m, nn.ConvTranspose2d): 54 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 55 | m.weight.data.normal_(0, math.sqrt(2. / n)) 56 | 57 | def __call__(self, h3): 58 | h2 = F.relu(self.bn1(self.deconvs1(h3))) 59 | h1 = F.relu(self.bn2(self.deconvs2(h2))) 60 | x_hat = F.relu(self.deconvs3(h1)) 61 | x_hat = x_hat.squeeze() 62 | 63 | # x.size() is (L, emb_dim) if batch_size is 1. 64 | # So interpolate x's dimension if batch_size is 1. 65 | if len(x_hat.size()) < 3: 66 | x_hat = x_hat.view(1, *x_hat.size()) 67 | # normalize 68 | norm_x_hat = torch.norm(x_hat, 2, dim=2, keepdim=True) 69 | rec_x_hat = x_hat / norm_x_hat 70 | 71 | # compute probability 72 | norm_w = Variable(self.embed.weight.data).t() 73 | prob_logits = torch.bmm(rec_x_hat, norm_w.unsqueeze(0) 74 | .expand(rec_x_hat.size(0), *norm_w.size())) / self.tau 75 | log_prob = F.log_softmax(prob_logits, dim=2) 76 | return log_prob 77 | 78 | 79 | class MLPClassifier(nn.Module): 80 | def __init__(self, output_dim, dropout): 81 | super(MLPClassifier, self).__init__() 82 | self.fc1 = nn.Linear(500, 300) 83 | self.out = nn.Linear(300, output_dim) 84 | self.dropout = nn.Dropout(dropout) 85 | 86 | def forward(self, x): 87 | h = self.dropout(self.fc1(x)) 88 | out = self.out(h) 89 | return F.log_softmax(out, dim=1) 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import pickle 5 | from sumeval.metrics.rouge import RougeCalculator 6 | from sumeval.metrics.bleu import BLEUCalculator 7 | from hyperdash import Experiment 8 | 9 | import util 10 | 11 | def train_classification(data_loader, dev_iter, encoder, decoder, mlp, args): 12 | lr = args.lr 13 | encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr) 14 | decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr) 15 | mlp_opt = torch.optim.Adam(mlp.parameters(), lr=lr) 16 | 17 | encoder.train() 18 | decoder.train() 19 | mlp.train() 20 | steps = 0 21 | for epoch in range(1, args.epochs+1): 22 | alpha = util.sigmoid_annealing_schedule(epoch, args.epochs) 23 | print("=======Epoch========") 24 | print(epoch) 25 | for batch in data_loader: 26 | feature, target = Variable(batch["sentence"]), Variable(batch["label"]) 27 | if args.use_cuda: 28 | encoder.cuda() 29 | decoder.cuda() 30 | mlp.cuda() 31 | feature, target = feature.cuda(), target.cuda() 32 | 33 | encoder_opt.zero_grad() 34 | decoder_opt.zero_grad() 35 | mlp_opt.zero_grad() 36 | 37 | h = encoder(feature) 38 | prob = decoder(h) 39 | log_prob = mlp(h.squeeze()) 40 | reconstruction_loss = compute_cross_entropy(prob, feature) 41 | supervised_loss = F.nll_loss(log_prob, target.view(target.size()[0])) 42 | loss = alpha * reconstruction_loss + supervised_loss 43 | loss.backward() 44 | encoder_opt.step() 45 | decoder_opt.step() 46 | mlp_opt.step() 47 | 48 | steps += 1 49 | print("Epoch: {}".format(epoch)) 50 | print("Steps: {}".format(steps)) 51 | print("Loss: {}".format(loss.data[0])) 52 | # check reconstructed sentence and classification 53 | if steps % args.log_interval == 0: 54 | print("Test!!") 55 | input_data = feature[0] 56 | input_label = target[0] 57 | single_data = prob[0] 58 | _, predict_index = torch.max(single_data, 1) 59 | input_sentence = util.transform_id2word(input_data.data, data_loader.dataset.index2word, lang="ja") 60 | predict_sentence = util.transform_id2word(predict_index.data, data_loader.dataset.index2word, lang="ja") 61 | print("Input Sentence:") 62 | print(input_sentence) 63 | print("Output Sentence:") 64 | print(predict_sentence) 65 | eval_classification(encoder, mlp, input_data, input_label) 66 | 67 | if epoch % args.lr_decay_interval == 0: 68 | # decrease learning rate 69 | lr = lr / 5 70 | encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr) 71 | decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr) 72 | mlp_opt = torch.optim.Adam(mlp.parameters(), lr=lr) 73 | encoder.train() 74 | decoder.train() 75 | mlp.train() 76 | 77 | if epoch % args.save_interval == 0: 78 | util.save_models(encoder, args.save_dir, "encoder", steps) 79 | util.save_models(decoder, args.save_dir, "decoder", steps) 80 | util.save_models(mlp, args.save_dir, "mlp", steps) 81 | 82 | # finalization 83 | # save vocabulary 84 | with open("word2index", "wb") as w2i, open("index2word", "wb") as i2w: 85 | pickle.dump(data_loader.dataset.word2index, w2i) 86 | pickle.dump(data_loader.dataset.index2word, i2w) 87 | 88 | # save models 89 | util.save_models(encoder, args.save_dir, "encoder", "final") 90 | util.save_models(decoder, args.save_dir, "decoder", "final") 91 | util.save_models(mlp, args.save_dir, "mlp", "final") 92 | 93 | print("Finish!!!") 94 | 95 | 96 | def train_reconstruction(train_loader, test_loader, encoder, decoder, args): 97 | exp = Experiment("Reconstruction Training") 98 | try: 99 | lr = args.lr 100 | encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr) 101 | decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr) 102 | 103 | encoder.train() 104 | decoder.train() 105 | steps = 0 106 | for epoch in range(1, args.epochs+1): 107 | print("=======Epoch========") 108 | print(epoch) 109 | for batch in train_loader: 110 | feature = Variable(batch) 111 | if args.use_cuda: 112 | encoder.cuda() 113 | decoder.cuda() 114 | feature = feature.cuda() 115 | 116 | encoder_opt.zero_grad() 117 | decoder_opt.zero_grad() 118 | 119 | h = encoder(feature) 120 | prob = decoder(h) 121 | reconstruction_loss = compute_cross_entropy(prob, feature) 122 | reconstruction_loss.backward() 123 | encoder_opt.step() 124 | decoder_opt.step() 125 | 126 | steps += 1 127 | print("Epoch: {}".format(epoch)) 128 | print("Steps: {}".format(steps)) 129 | print("Loss: {}".format(reconstruction_loss.data[0] / args.sentence_len)) 130 | exp.metric("Loss", reconstruction_loss.data[0] / args.sentence_len) 131 | # check reconstructed sentence 132 | if steps % args.log_interval == 0: 133 | print("Test!!") 134 | input_data = feature[0] 135 | single_data = prob[0] 136 | _, predict_index = torch.max(single_data, 1) 137 | input_sentence = util.transform_id2word(input_data.data, train_loader.dataset.index2word, lang="en") 138 | predict_sentence = util.transform_id2word(predict_index.data, train_loader.dataset.index2word, lang="en") 139 | print("Input Sentence:") 140 | print(input_sentence) 141 | print("Output Sentence:") 142 | print(predict_sentence) 143 | 144 | if steps % args.test_interval == 0: 145 | eval_reconstruction(encoder, decoder, test_loader, args) 146 | 147 | 148 | if epoch % args.lr_decay_interval == 0: 149 | # decrease learning rate 150 | lr = lr / 5 151 | encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr) 152 | decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr) 153 | encoder.train() 154 | decoder.train() 155 | 156 | if epoch % args.save_interval == 0: 157 | util.save_models(encoder, args.save_dir, "encoder", steps) 158 | util.save_models(decoder, args.save_dir, "decoder", steps) 159 | 160 | # finalization 161 | # save vocabulary 162 | with open("word2index", "wb") as w2i, open("index2word", "wb") as i2w: 163 | pickle.dump(train_loader.dataset.word2index, w2i) 164 | pickle.dump(train_loader.dataset.index2word, i2w) 165 | 166 | # save models 167 | util.save_models(encoder, args.save_dir, "encoder", "final") 168 | util.save_models(decoder, args.save_dir, "decoder", "final") 169 | 170 | print("Finish!!!") 171 | finally: 172 | exp.end() 173 | 174 | 175 | def compute_cross_entropy(log_prob, target): 176 | # compute reconstruction loss using cross entropy 177 | loss = [F.nll_loss(sentence_emb_matrix, word_ids, size_average=False) for sentence_emb_matrix, word_ids in zip(log_prob, target)] 178 | average_loss = sum([torch.sum(l) for l in loss]) / log_prob.size()[0] 179 | return average_loss 180 | 181 | def eval_classification(encoder, mlp, feature, label): 182 | encoder.eval() 183 | mlp.eval() 184 | h = encoder(feature) 185 | h = h.view(1, 500) 186 | out = mlp(h) 187 | value, predicted = torch.max(out, 0) 188 | print("Input label: {}".format(label.data[0])) 189 | print("Predicted label: {}".format(predicted.data[0])) 190 | print("Predicted value: {}".format(value.data[0])) 191 | encoder.train() 192 | mlp.train() 193 | 194 | 195 | def eval_reconstruction(encoder, decoder, data_iter, args): 196 | print("=================Eval======================") 197 | encoder.eval() 198 | decoder.eval() 199 | avg_loss = 0 200 | rouge_1 = 0.0 201 | rouge_2 = 0.0 202 | index2word = data_iter.dataset.index2word 203 | for batch in data_iter: 204 | feature = Variable(batch, requires_grad=False) 205 | if args.use_cuda: 206 | feature = feature.cuda() 207 | h = encoder(feature) 208 | prob = decoder(h) 209 | _, predict_index = torch.max(prob, 2) 210 | original_sentences = [util.transform_id2word(sentence, index2word, "en") for sentence in batch] 211 | predict_sentences = [util.transform_id2word(sentence, index2word, "en") for sentence in predict_index.data] 212 | r1, r2 = calc_rouge(original_sentences, predict_sentences) 213 | rouge_1 += r1 214 | rouge_2 += r2 215 | reconstruction_loss = compute_cross_entropy(prob, feature) 216 | avg_loss += reconstruction_loss.data[0] 217 | avg_loss = avg_loss / len(data_iter.dataset) 218 | avg_loss = avg_loss / args.sentence_len 219 | rouge_1 = rouge_1 / len(data_iter.dataset) 220 | rouge_2 = rouge_2 / len(data_iter.dataset) 221 | print("Evaluation - loss: {} Rouge1: {} Rouge2: {}".format(avg_loss, rouge_1, rouge_2)) 222 | print("===============================================================") 223 | encoder.train() 224 | decoder.train() 225 | 226 | def calc_rouge(original_sentences, predict_sentences): 227 | rouge_1 = 0.0 228 | rouge_2 = 0.0 229 | for original, predict in zip(original_sentences, predict_sentences): 230 | # Remove padding 231 | original, predict = original.replace("", "").strip(), predict.replace("", "").strip() 232 | rouge = RougeCalculator(stopwords=True, lang="en") 233 | r1 = rouge.rouge_1(summary=predict, references=original) 234 | r2 = rouge.rouge_2(summary=predict, references=original) 235 | rouge_1 += r1 236 | rouge_2 += r2 237 | return rouge_1, rouge_2 238 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | 5 | def transform_id2word(index, id2word, lang): 6 | if lang == "ja": 7 | return "".join([id2word[idx] for idx in index]) 8 | else: 9 | return " ".join([id2word[idx] for idx in index]) 10 | 11 | def sigmoid_annealing_schedule(step, max_step, param_init=1.0, param_final=0.01, gain=0.3): 12 | return ((param_init - param_final) / (1 + math.exp(gain * (step - (max_step / 2))))) + param_final 13 | 14 | def save_models(model, path, prefix, steps): 15 | if not os.path.isdir(path): 16 | os.makedirs(path) 17 | model_save_path = '{}/{}_steps_{}.pt'.format(path, prefix, steps) 18 | torch.save(model, model_save_path) 19 | --------------------------------------------------------------------------------