├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── bert_pytorch ├── __init__.py ├── __main__.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── vocab.py ├── model │ ├── __init__.py │ ├── attention │ │ ├── __init__.py │ │ ├── multi_head.py │ │ └── single.py │ ├── bert.py │ ├── embedding │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── position.py │ │ ├── segment.py │ │ └── token.py │ ├── language_model.py │ ├── transformer.py │ └── utils │ │ ├── __init__.py │ │ ├── feed_forward.py │ │ ├── gelu.py │ │ ├── layer_norm.py │ │ └── sublayer.py └── trainer │ ├── __init__.py │ ├── optim_schedule.py │ └── pretrain.py ├── requirements.txt ├── setup.py └── test.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: circleci/python:3.6.1 6 | 7 | working_directory: ~/repo 8 | 9 | steps: 10 | - checkout 11 | 12 | - restore_cache: 13 | keys: 14 | - v1-dependencies-{{ checksum "requirements.txt" }} 15 | - v1-dependencies- 16 | 17 | - run: 18 | name: install dependencies 19 | command: | 20 | python3 -m venv venv 21 | . venv/bin/activate 22 | pip install -r requirements.txt 23 | 24 | - save_cache: 25 | paths: 26 | - ./venv 27 | key: v1-dependencies-{{ checksum "requirements.txt" }} 28 | 29 | - run: 30 | name: run tests 31 | command: | 32 | . venv/bin/activate 33 | python -m unittest test.py 34 | 35 | - store_artifacts: 36 | path: test-reports 37 | destination: test-reports 38 | 39 | deploy: 40 | docker: 41 | - image: circleci/python:3.6.1 42 | 43 | working_directory: ~/repo 44 | 45 | steps: 46 | - checkout 47 | 48 | - restore_cache: 49 | key: v1-dependency-cache-{{ checksum "setup.py" }}-{{ checksum "Makefile" }} 50 | 51 | - run: 52 | name: verify git tag vs. version 53 | command: | 54 | python3 -m venv venv 55 | . venv/bin/activate 56 | python setup.py verify 57 | pip install twine 58 | 59 | - save_cache: 60 | key: v1-dependency-cache-{{ checksum "setup.py" }}-{{ checksum "Makefile" }} 61 | paths: 62 | - "venv" 63 | 64 | # Deploying to PyPI 65 | # for pip install kor2vec 66 | - run: 67 | name: init .pypirc 68 | command: | 69 | echo -e "[pypi]" >> ~/.pypirc 70 | echo -e "username = codertimo" >> ~/.pypirc 71 | echo -e "password = $PYPI_PASSWORD" >> ~/.pypirc 72 | 73 | - run: 74 | name: create packages 75 | command: | 76 | make package 77 | 78 | - run: 79 | name: upload to pypi 80 | command: | 81 | . venv/bin/activate 82 | twine upload dist/* 83 | workflows: 84 | version: 2 85 | build_and_deploy: 86 | jobs: 87 | - build: 88 | filters: 89 | tags: 90 | only: /.*/ 91 | - deploy: 92 | requires: 93 | - build 94 | filters: 95 | tags: 96 | only: /.*/ 97 | branches: 98 | ignore: /.*/ 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | output/ 3 | 4 | # Created by .ignore support plugin (hsz.mobi) 5 | ### Python template 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | ### JetBrains template 111 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 112 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 113 | 114 | # User-specific stuff 115 | .idea/**/workspace.xml 116 | .idea/**/tasks.xml 117 | .idea/**/usage.statistics.xml 118 | .idea/**/dictionaries 119 | .idea/**/shelf 120 | 121 | # Sensitive or high-churn files 122 | .idea/**/dataSources/ 123 | .idea/**/dataSources.ids 124 | .idea/**/dataSources.local.xml 125 | .idea/**/sqlDataSources.xml 126 | .idea/**/dynamic.xml 127 | .idea/**/uiDesigner.xml 128 | .idea/**/dbnavigator.xml 129 | 130 | # Gradle 131 | .idea/**/gradle.xml 132 | .idea/**/libraries 133 | 134 | # Gradle and Maven with auto-import 135 | # When using Gradle or Maven with auto-import, you should exclude module files, 136 | # since they will be recreated, and may cause churn. Uncomment if using 137 | # auto-import. 138 | # .idea/modules.xml 139 | # .idea/*.iml 140 | # .idea/modules 141 | 142 | # CMake 143 | cmake-build-*/ 144 | 145 | # Mongo Explorer plugin 146 | .idea/**/mongoSettings.xml 147 | 148 | # File-based project format 149 | *.iws 150 | 151 | # IntelliJ 152 | out/ 153 | 154 | # mpeltonen/sbt-idea plugin 155 | .idea_modules/ 156 | 157 | # JIRA plugin 158 | atlassian-ide-plugin.xml 159 | 160 | # Cursive Clojure plugin 161 | .idea/replstate.xml 162 | 163 | # Crashlytics plugin (for Android Studio and IntelliJ) 164 | com_crashlytics_export_strings.xml 165 | crashlytics.properties 166 | crashlytics-build.properties 167 | fabric.properties 168 | 169 | # Editor-based Rest Client 170 | .idea/httpRequests 171 | 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018 Junseong Kim, Scatter Lab, BERT contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | package: 2 | python setup.py sdist 3 | python setup.py bdist_wheel 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT-pytorch 2 | 3 | [![LICENSE](https://img.shields.io/github/license/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/blob/master/LICENSE) 4 | ![GitHub issues](https://img.shields.io/github/issues/codertimo/BERT-pytorch.svg) 5 | [![GitHub stars](https://img.shields.io/github/stars/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/stargazers) 6 | [![CircleCI](https://circleci.com/gh/codertimo/BERT-pytorch.svg?style=shield)](https://circleci.com/gh/codertimo/BERT-pytorch) 7 | [![PyPI](https://img.shields.io/pypi/v/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/) 8 | [![PyPI - Status](https://img.shields.io/pypi/status/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/) 9 | [![Documentation Status](https://readthedocs.org/projects/bert-pytorch/badge/?version=latest)](https://bert-pytorch.readthedocs.io/en/latest/?badge=latest) 10 | 11 | Pytorch implementation of Google AI's 2018 BERT, with simple annotation 12 | 13 | > BERT 2018 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 14 | > Paper URL : https://arxiv.org/abs/1810.04805 15 | 16 | 17 | ## Introduction 18 | 19 | Google AI's BERT paper shows the amazing result on various NLP task (new 17 NLP tasks SOTA), 20 | including outperform the human F1 score on SQuAD v1.1 QA task. 21 | This paper proved that Transformer(self-attention) based encoder can be powerfully used as 22 | alternative of previous language model with proper language model training method. 23 | And more importantly, they showed us that this pre-trained language model can be transfer 24 | into any NLP task without making task specific model architecture. 25 | 26 | This amazing result would be record in NLP history, 27 | and I expect many further papers about BERT will be published very soon. 28 | 29 | This repo is implementation of BERT. Code is very simple and easy to understand fastly. 30 | Some of these codes are based on [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) 31 | 32 | Currently this project is working on progress. And the code is not verified yet. 33 | 34 | ## Installation 35 | ``` 36 | pip install bert-pytorch 37 | ``` 38 | 39 | ## Quickstart 40 | 41 | **NOTICE : Your corpus should be prepared with two sentences in one line with tab(\t) separator** 42 | 43 | ### 0. Prepare your corpus 44 | ``` 45 | Welcome to the \t the jungle\n 46 | I can stay \t here all night\n 47 | ``` 48 | 49 | or tokenized corpus (tokenization is not in package) 50 | ``` 51 | Wel_ _come _to _the \t _the _jungle\n 52 | _I _can _stay \t _here _all _night\n 53 | ``` 54 | 55 | 56 | ### 1. Building vocab based on your corpus 57 | ```shell 58 | bert-vocab -c data/corpus.small -o data/vocab.small 59 | ``` 60 | 61 | ### 2. Train your own BERT model 62 | ```shell 63 | bert -c data/corpus.small -v data/vocab.small -o output/bert.model 64 | ``` 65 | 66 | ## Language Model Pre-training 67 | 68 | In the paper, authors shows the new language model training methods, 69 | which are "masked language model" and "predict next sentence". 70 | 71 | 72 | ### Masked Language Model 73 | 74 | > Original Paper : 3.3.1 Task #1: Masked LM 75 | 76 | ``` 77 | Input Sequence : The man went to [MASK] store with [MASK] dog 78 | Target Sequence : the his 79 | ``` 80 | 81 | #### Rules: 82 | Randomly 15% of input token will be changed into something, based on under sub-rules 83 | 84 | 1. Randomly 80% of tokens, gonna be a `[MASK]` token 85 | 2. Randomly 10% of tokens, gonna be a `[RANDOM]` token(another word) 86 | 3. Randomly 10% of tokens, will be remain as same. But need to be predicted. 87 | 88 | ### Predict Next Sentence 89 | 90 | > Original Paper : 3.3.2 Task #2: Next Sentence Prediction 91 | 92 | ``` 93 | Input : [CLS] the man went to the store [SEP] he bought a gallon of milk [SEP] 94 | Label : Is Next 95 | 96 | Input = [CLS] the man heading to the store [SEP] penguin [MASK] are flight ##less birds [SEP] 97 | Label = NotNext 98 | ``` 99 | 100 | "Is this sentence can be continuously connected?" 101 | 102 | understanding the relationship, between two text sentences, which is 103 | not directly captured by language modeling 104 | 105 | #### Rules: 106 | 107 | 1. Randomly 50% of next sentence, gonna be continuous sentence. 108 | 2. Randomly 50% of next sentence, gonna be unrelated sentence. 109 | 110 | 111 | ## Author 112 | Junseong Kim, Scatter Lab (codertimo@gmail.com / junseong.kim@scatterlab.co.kr) 113 | 114 | ## License 115 | 116 | This project following Apache 2.0 License as written in LICENSE file 117 | 118 | Copyright 2018 Junseong Kim, Scatter Lab, respective BERT contributors 119 | 120 | Copyright (c) 2018 Alexander Rush : [The Annotated Trasnformer](https://github.com/harvardnlp/annotated-transformer) 121 | -------------------------------------------------------------------------------- /bert_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import BERT 2 | -------------------------------------------------------------------------------- /bert_pytorch/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from .model import BERT 6 | from .trainer import BERTTrainer 7 | from .dataset import BERTDataset, WordVocab 8 | 9 | 10 | def train(): 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert") 14 | parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set") 15 | parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab") 16 | parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model") 17 | 18 | parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model") 19 | parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers") 20 | parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads") 21 | parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len") 22 | 23 | parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size") 24 | parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") 25 | parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size") 26 | 27 | parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false") 28 | parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n") 29 | parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus") 30 | parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids") 31 | parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false") 32 | 33 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam") 34 | parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam") 35 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value") 36 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value") 37 | 38 | args = parser.parse_args() 39 | 40 | print("Loading Vocab", args.vocab_path) 41 | vocab = WordVocab.load_vocab(args.vocab_path) 42 | print("Vocab Size: ", len(vocab)) 43 | 44 | print("Loading Train Dataset", args.train_dataset) 45 | train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, 46 | corpus_lines=args.corpus_lines, on_memory=args.on_memory) 47 | 48 | print("Loading Test Dataset", args.test_dataset) 49 | test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \ 50 | if args.test_dataset is not None else None 51 | 52 | print("Creating Dataloader") 53 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 54 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \ 55 | if test_dataset is not None else None 56 | 57 | print("Building BERT model") 58 | bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) 59 | 60 | print("Creating BERT Trainer") 61 | trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, 62 | lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, 63 | with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq) 64 | 65 | print("Training Start") 66 | for epoch in range(args.epochs): 67 | trainer.train(epoch) 68 | trainer.save(epoch, args.output_path) 69 | 70 | if test_data_loader is not None: 71 | trainer.test(epoch) 72 | -------------------------------------------------------------------------------- /bert_pytorch/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import BERTDataset 2 | from .vocab import WordVocab 3 | -------------------------------------------------------------------------------- /bert_pytorch/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import tqdm 3 | import torch 4 | import random 5 | 6 | 7 | class BERTDataset(Dataset): 8 | def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): 9 | self.vocab = vocab 10 | self.seq_len = seq_len 11 | 12 | self.on_memory = on_memory 13 | self.corpus_lines = corpus_lines 14 | self.corpus_path = corpus_path 15 | self.encoding = encoding 16 | 17 | with open(corpus_path, "r", encoding=encoding) as f: 18 | if self.corpus_lines is None and not on_memory: 19 | for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines): 20 | self.corpus_lines += 1 21 | 22 | if on_memory: 23 | self.lines = [line[:-1].split("\t") 24 | for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)] 25 | self.corpus_lines = len(self.lines) 26 | 27 | if not on_memory: 28 | self.file = open(corpus_path, "r", encoding=encoding) 29 | self.random_file = open(corpus_path, "r", encoding=encoding) 30 | 31 | for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): 32 | self.random_file.__next__() 33 | 34 | def __len__(self): 35 | return self.corpus_lines 36 | 37 | def __getitem__(self, item): 38 | t1, t2, is_next_label = self.random_sent(item) 39 | t1_random, t1_label = self.random_word(t1) 40 | t2_random, t2_label = self.random_word(t2) 41 | 42 | # [CLS] tag = SOS tag, [SEP] tag = EOS tag 43 | t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index] 44 | t2 = t2_random + [self.vocab.eos_index] 45 | 46 | t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index] 47 | t2_label = t2_label + [self.vocab.pad_index] 48 | 49 | segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] 50 | bert_input = (t1 + t2)[:self.seq_len] 51 | bert_label = (t1_label + t2_label)[:self.seq_len] 52 | 53 | padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))] 54 | bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding) 55 | 56 | output = {"bert_input": bert_input, 57 | "bert_label": bert_label, 58 | "segment_label": segment_label, 59 | "is_next": is_next_label} 60 | 61 | return {key: torch.tensor(value) for key, value in output.items()} 62 | 63 | def random_word(self, sentence): 64 | tokens = sentence.split() 65 | output_label = [] 66 | 67 | for i, token in enumerate(tokens): 68 | prob = random.random() 69 | if prob < 0.15: 70 | prob /= 0.15 71 | 72 | # 80% randomly change token to mask token 73 | if prob < 0.8: 74 | tokens[i] = self.vocab.mask_index 75 | 76 | # 10% randomly change token to random token 77 | elif prob < 0.9: 78 | tokens[i] = random.randrange(len(self.vocab)) 79 | 80 | # 10% randomly change token to current token 81 | else: 82 | tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) 83 | 84 | output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) 85 | 86 | else: 87 | tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) 88 | output_label.append(0) 89 | 90 | return tokens, output_label 91 | 92 | def random_sent(self, index): 93 | t1, t2 = self.get_corpus_line(index) 94 | 95 | # output_text, label(isNotNext:0, isNext:1) 96 | if random.random() > 0.5: 97 | return t1, t2, 1 98 | else: 99 | return t1, self.get_random_line(), 0 100 | 101 | def get_corpus_line(self, item): 102 | if self.on_memory: 103 | return self.lines[item][0], self.lines[item][1] 104 | else: 105 | line = self.file.__next__() 106 | if line is None: 107 | self.file.close() 108 | self.file = open(self.corpus_path, "r", encoding=self.encoding) 109 | line = self.file.__next__() 110 | 111 | t1, t2 = line[:-1].split("\t") 112 | return t1, t2 113 | 114 | def get_random_line(self): 115 | if self.on_memory: 116 | return self.lines[random.randrange(len(self.lines))][1] 117 | 118 | line = self.file.__next__() 119 | if line is None: 120 | self.file.close() 121 | self.file = open(self.corpus_path, "r", encoding=self.encoding) 122 | for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): 123 | self.random_file.__next__() 124 | line = self.random_file.__next__() 125 | return line[:-1].split("\t")[1] 126 | -------------------------------------------------------------------------------- /bert_pytorch/dataset/vocab.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tqdm 3 | from collections import Counter 4 | 5 | 6 | class TorchVocab(object): 7 | """Defines a vocabulary object that will be used to numericalize a field. 8 | Attributes: 9 | freqs: A collections.Counter object holding the frequencies of tokens 10 | in the data used to build the Vocab. 11 | stoi: A collections.defaultdict instance mapping token strings to 12 | numerical identifiers. 13 | itos: A list of token strings indexed by their numerical identifiers. 14 | """ 15 | 16 | def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''], 17 | vectors=None, unk_init=None, vectors_cache=None): 18 | """Create a Vocab object from a collections.Counter. 19 | Arguments: 20 | counter: collections.Counter object holding the frequencies of 21 | each value found in the data. 22 | max_size: The maximum size of the vocabulary, or None for no 23 | maximum. Default: None. 24 | min_freq: The minimum frequency needed to include a token in the 25 | vocabulary. Values less than 1 will be set to 1. Default: 1. 26 | specials: The list of special tokens (e.g., padding or eos) that 27 | will be prepended to the vocabulary in addition to an 28 | token. Default: [''] 29 | vectors: One of either the available pretrained vectors 30 | or custom pretrained vectors (see Vocab.load_vectors); 31 | or a list of aforementioned vectors 32 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 33 | to zero vectors; can be any function that takes in a Tensor and 34 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 35 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 36 | """ 37 | self.freqs = counter 38 | counter = counter.copy() 39 | min_freq = max(min_freq, 1) 40 | 41 | self.itos = list(specials) 42 | # frequencies of special tokens are not counted when building vocabulary 43 | # in frequency order 44 | for tok in specials: 45 | del counter[tok] 46 | 47 | max_size = None if max_size is None else max_size + len(self.itos) 48 | 49 | # sort by frequency, then alphabetically 50 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 51 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 52 | 53 | for word, freq in words_and_frequencies: 54 | if freq < min_freq or len(self.itos) == max_size: 55 | break 56 | self.itos.append(word) 57 | 58 | # stoi is simply a reverse dict for itos 59 | self.stoi = {tok: i for i, tok in enumerate(self.itos)} 60 | 61 | self.vectors = None 62 | if vectors is not None: 63 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 64 | else: 65 | assert unk_init is None and vectors_cache is None 66 | 67 | def __eq__(self, other): 68 | if self.freqs != other.freqs: 69 | return False 70 | if self.stoi != other.stoi: 71 | return False 72 | if self.itos != other.itos: 73 | return False 74 | if self.vectors != other.vectors: 75 | return False 76 | return True 77 | 78 | def __len__(self): 79 | return len(self.itos) 80 | 81 | def vocab_rerank(self): 82 | self.stoi = {word: i for i, word in enumerate(self.itos)} 83 | 84 | def extend(self, v, sort=False): 85 | words = sorted(v.itos) if sort else v.itos 86 | for w in words: 87 | if w not in self.stoi: 88 | self.itos.append(w) 89 | self.stoi[w] = len(self.itos) - 1 90 | 91 | 92 | class Vocab(TorchVocab): 93 | def __init__(self, counter, max_size=None, min_freq=1): 94 | self.pad_index = 0 95 | self.unk_index = 1 96 | self.eos_index = 2 97 | self.sos_index = 3 98 | self.mask_index = 4 99 | super().__init__(counter, specials=["", "", "", "", ""], 100 | max_size=max_size, min_freq=min_freq) 101 | 102 | def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: 103 | pass 104 | 105 | def from_seq(self, seq, join=False, with_pad=False): 106 | pass 107 | 108 | @staticmethod 109 | def load_vocab(vocab_path: str) -> 'Vocab': 110 | with open(vocab_path, "rb") as f: 111 | return pickle.load(f) 112 | 113 | def save_vocab(self, vocab_path): 114 | with open(vocab_path, "wb") as f: 115 | pickle.dump(self, f) 116 | 117 | 118 | # Building Vocab with text files 119 | class WordVocab(Vocab): 120 | def __init__(self, texts, max_size=None, min_freq=1): 121 | print("Building Vocab") 122 | counter = Counter() 123 | for line in tqdm.tqdm(texts): 124 | if isinstance(line, list): 125 | words = line 126 | else: 127 | words = line.replace("\n", "").replace("\t", "").split() 128 | 129 | for word in words: 130 | counter[word] += 1 131 | super().__init__(counter, max_size=max_size, min_freq=min_freq) 132 | 133 | def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): 134 | if isinstance(sentence, str): 135 | sentence = sentence.split() 136 | 137 | seq = [self.stoi.get(word, self.unk_index) for word in sentence] 138 | 139 | if with_eos: 140 | seq += [self.eos_index] # this would be index 1 141 | if with_sos: 142 | seq = [self.sos_index] + seq 143 | 144 | origin_seq_len = len(seq) 145 | 146 | if seq_len is None: 147 | pass 148 | elif len(seq) <= seq_len: 149 | seq += [self.pad_index for _ in range(seq_len - len(seq))] 150 | else: 151 | seq = seq[:seq_len] 152 | 153 | return (seq, origin_seq_len) if with_len else seq 154 | 155 | def from_seq(self, seq, join=False, with_pad=False): 156 | words = [self.itos[idx] 157 | if idx < len(self.itos) 158 | else "<%d>" % idx 159 | for idx in seq 160 | if not with_pad or idx != self.pad_index] 161 | 162 | return " ".join(words) if join else words 163 | 164 | @staticmethod 165 | def load_vocab(vocab_path: str) -> 'WordVocab': 166 | with open(vocab_path, "rb") as f: 167 | return pickle.load(f) 168 | 169 | 170 | def build(): 171 | import argparse 172 | 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument("-c", "--corpus_path", required=True, type=str) 175 | parser.add_argument("-o", "--output_path", required=True, type=str) 176 | parser.add_argument("-s", "--vocab_size", type=int, default=None) 177 | parser.add_argument("-e", "--encoding", type=str, default="utf-8") 178 | parser.add_argument("-m", "--min_freq", type=int, default=1) 179 | args = parser.parse_args() 180 | 181 | with open(args.corpus_path, "r", encoding=args.encoding) as f: 182 | vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) 183 | 184 | print("VOCAB SIZE:", len(vocab)) 185 | vocab.save_vocab(args.output_path) 186 | -------------------------------------------------------------------------------- /bert_pytorch/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import BERT 2 | from .language_model import BERTLM 3 | -------------------------------------------------------------------------------- /bert_pytorch/model/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_head import MultiHeadedAttention 2 | from .single import Attention 3 | -------------------------------------------------------------------------------- /bert_pytorch/model/attention/multi_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .single import Attention 3 | 4 | 5 | class MultiHeadedAttention(nn.Module): 6 | """ 7 | Take in model size and number of heads. 8 | """ 9 | 10 | def __init__(self, h, d_model, dropout=0.1): 11 | super().__init__() 12 | assert d_model % h == 0 13 | 14 | # We assume d_v always equals d_k 15 | self.d_k = d_model // h 16 | self.h = h 17 | 18 | self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) 19 | self.output_linear = nn.Linear(d_model, d_model) 20 | self.attention = Attention() 21 | 22 | self.dropout = nn.Dropout(p=dropout) 23 | 24 | def forward(self, query, key, value, mask=None): 25 | batch_size = query.size(0) 26 | 27 | # 1) Do all the linear projections in batch from d_model => h x d_k 28 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 29 | for l, x in zip(self.linear_layers, (query, key, value))] 30 | 31 | # 2) Apply attention on all the projected vectors in batch. 32 | x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout) 33 | 34 | # 3) "Concat" using a view and apply a final linear. 35 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 36 | 37 | return self.output_linear(x) 38 | -------------------------------------------------------------------------------- /bert_pytorch/model/attention/single.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | import math 6 | 7 | 8 | class Attention(nn.Module): 9 | """ 10 | Compute 'Scaled Dot Product Attention 11 | """ 12 | 13 | def forward(self, query, key, value, mask=None, dropout=None): 14 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 15 | / math.sqrt(query.size(-1)) 16 | 17 | if mask is not None: 18 | scores = scores.masked_fill(mask == 0, -1e9) 19 | 20 | p_attn = F.softmax(scores, dim=-1) 21 | 22 | if dropout is not None: 23 | p_attn = dropout(p_attn) 24 | 25 | return torch.matmul(p_attn, value), p_attn 26 | -------------------------------------------------------------------------------- /bert_pytorch/model/bert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .transformer import TransformerBlock 4 | from .embedding import BERTEmbedding 5 | 6 | 7 | class BERT(nn.Module): 8 | """ 9 | BERT model : Bidirectional Encoder Representations from Transformers. 10 | """ 11 | 12 | def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): 13 | """ 14 | :param vocab_size: vocab_size of total words 15 | :param hidden: BERT model hidden size 16 | :param n_layers: numbers of Transformer blocks(layers) 17 | :param attn_heads: number of attention heads 18 | :param dropout: dropout rate 19 | """ 20 | 21 | super().__init__() 22 | self.hidden = hidden 23 | self.n_layers = n_layers 24 | self.attn_heads = attn_heads 25 | 26 | # paper noted they used 4*hidden_size for ff_network_hidden_size 27 | self.feed_forward_hidden = hidden * 4 28 | 29 | # embedding for BERT, sum of positional, segment, token embeddings 30 | self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden) 31 | 32 | # multi-layers transformer blocks, deep network 33 | self.transformer_blocks = nn.ModuleList( 34 | [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)]) 35 | 36 | def forward(self, x, segment_info): 37 | # attention masking for padded token 38 | # torch.ByteTensor([batch_size, 1, seq_len, seq_len) 39 | mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) 40 | 41 | # embedding the indexed sequence to sequence of vectors 42 | x = self.embedding(x, segment_info) 43 | 44 | # running over multiple transformer blocks 45 | for transformer in self.transformer_blocks: 46 | x = transformer.forward(x, mask) 47 | 48 | return x 49 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import BERTEmbedding 2 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/bert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .token import TokenEmbedding 3 | from .position import PositionalEmbedding 4 | from .segment import SegmentEmbedding 5 | 6 | 7 | class BERTEmbedding(nn.Module): 8 | """ 9 | BERT Embedding which is consisted with under features 10 | 1. TokenEmbedding : normal embedding matrix 11 | 2. PositionalEmbedding : adding positional information using sin, cos 12 | 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2) 13 | 14 | sum of all these features are output of BERTEmbedding 15 | """ 16 | 17 | def __init__(self, vocab_size, embed_size, dropout=0.1): 18 | """ 19 | :param vocab_size: total vocab size 20 | :param embed_size: embedding size of token embedding 21 | :param dropout: dropout rate 22 | """ 23 | super().__init__() 24 | self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) 25 | self.position = PositionalEmbedding(d_model=self.token.embedding_dim) 26 | self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim) 27 | self.dropout = nn.Dropout(p=dropout) 28 | self.embed_size = embed_size 29 | 30 | def forward(self, sequence, segment_label): 31 | x = self.token(sequence) + self.position(sequence) + self.segment(segment_label) 32 | return self.dropout(x) 33 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/position.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | 8 | def __init__(self, d_model, max_len=512): 9 | super().__init__() 10 | 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, :x.size(1)] 26 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/segment.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SegmentEmbedding(nn.Embedding): 5 | def __init__(self, embed_size=512): 6 | super().__init__(3, embed_size, padding_idx=0) 7 | -------------------------------------------------------------------------------- /bert_pytorch/model/embedding/token.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class TokenEmbedding(nn.Embedding): 5 | def __init__(self, vocab_size, embed_size=512): 6 | super().__init__(vocab_size, embed_size, padding_idx=0) 7 | -------------------------------------------------------------------------------- /bert_pytorch/model/language_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .bert import BERT 4 | 5 | 6 | class BERTLM(nn.Module): 7 | """ 8 | BERT Language Model 9 | Next Sentence Prediction Model + Masked Language Model 10 | """ 11 | 12 | def __init__(self, bert: BERT, vocab_size): 13 | """ 14 | :param bert: BERT model which should be trained 15 | :param vocab_size: total vocab size for masked_lm 16 | """ 17 | 18 | super().__init__() 19 | self.bert = bert 20 | self.next_sentence = NextSentencePrediction(self.bert.hidden) 21 | self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size) 22 | 23 | def forward(self, x, segment_label): 24 | x = self.bert(x, segment_label) 25 | return self.next_sentence(x), self.mask_lm(x) 26 | 27 | 28 | class NextSentencePrediction(nn.Module): 29 | """ 30 | 2-class classification model : is_next, is_not_next 31 | """ 32 | 33 | def __init__(self, hidden): 34 | """ 35 | :param hidden: BERT model output size 36 | """ 37 | super().__init__() 38 | self.linear = nn.Linear(hidden, 2) 39 | self.softmax = nn.LogSoftmax(dim=-1) 40 | 41 | def forward(self, x): 42 | return self.softmax(self.linear(x[:, 0])) 43 | 44 | 45 | class MaskedLanguageModel(nn.Module): 46 | """ 47 | predicting origin token from masked input sequence 48 | n-class classification problem, n-class = vocab_size 49 | """ 50 | 51 | def __init__(self, hidden, vocab_size): 52 | """ 53 | :param hidden: output size of BERT model 54 | :param vocab_size: total vocab size 55 | """ 56 | super().__init__() 57 | self.linear = nn.Linear(hidden, vocab_size) 58 | self.softmax = nn.LogSoftmax(dim=-1) 59 | 60 | def forward(self, x): 61 | return self.softmax(self.linear(x)) 62 | -------------------------------------------------------------------------------- /bert_pytorch/model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .attention import MultiHeadedAttention 4 | from .utils import SublayerConnection, PositionwiseFeedForward 5 | 6 | 7 | class TransformerBlock(nn.Module): 8 | """ 9 | Bidirectional Encoder = Transformer (self-attention) 10 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 11 | """ 12 | 13 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): 14 | """ 15 | :param hidden: hidden size of transformer 16 | :param attn_heads: head sizes of multi-head attention 17 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 18 | :param dropout: dropout rate 19 | """ 20 | 21 | super().__init__() 22 | self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) 23 | self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) 24 | self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) 25 | self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) 26 | self.dropout = nn.Dropout(p=dropout) 27 | 28 | def forward(self, x, mask): 29 | x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) 30 | x = self.output_sublayer(x, self.feed_forward) 31 | return self.dropout(x) 32 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .feed_forward import PositionwiseFeedForward 2 | from .layer_norm import LayerNorm 3 | from .sublayer import SublayerConnection 4 | from .gelu import GELU 5 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .gelu import GELU 3 | 4 | 5 | class PositionwiseFeedForward(nn.Module): 6 | "Implements FFN equation." 7 | 8 | def __init__(self, d_model, d_ff, dropout=0.1): 9 | super(PositionwiseFeedForward, self).__init__() 10 | self.w_1 = nn.Linear(d_model, d_ff) 11 | self.w_2 = nn.Linear(d_ff, d_model) 12 | self.dropout = nn.Dropout(dropout) 13 | self.activation = GELU() 14 | 15 | def forward(self, x): 16 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 17 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/gelu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class GELU(nn.Module): 7 | """ 8 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 9 | """ 10 | 11 | def forward(self, x): 12 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 13 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class LayerNorm(nn.Module): 6 | "Construct a layernorm module (See citation for details)." 7 | 8 | def __init__(self, features, eps=1e-6): 9 | super(LayerNorm, self).__init__() 10 | self.a_2 = nn.Parameter(torch.ones(features)) 11 | self.b_2 = nn.Parameter(torch.zeros(features)) 12 | self.eps = eps 13 | 14 | def forward(self, x): 15 | mean = x.mean(-1, keepdim=True) 16 | std = x.std(-1, keepdim=True) 17 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 18 | -------------------------------------------------------------------------------- /bert_pytorch/model/utils/sublayer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .layer_norm import LayerNorm 3 | 4 | 5 | class SublayerConnection(nn.Module): 6 | """ 7 | A residual connection followed by a layer norm. 8 | Note for code simplicity the norm is first as opposed to last. 9 | """ 10 | 11 | def __init__(self, size, dropout): 12 | super(SublayerConnection, self).__init__() 13 | self.norm = LayerNorm(size) 14 | self.dropout = nn.Dropout(dropout) 15 | 16 | def forward(self, x, sublayer): 17 | "Apply residual connection to any sublayer with the same size." 18 | return x + self.dropout(sublayer(self.norm(x))) 19 | -------------------------------------------------------------------------------- /bert_pytorch/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .pretrain import BERTTrainer 2 | -------------------------------------------------------------------------------- /bert_pytorch/trainer/optim_schedule.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim(): 6 | '''A simple wrapper class for learning rate scheduling''' 7 | 8 | def __init__(self, optimizer, d_model, n_warmup_steps): 9 | self._optimizer = optimizer 10 | self.n_warmup_steps = n_warmup_steps 11 | self.n_current_steps = 0 12 | self.init_lr = np.power(d_model, -0.5) 13 | 14 | def step_and_update_lr(self): 15 | "Step with the inner optimizer" 16 | self._update_learning_rate() 17 | self._optimizer.step() 18 | 19 | def zero_grad(self): 20 | "Zero out the gradients by the inner optimizer" 21 | self._optimizer.zero_grad() 22 | 23 | def _get_lr_scale(self): 24 | return np.min([ 25 | np.power(self.n_current_steps, -0.5), 26 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 27 | 28 | def _update_learning_rate(self): 29 | ''' Learning rate scheduling per step ''' 30 | 31 | self.n_current_steps += 1 32 | lr = self.init_lr * self._get_lr_scale() 33 | 34 | for param_group in self._optimizer.param_groups: 35 | param_group['lr'] = lr 36 | -------------------------------------------------------------------------------- /bert_pytorch/trainer/pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.utils.data import DataLoader 5 | 6 | from ..model import BERTLM, BERT 7 | from .optim_schedule import ScheduledOptim 8 | 9 | import tqdm 10 | 11 | 12 | class BERTTrainer: 13 | """ 14 | BERTTrainer make the pretrained BERT model with two LM training method. 15 | 16 | 1. Masked Language Model : 3.3.1 Task #1: Masked LM 17 | 2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction 18 | 19 | please check the details on README.md with simple example. 20 | 21 | """ 22 | 23 | def __init__(self, bert: BERT, vocab_size: int, 24 | train_dataloader: DataLoader, test_dataloader: DataLoader = None, 25 | lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, 26 | with_cuda: bool = True, cuda_devices=None, log_freq: int = 10): 27 | """ 28 | :param bert: BERT model which you want to train 29 | :param vocab_size: total word vocab size 30 | :param train_dataloader: train dataset data loader 31 | :param test_dataloader: test dataset data loader [can be None] 32 | :param lr: learning rate of optimizer 33 | :param betas: Adam optimizer betas 34 | :param weight_decay: Adam optimizer weight decay param 35 | :param with_cuda: traning with cuda 36 | :param log_freq: logging frequency of the batch iteration 37 | """ 38 | 39 | # Setup cuda device for BERT training, argument -c, --cuda should be true 40 | cuda_condition = torch.cuda.is_available() and with_cuda 41 | self.device = torch.device("cuda:0" if cuda_condition else "cpu") 42 | 43 | # This BERT model will be saved every epoch 44 | self.bert = bert 45 | # Initialize the BERT Language Model, with BERT model 46 | self.model = BERTLM(bert, vocab_size).to(self.device) 47 | 48 | # Distributed GPU training if CUDA can detect more than 1 GPU 49 | if with_cuda and torch.cuda.device_count() > 1: 50 | print("Using %d GPUS for BERT" % torch.cuda.device_count()) 51 | self.model = nn.DataParallel(self.model, device_ids=cuda_devices) 52 | 53 | # Setting the train and test data loader 54 | self.train_data = train_dataloader 55 | self.test_data = test_dataloader 56 | 57 | # Setting the Adam optimizer with hyper-param 58 | self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) 59 | self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) 60 | 61 | # Using Negative Log Likelihood Loss function for predicting the masked_token 62 | self.criterion = nn.NLLLoss(ignore_index=0) 63 | 64 | self.log_freq = log_freq 65 | 66 | print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) 67 | 68 | def train(self, epoch): 69 | self.iteration(epoch, self.train_data) 70 | 71 | def test(self, epoch): 72 | self.iteration(epoch, self.test_data, train=False) 73 | 74 | def iteration(self, epoch, data_loader, train=True): 75 | """ 76 | loop over the data_loader for training or testing 77 | if on train status, backward operation is activated 78 | and also auto save the model every peoch 79 | 80 | :param epoch: current epoch index 81 | :param data_loader: torch.utils.data.DataLoader for iteration 82 | :param train: boolean value of is train or test 83 | :return: None 84 | """ 85 | str_code = "train" if train else "test" 86 | 87 | # Setting the tqdm progress bar 88 | data_iter = tqdm.tqdm(enumerate(data_loader), 89 | desc="EP_%s:%d" % (str_code, epoch), 90 | total=len(data_loader), 91 | bar_format="{l_bar}{r_bar}") 92 | 93 | avg_loss = 0.0 94 | total_correct = 0 95 | total_element = 0 96 | 97 | for i, data in data_iter: 98 | # 0. batch_data will be sent into the device(GPU or cpu) 99 | data = {key: value.to(self.device) for key, value in data.items()} 100 | 101 | # 1. forward the next_sentence_prediction and masked_lm model 102 | next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) 103 | 104 | # 2-1. NLL(negative log likelihood) loss of is_next classification result 105 | next_loss = self.criterion(next_sent_output, data["is_next"]) 106 | 107 | # 2-2. NLLLoss of predicting masked token word 108 | mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) 109 | 110 | # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure 111 | loss = next_loss + mask_loss 112 | 113 | # 3. backward and optimization only in train 114 | if train: 115 | self.optim_schedule.zero_grad() 116 | loss.backward() 117 | self.optim_schedule.step_and_update_lr() 118 | 119 | # next sentence prediction accuracy 120 | correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item() 121 | avg_loss += loss.item() 122 | total_correct += correct 123 | total_element += data["is_next"].nelement() 124 | 125 | post_fix = { 126 | "epoch": epoch, 127 | "iter": i, 128 | "avg_loss": avg_loss / (i + 1), 129 | "avg_acc": total_correct / total_element * 100, 130 | "loss": loss.item() 131 | } 132 | 133 | if i % self.log_freq == 0: 134 | data_iter.write(str(post_fix)) 135 | 136 | print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", 137 | total_correct * 100.0 / total_element) 138 | 139 | def save(self, epoch, file_path="output/bert_trained.model"): 140 | """ 141 | Saving the current BERT model on file_path 142 | 143 | :param epoch: current epoch number 144 | :param file_path: model output path which gonna be file_path+"ep%d" % epoch 145 | :return: final_output_path 146 | """ 147 | output_path = file_path + ".ep%d" % epoch 148 | torch.save(self.bert.cpu(), output_path) 149 | self.bert.to(self.device) 150 | print("EP:%d Model Saved on:" % epoch, output_path) 151 | return output_path 152 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | torch>=0.4.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from setuptools.command.install import install 3 | import os 4 | import sys 5 | 6 | __version__ = "0.0.1a4" 7 | 8 | with open("requirements.txt") as f: 9 | require_packages = [line[:-1] if line[-1] == "\n" else line for line in f] 10 | 11 | with open("README.md", "r", encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | 15 | class VerifyVersionCommand(install): 16 | """Custom command to verify that the git tag matches our version""" 17 | description = 'verify that the git tag matches our version' 18 | 19 | def run(self): 20 | tag = os.getenv('CIRCLE_TAG') 21 | 22 | if tag != __version__: 23 | info = "Git tag: {0} does not match the version of this app: {1}".format( 24 | tag, __version__ 25 | ) 26 | sys.exit(info) 27 | 28 | 29 | setup( 30 | name="bert_pytorch", 31 | version=__version__, 32 | author='Junseong Kim', 33 | author_email='codertimo@gmail.com', 34 | packages=find_packages(), 35 | install_requires=require_packages, 36 | url="https://github.com/codertimo/BERT-pytorch", 37 | description="Google AI 2018 BERT pytorch implementation", 38 | long_description=long_description, 39 | long_description_content_type="text/markdown", 40 | classifiers=[ 41 | "Programming Language :: Python :: 3", 42 | "License :: OSI Approved :: Apache Software License", 43 | "Operating System :: OS Independent", 44 | ], 45 | entry_points={ 46 | 'console_scripts': [ 47 | 'bert = bert_pytorch.__main__:train', 48 | 'bert-vocab = bert_pytorch.dataset.vocab:build', 49 | ] 50 | }, 51 | cmdclass={ 52 | 'verify': VerifyVersionCommand, 53 | } 54 | ) 55 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from bert_pytorch import BERT 3 | 4 | 5 | class BERTVocabTestCase(unittest.TestCase): 6 | pass 7 | --------------------------------------------------------------------------------