├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ └── feature.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE ├── README.md ├── imgs └── 2019-04-29_TensorBoard.png ├── kobert ├── __init__.py ├── mxnet_kobert.py ├── onnx_kobert.py ├── pytorch_kobert.py └── utils │ ├── __init__.py │ ├── aws_s3_downloader.py │ └── utils.py ├── kobert_hf ├── README.md ├── kobert_tokenizer │ ├── __init__.py │ └── kobert_tokenizer.py ├── requirements.txt └── setup.py ├── logs └── bert_naver_small_512_news_simple_20190624.txt ├── requirements.txt ├── scripts └── NSMC │ ├── naver_review_classifications_gluon_kobert.ipynb │ └── naver_review_classifications_pytorch_kobert.ipynb └── setup.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: 버그 관련 리포팅을 합니다. 4 | title: "[BUG] " 5 | labels: bug 6 | assignees: "" 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | ## To Reproduce 13 | 14 | 15 | 버그를 재현하기 위한 재현절차를 작성해주세요. 16 | 17 | 1. - 18 | 2. - 19 | 3. - 20 | 21 | ## Expected behavior 22 | 23 | 24 | ## Environment 25 | 26 | 27 | ## Additional context 28 | 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature 3 | about: 개발할 기능에 대해 서술합니다. 4 | title: "[FEATURE] " 5 | labels: enhancement 6 | assignees: "" 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | ## Motivation 13 | 14 | 15 | ## Pitch 16 | 17 | 18 | ## Additional context 19 | 20 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Pull Request 2 | 레파지토리에 기여해주셔서 감사드립니다. 3 | 4 | 해당 PR을 제출하기 전에 아래 사항이 완료되었는지 확인 부탁드립니다: 5 | - [ ] 작성한 코드가 어떤 에러나 경고없이 빌드가 되었나요? 6 | - [ ] 충분한 테스트를 수행하셨나요? 7 | 8 | ## 1. 해당 PR은 어떤 내용인가요? 9 | 10 | 11 | ## 2. PR과 관련된 이슈가 있나요? 12 | 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright (c) 2019 SK T-Brain 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KoBERT 2 | 3 | * [KoBERT](#kobert) 4 | * [Korean BERT pre-trained cased (KoBERT)](#korean-bert-pre-trained-cased-kobert) 5 | * [Why'?'](#why) 6 | * [Training Environment](#training-environment) 7 | * [Requirements](#requirements) 8 | * [How to install](#how-to-install) 9 | * [How to use](#how-to-use) 10 | * [Using with PyTorch](#using-with-pytorch) 11 | * [Using with ONNX](#using-with-onnx) 12 | * [Using with MXNet-Gluon](#using-with-mxnet-gluon) 13 | * [Tokenizer](#tokenizer) 14 | * [Subtasks](#subtasks) 15 | * [Naver Sentiment Analysis](#naver-sentiment-analysis) 16 | * [KoBERT와 CRF로 만든 한국어 객체명인식기](#kobert와-crf로-만든-한국어-객체명인식기) 17 | * [Korean Sentence BERT](#korean-sentence-bert) 18 | * [Release](#release) 19 | * [Contacts](#contacts) 20 | * [License](#license) 21 | 22 | --- 23 | 24 | ## Korean BERT pre-trained cased (KoBERT) 25 | 26 | ### Why'?' 27 | 28 | * 구글 [BERT base multilingual cased](https://github.com/google-research/bert/blob/master/multilingual.md)의 한국어 성능 한계 29 | 30 | ### Training Environment 31 | 32 | * Architecture 33 | 34 | ```python 35 | predefined_args = { 36 | 'attention_cell': 'multi_head', 37 | 'num_layers': 12, 38 | 'units': 768, 39 | 'hidden_size': 3072, 40 | 'max_length': 512, 41 | 'num_heads': 12, 42 | 'scaled': True, 43 | 'dropout': 0.1, 44 | 'use_residual': True, 45 | 'embed_size': 768, 46 | 'embed_dropout': 0.1, 47 | 'token_type_vocab_size': 2, 48 | 'word_embed': None, 49 | } 50 | ``` 51 | 52 | * 학습셋 53 | 54 | | 데이터 | 문장 | 단어 | 55 | | ----------- | ---- | ---- | 56 | | 한국어 위키 | 5M | 54M | 57 | 58 | * 학습 환경 59 | * V100 GPU x 32, Horovod(with InfiniBand) 60 | 61 | ![2019-04-29 텐서보드 로그](imgs/2019-04-29_TensorBoard.png) 62 | 63 | * 사전(Vocabulary) 64 | * 크기 : 8,002 65 | * 한글 위키 기반으로 학습한 토크나이저(SentencePiece) 66 | * Less number of parameters(92M < 110M ) 67 | 68 | ### Requirements 69 | 70 | * see [requirements.txt](https://github.com/SKTBrain/KoBERT/blob/master/requirements.txt) 71 | 72 | ### How to install 73 | 74 | * Install KoBERT as a python package 75 | 76 | ```sh 77 | pip install git+https://git@github.com/SKTBrain/KoBERT.git@master 78 | ``` 79 | 80 | * If you want to modify source codes, please clone this repository 81 | 82 | ```sh 83 | git clone https://github.com/SKTBrain/KoBERT.git 84 | cd KoBERT 85 | pip install -r requirements.txt 86 | ``` 87 | 88 | --- 89 | 90 | ## How to use 91 | 92 | ### PyTorch 93 | 94 | *Huggingface transformers API가 편하신 분은 [여기](kobert_hf)를 참고하세요.* 95 | 96 | ```python 97 | >>> import torch 98 | >>> from kobert import get_pytorch_kobert_model 99 | >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 100 | >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 101 | >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 102 | >>> model, vocab = get_pytorch_kobert_model() 103 | >>> sequence_output, pooled_output = model(input_ids, input_mask, token_type_ids) 104 | >>> pooled_output.shape 105 | torch.Size([2, 768]) 106 | >>> vocab 107 | Vocab(size=8002, unk="[UNK]", reserved="['[MASK]', '[SEP]', '[CLS]']") 108 | >>> # Last Encoding Layer 109 | >>> sequence_output[0] 110 | tensor([[-0.2461, 0.2428, 0.2590, ..., -0.4861, -0.0731, 0.0756], 111 | [-0.2478, 0.2420, 0.2552, ..., -0.4877, -0.0727, 0.0754], 112 | [-0.2472, 0.2420, 0.2561, ..., -0.4874, -0.0733, 0.0765]], 113 | grad_fn=) 114 | ``` 115 | 116 | `model`은 디폴트로 `eval()`모드로 리턴됨, 따라서 학습 용도로 사용시 `model.train()`명령을 통해 학습 모드로 변경할 필요가 있다. 117 | 118 | * Naver Sentiment Analysis Fine-Tuning with pytorch 119 | * Colab에서 [런타임] - [런타임 유형 변경] - 하드웨어 가속기(GPU) 사용을 권장합니다. 120 | * [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SKTBrain/KoBERT/blob/master/scripts/NSMC/naver_review_classifications_pytorch_kobert.ipynb) 121 | 122 | ### ONNX 123 | 124 | ```python 125 | >>> import onnxruntime 126 | >>> import numpy as np 127 | >>> from kobert import get_onnx_kobert_model 128 | >>> onnx_path = get_onnx_kobert_model() 129 | >>> sess = onnxruntime.InferenceSession(onnx_path) 130 | >>> input_ids = [[31, 51, 99], [15, 5, 0]] 131 | >>> input_mask = [[1, 1, 1], [1, 1, 0]] 132 | >>> token_type_ids = [[0, 0, 1], [0, 1, 0]] 133 | >>> len_seq = len(input_ids[0]) 134 | >>> pred_onnx = sess.run(None, {'input_ids':np.array(input_ids), 135 | >>> 'token_type_ids':np.array(token_type_ids), 136 | >>> 'input_mask':np.array(input_mask), 137 | >>> 'position_ids':np.array(range(len_seq))}) 138 | >>> # Last Encoding Layer 139 | >>> pred_onnx[-2][0] 140 | array([[-0.24610452, 0.24282141, 0.25895312, ..., -0.48613444, 141 | -0.07305173, 0.07560554], 142 | [-0.24783179, 0.24200465, 0.25520486, ..., -0.4877185 , 143 | -0.0727044 , 0.07536091], 144 | [-0.24721591, 0.24196623, 0.2560626 , ..., -0.48743123, 145 | -0.07326943, 0.07650235]], dtype=float32) 146 | ``` 147 | 148 | _ONNX 컨버팅은 [soeque1](https://github.com/soeque1)께서 도움을 주셨습니다._ 149 | 150 | ### MXNet-Gluon 151 | 152 | ```python 153 | >>> import mxnet as mx 154 | >>> from kobert import get_mxnet_kobert_model 155 | >>> input_id = mx.nd.array([[31, 51, 99], [15, 5, 0]]) 156 | >>> input_mask = mx.nd.array([[1, 1, 1], [1, 1, 0]]) 157 | >>> token_type_ids = mx.nd.array([[0, 0, 1], [0, 1, 0]]) 158 | >>> model, vocab = get_mxnet_kobert_model(use_decoder=False, use_classifier=False) 159 | >>> encoder_layer, pooled_output = model(input_id, token_type_ids) 160 | >>> pooled_output.shape 161 | (2, 768) 162 | >>> vocab 163 | Vocab(size=8002, unk="[UNK]", reserved="['[MASK]', '[SEP]', '[CLS]']") 164 | >>> # Last Encoding Layer 165 | >>> encoder_layer[0] 166 | [[-0.24610372 0.24282135 0.2589539 ... -0.48613444 -0.07305248 167 | 0.07560539] 168 | [-0.24783105 0.242005 0.25520545 ... -0.48771808 -0.07270523 169 | 0.07536077] 170 | [-0.24721491 0.241966 0.25606337 ... -0.48743105 -0.07327032 171 | 0.07650219]] 172 | 173 | ``` 174 | 175 | * Naver Sentiment Analysis Fine-Tuning with MXNet 176 | * [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SKTBrain/KoBERT/blob/master/scripts/NSMC/naver_review_classifications_gluon_kobert.ipynb) 177 | 178 | ### Tokenizer 179 | 180 | * Pretrained [Sentencepiece](https://github.com/google/sentencepiece) tokenizer 181 | 182 | ```python 183 | >>> from gluonnlp.data import SentencepieceTokenizer 184 | >>> from kobert import get_tokenizer 185 | >>> tok_path = get_tokenizer() 186 | >>> sp = SentencepieceTokenizer(tok_path) 187 | >>> sp('한국어 모델을 공유합니다.') 188 | ['▁한국', '어', '▁모델', '을', '▁공유', '합니다', '.'] 189 | ``` 190 | 191 | --- 192 | 193 | ## Task Fine-tuning 194 | 195 | ### Naver Sentiment Analysis 196 | 197 | * Dataset : 198 | 199 | | Model | Accuracy | 200 | | --------------------------------------------------------------------------------------------------- | --------------------------------------------------------------- | 201 | | [BERT base multilingual cased](https://github.com/google-research/bert/blob/master/multilingual.md) | 0.875 | 202 | | KoBERT | **[0.901](logs/bert_naver_small_512_news_simple_20190624.txt)** | 203 | | [KoGPT2](https://github.com/SKT-AI/KoGPT2) | 0.899 | 204 | 205 | ### KoBERT와 CRF로 만든 한국어 객체명인식기 206 | 207 | * 208 | 209 | ```text 210 | 문장을 입력하세요: SKTBrain에서 KoBERT 모델을 공개해준 덕분에 BERT-CRF 기반 객체명인식기를 쉽게 개발할 수 있었다. 211 | len: 40, input_token:['[CLS]', '▁SK', 'T', 'B', 'ra', 'in', '에서', '▁K', 'o', 'B', 'ER', 'T', '▁모델', '을', '▁공개', '해', '준', '▁덕분에', '▁B', 'ER', 'T', '-', 'C', 'R', 'F', '▁기반', '▁', '객', '체', '명', '인', '식', '기를', '▁쉽게', '▁개발', '할', '▁수', '▁있었다', '.', '[SEP]'] 212 | len: 40, pred_ner_tag:['[CLS]', 'B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'B-POH', 'I-POH', 'I-POH', 'I-POH', 'I-POH', 'O', 'O', 'O', 'O', 'O', 'O', 'B-POH', 'I-POH', 'I-POH', 'I-POH', 'I-POH', 'I-POH', 'I-POH', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '[SEP]'] 213 | decoding_ner_sentence: [CLS] 에서 모델을 공개해준 덕분에 기반 객체명인식기를 쉽게 개발할 수 있었다.[SEP] 214 | ``` 215 | 216 | ### Korean Sentence BERT 217 | 218 | * 219 | 220 | |Model|Cosine Pearson|Cosine Spearman|Euclidean Pearson|Euclidean Spearman|Manhattan Pearson|Manhattan Spearman|Dot Pearson|Dot Spearman| 221 | |:------------------------:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| 222 | |NLl|65.05|68.48|68.81|68.18|68.90|68.20|65.22|66.81| 223 | |STS|**80.42**|**79.64**|**77.93**|77.43|**77.92**|77.44|**76.56**|**75.83**| 224 | |STS + NLI|78.81|78.47|77.68|**77.78**|77.71|**77.83**|75.75|75.22| 225 | --- 226 | 227 | ## Release 228 | 229 | * v0.2.3 230 | * support `onnx 1.8.0` 231 | * v0.2.2 232 | * fix `No module named 'kobert.utils'` 233 | * v0.2.1 234 | * guide default 'import statements' 235 | * v0.2 236 | * download large files from `aws s3` 237 | * rename functions 238 | * v0.1.2 239 | * Guaranteed compatibility with higher versions of transformers 240 | * fix pad token index id 241 | * v0.1.1 242 | * 사전(vocabulary)과 토크나이저 통합 243 | * v0.1 244 | * 초기 모델 릴리즈 245 | 246 | ## Contacts 247 | 248 | `KoBERT` 관련 이슈는 [이곳](https://github.com/SKTBrain/KoBERT/issues)에 등록해 주시기 바랍니다. 249 | 250 | ## License 251 | 252 | `KoBERT`는 `Apache-2.0` 라이선스 하에 공개되어 있습니다. 모델 및 코드를 사용할 경우 라이선스 내용을 준수해주세요. 라이선스 전문은 `LICENSE` 파일에서 확인하실 수 있습니다. 253 | -------------------------------------------------------------------------------- /imgs/2019-04-29_TensorBoard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SKTBrain/KoBERT/5c46b1c68e4755b54879431bd302db621f4d2f47/imgs/2019-04-29_TensorBoard.png -------------------------------------------------------------------------------- /kobert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 SK T-Brain Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from kobert.utils.utils import download, get_tokenizer 17 | from kobert.pytorch_kobert import get_pytorch_kobert_model 18 | from kobert.mxnet_kobert import get_mxnet_kobert_model 19 | from kobert.onnx_kobert import get_onnx_kobert_model 20 | 21 | __all__ = ("download", "get_tokenizer", "get_pytorch_kobert_model" ,"get_mxnet_kobert_model", "get_onnx_kobert_model") 22 | -------------------------------------------------------------------------------- /kobert/mxnet_kobert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 SK T-Brain Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import gluonnlp as nlp 17 | import mxnet as mx 18 | from gluonnlp.model import BERTEncoder, BERTModel 19 | 20 | from kobert import download, get_tokenizer 21 | 22 | 23 | def get_mxnet_kobert_model( 24 | use_pooler=True, 25 | use_decoder=True, 26 | use_classifier=True, 27 | ctx=mx.cpu(0), 28 | cachedir=".cache", 29 | ): 30 | def get_kobert_model( 31 | model_file, 32 | vocab_file, 33 | use_pooler=True, 34 | use_decoder=True, 35 | use_classifier=True, 36 | ctx=mx.cpu(0), 37 | ): 38 | vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece( 39 | vocab_file, padding_token="[PAD]" 40 | ) 41 | 42 | predefined_args = { 43 | "attention_cell": "multi_head", 44 | "num_layers": 12, 45 | "units": 768, 46 | "hidden_size": 3072, 47 | "max_length": 512, 48 | "num_heads": 12, 49 | "scaled": True, 50 | "dropout": 0.1, 51 | "use_residual": True, 52 | "embed_size": 768, 53 | "embed_dropout": 0.1, 54 | "token_type_vocab_size": 2, 55 | "word_embed": None, 56 | } 57 | 58 | encoder = BERTEncoder( 59 | num_layers=predefined_args["num_layers"], 60 | units=predefined_args["units"], 61 | hidden_size=predefined_args["hidden_size"], 62 | max_length=predefined_args["max_length"], 63 | num_heads=predefined_args["num_heads"], 64 | dropout=predefined_args["dropout"], 65 | output_attention=False, 66 | output_all_encodings=False, 67 | ) 68 | 69 | # BERT 70 | net = BERTModel( 71 | encoder, 72 | len(vocab_b_obj.idx_to_token), 73 | token_type_vocab_size=predefined_args["token_type_vocab_size"], 74 | units=predefined_args["units"], 75 | embed_size=predefined_args["embed_size"], 76 | word_embed=predefined_args["word_embed"], 77 | use_pooler=use_pooler, 78 | use_decoder=use_decoder, 79 | use_classifier=use_classifier, 80 | ) 81 | net.initialize(ctx=ctx) 82 | net.load_parameters(model_file, ctx, ignore_extra=True) 83 | return (net, vocab_b_obj) 84 | 85 | mxnet_kobert = { 86 | "url": "s3://skt-lsl-nlp-model/KoBERT/models/mxnet_kobert_45b6957552.params", 87 | "chksum": "45b6957552", 88 | } 89 | 90 | # download model 91 | model_info = mxnet_kobert 92 | model_path, is_cached = download( 93 | model_info["url"], model_info["chksum"], cachedir=cachedir 94 | ) 95 | # download vocab 96 | vocab_path = get_tokenizer() 97 | return get_kobert_model( 98 | model_path, vocab_path, use_pooler, use_decoder, use_classifier, ctx 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | import mxnet as mx 104 | from kobert import get_mxnet_kobert_model 105 | 106 | input_id = mx.nd.array([[31, 51, 99], [15, 5, 0]]) 107 | input_mask = mx.nd.array([[1, 1, 1], [1, 1, 0]]) 108 | token_type_ids = mx.nd.array([[0, 0, 1], [0, 1, 0]]) 109 | model, vocab = get_mxnet_kobert_model(use_decoder=False, use_classifier=False) 110 | encoder_layer, pooled_output = model(input_id, token_type_ids) 111 | print(pooled_output.shape) 112 | print(vocab) 113 | print(encoder_layer[0]) 114 | -------------------------------------------------------------------------------- /kobert/onnx_kobert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 SK T-Brain Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from kobert import download 17 | 18 | 19 | def get_onnx_kobert_model(cachedir=".cache"): 20 | """Get KoBERT ONNX file path after downloading""" 21 | onnx_kobert = { 22 | "url": "s3://skt-lsl-nlp-model/KoBERT/models/kobert.onnx1.8.0.onnx", 23 | "chksum": "6f6610f2e3b61da6de8dbce", 24 | } 25 | 26 | model_info = onnx_kobert 27 | model_path, is_cached = download( 28 | model_info["url"], model_info["chksum"], cachedir=cachedir 29 | ) 30 | return model_path 31 | 32 | 33 | def make_dummy_input(max_seq_len): 34 | def do_pad(x, max_seq_len, pad_id): 35 | return [_x + [pad_id] * (max_seq_len - len(_x)) for _x in x] 36 | 37 | input_ids = do_pad([[31, 51, 99], [15, 5]], max_seq_len, pad_id=1) 38 | token_type_ids = do_pad([[0, 0, 0], [0, 0]], max_seq_len, pad_id=0) 39 | input_mask = do_pad([[1, 1, 1], [1, 1]], max_seq_len, pad_id=0) 40 | position_ids = list(range(max_seq_len)) 41 | return (input_ids, token_type_ids, input_mask, position_ids) 42 | 43 | 44 | if __name__ == "__main__": 45 | import onnxruntime 46 | import numpy as np 47 | from kobert import get_onnx_kobert_model 48 | 49 | onnx_path = get_onnx_kobert_model() 50 | dummy_input = make_dummy_input(max_seq_len=512) 51 | so = onnxruntime.SessionOptions() 52 | sess = onnxruntime.InferenceSession(onnx_path) 53 | outputs = sess.run( 54 | None, 55 | { 56 | "input_ids": np.array(dummy_input[0]), 57 | "token_type_ids": np.array(dummy_input[1]), 58 | "input_mask": np.array(dummy_input[2]), 59 | "position_ids": np.array(dummy_input[3]), 60 | }, 61 | ) 62 | print(outputs[-2][0]) 63 | -------------------------------------------------------------------------------- /kobert/pytorch_kobert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 SK T-Brain Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from zipfile import ZipFile 18 | import torch 19 | from transformers import BertModel 20 | import gluonnlp as nlp 21 | 22 | from kobert import download, get_tokenizer 23 | 24 | 25 | def get_pytorch_kobert_model(ctx="cpu", cachedir=".cache"): 26 | def get_kobert_model(model_path, vocab_file, ctx="cpu"): 27 | bertmodel = BertModel.from_pretrained(model_path, return_dict=False) 28 | device = torch.device(ctx) 29 | bertmodel.to(device) 30 | bertmodel.eval() 31 | vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece( 32 | vocab_file, padding_token="[PAD]" 33 | ) 34 | return bertmodel, vocab_b_obj 35 | 36 | pytorch_kobert = { 37 | "url": "s3://skt-lsl-nlp-model/KoBERT/models/kobert_v1.zip", 38 | "chksum": "411b242919", # 411b2429199bc04558576acdcac6d498 39 | } 40 | 41 | # download model 42 | model_info = pytorch_kobert 43 | model_path, is_cached = download( 44 | model_info["url"], model_info["chksum"], cachedir=cachedir 45 | ) 46 | cachedir_full = os.path.expanduser(cachedir) 47 | zipf = ZipFile(os.path.expanduser(model_path)) 48 | zipf.extractall(path=cachedir_full) 49 | model_path = os.path.join(os.path.expanduser(cachedir), "kobert_from_pretrained") 50 | # download vocab 51 | vocab_path = get_tokenizer() 52 | return get_kobert_model(model_path, vocab_path, ctx) 53 | 54 | 55 | if __name__ == "__main__": 56 | import torch 57 | from kobert import get_pytorch_kobert_model 58 | 59 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 60 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 61 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 62 | model, vocab = get_pytorch_kobert_model() 63 | sequence_output, pooled_output = model(input_ids, input_mask, token_type_ids) 64 | print(pooled_output.shape) 65 | print(vocab) 66 | print(sequence_output[0]) 67 | -------------------------------------------------------------------------------- /kobert/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from kobert.utils.utils import download, get_tokenizer 2 | -------------------------------------------------------------------------------- /kobert/utils/aws_s3_downloader.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import os 3 | import sys 4 | from botocore import UNSIGNED 5 | from botocore.client import Config 6 | 7 | 8 | class AwsS3Downloader(object): 9 | def __init__( 10 | self, 11 | aws_access_key_id=None, 12 | aws_secret_access_key=None, 13 | ): 14 | self.resource = boto3.Session( 15 | aws_access_key_id=aws_access_key_id, 16 | aws_secret_access_key=aws_secret_access_key, 17 | ).resource("s3") 18 | self.client = boto3.client( 19 | "s3", 20 | aws_access_key_id=aws_access_key_id, 21 | aws_secret_access_key=aws_secret_access_key, 22 | config=Config(signature_version=UNSIGNED), 23 | ) 24 | 25 | def __split_url(self, url: str): 26 | if url.startswith("s3://"): 27 | url = url.replace("s3://", "") 28 | bucket, key = url.split("/", maxsplit=1) 29 | return bucket, key 30 | 31 | def download(self, url: str, local_dir: str): 32 | bucket, key = self.__split_url(url) 33 | filename = os.path.basename(key) 34 | file_path = os.path.join(local_dir, filename) 35 | 36 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 37 | meta_data = self.client.head_object(Bucket=bucket, Key=key) 38 | total_length = int(meta_data.get("ContentLength", 0)) 39 | 40 | downloaded = 0 41 | 42 | def progress(chunk): 43 | nonlocal downloaded 44 | downloaded += chunk 45 | done = int(50 * downloaded / total_length) 46 | sys.stdout.write( 47 | "\r{}[{}{}]".format(file_path, "█" * done, "." * (50 - done)) 48 | ) 49 | sys.stdout.flush() 50 | 51 | try: 52 | with open(file_path, "wb") as f: 53 | self.client.download_fileobj(bucket, key, f, Callback=progress) 54 | sys.stdout.write("\n") 55 | sys.stdout.flush() 56 | except: 57 | raise Exception(f"downloading file is failed. {url}") 58 | return file_path 59 | 60 | 61 | if __name__ == "__main__": 62 | s3 = AwsS3Downloader() 63 | 64 | s3.download( 65 | url="s3://skt-lsl-nlp-model/KoBERT/tokenizers/kobert_news_wiki_ko_cased-1087f8699e.spiece", 66 | local_dir=".cache", 67 | ) 68 | -------------------------------------------------------------------------------- /kobert/utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 SK T-Brain Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import hashlib 17 | import os 18 | 19 | from kobert.utils.aws_s3_downloader import AwsS3Downloader 20 | 21 | 22 | def download(url, chksum=None, cachedir=".cache"): 23 | cachedir_full = os.path.join(os.getcwd(), cachedir) 24 | os.makedirs(cachedir_full, exist_ok=True) 25 | filename = os.path.basename(url) 26 | file_path = os.path.join(cachedir_full, filename) 27 | if os.path.isfile(file_path): 28 | if hashlib.md5(open(file_path, "rb").read()).hexdigest()[:10] == chksum[:10]: 29 | print(f"using cached model. {file_path}") 30 | return file_path, True 31 | 32 | s3 = AwsS3Downloader() 33 | file_path = s3.download(url, cachedir_full) 34 | if chksum: 35 | assert ( 36 | chksum[:10] == hashlib.md5(open(file_path, "rb").read()).hexdigest()[:10] 37 | ), "corrupted file!" 38 | return file_path, False 39 | 40 | 41 | def get_tokenizer(cachedir=".cache"): 42 | """Get KoBERT Tokenizer file path after downloading""" 43 | tokenizer = { 44 | "url": "s3://skt-lsl-nlp-model/KoBERT/tokenizers/kobert_news_wiki_ko_cased-1087f8699e.spiece", 45 | "chksum": "ae5711deb3", 46 | } 47 | 48 | model_info = tokenizer 49 | model_path, is_cached = download(model_info["url"], model_info["chksum"], cachedir=cachedir) 50 | return model_path 51 | -------------------------------------------------------------------------------- /kobert_hf/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | - [Korean BERT pre-trained cased (KoBERT) for Huggingface Transformers](#korean-bert-pre-trained-cased-kobert-for-huggingface-transformers) 8 | - [Requirements](#requirements) 9 | - [How to install](#how-to-install) 10 | - [Tokenizer](#tokenizer) 11 | - [Model](#model) 12 | - [License](#license) 13 | 14 | 15 | 16 | --- 17 | 18 | ### Korean BERT pre-trained cased (KoBERT) for Huggingface Transformers 19 | 20 | KoBERT를 Huggingface.co 기반으로 사용할 수 있게 Wrapping 작업을 수행하였습니다. 21 | 22 | 23 | #### Requirements 24 | 25 | * Python >= 3.6 26 | * PyTorch >= 1.8.1 27 | * transformers >= 4.8.2 28 | * sentencepiece >= 0.1.91 29 | 30 | #### How to install 31 | 32 | ```sh 33 | pip install 'git+https://github.com/SKTBrain/KoBERT.git#egg=kobert_tokenizer&subdirectory=kobert_hf' 34 | ``` 35 | 36 | --- 37 | 38 | ### Tokenizer (BPE-dropout) 39 | 40 | [XLNetTokenizer](https://github.com/huggingface/transformers/blob/master/src/transformers/models/xlnet/tokenization_xlnet.py)를 활용하여 Wrapping 작업을 진행하였습니다. 41 | 42 | 기존 Tokenizer와 동일하게 사전 크기는 8,002개 입니다. 43 | 44 | 일반적인 토크나이저 사용시(예: inference) 아래와 같이 사용하면 됩니다. 45 | 46 | ```python 47 | > from kobert_tokenizer import KoBERTTokenizer 48 | > tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1') 49 | > tokenizer.encode("한국어 모델을 공유합니다.") 50 | [2, 4958, 6855, 2046, 7088, 1050, 7843, 54, 3] 51 | ``` 52 | 53 | [`BPE-dropout`](https://arxiv.org/pdf/1910.13267.pdf)을 이용하면 서비스에 적합한 띄어쓰기에 강건한 모델로 튜닝 할 수 있습니다. 학습시 아래와 유사한 토크나이저 설정으로 학습을 진행할 수 있습니다. 자세한 옵션 설명은 [이곳](https://github.com/google/sentencepiece/tree/master/python)을 참고하세요. 54 | 55 | ```python 56 | > from kobert_tokenizer import KoBERTTokenizer 57 | > tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1', sp_model_kwargs={'nbest_size': -1, 'alpha': 0.6, 'enable_sampling': True}) 58 | > tokenizer.encode("한국어 모델을 공유합니다.") 59 | [2, 4958, 6855, 2046, 7088, 1023, 7063, 7843, 54, 3] 60 | ``` 61 | 62 | 63 | 64 | ### Model 65 | ```python 66 | > import torch 67 | > from transformers import BertModel 68 | > model = BertModel.from_pretrained('skt/kobert-base-v1') 69 | > text = "한국어 모델을 공유합니다." 70 | > inputs = tokenizer.batch_encode_plus([text]) 71 | > out = model(input_ids = torch.tensor(inputs['input_ids']), 72 | attention_mask = torch.tensor(inputs['attention_mask'])) 73 | > out.pooler_output.shape 74 | torch.Size([1, 768]) 75 | 76 | ``` 77 | 78 | ### License 79 | 80 | `KoBERT`는 Apache-2.0 라이선스 하에 공개되어 있습니다. 모델 및 코드를 사용할 경우 라이선스 내용을 준수해주세요. 라이선스 전문은 `LICENSE` 파일에서 확인하실 수 있습니다. 81 | -------------------------------------------------------------------------------- /kobert_hf/kobert_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .kobert_tokenizer import KoBERTTokenizer 2 | -------------------------------------------------------------------------------- /kobert_hf/kobert_tokenizer/kobert_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 SKT AI Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, Dict, List, Optional 17 | from transformers.tokenization_utils import AddedToken 18 | from transformers import XLNetTokenizer 19 | from transformers import SPIECE_UNDERLINE 20 | 21 | 22 | class KoBERTTokenizer(XLNetTokenizer): 23 | padding_side = "right" 24 | 25 | def __init__( 26 | self, 27 | vocab_file, 28 | do_lower_case=False, 29 | remove_space=True, 30 | keep_accents=False, 31 | bos_token="[CLS]", 32 | eos_token="[SEP]", 33 | unk_token="[UNK]", 34 | sep_token="[SEP]", 35 | pad_token="[PAD]", 36 | cls_token="[CLS]", 37 | mask_token="[MASK]", 38 | additional_special_tokens=None, 39 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 40 | **kwargs 41 | ) -> None: 42 | # Mask token behave like a normal word, i.e. include the space before it 43 | mask_token = ( 44 | AddedToken(mask_token, lstrip=True, rstrip=False) 45 | if isinstance(mask_token, str) 46 | else mask_token 47 | ) 48 | 49 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 50 | 51 | super().__init__( 52 | vocab_file, 53 | do_lower_case=do_lower_case, 54 | remove_space=remove_space, 55 | keep_accents=keep_accents, 56 | bos_token=bos_token, 57 | eos_token=eos_token, 58 | unk_token=unk_token, 59 | sep_token=sep_token, 60 | pad_token=pad_token, 61 | cls_token=cls_token, 62 | mask_token=mask_token, 63 | additional_special_tokens=additional_special_tokens, 64 | sp_model_kwargs=self.sp_model_kwargs, 65 | **kwargs, 66 | ) 67 | self._pad_token_type_id = 0 68 | 69 | def build_inputs_with_special_tokens( 70 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 71 | ) -> List[int]: 72 | """ 73 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 74 | adding special tokens. An XLNet sequence has the following format: 75 | - single sequence: `` X `` 76 | - pair of sequences: `` A B `` 77 | Args: 78 | token_ids_0 (:obj:`List[int]`): 79 | List of IDs to which the special tokens will be added. 80 | token_ids_1 (:obj:`List[int]`, `optional`): 81 | Optional second list of IDs for sequence pairs. 82 | Returns: 83 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 84 | """ 85 | sep = [self.sep_token_id] 86 | cls = [self.cls_token_id] 87 | if token_ids_1 is None: 88 | return cls + token_ids_0 + sep 89 | return cls + token_ids_0 + sep + token_ids_1 + sep 90 | 91 | def _tokenize(self, text: str) -> List[str]: 92 | """Tokenize a string.""" 93 | text = self.preprocess_text(text) 94 | pieces = self.sp_model.encode(text, out_type=str, **self.sp_model_kwargs) 95 | new_pieces = [] 96 | for piece in pieces: 97 | if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit(): 98 | cur_pieces = self.sp_model.EncodeAsPieces( 99 | piece[:-1].replace(SPIECE_UNDERLINE, "") 100 | ) 101 | if ( 102 | piece[0] != SPIECE_UNDERLINE 103 | and cur_pieces[0][0] == SPIECE_UNDERLINE 104 | ): 105 | if len(cur_pieces[0]) == 1: 106 | cur_pieces = cur_pieces[1:] 107 | else: 108 | cur_pieces[0] = cur_pieces[0][1:] 109 | cur_pieces.append(piece[-1]) 110 | new_pieces.extend(cur_pieces) 111 | else: 112 | new_pieces.append(piece) 113 | 114 | return new_pieces 115 | 116 | def build_inputs_with_special_tokens( 117 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 118 | ) -> List[int]: 119 | """ 120 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 121 | adding special tokens. An XLNet sequence has the following format: 122 | 123 | - single sequence: `` X `` 124 | - pair of sequences: `` A B `` 125 | 126 | Args: 127 | token_ids_0 (:obj:`List[int]`): 128 | List of IDs to which the special tokens will be added. 129 | token_ids_1 (:obj:`List[int]`, `optional`): 130 | Optional second list of IDs for sequence pairs. 131 | 132 | Returns: 133 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 134 | """ 135 | sep = [self.sep_token_id] 136 | cls = [self.cls_token_id] 137 | if token_ids_1 is None: 138 | return cls + token_ids_0 + sep 139 | return cls + token_ids_0 + sep + token_ids_1 + sep 140 | 141 | def create_token_type_ids_from_sequences( 142 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 143 | ) -> List[int]: 144 | """ 145 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet 146 | sequence pair mask has the following format: 147 | 148 | :: 149 | 150 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 151 | | first sequence | second sequence | 152 | 153 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 154 | 155 | Args: 156 | token_ids_0 (:obj:`List[int]`): 157 | List of IDs. 158 | token_ids_1 (:obj:`List[int]`, `optional`): 159 | Optional second list of IDs for sequence pairs. 160 | 161 | Returns: 162 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 163 | sequence(s). 164 | """ 165 | sep = [self.sep_token_id] 166 | cls = [self.cls_token_id] 167 | if token_ids_1 is None: 168 | return len(cls + token_ids_0 + sep) * [0] 169 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 170 | -------------------------------------------------------------------------------- /kobert_hf/requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.8.1 2 | transformers >= 4.8.2 3 | sentencepiece >= 0.1.91 4 | -------------------------------------------------------------------------------- /kobert_hf/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | __version__ = "0.1" 4 | 5 | 6 | setup( 7 | name="kobert-tokenizer", 8 | version=__version__, 9 | url="https://github.com/SKTBrain/KoBERT", 10 | license="Apache-2.0", 11 | author="SeungHwan Jung", 12 | author_email="digit82@gmail.com", 13 | description="Korean BERT pre-trained cased (KoBERT) for HuggingFace ", 14 | packages=[ 15 | "kobert_tokenizer", 16 | ], 17 | long_description=open("README.md", encoding="utf-8").read(), 18 | zip_safe=False, 19 | include_package_data=True, 20 | ) 21 | -------------------------------------------------------------------------------- /logs/bert_naver_small_512_news_simple_20190624.txt: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 2 | INFO:root:Namespace(accumulate=2, batch_size=64, dev_batch_size=8, epochs=5, gpu=True, init_param='/root/gfs/gogamza/korean_bert_pre_trained/mxnet_model/korbert_model_054_00003271.params', log_interval=10, lr=5e-05, max_len=512, optimizer='bertadam', seed=2, tokenizer='/root/gfs/gogamza/korean_bert_pre_trained/tokenizer/tokenizer.model', vocab_file='/root/gfs/gogamza/korean_bert_pre_trained/vocabs/news_wiki_2019_small_bertvocab.json', warmup_ratio=0.1) 3 | INFO:root:Using gradient accumulation. Effective batch size = 128 4 | INFO:root:[Epoch 0 Batch 20/2349] loss=0.6985, lr=0.0000009, acc=0.521 5 | INFO:root:[Epoch 0 Batch 40/2349] loss=0.6912, lr=0.0000017, acc=0.525 6 | INFO:root:[Epoch 0 Batch 60/2349] loss=0.6763, lr=0.0000026, acc=0.543 7 | INFO:root:[Epoch 0 Batch 80/2349] loss=0.6519, lr=0.0000034, acc=0.572 8 | INFO:root:[Epoch 0 Batch 100/2349] loss=0.6092, lr=0.0000043, acc=0.598 9 | INFO:root:[Epoch 0 Batch 120/2349] loss=0.5466, lr=0.0000051, acc=0.623 10 | INFO:root:[Epoch 0 Batch 140/2349] loss=0.5037, lr=0.0000060, acc=0.643 11 | INFO:root:[Epoch 0 Batch 160/2349] loss=0.4727, lr=0.0000068, acc=0.662 12 | INFO:root:[Epoch 0 Batch 180/2349] loss=0.4602, lr=0.0000077, acc=0.677 13 | INFO:root:[Epoch 0 Batch 200/2349] loss=0.4767, lr=0.0000085, acc=0.688 14 | INFO:root:[Epoch 0 Batch 220/2349] loss=0.4414, lr=0.0000094, acc=0.699 15 | INFO:root:[Epoch 0 Batch 240/2349] loss=0.4079, lr=0.0000103, acc=0.709 16 | INFO:root:[Epoch 0 Batch 260/2349] loss=0.4265, lr=0.0000111, acc=0.716 17 | INFO:root:[Epoch 0 Batch 280/2349] loss=0.4092, lr=0.0000120, acc=0.724 18 | INFO:root:[Epoch 0 Batch 300/2349] loss=0.4460, lr=0.0000128, acc=0.729 19 | INFO:root:[Epoch 0 Batch 320/2349] loss=0.3898, lr=0.0000137, acc=0.736 20 | INFO:root:[Epoch 0 Batch 340/2349] loss=0.3804, lr=0.0000145, acc=0.742 21 | INFO:root:[Epoch 0 Batch 360/2349] loss=0.3750, lr=0.0000154, acc=0.747 22 | INFO:root:[Epoch 0 Batch 380/2349] loss=0.3776, lr=0.0000162, acc=0.752 23 | INFO:root:[Epoch 0 Batch 400/2349] loss=0.3826, lr=0.0000171, acc=0.755 24 | INFO:root:[Epoch 0 Batch 420/2349] loss=0.3574, lr=0.0000179, acc=0.760 25 | INFO:root:[Epoch 0 Batch 440/2349] loss=0.3806, lr=0.0000188, acc=0.762 26 | INFO:root:[Epoch 0 Batch 460/2349] loss=0.3789, lr=0.0000197, acc=0.765 27 | INFO:root:[Epoch 0 Batch 480/2349] loss=0.3418, lr=0.0000205, acc=0.769 28 | INFO:root:[Epoch 0 Batch 500/2349] loss=0.3144, lr=0.0000214, acc=0.773 29 | INFO:root:[Epoch 0 Batch 520/2349] loss=0.3721, lr=0.0000222, acc=0.776 30 | INFO:root:[Epoch 0 Batch 540/2349] loss=0.3593, lr=0.0000231, acc=0.778 31 | INFO:root:[Epoch 0 Batch 560/2349] loss=0.3343, lr=0.0000239, acc=0.781 32 | INFO:root:[Epoch 0 Batch 580/2349] loss=0.2992, lr=0.0000248, acc=0.784 33 | INFO:root:[Epoch 0 Batch 600/2349] loss=0.3499, lr=0.0000256, acc=0.786 34 | INFO:root:[Epoch 0 Batch 620/2349] loss=0.3485, lr=0.0000265, acc=0.788 35 | INFO:root:[Epoch 0 Batch 640/2349] loss=0.3567, lr=0.0000274, acc=0.790 36 | INFO:root:[Epoch 0 Batch 660/2349] loss=0.3207, lr=0.0000282, acc=0.792 37 | INFO:root:[Epoch 0 Batch 680/2349] loss=0.3397, lr=0.0000291, acc=0.794 38 | INFO:root:[Epoch 0 Batch 700/2349] loss=0.3379, lr=0.0000299, acc=0.796 39 | INFO:root:[Epoch 0 Batch 720/2349] loss=0.3357, lr=0.0000308, acc=0.797 40 | INFO:root:[Epoch 0 Batch 740/2349] loss=0.3422, lr=0.0000316, acc=0.799 41 | INFO:root:[Epoch 0 Batch 760/2349] loss=0.3019, lr=0.0000325, acc=0.801 42 | INFO:root:[Epoch 0 Batch 780/2349] loss=0.3251, lr=0.0000333, acc=0.802 43 | INFO:root:[Epoch 0 Batch 800/2349] loss=0.3399, lr=0.0000342, acc=0.803 44 | INFO:root:[Epoch 0 Batch 820/2349] loss=0.3116, lr=0.0000350, acc=0.805 45 | INFO:root:[Epoch 0 Batch 840/2349] loss=0.3557, lr=0.0000359, acc=0.806 46 | INFO:root:[Epoch 0 Batch 860/2349] loss=0.3411, lr=0.0000368, acc=0.808 47 | INFO:root:[Epoch 0 Batch 880/2349] loss=0.3002, lr=0.0000376, acc=0.809 48 | INFO:root:[Epoch 0 Batch 900/2349] loss=0.3168, lr=0.0000385, acc=0.810 49 | INFO:root:[Epoch 0 Batch 920/2349] loss=0.3541, lr=0.0000393, acc=0.811 50 | INFO:root:[Epoch 0 Batch 940/2349] loss=0.3555, lr=0.0000402, acc=0.811 51 | INFO:root:[Epoch 0 Batch 960/2349] loss=0.2993, lr=0.0000410, acc=0.813 52 | INFO:root:[Epoch 0 Batch 980/2349] loss=0.3127, lr=0.0000419, acc=0.814 53 | INFO:root:[Epoch 0 Batch 1000/2349] loss=0.3205, lr=0.0000427, acc=0.815 54 | INFO:root:[Epoch 0 Batch 1020/2349] loss=0.3432, lr=0.0000436, acc=0.816 55 | INFO:root:[Epoch 0 Batch 1040/2349] loss=0.3004, lr=0.0000444, acc=0.817 56 | INFO:root:[Epoch 0 Batch 1060/2349] loss=0.3220, lr=0.0000453, acc=0.818 57 | INFO:root:[Epoch 0 Batch 1080/2349] loss=0.3243, lr=0.0000462, acc=0.818 58 | INFO:root:[Epoch 0 Batch 1100/2349] loss=0.3269, lr=0.0000470, acc=0.819 59 | INFO:root:[Epoch 0 Batch 1120/2349] loss=0.3718, lr=0.0000479, acc=0.819 60 | INFO:root:[Epoch 0 Batch 1140/2349] loss=0.3387, lr=0.0000487, acc=0.820 61 | INFO:root:[Epoch 0 Batch 1160/2349] loss=0.3155, lr=0.0000496, acc=0.820 62 | INFO:root:[Epoch 0 Batch 1180/2349] loss=0.3200, lr=0.0000500, acc=0.821 63 | INFO:root:[Epoch 0 Batch 1200/2349] loss=0.3082, lr=0.0000499, acc=0.822 64 | INFO:root:[Epoch 0 Batch 1220/2349] loss=0.3230, lr=0.0000498, acc=0.823 65 | INFO:root:[Epoch 0 Batch 1240/2349] loss=0.3285, lr=0.0000497, acc=0.823 66 | INFO:root:[Epoch 0 Batch 1260/2349] loss=0.3115, lr=0.0000496, acc=0.824 67 | INFO:root:[Epoch 0 Batch 1280/2349] loss=0.3151, lr=0.0000495, acc=0.824 68 | INFO:root:[Epoch 0 Batch 1300/2349] loss=0.2915, lr=0.0000494, acc=0.825 69 | INFO:root:[Epoch 0 Batch 1320/2349] loss=0.3143, lr=0.0000493, acc=0.826 70 | INFO:root:[Epoch 0 Batch 1340/2349] loss=0.3417, lr=0.0000492, acc=0.826 71 | INFO:root:[Epoch 0 Batch 1360/2349] loss=0.2848, lr=0.0000491, acc=0.827 72 | INFO:root:[Epoch 0 Batch 1380/2349] loss=0.2971, lr=0.0000490, acc=0.828 73 | INFO:root:[Epoch 0 Batch 1400/2349] loss=0.2928, lr=0.0000489, acc=0.829 74 | INFO:root:[Epoch 0 Batch 1420/2349] loss=0.3258, lr=0.0000488, acc=0.829 75 | INFO:root:[Epoch 0 Batch 1440/2349] loss=0.2955, lr=0.0000487, acc=0.830 76 | INFO:root:[Epoch 0 Batch 1460/2349] loss=0.2928, lr=0.0000486, acc=0.831 77 | INFO:root:[Epoch 0 Batch 1480/2349] loss=0.3149, lr=0.0000485, acc=0.831 78 | INFO:root:[Epoch 0 Batch 1500/2349] loss=0.2939, lr=0.0000484, acc=0.831 79 | INFO:root:[Epoch 0 Batch 1520/2349] loss=0.2996, lr=0.0000483, acc=0.832 80 | INFO:root:[Epoch 0 Batch 1540/2349] loss=0.2749, lr=0.0000482, acc=0.833 81 | INFO:root:[Epoch 0 Batch 1560/2349] loss=0.3180, lr=0.0000482, acc=0.833 82 | INFO:root:[Epoch 0 Batch 1580/2349] loss=0.3121, lr=0.0000481, acc=0.834 83 | INFO:root:[Epoch 0 Batch 1600/2349] loss=0.3010, lr=0.0000480, acc=0.834 84 | INFO:root:[Epoch 0 Batch 1620/2349] loss=0.3024, lr=0.0000479, acc=0.834 85 | INFO:root:[Epoch 0 Batch 1640/2349] loss=0.2806, lr=0.0000478, acc=0.835 86 | INFO:root:[Epoch 0 Batch 1660/2349] loss=0.2992, lr=0.0000477, acc=0.836 87 | INFO:root:[Epoch 0 Batch 1680/2349] loss=0.2903, lr=0.0000476, acc=0.836 88 | INFO:root:[Epoch 0 Batch 1700/2349] loss=0.3172, lr=0.0000475, acc=0.836 89 | INFO:root:[Epoch 0 Batch 1720/2349] loss=0.3008, lr=0.0000474, acc=0.837 90 | INFO:root:[Epoch 0 Batch 1740/2349] loss=0.2878, lr=0.0000473, acc=0.837 91 | INFO:root:[Epoch 0 Batch 1760/2349] loss=0.2797, lr=0.0000472, acc=0.838 92 | INFO:root:[Epoch 0 Batch 1780/2349] loss=0.2787, lr=0.0000471, acc=0.839 93 | INFO:root:[Epoch 0 Batch 1800/2349] loss=0.2899, lr=0.0000470, acc=0.839 94 | INFO:root:[Epoch 0 Batch 1820/2349] loss=0.2943, lr=0.0000469, acc=0.839 95 | INFO:root:[Epoch 0 Batch 1840/2349] loss=0.3081, lr=0.0000468, acc=0.840 96 | INFO:root:[Epoch 0 Batch 1860/2349] loss=0.2504, lr=0.0000467, acc=0.840 97 | INFO:root:[Epoch 0 Batch 1880/2349] loss=0.2934, lr=0.0000466, acc=0.841 98 | INFO:root:[Epoch 0 Batch 1900/2349] loss=0.3081, lr=0.0000465, acc=0.841 99 | INFO:root:[Epoch 0 Batch 1920/2349] loss=0.2955, lr=0.0000464, acc=0.841 100 | INFO:root:[Epoch 0 Batch 1940/2349] loss=0.2660, lr=0.0000464, acc=0.842 101 | INFO:root:[Epoch 0 Batch 1960/2349] loss=0.2827, lr=0.0000463, acc=0.842 102 | INFO:root:[Epoch 0 Batch 1980/2349] loss=0.2583, lr=0.0000462, acc=0.843 103 | INFO:root:[Epoch 0 Batch 2000/2349] loss=0.2739, lr=0.0000461, acc=0.843 104 | INFO:root:[Epoch 0 Batch 2020/2349] loss=0.2727, lr=0.0000460, acc=0.844 105 | INFO:root:[Epoch 0 Batch 2040/2349] loss=0.2707, lr=0.0000459, acc=0.844 106 | INFO:root:[Epoch 0 Batch 2060/2349] loss=0.3077, lr=0.0000458, acc=0.844 107 | INFO:root:[Epoch 0 Batch 2080/2349] loss=0.2741, lr=0.0000457, acc=0.845 108 | INFO:root:[Epoch 0 Batch 2100/2349] loss=0.2567, lr=0.0000456, acc=0.845 109 | INFO:root:[Epoch 0 Batch 2120/2349] loss=0.2881, lr=0.0000455, acc=0.846 110 | INFO:root:[Epoch 0 Batch 2140/2349] loss=0.2805, lr=0.0000454, acc=0.846 111 | INFO:root:[Epoch 0 Batch 2160/2349] loss=0.2873, lr=0.0000453, acc=0.846 112 | INFO:root:[Epoch 0 Batch 2180/2349] loss=0.2876, lr=0.0000452, acc=0.846 113 | INFO:root:[Epoch 0 Batch 2200/2349] loss=0.2848, lr=0.0000451, acc=0.847 114 | INFO:root:[Epoch 0 Batch 2220/2349] loss=0.2573, lr=0.0000450, acc=0.847 115 | INFO:root:[Epoch 0 Batch 2240/2349] loss=0.2612, lr=0.0000449, acc=0.848 116 | INFO:root:[Epoch 0 Batch 2260/2349] loss=0.2889, lr=0.0000448, acc=0.848 117 | INFO:root:[Epoch 0 Batch 2280/2349] loss=0.2706, lr=0.0000447, acc=0.848 118 | INFO:root:[Epoch 0 Batch 2300/2349] loss=0.2847, lr=0.0000446, acc=0.848 119 | INFO:root:[Epoch 0 Batch 2320/2349] loss=0.2673, lr=0.0000445, acc=0.849 120 | INFO:root:[Epoch 0 Batch 2340/2349] loss=0.2624, lr=0.0000445, acc=0.849 121 | INFO:root:Validation accuracy: 0.889 122 | INFO:root:Time cost=1112.9s 123 | INFO:root:[Epoch 1 Batch 20/2349] loss=0.2233, lr=0.0000443, acc=0.909 124 | INFO:root:[Epoch 1 Batch 40/2349] loss=0.2262, lr=0.0000442, acc=0.911 125 | INFO:root:[Epoch 1 Batch 60/2349] loss=0.2202, lr=0.0000441, acc=0.913 126 | INFO:root:[Epoch 1 Batch 80/2349] loss=0.1996, lr=0.0000440, acc=0.913 127 | INFO:root:[Epoch 1 Batch 100/2349] loss=0.2070, lr=0.0000439, acc=0.915 128 | INFO:root:[Epoch 1 Batch 120/2349] loss=0.2591, lr=0.0000438, acc=0.911 129 | INFO:root:[Epoch 1 Batch 140/2349] loss=0.2206, lr=0.0000437, acc=0.912 130 | INFO:root:[Epoch 1 Batch 160/2349] loss=0.2414, lr=0.0000436, acc=0.912 131 | INFO:root:[Epoch 1 Batch 180/2349] loss=0.2285, lr=0.0000436, acc=0.911 132 | INFO:root:[Epoch 1 Batch 200/2349] loss=0.2258, lr=0.0000435, acc=0.911 133 | INFO:root:[Epoch 1 Batch 220/2349] loss=0.2438, lr=0.0000434, acc=0.910 134 | INFO:root:[Epoch 1 Batch 240/2349] loss=0.2635, lr=0.0000433, acc=0.909 135 | INFO:root:[Epoch 1 Batch 260/2349] loss=0.2210, lr=0.0000432, acc=0.909 136 | INFO:root:[Epoch 1 Batch 280/2349] loss=0.2180, lr=0.0000431, acc=0.910 137 | INFO:root:[Epoch 1 Batch 300/2349] loss=0.2051, lr=0.0000430, acc=0.910 138 | INFO:root:[Epoch 1 Batch 320/2349] loss=0.2268, lr=0.0000429, acc=0.910 139 | INFO:root:[Epoch 1 Batch 340/2349] loss=0.2329, lr=0.0000428, acc=0.910 140 | INFO:root:[Epoch 1 Batch 360/2349] loss=0.2296, lr=0.0000427, acc=0.910 141 | INFO:root:[Epoch 1 Batch 380/2349] loss=0.1956, lr=0.0000426, acc=0.911 142 | INFO:root:[Epoch 1 Batch 400/2349] loss=0.2434, lr=0.0000425, acc=0.910 143 | INFO:root:[Epoch 1 Batch 420/2349] loss=0.2016, lr=0.0000424, acc=0.911 144 | INFO:root:[Epoch 1 Batch 440/2349] loss=0.2303, lr=0.0000423, acc=0.910 145 | INFO:root:[Epoch 1 Batch 460/2349] loss=0.2212, lr=0.0000422, acc=0.910 146 | INFO:root:[Epoch 1 Batch 480/2349] loss=0.2246, lr=0.0000421, acc=0.910 147 | INFO:root:[Epoch 1 Batch 500/2349] loss=0.2499, lr=0.0000420, acc=0.909 148 | INFO:root:[Epoch 1 Batch 520/2349] loss=0.2516, lr=0.0000419, acc=0.909 149 | INFO:root:[Epoch 1 Batch 540/2349] loss=0.2362, lr=0.0000418, acc=0.909 150 | INFO:root:[Epoch 1 Batch 560/2349] loss=0.2645, lr=0.0000418, acc=0.908 151 | INFO:root:[Epoch 1 Batch 580/2349] loss=0.2116, lr=0.0000417, acc=0.908 152 | INFO:root:[Epoch 1 Batch 600/2349] loss=0.2422, lr=0.0000416, acc=0.908 153 | INFO:root:[Epoch 1 Batch 620/2349] loss=0.2013, lr=0.0000415, acc=0.908 154 | INFO:root:[Epoch 1 Batch 640/2349] loss=0.2036, lr=0.0000414, acc=0.908 155 | INFO:root:[Epoch 1 Batch 660/2349] loss=0.2012, lr=0.0000413, acc=0.909 156 | INFO:root:[Epoch 1 Batch 680/2349] loss=0.2188, lr=0.0000412, acc=0.909 157 | INFO:root:[Epoch 1 Batch 700/2349] loss=0.2139, lr=0.0000411, acc=0.909 158 | INFO:root:[Epoch 1 Batch 720/2349] loss=0.2246, lr=0.0000410, acc=0.909 159 | INFO:root:[Epoch 1 Batch 740/2349] loss=0.2387, lr=0.0000409, acc=0.909 160 | INFO:root:[Epoch 1 Batch 760/2349] loss=0.2298, lr=0.0000408, acc=0.909 161 | INFO:root:[Epoch 1 Batch 780/2349] loss=0.2360, lr=0.0000407, acc=0.909 162 | INFO:root:[Epoch 1 Batch 800/2349] loss=0.2352, lr=0.0000406, acc=0.908 163 | INFO:root:[Epoch 1 Batch 820/2349] loss=0.2290, lr=0.0000405, acc=0.909 164 | INFO:root:[Epoch 1 Batch 840/2349] loss=0.2175, lr=0.0000404, acc=0.909 165 | INFO:root:[Epoch 1 Batch 860/2349] loss=0.2253, lr=0.0000403, acc=0.909 166 | INFO:root:[Epoch 1 Batch 880/2349] loss=0.2683, lr=0.0000402, acc=0.908 167 | INFO:root:[Epoch 1 Batch 900/2349] loss=0.2149, lr=0.0000401, acc=0.908 168 | INFO:root:[Epoch 1 Batch 920/2349] loss=0.2285, lr=0.0000400, acc=0.908 169 | INFO:root:[Epoch 1 Batch 940/2349] loss=0.2042, lr=0.0000400, acc=0.908 170 | INFO:root:[Epoch 1 Batch 960/2349] loss=0.2121, lr=0.0000399, acc=0.909 171 | INFO:root:[Epoch 1 Batch 980/2349] loss=0.2260, lr=0.0000398, acc=0.909 172 | INFO:root:[Epoch 1 Batch 1000/2349] loss=0.2110, lr=0.0000397, acc=0.909 173 | INFO:root:[Epoch 1 Batch 1020/2349] loss=0.2101, lr=0.0000396, acc=0.909 174 | INFO:root:[Epoch 1 Batch 1040/2349] loss=0.1777, lr=0.0000395, acc=0.910 175 | INFO:root:[Epoch 1 Batch 1060/2349] loss=0.1972, lr=0.0000394, acc=0.910 176 | INFO:root:[Epoch 1 Batch 1080/2349] loss=0.2308, lr=0.0000393, acc=0.910 177 | INFO:root:[Epoch 1 Batch 1100/2349] loss=0.2192, lr=0.0000392, acc=0.910 178 | INFO:root:[Epoch 1 Batch 1120/2349] loss=0.2243, lr=0.0000391, acc=0.910 179 | INFO:root:[Epoch 1 Batch 1140/2349] loss=0.1994, lr=0.0000390, acc=0.910 180 | INFO:root:[Epoch 1 Batch 1160/2349] loss=0.2122, lr=0.0000389, acc=0.910 181 | INFO:root:[Epoch 1 Batch 1180/2349] loss=0.2322, lr=0.0000388, acc=0.909 182 | INFO:root:[Epoch 1 Batch 1200/2349] loss=0.2328, lr=0.0000387, acc=0.909 183 | INFO:root:[Epoch 1 Batch 1220/2349] loss=0.2218, lr=0.0000386, acc=0.909 184 | INFO:root:[Epoch 1 Batch 1240/2349] loss=0.2323, lr=0.0000385, acc=0.909 185 | INFO:root:[Epoch 1 Batch 1260/2349] loss=0.2130, lr=0.0000384, acc=0.910 186 | INFO:root:[Epoch 1 Batch 1280/2349] loss=0.2207, lr=0.0000383, acc=0.910 187 | INFO:root:[Epoch 1 Batch 1300/2349] loss=0.1797, lr=0.0000382, acc=0.910 188 | INFO:root:[Epoch 1 Batch 1320/2349] loss=0.2173, lr=0.0000381, acc=0.910 189 | INFO:root:[Epoch 1 Batch 1340/2349] loss=0.2212, lr=0.0000381, acc=0.910 190 | INFO:root:[Epoch 1 Batch 1360/2349] loss=0.2119, lr=0.0000380, acc=0.910 191 | INFO:root:[Epoch 1 Batch 1380/2349] loss=0.2112, lr=0.0000379, acc=0.910 192 | INFO:root:[Epoch 1 Batch 1400/2349] loss=0.2070, lr=0.0000378, acc=0.910 193 | INFO:root:[Epoch 1 Batch 1420/2349] loss=0.2457, lr=0.0000377, acc=0.910 194 | INFO:root:[Epoch 1 Batch 1440/2349] loss=0.2218, lr=0.0000376, acc=0.910 195 | INFO:root:[Epoch 1 Batch 1460/2349] loss=0.1926, lr=0.0000375, acc=0.910 196 | INFO:root:[Epoch 1 Batch 1480/2349] loss=0.2147, lr=0.0000374, acc=0.910 197 | INFO:root:[Epoch 1 Batch 1500/2349] loss=0.2371, lr=0.0000373, acc=0.910 198 | INFO:root:[Epoch 1 Batch 1520/2349] loss=0.2106, lr=0.0000372, acc=0.910 199 | INFO:root:[Epoch 1 Batch 1540/2349] loss=0.2332, lr=0.0000371, acc=0.910 200 | INFO:root:[Epoch 1 Batch 1560/2349] loss=0.1809, lr=0.0000370, acc=0.911 201 | INFO:root:[Epoch 1 Batch 1580/2349] loss=0.2317, lr=0.0000369, acc=0.911 202 | INFO:root:[Epoch 1 Batch 1600/2349] loss=0.2128, lr=0.0000368, acc=0.911 203 | INFO:root:[Epoch 1 Batch 1620/2349] loss=0.2263, lr=0.0000367, acc=0.911 204 | INFO:root:[Epoch 1 Batch 1640/2349] loss=0.2215, lr=0.0000366, acc=0.911 205 | INFO:root:[Epoch 1 Batch 1660/2349] loss=0.2291, lr=0.0000365, acc=0.910 206 | INFO:root:[Epoch 1 Batch 1680/2349] loss=0.2073, lr=0.0000364, acc=0.911 207 | INFO:root:[Epoch 1 Batch 1700/2349] loss=0.2083, lr=0.0000363, acc=0.911 208 | INFO:root:[Epoch 1 Batch 1720/2349] loss=0.2227, lr=0.0000363, acc=0.911 209 | INFO:root:[Epoch 1 Batch 1740/2349] loss=0.2015, lr=0.0000362, acc=0.911 210 | INFO:root:[Epoch 1 Batch 1760/2349] loss=0.2062, lr=0.0000361, acc=0.911 211 | INFO:root:[Epoch 1 Batch 1780/2349] loss=0.2114, lr=0.0000360, acc=0.911 212 | INFO:root:[Epoch 1 Batch 1800/2349] loss=0.1953, lr=0.0000359, acc=0.911 213 | INFO:root:[Epoch 1 Batch 1820/2349] loss=0.2316, lr=0.0000358, acc=0.911 214 | INFO:root:[Epoch 1 Batch 1840/2349] loss=0.2266, lr=0.0000357, acc=0.911 215 | INFO:root:[Epoch 1 Batch 1860/2349] loss=0.2026, lr=0.0000356, acc=0.911 216 | INFO:root:[Epoch 1 Batch 1880/2349] loss=0.1875, lr=0.0000355, acc=0.911 217 | INFO:root:[Epoch 1 Batch 1900/2349] loss=0.2406, lr=0.0000354, acc=0.911 218 | INFO:root:[Epoch 1 Batch 1920/2349] loss=0.1935, lr=0.0000353, acc=0.911 219 | INFO:root:[Epoch 1 Batch 1940/2349] loss=0.2212, lr=0.0000352, acc=0.911 220 | INFO:root:[Epoch 1 Batch 1960/2349] loss=0.2096, lr=0.0000351, acc=0.911 221 | INFO:root:[Epoch 1 Batch 1980/2349] loss=0.2192, lr=0.0000350, acc=0.912 222 | INFO:root:[Epoch 1 Batch 2000/2349] loss=0.2126, lr=0.0000349, acc=0.912 223 | INFO:root:[Epoch 1 Batch 2020/2349] loss=0.2200, lr=0.0000348, acc=0.912 224 | INFO:root:[Epoch 1 Batch 2040/2349] loss=0.2011, lr=0.0000347, acc=0.912 225 | INFO:root:[Epoch 1 Batch 2060/2349] loss=0.2247, lr=0.0000346, acc=0.912 226 | INFO:root:[Epoch 1 Batch 2080/2349] loss=0.2204, lr=0.0000345, acc=0.912 227 | INFO:root:[Epoch 1 Batch 2100/2349] loss=0.2040, lr=0.0000345, acc=0.912 228 | INFO:root:[Epoch 1 Batch 2120/2349] loss=0.2322, lr=0.0000344, acc=0.912 229 | INFO:root:[Epoch 1 Batch 2140/2349] loss=0.1871, lr=0.0000343, acc=0.912 230 | INFO:root:[Epoch 1 Batch 2160/2349] loss=0.2286, lr=0.0000342, acc=0.912 231 | INFO:root:[Epoch 1 Batch 2180/2349] loss=0.1914, lr=0.0000341, acc=0.912 232 | INFO:root:[Epoch 1 Batch 2200/2349] loss=0.2238, lr=0.0000340, acc=0.912 233 | INFO:root:[Epoch 1 Batch 2220/2349] loss=0.2056, lr=0.0000339, acc=0.912 234 | INFO:root:[Epoch 1 Batch 2240/2349] loss=0.2146, lr=0.0000338, acc=0.912 235 | INFO:root:[Epoch 1 Batch 2260/2349] loss=0.1991, lr=0.0000337, acc=0.912 236 | INFO:root:[Epoch 1 Batch 2280/2349] loss=0.2153, lr=0.0000336, acc=0.912 237 | INFO:root:[Epoch 1 Batch 2300/2349] loss=0.2157, lr=0.0000335, acc=0.912 238 | INFO:root:[Epoch 1 Batch 2320/2349] loss=0.2053, lr=0.0000334, acc=0.912 239 | INFO:root:[Epoch 1 Batch 2340/2349] loss=0.2125, lr=0.0000333, acc=0.912 240 | INFO:root:Validation accuracy: 0.900 241 | INFO:root:Time cost=1119.2s 242 | INFO:root:[Epoch 2 Batch 20/2349] loss=0.1304, lr=0.0000332, acc=0.955 243 | INFO:root:[Epoch 2 Batch 40/2349] loss=0.1755, lr=0.0000331, acc=0.947 244 | INFO:root:[Epoch 2 Batch 60/2349] loss=0.1244, lr=0.0000330, acc=0.951 245 | INFO:root:[Epoch 2 Batch 80/2349] loss=0.1791, lr=0.0000329, acc=0.946 246 | INFO:root:[Epoch 2 Batch 100/2349] loss=0.1618, lr=0.0000328, acc=0.946 247 | INFO:root:[Epoch 2 Batch 120/2349] loss=0.1467, lr=0.0000327, acc=0.946 248 | INFO:root:[Epoch 2 Batch 140/2349] loss=0.1281, lr=0.0000326, acc=0.948 249 | INFO:root:[Epoch 2 Batch 160/2349] loss=0.1397, lr=0.0000325, acc=0.949 250 | INFO:root:[Epoch 2 Batch 180/2349] loss=0.1528, lr=0.0000324, acc=0.947 251 | INFO:root:[Epoch 2 Batch 200/2349] loss=0.1688, lr=0.0000323, acc=0.946 252 | INFO:root:[Epoch 2 Batch 220/2349] loss=0.1431, lr=0.0000322, acc=0.946 253 | INFO:root:[Epoch 2 Batch 240/2349] loss=0.1519, lr=0.0000321, acc=0.946 254 | INFO:root:[Epoch 2 Batch 260/2349] loss=0.1376, lr=0.0000320, acc=0.946 255 | INFO:root:[Epoch 2 Batch 280/2349] loss=0.1606, lr=0.0000319, acc=0.946 256 | INFO:root:[Epoch 2 Batch 300/2349] loss=0.1352, lr=0.0000318, acc=0.946 257 | INFO:root:[Epoch 2 Batch 320/2349] loss=0.1712, lr=0.0000318, acc=0.946 258 | INFO:root:[Epoch 2 Batch 340/2349] loss=0.1347, lr=0.0000317, acc=0.946 259 | INFO:root:[Epoch 2 Batch 360/2349] loss=0.1213, lr=0.0000316, acc=0.946 260 | INFO:root:[Epoch 2 Batch 380/2349] loss=0.1337, lr=0.0000315, acc=0.946 261 | INFO:root:[Epoch 2 Batch 400/2349] loss=0.1686, lr=0.0000314, acc=0.946 262 | INFO:root:[Epoch 2 Batch 420/2349] loss=0.1508, lr=0.0000313, acc=0.946 263 | INFO:root:[Epoch 2 Batch 440/2349] loss=0.1619, lr=0.0000312, acc=0.945 264 | INFO:root:[Epoch 2 Batch 460/2349] loss=0.1140, lr=0.0000311, acc=0.946 265 | INFO:root:[Epoch 2 Batch 480/2349] loss=0.1127, lr=0.0000310, acc=0.946 266 | INFO:root:[Epoch 2 Batch 500/2349] loss=0.1468, lr=0.0000309, acc=0.947 267 | INFO:root:[Epoch 2 Batch 520/2349] loss=0.1844, lr=0.0000308, acc=0.946 268 | INFO:root:[Epoch 2 Batch 540/2349] loss=0.1160, lr=0.0000307, acc=0.947 269 | INFO:root:[Epoch 2 Batch 560/2349] loss=0.1582, lr=0.0000306, acc=0.947 270 | INFO:root:[Epoch 2 Batch 580/2349] loss=0.1315, lr=0.0000305, acc=0.947 271 | INFO:root:[Epoch 2 Batch 600/2349] loss=0.1660, lr=0.0000304, acc=0.947 272 | INFO:root:[Epoch 2 Batch 620/2349] loss=0.1738, lr=0.0000303, acc=0.946 273 | INFO:root:[Epoch 2 Batch 640/2349] loss=0.1525, lr=0.0000302, acc=0.946 274 | INFO:root:[Epoch 2 Batch 660/2349] loss=0.1257, lr=0.0000301, acc=0.946 275 | INFO:root:[Epoch 2 Batch 680/2349] loss=0.1421, lr=0.0000300, acc=0.946 276 | INFO:root:[Epoch 2 Batch 700/2349] loss=0.1389, lr=0.0000299, acc=0.947 277 | INFO:root:[Epoch 2 Batch 720/2349] loss=0.1440, lr=0.0000299, acc=0.946 278 | INFO:root:[Epoch 2 Batch 740/2349] loss=0.1741, lr=0.0000298, acc=0.946 279 | INFO:root:[Epoch 2 Batch 760/2349] loss=0.1400, lr=0.0000297, acc=0.946 280 | INFO:root:[Epoch 2 Batch 780/2349] loss=0.1668, lr=0.0000296, acc=0.946 281 | INFO:root:[Epoch 2 Batch 800/2349] loss=0.1171, lr=0.0000295, acc=0.946 282 | INFO:root:[Epoch 2 Batch 820/2349] loss=0.1660, lr=0.0000294, acc=0.946 283 | INFO:root:[Epoch 2 Batch 840/2349] loss=0.1562, lr=0.0000293, acc=0.946 284 | INFO:root:[Epoch 2 Batch 860/2349] loss=0.1316, lr=0.0000292, acc=0.946 285 | INFO:root:[Epoch 2 Batch 880/2349] loss=0.1534, lr=0.0000291, acc=0.946 286 | INFO:root:[Epoch 2 Batch 900/2349] loss=0.1541, lr=0.0000290, acc=0.946 287 | INFO:root:[Epoch 2 Batch 920/2349] loss=0.1554, lr=0.0000289, acc=0.946 288 | INFO:root:[Epoch 2 Batch 940/2349] loss=0.1265, lr=0.0000288, acc=0.946 289 | INFO:root:[Epoch 2 Batch 960/2349] loss=0.1317, lr=0.0000287, acc=0.946 290 | INFO:root:[Epoch 2 Batch 980/2349] loss=0.1633, lr=0.0000286, acc=0.946 291 | INFO:root:[Epoch 2 Batch 1000/2349] loss=0.1158, lr=0.0000285, acc=0.946 292 | INFO:root:[Epoch 2 Batch 1020/2349] loss=0.1435, lr=0.0000284, acc=0.946 293 | INFO:root:[Epoch 2 Batch 1040/2349] loss=0.1600, lr=0.0000283, acc=0.946 294 | INFO:root:[Epoch 2 Batch 1060/2349] loss=0.1303, lr=0.0000282, acc=0.946 295 | INFO:root:[Epoch 2 Batch 1080/2349] loss=0.1450, lr=0.0000281, acc=0.946 296 | INFO:root:[Epoch 2 Batch 1100/2349] loss=0.1393, lr=0.0000281, acc=0.946 297 | INFO:root:[Epoch 2 Batch 1120/2349] loss=0.1438, lr=0.0000280, acc=0.946 298 | INFO:root:[Epoch 2 Batch 1140/2349] loss=0.1516, lr=0.0000279, acc=0.946 299 | INFO:root:[Epoch 2 Batch 1160/2349] loss=0.1219, lr=0.0000278, acc=0.946 300 | INFO:root:[Epoch 2 Batch 1180/2349] loss=0.1538, lr=0.0000277, acc=0.946 301 | INFO:root:[Epoch 2 Batch 1200/2349] loss=0.1509, lr=0.0000276, acc=0.946 302 | INFO:root:[Epoch 2 Batch 1220/2349] loss=0.1508, lr=0.0000275, acc=0.946 303 | INFO:root:[Epoch 2 Batch 1240/2349] loss=0.1380, lr=0.0000274, acc=0.946 304 | INFO:root:[Epoch 2 Batch 1260/2349] loss=0.1513, lr=0.0000273, acc=0.946 305 | INFO:root:[Epoch 2 Batch 1280/2349] loss=0.1543, lr=0.0000272, acc=0.946 306 | INFO:root:[Epoch 2 Batch 1300/2349] loss=0.1648, lr=0.0000271, acc=0.946 307 | INFO:root:[Epoch 2 Batch 1320/2349] loss=0.1586, lr=0.0000270, acc=0.946 308 | INFO:root:[Epoch 2 Batch 1340/2349] loss=0.1468, lr=0.0000269, acc=0.945 309 | INFO:root:[Epoch 2 Batch 1360/2349] loss=0.1388, lr=0.0000268, acc=0.945 310 | INFO:root:[Epoch 2 Batch 1380/2349] loss=0.1050, lr=0.0000267, acc=0.946 311 | INFO:root:[Epoch 2 Batch 1400/2349] loss=0.1374, lr=0.0000266, acc=0.946 312 | INFO:root:[Epoch 2 Batch 1420/2349] loss=0.1636, lr=0.0000265, acc=0.946 313 | INFO:root:[Epoch 2 Batch 1440/2349] loss=0.1430, lr=0.0000264, acc=0.946 314 | INFO:root:[Epoch 2 Batch 1460/2349] loss=0.1411, lr=0.0000263, acc=0.946 315 | INFO:root:[Epoch 2 Batch 1480/2349] loss=0.1259, lr=0.0000263, acc=0.946 316 | INFO:root:[Epoch 2 Batch 1500/2349] loss=0.1280, lr=0.0000262, acc=0.946 317 | INFO:root:[Epoch 2 Batch 1520/2349] loss=0.1701, lr=0.0000261, acc=0.946 318 | INFO:root:[Epoch 2 Batch 1540/2349] loss=0.1508, lr=0.0000260, acc=0.946 319 | INFO:root:[Epoch 2 Batch 1560/2349] loss=0.1237, lr=0.0000259, acc=0.946 320 | INFO:root:[Epoch 2 Batch 1580/2349] loss=0.1398, lr=0.0000258, acc=0.946 321 | INFO:root:[Epoch 2 Batch 1600/2349] loss=0.1406, lr=0.0000257, acc=0.946 322 | INFO:root:[Epoch 2 Batch 1620/2349] loss=0.1495, lr=0.0000256, acc=0.946 323 | INFO:root:[Epoch 2 Batch 1640/2349] loss=0.1369, lr=0.0000255, acc=0.946 324 | INFO:root:[Epoch 2 Batch 1660/2349] loss=0.1874, lr=0.0000254, acc=0.946 325 | INFO:root:[Epoch 2 Batch 1680/2349] loss=0.1558, lr=0.0000253, acc=0.946 326 | INFO:root:[Epoch 2 Batch 1700/2349] loss=0.1182, lr=0.0000252, acc=0.946 327 | INFO:root:[Epoch 2 Batch 1720/2349] loss=0.1452, lr=0.0000251, acc=0.946 328 | INFO:root:[Epoch 2 Batch 1740/2349] loss=0.1211, lr=0.0000250, acc=0.946 329 | INFO:root:[Epoch 2 Batch 1760/2349] loss=0.1866, lr=0.0000249, acc=0.946 330 | INFO:root:[Epoch 2 Batch 1780/2349] loss=0.1620, lr=0.0000248, acc=0.946 331 | INFO:root:[Epoch 2 Batch 1800/2349] loss=0.1377, lr=0.0000247, acc=0.946 332 | INFO:root:[Epoch 2 Batch 1820/2349] loss=0.1552, lr=0.0000246, acc=0.946 333 | INFO:root:[Epoch 2 Batch 1840/2349] loss=0.1447, lr=0.0000245, acc=0.946 334 | INFO:root:[Epoch 2 Batch 1860/2349] loss=0.1402, lr=0.0000245, acc=0.946 335 | INFO:root:[Epoch 2 Batch 1880/2349] loss=0.1288, lr=0.0000244, acc=0.946 336 | INFO:root:[Epoch 2 Batch 1900/2349] loss=0.1362, lr=0.0000243, acc=0.946 337 | INFO:root:[Epoch 2 Batch 1920/2349] loss=0.1413, lr=0.0000242, acc=0.946 338 | INFO:root:[Epoch 2 Batch 1940/2349] loss=0.1575, lr=0.0000241, acc=0.946 339 | INFO:root:[Epoch 2 Batch 1960/2349] loss=0.1297, lr=0.0000240, acc=0.946 340 | INFO:root:[Epoch 2 Batch 1980/2349] loss=0.1368, lr=0.0000239, acc=0.946 341 | INFO:root:[Epoch 2 Batch 2000/2349] loss=0.1324, lr=0.0000238, acc=0.946 342 | INFO:root:[Epoch 2 Batch 2020/2349] loss=0.1312, lr=0.0000237, acc=0.946 343 | INFO:root:[Epoch 2 Batch 2040/2349] loss=0.1133, lr=0.0000236, acc=0.946 344 | INFO:root:[Epoch 2 Batch 2060/2349] loss=0.1617, lr=0.0000235, acc=0.946 345 | INFO:root:[Epoch 2 Batch 2080/2349] loss=0.1629, lr=0.0000234, acc=0.946 346 | INFO:root:[Epoch 2 Batch 2100/2349] loss=0.1266, lr=0.0000233, acc=0.946 347 | INFO:root:[Epoch 2 Batch 2120/2349] loss=0.1597, lr=0.0000232, acc=0.946 348 | INFO:root:[Epoch 2 Batch 2140/2349] loss=0.1222, lr=0.0000231, acc=0.946 349 | INFO:root:[Epoch 2 Batch 2160/2349] loss=0.1443, lr=0.0000230, acc=0.946 350 | INFO:root:[Epoch 2 Batch 2180/2349] loss=0.1100, lr=0.0000229, acc=0.946 351 | INFO:root:[Epoch 2 Batch 2200/2349] loss=0.1378, lr=0.0000228, acc=0.946 352 | INFO:root:[Epoch 2 Batch 2220/2349] loss=0.1469, lr=0.0000227, acc=0.946 353 | INFO:root:[Epoch 2 Batch 2240/2349] loss=0.1492, lr=0.0000226, acc=0.946 354 | INFO:root:[Epoch 2 Batch 2260/2349] loss=0.1693, lr=0.0000226, acc=0.946 355 | INFO:root:[Epoch 2 Batch 2280/2349] loss=0.1691, lr=0.0000225, acc=0.946 356 | INFO:root:[Epoch 2 Batch 2300/2349] loss=0.1463, lr=0.0000224, acc=0.946 357 | INFO:root:[Epoch 2 Batch 2320/2349] loss=0.1373, lr=0.0000223, acc=0.946 358 | INFO:root:[Epoch 2 Batch 2340/2349] loss=0.1490, lr=0.0000222, acc=0.946 359 | INFO:root:Validation accuracy: 0.899 360 | INFO:root:Time cost=1114.4s 361 | INFO:root:[Epoch 3 Batch 20/2349] loss=0.0887, lr=0.0000220, acc=0.969 362 | INFO:root:[Epoch 3 Batch 40/2349] loss=0.0966, lr=0.0000219, acc=0.968 363 | INFO:root:[Epoch 3 Batch 60/2349] loss=0.1050, lr=0.0000218, acc=0.967 364 | INFO:root:[Epoch 3 Batch 80/2349] loss=0.0728, lr=0.0000217, acc=0.969 365 | INFO:root:[Epoch 3 Batch 100/2349] loss=0.0899, lr=0.0000217, acc=0.969 366 | INFO:root:[Epoch 3 Batch 120/2349] loss=0.0921, lr=0.0000216, acc=0.969 367 | INFO:root:[Epoch 3 Batch 140/2349] loss=0.0765, lr=0.0000215, acc=0.970 368 | INFO:root:[Epoch 3 Batch 160/2349] loss=0.0799, lr=0.0000214, acc=0.971 369 | INFO:root:[Epoch 3 Batch 180/2349] loss=0.1005, lr=0.0000213, acc=0.970 370 | INFO:root:[Epoch 3 Batch 200/2349] loss=0.0803, lr=0.0000212, acc=0.970 371 | INFO:root:[Epoch 3 Batch 220/2349] loss=0.0934, lr=0.0000211, acc=0.969 372 | INFO:root:[Epoch 3 Batch 240/2349] loss=0.0922, lr=0.0000210, acc=0.969 373 | INFO:root:[Epoch 3 Batch 260/2349] loss=0.1014, lr=0.0000209, acc=0.969 374 | INFO:root:[Epoch 3 Batch 280/2349] loss=0.0616, lr=0.0000208, acc=0.969 375 | INFO:root:[Epoch 3 Batch 300/2349] loss=0.0956, lr=0.0000207, acc=0.969 376 | INFO:root:[Epoch 3 Batch 320/2349] loss=0.0963, lr=0.0000206, acc=0.969 377 | INFO:root:[Epoch 3 Batch 340/2349] loss=0.1058, lr=0.0000205, acc=0.968 378 | INFO:root:[Epoch 3 Batch 360/2349] loss=0.1081, lr=0.0000204, acc=0.968 379 | INFO:root:[Epoch 3 Batch 380/2349] loss=0.0721, lr=0.0000203, acc=0.969 380 | INFO:root:[Epoch 3 Batch 400/2349] loss=0.0829, lr=0.0000202, acc=0.969 381 | INFO:root:[Epoch 3 Batch 420/2349] loss=0.1092, lr=0.0000201, acc=0.969 382 | INFO:root:[Epoch 3 Batch 440/2349] loss=0.1071, lr=0.0000200, acc=0.969 383 | INFO:root:[Epoch 3 Batch 460/2349] loss=0.0945, lr=0.0000199, acc=0.969 384 | INFO:root:[Epoch 3 Batch 480/2349] loss=0.1088, lr=0.0000199, acc=0.968 385 | INFO:root:[Epoch 3 Batch 500/2349] loss=0.0943, lr=0.0000198, acc=0.968 386 | INFO:root:[Epoch 3 Batch 520/2349] loss=0.0869, lr=0.0000197, acc=0.968 387 | INFO:root:[Epoch 3 Batch 540/2349] loss=0.1120, lr=0.0000196, acc=0.968 388 | INFO:root:[Epoch 3 Batch 560/2349] loss=0.0640, lr=0.0000195, acc=0.968 389 | INFO:root:[Epoch 3 Batch 580/2349] loss=0.0963, lr=0.0000194, acc=0.968 390 | INFO:root:[Epoch 3 Batch 600/2349] loss=0.0680, lr=0.0000193, acc=0.969 391 | INFO:root:[Epoch 3 Batch 620/2349] loss=0.0771, lr=0.0000192, acc=0.969 392 | INFO:root:[Epoch 3 Batch 640/2349] loss=0.0962, lr=0.0000191, acc=0.969 393 | INFO:root:[Epoch 3 Batch 660/2349] loss=0.0929, lr=0.0000190, acc=0.969 394 | INFO:root:[Epoch 3 Batch 680/2349] loss=0.1006, lr=0.0000189, acc=0.969 395 | INFO:root:[Epoch 3 Batch 700/2349] loss=0.0941, lr=0.0000188, acc=0.968 396 | INFO:root:[Epoch 3 Batch 720/2349] loss=0.0807, lr=0.0000187, acc=0.968 397 | INFO:root:[Epoch 3 Batch 740/2349] loss=0.0820, lr=0.0000186, acc=0.969 398 | INFO:root:[Epoch 3 Batch 760/2349] loss=0.0982, lr=0.0000185, acc=0.969 399 | INFO:root:[Epoch 3 Batch 780/2349] loss=0.1066, lr=0.0000184, acc=0.969 400 | INFO:root:[Epoch 3 Batch 800/2349] loss=0.0931, lr=0.0000183, acc=0.969 401 | INFO:root:[Epoch 3 Batch 820/2349] loss=0.0988, lr=0.0000182, acc=0.969 402 | INFO:root:[Epoch 3 Batch 840/2349] loss=0.1155, lr=0.0000181, acc=0.968 403 | INFO:root:[Epoch 3 Batch 860/2349] loss=0.0908, lr=0.0000181, acc=0.968 404 | INFO:root:[Epoch 3 Batch 880/2349] loss=0.1082, lr=0.0000180, acc=0.968 405 | INFO:root:[Epoch 3 Batch 900/2349] loss=0.1005, lr=0.0000179, acc=0.968 406 | INFO:root:[Epoch 3 Batch 920/2349] loss=0.0837, lr=0.0000178, acc=0.968 407 | INFO:root:[Epoch 3 Batch 940/2349] loss=0.0863, lr=0.0000177, acc=0.968 408 | INFO:root:[Epoch 3 Batch 960/2349] loss=0.1011, lr=0.0000176, acc=0.968 409 | INFO:root:[Epoch 3 Batch 980/2349] loss=0.0897, lr=0.0000175, acc=0.968 410 | INFO:root:[Epoch 3 Batch 1000/2349] loss=0.0892, lr=0.0000174, acc=0.968 411 | INFO:root:[Epoch 3 Batch 1020/2349] loss=0.0952, lr=0.0000173, acc=0.968 412 | INFO:root:[Epoch 3 Batch 1040/2349] loss=0.0852, lr=0.0000172, acc=0.968 413 | INFO:root:[Epoch 3 Batch 1060/2349] loss=0.0666, lr=0.0000171, acc=0.968 414 | INFO:root:[Epoch 3 Batch 1080/2349] loss=0.0968, lr=0.0000170, acc=0.968 415 | INFO:root:[Epoch 3 Batch 1100/2349] loss=0.0801, lr=0.0000169, acc=0.968 416 | INFO:root:[Epoch 3 Batch 1120/2349] loss=0.0832, lr=0.0000168, acc=0.968 417 | INFO:root:[Epoch 3 Batch 1140/2349] loss=0.0931, lr=0.0000167, acc=0.968 418 | INFO:root:[Epoch 3 Batch 1160/2349] loss=0.0765, lr=0.0000166, acc=0.968 419 | INFO:root:[Epoch 3 Batch 1180/2349] loss=0.1259, lr=0.0000165, acc=0.968 420 | INFO:root:[Epoch 3 Batch 1200/2349] loss=0.1182, lr=0.0000164, acc=0.968 421 | INFO:root:[Epoch 3 Batch 1220/2349] loss=0.0642, lr=0.0000163, acc=0.968 422 | INFO:root:[Epoch 3 Batch 1240/2349] loss=0.0974, lr=0.0000162, acc=0.968 423 | INFO:root:[Epoch 3 Batch 1260/2349] loss=0.0918, lr=0.0000162, acc=0.968 424 | INFO:root:[Epoch 3 Batch 1280/2349] loss=0.0742, lr=0.0000161, acc=0.968 425 | INFO:root:[Epoch 3 Batch 1300/2349] loss=0.0882, lr=0.0000160, acc=0.968 426 | INFO:root:[Epoch 3 Batch 1320/2349] loss=0.0876, lr=0.0000159, acc=0.968 427 | INFO:root:[Epoch 3 Batch 1340/2349] loss=0.0797, lr=0.0000158, acc=0.968 428 | INFO:root:[Epoch 3 Batch 1360/2349] loss=0.1009, lr=0.0000157, acc=0.968 429 | INFO:root:[Epoch 3 Batch 1380/2349] loss=0.0946, lr=0.0000156, acc=0.968 430 | INFO:root:[Epoch 3 Batch 1400/2349] loss=0.1064, lr=0.0000155, acc=0.968 431 | INFO:root:[Epoch 3 Batch 1420/2349] loss=0.0820, lr=0.0000154, acc=0.968 432 | INFO:root:[Epoch 3 Batch 1440/2349] loss=0.0910, lr=0.0000153, acc=0.968 433 | INFO:root:[Epoch 3 Batch 1460/2349] loss=0.0943, lr=0.0000152, acc=0.968 434 | INFO:root:[Epoch 3 Batch 1480/2349] loss=0.0707, lr=0.0000151, acc=0.968 435 | INFO:root:[Epoch 3 Batch 1500/2349] loss=0.0542, lr=0.0000150, acc=0.969 436 | INFO:root:[Epoch 3 Batch 1520/2349] loss=0.0643, lr=0.0000149, acc=0.969 437 | INFO:root:[Epoch 3 Batch 1540/2349] loss=0.0903, lr=0.0000148, acc=0.969 438 | INFO:root:[Epoch 3 Batch 1560/2349] loss=0.1009, lr=0.0000147, acc=0.969 439 | INFO:root:[Epoch 3 Batch 1580/2349] loss=0.0789, lr=0.0000146, acc=0.969 440 | INFO:root:[Epoch 3 Batch 1600/2349] loss=0.0903, lr=0.0000145, acc=0.969 441 | INFO:root:[Epoch 3 Batch 1620/2349] loss=0.0920, lr=0.0000144, acc=0.969 442 | INFO:root:[Epoch 3 Batch 1640/2349] loss=0.1110, lr=0.0000144, acc=0.969 443 | INFO:root:[Epoch 3 Batch 1660/2349] loss=0.0823, lr=0.0000143, acc=0.969 444 | INFO:root:[Epoch 3 Batch 1680/2349] loss=0.0873, lr=0.0000142, acc=0.969 445 | INFO:root:[Epoch 3 Batch 1700/2349] loss=0.0887, lr=0.0000141, acc=0.969 446 | INFO:root:[Epoch 3 Batch 1720/2349] loss=0.0929, lr=0.0000140, acc=0.969 447 | INFO:root:[Epoch 3 Batch 1740/2349] loss=0.1094, lr=0.0000139, acc=0.969 448 | INFO:root:[Epoch 3 Batch 1760/2349] loss=0.0854, lr=0.0000138, acc=0.969 449 | INFO:root:[Epoch 3 Batch 1780/2349] loss=0.1062, lr=0.0000137, acc=0.969 450 | INFO:root:[Epoch 3 Batch 1800/2349] loss=0.0951, lr=0.0000136, acc=0.969 451 | INFO:root:[Epoch 3 Batch 1820/2349] loss=0.0960, lr=0.0000135, acc=0.969 452 | INFO:root:[Epoch 3 Batch 1840/2349] loss=0.0848, lr=0.0000134, acc=0.969 453 | INFO:root:[Epoch 3 Batch 1860/2349] loss=0.0857, lr=0.0000133, acc=0.969 454 | INFO:root:[Epoch 3 Batch 1880/2349] loss=0.0994, lr=0.0000132, acc=0.969 455 | INFO:root:[Epoch 3 Batch 1900/2349] loss=0.0955, lr=0.0000131, acc=0.969 456 | INFO:root:[Epoch 3 Batch 1920/2349] loss=0.1111, lr=0.0000130, acc=0.968 457 | INFO:root:[Epoch 3 Batch 1940/2349] loss=0.1101, lr=0.0000129, acc=0.968 458 | INFO:root:[Epoch 3 Batch 1960/2349] loss=0.0787, lr=0.0000128, acc=0.968 459 | INFO:root:[Epoch 3 Batch 1980/2349] loss=0.0890, lr=0.0000127, acc=0.968 460 | INFO:root:[Epoch 3 Batch 2000/2349] loss=0.1081, lr=0.0000126, acc=0.968 461 | INFO:root:[Epoch 3 Batch 2020/2349] loss=0.0923, lr=0.0000126, acc=0.968 462 | INFO:root:[Epoch 3 Batch 2040/2349] loss=0.0855, lr=0.0000125, acc=0.968 463 | INFO:root:[Epoch 3 Batch 2060/2349] loss=0.0953, lr=0.0000124, acc=0.968 464 | INFO:root:[Epoch 3 Batch 2080/2349] loss=0.0991, lr=0.0000123, acc=0.968 465 | INFO:root:[Epoch 3 Batch 2100/2349] loss=0.0786, lr=0.0000122, acc=0.968 466 | INFO:root:[Epoch 3 Batch 2120/2349] loss=0.0863, lr=0.0000121, acc=0.968 467 | INFO:root:[Epoch 3 Batch 2140/2349] loss=0.0841, lr=0.0000120, acc=0.968 468 | INFO:root:[Epoch 3 Batch 2160/2349] loss=0.0886, lr=0.0000119, acc=0.969 469 | INFO:root:[Epoch 3 Batch 2180/2349] loss=0.0730, lr=0.0000118, acc=0.969 470 | INFO:root:[Epoch 3 Batch 2200/2349] loss=0.0848, lr=0.0000117, acc=0.969 471 | INFO:root:[Epoch 3 Batch 2220/2349] loss=0.0962, lr=0.0000116, acc=0.969 472 | INFO:root:[Epoch 3 Batch 2240/2349] loss=0.0900, lr=0.0000115, acc=0.969 473 | INFO:root:[Epoch 3 Batch 2260/2349] loss=0.1082, lr=0.0000114, acc=0.969 474 | INFO:root:[Epoch 3 Batch 2280/2349] loss=0.0690, lr=0.0000113, acc=0.969 475 | INFO:root:[Epoch 3 Batch 2300/2349] loss=0.1000, lr=0.0000112, acc=0.969 476 | INFO:root:[Epoch 3 Batch 2320/2349] loss=0.0890, lr=0.0000111, acc=0.969 477 | INFO:root:[Epoch 3 Batch 2340/2349] loss=0.0927, lr=0.0000110, acc=0.969 478 | INFO:root:Validation accuracy: 0.899 479 | INFO:root:Time cost=1118.6s 480 | INFO:root:[Epoch 4 Batch 20/2349] loss=0.0582, lr=0.0000109, acc=0.981 481 | INFO:root:[Epoch 4 Batch 40/2349] loss=0.0704, lr=0.0000108, acc=0.980 482 | INFO:root:[Epoch 4 Batch 60/2349] loss=0.0796, lr=0.0000107, acc=0.978 483 | INFO:root:[Epoch 4 Batch 80/2349] loss=0.0481, lr=0.0000106, acc=0.979 484 | INFO:root:[Epoch 4 Batch 100/2349] loss=0.0580, lr=0.0000105, acc=0.979 485 | INFO:root:[Epoch 4 Batch 120/2349] loss=0.0469, lr=0.0000104, acc=0.980 486 | INFO:root:[Epoch 4 Batch 140/2349] loss=0.0580, lr=0.0000103, acc=0.980 487 | INFO:root:[Epoch 4 Batch 160/2349] loss=0.0460, lr=0.0000102, acc=0.980 488 | INFO:root:[Epoch 4 Batch 180/2349] loss=0.0444, lr=0.0000101, acc=0.981 489 | INFO:root:[Epoch 4 Batch 200/2349] loss=0.0424, lr=0.0000100, acc=0.981 490 | INFO:root:[Epoch 4 Batch 220/2349] loss=0.0722, lr=0.0000099, acc=0.981 491 | INFO:root:[Epoch 4 Batch 240/2349] loss=0.0645, lr=0.0000099, acc=0.981 492 | INFO:root:[Epoch 4 Batch 260/2349] loss=0.0715, lr=0.0000098, acc=0.981 493 | INFO:root:[Epoch 4 Batch 280/2349] loss=0.1030, lr=0.0000097, acc=0.980 494 | INFO:root:[Epoch 4 Batch 300/2349] loss=0.0666, lr=0.0000096, acc=0.979 495 | INFO:root:[Epoch 4 Batch 320/2349] loss=0.0529, lr=0.0000095, acc=0.980 496 | INFO:root:[Epoch 4 Batch 340/2349] loss=0.0743, lr=0.0000094, acc=0.980 497 | INFO:root:[Epoch 4 Batch 360/2349] loss=0.0573, lr=0.0000093, acc=0.980 498 | INFO:root:[Epoch 4 Batch 380/2349] loss=0.0356, lr=0.0000092, acc=0.980 499 | INFO:root:[Epoch 4 Batch 400/2349] loss=0.0631, lr=0.0000091, acc=0.980 500 | INFO:root:[Epoch 4 Batch 420/2349] loss=0.0518, lr=0.0000090, acc=0.980 501 | INFO:root:[Epoch 4 Batch 440/2349] loss=0.0762, lr=0.0000089, acc=0.980 502 | INFO:root:[Epoch 4 Batch 460/2349] loss=0.0546, lr=0.0000088, acc=0.980 503 | INFO:root:[Epoch 4 Batch 480/2349] loss=0.0619, lr=0.0000087, acc=0.980 504 | INFO:root:[Epoch 4 Batch 500/2349] loss=0.0361, lr=0.0000086, acc=0.980 505 | INFO:root:[Epoch 4 Batch 520/2349] loss=0.0476, lr=0.0000085, acc=0.980 506 | INFO:root:[Epoch 4 Batch 540/2349] loss=0.0624, lr=0.0000084, acc=0.980 507 | INFO:root:[Epoch 4 Batch 560/2349] loss=0.0579, lr=0.0000083, acc=0.980 508 | INFO:root:[Epoch 4 Batch 580/2349] loss=0.0560, lr=0.0000082, acc=0.980 509 | INFO:root:[Epoch 4 Batch 600/2349] loss=0.0526, lr=0.0000081, acc=0.981 510 | INFO:root:[Epoch 4 Batch 620/2349] loss=0.0471, lr=0.0000080, acc=0.981 511 | INFO:root:[Epoch 4 Batch 640/2349] loss=0.0764, lr=0.0000080, acc=0.980 512 | INFO:root:[Epoch 4 Batch 660/2349] loss=0.0682, lr=0.0000079, acc=0.980 513 | INFO:root:[Epoch 4 Batch 680/2349] loss=0.0576, lr=0.0000078, acc=0.980 514 | INFO:root:[Epoch 4 Batch 700/2349] loss=0.0553, lr=0.0000077, acc=0.980 515 | INFO:root:[Epoch 4 Batch 720/2349] loss=0.0768, lr=0.0000076, acc=0.980 516 | INFO:root:[Epoch 4 Batch 740/2349] loss=0.0625, lr=0.0000075, acc=0.980 517 | INFO:root:[Epoch 4 Batch 760/2349] loss=0.0521, lr=0.0000074, acc=0.980 518 | INFO:root:[Epoch 4 Batch 780/2349] loss=0.0657, lr=0.0000073, acc=0.980 519 | INFO:root:[Epoch 4 Batch 800/2349] loss=0.0749, lr=0.0000072, acc=0.980 520 | INFO:root:[Epoch 4 Batch 820/2349] loss=0.0573, lr=0.0000071, acc=0.980 521 | INFO:root:[Epoch 4 Batch 840/2349] loss=0.0713, lr=0.0000070, acc=0.980 522 | INFO:root:[Epoch 4 Batch 860/2349] loss=0.0604, lr=0.0000069, acc=0.980 523 | INFO:root:[Epoch 4 Batch 880/2349] loss=0.0449, lr=0.0000068, acc=0.980 524 | INFO:root:[Epoch 4 Batch 900/2349] loss=0.0719, lr=0.0000067, acc=0.980 525 | INFO:root:[Epoch 4 Batch 920/2349] loss=0.0693, lr=0.0000066, acc=0.980 526 | INFO:root:[Epoch 4 Batch 940/2349] loss=0.0678, lr=0.0000065, acc=0.980 527 | INFO:root:[Epoch 4 Batch 960/2349] loss=0.0485, lr=0.0000064, acc=0.980 528 | INFO:root:[Epoch 4 Batch 980/2349] loss=0.0578, lr=0.0000063, acc=0.980 529 | INFO:root:[Epoch 4 Batch 1000/2349] loss=0.0494, lr=0.0000062, acc=0.980 530 | INFO:root:[Epoch 4 Batch 1020/2349] loss=0.0510, lr=0.0000062, acc=0.980 531 | INFO:root:[Epoch 4 Batch 1040/2349] loss=0.0660, lr=0.0000061, acc=0.980 532 | INFO:root:[Epoch 4 Batch 1060/2349] loss=0.0576, lr=0.0000060, acc=0.980 533 | INFO:root:[Epoch 4 Batch 1080/2349] loss=0.0457, lr=0.0000059, acc=0.980 534 | INFO:root:[Epoch 4 Batch 1100/2349] loss=0.0493, lr=0.0000058, acc=0.980 535 | INFO:root:[Epoch 4 Batch 1120/2349] loss=0.0540, lr=0.0000057, acc=0.980 536 | INFO:root:[Epoch 4 Batch 1140/2349] loss=0.0622, lr=0.0000056, acc=0.980 537 | INFO:root:[Epoch 4 Batch 1160/2349] loss=0.0684, lr=0.0000055, acc=0.980 538 | INFO:root:[Epoch 4 Batch 1180/2349] loss=0.0520, lr=0.0000054, acc=0.980 539 | INFO:root:[Epoch 4 Batch 1200/2349] loss=0.0725, lr=0.0000053, acc=0.980 540 | INFO:root:[Epoch 4 Batch 1220/2349] loss=0.0824, lr=0.0000052, acc=0.980 541 | INFO:root:[Epoch 4 Batch 1240/2349] loss=0.0708, lr=0.0000051, acc=0.980 542 | INFO:root:[Epoch 4 Batch 1260/2349] loss=0.0545, lr=0.0000050, acc=0.980 543 | INFO:root:[Epoch 4 Batch 1280/2349] loss=0.0685, lr=0.0000049, acc=0.980 544 | INFO:root:[Epoch 4 Batch 1300/2349] loss=0.0500, lr=0.0000048, acc=0.980 545 | INFO:root:[Epoch 4 Batch 1320/2349] loss=0.0633, lr=0.0000047, acc=0.980 546 | INFO:root:[Epoch 4 Batch 1340/2349] loss=0.0526, lr=0.0000046, acc=0.980 547 | INFO:root:[Epoch 4 Batch 1360/2349] loss=0.0570, lr=0.0000045, acc=0.980 548 | INFO:root:[Epoch 4 Batch 1380/2349] loss=0.0614, lr=0.0000044, acc=0.980 549 | INFO:root:[Epoch 4 Batch 1400/2349] loss=0.0564, lr=0.0000044, acc=0.980 550 | INFO:root:[Epoch 4 Batch 1420/2349] loss=0.0735, lr=0.0000043, acc=0.980 551 | INFO:root:[Epoch 4 Batch 1440/2349] loss=0.0646, lr=0.0000042, acc=0.980 552 | INFO:root:[Epoch 4 Batch 1460/2349] loss=0.0708, lr=0.0000041, acc=0.980 553 | INFO:root:[Epoch 4 Batch 1480/2349] loss=0.0508, lr=0.0000040, acc=0.980 554 | INFO:root:[Epoch 4 Batch 1500/2349] loss=0.0656, lr=0.0000039, acc=0.980 555 | INFO:root:[Epoch 4 Batch 1520/2349] loss=0.0493, lr=0.0000038, acc=0.980 556 | INFO:root:[Epoch 4 Batch 1540/2349] loss=0.0468, lr=0.0000037, acc=0.980 557 | INFO:root:[Epoch 4 Batch 1560/2349] loss=0.0316, lr=0.0000036, acc=0.980 558 | INFO:root:[Epoch 4 Batch 1580/2349] loss=0.0508, lr=0.0000035, acc=0.980 559 | INFO:root:[Epoch 4 Batch 1600/2349] loss=0.0540, lr=0.0000034, acc=0.980 560 | INFO:root:[Epoch 4 Batch 1620/2349] loss=0.0823, lr=0.0000033, acc=0.980 561 | INFO:root:[Epoch 4 Batch 1640/2349] loss=0.0550, lr=0.0000032, acc=0.980 562 | INFO:root:[Epoch 4 Batch 1660/2349] loss=0.0571, lr=0.0000031, acc=0.980 563 | INFO:root:[Epoch 4 Batch 1680/2349] loss=0.0819, lr=0.0000030, acc=0.980 564 | INFO:root:[Epoch 4 Batch 1700/2349] loss=0.0654, lr=0.0000029, acc=0.980 565 | INFO:root:[Epoch 4 Batch 1720/2349] loss=0.0933, lr=0.0000028, acc=0.980 566 | INFO:root:[Epoch 4 Batch 1740/2349] loss=0.0707, lr=0.0000027, acc=0.980 567 | INFO:root:[Epoch 4 Batch 1760/2349] loss=0.0749, lr=0.0000026, acc=0.980 568 | INFO:root:[Epoch 4 Batch 1780/2349] loss=0.0636, lr=0.0000026, acc=0.980 569 | INFO:root:[Epoch 4 Batch 1800/2349] loss=0.0728, lr=0.0000025, acc=0.980 570 | INFO:root:[Epoch 4 Batch 1820/2349] loss=0.0666, lr=0.0000024, acc=0.980 571 | INFO:root:[Epoch 4 Batch 1840/2349] loss=0.0753, lr=0.0000023, acc=0.980 572 | INFO:root:[Epoch 4 Batch 1860/2349] loss=0.0628, lr=0.0000022, acc=0.980 573 | INFO:root:[Epoch 4 Batch 1880/2349] loss=0.0595, lr=0.0000021, acc=0.980 574 | INFO:root:[Epoch 4 Batch 1900/2349] loss=0.0509, lr=0.0000020, acc=0.980 575 | INFO:root:[Epoch 4 Batch 1920/2349] loss=0.0504, lr=0.0000019, acc=0.980 576 | INFO:root:[Epoch 4 Batch 1940/2349] loss=0.0583, lr=0.0000018, acc=0.980 577 | INFO:root:[Epoch 4 Batch 1960/2349] loss=0.0448, lr=0.0000017, acc=0.980 578 | INFO:root:[Epoch 4 Batch 1980/2349] loss=0.0556, lr=0.0000016, acc=0.980 579 | INFO:root:[Epoch 4 Batch 2000/2349] loss=0.0431, lr=0.0000015, acc=0.980 580 | INFO:root:[Epoch 4 Batch 2020/2349] loss=0.0563, lr=0.0000014, acc=0.980 581 | INFO:root:[Epoch 4 Batch 2040/2349] loss=0.0556, lr=0.0000013, acc=0.980 582 | INFO:root:[Epoch 4 Batch 2060/2349] loss=0.0733, lr=0.0000012, acc=0.980 583 | INFO:root:[Epoch 4 Batch 2080/2349] loss=0.0640, lr=0.0000011, acc=0.980 584 | INFO:root:[Epoch 4 Batch 2100/2349] loss=0.0629, lr=0.0000010, acc=0.980 585 | INFO:root:[Epoch 4 Batch 2120/2349] loss=0.0422, lr=0.0000009, acc=0.980 586 | INFO:root:[Epoch 4 Batch 2140/2349] loss=0.0708, lr=0.0000008, acc=0.980 587 | INFO:root:[Epoch 4 Batch 2160/2349] loss=0.0477, lr=0.0000007, acc=0.980 588 | INFO:root:[Epoch 4 Batch 2180/2349] loss=0.0877, lr=0.0000007, acc=0.980 589 | INFO:root:[Epoch 4 Batch 2200/2349] loss=0.0550, lr=0.0000006, acc=0.980 590 | INFO:root:[Epoch 4 Batch 2220/2349] loss=0.0581, lr=0.0000005, acc=0.980 591 | INFO:root:[Epoch 4 Batch 2240/2349] loss=0.0392, lr=0.0000004, acc=0.980 592 | INFO:root:[Epoch 4 Batch 2260/2349] loss=0.0531, lr=0.0000003, acc=0.980 593 | INFO:root:[Epoch 4 Batch 2280/2349] loss=0.0452, lr=0.0000002, acc=0.980 594 | INFO:root:[Epoch 4 Batch 2300/2349] loss=0.0864, lr=0.0000001, acc=0.980 595 | INFO:root:[Epoch 4 Batch 2320/2349] loss=0.0621, lr=-0.0000000, acc=0.980 596 | INFO:root:[Epoch 4 Batch 2340/2349] loss=0.0537, lr=-0.0000001, acc=0.980 597 | INFO:root:Validation accuracy: 0.901 598 | INFO:root:Time cost=1113.0s 599 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 <=1.15.18 2 | gluonnlp >= 0.6.0, <=0.10.0 3 | mxnet >= 1.4.0, <=1.7.0.post2 4 | onnxruntime == 1.8.0, <=1.8.0 5 | sentencepiece >= 0.1.6, <=0.1.96 6 | torch >= 1.7.0, <=1.10.1 7 | transformers >= 4.8.1, <=4.8.1 8 | -------------------------------------------------------------------------------- /scripts/NSMC/naver_review_classifications_gluon_kobert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# KoBERT finetuning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "colab": { 15 | "base_uri": "https://localhost:8080/", 16 | "height": 188 17 | }, 18 | "colab_type": "code", 19 | "id": "-sx87sgK7_pz", 20 | "outputId": "9f1c67bf-7c67-45d7-88b7-4bbc7384a29b" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "!pip install ipywidgets # for vscode\n", 25 | "!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": { 32 | "colab": {}, 33 | "colab_type": "code", 34 | "id": "5mTNl7BKT2Fx" 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "import numpy as np\n", 39 | "from tqdm.notebook import tqdm\n", 40 | "\n", 41 | "from mxnet.gluon import nn\n", 42 | "from mxnet import gluon\n", 43 | "import mxnet as mx\n", 44 | "import gluonnlp as nlp\n", 45 | "\n", 46 | "from kobert import get_mxnet_kobert_model\n", 47 | "from kobert import get_tokenizer" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "colab_type": "text", 54 | "id": "Cc-zco-ST2F_" 55 | }, 56 | "source": [ 57 | "### Loading KoBERT" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# CPU\n", 67 | "ctx = mx.cpu()\n", 68 | "\n", 69 | "# GPU\n", 70 | "# ctx = mx.gpu()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "colab": { 78 | "base_uri": "https://localhost:8080/", 79 | "height": 55 80 | }, 81 | "colab_type": "code", 82 | "id": "wI841Zb38XOn", 83 | "outputId": "f9794e99-c913-4ca0-b8fd-6e15ce9d74c7" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "bert_base, vocab = get_mxnet_kobert_model(use_decoder=False, use_classifier=False, ctx=ctx, cachedir=\".cache\")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "colab": { 95 | "base_uri": "https://localhost:8080/", 96 | "height": 36 97 | }, 98 | "colab_type": "code", 99 | "id": "NijpWe8J8isZ", 100 | "outputId": "d03d1cc1-327f-4b44-ed66-f02126687688" 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "tokenizer = get_tokenizer()\n", 105 | "tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "colab": { 113 | "base_uri": "https://localhost:8080/", 114 | "height": 93 115 | }, 116 | "colab_type": "code", 117 | "id": "i69AUj9gT2Gk", 118 | "outputId": "050d42de-ac07-4c04-9f14-b5ea411df008" 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "ds = gluon.data.SimpleDataset([['나 보기가 역겨워', '김소월']])\n", 123 | "trans = nlp.data.BERTSentenceTransform(tok, max_seq_length=10)\n", 124 | "\n", 125 | "list(ds.transform(trans))" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "### Loading Data" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "colab": { 140 | "base_uri": "https://localhost:8080/", 141 | "height": 796 142 | }, 143 | "colab_type": "code", 144 | "id": "4qy9g_UMVtdj", 145 | "outputId": "c4546df4-ce5b-4484-e245-309c90fed014" 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "!wget -O .cache/ratings_train.txt http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_train.txt\n", 150 | "!wget -O .cache/ratings_test.txt http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_test.txt" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": { 157 | "colab": {}, 158 | "colab_type": "code", 159 | "id": "4LfCTweqT2Gt" 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "dataset_train = nlp.data.TSVDataset(\".cache/ratings_train.txt\", field_indices=[1,2], num_discard_samples=1)\n", 164 | "dataset_test = nlp.data.TSVDataset(\".cache/ratings_test.txt\", field_indices=[1,2], num_discard_samples=1)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "colab": {}, 172 | "colab_type": "code", 173 | "id": "pt0raV8uT2G2" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "class BERTDataset(mx.gluon.data.Dataset):\n", 178 | " def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,\n", 179 | " pad, pair):\n", 180 | " transform = nlp.data.BERTSentenceTransform(\n", 181 | " bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)\n", 182 | " sent_dataset = gluon.data.SimpleDataset([[\n", 183 | " i[sent_idx],\n", 184 | " ] for i in dataset])\n", 185 | " self.sentences = sent_dataset.transform(transform)\n", 186 | " self.labels = gluon.data.SimpleDataset(\n", 187 | " [np.array(np.int32(i[label_idx])) for i in dataset])\n", 188 | "\n", 189 | " def __getitem__(self, i):\n", 190 | " return (self.sentences[i] + (self.labels[i], ))\n", 191 | "\n", 192 | " def __len__(self):\n", 193 | " return (len(self.labels))\n" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "colab": {}, 201 | "colab_type": "code", 202 | "id": "vtk-8pQST2G9" 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "max_len = 128" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": { 213 | "colab": {}, 214 | "colab_type": "code", 215 | "id": "_K_BLZP_T2HF" 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)\n", 220 | "data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": { 227 | "colab": {}, 228 | "colab_type": "code", 229 | "id": "rhaw0H4ST2HM" 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "class BERTClassifier(nn.Block):\n", 234 | " def __init__(self,\n", 235 | " bert,\n", 236 | " num_classes=2,\n", 237 | " dropout=None,\n", 238 | " prefix=None,\n", 239 | " params=None):\n", 240 | " super(BERTClassifier, self).__init__(prefix=prefix, params=params)\n", 241 | " self.bert = bert\n", 242 | " with self.name_scope():\n", 243 | " self.classifier = nn.HybridSequential(prefix=prefix)\n", 244 | " if dropout:\n", 245 | " self.classifier.add(nn.Dropout(rate=dropout))\n", 246 | " self.classifier.add(nn.Dense(units=num_classes))\n", 247 | "\n", 248 | " def forward(self, inputs, token_types, valid_length=None):\n", 249 | " _, pooler = self.bert(inputs, token_types, valid_length)\n", 250 | " return self.classifier(pooler)\n", 251 | " " 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": { 258 | "colab": {}, 259 | "colab_type": "code", 260 | "id": "Y00BOPwST2HX" 261 | }, 262 | "outputs": [], 263 | "source": [ 264 | "model = BERTClassifier(bert_base, num_classes=2, dropout=0.1)\n", 265 | "# 분류 레이어만 초기화 한다. \n", 266 | "model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)\n", 267 | "model.hybridize()\n", 268 | "\n", 269 | "# softmax cross entropy loss for classification\n", 270 | "loss_function = gluon.loss.SoftmaxCELoss()\n", 271 | "\n", 272 | "metric = mx.metric.Accuracy()" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": { 279 | "colab": {}, 280 | "colab_type": "code", 281 | "id": "A2dLhnHkT2Hf" 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "batch_size = 32\n", 286 | "lr = 5e-5\n", 287 | "\n", 288 | "train_dataloader = mx.gluon.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)\n", 289 | "test_dataloader = mx.gluon.data.DataLoader(data_test, batch_size=int(batch_size/2), num_workers=5)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": { 296 | "colab": {}, 297 | "colab_type": "code", 298 | "id": "ESo76UH-T2Hr" 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "trainer = gluon.Trainer(model.collect_params(), 'bertadam',\n", 303 | " {'learning_rate': lr, 'epsilon': 1e-9, 'wd':0.01})\n", 304 | "\n", 305 | "log_interval = 4\n", 306 | "num_epochs = 5" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": { 313 | "colab": {}, 314 | "colab_type": "code", 315 | "id": "wspMBDOAT2H0" 316 | }, 317 | "outputs": [], 318 | "source": [ 319 | "# LayerNorm과 Bias에는 Weight Decay를 적용하지 않는다. \n", 320 | "for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():\n", 321 | " v.wd_mult = 0.0\n", 322 | "params = [\n", 323 | " p for p in model.collect_params().values() if p.grad_req != 'null'\n", 324 | "]\n" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": { 331 | "colab": {}, 332 | "colab_type": "code", 333 | "id": "NCR6AMKHT2H6" 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "def evaluate_accuracy(model, data_iter, ctx=ctx):\n", 338 | " acc = mx.metric.Accuracy()\n", 339 | " i = 0\n", 340 | " for i, (t,v,s, label) in enumerate(data_iter):\n", 341 | " token_ids = t.as_in_context(ctx)\n", 342 | " valid_length = v.as_in_context(ctx)\n", 343 | " segment_ids = s.as_in_context(ctx)\n", 344 | " label = label.as_in_context(ctx)\n", 345 | " output = model(token_ids, segment_ids, valid_length.astype('float32'))\n", 346 | " acc.update(preds=output, labels=label)\n", 347 | " if i > 1000:\n", 348 | " break\n", 349 | " i += 1\n", 350 | " return(acc.get()[1])" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": { 357 | "colab": {}, 358 | "colab_type": "code", 359 | "id": "SkcW6GyeT2IA" 360 | }, 361 | "outputs": [], 362 | "source": [ 363 | "#learning rate warmup을 위한 준비 \n", 364 | "accumulate = 4\n", 365 | "step_size = batch_size * accumulate if accumulate else batch_size\n", 366 | "num_train_examples = len(data_train)\n", 367 | "num_train_steps = int(num_train_examples / step_size * num_epochs)\n", 368 | "warmup_ratio = 0.1\n", 369 | "num_warmup_steps = int(num_train_steps * warmup_ratio)\n", 370 | "step_num = 0\n", 371 | "all_model_params = model.collect_params()" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": { 378 | "colab": {}, 379 | "colab_type": "code", 380 | "id": "Yf_rpZTq6uES" 381 | }, 382 | "outputs": [], 383 | "source": [ 384 | "# Set grad_req if gradient accumulation is required\n", 385 | "if accumulate and accumulate > 1:\n", 386 | " for p in params:\n", 387 | " p.grad_req = 'add'" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": { 394 | "colab": { 395 | "base_uri": "https://localhost:8080/", 396 | "height": 984 397 | }, 398 | "colab_type": "code", 399 | "id": "0mJ3Pw_VT2IH", 400 | "outputId": "abc9ecfb-8674-445f-cd5d-3fcd57252f39" 401 | }, 402 | "outputs": [], 403 | "source": [ 404 | "for epoch_id in range(num_epochs):\n", 405 | " metric.reset()\n", 406 | " step_loss = 0\n", 407 | " for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):\n", 408 | " if step_num < num_warmup_steps:\n", 409 | " new_lr = lr * step_num / num_warmup_steps\n", 410 | " else:\n", 411 | " non_warmup_steps = step_num - num_warmup_steps\n", 412 | " offset = non_warmup_steps / (num_train_steps - num_warmup_steps)\n", 413 | " new_lr = lr - offset * lr\n", 414 | " trainer.set_learning_rate(new_lr)\n", 415 | " with mx.autograd.record():\n", 416 | " # load data to GPU\n", 417 | " token_ids = token_ids.as_in_context(ctx)\n", 418 | " valid_length = valid_length.as_in_context(ctx)\n", 419 | " segment_ids = segment_ids.as_in_context(ctx)\n", 420 | " label = label.as_in_context(ctx)\n", 421 | "\n", 422 | " # forward computation\n", 423 | " out = model(token_ids, segment_ids, valid_length.astype('float32'))\n", 424 | " ls = loss_function(out, label).mean()\n", 425 | "\n", 426 | " # backward computation\n", 427 | " ls.backward()\n", 428 | " if not accumulate or (batch_id + 1) % accumulate == 0:\n", 429 | " trainer.allreduce_grads()\n", 430 | " nlp.utils.clip_grad_global_norm(params, 1)\n", 431 | " trainer.update(accumulate if accumulate else 1)\n", 432 | " step_num += 1\n", 433 | " if accumulate and accumulate > 1:\n", 434 | " # set grad to zero for gradient accumulation\n", 435 | " all_model_params.zero_grad()\n", 436 | "\n", 437 | " step_loss += ls.asscalar()\n", 438 | " metric.update([label], [out])\n", 439 | " if (batch_id + 1) % (50) == 0:\n", 440 | " print('[Epoch {} Batch {}/{}] loss={:.4f}, lr={:.10f}, acc={:.3f}'\n", 441 | " .format(epoch_id + 1, batch_id + 1, len(train_dataloader),\n", 442 | " step_loss / log_interval,\n", 443 | " trainer.learning_rate, metric.get()[1]))\n", 444 | " step_loss = 0\n", 445 | " test_acc = evaluate_accuracy(model, test_dataloader, ctx)\n", 446 | " print('Test Acc : {}'.format(test_acc))" 447 | ] 448 | } 449 | ], 450 | "metadata": { 451 | "accelerator": "GPU", 452 | "colab": { 453 | "collapsed_sections": [], 454 | "name": "naver_review_classifications_gluon_bert.ipynb의 사본", 455 | "provenance": [] 456 | }, 457 | "kernelspec": { 458 | "display_name": "Python 3", 459 | "language": "python", 460 | "name": "python3" 461 | }, 462 | "language_info": { 463 | "codemirror_mode": { 464 | "name": "ipython", 465 | "version": 3 466 | }, 467 | "file_extension": ".py", 468 | "mimetype": "text/x-python", 469 | "name": "python", 470 | "nbconvert_exporter": "python", 471 | "pygments_lexer": "ipython3", 472 | "version": "3.7.0" 473 | }, 474 | "toc": { 475 | "nav_menu": {}, 476 | "number_sections": true, 477 | "sideBar": true, 478 | "skip_h1_title": false, 479 | "toc_cell": false, 480 | "toc_position": {}, 481 | "toc_section_display": "block", 482 | "toc_window_display": false 483 | } 484 | }, 485 | "nbformat": 4, 486 | "nbformat_minor": 1 487 | } 488 | -------------------------------------------------------------------------------- /scripts/NSMC/naver_review_classifications_pytorch_kobert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# KoBERT finetuning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "!pip install ipywidgets # for vscode\n", 17 | "!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import torch\n", 27 | "from torch import nn\n", 28 | "import torch.nn.functional as F\n", 29 | "import torch.optim as optim\n", 30 | "from torch.utils.data import Dataset, DataLoader\n", 31 | "import gluonnlp as nlp\n", 32 | "import numpy as np\n", 33 | "from tqdm.notebook import tqdm" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from kobert import get_tokenizer\n", 43 | "from kobert import get_pytorch_kobert_model" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from transformers import AdamW\n", 53 | "from transformers.optimization import get_cosine_schedule_with_warmup" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "## CPU\n", 63 | "device = torch.device(\"cpu\")\n", 64 | "\n", 65 | "## GPU\n", 66 | "# device = torch.device(\"cuda:0\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "bertmodel, vocab = get_pytorch_kobert_model(cachedir=\".cache\")" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "!wget -O .cache/ratings_train.txt http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_train.txt\n", 85 | "!wget -O .cache/ratings_test.txt http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_test.txt" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "dataset_train = nlp.data.TSVDataset(\".cache/ratings_train.txt\", field_indices=[1,2], num_discard_samples=1)\n", 95 | "dataset_test = nlp.data.TSVDataset(\".cache/ratings_test.txt\", field_indices=[1,2], num_discard_samples=1)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "tokenizer = get_tokenizer()\n", 105 | "tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "class BERTDataset(Dataset):\n", 115 | " def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,\n", 116 | " pad, pair):\n", 117 | " transform = nlp.data.BERTSentenceTransform(\n", 118 | " bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)\n", 119 | "\n", 120 | " self.sentences = [transform([i[sent_idx]]) for i in dataset]\n", 121 | " self.labels = [np.int32(i[label_idx]) for i in dataset]\n", 122 | "\n", 123 | " def __getitem__(self, i):\n", 124 | " return (self.sentences[i] + (self.labels[i], ))\n", 125 | "\n", 126 | " def __len__(self):\n", 127 | " return (len(self.labels))\n" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "## Setting parameters\n", 137 | "max_len = 64\n", 138 | "batch_size = 64\n", 139 | "warmup_ratio = 0.1\n", 140 | "num_epochs = 5\n", 141 | "max_grad_norm = 1\n", 142 | "log_interval = 200\n", 143 | "learning_rate = 5e-5" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)\n", 153 | "data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)\n", 163 | "test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "class BERTClassifier(nn.Module):\n", 173 | " def __init__(self,\n", 174 | " bert,\n", 175 | " hidden_size = 768,\n", 176 | " num_classes=2,\n", 177 | " dr_rate=None,\n", 178 | " params=None):\n", 179 | " super(BERTClassifier, self).__init__()\n", 180 | " self.bert = bert\n", 181 | " self.dr_rate = dr_rate\n", 182 | " \n", 183 | " self.classifier = nn.Linear(hidden_size , num_classes)\n", 184 | " if dr_rate:\n", 185 | " self.dropout = nn.Dropout(p=dr_rate)\n", 186 | " \n", 187 | " def gen_attention_mask(self, token_ids, valid_length):\n", 188 | " attention_mask = torch.zeros_like(token_ids)\n", 189 | " for i, v in enumerate(valid_length):\n", 190 | " attention_mask[i][:v] = 1\n", 191 | " return attention_mask.float()\n", 192 | "\n", 193 | " def forward(self, token_ids, valid_length, segment_ids):\n", 194 | " attention_mask = self.gen_attention_mask(token_ids, valid_length)\n", 195 | " \n", 196 | " _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))\n", 197 | " if self.dr_rate:\n", 198 | " out = self.dropout(pooler)\n", 199 | " else:\n", 200 | " out = pooler\n", 201 | " return self.classifier(out)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "# Prepare optimizer and schedule (linear warmup and decay)\n", 220 | "no_decay = ['bias', 'LayerNorm.weight']\n", 221 | "optimizer_grouped_parameters = [\n", 222 | " {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n", 223 | " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 224 | "]" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)\n", 234 | "loss_fn = nn.CrossEntropyLoss()" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "t_total = len(train_dataloader) * num_epochs\n", 244 | "warmup_step = int(t_total * warmup_ratio)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "def calc_accuracy(X,Y):\n", 263 | " max_vals, max_indices = torch.max(X, 1)\n", 264 | " train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]\n", 265 | " return train_acc" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "for e in range(num_epochs):\n", 275 | " train_acc = 0.0\n", 276 | " test_acc = 0.0\n", 277 | " model.train()\n", 278 | " for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):\n", 279 | " optimizer.zero_grad()\n", 280 | " token_ids = token_ids.long().to(device)\n", 281 | " segment_ids = segment_ids.long().to(device)\n", 282 | " valid_length= valid_length\n", 283 | " label = label.long().to(device)\n", 284 | " out = model(token_ids, valid_length, segment_ids)\n", 285 | " loss = loss_fn(out, label)\n", 286 | " loss.backward()\n", 287 | " torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n", 288 | " optimizer.step()\n", 289 | " scheduler.step() # Update learning rate schedule\n", 290 | " train_acc += calc_accuracy(out, label)\n", 291 | " if batch_id % log_interval == 0:\n", 292 | " print(\"epoch {} batch id {} loss {} train acc {}\".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))\n", 293 | " print(\"epoch {} train acc {}\".format(e+1, train_acc / (batch_id+1)))\n", 294 | " model.eval()\n", 295 | " for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):\n", 296 | " token_ids = token_ids.long().to(device)\n", 297 | " segment_ids = segment_ids.long().to(device)\n", 298 | " valid_length= valid_length\n", 299 | " label = label.long().to(device)\n", 300 | " out = model(token_ids, valid_length, segment_ids)\n", 301 | " test_acc += calc_accuracy(out, label)\n", 302 | " print(\"epoch {} test acc {}\".format(e+1, test_acc / (batch_id+1)))" 303 | ] 304 | } 305 | ], 306 | "metadata": { 307 | "kernelspec": { 308 | "display_name": "Python 3", 309 | "language": "python", 310 | "name": "python3" 311 | }, 312 | "language_info": { 313 | "codemirror_mode": { 314 | "name": "ipython", 315 | "version": 3 316 | }, 317 | "file_extension": ".py", 318 | "mimetype": "text/x-python", 319 | "name": "python", 320 | "nbconvert_exporter": "python", 321 | "pygments_lexer": "ipython3", 322 | "version": "3.7.0" 323 | } 324 | }, 325 | "nbformat": 4, 326 | "nbformat_minor": 4 327 | } 328 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | def install_requires(): 5 | with open("requirements.txt") as f: 6 | lines = f.read().splitlines() 7 | install_requires = [line for line in lines] 8 | return install_requires 9 | 10 | setup( 11 | name="kobert", 12 | version="0.2.3", 13 | url="https://github.com/SKTBrain/KoBERT", 14 | license="Apache-2.0", 15 | author="Heewon Jeon", 16 | author_email="madjakarta@gmail.com", 17 | description="Korean BERT pre-trained cased (KoBERT) ", 18 | packages=find_packages(), 19 | long_description=open("README.md", encoding="utf-8").read(), 20 | zip_safe=False, 21 | include_package_data=True, 22 | python_requires=">=3.6", 23 | install_requires=install_requires(), 24 | ) 25 | --------------------------------------------------------------------------------