├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── README_back.md ├── bert_pytorch ├── __init__.py ├── __main__.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── vocab.py ├── model │ ├── __init__.py │ ├── attention │ │ ├── __init__.py │ │ ├── multi_head.py │ │ └── single.py │ ├── bert.py │ ├── embedding │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── position.py │ │ ├── segment.py │ │ └── token.py │ ├── language_model.py │ ├── transformer.py │ └── utils │ │ ├── __init__.py │ │ ├── feed_forward.py │ │ ├── gelu.py │ │ ├── layer_norm.py │ │ └── sublayer.py └── trainer │ ├── __init__.py │ ├── optim_schedule.py │ └── pretrain.py ├── data └── corpus.small ├── img ├── 1.png └── 2.png ├── requirements.txt ├── setup.py ├── test.py ├── test_bert.py └── test_bert_vocab.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: circleci/python:3.6.1 6 | 7 | working_directory: ~/repo 8 | 9 | steps: 10 | - checkout 11 | 12 | - restore_cache: 13 | keys: 14 | - v1-dependencies-{{ checksum "requirements.txt" }} 15 | - v1-dependencies- 16 | 17 | - run: 18 | name: install dependencies 19 | command: | 20 | python3 -m venv venv 21 | . venv/bin/activate 22 | pip install -r requirements.txt 23 | 24 | - save_cache: 25 | paths: 26 | - ./venv 27 | key: v1-dependencies-{{ checksum "requirements.txt" }} 28 | 29 | - run: 30 | name: run tests 31 | command: | 32 | . venv/bin/activate 33 | python -m unittest test.py 34 | 35 | - store_artifacts: 36 | path: test-reports 37 | destination: test-reports 38 | 39 | deploy: 40 | docker: 41 | - image: circleci/python:3.6.1 42 | 43 | working_directory: ~/repo 44 | 45 | steps: 46 | - checkout 47 | 48 | - restore_cache: 49 | key: v1-dependency-cache-{{ checksum "setup.py" }}-{{ checksum "Makefile" }} 50 | 51 | - run: 52 | name: verify git tag vs. version 53 | command: | 54 | python3 -m venv venv 55 | . venv/bin/activate 56 | python setup.py verify 57 | pip install twine 58 | 59 | - save_cache: 60 | key: v1-dependency-cache-{{ checksum "setup.py" }}-{{ checksum "Makefile" }} 61 | paths: 62 | - "venv" 63 | 64 | # Deploying to PyPI 65 | # for pip install kor2vec 66 | - run: 67 | name: init .pypirc 68 | command: | 69 | echo -e "[pypi]" >> ~/.pypirc 70 | echo -e "username = codertimo" >> ~/.pypirc 71 | echo -e "password = $PYPI_PASSWORD" >> ~/.pypirc 72 | 73 | - run: 74 | name: create packages 75 | command: | 76 | make package 77 | 78 | - run: 79 | name: upload to pypi 80 | command: | 81 | . venv/bin/activate 82 | twine upload dist/* 83 | workflows: 84 | version: 2 85 | build_and_deploy: 86 | jobs: 87 | - build: 88 | filters: 89 | tags: 90 | only: /.*/ 91 | - deploy: 92 | requires: 93 | - build 94 | filters: 95 | tags: 96 | only: /.*/ 97 | branches: 98 | ignore: /.*/ 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Created by .ignore support plugin (hsz.mobi) 4 | ### Python template 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | ### JetBrains template 110 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 111 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 112 | 113 | # User-specific stuff 114 | .idea/**/workspace.xml 115 | .idea/**/tasks.xml 116 | .idea/**/usage.statistics.xml 117 | .idea/**/dictionaries 118 | .idea/**/shelf 119 | 120 | # Sensitive or high-churn files 121 | .idea/**/dataSources/ 122 | .idea/**/dataSources.ids 123 | .idea/**/dataSources.local.xml 124 | .idea/**/sqlDataSources.xml 125 | .idea/**/dynamic.xml 126 | .idea/**/uiDesigner.xml 127 | .idea/**/dbnavigator.xml 128 | 129 | # Gradle 130 | .idea/**/gradle.xml 131 | .idea/**/libraries 132 | 133 | # Gradle and Maven with auto-import 134 | # When using Gradle or Maven with auto-import, you should exclude module files, 135 | # since they will be recreated, and may cause churn. Uncomment if using 136 | # auto-import. 137 | # .idea/modules.xml 138 | # .idea/*.iml 139 | # .idea/modules 140 | 141 | # CMake 142 | cmake-build-*/ 143 | 144 | # Mongo Explorer plugin 145 | .idea/**/mongoSettings.xml 146 | 147 | # File-based project format 148 | *.iws 149 | 150 | # IntelliJ 151 | out/ 152 | 153 | # mpeltonen/sbt-idea plugin 154 | .idea_modules/ 155 | 156 | # JIRA plugin 157 | atlassian-ide-plugin.xml 158 | 159 | # Cursive Clojure plugin 160 | .idea/replstate.xml 161 | 162 | # Crashlytics plugin (for Android Studio and IntelliJ) 163 | com_crashlytics_export_strings.xml 164 | crashlytics.properties 165 | crashlytics-build.properties 166 | fabric.properties 167 | 168 | # Editor-based Rest Client 169 | .idea/httpRequests 170 | 171 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018 Junseong Kim, Scatter Lab, BERT contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | package: 2 | python setup.py sdist 3 | python setup.py bdist_wheel 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT-Pytorch 源码阅读 2 | 3 | --- 4 | 5 | [TOC] 6 | 7 | ## 1. 整体描述 8 | 9 | BERT-Pytorch 在分发包时,主要设置了两大功能: 10 | 11 | - bert-vocab :统计词频,token2idx, idx2token 等信息。对应 `bert_pytorch.dataset.vocab` 中的 `build` 函数。 12 | - bert:对应 `bert_pytorch.__main__` 下的 train 函数。 13 | 14 | 为了能够调试,我重新建立了两个文件来分别对这两大功能进行调试。 15 | 16 | ## 1. bert-vocab 17 | 18 | ``` 19 | python3 -m ipdb test_bert_vocab.py # 调试 bert-vocab 20 | ``` 21 | 22 | 其实 bert-vocab 内部并没有什么重要信息,无非就是一些自然语言处理中常见的预处理手段, 自己花个十分钟调试一下就明白了, 我加了少部分注释, 很容易就能明白。 23 | 24 | 内部继承关系为: 25 | 26 | ``` 27 | TorchVocab --> Vocab --> WordVocab 28 | ``` 29 | 30 | ## 2. 模型架构 31 | 32 | - 调试命令: 33 | 34 | ``` 35 | python3 -m ipdb test_bert.py -c data/corpus.small -v data/vocab.small -o output/bert.model 36 | ``` 37 | 38 | ![](http://ww1.sinaimg.cn/large/006gOeiSly1g5qw6nkhhgj31400u0myh.jpg) 39 | 40 | 从模型整体上看, 分为两大部分: **MaskedLanguageModel** 与 **NextSentencePrediction** ,并且二者都以 **BERT** 为前置模型,在分别加上一个全连接层与 softmax 层来分别获得输出。 41 | 42 | 这段代码相对很简单,十分容易理解,略过。 43 | 44 | ### 1. Bert Model 45 | 46 | ![](http://ww1.sinaimg.cn/large/006gOeiSly1g5qw6wqgjoj31400u0dhz.jpg) 47 | 48 | 这部分其实就是 Transformer Encoder 部分 + BERT Embedding, 如果不熟悉 Transformer 的同学,恰好可以从此处来加深理解。 49 | 50 | 这部分源码阅读建议可先大致浏览一下整体, 有一个大致的框架,明白各个类之间的依赖关系,然后从细节到整体逐渐理解,即从上图看,从右往左读,效果会更好。 51 | 52 | #### 1. BERTEmbedding 53 | 54 | 分为三大部分: 55 | 56 | - TokenEmbedding : 对 token 的编码,继承于 `nn.Embedding`, 默认初始化为 :`N(0,1)` 57 | - SegmentEmbedding: 对句子信息编码,继承于 `nn.Embedding`, 默认初始化为 :`N(0,1)` 58 | - PositionalEmbedding: 对位置信息编码, 可参见论文,生成的是一个固定的向量表示,不参与训练 59 | 60 | 这里面需要注意的就是 PositionalEmbedding, 因为有些面试官会很抠细节,而我对这些我觉得对我没有啥帮助的东西,一般了解一下就放过了,细节没有抠清楚,事实证明,吃亏了。 61 | 62 | #### 2. Transformer 63 | 64 | 这里面的东西十分建议对照论文一起看,当然,如果很熟的话可以略过。 我在里面管家的地方都加上了注释,如果还是看不懂的话可以提 issue, 这里就不赘述了。 65 | 66 | ## 最后 67 | 68 | 我个人觉得 Google 这个代码写的真的是漂亮, 结构很清晰, 整个看下来不用几个小时就能明白了, 推荐采用我的那种调试方式从头到尾调试一遍,这样会更清晰。 69 | 70 | 觉得不错,点个赞可好。 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /README_back.md: -------------------------------------------------------------------------------- 1 | # BERT-pytorch 2 | 3 | [![LICENSE](https://img.shields.io/github/license/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/blob/master/LICENSE) 4 | ![GitHub issues](https://img.shields.io/github/issues/codertimo/BERT-pytorch.svg) 5 | [![GitHub stars](https://img.shields.io/github/stars/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/stargazers) 6 | [![CircleCI](https://circleci.com/gh/codertimo/BERT-pytorch.svg?style=shield)](https://circleci.com/gh/codertimo/BERT-pytorch) 7 | [![PyPI](https://img.shields.io/pypi/v/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/) 8 | [![PyPI - Status](https://img.shields.io/pypi/status/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/) 9 | [![Documentation Status](https://readthedocs.org/projects/bert-pytorch/badge/?version=latest)](https://bert-pytorch.readthedocs.io/en/latest/?badge=latest) 10 | 11 | Pytorch implementation of Google AI's 2018 BERT, with simple annotation 12 | 13 | > BERT 2018 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 14 | > Paper URL : https://arxiv.org/abs/1810.04805 15 | 16 | 17 | ## Introduction 18 | 19 | Google AI's BERT paper shows the amazing result on various NLP task (new 17 NLP tasks SOTA), 20 | including outperform the human F1 score on SQuAD v1.1 QA task. 21 | This paper proved that Transformer(self-attention) based encoder can be powerfully used as 22 | alternative of previous language model with proper language model training method. 23 | And more importantly, they showed us that this pre-trained language model can be transfer 24 | into any NLP task without making task specific model architecture. 25 | 26 | This amazing result would be record in NLP history, 27 | and I expect many further papers about BERT will be published very soon. 28 | 29 | This repo is implementation of BERT. Code is very simple and easy to understand fastly. 30 | Some of these codes are based on [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) 31 | 32 | Currently this project is working on progress. And the code is not verified yet. 33 | 34 | ## Installation 35 | ``` 36 | pip install bert-pytorch 37 | ``` 38 | 39 | ## Quickstart 40 | 41 | **NOTICE : Your corpus should be prepared with two sentences in one line with tab(\t) separator** 42 | 43 | ### 0. Prepare your corpus 44 | ``` 45 | Welcome to the \t the jungle\n 46 | I can stay \t here all night\n 47 | ``` 48 | 49 | or tokenized corpus (tokenization is not in package) 50 | ``` 51 | Wel_ _come _to _the \t _the _jungle\n 52 | _I _can _stay \t _here _all _night\n 53 | ``` 54 | 55 | 56 | ### 1. Building vocab based on your corpus 57 | ```shell 58 | bert-vocab -c data/corpus.small -o data/vocab.small 59 | ``` 60 | 61 | ### 2. Train your own BERT model 62 | ```shell 63 | bert -c data/corpus.small -v data/vocab.small -o output/bert.model 64 | ``` 65 | 66 | ## Language Model Pre-training 67 | 68 | In the paper, authors shows the new language model training methods, 69 | which are "masked language model" and "predict next sentence". 70 | 71 | 72 | ### Masked Language Model 73 | 74 | > Original Paper : 3.3.1 Task #1: Masked LM 75 | 76 | ``` 77 | Input Sequence : The man went to [MASK] store with [MASK] dog 78 | Target Sequence : the his 79 | ``` 80 | 81 | #### Rules: 82 | Randomly 15% of input token will be changed into something, based on under sub-rules 83 | 84 | 1. Randomly 80% of tokens, gonna be a `[MASK]` token 85 | 2. Randomly 10% of tokens, gonna be a `[RANDOM]` token(another word) 86 | 3. Randomly 10% of tokens, will be remain as same. But need to be predicted. 87 | 88 | ### Predict Next Sentence 89 | 90 | > Original Paper : 3.3.2 Task #2: Next Sentence Prediction 91 | 92 | ``` 93 | Input : [CLS] the man went to the store [SEP] he bought a gallon of milk [SEP] 94 | Label : Is Next 95 | 96 | Input = [CLS] the man heading to the store [SEP] penguin [MASK] are flight ##less birds [SEP] 97 | Label = NotNext 98 | ``` 99 | 100 | "Is this sentence can be continuously connected?" 101 | 102 | understanding the relationship, between two text sentences, which is 103 | not directly captured by language modeling 104 | 105 | #### Rules: 106 | 107 | 1. Randomly 50% of next sentence, gonna be continuous sentence. 108 | 2. Randomly 50% of next sentence, gonna be unrelated sentence. 109 | 110 | 111 | ## Author 112 | Junseong Kim, Scatter Lab (codertimo@gmail.com / junseong.kim@scatterlab.co.kr) 113 | 114 | ## License 115 | 116 | This project following Apache 2.0 License as written in LICENSE file 117 | 118 | Copyright 2018 Junseong Kim, Scatter Lab, respective BERT contributors 119 | 120 | Copyright (c) 2018 Alexander Rush : [The Annotated Trasnformer](https://github.com/harvardnlp/annotated-transformer) 121 | -------------------------------------------------------------------------------- /bert_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import BERT 2 | -------------------------------------------------------------------------------- /bert_pytorch/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from .model import BERT 6 | from .trainer import BERTTrainer 7 | from .dataset import BERTDataset, WordVocab 8 | 9 | 10 | def train(): 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert") 14 | parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set") 15 | parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab") 16 | parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model") 17 | 18 | parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model") 19 | parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers") 20 | parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads") 21 | parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len") 22 | 23 | parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size") 24 | parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") 25 | parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size") 26 | 27 | parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false") 28 | parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n") 29 | parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus") 30 | parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids") 31 | parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false") 32 | 33 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam") 34 | parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam") 35 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value") 36 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value") 37 | 38 | args = parser.parse_args() 39 | 40 | # 读词表 41 | print("Loading Vocab", args.vocab_path) 42 | vocab = WordVocab.load_vocab(args.vocab_path) 43 | print("Vocab Size: ", len(vocab)) 44 | 45 | # 数据准备 46 | print("Loading Train Dataset", args.train_dataset) 47 | train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, 48 | corpus_lines=args.corpus_lines, on_memory=args.on_memory) 49 | 50 | print("Loading Test Dataset", args.test_dataset) 51 | test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \ 52 | if args.test_dataset is not None else None 53 | 54 | print("Creating Dataloader") 55 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 56 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \ 57 | if test_dataset is not None else None 58 | 59 | # 模型准备 60 | print("Building BERT model") 61 | bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) 62 | 63 | # 训练过程准备 64 | print("Creating BERT Trainer") 65 | trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, 66 | lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, 67 | with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq) 68 | 69 | # 开始训练 70 | print("Training Start") 71 | for epoch in range(args.epochs): 72 | trainer.train(epoch) 73 | trainer.save(epoch, args.output_path) 74 | 75 | if test_data_loader is not None: 76 | trainer.test(epoch) 77 | -------------------------------------------------------------------------------- /bert_pytorch/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import BERTDataset 2 | from .vocab import WordVocab 3 | -------------------------------------------------------------------------------- /bert_pytorch/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import tqdm 3 | import torch 4 | import random 5 | 6 | 7 | class BERTDataset(Dataset): 8 | def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): 9 | self.vocab = vocab 10 | self.seq_len = seq_len 11 | 12 | self.on_memory = on_memory 13 | self.corpus_lines = corpus_lines 14 | self.corpus_path = corpus_path 15 | self.encoding = encoding 16 | 17 | with open(corpus_path, "r", encoding=encoding) as f: 18 | if self.corpus_lines is None and not on_memory: 19 | for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines): 20 | self.corpus_lines += 1 21 | 22 | if on_memory: 23 | self.lines = [line[:-1].split("\\t") 24 | for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)] 25 | self.corpus_lines = len(self.lines) 26 | 27 | if not on_memory: 28 | self.file = open(corpus_path, "r", encoding=encoding) 29 | self.random_file = open(corpus_path, "r", encoding=encoding) 30 | 31 | for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): 32 | self.random_file.__next__() 33 | 34 | def __len__(self): 35 | return self.corpus_lines 36 | 37 | def __getitem__(self, item): 38 | t1, t2, is_next_label = self.random_sent(item) 39 | t1_random, t1_label = self.random_word(t1) 40 | t2_random, t2_label = self.random_word(t2) 41 | 42 | # [CLS] tag = SOS tag, [SEP] tag = EOS tag 43 | t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index] 44 | t2 = t2_random + [self.vocab.eos_index] 45 | 46 | t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index] 47 | t2_label = t2_label + [self.vocab.pad_index] 48 | 49 | segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] 50 | bert_input = (t1 + t2)[:self.seq_len] 51 | bert_label = (t1_label + t2_label)[:self.seq_len] 52 | 53 | padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))] 54 | bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding) 55 | 56 | output = {"bert_input": bert_input, 57 | "bert_label": bert_label, 58 | "segment_label": segment_label, 59 | "is_next": is_next_label} 60 | 61 | return {key: torch.tensor(value) for key, value in output.items()} 62 | 63 | def random_word(self, sentence): 64 | tokens = sentence.split() 65 | output_label = [] 66 | 67 | for i, token in enumerate(tokens): 68 | prob = random.random() 69 | if prob < 0.15: 70 | prob /= 0.15 71 | 72 | # 80% randomly change token to mask token 73 | if prob < 0.8: 74 | tokens[i] = self.vocab.mask_index 75 | 76 | # 10% randomly change token to random token 77 | elif prob < 0.9: 78 | tokens[i] = random.randrange(len(self.vocab)) 79 | 80 | # 10% randomly change token to current token 81 | else: 82 | tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) 83 | 84 | output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) 85 | 86 | else: 87 | tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) 88 | output_label.append(0) 89 | 90 | return tokens, output_label 91 | 92 | def random_sent(self, index): 93 | t1, t2 = self.get_corpus_line(index) 94 | 95 | # output_text, label(isNotNext:0, isNext:1) 96 | if random.random() > 0.5: 97 | return t1, t2, 1 98 | else: 99 | return t1, self.get_random_line(), 0 100 | 101 | def get_corpus_line(self, item): 102 | if self.on_memory: 103 | return self.lines[item][0], self.lines[item][1] 104 | else: 105 | line = self.file.__next__() 106 | if line is None: 107 | self.file.close() 108 | self.file = open(self.corpus_path, "r", encoding=self.encoding) 109 | line = self.file.__next__() 110 | 111 | t1, t2 = line[:-1].split("\t") 112 | return t1, t2 113 | 114 | def get_random_line(self): 115 | if self.on_memory: 116 | return self.lines[random.randrange(len(self.lines))][1] 117 | 118 | line = self.file.__next__() 119 | if line is None: 120 | self.file.close() 121 | self.file = open(self.corpus_path, "r", encoding=self.encoding) 122 | for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): 123 | self.random_file.__next__() 124 | line = self.random_file.__next__() 125 | return line[:-1].split("\t")[1] 126 | -------------------------------------------------------------------------------- /bert_pytorch/dataset/vocab.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tqdm 3 | from collections import Counter 4 | 5 | 6 | class TorchVocab(object): 7 | """Defines a vocabulary object that will be used to numericalize a field. 8 | Attributes: 9 | freqs: A collections.Counter object holding the frequencies of tokens 10 | in the data used to build the Vocab. 11 | stoi: A collections.defaultdict instance mapping token strings to 12 | numerical identifiers. 13 | itos: A list of token strings indexed by their numerical identifiers. 14 | """ 15 | 16 | def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''], 17 | vectors=None, unk_init=None, vectors_cache=None): 18 | """用一个 collections.Counter 对象简历 Vocab 19 | Args: 20 | counter: collections.Counter 对象。预训练文件中的 token 统计 {'token': 10} 21 | max_size: 词表最大长度。 None for no maximum. Default: None. 22 | min_freq: 最小词频。 Default: 1. 23 | specials: 列表, 包含一系列特殊字符,如[''] 24 | vectors: One of either the available pretrained vectors 25 | or custom pretrained vectors (see Vocab.load_vectors); 26 | or a list of aforementioned vectors 27 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 28 | to zero vectors; can be any function that takes in a Tensor and 29 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 30 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 31 | """ 32 | self.freqs = counter 33 | counter = counter.copy() 34 | min_freq = max(min_freq, 1) 35 | 36 | self.itos = list(specials) 37 | 38 | # 特殊字符不计入统计词频 39 | for tok in specials: 40 | del counter[tok] 41 | 42 | max_size = None if max_size is None else max_size + len(self.itos) 43 | 44 | # 先按照字典序排列,然后按照词频排列 45 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 46 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 47 | 48 | # 依据词频和字典长度过滤数据 49 | for word, freq in words_and_frequencies: 50 | if freq < min_freq or len(self.itos) == max_size: 51 | break 52 | self.itos.append(word) 53 | 54 | # token2idx 55 | self.stoi = {tok: i for i, tok in enumerate(self.itos)} 56 | 57 | self.vectors = None 58 | if vectors is not None: 59 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 60 | else: 61 | assert unk_init is None and vectors_cache is None 62 | 63 | def __eq__(self, other): 64 | if self.freqs != other.freqs: 65 | return False 66 | if self.stoi != other.stoi: 67 | return False 68 | if self.itos != other.itos: 69 | return False 70 | if self.vectors != other.vectors: 71 | return False 72 | return True 73 | 74 | def __len__(self): 75 | return len(self.itos) 76 | 77 | def vocab_rerank(self): 78 | self.stoi = {word: i for i, word in enumerate(self.itos)} 79 | 80 | def extend(self, v, sort=False): 81 | words = sorted(v.itos) if sort else v.itos 82 | for w in words: 83 | if w not in self.stoi: 84 | self.itos.append(w) 85 | self.stoi[w] = len(self.itos) - 1 86 | 87 | 88 | class Vocab(TorchVocab): 89 | def __init__(self, counter, max_size=None, min_freq=1): 90 | self.pad_index = 0 91 | self.unk_index = 1 92 | self.eos_index = 2 93 | self.sos_index = 3 94 | self.mask_index = 4 95 | super().__init__(counter, specials=["", "", "", "", ""], 96 | max_size=max_size, min_freq=min_freq) 97 | 98 | def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: 99 | pass 100 | 101 | def from_seq(self, seq, join=False, with_pad=False): 102 | pass 103 | 104 | @staticmethod 105 | def load_vocab(vocab_path: str) -> 'Vocab': 106 | with open(vocab_path, "rb") as f: 107 | return pickle.load(f) 108 | 109 | def save_vocab(self, vocab_path): 110 | with open(vocab_path, "wb") as f: 111 | pickle.dump(self, f) 112 | 113 | 114 | # Building Vocab with text files 115 | class WordVocab(Vocab): 116 | def __init__(self, texts, max_size=None, min_freq=1): 117 | print("Building Vocab") 118 | counter = Counter() 119 | for line in tqdm.tqdm(texts): 120 | if isinstance(line, list): 121 | words = line 122 | else: 123 | words = line.replace("\n", "").replace("\t", "").split() 124 | 125 | for word in words: 126 | counter[word] += 1 127 | super().__init__(counter, max_size=max_size, min_freq=min_freq) 128 | 129 | def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): 130 | if isinstance(sentence, str): 131 | sentence = sentence.split() 132 | 133 | seq = [self.stoi.get(word, self.unk_index) for word in sentence] 134 | 135 | if with_eos: 136 | seq += [self.eos_index] # this would be index 1 137 | if with_sos: 138 | seq = [self.sos_index] + seq 139 | 140 | origin_seq_len = len(seq) 141 | 142 | if seq_len is None: 143 | pass 144 | elif len(seq) <= seq_len: 145 | seq += [self.pad_index for _ in range(seq_len - len(seq))] 146 | else: 147 | seq = seq[:seq_len] 148 | 149 | return (seq, origin_seq_len) if with_len else seq 150 | 151 | def from_seq(self, seq, join=False, with_pad=False): 152 | words = [self.itos[idx] 153 | if idx < len(self.itos) 154 | else "<%d>" % idx 155 | for idx in seq 156 | if not with_pad or idx != self.pad_index] 157 | 158 | return " ".join(words) if join else words 159 | 160 | @staticmethod 161 | def load_vocab(vocab_path: str) -> 'WordVocab': 162 | """将 WordVocab 对象序列化到 vocab_path 文件中 """ 163 | with open(vocab_path, "rb") as f: 164 | return pickle.load(f) 165 | 166 | 167 | def build(): 168 | import argparse 169 | 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument("-c", "--corpus_path", required=True, type=str) 172 | parser.add_argument("-o", "--output_path", required=True, type=str) 173 | parser.add_argument("-s", "--vocab_size", type=int, default=None) 174 | parser.add_argument("-e", "--encoding", type=str, default="utf-8") 175 | parser.add_argument("-m", "--min_freq", type=int, default=1) 176 | args = parser.parse_args() 177 | 178 | with open(args.corpus_path, "r", encoding=args.encoding) as f: 179 | vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) 180 | 181 | print("VOCAB SIZE:", len(vocab)) 182 | vocab.save_vocab(args.output_path) 183 | -------------------------------------------------------------------------------- /bert_pytorch/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import BERT 2 | from .language_model import BERTLM 3 | -------------------------------------------------------------------------------- /bert_pytorch/model/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_head import MultiHeadedAttention 2 | from .single import Attention 3 | -------------------------------------------------------------------------------- /bert_pytorch/model/attention/multi_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .single import Attention 3 | 4 | 5 | class MultiHeadedAttention(nn.Module): 6 | """ 7 | Take in model size and number of heads. 8 | """ 9 | 10 | def __init__(self, h, d_model, dropout=0.1): 11 | super().__init__() 12 | assert d_model % h == 0 13 | 14 | # We assume d_v always equals d_k 15 | self.d_k = d_model // h 16 | self.h = h 17 | 18 | self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) 19 | self.output_linear = nn.Linear(d_model, d_model) 20 | self.attention = Attention() 21 | 22 | self.dropout = nn.Dropout(p=dropout) 23 | 24 | def forward(self, query, key, value, mask=None): 25 | batch_size = query.size(0) 26 | 27 | # 1) Do all the linear projections in batch from d_model => h x d_k 28 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 29 | for l, x in zip(self.linear_layers, (query, key, value))] 30 | 31 | # 2) Apply attention on all the projected vectors in batch. 32 | x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout) 33 | 34 | # 3) "Concat" using a view and apply a final linear. 35 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 36 | 37 | return self.output_linear(x) 38 | -------------------------------------------------------------------------------- /bert_pytorch/model/attention/single.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | import math 6 | 7 | 8 | class Attention(nn.Module): 9 | """ 10 | Compute 'Scaled Dot Product Attention 11 | """ 12 | 13 | def forward(self, query, key, value, mask=None, dropout=None): 14 | """ 15 | Args: query, key, value 同源且 shape 相同 16 | query: [batch_size, head_num, seq_len, dim] 17 | key: [batch_size, head_num, seq_len, dim] 18 | value: [batch_size, head_num, seq_len, dim] 19 | """ 20 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 21 | / math.sqrt(query.size(-1)) 22 | 23 | if mask is not None: 24 | scores = scores.masked_fill(mask == 0, -1e9) 25 | 26 | p_attn = F.softmax(scores, dim=-1) 27 | 28 | if dropout is not None: 29 | p_attn = dropout(p_attn) 30 | 31 | return torch.matmul(p_attn, value), p_attn 32 | -------------------------------------------------------------------------------- /bert_pytorch/model/bert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .transformer import TransformerBlock 4 | from .embedding import BERTEmbedding 5 | 6 | 7 | class BERT(nn.Module): 8 | """ 9 | BERT model : Bidirectional Encoder Representations from Transformers. 10 | """ 11 | 12 | def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): 13 | """ Bert 模型 14 | Args: 15 | vocab_size: 词表大小 16 | hidden: BERT 的 hidden size 17 | n_layers: Transformer 的层数 18 | attn_heads: Multi-head Attention 中的 head 数 19 | dropout: dropout rate 20 | """ 21 | 22 | super().__init__() 23 | self.hidden = hidden 24 | self.n_layers = n_layers 25 | self.attn_heads = attn_heads 26 | 27 | # paper noted they used 4*hidden_size for ff_network_hidden_size 28 | self.feed_forward_hidden = hidden * 4 29 | 30 | # BERT的输入 embedding, 由 positional, segment, token embeddings 三部分组成 31 | self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden) 32 | 33 | # 多层的 Transformer (Encoder), 由多个 TransformerBlock 组成 34 | self.transformer_blocks = nn.ModuleList( 35 | [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)]) 36 | 37 | def forward(self, x, segment_info): 38 | """ 39 | x: [batch_size, seq_len] 40 | segment_info: [batch_size, seq_len] 41 | """ 42 | 43 | # attention masking for padded token, 44 | mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) 45 | # [batch_size, 1, seq_len, seq_len] 46 | 47 | # embedding the indexed sequence to sequence of vectors 48 | x = self.embedding(x, segment_info) 49 | 50 | # running over multiple transformer blocks 51 | for transformer in self.transformer_blocks: 52 | x = transformer.forward(x, mask) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import BERTEmbedding 2 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/bert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .token import TokenEmbedding 3 | from .position import PositionalEmbedding 4 | from .segment import SegmentEmbedding 5 | 6 | 7 | class BERTEmbedding(nn.Module): 8 | """ 9 | BERT Embedding 由以下三部分组成: 10 | 1. TokenEmbedding : token embedding matrix 11 | 2. PositionalEmbedding : 位置信息编码 12 | 2. SegmentEmbedding : 句子信息编码, (sent_A:1, sent_B:2) 13 | """ 14 | 15 | def __init__(self, vocab_size, embed_size, dropout=0.1): 16 | """ 17 | Args: 18 | vocab_size: 词表大小 19 | embed_size: token embedding 的维度 20 | dropout: dropout rate 21 | """ 22 | super().__init__() 23 | self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) 24 | self.position = PositionalEmbedding(d_model=self.token.embedding_dim) 25 | self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim) 26 | self.dropout = nn.Dropout(p=dropout) 27 | self.embed_size = embed_size 28 | 29 | def forward(self, sequence, segment_label): 30 | x = self.token(sequence) + self.position(sequence) + self.segment(segment_label) 31 | return self.dropout(x) 32 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/position.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | 8 | def __init__(self, d_model, max_len=512): 9 | super().__init__() 10 | 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, :x.size(1)] 26 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/segment.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SegmentEmbedding(nn.Embedding): 5 | def __init__(self, embed_size=512): 6 | """ 3 为 0:padding_idx, 1:sent_A, 2:sent_B """ 7 | super().__init__(3, embed_size, padding_idx=0) 8 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/token.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class TokenEmbedding(nn.Embedding): 5 | def __init__(self, vocab_size, embed_size=512): 6 | super().__init__(vocab_size, embed_size, padding_idx=0) 7 | -------------------------------------------------------------------------------- /bert_pytorch/model/language_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .bert import BERT 4 | 5 | 6 | class BERTLM(nn.Module): 7 | """ 8 | BERT Language Model 9 | Next Sentence Prediction Model + Masked Language Model 10 | """ 11 | 12 | def __init__(self, bert: BERT, vocab_size): 13 | """ 14 | :param bert: BERT model which should be trained 15 | :param vocab_size: total vocab size for masked_lm 16 | """ 17 | 18 | super().__init__() 19 | self.bert = bert 20 | self.next_sentence = NextSentencePrediction(self.bert.hidden) 21 | self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size) 22 | 23 | def forward(self, x, segment_label): 24 | """ 25 | Args: 26 | x: [batch_size, seq_len] 27 | segment_label: [batch_size, seq_len], 句子标识,是句子1 还是句子2 28 | """ 29 | x = self.bert(x, segment_label) 30 | # x: [batch_size, seq_len, hidden] 31 | return self.next_sentence(x), self.mask_lm(x) 32 | 33 | 34 | class NextSentencePrediction(nn.Module): 35 | """ 36 | 2-class classification model : is_next, is_not_next 37 | """ 38 | 39 | def __init__(self, hidden): 40 | """ 41 | :param hidden: BERT model output size 42 | """ 43 | super().__init__() 44 | self.linear = nn.Linear(hidden, 2) 45 | self.softmax = nn.LogSoftmax(dim=-1) 46 | 47 | def forward(self, x): 48 | """ 49 | Args: 50 | x: [batch_size, seq_len, hidden] 51 | """ 52 | return self.softmax(self.linear(x[:, 0])) 53 | 54 | 55 | class MaskedLanguageModel(nn.Module): 56 | """ 57 | predicting origin token from masked input sequence 58 | n-class classification problem, n-class = vocab_size 59 | """ 60 | 61 | def __init__(self, hidden, vocab_size): 62 | """ 63 | :param hidden: output size of BERT model 64 | :param vocab_size: total vocab size 65 | """ 66 | super().__init__() 67 | self.linear = nn.Linear(hidden, vocab_size) 68 | self.softmax = nn.LogSoftmax(dim=-1) 69 | 70 | def forward(self, x): 71 | """ 72 | Args: 73 | x: [batch_size, seq_len, hidden] 74 | """ 75 | return self.softmax(self.linear(x)) 76 | -------------------------------------------------------------------------------- /bert_pytorch/model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .attention import MultiHeadedAttention 4 | from .utils import SublayerConnection, PositionwiseFeedForward 5 | 6 | 7 | class TransformerBlock(nn.Module): 8 | """ 9 | Bidirectional Encoder = Transformer (self-attention) 10 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 11 | """ 12 | 13 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): 14 | """ 15 | :param hidden: hidden size of transformer 16 | :param attn_heads: head sizes of multi-head attention 17 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 18 | :param dropout: dropout rate 19 | """ 20 | 21 | super().__init__() 22 | self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) 23 | self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) 24 | self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) 25 | self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) 26 | self.dropout = nn.Dropout(p=dropout) 27 | 28 | def forward(self, x, mask): 29 | x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) 30 | x = self.output_sublayer(x, self.feed_forward) 31 | return self.dropout(x) 32 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .feed_forward import PositionwiseFeedForward 2 | from .layer_norm import LayerNorm 3 | from .sublayer import SublayerConnection 4 | from .gelu import GELU 5 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .gelu import GELU 3 | 4 | 5 | class PositionwiseFeedForward(nn.Module): 6 | "Implements FFN equation." 7 | 8 | def __init__(self, d_model, d_ff, dropout=0.1): 9 | super(PositionwiseFeedForward, self).__init__() 10 | self.w_1 = nn.Linear(d_model, d_ff) 11 | self.w_2 = nn.Linear(d_ff, d_model) 12 | self.dropout = nn.Dropout(dropout) 13 | self.activation = GELU() 14 | 15 | def forward(self, x): 16 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 17 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/gelu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class GELU(nn.Module): 7 | """ 8 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 9 | """ 10 | 11 | def forward(self, x): 12 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 13 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class LayerNorm(nn.Module): 6 | "Construct a layernorm module (See citation for details)." 7 | 8 | def __init__(self, features, eps=1e-6): 9 | super(LayerNorm, self).__init__() 10 | self.a_2 = nn.Parameter(torch.ones(features)) 11 | self.b_2 = nn.Parameter(torch.zeros(features)) 12 | self.eps = eps 13 | 14 | def forward(self, x): 15 | mean = x.mean(-1, keepdim=True) 16 | std = x.std(-1, keepdim=True) 17 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 18 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/sublayer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .layer_norm import LayerNorm 3 | 4 | 5 | class SublayerConnection(nn.Module): 6 | """ 7 | A residual connection followed by a layer norm. 8 | Note for code simplicity the norm is first as opposed to last. 9 | """ 10 | 11 | def __init__(self, size, dropout): 12 | super(SublayerConnection, self).__init__() 13 | self.norm = LayerNorm(size) 14 | self.dropout = nn.Dropout(dropout) 15 | 16 | def forward(self, x, sublayer): 17 | "Apply residual connection to any sublayer with the same size." 18 | return x + self.dropout(sublayer(self.norm(x))) 19 | -------------------------------------------------------------------------------- /bert_pytorch/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .pretrain import BERTTrainer 2 | -------------------------------------------------------------------------------- /bert_pytorch/trainer/optim_schedule.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim(): 6 | '''A simple wrapper class for learning rate scheduling''' 7 | 8 | def __init__(self, optimizer, d_model, n_warmup_steps): 9 | self._optimizer = optimizer 10 | self.n_warmup_steps = n_warmup_steps 11 | self.n_current_steps = 0 12 | self.init_lr = np.power(d_model, -0.5) 13 | 14 | def step_and_update_lr(self): 15 | "Step with the inner optimizer" 16 | self._update_learning_rate() 17 | self._optimizer.step() 18 | 19 | def zero_grad(self): 20 | "Zero out the gradients by the inner optimizer" 21 | self._optimizer.zero_grad() 22 | 23 | def _get_lr_scale(self): 24 | return np.min([ 25 | np.power(self.n_current_steps, -0.5), 26 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 27 | 28 | def _update_learning_rate(self): 29 | ''' Learning rate scheduling per step ''' 30 | 31 | self.n_current_steps += 1 32 | lr = self.init_lr * self._get_lr_scale() 33 | 34 | for param_group in self._optimizer.param_groups: 35 | param_group['lr'] = lr 36 | -------------------------------------------------------------------------------- /bert_pytorch/trainer/pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.utils.data import DataLoader 5 | 6 | from ..model import BERTLM, BERT 7 | from .optim_schedule import ScheduledOptim 8 | 9 | import tqdm 10 | 11 | 12 | class BERTTrainer: 13 | """ 14 | BERTTrainer make the pretrained BERT model with two LM training method. 15 | 16 | 1. Masked Language Model : 3.3.1 Task #1: Masked LM 17 | 2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction 18 | 19 | please check the details on README.md with simple example. 20 | 21 | """ 22 | 23 | def __init__(self, bert: BERT, vocab_size: int, 24 | train_dataloader: DataLoader, test_dataloader: DataLoader = None, 25 | lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, 26 | with_cuda: bool = True, cuda_devices=None, log_freq: int = 10): 27 | """ 28 | :param bert: BERT model which you want to train 29 | :param vocab_size: total word vocab size 30 | :param train_dataloader: train dataset data loader 31 | :param test_dataloader: test dataset data loader [can be None] 32 | :param lr: learning rate of optimizer 33 | :param betas: Adam optimizer betas 34 | :param weight_decay: Adam optimizer weight decay param 35 | :param with_cuda: traning with cuda 36 | :param log_freq: logging frequency of the batch iteration 37 | """ 38 | 39 | # Setup cuda device for BERT training, argument -c, --cuda should be true 40 | cuda_condition = torch.cuda.is_available() and with_cuda 41 | self.device = torch.device("cuda:0" if cuda_condition else "cpu") 42 | 43 | # This BERT model will be saved every epoch 44 | self.bert = bert 45 | # Initialize the BERT Language Model, with BERT model 46 | self.model = BERTLM(bert, vocab_size).to(self.device) 47 | 48 | # Distributed GPU training if CUDA can detect more than 1 GPU 49 | # if with_cuda and torch.cuda.device_count() > 1: 50 | # print("Using %d GPUS for BERT" % torch.cuda.device_count()) 51 | # self.model = nn.DataParallel(self.model, device_ids=cuda_devices) 52 | 53 | # Setting the train and test data loader 54 | self.train_data = train_dataloader 55 | self.test_data = test_dataloader 56 | 57 | # Setting the Adam optimizer with hyper-param 58 | self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) 59 | self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) 60 | 61 | # Using Negative Log Likelihood Loss function for predicting the masked_token 62 | self.criterion = nn.NLLLoss(ignore_index=0) 63 | 64 | self.log_freq = log_freq 65 | 66 | print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) 67 | 68 | def train(self, epoch): 69 | self.iteration(epoch, self.train_data) 70 | 71 | def test(self, epoch): 72 | self.iteration(epoch, self.test_data, train=False) 73 | 74 | def iteration(self, epoch, data_loader, train=True): 75 | """ 76 | loop over the data_loader for training or testing 77 | if on train status, backward operation is activated 78 | and also auto save the model every peoch 79 | 80 | :param epoch: current epoch index 81 | :param data_loader: torch.utils.data.DataLoader for iteration 82 | :param train: boolean value of is train or test 83 | :return: None 84 | """ 85 | str_code = "train" if train else "test" 86 | 87 | # Setting the tqdm progress bar 88 | data_iter = tqdm.tqdm(enumerate(data_loader), 89 | desc="EP_%s:%d" % (str_code, epoch), 90 | total=len(data_loader), 91 | bar_format="{l_bar}{r_bar}") 92 | 93 | avg_loss = 0.0 94 | total_correct = 0 95 | total_element = 0 96 | 97 | for i, data in data_iter: 98 | # 0. batch_data will be sent into the device(GPU or cpu) 99 | data = {key: value.to(self.device) for key, value in data.items()} 100 | 101 | # 1. forward the next_sentence_prediction and masked_lm model 102 | next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) 103 | 104 | # 2-1. NLL(negative log likelihood) loss of is_next classification result 105 | next_loss = self.criterion(next_sent_output, data["is_next"]) 106 | 107 | # 2-2. NLLLoss of predicting masked token word 108 | mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) 109 | 110 | # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure 111 | loss = next_loss + mask_loss 112 | 113 | # 3. backward and optimization only in train 114 | if train: 115 | self.optim_schedule.zero_grad() 116 | loss.backward() 117 | self.optim_schedule.step_and_update_lr() 118 | 119 | # next sentence prediction accuracy 120 | correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item() 121 | avg_loss += loss.item() 122 | total_correct += correct 123 | total_element += data["is_next"].nelement() 124 | 125 | post_fix = { 126 | "epoch": epoch, 127 | "iter": i, 128 | "avg_loss": avg_loss / (i + 1), 129 | "avg_acc": total_correct / total_element * 100, 130 | "loss": loss.item() 131 | } 132 | 133 | if i % self.log_freq == 0: 134 | data_iter.write(str(post_fix)) 135 | 136 | print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", 137 | total_correct * 100.0 / total_element) 138 | 139 | def save(self, epoch, file_path="output/bert_trained.model"): 140 | """ 141 | Saving the current BERT model on file_path 142 | 143 | :param epoch: current epoch number 144 | :param file_path: model output path which gonna be file_path+"ep%d" % epoch 145 | :return: final_output_path 146 | """ 147 | output_path = file_path + ".ep%d" % epoch 148 | torch.save(self.bert.cpu(), output_path) 149 | self.bert.to(self.device) 150 | print("EP:%d Model Saved on:" % epoch, output_path) 151 | return output_path 152 | -------------------------------------------------------------------------------- /data/corpus.small: -------------------------------------------------------------------------------- 1 | Welcome to the \t the jungle\n 2 | I can stay \t here all night\n 3 | Welcome to the \t the jungle\n 4 | I can stay \t here all night\n 5 | Welcome to the \t the jungle\n 6 | I can stay \t here all night\n 7 | Welcome to the \t the jungle\n 8 | I can stay \t here all night\n 9 | Welcome to the \t the jungle\n 10 | I can stay \t here all night\n 11 | Welcome to the \t the jungle\n 12 | I can stay \t here all night\n 13 | Welcome to the \t the jungle\n 14 | I can stay \t here all night\n 15 | Welcome to the \t the jungle\n 16 | I can stay \t here all night\n 17 | Welcome to the \t the jungle\n 18 | I can stay \t here all night\n 19 | Welcome to the \t the jungle\n 20 | I can stay \t here all night\n 21 | Welcome to the \t the jungle\n 22 | I can stay \t here all night\n 23 | Welcome to the \t the jungle\n 24 | I can stay \t here all night\n 25 | Welcome to the \t the jungle\n 26 | I can stay \t here all night\n 27 | Welcome to the \t the jungle\n 28 | I can stay \t here all night\n 29 | Welcome to the \t the jungle\n 30 | I can stay \t here all night\n 31 | Welcome to the \t the jungle\n 32 | I can stay \t here all night\n 33 | Welcome to the \t the jungle\n 34 | I can stay \t here all night\n 35 | Welcome to the \t the jungle\n 36 | I can stay \t here all night\n 37 | Welcome to the \t the jungle\n 38 | I can stay \t here all night\n 39 | Welcome to the \t the jungle\n 40 | I can stay \t here all night\n 41 | Welcome to the \t the jungle\n 42 | I can stay \t here all night\n 43 | Welcome to the \t the jungle\n 44 | I can stay \t here all night\n 45 | Welcome to the \t the jungle\n 46 | I can stay \t here all night\n 47 | Welcome to the \t the jungle\n 48 | I can stay \t here all night\n 49 | Welcome to the \t the jungle\n 50 | I can stay \t here all night\n 51 | Welcome to the \t the jungle\n 52 | I can stay \t here all night\n 53 | Welcome to the \t the jungle\n 54 | I can stay \t here all night\n 55 | Welcome to the \t the jungle\n 56 | I can stay \t here all night\n 57 | Welcome to the \t the jungle\n 58 | I can stay \t here all night\n 59 | Welcome to the \t the jungle\n 60 | I can stay \t here all night\n 61 | Welcome to the \t the jungle\n 62 | I can stay \t here all night\n 63 | Welcome to the \t the jungle\n 64 | I can stay \t here all night\n 65 | Welcome to the \t the jungle\n 66 | I can stay \t here all night\n 67 | Welcome to the \t the jungle\n 68 | I can stay \t here all night\n 69 | Welcome to the \t the jungle\n 70 | I can stay \t here all night\n 71 | Welcome to the \t the jungle\n 72 | I can stay \t here all night\n 73 | Welcome to the \t the jungle\n 74 | I can stay \t here all night\n 75 | Welcome to the \t the jungle\n 76 | I can stay \t here all night\n 77 | Welcome to the \t the jungle\n 78 | I can stay \t here all night\n 79 | Welcome to the \t the jungle\n 80 | I can stay \t here all night\n 81 | Welcome to the \t the jungle\n 82 | I can stay \t here all night\n 83 | Welcome to the \t the jungle\n 84 | I can stay \t here all night\n 85 | Welcome to the \t the jungle\n 86 | I can stay \t here all night\n 87 | Welcome to the \t the jungle\n 88 | I can stay \t here all night\n 89 | Welcome to the \t the jungle\n 90 | I can stay \t here all night\n 91 | Welcome to the \t the jungle\n 92 | I can stay \t here all night\n 93 | Welcome to the \t the jungle\n 94 | I can stay \t here all night\n 95 | Welcome to the \t the jungle\n 96 | I can stay \t here all night\n 97 | Welcome to the \t the jungle\n 98 | I can stay \t here all night\n 99 | Welcome to the \t the jungle\n 100 | I can stay \t here all night\n 101 | Welcome to the \t the jungle\n 102 | I can stay \t here all night\n 103 | -------------------------------------------------------------------------------- /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songyingxin/BERT-pytorch/c03bc7f9ab3d382dce10d07ecb12cbf74b38ba51/img/1.png -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songyingxin/BERT-pytorch/c03bc7f9ab3d382dce10d07ecb12cbf74b38ba51/img/2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | torch>=0.4.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from setuptools.command.install import install 3 | import os 4 | import sys 5 | 6 | __version__ = "0.0.1a4" 7 | 8 | with open("requirements.txt") as f: 9 | require_packages = [line[:-1] if line[-1] == "\n" else line for line in f] 10 | 11 | with open("README.md", "r", encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | 15 | class VerifyVersionCommand(install): 16 | """Custom command to verify that the git tag matches our version""" 17 | description = 'verify that the git tag matches our version' 18 | 19 | def run(self): 20 | tag = os.getenv('CIRCLE_TAG') 21 | 22 | if tag != __version__: 23 | info = "Git tag: {0} does not match the version of this app: {1}".format( 24 | tag, __version__ 25 | ) 26 | sys.exit(info) 27 | 28 | 29 | setup( 30 | name="bert_pytorch", 31 | version=__version__, 32 | author='Junseong Kim', 33 | author_email='codertimo@gmail.com', 34 | packages=find_packages(), 35 | install_requires=require_packages, 36 | url="https://github.com/codertimo/BERT-pytorch", 37 | description="Google AI 2018 BERT pytorch implementation", 38 | long_description=long_description, 39 | long_description_content_type="text/markdown", 40 | classifiers=[ 41 | "Programming Language :: Python :: 3", 42 | "License :: OSI Approved :: Apache Software License", 43 | "Operating System :: OS Independent", 44 | ], 45 | entry_points={ 46 | 'console_scripts': [ 47 | 'bert = bert_pytorch.__main__:train', 48 | 'bert-vocab = bert_pytorch.dataset.vocab:build', 49 | ] 50 | }, 51 | cmdclass={ 52 | 'verify': VerifyVersionCommand, 53 | } 54 | ) 55 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from bert_pytorch import BERT 3 | 4 | 5 | class BERTVocabTestCase(unittest.TestCase): 6 | pass 7 | -------------------------------------------------------------------------------- /test_bert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from bert_pytorch.model import BERT 6 | from bert_pytorch.trainer import BERTTrainer 7 | from bert_pytorch.dataset import BERTDataset, WordVocab 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument("-c", "--train_dataset", required=True, 14 | type=str, help="train dataset for train bert") 15 | parser.add_argument("-t", "--test_dataset", type=str, 16 | default=None, help="test set for evaluate train set") 17 | parser.add_argument("-v", "--vocab_path", required=True, 18 | type=str, help="built vocab model path with bert-vocab") 19 | parser.add_argument("-o", "--output_path", required=True, 20 | type=str, help="ex)output/bert.model") 21 | 22 | parser.add_argument("-hs", "--hidden", type=int, 23 | default=256, help="hidden size of transformer model") 24 | parser.add_argument("-l", "--layers", type=int, 25 | default=8, help="number of layers") 26 | parser.add_argument("-a", "--attn_heads", type=int, 27 | default=8, help="number of attention heads") 28 | parser.add_argument("-s", "--seq_len", type=int, 29 | default=20, help="maximum sequence len") 30 | 31 | parser.add_argument("-b", "--batch_size", type=int, 32 | default=64, help="number of batch_size") 33 | parser.add_argument("-e", "--epochs", type=int, 34 | default=10, help="number of epochs") 35 | parser.add_argument("-w", "--num_workers", type=int, 36 | default=5, help="dataloader worker size") 37 | 38 | parser.add_argument("--with_cuda", type=bool, default=True, 39 | help="training with CUDA: true, or false") 40 | parser.add_argument("--log_freq", type=int, default=10, 41 | help="printing loss every n iter: setting n") 42 | parser.add_argument("--corpus_lines", type=int, 43 | default=None, help="total number of lines in corpus") 44 | parser.add_argument("--cuda_devices", type=int, nargs='+', 45 | default=None, help="CUDA device ids") 46 | parser.add_argument("--on_memory", type=bool, default=True, 47 | help="Loading on memory: true or false") 48 | 49 | parser.add_argument("--lr", type=float, default=1e-3, 50 | help="learning rate of adam") 51 | parser.add_argument("--adam_weight_decay", type=float, 52 | default=0.01, help="weight_decay of adam") 53 | parser.add_argument("--adam_beta1", type=float, 54 | default=0.9, help="adam first beta value") 55 | parser.add_argument("--adam_beta2", type=float, 56 | default=0.999, help="adam first beta value") 57 | 58 | args = parser.parse_args() 59 | 60 | print("Loading Vocab", args.vocab_path) 61 | vocab = WordVocab.load_vocab(args.vocab_path) 62 | print("Vocab Size: ", len(vocab)) 63 | 64 | print("Loading Train Dataset", args.train_dataset) 65 | train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, 66 | corpus_lines=args.corpus_lines, on_memory=args.on_memory) 67 | 68 | print("Loading Test Dataset", args.test_dataset) 69 | test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \ 70 | if args.test_dataset is not None else None 71 | 72 | print("Creating Dataloader") 73 | train_data_loader = DataLoader( 74 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 75 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \ 76 | if test_dataset is not None else None 77 | 78 | print("Building BERT model") 79 | bert = BERT(len(vocab), hidden=args.hidden, 80 | n_layers=args.layers, attn_heads=args.attn_heads) 81 | 82 | print("Creating BERT Trainer") 83 | trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, 84 | lr=args.lr, betas=( 85 | args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, 86 | with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq) 87 | 88 | print("Training Start") 89 | for epoch in range(args.epochs): 90 | trainer.train(epoch) 91 | trainer.save(epoch, args.output_path) 92 | 93 | if test_data_loader is not None: 94 | trainer.test(epoch) 95 | -------------------------------------------------------------------------------- /test_bert_vocab.py: -------------------------------------------------------------------------------- 1 | 2 | from bert_pytorch.dataset.vocab import * 3 | 4 | 5 | if __name__ == "__main__": 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("-c", "--corpus_path", required=True, type=str) 10 | parser.add_argument("-o", "--output_path", required=True, type=str) 11 | parser.add_argument("-s", "--vocab_size", type=int, default=None) 12 | parser.add_argument("-e", "--encoding", type=str, default="utf-8") 13 | parser.add_argument("-m", "--min_freq", type=int, default=1) 14 | args = parser.parse_args() 15 | 16 | with open(args.corpus_path, "r", encoding=args.encoding) as f: 17 | vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) 18 | 19 | print("VOCAB SIZE:", len(vocab)) 20 | vocab.save_vocab(args.output_path) 21 | --------------------------------------------------------------------------------