├── .gitignore ├── LICENSE ├── README.md ├── data ├── kkk.cls ├── kkk.dev └── kkk.train ├── data_helpers.py ├── eval.py ├── multi_class_data_loader.py ├── text_cnn.py ├── train.py └── word_data_processor.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | runs/ 3 | 4 | # Created by https://www.gitignore.io/api/python,ipythonnotebook 5 | 6 | ### Python ### 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *,cover 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | 60 | # Sphinx documentation 61 | docs/_build/ 62 | 63 | # PyBuilder 64 | target/ 65 | 66 | 67 | ### IPythonNotebook ### 68 | # Temporary data 69 | .ipynb_checkpoints/ 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **[This code belongs to the "Implementing a CNN for Text Classification in Tensorflow" blog post.](http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/)** 2 | 3 | It is slightly simplified implementation of Kim's [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) paper in Tensorflow. 4 | 5 | 한글로 소개한 문서는 [합성곱 신경망(CNN) 딥러닝을 이용한 한국어 문장 분류](http://docs.likejazz.com/cnn-text-classification-tf)를 참고하시기 바랍니다 6 | 7 | ## Requirements 8 | 9 | - Python 3 10 | - Tensorflow > 0.8 11 | - Numpy 12 | 13 | ## Training 14 | 15 | Print parameters: 16 | 17 | ```bash 18 | ./train.py --help 19 | ``` 20 | 21 | ``` 22 | optional arguments: 23 | -h, --help show this help message and exit 24 | --embedding_dim EMBEDDING_DIM 25 | Dimensionality of character embedding (default: 128) 26 | --filter_sizes FILTER_SIZES 27 | Comma-separated filter sizes (default: '3,4,5') 28 | --num_filters NUM_FILTERS 29 | Number of filters per filter size (default: 128) 30 | --l2_reg_lambda L2_REG_LAMBDA 31 | L2 regularizaion lambda (default: 0.0) 32 | --dropout_keep_prob DROPOUT_KEEP_PROB 33 | Dropout keep probability (default: 0.5) 34 | --batch_size BATCH_SIZE 35 | Batch Size (default: 64) 36 | --num_epochs NUM_EPOCHS 37 | Number of training epochs (default: 100) 38 | --evaluate_every EVALUATE_EVERY 39 | Evaluate model on dev set after this many steps 40 | (default: 100) 41 | --checkpoint_every CHECKPOINT_EVERY 42 | Save model after this many steps (default: 100) 43 | --allow_soft_placement ALLOW_SOFT_PLACEMENT 44 | Allow device soft device placement 45 | --noallow_soft_placement 46 | --log_device_placement LOG_DEVICE_PLACEMENT 47 | Log placement of ops on devices 48 | --nolog_device_placement 49 | 50 | ``` 51 | 52 | Train: 53 | 54 | ```bash 55 | ./train.py 56 | ``` 57 | 58 | ## Evaluating 59 | 60 | ```bash 61 | ./eval.py 62 | ``` 63 | 64 | ## References 65 | 66 | - [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) 67 | - [A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1510.03820) -------------------------------------------------------------------------------- /data/kkk.cls: -------------------------------------------------------------------------------- 1 | politics 2 | entertain 3 | etc -------------------------------------------------------------------------------- /data/kkk.dev: -------------------------------------------------------------------------------- 1 | 문재인 박지원 안희정 이유,politics 2 | 오늘 초아 태연,entertain 3 | 반기문 일본 세자,politics 4 | 러블리즈 지수,entertain 5 | 기타 환불 카드사 채무,etc 6 | 안희정 문재인 과거,politics 7 | 트위터 서울대 수석,etc 8 | 40 직장 남성 고민 상담,etc 9 | 오늘 첫끼 떡볶이,etc 10 | 눈 이유,etc -------------------------------------------------------------------------------- /data/kkk.train: -------------------------------------------------------------------------------- 1 | 도깨비 사랑 물리학 김인육,etc 2 | 종합 검진 대장 용종 제거 수술,etc 3 | 라이트룸 서적 중,etc 4 | 탐론 렌즈 조합 가능,etc 5 | 준중형 크루즈 독일제,etc 6 | 박 법적 수사 심판 준비,etc 7 | 치과 치료 간만 금액 후덜덜,etc 8 | 사진사 문제 가방 관련 질문,etc 9 | 카메라 의식 때 새누리 이정현,politics 10 | 질문 괌 여행 렌즈 채택 부탁,etc 11 | 고려 말 왜구 규모,etc 12 | 문재인 이재명 비교 글 금지,politics 13 | 얘 사진 곳 가면,etc 14 | 메모리 카드,etc 15 | 장수 사진 부탁,etc 16 | 하연 생각,etc 17 | 요즘 사진,etc 18 | 후방 논쟁 장면,etc 19 | 2016 발롱 발롱도르 도르 축구계 1 형 전성기,etc 20 | 현재 바르셀로나 팬 카페 상황,etc 21 | 아제로스 합법 무기 소 서리,etc 22 | 메이킹 육교 저승,etc 23 | 충격 촬영 장면,etc 24 | 초아,entertain 25 | 조국 교수 트위터,politics 26 | 검찰 요약 청와대 권력,politics 27 | 오늘 태연,entertain 28 | 클로저스 애니 수준,etc 29 | 일본 유심,etc 30 | 휴면 계좌 검색 차마,etc 31 | 예전 조개 무한 리필 집 요즘 이유,etc 32 | 일본 세자 부부 반기문,politics 33 | 트와이스 내부 불화,entertain 34 | 프로토 오늘 밤 모임 관계 베팅 좀,etc 35 | 한국 만화 일본 범람 결과 글,etc 36 | 멍멍이 멍멍이,etc 37 | 정신,etc 38 | 호불호,etc 39 | 솔지 담배 탁재훈,entertain 40 | 전설 답변,etc 41 | 대학생 현상황,etc 42 | 외출,etc 43 | 아프리카 사람 한국 처음,etc 44 | 무한 무한도전 도전 때 무도,etc 45 | 펌 엄마,etc 46 | 휴면 계좌 자기 돈 5만 이상,etc 47 | 아사히 맥주 필스너 필스너우르켈 우르켈 인수,etc 48 | 약 아이,etc 49 | 유럽 화장실,etc 50 | 여성 유저 고난,etc 51 | 연세 연세대학교 대학교 나무 나무인간 인간 카메라,etc 52 | 종편 자칭 전문가 문재인 평가,politics 53 | 일본 동물원,etc 54 | 연예 러블리즈 지수,entertain 55 | 이슈 울산 군 부대 폭발 추정 사고 현역 2,etc 56 | 이황 선생,etc 57 | 선인 상가 55 픽업 후기,etc 58 | 집값 2 7천만,etc 59 | 기타리스트 최고 간지,etc 60 | 내년 연봉 협상 8 인상,etc 61 | 셔틀콕 컨트롤 달인,etc 62 | 컬투쇼 공효진 사면초가 사면 초과,entertain 63 | 기타 사람 개,etc 64 | 누가 더 불호,etc 65 | 갓 독일 경제 위엄,etc 66 | 이슈 누가,etc 67 | 차 똥 냄새,etc 68 | 공중파 집밥 레전드,etc 69 | 국민,etc 70 | 요즘 아무,etc 71 | 일본 미녀 한국 한국 남자 1 데이,etc 72 | 제품 의자,etc 73 | 세관 추가 직원 실수,etc 74 | 이번 빅뱅 의외 복병,entertain 75 | 휴직 고려중,etc 76 | 토익 점수,etc 77 | 파주 우동 집 운영 질문,etc 78 | 드래곤 드래곤볼 볼 야무치 만화 다음 에피소드,etc 79 | 인도 카스트 제도 수준,etc 80 | 박지원 문재인 이유,politics 81 | 명절날 가장 안전 건프라,etc 82 | 편의점 인기 라면 10,etc 83 | 모바일 리니지 2 실제 게임 플레이 영상,etc 84 | 전세 계약 종료 후,etc 85 | 이후 유용 기능,etc 86 | 공대식 유우머,etc 87 | 비행기 여성,etc 88 | 아버지 고백,etc 89 | 닭 살처분,etc 90 | 슬라임 카레,etc 91 | 4만 후 저녁 버라이어티,etc 92 | 스압 남편 요정 인형,etc 93 | 집 앞 야옹이,etc 94 | 알 유,etc 95 | 아마존 개 시리즈,etc 96 | 초보 펀드 투자,etc 97 | 계층 아이폰,etc 98 | 여사 오늘 건,etc 99 | 당시 한국 국방 발전,etc 100 | 문재인 안희정,politics -------------------------------------------------------------------------------- /data_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def batch_iter(data, batch_size, num_epochs, shuffle=True): 4 | """ 5 | Generates a batch iterator for a dataset. 6 | """ 7 | data = np.array(data) 8 | data_size = len(data) 9 | num_batches_per_epoch = int(len(data)/batch_size) + 1 10 | for epoch in range(num_epochs): 11 | # Shuffle the data at each epoch 12 | if shuffle: 13 | shuffle_indices = np.random.permutation(np.arange(data_size)) 14 | shuffled_data = data[shuffle_indices] 15 | else: 16 | shuffled_data = data 17 | for batch_num in range(num_batches_per_epoch): 18 | start_index = batch_num * batch_size 19 | end_index = min((batch_num + 1) * batch_size, data_size) 20 | yield shuffled_data[start_index:end_index] 21 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import json 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import os 8 | import data_helpers 9 | from multi_class_data_loader import MultiClassDataLoader 10 | from word_data_processor import WordDataProcessor 11 | import csv 12 | 13 | # Parameters 14 | # ================================================== 15 | 16 | # Eval Parameters 17 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)") 18 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") 19 | tf.flags.DEFINE_boolean("eval_train", False, "Evaluate on all training data") 20 | 21 | # Misc Parameters 22 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 23 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 24 | 25 | data_loader = MultiClassDataLoader(tf.flags, WordDataProcessor()) 26 | data_loader.define_flags() 27 | 28 | FLAGS = tf.flags.FLAGS 29 | FLAGS._parse_flags() 30 | print("\nParameters:") 31 | for attr, value in sorted(FLAGS.__flags.items()): 32 | print("{}={}".format(attr.upper(), value)) 33 | print("") 34 | 35 | if FLAGS.eval_train: 36 | x_raw, y_test = data_loader.load_data_and_labels() 37 | y_test = np.argmax(y_test, axis=1) 38 | else: 39 | x_raw, y_test = data_loader.load_dev_data_and_labels() 40 | y_test = np.argmax(y_test, axis=1) 41 | 42 | # checkpoint_dir이 없다면 가장 최근 dir 추출하여 셋팅 43 | if FLAGS.checkpoint_dir == "": 44 | all_subdirs = ["./runs/" + d for d in os.listdir('./runs/.') if os.path.isdir("./runs/" + d)] 45 | latest_subdir = max(all_subdirs, key=os.path.getmtime) 46 | FLAGS.checkpoint_dir = latest_subdir + "/checkpoints/" 47 | 48 | # Map data into vocabulary 49 | vocab_path = os.path.join(FLAGS.checkpoint_dir, "..", "vocab") 50 | vocab_processor = data_loader.restore_vocab_processor(vocab_path) 51 | x_test = np.array(list(vocab_processor.transform(x_raw))) 52 | 53 | print("\nEvaluating...\n") 54 | 55 | # Evaluation 56 | # ================================================== 57 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 58 | graph = tf.Graph() 59 | with graph.as_default(): 60 | session_conf = tf.ConfigProto( 61 | allow_soft_placement=FLAGS.allow_soft_placement, 62 | log_device_placement=FLAGS.log_device_placement) 63 | sess = tf.Session(config=session_conf) 64 | with sess.as_default(): 65 | # Load the saved meta graph and restore variables 66 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 67 | saver.restore(sess, checkpoint_file) 68 | 69 | # Get the placeholders from the graph by name 70 | input_x = graph.get_operation_by_name("input_x").outputs[0] 71 | # input_y = graph.get_operation_by_name("input_y").outputs[0] 72 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 73 | 74 | # Tensors we want to evaluate 75 | predictions = graph.get_operation_by_name("output/predictions").outputs[0] 76 | 77 | # Generate batches for one epoch 78 | batches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False) 79 | 80 | # Collect the predictions here 81 | all_predictions = [] 82 | 83 | for x_test_batch in batches: 84 | batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0}) 85 | all_predictions = np.concatenate([all_predictions, batch_predictions]) 86 | 87 | # Print accuracy if y_test is defined 88 | if y_test is not None: 89 | correct_predictions = float(sum(all_predictions == y_test)) 90 | print("Total number of test examples: {}".format(len(y_test))) 91 | print("Accuracy: {:g}".format(correct_predictions/float(len(y_test)))) 92 | 93 | # Save the evaluation to a csv 94 | class_predictions = data_loader.class_labels(all_predictions.astype(int)) 95 | predictions_human_readable = np.column_stack((np.array(x_raw), class_predictions)) 96 | out_path = os.path.join(FLAGS.checkpoint_dir, "../../../", "prediction.csv") 97 | print("Saving evaluation to {0}".format(out_path)) 98 | with open(out_path, 'w') as f: 99 | csv.writer(f).writerows(predictions_human_readable) -------------------------------------------------------------------------------- /multi_class_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | 4 | class MultiClassDataLoader(object): 5 | """ 6 | Handles multi-class training data. It takes predefined sets of "train_data_file" and "dev_data_file" 7 | of the following record format. 8 | \t 9 | ex. "what a masterpiece! Positive" 10 | 11 | Class labels are given as "class_data_file", which is a list of class labels. 12 | """ 13 | def __init__(self, flags, data_processor): 14 | self.__flags = flags 15 | self.__data_processor = data_processor 16 | self.__train_data_file = None 17 | self.__dev_data_file = None 18 | self.__class_data_file = None 19 | self.__classes_cache = None 20 | 21 | 22 | def define_flags(self): 23 | self.__flags.DEFINE_string("train_data_file", "./data/kkk.train", "Data source for the training data.") 24 | self.__flags.DEFINE_string("dev_data_file", "./data/kkk.dev", "Data source for the cross validation data.") 25 | self.__flags.DEFINE_string("class_data_file", "./data/kkk.cls", "Data source for the class list.") 26 | 27 | def prepare_data(self): 28 | self.__resolve_params() 29 | x_train, y_train = self.__load_data_and_labels(self.__train_data_file) 30 | x_dev, y_dev = self.__load_data_and_labels(self.__dev_data_file) 31 | 32 | max_doc_len = max([len(doc.decode("utf-8")) for doc in x_train]) 33 | max_doc_len_dev = max([len(doc.decode("utf-8")) for doc in x_dev]) 34 | if max_doc_len_dev > max_doc_len: 35 | max_doc_len = max_doc_len_dev 36 | # Build vocabulary 37 | self.vocab_processor = self.__data_processor.vocab_processor(x_train, x_dev) 38 | x_train = np.array(list(self.vocab_processor.fit_transform(x_train))) 39 | # Build vocabulary 40 | x_dev = np.array(list(self.vocab_processor.fit_transform(x_dev))) 41 | return [x_train, y_train, x_dev, y_dev] 42 | 43 | def restore_vocab_processor(self, vocab_path): 44 | return self.__data_processor.restore_vocab_processor(vocab_path) 45 | 46 | def class_labels(self, class_indexes): 47 | return [ self.__classes()[idx] for idx in class_indexes ] 48 | 49 | def class_count(self): 50 | return self.__classes().__len__() 51 | 52 | def load_dev_data_and_labels(self): 53 | self.__resolve_params() 54 | x_dev, y_dev = self.__load_data_and_labels(self.__dev_data_file) 55 | return [x_dev, y_dev] 56 | 57 | def load_data_and_labels(self): 58 | self.__resolve_params() 59 | x_train, y_train = self.__load_data_and_labels(self.__train_data_file) 60 | x_dev, y_dev = self.__load_data_and_labels(self.__dev_data_file) 61 | x_all = x_train + x_dev 62 | y_all = np.concatenate([y_train, y_dev], 0) 63 | return [x_all, y_all] 64 | 65 | def __load_data_and_labels(self, data_file): 66 | x_text = [] 67 | y = [] 68 | with open(data_file, 'r') as tsvin: 69 | classes = self.__classes() 70 | one_hot_vectors = np.eye(len(classes), dtype=int) 71 | class_vectors = {} 72 | for i, cls in enumerate(classes): 73 | class_vectors[cls] = one_hot_vectors[i] 74 | tsvin = csv.reader(tsvin, delimiter=',') 75 | for row in tsvin: 76 | data = self.__data_processor.clean_data(row[0]) 77 | x_text.append(data) 78 | y.append(class_vectors[row[1]]) 79 | return [x_text, np.array(y)] 80 | 81 | def __classes(self): 82 | self.__resolve_params() 83 | if self.__classes_cache is None: 84 | with open(self.__class_data_file, 'r') as catin: 85 | classes = list(catin.readlines()) 86 | self.__classes_cache = [s.strip() for s in classes] 87 | return self.__classes_cache 88 | 89 | def __resolve_params(self): 90 | if self.__class_data_file is None: 91 | self.__train_data_file = self.__flags.FLAGS.train_data_file 92 | self.__dev_data_file = self.__flags.FLAGS.dev_data_file 93 | self.__class_data_file = self.__flags.FLAGS.class_data_file -------------------------------------------------------------------------------- /text_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class TextCNN(object): 6 | """ 7 | A CNN for text classification. 8 | Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer. 9 | """ 10 | def __init__( 11 | self, sequence_length, num_classes, vocab_size, 12 | embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0): 13 | 14 | # Placeholders for input, output and dropout 15 | self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x") 16 | self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y") 17 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 18 | 19 | # Keeping track of l2 regularization loss (optional) 20 | l2_loss = tf.constant(0.0) 21 | 22 | # Embedding layer 23 | with tf.device('/cpu:0'), tf.name_scope("embedding"): 24 | W = tf.Variable( 25 | tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), 26 | name="W") 27 | self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x) 28 | self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1) 29 | 30 | # Create a convolution + maxpool layer for each filter size 31 | pooled_outputs = [] 32 | for i, filter_size in enumerate(filter_sizes): 33 | with tf.name_scope("conv-maxpool-%s" % filter_size): 34 | # Convolution Layer 35 | filter_shape = [filter_size, embedding_size, 1, num_filters] 36 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W") 37 | b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b") 38 | conv = tf.nn.conv2d( 39 | self.embedded_chars_expanded, 40 | W, 41 | strides=[1, 1, 1, 1], 42 | padding="VALID", 43 | name="conv") 44 | # Apply nonlinearity 45 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 46 | # Maxpooling over the outputs 47 | pooled = tf.nn.max_pool( 48 | h, 49 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 50 | strides=[1, 1, 1, 1], 51 | padding='VALID', 52 | name="pool") 53 | pooled_outputs.append(pooled) 54 | 55 | # Combine all the pooled features 56 | num_filters_total = num_filters * len(filter_sizes) 57 | self.h_pool = tf.concat(3, pooled_outputs) 58 | self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total]) 59 | 60 | # Add dropout 61 | with tf.name_scope("dropout"): 62 | self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob) 63 | 64 | # Final (unnormalized) scores and predictions 65 | with tf.name_scope("output"): 66 | W = tf.get_variable( 67 | "W", 68 | shape=[num_filters_total, num_classes], 69 | initializer=tf.contrib.layers.xavier_initializer()) 70 | b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b") 71 | l2_loss += tf.nn.l2_loss(W) 72 | l2_loss += tf.nn.l2_loss(b) 73 | self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores") 74 | self.predictions = tf.argmax(self.scores, 1, name="predictions") 75 | 76 | # CalculateMean cross-entropy loss 77 | with tf.name_scope("loss"): 78 | losses = tf.nn.softmax_cross_entropy_with_logits(self.scores, self.input_y) 79 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 80 | 81 | # Accuracy 82 | with tf.name_scope("accuracy"): 83 | correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1)) 84 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 85 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import os 6 | import time 7 | import datetime 8 | import data_helpers 9 | from text_cnn import TextCNN 10 | from multi_class_data_loader import MultiClassDataLoader 11 | from word_data_processor import WordDataProcessor 12 | 13 | # Parameters 14 | # ================================================== 15 | 16 | # Model Hyperparameters 17 | tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)") 18 | tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')") 19 | tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)") 20 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)") 21 | tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularizaion lambda (default: 0.0)") 22 | 23 | # Training parameters 24 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)") 25 | tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)") 26 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps (default: 100)") 27 | tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (default: 100)") 28 | # Misc Parameters 29 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 30 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 31 | 32 | data_loader = MultiClassDataLoader(tf.flags, WordDataProcessor()) 33 | data_loader.define_flags() 34 | 35 | FLAGS = tf.flags.FLAGS 36 | FLAGS._parse_flags() 37 | print("\nParameters:") 38 | for attr, value in sorted(FLAGS.__flags.items()): 39 | print("{}={}".format(attr.upper(), value)) 40 | print("") 41 | 42 | 43 | # Data Preparatopn 44 | # ================================================== 45 | 46 | # Load data 47 | print("Loading data...") 48 | x_train, y_train, x_dev, y_dev = data_loader.prepare_data() 49 | vocab_processor = data_loader.vocab_processor 50 | 51 | print("Vocabulary Size: {:d}".format(len(vocab_processor.vocabulary_))) 52 | print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev))) 53 | 54 | 55 | # Training 56 | # ================================================== 57 | 58 | with tf.Graph().as_default(): 59 | session_conf = tf.ConfigProto( 60 | allow_soft_placement=FLAGS.allow_soft_placement, 61 | log_device_placement=FLAGS.log_device_placement) 62 | sess = tf.Session(config=session_conf) 63 | with sess.as_default(): 64 | cnn = TextCNN( 65 | sequence_length=x_train.shape[1], 66 | num_classes=y_train.shape[1], 67 | vocab_size=len(vocab_processor.vocabulary_), 68 | embedding_size=FLAGS.embedding_dim, 69 | filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), 70 | num_filters=FLAGS.num_filters, 71 | l2_reg_lambda=FLAGS.l2_reg_lambda) 72 | 73 | # Define Training procedure 74 | global_step = tf.Variable(0, name="global_step", trainable=False) 75 | optimizer = tf.train.AdamOptimizer(1e-3) 76 | grads_and_vars = optimizer.compute_gradients(cnn.loss) 77 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) 78 | 79 | # Keep track of gradient values and sparsity (optional) 80 | grad_summaries = [] 81 | for g, v in grads_and_vars: 82 | if g is not None: 83 | grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g) 84 | sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) 85 | grad_summaries.append(grad_hist_summary) 86 | grad_summaries.append(sparsity_summary) 87 | grad_summaries_merged = tf.merge_summary(grad_summaries) 88 | 89 | # Output directory for models and summaries 90 | timestamp = str(int(time.time())) 91 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 92 | print("Writing to {}\n".format(out_dir)) 93 | 94 | # Summaries for loss and accuracy 95 | loss_summary = tf.scalar_summary("loss", cnn.loss) 96 | acc_summary = tf.scalar_summary("accuracy", cnn.accuracy) 97 | 98 | # Train Summaries 99 | train_summary_op = tf.merge_summary([loss_summary, acc_summary, grad_summaries_merged]) 100 | train_summary_dir = os.path.join(out_dir, "summaries", "train") 101 | train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph) 102 | 103 | # Dev summaries 104 | dev_summary_op = tf.merge_summary([loss_summary, acc_summary]) 105 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 106 | dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph) 107 | 108 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 109 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 110 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 111 | if not os.path.exists(checkpoint_dir): 112 | os.makedirs(checkpoint_dir) 113 | saver = tf.train.Saver(tf.global_variables()) 114 | 115 | # Write vocabulary 116 | vocab_processor.save(os.path.join(out_dir, "vocab")) 117 | 118 | # Initialize all variables 119 | sess.run(tf.global_variables_initializer()) 120 | 121 | def train_step(x_batch, y_batch): 122 | """ 123 | A single training step 124 | """ 125 | feed_dict = { 126 | cnn.input_x: x_batch, 127 | cnn.input_y: y_batch, 128 | cnn.dropout_keep_prob: FLAGS.dropout_keep_prob 129 | } 130 | _, step, summaries, loss, accuracy = sess.run( 131 | [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy], 132 | feed_dict) 133 | time_str = datetime.datetime.now().isoformat() 134 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 135 | train_summary_writer.add_summary(summaries, step) 136 | 137 | def dev_step(x_batch, y_batch, writer=None): 138 | """ 139 | Evaluates model on a dev set 140 | """ 141 | feed_dict = { 142 | cnn.input_x: x_batch, 143 | cnn.input_y: y_batch, 144 | cnn.dropout_keep_prob: 1.0 145 | } 146 | step, summaries, loss, accuracy = sess.run( 147 | [global_step, dev_summary_op, cnn.loss, cnn.accuracy], 148 | feed_dict) 149 | time_str = datetime.datetime.now().isoformat() 150 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 151 | if writer: 152 | writer.add_summary(summaries, step) 153 | 154 | # Generate batches 155 | batches = data_helpers.batch_iter( 156 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) 157 | # Training loop. For each batch... 158 | for batch in batches: 159 | x_batch, y_batch = zip(*batch) 160 | train_step(x_batch, y_batch) 161 | current_step = tf.train.global_step(sess, global_step) 162 | if current_step % FLAGS.evaluate_every == 0: 163 | print("\nEvaluation:") 164 | dev_step(x_dev, y_dev, writer=dev_summary_writer) 165 | print("") 166 | if current_step % FLAGS.checkpoint_every == 0: 167 | path = saver.save(sess, checkpoint_prefix, global_step=current_step) 168 | print("Saved model checkpoint to {}\n".format(path)) 169 | -------------------------------------------------------------------------------- /word_data_processor.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from tensorflow.contrib import learn 5 | 6 | class WordDataProcessor(object): 7 | def vocab_processor(_, *texts): 8 | max_document_length = 0 9 | for text in texts: 10 | max_doc_len = max([len(line.split(" ")) for line in text]) 11 | if max_doc_len > max_document_length: 12 | max_document_length = max_doc_len 13 | return learn.preprocessing.VocabularyProcessor(max_document_length) 14 | 15 | def restore_vocab_processor(_, vocab_path): 16 | return learn.preprocessing.VocabularyProcessor.restore(vocab_path) 17 | 18 | def clean_data(_, string): 19 | """ 20 | 형태소(DHA) 분석된 결과로 학습할 것이므로 데이타 정제는 필요 없음 21 | """ 22 | if ":" not in string: 23 | string = string.strip().lower() 24 | return string 25 | --------------------------------------------------------------------------------