├── .gitignore ├── LICENSE ├── README.md ├── config.yml ├── data └── rt-polaritydata │ ├── rt-polarity.neg │ └── rt-polarity.pos ├── data_helpers.py ├── eval.py ├── text_cnn.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | runs/ 3 | 4 | # Created by https://www.gitignore.io/api/python,ipythonnotebook 5 | 6 | ### Python ### 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *,cover 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | 60 | # Sphinx documentation 61 | docs/_build/ 62 | 63 | # PyBuilder 64 | target/ 65 | 66 | 67 | ### IPythonNotebook ### 68 | # Temporary data 69 | .ipynb_checkpoints/ 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **[This code belongs to the "Implementing a CNN for Text Classification in Tensorflow" blog post.](http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/)** 2 | 3 | It is slightly simplified implementation of Kim's [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) paper in Tensorflow. 4 | 5 | ## Requirements 6 | 7 | - Python 3 8 | - Tensorflow > 0.12 9 | - Numpy 10 | 11 | ## Training 12 | 13 | Print parameters: 14 | 15 | ```bash 16 | ./train.py --help 17 | ``` 18 | 19 | ``` 20 | optional arguments: 21 | -h, --help show this help message and exit 22 | --embedding_dim EMBEDDING_DIM 23 | Dimensionality of character embedding (default: 128) 24 | --enable_word_embeddings 25 | Enable/disable the word embeddings (default: True) 26 | --filter_sizes FILTER_SIZES 27 | Comma-separated filter sizes (default: '3,4,5') 28 | --num_filters NUM_FILTERS 29 | Number of filters per filter size (default: 128) 30 | --l2_reg_lambda L2_REG_LAMBDA 31 | L2 regularizaion lambda (default: 0.0) 32 | --dropout_keep_prob DROPOUT_KEEP_PROB 33 | Dropout keep probability (default: 0.5) 34 | --batch_size BATCH_SIZE 35 | Batch Size (default: 64) 36 | --num_epochs NUM_EPOCHS 37 | Number of training epochs (default: 100) 38 | --evaluate_every EVALUATE_EVERY 39 | Evaluate model on dev set after this many steps 40 | (default: 100) 41 | --checkpoint_every CHECKPOINT_EVERY 42 | Save model after this many steps (default: 100) 43 | --allow_soft_placement ALLOW_SOFT_PLACEMENT 44 | Allow device soft device placement 45 | --noallow_soft_placement 46 | --log_device_placement LOG_DEVICE_PLACEMENT 47 | Log placement of ops on devices 48 | --nolog_device_placement 49 | 50 | ``` 51 | 52 | Train: 53 | 54 | ```bash 55 | ./train.py 56 | ``` 57 | 58 | ## Evaluating 59 | 60 | ```bash 61 | ./eval.py --eval_train --checkpoint_dir="./runs/1459637919/checkpoints/" 62 | ``` 63 | 64 | Replace the checkpoint dir with the output from the training. To use your own data, change the `eval.py` script to load your data. 65 | 66 | 67 | ## References 68 | 69 | - [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) 70 | - [A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1510.03820) 71 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | word_embeddings: 2 | # Two types of word embedding algorithm (word2vec and glove) are supported. 3 | # Just set the default to empty string to disable the word embeddings 4 | default: word2vec 5 | word2vec: 6 | path: ../../data/input/word_embeddings/GoogleNews-vectors-negative300.bin 7 | dimension: 300 8 | binary: True 9 | glove: 10 | path: ../../data/glove.6B.100d.txt 11 | dimension: 100 12 | length: 400000 13 | 14 | datasets: 15 | # Support currently 3 datasets: mrpolarity, 20newsgroup and localdata 16 | default: 20newsgroup 17 | mrpolarity: 18 | positive_data_file: 19 | path: "data/rt-polaritydata/rt-polarity.pos" 20 | info: "Data source for the positive data" 21 | negative_data_file: 22 | path: "data/rt-polaritydata/rt-polarity.neg" 23 | info: "Data source for the negative data" 24 | 20newsgroup: 25 | # The dataset includes following 20 newsgroups: 26 | # alt.atheism, comp.windows.x, rec.sport.hockey, soc.religion.christian 27 | # comp.graphics, misc.forsale, sci.crypt, talk.politics.guns 28 | # comp.os.ms-windows.misc, rec.autos, sci.electronics, talk.politics.mideast 29 | # comp.sys.ibm.pc.hardware, rec.motorcycles, sci.med, talk.politics.misc 30 | # comp.sys.mac.hardware, rec.sport.baseball, sci.space, talk.religion.misc 31 | categories: 32 | - alt.atheism 33 | - comp.graphics 34 | - sci.med 35 | - soc.religion.christian 36 | shuffle: True 37 | random_state: 42 38 | localdata: 39 | # Load text files with categories as subfolder names. 40 | # Individual samples are assumed to be files stored 41 | # a two levels folder structure such as the following: 42 | # container_folder/ 43 | # category_1_folder/ 44 | # file_1.txt file_2.txt ... file_42.txt 45 | # category_2_folder/ 46 | # file_43.txt file_44.txt ... 47 | # 48 | # As an example, a SentenceCorpus dataset from 49 | # https://archive.ics.uci.edu/ml/datasets/Sentence+Classification 50 | # has been used. The dataset includes following 3 domains: 51 | # arxiv, jdm and plos 52 | container_path: ../../data/input/SentenceCorpus 53 | categories: 54 | shuffle: True 55 | random_state: 42 56 | 57 | -------------------------------------------------------------------------------- /data_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | from sklearn.datasets import fetch_20newsgroups 4 | from sklearn.datasets import load_files 5 | 6 | 7 | def clean_str(string): 8 | """ 9 | Tokenization/string cleaning for all datasets except for SST. 10 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 11 | """ 12 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 13 | string = re.sub(r"\'s", " \'s", string) 14 | string = re.sub(r"\'ve", " \'ve", string) 15 | string = re.sub(r"n\'t", " n\'t", string) 16 | string = re.sub(r"\'re", " \'re", string) 17 | string = re.sub(r"\'d", " \'d", string) 18 | string = re.sub(r"\'ll", " \'ll", string) 19 | string = re.sub(r",", " , ", string) 20 | string = re.sub(r"!", " ! ", string) 21 | string = re.sub(r"\(", " \( ", string) 22 | string = re.sub(r"\)", " \) ", string) 23 | string = re.sub(r"\?", " \? ", string) 24 | string = re.sub(r"\s{2,}", " ", string) 25 | return string.strip().lower() 26 | 27 | 28 | def batch_iter(data, batch_size, num_epochs, shuffle=True): 29 | """ 30 | Generates a batch iterator for a dataset. 31 | """ 32 | data = np.array(data) 33 | data_size = len(data) 34 | num_batches_per_epoch = int((len(data)-1)/batch_size) + 1 35 | for epoch in range(num_epochs): 36 | # Shuffle the data at each epoch 37 | if shuffle: 38 | shuffle_indices = np.random.permutation(np.arange(data_size)) 39 | shuffled_data = data[shuffle_indices] 40 | else: 41 | shuffled_data = data 42 | for batch_num in range(num_batches_per_epoch): 43 | start_index = batch_num * batch_size 44 | end_index = min((batch_num + 1) * batch_size, data_size) 45 | yield shuffled_data[start_index:end_index] 46 | 47 | 48 | def get_datasets_20newsgroup(subset='train', categories=None, shuffle=True, random_state=42): 49 | """ 50 | Retrieve data from 20 newsgroups 51 | :param subset: train, test or all 52 | :param categories: List of newsgroup name 53 | :param shuffle: shuffle the list or not 54 | :param random_state: seed integer to shuffle the dataset 55 | :return: data and labels of the newsgroup 56 | """ 57 | datasets = fetch_20newsgroups(subset=subset, categories=categories, shuffle=shuffle, random_state=random_state) 58 | return datasets 59 | 60 | 61 | def get_datasets_mrpolarity(positive_data_file, negative_data_file): 62 | """ 63 | Loads MR polarity data from files, splits the data into words and generates labels. 64 | Returns split sentences and labels. 65 | """ 66 | # Load data from files 67 | positive_examples = list(open(positive_data_file, "r").readlines()) 68 | positive_examples = [s.strip() for s in positive_examples] 69 | negative_examples = list(open(negative_data_file, "r").readlines()) 70 | negative_examples = [s.strip() for s in negative_examples] 71 | 72 | datasets = dict() 73 | datasets['data'] = positive_examples + negative_examples 74 | target = [0 for x in positive_examples] + [1 for x in negative_examples] 75 | datasets['target'] = target 76 | datasets['target_names'] = ['positive_examples', 'negative_examples'] 77 | return datasets 78 | 79 | 80 | def get_datasets_localdata(container_path=None, categories=None, load_content=True, 81 | encoding='utf-8', shuffle=True, random_state=42): 82 | """ 83 | Load text files with categories as subfolder names. 84 | Individual samples are assumed to be files stored a two levels folder structure. 85 | :param container_path: The path of the container 86 | :param categories: List of classes to choose, all classes are chosen by default (if empty or omitted) 87 | :param shuffle: shuffle the list or not 88 | :param random_state: seed integer to shuffle the dataset 89 | :return: data and labels of the dataset 90 | """ 91 | datasets = load_files(container_path=container_path, categories=categories, 92 | load_content=load_content, shuffle=shuffle, encoding=encoding, 93 | random_state=random_state) 94 | return datasets 95 | 96 | 97 | def load_data_labels(datasets): 98 | """ 99 | Load data and labels 100 | :param datasets: 101 | :return: 102 | """ 103 | # Split by words 104 | x_text = datasets['data'] 105 | x_text = [clean_str(sent) for sent in x_text] 106 | # Generate labels 107 | labels = [] 108 | for i in range(len(x_text)): 109 | label = [0 for j in datasets['target_names']] 110 | label[datasets['target'][i]] = 1 111 | labels.append(label) 112 | y = np.array(labels) 113 | return [x_text, y] 114 | 115 | 116 | def load_embedding_vectors_word2vec(vocabulary, filename, binary): 117 | # load embedding_vectors from the word2vec 118 | encoding = 'utf-8' 119 | with open(filename, "rb") as f: 120 | header = f.readline() 121 | vocab_size, vector_size = map(int, header.split()) 122 | # initial matrix with random uniform 123 | embedding_vectors = np.random.uniform(-0.25, 0.25, (len(vocabulary), vector_size)) 124 | if binary: 125 | binary_len = np.dtype('float32').itemsize * vector_size 126 | for line_no in range(vocab_size): 127 | word = [] 128 | while True: 129 | ch = f.read(1) 130 | if ch == b' ': 131 | break 132 | if ch == b'': 133 | raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") 134 | if ch != b'\n': 135 | word.append(ch) 136 | word = str(b''.join(word), encoding=encoding, errors='strict') 137 | idx = vocabulary.get(word) 138 | if idx != 0: 139 | embedding_vectors[idx] = np.fromstring(f.read(binary_len), dtype='float32') 140 | else: 141 | f.seek(binary_len, 1) 142 | else: 143 | for line_no in range(vocab_size): 144 | line = f.readline() 145 | if line == b'': 146 | raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") 147 | parts = str(line.rstrip(), encoding=encoding, errors='strict').split(" ") 148 | if len(parts) != vector_size + 1: 149 | raise ValueError("invalid vector on line %s (is this really the text format?)" % (line_no)) 150 | word, vector = parts[0], list(map('float32', parts[1:])) 151 | idx = vocabulary.get(word) 152 | if idx != 0: 153 | embedding_vectors[idx] = vector 154 | f.close() 155 | return embedding_vectors 156 | 157 | 158 | def load_embedding_vectors_glove(vocabulary, filename, vector_size): 159 | # load embedding_vectors from the glove 160 | # initial matrix with random uniform 161 | embedding_vectors = np.random.uniform(-0.25, 0.25, (len(vocabulary), vector_size)) 162 | f = open(filename) 163 | for line in f: 164 | values = line.split() 165 | word = values[0] 166 | vector = np.asarray(values[1:], dtype="float32") 167 | idx = vocabulary.get(word) 168 | if idx != 0: 169 | embedding_vectors[idx] = vector 170 | f.close() 171 | return embedding_vectors 172 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import os 6 | import data_helpers 7 | from tensorflow.contrib import learn 8 | import csv 9 | from sklearn import metrics 10 | import yaml 11 | 12 | 13 | def softmax(x): 14 | """Compute softmax values for each sets of scores in x.""" 15 | if x.ndim == 1: 16 | x = x.reshape((1, -1)) 17 | max_x = np.max(x, axis=1).reshape((-1, 1)) 18 | exp_x = np.exp(x - max_x) 19 | return exp_x / np.sum(exp_x, axis=1).reshape((-1, 1)) 20 | 21 | with open("config.yml", 'r') as ymlfile: 22 | cfg = yaml.load(ymlfile) 23 | 24 | # Parameters 25 | # ================================================== 26 | 27 | # Data Parameters 28 | 29 | # Eval Parameters 30 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)") 31 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") 32 | tf.flags.DEFINE_boolean("eval_train", False, "Evaluate on all training data") 33 | 34 | # Misc Parameters 35 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 36 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 37 | 38 | 39 | FLAGS = tf.flags.FLAGS 40 | FLAGS._parse_flags() 41 | print("\nParameters:") 42 | for attr, value in sorted(FLAGS.__flags.items()): 43 | print("{}={}".format(attr.upper(), value)) 44 | print("") 45 | 46 | datasets = None 47 | 48 | # CHANGE THIS: Load data. Load your own data here 49 | dataset_name = cfg["datasets"]["default"] 50 | if FLAGS.eval_train: 51 | if dataset_name == "mrpolarity": 52 | datasets = data_helpers.get_datasets_mrpolarity(cfg["datasets"][dataset_name]["positive_data_file"]["path"], 53 | cfg["datasets"][dataset_name]["negative_data_file"]["path"]) 54 | elif dataset_name == "20newsgroup": 55 | datasets = data_helpers.get_datasets_20newsgroup(subset="test", 56 | categories=cfg["datasets"][dataset_name]["categories"], 57 | shuffle=cfg["datasets"][dataset_name]["shuffle"], 58 | random_state=cfg["datasets"][dataset_name]["random_state"]) 59 | x_raw, y_test = data_helpers.load_data_labels(datasets) 60 | y_test = np.argmax(y_test, axis=1) 61 | print("Total number of test examples: {}".format(len(y_test))) 62 | else: 63 | if dataset_name == "mrpolarity": 64 | datasets = {"target_names": ['positive_examples', 'negative_examples']} 65 | x_raw = ["a masterpiece four years in the making", "everything is off."] 66 | y_test = [1, 0] 67 | else: 68 | datasets = {"target_names": ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']} 69 | x_raw = ["The number of reported cases of gonorrhea in Colorado increased", 70 | "I am in the market for a 24-bit graphics card for a PC"] 71 | y_test = [2, 1] 72 | 73 | # Map data into vocabulary 74 | vocab_path = os.path.join(FLAGS.checkpoint_dir, "..", "vocab") 75 | vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path) 76 | x_test = np.array(list(vocab_processor.transform(x_raw))) 77 | 78 | print("\nEvaluating...\n") 79 | 80 | # Evaluation 81 | # ================================================== 82 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 83 | graph = tf.Graph() 84 | with graph.as_default(): 85 | session_conf = tf.ConfigProto( 86 | allow_soft_placement=FLAGS.allow_soft_placement, 87 | log_device_placement=FLAGS.log_device_placement) 88 | sess = tf.Session(config=session_conf) 89 | with sess.as_default(): 90 | # Load the saved meta graph and restore variables 91 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 92 | saver.restore(sess, checkpoint_file) 93 | 94 | # Get the placeholders from the graph by name 95 | input_x = graph.get_operation_by_name("input_x").outputs[0] 96 | # input_y = graph.get_operation_by_name("input_y").outputs[0] 97 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 98 | 99 | # Tensors we want to evaluate 100 | scores = graph.get_operation_by_name("output/scores").outputs[0] 101 | 102 | # Tensors we want to evaluate 103 | predictions = graph.get_operation_by_name("output/predictions").outputs[0] 104 | 105 | # Generate batches for one epoch 106 | batches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False) 107 | 108 | # Collect the predictions here 109 | all_predictions = [] 110 | all_probabilities = None 111 | 112 | for x_test_batch in batches: 113 | batch_predictions_scores = sess.run([predictions, scores], {input_x: x_test_batch, dropout_keep_prob: 1.0}) 114 | all_predictions = np.concatenate([all_predictions, batch_predictions_scores[0]]) 115 | probabilities = softmax(batch_predictions_scores[1]) 116 | if all_probabilities is not None: 117 | all_probabilities = np.concatenate([all_probabilities, probabilities]) 118 | else: 119 | all_probabilities = probabilities 120 | 121 | # Print accuracy if y_test is defined 122 | if y_test is not None: 123 | correct_predictions = float(sum(all_predictions == y_test)) 124 | print("Total number of test examples: {}".format(len(y_test))) 125 | print("Accuracy: {:g}".format(correct_predictions/float(len(y_test)))) 126 | print(metrics.classification_report(y_test, all_predictions, target_names=datasets['target_names'])) 127 | print(metrics.confusion_matrix(y_test, all_predictions)) 128 | 129 | # Save the evaluation to a csv 130 | predictions_human_readable = np.column_stack((np.array(x_raw), 131 | [int(prediction) for prediction in all_predictions], 132 | [ "{}".format(probability) for probability in all_probabilities])) 133 | out_path = os.path.join(FLAGS.checkpoint_dir, "..", "prediction.csv") 134 | print("Saving evaluation to {0}".format(out_path)) 135 | with open(out_path, 'w') as f: 136 | csv.writer(f).writerows(predictions_human_readable) 137 | -------------------------------------------------------------------------------- /text_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class TextCNN(object): 6 | """ 7 | A CNN for text classification. 8 | Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer. 9 | """ 10 | def __init__( 11 | self, sequence_length, num_classes, vocab_size, 12 | embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0): 13 | 14 | # Placeholders for input, output and dropout 15 | self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x") 16 | self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y") 17 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 18 | self.learning_rate = tf.placeholder(tf.float32) 19 | 20 | # Keeping track of l2 regularization loss (optional) 21 | l2_loss = tf.constant(0.0) 22 | 23 | # Embedding layer 24 | with tf.device('/cpu:0'), tf.name_scope("embedding"): 25 | self.W = tf.Variable( 26 | tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), 27 | name="W") 28 | self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x) 29 | self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1) 30 | 31 | # Create a convolution + maxpool layer for each filter size 32 | pooled_outputs = [] 33 | for i, filter_size in enumerate(filter_sizes): 34 | with tf.name_scope("conv-maxpool-%s" % filter_size): 35 | # Convolution Layer 36 | filter_shape = [filter_size, embedding_size, 1, num_filters] 37 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W") 38 | b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b") 39 | conv = tf.nn.conv2d( 40 | self.embedded_chars_expanded, 41 | W, 42 | strides=[1, 1, 1, 1], 43 | padding="VALID", 44 | name="conv") 45 | # Apply nonlinearity 46 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 47 | # Maxpooling over the outputs 48 | pooled = tf.nn.max_pool( 49 | h, 50 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 51 | strides=[1, 1, 1, 1], 52 | padding='VALID', 53 | name="pool") 54 | pooled_outputs.append(pooled) 55 | 56 | # Combine all the pooled features 57 | num_filters_total = num_filters * len(filter_sizes) 58 | self.h_pool = tf.concat(pooled_outputs, 3) 59 | self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total]) 60 | 61 | # Add dropout 62 | with tf.name_scope("dropout"): 63 | self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob) 64 | 65 | # Final (unnormalized) scores and predictions 66 | with tf.name_scope("output"): 67 | W = tf.get_variable( 68 | "W", 69 | shape=[num_filters_total, num_classes], 70 | initializer=tf.contrib.layers.xavier_initializer()) 71 | b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b") 72 | l2_loss += tf.nn.l2_loss(W) 73 | l2_loss += tf.nn.l2_loss(b) 74 | self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores") 75 | self.predictions = tf.argmax(self.scores, 1, name="predictions") 76 | 77 | # CalculateMean cross-entropy loss 78 | with tf.name_scope("loss"): 79 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y) 80 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 81 | 82 | # Accuracy 83 | with tf.name_scope("accuracy"): 84 | correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1)) 85 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import os 6 | import time 7 | import datetime 8 | import data_helpers 9 | from text_cnn import TextCNN 10 | from tensorflow.contrib import learn 11 | import yaml 12 | import math 13 | 14 | # Parameters 15 | # ================================================== 16 | 17 | # Data loading params 18 | tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation") 19 | 20 | # Model Hyperparameters 21 | tf.flags.DEFINE_boolean("enable_word_embeddings", True, "Enable/disable the word embedding (default: True)") 22 | tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)") 23 | tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')") 24 | tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)") 25 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)") 26 | tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)") 27 | 28 | # Training parameters 29 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)") 30 | tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)") 31 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps (default: 100)") 32 | tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (default: 100)") 33 | tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store (default: 5)") 34 | # Misc Parameters 35 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 36 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 37 | tf.flags.DEFINE_float("decay_coefficient", 2.5, "Decay coefficient (default: 2.5)") 38 | 39 | FLAGS = tf.flags.FLAGS 40 | FLAGS._parse_flags() 41 | print("\nParameters:") 42 | for attr, value in sorted(FLAGS.__flags.items()): 43 | print("{}={}".format(attr.upper(), value)) 44 | print("") 45 | 46 | with open("config.yml", 'r') as ymlfile: 47 | cfg = yaml.load(ymlfile) 48 | 49 | dataset_name = cfg["datasets"]["default"] 50 | if FLAGS.enable_word_embeddings and cfg['word_embeddings']['default'] is not None: 51 | embedding_name = cfg['word_embeddings']['default'] 52 | embedding_dimension = cfg['word_embeddings'][embedding_name]['dimension'] 53 | else: 54 | embedding_dimension = FLAGS.embedding_dim 55 | 56 | # Data Preparation 57 | # ================================================== 58 | 59 | # Load data 60 | print("Loading data...") 61 | datasets = None 62 | if dataset_name == "mrpolarity": 63 | datasets = data_helpers.get_datasets_mrpolarity(cfg["datasets"][dataset_name]["positive_data_file"]["path"], 64 | cfg["datasets"][dataset_name]["negative_data_file"]["path"]) 65 | elif dataset_name == "20newsgroup": 66 | datasets = data_helpers.get_datasets_20newsgroup(subset="train", 67 | categories=cfg["datasets"][dataset_name]["categories"], 68 | shuffle=cfg["datasets"][dataset_name]["shuffle"], 69 | random_state=cfg["datasets"][dataset_name]["random_state"]) 70 | elif dataset_name == "localdata": 71 | datasets = data_helpers.get_datasets_localdata(container_path=cfg["datasets"][dataset_name]["container_path"], 72 | categories=cfg["datasets"][dataset_name]["categories"], 73 | shuffle=cfg["datasets"][dataset_name]["shuffle"], 74 | random_state=cfg["datasets"][dataset_name]["random_state"]) 75 | x_text, y = data_helpers.load_data_labels(datasets) 76 | 77 | # Build vocabulary 78 | max_document_length = max([len(x.split(" ")) for x in x_text]) 79 | vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length) 80 | x = np.array(list(vocab_processor.fit_transform(x_text))) 81 | 82 | # Randomly shuffle data 83 | np.random.seed(10) 84 | shuffle_indices = np.random.permutation(np.arange(len(y))) 85 | x_shuffled = x[shuffle_indices] 86 | y_shuffled = y[shuffle_indices] 87 | 88 | # Split train/test set 89 | # TODO: This is very crude, should use cross-validation 90 | dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y))) 91 | x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:] 92 | y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:] 93 | print("Vocabulary Size: {:d}".format(len(vocab_processor.vocabulary_))) 94 | print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev))) 95 | 96 | 97 | # Training 98 | # ================================================== 99 | 100 | with tf.Graph().as_default(): 101 | session_conf = tf.ConfigProto( 102 | allow_soft_placement=FLAGS.allow_soft_placement, 103 | log_device_placement=FLAGS.log_device_placement) 104 | sess = tf.Session(config=session_conf) 105 | with sess.as_default(): 106 | cnn = TextCNN( 107 | sequence_length=x_train.shape[1], 108 | num_classes=y_train.shape[1], 109 | vocab_size=len(vocab_processor.vocabulary_), 110 | embedding_size=embedding_dimension, 111 | filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), 112 | num_filters=FLAGS.num_filters, 113 | l2_reg_lambda=FLAGS.l2_reg_lambda) 114 | 115 | # Define Training procedure 116 | global_step = tf.Variable(0, name="global_step", trainable=False) 117 | optimizer = tf.train.AdamOptimizer(cnn.learning_rate) 118 | grads_and_vars = optimizer.compute_gradients(cnn.loss) 119 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) 120 | 121 | # Keep track of gradient values and sparsity (optional) 122 | grad_summaries = [] 123 | for g, v in grads_and_vars: 124 | if g is not None: 125 | grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g) 126 | sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) 127 | grad_summaries.append(grad_hist_summary) 128 | grad_summaries.append(sparsity_summary) 129 | grad_summaries_merged = tf.summary.merge(grad_summaries) 130 | 131 | # Output directory for models and summaries 132 | timestamp = str(int(time.time())) 133 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 134 | print("Writing to {}\n".format(out_dir)) 135 | 136 | # Summaries for loss and accuracy 137 | loss_summary = tf.summary.scalar("loss", cnn.loss) 138 | acc_summary = tf.summary.scalar("accuracy", cnn.accuracy) 139 | 140 | # Train Summaries 141 | train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged]) 142 | train_summary_dir = os.path.join(out_dir, "summaries", "train") 143 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 144 | 145 | # Dev summaries 146 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) 147 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 148 | dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) 149 | 150 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 151 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 152 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 153 | if not os.path.exists(checkpoint_dir): 154 | os.makedirs(checkpoint_dir) 155 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) 156 | 157 | # Write vocabulary 158 | vocab_processor.save(os.path.join(out_dir, "vocab")) 159 | 160 | # Initialize all variables 161 | sess.run(tf.global_variables_initializer()) 162 | if FLAGS.enable_word_embeddings and cfg['word_embeddings']['default'] is not None: 163 | vocabulary = vocab_processor.vocabulary_ 164 | initW = None 165 | if embedding_name == 'word2vec': 166 | # load embedding vectors from the word2vec 167 | print("Load word2vec file {}".format(cfg['word_embeddings']['word2vec']['path'])) 168 | initW = data_helpers.load_embedding_vectors_word2vec(vocabulary, 169 | cfg['word_embeddings']['word2vec']['path'], 170 | cfg['word_embeddings']['word2vec']['binary']) 171 | print("word2vec file has been loaded") 172 | elif embedding_name == 'glove': 173 | # load embedding vectors from the glove 174 | print("Load glove file {}".format(cfg['word_embeddings']['glove']['path'])) 175 | initW = data_helpers.load_embedding_vectors_glove(vocabulary, 176 | cfg['word_embeddings']['glove']['path'], 177 | embedding_dimension) 178 | print("glove file has been loaded\n") 179 | sess.run(cnn.W.assign(initW)) 180 | 181 | def train_step(x_batch, y_batch, learning_rate): 182 | """ 183 | A single training step 184 | """ 185 | feed_dict = { 186 | cnn.input_x: x_batch, 187 | cnn.input_y: y_batch, 188 | cnn.dropout_keep_prob: FLAGS.dropout_keep_prob, 189 | cnn.learning_rate: learning_rate 190 | } 191 | _, step, summaries, loss, accuracy = sess.run( 192 | [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy], 193 | feed_dict) 194 | time_str = datetime.datetime.now().isoformat() 195 | print("{}: step {}, loss {:g}, acc {:g}, learning_rate {:g}" 196 | .format(time_str, step, loss, accuracy, learning_rate)) 197 | train_summary_writer.add_summary(summaries, step) 198 | 199 | def dev_step(x_batch, y_batch, writer=None): 200 | """ 201 | Evaluates model on a dev set 202 | """ 203 | feed_dict = { 204 | cnn.input_x: x_batch, 205 | cnn.input_y: y_batch, 206 | cnn.dropout_keep_prob: 1.0 207 | } 208 | step, summaries, loss, accuracy = sess.run( 209 | [global_step, dev_summary_op, cnn.loss, cnn.accuracy], 210 | feed_dict) 211 | time_str = datetime.datetime.now().isoformat() 212 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 213 | if writer: 214 | writer.add_summary(summaries, step) 215 | 216 | # Generate batches 217 | batches = data_helpers.batch_iter( 218 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) 219 | # It uses dynamic learning rate with a high value at the beginning to speed up the training 220 | max_learning_rate = 0.005 221 | min_learning_rate = 0.0001 222 | decay_speed = FLAGS.decay_coefficient*len(y_train)/FLAGS.batch_size 223 | # Training loop. For each batch... 224 | counter = 0 225 | for batch in batches: 226 | learning_rate = min_learning_rate + (max_learning_rate - min_learning_rate) * math.exp(-counter/decay_speed) 227 | counter += 1 228 | x_batch, y_batch = zip(*batch) 229 | train_step(x_batch, y_batch, learning_rate) 230 | current_step = tf.train.global_step(sess, global_step) 231 | if current_step % FLAGS.evaluate_every == 0: 232 | print("\nEvaluation:") 233 | dev_step(x_dev, y_dev, writer=dev_summary_writer) 234 | print("") 235 | if current_step % FLAGS.checkpoint_every == 0: 236 | path = saver.save(sess, checkpoint_prefix, global_step=current_step) 237 | print("Saved model checkpoint to {}\n".format(path)) 238 | --------------------------------------------------------------------------------