├── .dockerignore ├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE ├── README.md ├── config.ini ├── data ├── .gitkeep └── wiki │ └── .gitkeep ├── model └── .gitkeep ├── notebook ├── .gitkeep ├── check-extract-features.ipynb ├── check-trained-tokenizer.ipynb ├── finetune-to-livedoor-corpus.ipynb └── pretraining.ipynb ├── pretraining-loss.png ├── requirements.txt └── src ├── create_pretraining_data.py ├── data-download-and-extract.py ├── extract_features.py ├── file-preprocessing.sh ├── run_classifier.py ├── run_pretraining.py ├── tokenization_sentencepiece.py ├── train-sentencepiece.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data 2 | model 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | !data/.gitkeep 3 | !data/wiki/ 4 | data/wiki/* 5 | !data/wiki/.gitkeep 6 | model/* 7 | !model/.gitkeep 8 | __pycache__ 9 | .ipynb_checkpoints 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "bert"] 2 | path = bert 3 | url = https://github.com/google-research/bert.git 4 | [submodule "wikiextractor"] 5 | path = wikiextractor 6 | url = https://github.com/attardi/wikiextractor.git 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | LABEL maintainer="diracdiego@gmail.com" 3 | LABEL version="1.0" 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | python3 \ 7 | python3-pip \ 8 | python3-dev \ 9 | git 10 | 11 | RUN pip3 install --upgrade pip 12 | 13 | # Install python modules. 14 | COPY ./requirements.txt /requirements.txt 15 | RUN pip install -r /requirements.txt 16 | 17 | # Set Japanese environment 18 | RUN apt-get update && \ 19 | apt-get install -y locales && \ 20 | locale-gen ja_JP.UTF-8 && \ 21 | echo "export LANG=ja_JP.UTF-8" >> ~/.bashrc 22 | 23 | # Set alias for python3 24 | RUN echo "alias python=python3" >> $HOME/.bashrc && \ 25 | echo "alias pip=pip3" >> $HOME/.bashrc 26 | 27 | WORKDIR /work 28 | 29 | CMD ["/bin/bash"] 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT with SentencePiece for Japanese text. 2 | This is a repository of Japanese BERT model with SentencePiece tokenizer. 3 | 4 | To clone this repository together with the required 5 | [BERT](https://github.com/google-research/bert) and 6 | [WikiExtractor](https://github.com/attardi/wikiextractor): 7 | 8 | ```sh 9 | git clone --recurse-submodules https://github.com/yoheikikuta/bert-japanese 10 | ``` 11 | 12 | ## Pretrained models 13 | We provide pretrained BERT model and trained SentencePiece model for Japanese text. 14 | Training data is the Japanese wikipedia corpus from [`Wikimedia Downloads`](https://dumps.wikimedia.org/). 15 | Please download all objects in the following google drive to `model/` directory. 16 | - **[`Pretrained BERT model and trained SentencePiece model`](https://drive.google.com/drive/folders/1Zsm9DD40lrUVu6iAnIuTH2ODIkh-WM-O?usp=sharing)** 17 | 18 | Loss function during training is as below (after 1M steps the loss function massively changes because `max_seq_length` is changed from `128` to `512`.): 19 | ![pretraining-loss](pretraining-loss.png) 20 | 21 | ```sh 22 | ***** Eval results ***** 23 | global_step = 1400000 24 | loss = 1.3773012 25 | masked_lm_accuracy = 0.6810424 26 | masked_lm_loss = 1.4216621 27 | next_sentence_accuracy = 0.985 28 | next_sentence_loss = 0.059553143 29 | ``` 30 | 31 | ## Finetuning with BERT Japanese 32 | We also provide a simple Japanese text classification problem with [`livedoor ニュースコーパス`](https://www.rondhuit.com/download.html). 33 | Try the following notebook to check the usability of finetuning. 34 | You can run the notebook on CPU (too slow) or GPU/TPU environments. 35 | - **[finetune-to-livedoor-corpus.ipynb](https://github.com/yoheikikuta/bert-japanese/blob/master/notebook/finetune-to-livedoor-corpus.ipynb)** 36 | 37 | The results are the following: 38 | - BERT with SentencePiece 39 | 40 | ```sh 41 | precision recall f1-score support 42 | 43 | dokujo-tsushin 0.98 0.94 0.96 178 44 | it-life-hack 0.96 0.97 0.96 172 45 | kaden-channel 0.99 0.98 0.99 176 46 | livedoor-homme 0.98 0.88 0.93 95 47 | movie-enter 0.96 0.99 0.98 158 48 | peachy 0.94 0.98 0.96 174 49 | smax 0.98 0.99 0.99 167 50 | sports-watch 0.98 1.00 0.99 190 51 | topic-news 0.99 0.98 0.98 163 52 | 53 | micro avg 0.97 0.97 0.97 1473 54 | macro avg 0.97 0.97 0.97 1473 55 | weighted avg 0.97 0.97 0.97 1473 56 | ``` 57 | 58 | - sklearn GradientBoostingClassifier with MeCab 59 | 60 | ```sh 61 | precision recall f1-score support 62 | 63 | dokujo-tsushin 0.89 0.86 0.88 178 64 | it-life-hack 0.91 0.90 0.91 172 65 | kaden-channel 0.90 0.94 0.92 176 66 | livedoor-homme 0.79 0.74 0.76 95 67 | movie-enter 0.93 0.96 0.95 158 68 | peachy 0.87 0.92 0.89 174 69 | smax 0.99 1.00 1.00 167 70 | sports-watch 0.93 0.98 0.96 190 71 | topic-news 0.96 0.86 0.91 163 72 | 73 | micro avg 0.92 0.92 0.92 1473 74 | macro avg 0.91 0.91 0.91 1473 75 | weighted avg 0.92 0.92 0.91 1473 76 | ``` 77 | 78 | ## Cautions when using the model as a sentence generation model 79 | The model expects lowercase input and the tokenizer is assumed to be used with `do_lower_case=True` option, but the special tokens such as `[CLS]` are registered in uppercase characters. 80 | Therefore, when we put `"[CLS] I am ..."` it into the tokenizer as a raw string, the tokenizer first makes it lowercase (`"[cls] i am ..."`) and then cannot interpret `"[cls]"` as the special token, which causes problems. 81 | If you wanna use the model as a sentence generation model, follow from these steps (sorry, it's a little bit confusing): 82 | - keep special tokens (such as `[CLS]` or `[SEP]`) uppercase 83 | - make original input sentences lowercase manually (e.g., `"i am ..."`) 84 | - join them together (e.g., `"[CLS] i am ..."`) and put it into the tokenizer with `do_lower_case=False` option 85 | - put the obtained tokens into the model 86 | 87 | ## Pretraining from scratch 88 | All scripts for pretraining from scratch are provided. 89 | Follow the instructions below. 90 | 91 | ### Environment set up 92 | Build a docker image with Dockerfile and create a docker container. 93 | 94 | ```sh 95 | docker build -t bert-ja . 96 | docker run -it --rm -v $PWD:/work -p 8888:8888 bert-ja 97 | ``` 98 | 99 | ### Data preparation 100 | Data downloading and preprocessing. 101 | It takes about a few hours on GCP n1-standard-16 (16CPUs, 60GB memories) instance. 102 | 103 | ```sh 104 | python3 src/data-download-and-extract.py 105 | bash src/file-preprocessing.sh 106 | ``` 107 | 108 | The above scripts use the latest jawiki data and wikiextractor module, which are different from those used for the pretrained model. 109 | If you wanna prepare the same situation, use the following information: 110 | 111 | - bert-japanese: commit `074fe20f33a020769091e1e5552b33867ccbd750` 112 | - dataset: `jawiki-20181220-pages-articles-multistream.xml.bz2` in the [Google Drive](https://drive.google.com/drive/folders/1Zsm9DD40lrUVu6iAnIuTH2ODIkh-WM-O?usp=sharing) 113 | - wikiextractor: commit `1e4236de4237d0a89d0ad7241505d73ee7e23517` 114 | 115 | ### Training SentencePiece model 116 | Train a SentencePiece model using the preprocessed data. 117 | It takes about two hours on the instance. 118 | 119 | ```sh 120 | python3 src/train-sentencepiece.py 121 | ``` 122 | 123 | ### Creating data for pretraining 124 | Create .tfrecord files for pretraining. 125 | For longer sentence data, replace the value of `max_seq_length` with `512`. 126 | 127 | ```sh 128 | for DIR in $( find /work/data/wiki/ -mindepth 1 -type d ); do 129 | python3 src/create_pretraining_data.py \ 130 | --input_file=${DIR}/all.txt \ 131 | --output_file=${DIR}/all-maxseq128.tfrecord \ 132 | --model_file=./model/wiki-ja.model \ 133 | --vocab_file=./model/wiki-ja.vocab \ 134 | --do_lower_case=True \ 135 | --max_seq_length=128 \ 136 | --max_predictions_per_seq=20 \ 137 | --masked_lm_prob=0.15 \ 138 | --random_seed=12345 \ 139 | --dupe_factor=5 \ 140 | --do_whole_word_mask=False 141 | done 142 | ``` 143 | 144 | ### Pretraining 145 | You need GPU/TPU environment to pretrain a BERT model. 146 | The following notebook provides the link to Colab notebook where you can run the scripts with TPUs. 147 | 148 | - **[pretraining.ipynb](https://github.com/yoheikikuta/bert-japanese/blob/master/notebook/pretraining.ipynb)** 149 | 150 | 151 | ## How to cite this work in papers 152 | We didn't publish any paper about this work. 153 | Please cite this repository in publications as the following: 154 | 155 | ```bibtex 156 | @misc{bertjapanese, 157 | author = {Yohei Kikuta}, 158 | title = {BERT Pretrained model Trained On Japanese Wikipedia Articles}, 159 | year = {2019}, 160 | publisher = {GitHub}, 161 | journal = {GitHub repository}, 162 | howpublished = {\url{https://github.com/yoheikikuta/bert-japanese}}, 163 | } 164 | ``` 165 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [DATA] 2 | FILEURL = https://dumps.wikimedia.org/jawiki/latest/jawiki-latest-pages-articles-multistream.xml.bz2 3 | FILEPATH = /work/data/jawiki-latest-pages-articles-multistream.xml.bz2 4 | DATADIR = /work/data/ 5 | TEXTDIR = /work/data/wiki/ 6 | 7 | [SENTENCEPIECE] 8 | PREFIX = /work/model/wiki-ja 9 | VOCABSIZE = 32000 10 | CTLSYMBOLS = [PAD],[CLS],[SEP],[MASK] 11 | 12 | [BERT-CONFIG] 13 | attention_probs_dropout_prob = 0.1 14 | hidden_act = gelu 15 | hidden_dropout_prob = 0.1 16 | hidden_size = 768 17 | initializer_range = 0.02 18 | intermediate_size = 3072 19 | max_position_embeddings = 512 20 | num_attention_heads = 12 21 | num_hidden_layers = 12 22 | type_vocab_size = 2 23 | vocab_size = 32000 24 | 25 | [FINETUNING-DATA] 26 | FILEURL = https://www.rondhuit.com/download/ldcc-20140209.tar.gz 27 | FILEPATH = /work/data/ldcc-20140209.tar.gz 28 | DATADIR = /work/data/ 29 | TEXTDIR = /work/data/livedoor/ 30 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoheikikuta/bert-japanese/8d197e23b0e54da785ca9d16b7998c708767d649/data/.gitkeep -------------------------------------------------------------------------------- /data/wiki/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoheikikuta/bert-japanese/8d197e23b0e54da785ca9d16b7998c708767d649/data/wiki/.gitkeep -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoheikikuta/bert-japanese/8d197e23b0e54da785ca9d16b7998c708767d649/model/.gitkeep -------------------------------------------------------------------------------- /notebook/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoheikikuta/bert-japanese/8d197e23b0e54da785ca9d16b7998c708767d649/notebook/.gitkeep -------------------------------------------------------------------------------- /notebook/check-extract-features.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 37, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import configparser\n", 10 | "import glob\n", 11 | "import os\n", 12 | "import pandas as pd\n", 13 | "import subprocess\n", 14 | "import sys\n", 15 | "import tarfile \n", 16 | "from urllib.request import urlretrieve\n", 17 | "import json\n", 18 | "import tempfile\n", 19 | "import tensorflow as tf" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 38, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "sys.path.append(\"../src\")\n", 29 | "from utils import str_to_value" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 39, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "sys.path.append(\"../bert\")\n", 39 | "import modeling" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "CURDIR = os.getcwd()\n", 49 | "CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini')\n", 50 | "config = configparser.ConfigParser()\n", 51 | "config.read(CONFIGPATH)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 41, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "FILEURL = config['FINETUNING-DATA']['FILEURL']\n", 61 | "FILEPATH = config['FINETUNING-DATA']['FILEPATH']\n", 62 | "EXTRACTDIR = config['FINETUNING-DATA']['TEXTDIR']" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 42, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "PRETRAINED_MODEL_PATH = '../model/bert-wiki-ja/model.ckpt-1400000'" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 43, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "bert_config_file = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8', suffix='.json')\n", 81 | "bert_config_file.write(json.dumps({k:str_to_value(v) for k,v in config['BERT-CONFIG'].items()}))\n", 82 | "bert_config_file.seek(0)\n", 83 | "bert_config_file_path = str(bert_config_file.name)\n", 84 | "bert_config = modeling.BertConfig.from_json_file(bert_config_file.name)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 44, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "!echo 'すべての人間は、生れながらにして自由であり、かつ、尊厳と権利とについて平等である。 ||| 人間は、理性と良心とを授けられており、互いに同胞の精神をもって行動しなければならない。' > /tmp/input.txt" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "scrolled": true 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "!python ../src/extract_features.py \\\n", 105 | " --input_file=/tmp/input.txt \\\n", 106 | " --output_file=output.json \\\n", 107 | " --vocab_file=../model/bert-wiki-ja/wiki-ja.vocab \\\n", 108 | " --model_file=../model/bert-wiki-ja/wiki-ja.model \\\n", 109 | " --bert_config_file={bert_config_file.name} \\\n", 110 | " --init_checkpoint={PRETRAINED_MODEL_PATH} \\\n", 111 | " --layers=-1,-2,-3,-4 \\\n", 112 | " --max_seq_length=128 \\\n", 113 | " --batch_size=8" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "Let's have a look at the features of the last layer for the word \"人間\".\n", 121 | "\n", 122 | "The 0-th token is always [CLS], and the 1st token of a sentence is [▁]. So the word comes in 3rd position." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 59, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "人間\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "import json\n", 140 | "\n", 141 | "outputs = json.load(open('output.json'))\n", 142 | "print(outputs['features'][3]['token'])" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "The last year is layer 0, the one before is layer -1, etc...\n", 150 | "The embeddings are stored in the *values* entry." 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 66, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "768\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "embeddings = outputs['features'][3]['layers'][0]['values']\n", 168 | "print(len(embeddings))" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [] 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "Python 3", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 3 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython3", 195 | "version": "3.6.8" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 2 200 | } 201 | -------------------------------------------------------------------------------- /notebook/check-trained-tokenizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Check trained tokenizer\n", 8 | "\n", 9 | "Check trained SentencePiece tokenizer." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "sys.path.append(\"../src/\")" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import tokenization_sentencepiece as tokenization" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "text1 = \"信じられているから走るのだ。間に合う、間に合わぬは問題でないのだ。\"\n", 38 | "text2 = \"新たな時代のMarxよこれらの盲目な衝動から動く世界を素晴しく美しい構成に変へよ\"" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "## `do_lower_case = True` case (expected)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "Loaded a trained SentencePiece model.\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "tokenizer = tokenization.FullTokenizer(\n", 63 | " model_file=\"../model/wiki-ja.model\",\n", 64 | " vocab_file=\"../model/wiki-ja.vocab\",\n", 65 | " do_lower_case=True)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "['▁',\n", 77 | " '信じられ',\n", 78 | " 'ている',\n", 79 | " 'から',\n", 80 | " '走る',\n", 81 | " 'のだ',\n", 82 | " '。',\n", 83 | " '間に',\n", 84 | " '合う',\n", 85 | " '、',\n", 86 | " '間に合わ',\n", 87 | " 'ぬ',\n", 88 | " 'は',\n", 89 | " '問題',\n", 90 | " 'でない',\n", 91 | " 'のだ',\n", 92 | " '。']" 93 | ] 94 | }, 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "tokenizer.tokenize(text1)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 6, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "data": { 111 | "text/plain": [ 112 | "[9,\n", 113 | " 25435,\n", 114 | " 68,\n", 115 | " 28,\n", 116 | " 8956,\n", 117 | " 6312,\n", 118 | " 8,\n", 119 | " 3637,\n", 120 | " 4277,\n", 121 | " 7,\n", 122 | " 23144,\n", 123 | " 2777,\n", 124 | " 11,\n", 125 | " 451,\n", 126 | " 4772,\n", 127 | " 6312,\n", 128 | " 8]" 129 | ] 130 | }, 131 | "execution_count": 6, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | } 135 | ], 136 | "source": [ 137 | "tokenizer.convert_tokens_to_ids( tokenizer.tokenize(text1) )" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 7, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "['▁',\n", 149 | " '新たな',\n", 150 | " '時代の',\n", 151 | " 'mar',\n", 152 | " 'x',\n", 153 | " 'よ',\n", 154 | " 'これらの',\n", 155 | " '盲',\n", 156 | " '目',\n", 157 | " 'な',\n", 158 | " '衝動',\n", 159 | " 'から',\n", 160 | " '動く',\n", 161 | " '世界',\n", 162 | " 'を',\n", 163 | " '素',\n", 164 | " '晴',\n", 165 | " 'しく',\n", 166 | " '美しい',\n", 167 | " '構成',\n", 168 | " 'に',\n", 169 | " '変',\n", 170 | " 'へ',\n", 171 | " 'よ']" 172 | ] 173 | }, 174 | "execution_count": 7, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "tokenizer.tokenize(text2)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 8, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "text/plain": [ 191 | "[9,\n", 192 | " 1379,\n", 193 | " 1097,\n", 194 | " 6459,\n", 195 | " 282,\n", 196 | " 842,\n", 197 | " 1432,\n", 198 | " 12428,\n", 199 | " 303,\n", 200 | " 57,\n", 201 | " 30802,\n", 202 | " 28,\n", 203 | " 11612,\n", 204 | " 301,\n", 205 | " 18,\n", 206 | " 1407,\n", 207 | " 3606,\n", 208 | " 3526,\n", 209 | " 5797,\n", 210 | " 1171,\n", 211 | " 17,\n", 212 | " 1987,\n", 213 | " 90,\n", 214 | " 842]" 215 | ] 216 | }, 217 | "execution_count": 8, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "tokenizer.convert_tokens_to_ids( tokenizer.tokenize(text2) )" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 9, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "['▁',\n", 235 | " '新たな',\n", 236 | " '時代の',\n", 237 | " 'mar',\n", 238 | " 'x',\n", 239 | " 'よ',\n", 240 | " 'これらの',\n", 241 | " '盲',\n", 242 | " '目',\n", 243 | " 'な',\n", 244 | " '衝動',\n", 245 | " 'から',\n", 246 | " '動く',\n", 247 | " '世界',\n", 248 | " 'を',\n", 249 | " '素',\n", 250 | " '晴',\n", 251 | " 'しく',\n", 252 | " '美しい',\n", 253 | " '構成',\n", 254 | " 'に',\n", 255 | " '変',\n", 256 | " 'へ',\n", 257 | " 'よ']" 258 | ] 259 | }, 260 | "execution_count": 9, 261 | "metadata": {}, 262 | "output_type": "execute_result" 263 | } 264 | ], 265 | "source": [ 266 | "tokenizer.convert_ids_to_tokens( tokenizer.convert_tokens_to_ids( tokenizer.tokenize(text2) ) )" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "## `do_lower_case = False` case (Unexpected)\n", 274 | "\n", 275 | "Note that English uppercase characters are NOT converted into lowercase ones." 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 10, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "Loaded a trained SentencePiece model.\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "tokenizer = tokenization.FullTokenizer(\n", 293 | " model_file=\"../model/wiki-ja.model\",\n", 294 | " vocab_file=\"../model/wiki-ja.vocab\",\n", 295 | " do_lower_case=False)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 11, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "text/plain": [ 306 | "['▁',\n", 307 | " '信じられ',\n", 308 | " 'ている',\n", 309 | " 'から',\n", 310 | " '走る',\n", 311 | " 'のだ',\n", 312 | " '。',\n", 313 | " '間に',\n", 314 | " '合う',\n", 315 | " '、',\n", 316 | " '間に合わ',\n", 317 | " 'ぬ',\n", 318 | " 'は',\n", 319 | " '問題',\n", 320 | " 'でない',\n", 321 | " 'のだ',\n", 322 | " '。']" 323 | ] 324 | }, 325 | "execution_count": 11, 326 | "metadata": {}, 327 | "output_type": "execute_result" 328 | } 329 | ], 330 | "source": [ 331 | "tokenizer.tokenize(text1)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 13, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "data": { 341 | "text/plain": [ 342 | "['▁',\n", 343 | " '新たな',\n", 344 | " '時代の',\n", 345 | " 'M',\n", 346 | " 'ar',\n", 347 | " 'x',\n", 348 | " 'よ',\n", 349 | " 'これらの',\n", 350 | " '盲',\n", 351 | " '目',\n", 352 | " 'な',\n", 353 | " '衝動',\n", 354 | " 'から',\n", 355 | " '動く',\n", 356 | " '世界',\n", 357 | " 'を',\n", 358 | " '素',\n", 359 | " '晴',\n", 360 | " 'しく',\n", 361 | " '美しい',\n", 362 | " '構成',\n", 363 | " 'に',\n", 364 | " '変',\n", 365 | " 'へ',\n", 366 | " 'よ']" 367 | ] 368 | }, 369 | "execution_count": 13, 370 | "metadata": {}, 371 | "output_type": "execute_result" 372 | } 373 | ], 374 | "source": [ 375 | "tokenizer.tokenize(text2)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 14, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "data": { 385 | "text/plain": [ 386 | "[9,\n", 387 | " 1379,\n", 388 | " 1097,\n", 389 | " 0,\n", 390 | " 2381,\n", 391 | " 282,\n", 392 | " 842,\n", 393 | " 1432,\n", 394 | " 12428,\n", 395 | " 303,\n", 396 | " 57,\n", 397 | " 30802,\n", 398 | " 28,\n", 399 | " 11612,\n", 400 | " 301,\n", 401 | " 18,\n", 402 | " 1407,\n", 403 | " 3606,\n", 404 | " 3526,\n", 405 | " 5797,\n", 406 | " 1171,\n", 407 | " 17,\n", 408 | " 1987,\n", 409 | " 90,\n", 410 | " 842]" 411 | ] 412 | }, 413 | "execution_count": 14, 414 | "metadata": {}, 415 | "output_type": "execute_result" 416 | } 417 | ], 418 | "source": [ 419 | "tokenizer.convert_tokens_to_ids( tokenizer.tokenize(text2) )" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 15, 425 | "metadata": {}, 426 | "outputs": [ 427 | { 428 | "data": { 429 | "text/plain": [ 430 | "['▁',\n", 431 | " '新たな',\n", 432 | " '時代の',\n", 433 | " '',\n", 434 | " 'ar',\n", 435 | " 'x',\n", 436 | " 'よ',\n", 437 | " 'これらの',\n", 438 | " '盲',\n", 439 | " '目',\n", 440 | " 'な',\n", 441 | " '衝動',\n", 442 | " 'から',\n", 443 | " '動く',\n", 444 | " '世界',\n", 445 | " 'を',\n", 446 | " '素',\n", 447 | " '晴',\n", 448 | " 'しく',\n", 449 | " '美しい',\n", 450 | " '構成',\n", 451 | " 'に',\n", 452 | " '変',\n", 453 | " 'へ',\n", 454 | " 'よ']" 455 | ] 456 | }, 457 | "execution_count": 15, 458 | "metadata": {}, 459 | "output_type": "execute_result" 460 | } 461 | ], 462 | "source": [ 463 | "tokenizer.convert_ids_to_tokens( tokenizer.convert_tokens_to_ids( tokenizer.tokenize(text2) ) )" 464 | ] 465 | } 466 | ], 467 | "metadata": { 468 | "kernelspec": { 469 | "display_name": "Python 3", 470 | "language": "python", 471 | "name": "python3" 472 | }, 473 | "language_info": { 474 | "codemirror_mode": { 475 | "name": "ipython", 476 | "version": 3 477 | }, 478 | "file_extension": ".py", 479 | "mimetype": "text/x-python", 480 | "name": "python", 481 | "nbconvert_exporter": "python", 482 | "pygments_lexer": "ipython3", 483 | "version": "3.6.7" 484 | } 485 | }, 486 | "nbformat": 4, 487 | "nbformat_minor": 2 488 | } 489 | -------------------------------------------------------------------------------- /notebook/finetune-to-livedoor-corpus.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Finetuning of the pretrained Japanese BERT model\n", 8 | "\n", 9 | "Finetune the pretrained model to solve multi-class classification problems. \n", 10 | "This notebook requires the following objects:\n", 11 | "- trained sentencepiece model (model and vocab files)\n", 12 | "- pretraiend Japanese BERT model\n", 13 | "\n", 14 | "Dataset is livedoor ニュースコーパス in https://www.rondhuit.com/download.html. \n", 15 | "We make test:dev:train = 2:2:6 datasets." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "Results:\n", 23 | "\n", 24 | "- Full training data\n", 25 | " - BERT with SentencePiece\n", 26 | " ```\n", 27 | " precision recall f1-score support\n", 28 | "\n", 29 | " dokujo-tsushin 0.98 0.94 0.96 178\n", 30 | " it-life-hack 0.96 0.97 0.96 172\n", 31 | " kaden-channel 0.99 0.98 0.99 176\n", 32 | " livedoor-homme 0.98 0.88 0.93 95\n", 33 | " movie-enter 0.96 0.99 0.98 158\n", 34 | " peachy 0.94 0.98 0.96 174\n", 35 | " smax 0.98 0.99 0.99 167\n", 36 | " sports-watch 0.98 1.00 0.99 190\n", 37 | " topic-news 0.99 0.98 0.98 163\n", 38 | "\n", 39 | " micro avg 0.97 0.97 0.97 1473\n", 40 | " macro avg 0.97 0.97 0.97 1473\n", 41 | " weighted avg 0.97 0.97 0.97 1473\n", 42 | " ```\n", 43 | " - sklearn GradientBoostingClassifier with MeCab\n", 44 | " ```\n", 45 | " precision recall f1-score support\n", 46 | "\n", 47 | " dokujo-tsushin 0.89 0.86 0.88 178\n", 48 | " it-life-hack 0.91 0.90 0.91 172\n", 49 | " kaden-channel 0.90 0.94 0.92 176\n", 50 | " livedoor-homme 0.79 0.74 0.76 95\n", 51 | " movie-enter 0.93 0.96 0.95 158\n", 52 | " peachy 0.87 0.92 0.89 174\n", 53 | " smax 0.99 1.00 1.00 167\n", 54 | " sports-watch 0.93 0.98 0.96 190\n", 55 | " topic-news 0.96 0.86 0.91 163\n", 56 | "\n", 57 | " micro avg 0.92 0.92 0.92 1473\n", 58 | " macro avg 0.91 0.91 0.91 1473\n", 59 | " weighted avg 0.92 0.92 0.91 1473\n", 60 | " ```\n", 61 | "\n", 62 | "- Small training data (1/5 of full training data)\n", 63 | " - BERT with SentencePiece\n", 64 | " ```\n", 65 | " precision recall f1-score support\n", 66 | "\n", 67 | " dokujo-tsushin 0.97 0.87 0.92 178\n", 68 | " it-life-hack 0.86 0.86 0.86 172\n", 69 | " kaden-channel 0.95 0.94 0.95 176\n", 70 | " livedoor-homme 0.82 0.82 0.82 95\n", 71 | " movie-enter 0.97 0.99 0.98 158\n", 72 | " peachy 0.89 0.95 0.92 174\n", 73 | " smax 0.94 0.96 0.95 167\n", 74 | " sports-watch 0.97 0.97 0.97 190\n", 75 | " topic-news 0.94 0.94 0.94 163\n", 76 | "\n", 77 | " micro avg 0.93 0.93 0.93 1473\n", 78 | " macro avg 0.92 0.92 0.92 1473\n", 79 | " weighted avg 0.93 0.93 0.93 1473\n", 80 | " ```\n", 81 | " - sklearn GradientBoostingClassifier with MeCab\n", 82 | " ```\n", 83 | " precision recall f1-score support\n", 84 | "\n", 85 | " dokujo-tsushin 0.82 0.71 0.76 178\n", 86 | " it-life-hack 0.86 0.88 0.87 172\n", 87 | " kaden-channel 0.91 0.87 0.89 176\n", 88 | " livedoor-homme 0.67 0.63 0.65 95\n", 89 | " movie-enter 0.87 0.95 0.91 158\n", 90 | " peachy 0.70 0.78 0.73 174\n", 91 | " smax 1.00 1.00 1.00 167\n", 92 | " sports-watch 0.87 0.95 0.91 190\n", 93 | " topic-news 0.92 0.82 0.87 163\n", 94 | "\n", 95 | " micro avg 0.85 0.85 0.85 1473\n", 96 | " macro avg 0.85 0.84 0.84 1473\n", 97 | " weighted avg 0.86 0.85 0.85 1473\n", 98 | " ```" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "import configparser\n", 108 | "import glob\n", 109 | "import os\n", 110 | "import pandas as pd\n", 111 | "import subprocess\n", 112 | "import sys\n", 113 | "import tarfile \n", 114 | "from urllib.request import urlretrieve\n", 115 | "\n", 116 | "CURDIR = os.getcwd()\n", 117 | "CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini')\n", 118 | "config = configparser.ConfigParser()\n", 119 | "config.read(CONFIGPATH)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## Data preparing\n", 127 | "\n", 128 | "You need execute the following cells just once." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "FILEURL = config['FINETUNING-DATA']['FILEURL']\n", 138 | "FILEPATH = config['FINETUNING-DATA']['FILEPATH']\n", 139 | "EXTRACTDIR = config['FINETUNING-DATA']['TEXTDIR']" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "Download and unzip data." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "%%time\n", 156 | "\n", 157 | "urlretrieve(FILEURL, FILEPATH)\n", 158 | "\n", 159 | "mode = \"r:gz\"\n", 160 | "tar = tarfile.open(FILEPATH, mode) \n", 161 | "tar.extractall(EXTRACTDIR) \n", 162 | "tar.close()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "Data preprocessing." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "def extract_txt(filename):\n", 179 | " with open(filename) as text_file:\n", 180 | " # 0: URL, 1: timestamp\n", 181 | " text = text_file.readlines()[2:]\n", 182 | " text = [sentence.strip() for sentence in text]\n", 183 | " text = list(filter(lambda line: line != '', text))\n", 184 | " return ''.join(text)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "categories = [ \n", 194 | " name for name \n", 195 | " in os.listdir( os.path.join(EXTRACTDIR, \"text\") ) \n", 196 | " if os.path.isdir( os.path.join(EXTRACTDIR, \"text\", name) ) ]\n", 197 | "\n", 198 | "categories = sorted(categories)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "categories" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "table = str.maketrans({\n", 217 | " '\\n': '',\n", 218 | " '\\t': ' ',\n", 219 | " '\\r': '',\n", 220 | "})" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "%%time\n", 230 | "\n", 231 | "all_text = []\n", 232 | "all_label = []\n", 233 | "\n", 234 | "for cat in categories:\n", 235 | " files = glob.glob(os.path.join(EXTRACTDIR, \"text\", cat, \"{}*.txt\".format(cat)))\n", 236 | " files = sorted(files)\n", 237 | " body = [ extract_txt(elem).translate(table) for elem in files ]\n", 238 | " label = [cat] * len(body)\n", 239 | " \n", 240 | " all_text.extend(body)\n", 241 | " all_label.extend(label)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "df = pd.DataFrame({'text' : all_text, 'label' : all_label})" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "df.head()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "df = df.sample(frac=1, random_state=23).reset_index(drop=True)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "df.head()" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "Save data as tsv files. \n", 285 | "test:dev:train = 2:2:6. To check the usability of finetuning, we also prepare sampled training data (1/5 of full training data)." 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "df[:len(df) // 5].to_csv( os.path.join(EXTRACTDIR, \"test.tsv\"), sep='\\t', index=False)\n", 295 | "df[len(df) // 5:len(df)*2 // 5].to_csv( os.path.join(EXTRACTDIR, \"dev.tsv\"), sep='\\t', index=False)\n", 296 | "df[len(df)*2 // 5:].to_csv( os.path.join(EXTRACTDIR, \"train.tsv\"), sep='\\t', index=False)\n", 297 | "\n", 298 | "### 1/5 of full training data.\n", 299 | "# df[:len(df) // 5].to_csv( os.path.join(EXTRACTDIR, \"test.tsv\"), sep='\\t', index=False)\n", 300 | "# df[len(df) // 5:len(df)*2 // 5].to_csv( os.path.join(EXTRACTDIR, \"dev.tsv\"), sep='\\t', index=False)\n", 301 | "# df[len(df)*2 // 5:].sample(frac=0.2, random_state=23).to_csv( os.path.join(EXTRACTDIR, \"train.tsv\"), sep='\\t', index=False)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "## Finetune pre-trained model\n", 309 | "\n", 310 | "It will take a lot of hours to execute the following cells on CPU environment. \n", 311 | "You can also use colab to recieve the power of TPU. You need to uplode the created data onto your GCS bucket.\n", 312 | "\n", 313 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1zZH2GWe0U-7GjJ2w2duodFfEUptvHjcx)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "PRETRAINED_MODEL_PATH = '../model/model.ckpt-1400000'\n", 323 | "FINETUNE_OUTPUT_DIR = '../model/livedoor_output'" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "%%time\n", 333 | "# It will take many hours on CPU environment.\n", 334 | "\n", 335 | "!python3 ../src/run_classifier.py \\\n", 336 | " --task_name=livedoor \\\n", 337 | " --do_train=true \\\n", 338 | " --do_eval=true \\\n", 339 | " --data_dir=../data/livedoor \\\n", 340 | " --model_file=../model/wiki-ja.model \\\n", 341 | " --vocab_file=../model/wiki-ja.vocab \\\n", 342 | " --init_checkpoint={PRETRAINED_MODEL_PATH} \\\n", 343 | " --max_seq_length=512 \\\n", 344 | " --train_batch_size=4 \\\n", 345 | " --learning_rate=2e-5 \\\n", 346 | " --num_train_epochs=10 \\\n", 347 | " --output_dir={FINETUNE_OUTPUT_DIR}" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": {}, 353 | "source": [ 354 | "## Predict using the finetuned model\n", 355 | "\n", 356 | "Let's predict test data using the finetuned model. " 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "%%time\n", 366 | "# It will take many hours on CPU environment.\n", 367 | "\n", 368 | "!python3 ../src/run_classifier.py \\\n", 369 | " --task_name=livedoor \\\n", 370 | " --do_predict=true \\\n", 371 | " --data_dir=../data/livedoor \\\n", 372 | " --model_file=../model/wiki-ja.model \\\n", 373 | " --vocab_file=../model/wiki-ja.vocab \\\n", 374 | " --output_dir={FINETUNE_OUTPUT_DIR}" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "import sys\n", 384 | "sys.path.append(\"../src\")\n", 385 | "\n", 386 | "from run_classifier import LivedoorProcessor" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "processor = LivedoorProcessor()\n", 396 | "label_list = processor.get_labels()" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "result = pd.read_csv(FINETUNE_OUTPUT_DIR+\"/test_results.tsv\", sep='\\t', header=None)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "result.head()" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": {}, 420 | "source": [ 421 | "Read test data set and add prediction results." 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "import pandas as pd" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "test_df = pd.read_csv(\"../data/livedoor/test.tsv\", sep='\\t')" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "test_df['predict'] = [ label_list[idx] for idx in result.idxmax(axis=1) ]" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "test_df.head()" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [ 466 | "sum( test_df['label'] == test_df['predict'] ) / len(test_df)" 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "metadata": {}, 472 | "source": [ 473 | "A littel more detailed check using `sklearn.metrics`." 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "!pip install scikit-learn" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "metadata": {}, 489 | "outputs": [], 490 | "source": [ 491 | "from sklearn.metrics import classification_report\n", 492 | "from sklearn.metrics import confusion_matrix" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": {}, 499 | "outputs": [], 500 | "source": [ 501 | "print(classification_report(test_df['label'], test_df['predict']))" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "print(confusion_matrix(test_df['label'], test_df['predict']))" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "### Simple baseline model." 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "import pandas as pd\n", 527 | "from sklearn.metrics import classification_report\n", 528 | "from sklearn.metrics import confusion_matrix" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "train_df = pd.read_csv(\"../data/livedoor/train.tsv\", sep='\\t')\n", 538 | "dev_df = pd.read_csv(\"../data/livedoor/dev.tsv\", sep='\\t')\n", 539 | "test_df = pd.read_csv(\"../data/livedoor/test.tsv\", sep='\\t')" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "!apt-get install -q -y mecab libmecab-dev mecab-ipadic mecab-ipadic-utf8" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "metadata": {}, 555 | "outputs": [], 556 | "source": [ 557 | "!pip install mecab-python3==0.7" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "from sklearn.feature_extraction.text import TfidfVectorizer\n", 567 | "from sklearn.ensemble import GradientBoostingClassifier\n", 568 | "import MeCab" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "metadata": {}, 575 | "outputs": [], 576 | "source": [ 577 | "m = MeCab.Tagger(\"-Owakati\")" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [ 586 | "train_dev_df = pd.concat([train_df, dev_df])" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": null, 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "train_dev_xs = train_dev_df['text'].apply(lambda x: m.parse(x))\n", 596 | "train_dev_ys = train_dev_df['label']\n", 597 | "\n", 598 | "test_xs = test_df['text'].apply(lambda x: m.parse(x))\n", 599 | "test_ys = test_df['label']" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "vectorizer = TfidfVectorizer(max_features=750)\n", 609 | "train_dev_xs_ = vectorizer.fit_transform(train_dev_xs)\n", 610 | "test_xs_ = vectorizer.transform(test_xs)" 611 | ] 612 | }, 613 | { 614 | "cell_type": "markdown", 615 | "metadata": {}, 616 | "source": [ 617 | "The following set up is not exactly identical to that of BERT because inside Classifier it uses `train_test_split` with shuffle. \n", 618 | "In addition, parameters are not well tuned, however, we think it's enough to check the power of BERT." 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "%%time\n", 628 | "\n", 629 | "model = GradientBoostingClassifier(n_estimators=200,\n", 630 | " validation_fraction=len(dev_df)/len(train_df),\n", 631 | " n_iter_no_change=5,\n", 632 | " tol=0.01,\n", 633 | " random_state=23)\n", 634 | "\n", 635 | "### 1/5 of full training data.\n", 636 | "# model = GradientBoostingClassifier(n_estimators=200,\n", 637 | "# validation_fraction=len(dev_df)/len(train_df),\n", 638 | "# n_iter_no_change=5,\n", 639 | "# tol=0.01,\n", 640 | "# random_state=23)\n", 641 | "\n", 642 | "model.fit(train_dev_xs_, train_dev_ys)" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": null, 648 | "metadata": {}, 649 | "outputs": [], 650 | "source": [ 651 | "print(classification_report(test_ys, model.predict(test_xs_)))" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "metadata": {}, 658 | "outputs": [], 659 | "source": [ 660 | "print(confusion_matrix(test_ys, model.predict(test_xs_)))" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": {}, 667 | "outputs": [], 668 | "source": [] 669 | } 670 | ], 671 | "metadata": { 672 | "kernelspec": { 673 | "display_name": "Python 3", 674 | "language": "python", 675 | "name": "python3" 676 | }, 677 | "language_info": { 678 | "codemirror_mode": { 679 | "name": "ipython", 680 | "version": 3 681 | }, 682 | "file_extension": ".py", 683 | "mimetype": "text/x-python", 684 | "name": "python", 685 | "nbconvert_exporter": "python", 686 | "pygments_lexer": "ipython3", 687 | "version": "3.6.7" 688 | } 689 | }, 690 | "nbformat": 4, 691 | "nbformat_minor": 2 692 | } 693 | -------------------------------------------------------------------------------- /pretraining-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoheikikuta/bert-japanese/8d197e23b0e54da785ca9d16b7998c708767d649/pretraining-loss.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentencepiece 2 | jupyter 3 | tensorflow<2.0.0 4 | pandas 5 | -------------------------------------------------------------------------------- /src/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file is based on https://github.com/google-research/bert/blob/master/create_pretraining_data.py. 3 | # It is changed to use SentencePiece tokenizer for tokenizations. 4 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import collections 11 | import random 12 | import tensorflow as tf 13 | import tokenization_sentencepiece as tokenization 14 | 15 | flags = tf.flags 16 | 17 | FLAGS = flags.FLAGS 18 | 19 | flags.DEFINE_string("input_file", None, 20 | "Input raw text file (or comma-separated list of files).") 21 | 22 | flags.DEFINE_string( 23 | "output_file", None, 24 | "Output TF example file (or comma-separated list of files).") 25 | 26 | flags.DEFINE_string("model_file", None, 27 | "The model file that the SentencePiece model was trained on.") 28 | 29 | flags.DEFINE_string("vocab_file", None, 30 | "The vocabulary file that the SentencePiece model was trained on.") 31 | 32 | flags.DEFINE_bool( 33 | "do_lower_case", True, 34 | "Whether to lower case the input text. Should be True for uncased " 35 | "models and False for cased models.") 36 | 37 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 38 | 39 | flags.DEFINE_integer("max_predictions_per_seq", 20, 40 | "Maximum number of masked LM predictions per sequence.") 41 | 42 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 43 | 44 | flags.DEFINE_integer( 45 | "dupe_factor", 10, 46 | "Number of times to duplicate the input data (with different masks).") 47 | 48 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 49 | 50 | flags.DEFINE_float( 51 | "short_seq_prob", 0.1, 52 | "Probability of creating sequences which are shorter than the " 53 | "maximum length.") 54 | 55 | 56 | class TrainingInstance(object): 57 | """A single training instance (sentence pair).""" 58 | 59 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 60 | is_random_next): 61 | self.tokens = tokens 62 | self.segment_ids = segment_ids 63 | self.is_random_next = is_random_next 64 | self.masked_lm_positions = masked_lm_positions 65 | self.masked_lm_labels = masked_lm_labels 66 | 67 | def __str__(self): 68 | s = "" 69 | s += "tokens: %s\n" % (" ".join( 70 | [tokenization.printable_text(x) for x in self.tokens])) 71 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 72 | s += "is_random_next: %s\n" % self.is_random_next 73 | s += "masked_lm_positions: %s\n" % (" ".join( 74 | [str(x) for x in self.masked_lm_positions])) 75 | s += "masked_lm_labels: %s\n" % (" ".join( 76 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 77 | s += "\n" 78 | return s 79 | 80 | def __repr__(self): 81 | return self.__str__() 82 | 83 | 84 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 85 | max_predictions_per_seq, output_files): 86 | """Create TF example files from `TrainingInstance`s.""" 87 | writers = [] 88 | for output_file in output_files: 89 | writers.append(tf.python_io.TFRecordWriter(output_file)) 90 | 91 | writer_index = 0 92 | 93 | total_written = 0 94 | for (inst_index, instance) in enumerate(instances): 95 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 96 | input_mask = [1] * len(input_ids) 97 | segment_ids = list(instance.segment_ids) 98 | assert len(input_ids) <= max_seq_length 99 | 100 | while len(input_ids) < max_seq_length: 101 | input_ids.append(0) 102 | input_mask.append(0) 103 | segment_ids.append(0) 104 | 105 | assert len(input_ids) == max_seq_length 106 | assert len(input_mask) == max_seq_length 107 | assert len(segment_ids) == max_seq_length 108 | 109 | masked_lm_positions = list(instance.masked_lm_positions) 110 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 111 | masked_lm_weights = [1.0] * len(masked_lm_ids) 112 | 113 | while len(masked_lm_positions) < max_predictions_per_seq: 114 | masked_lm_positions.append(0) 115 | masked_lm_ids.append(0) 116 | masked_lm_weights.append(0.0) 117 | 118 | next_sentence_label = 1 if instance.is_random_next else 0 119 | 120 | features = collections.OrderedDict() 121 | features["input_ids"] = create_int_feature(input_ids) 122 | features["input_mask"] = create_int_feature(input_mask) 123 | features["segment_ids"] = create_int_feature(segment_ids) 124 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 125 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 126 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 127 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 128 | 129 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 130 | 131 | writers[writer_index].write(tf_example.SerializeToString()) 132 | writer_index = (writer_index + 1) % len(writers) 133 | 134 | total_written += 1 135 | 136 | if inst_index < 20: 137 | tf.logging.info("*** Example ***") 138 | tf.logging.info("tokens: %s" % " ".join( 139 | [tokenization.printable_text(x) for x in instance.tokens])) 140 | 141 | for feature_name in features.keys(): 142 | feature = features[feature_name] 143 | values = [] 144 | if feature.int64_list.value: 145 | values = feature.int64_list.value 146 | elif feature.float_list.value: 147 | values = feature.float_list.value 148 | tf.logging.info( 149 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 150 | 151 | for writer in writers: 152 | writer.close() 153 | 154 | tf.logging.info("Wrote %d total instances", total_written) 155 | 156 | 157 | def create_int_feature(values): 158 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 159 | return feature 160 | 161 | 162 | def create_float_feature(values): 163 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 164 | return feature 165 | 166 | 167 | def create_training_instances(input_files, tokenizer, max_seq_length, 168 | dupe_factor, short_seq_prob, masked_lm_prob, 169 | max_predictions_per_seq, rng): 170 | """Create `TrainingInstance`s from raw text.""" 171 | all_documents = [[]] 172 | 173 | # Input file format: 174 | # (1) One sentence per line. These should ideally be actual sentences, not 175 | # entire paragraphs or arbitrary spans of text. (Because we use the 176 | # sentence boundaries for the "next sentence prediction" task). 177 | # (2) Blank lines between documents. Document boundaries are needed so 178 | # that the "next sentence prediction" task doesn't span between documents. 179 | for input_file in input_files: 180 | with tf.gfile.GFile(input_file, "r") as reader: 181 | while True: 182 | line = tokenization.convert_to_unicode(reader.readline()) 183 | if not line: 184 | break 185 | line = line.strip() 186 | 187 | # Empty lines are used as document delimiters 188 | if not line: 189 | all_documents.append([]) 190 | tokens = tokenizer.tokenize(line) 191 | if tokens: 192 | all_documents[-1].append(tokens) 193 | 194 | # Remove empty documents 195 | all_documents = [x for x in all_documents if x] 196 | rng.shuffle(all_documents) 197 | 198 | vocab_words = list(tokenizer.vocab.keys()) 199 | instances = [] 200 | for _ in range(dupe_factor): 201 | for document_index in range(len(all_documents)): 202 | instances.extend( 203 | create_instances_from_document( 204 | all_documents, document_index, max_seq_length, short_seq_prob, 205 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 206 | 207 | rng.shuffle(instances) 208 | return instances 209 | 210 | 211 | def create_instances_from_document( 212 | all_documents, document_index, max_seq_length, short_seq_prob, 213 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 214 | """Creates `TrainingInstance`s for a single document.""" 215 | document = all_documents[document_index] 216 | 217 | # Account for [CLS], [SEP], [SEP] 218 | max_num_tokens = max_seq_length - 3 219 | 220 | # We *usually* want to fill up the entire sequence since we are padding 221 | # to `max_seq_length` anyways, so short sequences are generally wasted 222 | # computation. However, we *sometimes* 223 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 224 | # sequences to minimize the mismatch between pre-training and fine-tuning. 225 | # The `target_seq_length` is just a rough target however, whereas 226 | # `max_seq_length` is a hard limit. 227 | target_seq_length = max_num_tokens 228 | if rng.random() < short_seq_prob: 229 | target_seq_length = rng.randint(2, max_num_tokens) 230 | 231 | # We DON'T just concatenate all of the tokens from a document into a long 232 | # sequence and choose an arbitrary split point because this would make the 233 | # next sentence prediction task too easy. Instead, we split the input into 234 | # segments "A" and "B" based on the actual "sentences" provided by the user 235 | # input. 236 | instances = [] 237 | current_chunk = [] 238 | current_length = 0 239 | i = 0 240 | while i < len(document): 241 | segment = document[i] 242 | current_chunk.append(segment) 243 | current_length += len(segment) 244 | if i == len(document) - 1 or current_length >= target_seq_length: 245 | if current_chunk: 246 | # `a_end` is how many segments from `current_chunk` go into the `A` 247 | # (first) sentence. 248 | a_end = 1 249 | if len(current_chunk) >= 2: 250 | a_end = rng.randint(1, len(current_chunk) - 1) 251 | 252 | tokens_a = [] 253 | for j in range(a_end): 254 | tokens_a.extend(current_chunk[j]) 255 | 256 | tokens_b = [] 257 | # Random next 258 | is_random_next = False 259 | if len(current_chunk) == 1 or rng.random() < 0.5: 260 | is_random_next = True 261 | target_b_length = target_seq_length - len(tokens_a) 262 | 263 | # This should rarely go for more than one iteration for large 264 | # corpora. However, just to be careful, we try to make sure that 265 | # the random document is not the same as the document 266 | # we're processing. 267 | for _ in range(10): 268 | random_document_index = rng.randint(0, len(all_documents) - 1) 269 | if random_document_index != document_index: 270 | break 271 | 272 | random_document = all_documents[random_document_index] 273 | random_start = rng.randint(0, len(random_document) - 1) 274 | for j in range(random_start, len(random_document)): 275 | tokens_b.extend(random_document[j]) 276 | if len(tokens_b) >= target_b_length: 277 | break 278 | # We didn't actually use these segments so we "put them back" so 279 | # they don't go to waste. 280 | num_unused_segments = len(current_chunk) - a_end 281 | i -= num_unused_segments 282 | # Actual next 283 | else: 284 | is_random_next = False 285 | for j in range(a_end, len(current_chunk)): 286 | tokens_b.extend(current_chunk[j]) 287 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 288 | 289 | assert len(tokens_a) >= 1 290 | assert len(tokens_b) >= 1 291 | 292 | tokens = [] 293 | segment_ids = [] 294 | tokens.append("[CLS]") 295 | segment_ids.append(0) 296 | for token in tokens_a: 297 | tokens.append(token) 298 | segment_ids.append(0) 299 | 300 | tokens.append("[SEP]") 301 | segment_ids.append(0) 302 | 303 | for token in tokens_b: 304 | tokens.append(token) 305 | segment_ids.append(1) 306 | tokens.append("[SEP]") 307 | segment_ids.append(1) 308 | 309 | (tokens, masked_lm_positions, 310 | masked_lm_labels) = create_masked_lm_predictions( 311 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 312 | instance = TrainingInstance( 313 | tokens=tokens, 314 | segment_ids=segment_ids, 315 | is_random_next=is_random_next, 316 | masked_lm_positions=masked_lm_positions, 317 | masked_lm_labels=masked_lm_labels) 318 | instances.append(instance) 319 | current_chunk = [] 320 | current_length = 0 321 | i += 1 322 | 323 | return instances 324 | 325 | 326 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 327 | ["index", "label"]) 328 | 329 | 330 | def create_masked_lm_predictions(tokens, masked_lm_prob, 331 | max_predictions_per_seq, vocab_words, rng): 332 | """Creates the predictions for the masked LM objective.""" 333 | 334 | cand_indexes = [] 335 | for (i, token) in enumerate(tokens): 336 | if token == "[CLS]" or token == "[SEP]": 337 | continue 338 | cand_indexes.append(i) 339 | 340 | rng.shuffle(cand_indexes) 341 | 342 | output_tokens = list(tokens) 343 | 344 | num_to_predict = min(max_predictions_per_seq, 345 | max(1, int(round(len(tokens) * masked_lm_prob)))) 346 | 347 | masked_lms = [] 348 | covered_indexes = set() 349 | for index in cand_indexes: 350 | if len(masked_lms) >= num_to_predict: 351 | break 352 | if index in covered_indexes: 353 | continue 354 | covered_indexes.add(index) 355 | 356 | masked_token = None 357 | # 80% of the time, replace with [MASK] 358 | if rng.random() < 0.8: 359 | masked_token = "[MASK]" 360 | else: 361 | # 10% of the time, keep original 362 | if rng.random() < 0.5: 363 | masked_token = tokens[index] 364 | # 10% of the time, replace with random word 365 | else: 366 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 367 | 368 | output_tokens[index] = masked_token 369 | 370 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 371 | 372 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 373 | 374 | masked_lm_positions = [] 375 | masked_lm_labels = [] 376 | for p in masked_lms: 377 | masked_lm_positions.append(p.index) 378 | masked_lm_labels.append(p.label) 379 | 380 | return (output_tokens, masked_lm_positions, masked_lm_labels) 381 | 382 | 383 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 384 | """Truncates a pair of sequences to a maximum sequence length.""" 385 | while True: 386 | total_length = len(tokens_a) + len(tokens_b) 387 | if total_length <= max_num_tokens: 388 | break 389 | 390 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 391 | assert len(trunc_tokens) >= 1 392 | 393 | # We want to sometimes truncate from the front and sometimes from the 394 | # back to add more randomness and avoid biases. 395 | if rng.random() < 0.5: 396 | del trunc_tokens[0] 397 | else: 398 | trunc_tokens.pop() 399 | 400 | 401 | def main(_): 402 | tf.logging.set_verbosity(tf.logging.INFO) 403 | 404 | tokenizer = tokenization.FullTokenizer( 405 | model_file=FLAGS.model_file, vocab_file=FLAGS.vocab_file, 406 | do_lower_case=FLAGS.do_lower_case) 407 | 408 | input_files = [] 409 | for input_pattern in FLAGS.input_file.split(","): 410 | input_files.extend(tf.gfile.Glob(input_pattern)) 411 | 412 | tf.logging.info("*** Reading from input files ***") 413 | for input_file in input_files: 414 | tf.logging.info(" %s", input_file) 415 | 416 | rng = random.Random(FLAGS.random_seed) 417 | instances = create_training_instances( 418 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 419 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 420 | rng) 421 | 422 | output_files = FLAGS.output_file.split(",") 423 | tf.logging.info("*** Writing to output files ***") 424 | for output_file in output_files: 425 | tf.logging.info(" %s", output_file) 426 | 427 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 428 | FLAGS.max_predictions_per_seq, output_files) 429 | 430 | 431 | if __name__ == "__main__": 432 | flags.mark_flag_as_required("input_file") 433 | flags.mark_flag_as_required("output_file") 434 | flags.mark_flag_as_required("model_file") 435 | flags.mark_flag_as_required("vocab_file") 436 | tf.app.run() 437 | -------------------------------------------------------------------------------- /src/data-download-and-extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import configparser 4 | import os 5 | import subprocess 6 | import sys 7 | from urllib.request import urlretrieve 8 | 9 | CURDIR = os.path.dirname(os.path.abspath(__file__)) 10 | CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini') 11 | config = configparser.ConfigParser() 12 | config.read(CONFIGPATH) 13 | 14 | FILEURL = config['DATA']['FILEURL'] 15 | FILEPATH = config['DATA']['FILEPATH'] 16 | EXTRACTDIR = config['DATA']['TEXTDIR'] 17 | 18 | 19 | def reporthook(blocknum, blocksize, totalsize): 20 | ''' 21 | Callback function to show progress of file downloading. 22 | ''' 23 | readsofar = blocknum * blocksize 24 | if totalsize > 0: 25 | percent = readsofar * 1e2 / totalsize 26 | s = "\r%5.1f%% %*d / %d" % ( 27 | percent, len(str(totalsize)), readsofar, totalsize) 28 | sys.stderr.write(s) 29 | if readsofar >= totalsize: # near the end 30 | sys.stderr.write("\n") 31 | else: # total size is unknown 32 | sys.stderr.write("read %d\n" % (readsofar,)) 33 | 34 | 35 | def download(): 36 | urlretrieve(FILEURL, FILEPATH, reporthook) 37 | 38 | 39 | def extract(): 40 | subprocess.call(['python3', 41 | os.path.join(CURDIR, os.pardir, 42 | 'wikiextractor', 'WikiExtractor.py'), 43 | FILEPATH, "-o={}".format(EXTRACTDIR)]) 44 | 45 | 46 | def main(): 47 | download() 48 | extract() 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /src/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import sys 22 | sys.path.append("../bert") 23 | 24 | 25 | import codecs 26 | import collections 27 | import json 28 | import re 29 | 30 | import modeling 31 | #import tokenization 32 | import tokenization_sentencepiece as tokenization 33 | import tensorflow as tf 34 | 35 | 36 | flags = tf.flags 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | flags.DEFINE_string("input_file", None, "") 41 | 42 | flags.DEFINE_string("output_file", None, "") 43 | 44 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 45 | 46 | flags.DEFINE_string( 47 | "bert_config_file", None, 48 | "The config json file corresponding to the pre-trained BERT model. " 49 | "This specifies the model architecture.") 50 | 51 | flags.DEFINE_integer( 52 | "max_seq_length", 128, 53 | "The maximum total input sequence length after WordPiece tokenization. " 54 | "Sequences longer than this will be truncated, and sequences shorter " 55 | "than this will be padded.") 56 | 57 | flags.DEFINE_string( 58 | "init_checkpoint", None, 59 | "Initial checkpoint (usually from a pre-trained BERT model).") 60 | 61 | flags.DEFINE_string("vocab_file", None, 62 | "The vocabulary file that the BERT model was trained on.") 63 | 64 | flags.DEFINE_bool( 65 | "do_lower_case", True, 66 | "Whether to lower case the input text. Should be True for uncased " 67 | "models and False for cased models.") 68 | 69 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 70 | 71 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 72 | 73 | flags.DEFINE_string("master", None, 74 | "If using a TPU, the address of the master.") 75 | 76 | flags.DEFINE_integer( 77 | "num_tpu_cores", 8, 78 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 79 | 80 | flags.DEFINE_bool( 81 | "use_one_hot_embeddings", False, 82 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 83 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 84 | "since it is much faster.") 85 | 86 | # Adeed to use the sentencepiece model 87 | flags.DEFINE_string("model_file", None, 88 | "The model file that the SentencePiece model was trained on.") 89 | 90 | class InputExample(object): 91 | 92 | def __init__(self, unique_id, text_a, text_b): 93 | self.unique_id = unique_id 94 | self.text_a = text_a 95 | self.text_b = text_b 96 | 97 | 98 | class InputFeatures(object): 99 | """A single set of features of data.""" 100 | 101 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 102 | self.unique_id = unique_id 103 | self.tokens = tokens 104 | self.input_ids = input_ids 105 | self.input_mask = input_mask 106 | self.input_type_ids = input_type_ids 107 | 108 | 109 | def input_fn_builder(features, seq_length): 110 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 111 | 112 | all_unique_ids = [] 113 | all_input_ids = [] 114 | all_input_mask = [] 115 | all_input_type_ids = [] 116 | 117 | for feature in features: 118 | all_unique_ids.append(feature.unique_id) 119 | all_input_ids.append(feature.input_ids) 120 | all_input_mask.append(feature.input_mask) 121 | all_input_type_ids.append(feature.input_type_ids) 122 | 123 | def input_fn(params): 124 | """The actual input function.""" 125 | batch_size = params["batch_size"] 126 | 127 | num_examples = len(features) 128 | 129 | # This is for demo purposes and does NOT scale to large data sets. We do 130 | # not use Dataset.from_generator() because that uses tf.py_func which is 131 | # not TPU compatible. The right way to load data is with TFRecordReader. 132 | d = tf.data.Dataset.from_tensor_slices({ 133 | "unique_ids": 134 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 135 | "input_ids": 136 | tf.constant( 137 | all_input_ids, shape=[num_examples, seq_length], 138 | dtype=tf.int32), 139 | "input_mask": 140 | tf.constant( 141 | all_input_mask, 142 | shape=[num_examples, seq_length], 143 | dtype=tf.int32), 144 | "input_type_ids": 145 | tf.constant( 146 | all_input_type_ids, 147 | shape=[num_examples, seq_length], 148 | dtype=tf.int32), 149 | }) 150 | 151 | d = d.batch(batch_size=batch_size, drop_remainder=False) 152 | return d 153 | 154 | return input_fn 155 | 156 | 157 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 158 | use_one_hot_embeddings): 159 | """Returns `model_fn` closure for TPUEstimator.""" 160 | 161 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 162 | """The `model_fn` for TPUEstimator.""" 163 | 164 | unique_ids = features["unique_ids"] 165 | input_ids = features["input_ids"] 166 | input_mask = features["input_mask"] 167 | input_type_ids = features["input_type_ids"] 168 | 169 | model = modeling.BertModel( 170 | config=bert_config, 171 | is_training=False, 172 | input_ids=input_ids, 173 | input_mask=input_mask, 174 | token_type_ids=input_type_ids, 175 | use_one_hot_embeddings=use_one_hot_embeddings) 176 | 177 | if mode != tf.estimator.ModeKeys.PREDICT: 178 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 179 | 180 | tvars = tf.trainable_variables() 181 | scaffold_fn = None 182 | (assignment_map, 183 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 184 | tvars, init_checkpoint) 185 | if use_tpu: 186 | 187 | def tpu_scaffold(): 188 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 189 | return tf.train.Scaffold() 190 | 191 | scaffold_fn = tpu_scaffold 192 | else: 193 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 194 | 195 | tf.logging.info("**** Trainable Variables ****") 196 | for var in tvars: 197 | init_string = "" 198 | if var.name in initialized_variable_names: 199 | init_string = ", *INIT_FROM_CKPT*" 200 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 201 | init_string) 202 | 203 | all_layers = model.get_all_encoder_layers() 204 | 205 | predictions = { 206 | "unique_id": unique_ids, 207 | } 208 | 209 | for (i, layer_index) in enumerate(layer_indexes): 210 | predictions["layer_output_%d" % i] = all_layers[layer_index] 211 | 212 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 213 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 214 | return output_spec 215 | 216 | return model_fn 217 | 218 | 219 | def convert_examples_to_features(examples, seq_length, tokenizer): 220 | """Loads a data file into a list of `InputBatch`s.""" 221 | 222 | features = [] 223 | for (ex_index, example) in enumerate(examples): 224 | tokens_a = tokenizer.tokenize(example.text_a) 225 | 226 | tokens_b = None 227 | if example.text_b: 228 | tokens_b = tokenizer.tokenize(example.text_b) 229 | 230 | if tokens_b: 231 | # Modifies `tokens_a` and `tokens_b` in place so that the total 232 | # length is less than the specified length. 233 | # Account for [CLS], [SEP], [SEP] with "- 3" 234 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 235 | else: 236 | # Account for [CLS] and [SEP] with "- 2" 237 | if len(tokens_a) > seq_length - 2: 238 | tokens_a = tokens_a[0:(seq_length - 2)] 239 | 240 | # The convention in BERT is: 241 | # (a) For sequence pairs: 242 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 243 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 244 | # (b) For single sequences: 245 | # tokens: [CLS] the dog is hairy . [SEP] 246 | # type_ids: 0 0 0 0 0 0 0 247 | # 248 | # Where "type_ids" are used to indicate whether this is the first 249 | # sequence or the second sequence. The embedding vectors for `type=0` and 250 | # `type=1` were learned during pre-training and are added to the wordpiece 251 | # embedding vector (and position vector). This is not *strictly* necessary 252 | # since the [SEP] token unambiguously separates the sequences, but it makes 253 | # it easier for the model to learn the concept of sequences. 254 | # 255 | # For classification tasks, the first vector (corresponding to [CLS]) is 256 | # used as as the "sentence vector". Note that this only makes sense because 257 | # the entire model is fine-tuned. 258 | tokens = [] 259 | input_type_ids = [] 260 | tokens.append("[CLS]") 261 | input_type_ids.append(0) 262 | for token in tokens_a: 263 | tokens.append(token) 264 | input_type_ids.append(0) 265 | tokens.append("[SEP]") 266 | input_type_ids.append(0) 267 | 268 | if tokens_b: 269 | for token in tokens_b: 270 | tokens.append(token) 271 | input_type_ids.append(1) 272 | tokens.append("[SEP]") 273 | input_type_ids.append(1) 274 | 275 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 276 | 277 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 278 | # tokens are attended to. 279 | input_mask = [1] * len(input_ids) 280 | 281 | # Zero-pad up to the sequence length. 282 | while len(input_ids) < seq_length: 283 | input_ids.append(0) 284 | input_mask.append(0) 285 | input_type_ids.append(0) 286 | 287 | assert len(input_ids) == seq_length 288 | assert len(input_mask) == seq_length 289 | assert len(input_type_ids) == seq_length 290 | 291 | if ex_index < 5: 292 | tf.logging.info("*** Example ***") 293 | tf.logging.info("unique_id: %s" % (example.unique_id)) 294 | tf.logging.info("tokens: %s" % " ".join( 295 | [tokenization.printable_text(x) for x in tokens])) 296 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 297 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 298 | tf.logging.info( 299 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 300 | 301 | features.append( 302 | InputFeatures( 303 | unique_id=example.unique_id, 304 | tokens=tokens, 305 | input_ids=input_ids, 306 | input_mask=input_mask, 307 | input_type_ids=input_type_ids)) 308 | return features 309 | 310 | 311 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 312 | """Truncates a sequence pair in place to the maximum length.""" 313 | 314 | # This is a simple heuristic which will always truncate the longer sequence 315 | # one token at a time. This makes more sense than truncating an equal percent 316 | # of tokens from each, since if one sequence is very short then each token 317 | # that's truncated likely contains more information than a longer sequence. 318 | while True: 319 | total_length = len(tokens_a) + len(tokens_b) 320 | if total_length <= max_length: 321 | break 322 | if len(tokens_a) > len(tokens_b): 323 | tokens_a.pop() 324 | else: 325 | tokens_b.pop() 326 | 327 | 328 | def read_examples(input_file): 329 | """Read a list of `InputExample`s from an input file.""" 330 | examples = [] 331 | unique_id = 0 332 | with tf.gfile.GFile(input_file, "r") as reader: 333 | while True: 334 | 335 | line = tokenization.convert_to_unicode(reader.readline()) 336 | if not line: 337 | break 338 | 339 | line = line.strip() 340 | text_a = None 341 | text_b = None 342 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 343 | 344 | if m is None: 345 | text_a = line 346 | else: 347 | text_a = m.group(1) 348 | text_b = m.group(2) 349 | 350 | examples.append( 351 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 352 | unique_id += 1 353 | 354 | return examples 355 | 356 | 357 | def main(_): 358 | 359 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 360 | 361 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 362 | 363 | # tokenizer = tokenization.FullTokenizer( 364 | # vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 365 | 366 | tokenizer = tokenization.FullTokenizer( 367 | model_file=FLAGS.model_file, vocab_file=FLAGS.vocab_file, 368 | do_lower_case=FLAGS.do_lower_case) 369 | 370 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 371 | run_config = tf.contrib.tpu.RunConfig( 372 | master=FLAGS.master, 373 | tpu_config=tf.contrib.tpu.TPUConfig( 374 | num_shards=FLAGS.num_tpu_cores, 375 | per_host_input_for_training=is_per_host)) 376 | 377 | examples = read_examples(FLAGS.input_file) 378 | 379 | 380 | features = convert_examples_to_features( 381 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 382 | 383 | unique_id_to_feature = {} 384 | for feature in features: 385 | 386 | unique_id_to_feature[feature.unique_id] = feature 387 | 388 | model_fn = model_fn_builder( 389 | bert_config=bert_config, 390 | init_checkpoint=FLAGS.init_checkpoint, 391 | layer_indexes=layer_indexes, 392 | use_tpu=FLAGS.use_tpu, 393 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 394 | 395 | 396 | # If TPU is not available, this will fall back to normal Estimator on CPU 397 | # or GPU. 398 | estimator = tf.contrib.tpu.TPUEstimator( 399 | use_tpu=FLAGS.use_tpu, 400 | model_fn=model_fn, 401 | config=run_config, 402 | predict_batch_size=FLAGS.batch_size) 403 | 404 | input_fn = input_fn_builder( 405 | features=features, seq_length=FLAGS.max_seq_length) 406 | 407 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 408 | "w")) as writer: 409 | for result in estimator.predict(input_fn, yield_single_examples=True): 410 | unique_id = int(result["unique_id"]) 411 | feature = unique_id_to_feature[unique_id] 412 | output_json = collections.OrderedDict() 413 | output_json["linex_index"] = unique_id 414 | all_features = [] 415 | for (i, token) in enumerate(feature.tokens): 416 | all_layers = [] 417 | for (j, layer_index) in enumerate(layer_indexes): 418 | layer_output = result["layer_output_%d" % j] 419 | layers = collections.OrderedDict() 420 | layers["index"] = layer_index 421 | layers["values"] = [ 422 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 423 | ] 424 | all_layers.append(layers) 425 | features = collections.OrderedDict() 426 | features["token"] = token 427 | features["layers"] = all_layers 428 | all_features.append(features) 429 | output_json["features"] = all_features 430 | writer.write(json.dumps(output_json) + "\n") 431 | 432 | 433 | if __name__ == "__main__": 434 | flags.mark_flag_as_required("input_file") 435 | flags.mark_flag_as_required("vocab_file") 436 | flags.mark_flag_as_required("bert_config_file") 437 | flags.mark_flag_as_required("init_checkpoint") 438 | flags.mark_flag_as_required("output_file") 439 | tf.app.run() 440 | -------------------------------------------------------------------------------- /src/file-preprocessing.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Read data_text_dir path from a config file. 4 | CURDIR=$(cd $(dirname $0); pwd) 5 | source <(sed -n '/^\[DATA\]/,/^\[/p' ${CURDIR}/../config.ini | grep TEXTDIR | sed 's/ *= */=/g') 6 | 7 | # Text preprocessing. 8 | # 1-1. Remove blank lines. 9 | # 1-2. Remove line with a blank line. 11 | # 2-1. Remove spaces at the end of each line. 12 | # 2-2. Break line at each 。, but not at 。」 or 。), position. 13 | # 2-3. Remove spaces at the head of each line. 14 | # 3. Remove lines with the head 。(these lines are not meaningful). 15 | # 4. Convert upper case characters to lower case ones. 16 | for FILE in $( find ${TEXTDIR} -name "wiki_*" ); do 17 | echo "Processing ${FILE}" 18 | sed -i -e '/^$/d; ///g' ${FILE} 19 | sed -i -e 's/ *$//g; s/。\([^」|)|)|"]\)/。\n\1/g; s/^[ ]*//g' ${FILE} 20 | sed -i -e '/^。/d' ${FILE} 21 | sed -i -e 's/\(.*\)/\L\1/' ${FILE} 22 | done 23 | 24 | # Concat all text files in each text directory. 25 | for DIR in $( find ${TEXTDIR} -mindepth 1 -type d ); do 26 | echo "Processing ${DIR}" 27 | for f in $( find ${DIR} -name "wiki_*" ); do cat $f >> ${DIR}/all.txt; done 28 | done 29 | -------------------------------------------------------------------------------- /src/run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file is based on https://github.com/google-research/bert/blob/master/run_classifier.py. 3 | # It is changed to use SentencePiece tokenizer and https://www.rondhuit.com/download/ldcc-20140209.tar.gz. 4 | """BERT finetuning runner.""" 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import collections 11 | import configparser 12 | import csv 13 | import json 14 | import os 15 | import sys 16 | import tempfile 17 | import tokenization_sentencepiece as tokenization 18 | import tensorflow as tf 19 | import utils 20 | 21 | CURDIR = os.path.dirname(os.path.abspath(__file__)) 22 | CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini') 23 | config = configparser.ConfigParser() 24 | config.read(CONFIGPATH) 25 | bert_config_file = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8', suffix='.json') 26 | bert_config_file.write(json.dumps({k:utils.str_to_value(v) for k,v in config['BERT-CONFIG'].items()})) 27 | bert_config_file.seek(0) 28 | 29 | sys.path.append(os.path.join(CURDIR, os.pardir, 'bert')) 30 | import modeling 31 | import optimization 32 | 33 | flags = tf.flags 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | # Required parameters 38 | flags.DEFINE_string( 39 | "data_dir", None, 40 | "The input data dir. Should contain the .tsv files (or other data files) " 41 | "for the task.") 42 | 43 | flags.DEFINE_string( 44 | "bert_config_file", None, 45 | "The config json file corresponding to the pre-trained BERT model. " 46 | "This specifies the model architecture.") 47 | 48 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 49 | 50 | flags.DEFINE_string("model_file", None, 51 | "The model file that the SentencePiece model was trained on.") 52 | 53 | flags.DEFINE_string("vocab_file", None, 54 | "The vocabulary file that the BERT model was trained on.") 55 | 56 | flags.DEFINE_string( 57 | "output_dir", None, 58 | "The output directory where the model checkpoints will be written.") 59 | 60 | # Other parameters 61 | 62 | flags.DEFINE_string( 63 | "init_checkpoint", None, 64 | "Initial checkpoint (usually from a pre-trained BERT model).") 65 | 66 | flags.DEFINE_bool( 67 | "do_lower_case", True, 68 | "Whether to lower case the input text. Should be True for uncased " 69 | "models and False for cased models.") 70 | 71 | flags.DEFINE_integer( 72 | "max_seq_length", 128, 73 | "The maximum total input sequence length after WordPiece tokenization. " 74 | "Sequences longer than this will be truncated, and sequences shorter " 75 | "than this will be padded.") 76 | 77 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 78 | 79 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 80 | 81 | flags.DEFINE_bool( 82 | "do_predict", False, 83 | "Whether to run the model in inference mode on the test set.") 84 | 85 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 86 | 87 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 88 | 89 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 90 | 91 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 92 | 93 | flags.DEFINE_float("num_train_epochs", 3.0, 94 | "Total number of training epochs to perform.") 95 | 96 | flags.DEFINE_float( 97 | "warmup_proportion", 0.1, 98 | "Proportion of training to perform linear learning rate warmup for. " 99 | "E.g., 0.1 = 10% of training.") 100 | 101 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 102 | "How often to save the model checkpoint.") 103 | 104 | flags.DEFINE_integer("iterations_per_loop", 1000, 105 | "How many steps to make in each estimator call.") 106 | 107 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 108 | 109 | tf.flags.DEFINE_string( 110 | "tpu_name", None, 111 | "The Cloud TPU to use for training. This should be either the name " 112 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 113 | "url.") 114 | 115 | tf.flags.DEFINE_string( 116 | "tpu_zone", None, 117 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 118 | "specified, we will attempt to automatically detect the GCE project from " 119 | "metadata.") 120 | 121 | tf.flags.DEFINE_string( 122 | "gcp_project", None, 123 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 124 | "specified, we will attempt to automatically detect the GCE project from " 125 | "metadata.") 126 | 127 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 128 | 129 | flags.DEFINE_integer( 130 | "num_tpu_cores", 8, 131 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 132 | 133 | 134 | class InputExample(object): 135 | """A single training/test example for simple sequence classification.""" 136 | 137 | def __init__(self, guid, text_a, text_b=None, label=None): 138 | """Constructs a InputExample. 139 | 140 | Args: 141 | guid: Unique id for the example. 142 | text_a: string. The untokenized text of the first sequence. For single 143 | sequence tasks, only this sequence must be specified. 144 | text_b: (Optional) string. The untokenized text of the second sequence. 145 | Only must be specified for sequence pair tasks. 146 | label: (Optional) string. The label of the example. This should be 147 | specified for train and dev examples, but not for test examples. 148 | """ 149 | self.guid = guid 150 | self.text_a = text_a 151 | self.text_b = text_b 152 | self.label = label 153 | 154 | 155 | class PaddingInputExample(object): 156 | """Fake example so the num input examples is a multiple of the batch size. 157 | 158 | When running eval/predict on the TPU, we need to pad the number of examples 159 | to be a multiple of the batch size, because the TPU requires a fixed batch 160 | size. The alternative is to drop the last batch, which is bad because it means 161 | the entire output data won't be generated. 162 | 163 | We use this class instead of `None` because treating `None` as padding 164 | battches could cause silent errors. 165 | """ 166 | 167 | 168 | class InputFeatures(object): 169 | """A single set of features of data.""" 170 | 171 | def __init__(self, 172 | input_ids, 173 | input_mask, 174 | segment_ids, 175 | label_id, 176 | is_real_example=True): 177 | self.input_ids = input_ids 178 | self.input_mask = input_mask 179 | self.segment_ids = segment_ids 180 | self.label_id = label_id 181 | self.is_real_example = is_real_example 182 | 183 | 184 | class DataProcessor(object): 185 | """Base class for data converters for sequence classification data sets.""" 186 | 187 | def get_train_examples(self, data_dir): 188 | """Gets a collection of `InputExample`s for the train set.""" 189 | raise NotImplementedError() 190 | 191 | def get_dev_examples(self, data_dir): 192 | """Gets a collection of `InputExample`s for the dev set.""" 193 | raise NotImplementedError() 194 | 195 | def get_test_examples(self, data_dir): 196 | """Gets a collection of `InputExample`s for prediction.""" 197 | raise NotImplementedError() 198 | 199 | def get_labels(self): 200 | """Gets the list of labels for this data set.""" 201 | raise NotImplementedError() 202 | 203 | @classmethod 204 | def _read_tsv(cls, input_file, quotechar=None): 205 | """Reads a tab separated value file.""" 206 | with tf.gfile.Open(input_file, "r") as f: 207 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 208 | lines = [] 209 | for line in reader: 210 | lines.append(line) 211 | return lines 212 | 213 | 214 | class LivedoorProcessor(DataProcessor): 215 | """Processor for the livedoor data set (see https://www.rondhuit.com/download.html).""" 216 | 217 | def get_train_examples(self, data_dir): 218 | """See base class.""" 219 | return self._create_examples( 220 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 221 | 222 | def get_dev_examples(self, data_dir): 223 | """See base class.""" 224 | return self._create_examples( 225 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 226 | 227 | def get_test_examples(self, data_dir): 228 | """See base class.""" 229 | return self._create_examples( 230 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 231 | 232 | def get_labels(self): 233 | """See base class.""" 234 | return ['dokujo-tsushin', 'it-life-hack', 'kaden-channel', 'livedoor-homme', 'movie-enter', 'peachy', 'smax', 'sports-watch', 'topic-news'] 235 | 236 | def _create_examples(self, lines, set_type): 237 | """Creates examples for the training and dev sets.""" 238 | examples = [] 239 | for (i, line) in enumerate(lines): 240 | if i == 0: 241 | idx_text = line.index('text') 242 | idx_label = line.index('label') 243 | else: 244 | guid = "%s-%s" % (set_type, i) 245 | text_a = tokenization.convert_to_unicode(line[idx_text]) 246 | label = tokenization.convert_to_unicode(line[idx_label]) 247 | examples.append( 248 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 249 | return examples 250 | 251 | 252 | def convert_single_example(ex_index, example, label_list, max_seq_length, 253 | tokenizer): 254 | """Converts a single `InputExample` into a single `InputFeatures`.""" 255 | 256 | if isinstance(example, PaddingInputExample): 257 | return InputFeatures( 258 | input_ids=[0] * max_seq_length, 259 | input_mask=[0] * max_seq_length, 260 | segment_ids=[0] * max_seq_length, 261 | label_id=0, 262 | is_real_example=False) 263 | 264 | label_map = {} 265 | for (i, label) in enumerate(label_list): 266 | label_map[label] = i 267 | 268 | tokens_a = tokenizer.tokenize(example.text_a) 269 | tokens_b = None 270 | if example.text_b: 271 | tokens_b = tokenizer.tokenize(example.text_b) 272 | 273 | if tokens_b: 274 | # Modifies `tokens_a` and `tokens_b` in place so that the total 275 | # length is less than the specified length. 276 | # Account for [CLS], [SEP], [SEP] with "- 3" 277 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 278 | else: 279 | # Account for [CLS] and [SEP] with "- 2" 280 | if len(tokens_a) > max_seq_length - 2: 281 | tokens_a = tokens_a[0:(max_seq_length - 2)] 282 | 283 | # The convention in BERT is: 284 | # (a) For sequence pairs: 285 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 286 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 287 | # (b) For single sequences: 288 | # tokens: [CLS] the dog is hairy . [SEP] 289 | # type_ids: 0 0 0 0 0 0 0 290 | # 291 | # Where "type_ids" are used to indicate whether this is the first 292 | # sequence or the second sequence. The embedding vectors for `type=0` and 293 | # `type=1` were learned during pre-training and are added to the wordpiece 294 | # embedding vector (and position vector). This is not *strictly* necessary 295 | # since the [SEP] token unambiguously separates the sequences, but it makes 296 | # it easier for the model to learn the concept of sequences. 297 | # 298 | # For classification tasks, the first vector (corresponding to [CLS]) is 299 | # used as the "sentence vector". Note that this only makes sense because 300 | # the entire model is fine-tuned. 301 | tokens = [] 302 | segment_ids = [] 303 | tokens.append("[CLS]") 304 | segment_ids.append(0) 305 | for token in tokens_a: 306 | tokens.append(token) 307 | segment_ids.append(0) 308 | tokens.append("[SEP]") 309 | segment_ids.append(0) 310 | 311 | if tokens_b: 312 | for token in tokens_b: 313 | tokens.append(token) 314 | segment_ids.append(1) 315 | tokens.append("[SEP]") 316 | segment_ids.append(1) 317 | 318 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 319 | 320 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 321 | # tokens are attended to. 322 | input_mask = [1] * len(input_ids) 323 | 324 | # Zero-pad up to the sequence length. 325 | while len(input_ids) < max_seq_length: 326 | input_ids.append(0) 327 | input_mask.append(0) 328 | segment_ids.append(0) 329 | 330 | assert len(input_ids) == max_seq_length 331 | assert len(input_mask) == max_seq_length 332 | assert len(segment_ids) == max_seq_length 333 | 334 | label_id = label_map[example.label] 335 | if ex_index < 5: 336 | tf.logging.info("*** Example ***") 337 | tf.logging.info("guid: %s" % (example.guid)) 338 | tf.logging.info("tokens: %s" % " ".join( 339 | [tokenization.printable_text(x) for x in tokens])) 340 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 341 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 342 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 343 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 344 | 345 | feature = InputFeatures( 346 | input_ids=input_ids, 347 | input_mask=input_mask, 348 | segment_ids=segment_ids, 349 | label_id=label_id, 350 | is_real_example=True) 351 | return feature 352 | 353 | 354 | def file_based_convert_examples_to_features( 355 | examples, label_list, max_seq_length, tokenizer, output_file): 356 | """Convert a set of `InputExample`s to a TFRecord file.""" 357 | 358 | writer = tf.python_io.TFRecordWriter(output_file) 359 | 360 | for (ex_index, example) in enumerate(examples): 361 | if ex_index % 10000 == 0: 362 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 363 | 364 | feature = convert_single_example(ex_index, example, label_list, 365 | max_seq_length, tokenizer) 366 | 367 | def create_int_feature(values): 368 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 369 | return f 370 | 371 | features = collections.OrderedDict() 372 | features["input_ids"] = create_int_feature(feature.input_ids) 373 | features["input_mask"] = create_int_feature(feature.input_mask) 374 | features["segment_ids"] = create_int_feature(feature.segment_ids) 375 | features["label_ids"] = create_int_feature([feature.label_id]) 376 | features["is_real_example"] = create_int_feature( 377 | [int(feature.is_real_example)]) 378 | 379 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 380 | writer.write(tf_example.SerializeToString()) 381 | writer.close() 382 | 383 | 384 | def file_based_input_fn_builder(input_file, seq_length, is_training, 385 | drop_remainder): 386 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 387 | 388 | name_to_features = { 389 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 390 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 391 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 392 | "label_ids": tf.FixedLenFeature([], tf.int64), 393 | "is_real_example": tf.FixedLenFeature([], tf.int64), 394 | } 395 | 396 | def _decode_record(record, name_to_features): 397 | """Decodes a record to a TensorFlow example.""" 398 | example = tf.parse_single_example(record, name_to_features) 399 | 400 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 401 | # So cast all int64 to int32. 402 | for name in list(example.keys()): 403 | t = example[name] 404 | if t.dtype == tf.int64: 405 | t = tf.to_int32(t) 406 | example[name] = t 407 | 408 | return example 409 | 410 | def input_fn(params): 411 | """The actual input function.""" 412 | batch_size = params["batch_size"] 413 | 414 | # For training, we want a lot of parallel reading and shuffling. 415 | # For eval, we want no shuffling and parallel reading doesn't matter. 416 | d = tf.data.TFRecordDataset(input_file) 417 | if is_training: 418 | d = d.repeat() 419 | d = d.shuffle(buffer_size=100) 420 | 421 | d = d.apply( 422 | tf.contrib.data.map_and_batch( 423 | lambda record: _decode_record(record, name_to_features), 424 | batch_size=batch_size, 425 | drop_remainder=drop_remainder)) 426 | 427 | return d 428 | 429 | return input_fn 430 | 431 | 432 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 433 | """Truncates a sequence pair in place to the maximum length.""" 434 | 435 | # This is a simple heuristic which will always truncate the longer sequence 436 | # one token at a time. This makes more sense than truncating an equal percent 437 | # of tokens from each, since if one sequence is very short then each token 438 | # that's truncated likely contains more information than a longer sequence. 439 | while True: 440 | total_length = len(tokens_a) + len(tokens_b) 441 | if total_length <= max_length: 442 | break 443 | if len(tokens_a) > len(tokens_b): 444 | tokens_a.pop() 445 | else: 446 | tokens_b.pop() 447 | 448 | 449 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 450 | labels, num_labels, use_one_hot_embeddings): 451 | """Creates a classification model.""" 452 | model = modeling.BertModel( 453 | config=bert_config, 454 | is_training=is_training, 455 | input_ids=input_ids, 456 | input_mask=input_mask, 457 | token_type_ids=segment_ids, 458 | use_one_hot_embeddings=use_one_hot_embeddings) 459 | 460 | # In the demo, we are doing a simple classification task on the entire 461 | # segment. 462 | # 463 | # If you want to use the token-level output, use model.get_sequence_output() 464 | # instead. 465 | output_layer = model.get_pooled_output() 466 | 467 | hidden_size = output_layer.shape[-1].value 468 | 469 | output_weights = tf.get_variable( 470 | "output_weights", [num_labels, hidden_size], 471 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 472 | 473 | output_bias = tf.get_variable( 474 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 475 | 476 | with tf.variable_scope("loss"): 477 | if is_training: 478 | # I.e., 0.1 dropout 479 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 480 | 481 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 482 | logits = tf.nn.bias_add(logits, output_bias) 483 | probabilities = tf.nn.softmax(logits, axis=-1) 484 | log_probs = tf.nn.log_softmax(logits, axis=-1) 485 | 486 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 487 | 488 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 489 | loss = tf.reduce_mean(per_example_loss) 490 | 491 | return (loss, per_example_loss, logits, probabilities) 492 | 493 | 494 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 495 | num_train_steps, num_warmup_steps, use_tpu, 496 | use_one_hot_embeddings): 497 | """Returns `model_fn` closure for TPUEstimator.""" 498 | 499 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 500 | """The `model_fn` for TPUEstimator.""" 501 | 502 | tf.logging.info("*** Features ***") 503 | for name in sorted(features.keys()): 504 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 505 | 506 | input_ids = features["input_ids"] 507 | input_mask = features["input_mask"] 508 | segment_ids = features["segment_ids"] 509 | label_ids = features["label_ids"] 510 | is_real_example = None 511 | if "is_real_example" in features: 512 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 513 | else: 514 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 515 | 516 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 517 | 518 | (total_loss, per_example_loss, logits, probabilities) = create_model( 519 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 520 | num_labels, use_one_hot_embeddings) 521 | 522 | tvars = tf.trainable_variables() 523 | initialized_variable_names = {} 524 | scaffold_fn = None 525 | if init_checkpoint: 526 | (assignment_map, initialized_variable_names 527 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 528 | if use_tpu: 529 | 530 | def tpu_scaffold(): 531 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 532 | return tf.train.Scaffold() 533 | 534 | scaffold_fn = tpu_scaffold 535 | else: 536 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 537 | 538 | tf.logging.info("**** Trainable Variables ****") 539 | for var in tvars: 540 | init_string = "" 541 | if var.name in initialized_variable_names: 542 | init_string = ", *INIT_FROM_CKPT*" 543 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 544 | init_string) 545 | 546 | output_spec = None 547 | if mode == tf.estimator.ModeKeys.TRAIN: 548 | 549 | train_op = optimization.create_optimizer( 550 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 551 | 552 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 553 | mode=mode, 554 | loss=total_loss, 555 | train_op=train_op, 556 | scaffold_fn=scaffold_fn) 557 | elif mode == tf.estimator.ModeKeys.EVAL: 558 | 559 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 560 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 561 | accuracy = tf.metrics.accuracy( 562 | labels=label_ids, predictions=predictions, weights=is_real_example) 563 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 564 | return { 565 | "eval_accuracy": accuracy, 566 | "eval_loss": loss, 567 | } 568 | 569 | eval_metrics = (metric_fn, 570 | [per_example_loss, label_ids, logits, is_real_example]) 571 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 572 | mode=mode, 573 | loss=total_loss, 574 | eval_metrics=eval_metrics, 575 | scaffold_fn=scaffold_fn) 576 | else: 577 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 578 | mode=mode, 579 | predictions={"probabilities": probabilities}, 580 | scaffold_fn=scaffold_fn) 581 | return output_spec 582 | 583 | return model_fn 584 | 585 | 586 | # This function is not used by this file but is still used by the Colab and 587 | # people who depend on it. 588 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 589 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 590 | 591 | all_input_ids = [] 592 | all_input_mask = [] 593 | all_segment_ids = [] 594 | all_label_ids = [] 595 | 596 | for feature in features: 597 | all_input_ids.append(feature.input_ids) 598 | all_input_mask.append(feature.input_mask) 599 | all_segment_ids.append(feature.segment_ids) 600 | all_label_ids.append(feature.label_id) 601 | 602 | def input_fn(params): 603 | """The actual input function.""" 604 | batch_size = params["batch_size"] 605 | 606 | num_examples = len(features) 607 | 608 | # This is for demo purposes and does NOT scale to large data sets. We do 609 | # not use Dataset.from_generator() because that uses tf.py_func which is 610 | # not TPU compatible. The right way to load data is with TFRecordReader. 611 | d = tf.data.Dataset.from_tensor_slices({ 612 | "input_ids": 613 | tf.constant( 614 | all_input_ids, shape=[num_examples, seq_length], 615 | dtype=tf.int32), 616 | "input_mask": 617 | tf.constant( 618 | all_input_mask, 619 | shape=[num_examples, seq_length], 620 | dtype=tf.int32), 621 | "segment_ids": 622 | tf.constant( 623 | all_segment_ids, 624 | shape=[num_examples, seq_length], 625 | dtype=tf.int32), 626 | "label_ids": 627 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 628 | }) 629 | 630 | if is_training: 631 | d = d.repeat() 632 | d = d.shuffle(buffer_size=100) 633 | 634 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 635 | return d 636 | 637 | return input_fn 638 | 639 | 640 | # This function is not used by this file but is still used by the Colab and 641 | # people who depend on it. 642 | def convert_examples_to_features(examples, label_list, max_seq_length, 643 | tokenizer): 644 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 645 | 646 | features = [] 647 | for (ex_index, example) in enumerate(examples): 648 | if ex_index % 10000 == 0: 649 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 650 | 651 | feature = convert_single_example(ex_index, example, label_list, 652 | max_seq_length, tokenizer) 653 | 654 | features.append(feature) 655 | return features 656 | 657 | 658 | def main(_): 659 | tf.logging.set_verbosity(tf.logging.INFO) 660 | 661 | processors = { 662 | "livedoor": LivedoorProcessor, 663 | } 664 | 665 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 666 | FLAGS.init_checkpoint) 667 | 668 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 669 | raise ValueError( 670 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 671 | 672 | bert_config = modeling.BertConfig.from_json_file(bert_config_file.name) 673 | 674 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 675 | raise ValueError( 676 | "Cannot use sequence length %d because the BERT model " 677 | "was only trained up to sequence length %d" % 678 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 679 | 680 | tf.gfile.MakeDirs(FLAGS.output_dir) 681 | 682 | task_name = FLAGS.task_name.lower() 683 | 684 | if task_name not in processors: 685 | raise ValueError("Task not found: %s" % (task_name)) 686 | 687 | processor = processors[task_name]() 688 | 689 | label_list = processor.get_labels() 690 | 691 | tokenizer = tokenization.FullTokenizer( 692 | model_file=FLAGS.model_file, vocab_file=FLAGS.vocab_file, 693 | do_lower_case=FLAGS.do_lower_case) 694 | 695 | tpu_cluster_resolver = None 696 | if FLAGS.use_tpu and FLAGS.tpu_name: 697 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 698 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 699 | 700 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 701 | run_config = tf.contrib.tpu.RunConfig( 702 | cluster=tpu_cluster_resolver, 703 | master=FLAGS.master, 704 | model_dir=FLAGS.output_dir, 705 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 706 | tpu_config=tf.contrib.tpu.TPUConfig( 707 | iterations_per_loop=FLAGS.iterations_per_loop, 708 | num_shards=FLAGS.num_tpu_cores, 709 | per_host_input_for_training=is_per_host)) 710 | 711 | train_examples = None 712 | num_train_steps = None 713 | num_warmup_steps = None 714 | if FLAGS.do_train: 715 | train_examples = processor.get_train_examples(FLAGS.data_dir) 716 | num_train_steps = int( 717 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 718 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 719 | 720 | model_fn = model_fn_builder( 721 | bert_config=bert_config, 722 | num_labels=len(label_list), 723 | init_checkpoint=FLAGS.init_checkpoint, 724 | learning_rate=FLAGS.learning_rate, 725 | num_train_steps=num_train_steps, 726 | num_warmup_steps=num_warmup_steps, 727 | use_tpu=FLAGS.use_tpu, 728 | use_one_hot_embeddings=FLAGS.use_tpu) 729 | 730 | # If TPU is not available, this will fall back to normal Estimator on CPU 731 | # or GPU. 732 | estimator = tf.contrib.tpu.TPUEstimator( 733 | use_tpu=FLAGS.use_tpu, 734 | model_fn=model_fn, 735 | config=run_config, 736 | train_batch_size=FLAGS.train_batch_size, 737 | eval_batch_size=FLAGS.eval_batch_size, 738 | predict_batch_size=FLAGS.predict_batch_size) 739 | 740 | if FLAGS.do_train: 741 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 742 | file_based_convert_examples_to_features( 743 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 744 | tf.logging.info("***** Running training *****") 745 | tf.logging.info(" Num examples = %d", len(train_examples)) 746 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 747 | tf.logging.info(" Num steps = %d", num_train_steps) 748 | train_input_fn = file_based_input_fn_builder( 749 | input_file=train_file, 750 | seq_length=FLAGS.max_seq_length, 751 | is_training=True, 752 | drop_remainder=True) 753 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 754 | 755 | if FLAGS.do_eval: 756 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 757 | num_actual_eval_examples = len(eval_examples) 758 | if FLAGS.use_tpu: 759 | # TPU requires a fixed batch size for all batches, therefore the number 760 | # of examples must be a multiple of the batch size, or else examples 761 | # will get dropped. So we pad with fake examples which are ignored 762 | # later on. These do NOT count towards the metric (all tf.metrics 763 | # support a per-instance weight, and these get a weight of 0.0). 764 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 765 | eval_examples.append(PaddingInputExample()) 766 | 767 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 768 | file_based_convert_examples_to_features( 769 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 770 | 771 | tf.logging.info("***** Running evaluation *****") 772 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 773 | len(eval_examples), num_actual_eval_examples, 774 | len(eval_examples) - num_actual_eval_examples) 775 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 776 | 777 | # This tells the estimator to run through the entire set. 778 | eval_steps = None 779 | # However, if running eval on the TPU, you will need to specify the 780 | # number of steps. 781 | if FLAGS.use_tpu: 782 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 783 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 784 | 785 | eval_drop_remainder = True if FLAGS.use_tpu else False 786 | eval_input_fn = file_based_input_fn_builder( 787 | input_file=eval_file, 788 | seq_length=FLAGS.max_seq_length, 789 | is_training=False, 790 | drop_remainder=eval_drop_remainder) 791 | 792 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 793 | 794 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 795 | with tf.gfile.GFile(output_eval_file, "w") as writer: 796 | tf.logging.info("***** Eval results *****") 797 | for key in sorted(result.keys()): 798 | tf.logging.info(" %s = %s", key, str(result[key])) 799 | writer.write("%s = %s\n" % (key, str(result[key]))) 800 | 801 | if FLAGS.do_predict: 802 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 803 | num_actual_predict_examples = len(predict_examples) 804 | if FLAGS.use_tpu: 805 | # TPU requires a fixed batch size for all batches, therefore the number 806 | # of examples must be a multiple of the batch size, or else examples 807 | # will get dropped. So we pad with fake examples which are ignored 808 | # later on. 809 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 810 | predict_examples.append(PaddingInputExample()) 811 | 812 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 813 | file_based_convert_examples_to_features(predict_examples, label_list, 814 | FLAGS.max_seq_length, tokenizer, 815 | predict_file) 816 | 817 | tf.logging.info("***** Running prediction*****") 818 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 819 | len(predict_examples), num_actual_predict_examples, 820 | len(predict_examples) - num_actual_predict_examples) 821 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 822 | 823 | predict_drop_remainder = True if FLAGS.use_tpu else False 824 | predict_input_fn = file_based_input_fn_builder( 825 | input_file=predict_file, 826 | seq_length=FLAGS.max_seq_length, 827 | is_training=False, 828 | drop_remainder=predict_drop_remainder) 829 | 830 | result = estimator.predict(input_fn=predict_input_fn) 831 | 832 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 833 | with tf.gfile.GFile(output_predict_file, "w") as writer: 834 | num_written_lines = 0 835 | tf.logging.info("***** Predict results *****") 836 | for (i, prediction) in enumerate(result): 837 | probabilities = prediction["probabilities"] 838 | if i >= num_actual_predict_examples: 839 | break 840 | output_line = "\t".join( 841 | str(class_probability) 842 | for class_probability in probabilities) + "\n" 843 | writer.write(output_line) 844 | num_written_lines += 1 845 | assert num_written_lines == num_actual_predict_examples 846 | 847 | 848 | if __name__ == "__main__": 849 | flags.mark_flag_as_required("data_dir") 850 | flags.mark_flag_as_required("task_name") 851 | flags.mark_flag_as_required("model_file") 852 | flags.mark_flag_as_required("vocab_file") 853 | flags.mark_flag_as_required("output_dir") 854 | tf.app.run() 855 | -------------------------------------------------------------------------------- /src/run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file is based on https://github.com/google-research/bert/blob/master/run_pretraining.py. 3 | # It is changed to read model parameters from config.ini. 4 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import configparser 11 | import json 12 | import os 13 | import sys 14 | import tempfile 15 | import tensorflow as tf 16 | import utils 17 | 18 | CURDIR = os.path.dirname(os.path.abspath(__file__)) 19 | CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini') 20 | config = configparser.ConfigParser() 21 | config.read(CONFIGPATH) 22 | bert_config_file = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8', suffix='.json') 23 | bert_config_file.write(json.dumps({k:utils.str_to_value(v) for k,v in config['BERT-CONFIG'].items()})) 24 | bert_config_file.seek(0) 25 | 26 | sys.path.append(os.path.join(CURDIR, os.pardir, 'bert')) 27 | import modeling 28 | import optimization 29 | 30 | 31 | flags = tf.flags 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | # Required parameters 36 | flags.DEFINE_string( 37 | "bert_config_file", None, 38 | "The config json file corresponding to the pre-trained BERT model. " 39 | "This specifies the model architecture.") 40 | 41 | flags.DEFINE_string( 42 | "input_file", None, 43 | "Input TF example files (can be a glob or comma separated).") 44 | 45 | flags.DEFINE_string( 46 | "output_dir", None, 47 | "The output directory where the model checkpoints will be written.") 48 | 49 | # Other parameters 50 | flags.DEFINE_string( 51 | "init_checkpoint", None, 52 | "Initial checkpoint (usually from a pre-trained BERT model).") 53 | 54 | flags.DEFINE_integer( 55 | "max_seq_length", 128, 56 | "The maximum total input sequence length after WordPiece tokenization. " 57 | "Sequences longer than this will be truncated, and sequences shorter " 58 | "than this will be padded. Must match data generation.") 59 | 60 | flags.DEFINE_integer( 61 | "max_predictions_per_seq", 20, 62 | "Maximum number of masked LM predictions per sequence. " 63 | "Must match data generation.") 64 | 65 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 66 | 67 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 68 | 69 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 70 | 71 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 72 | 73 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 74 | 75 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 76 | 77 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 78 | 79 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 80 | "How often to save the model checkpoint.") 81 | 82 | flags.DEFINE_integer("iterations_per_loop", 1000, 83 | "How many steps to make in each estimator call.") 84 | 85 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 86 | 87 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 88 | 89 | tf.flags.DEFINE_string( 90 | "tpu_name", None, 91 | "The Cloud TPU to use for training. This should be either the name " 92 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 93 | "url.") 94 | 95 | tf.flags.DEFINE_string( 96 | "tpu_zone", None, 97 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 98 | "specified, we will attempt to automatically detect the GCE project from " 99 | "metadata.") 100 | 101 | tf.flags.DEFINE_string( 102 | "gcp_project", None, 103 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 104 | "specified, we will attempt to automatically detect the GCE project from " 105 | "metadata.") 106 | 107 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 108 | 109 | flags.DEFINE_integer( 110 | "num_tpu_cores", 8, 111 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 112 | 113 | 114 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 115 | num_train_steps, num_warmup_steps, use_tpu, 116 | use_one_hot_embeddings): 117 | """Returns `model_fn` closure for TPUEstimator.""" 118 | 119 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 120 | """The `model_fn` for TPUEstimator.""" 121 | 122 | tf.logging.info("*** Features ***") 123 | for name in sorted(features.keys()): 124 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 125 | 126 | input_ids = features["input_ids"] 127 | input_mask = features["input_mask"] 128 | segment_ids = features["segment_ids"] 129 | masked_lm_positions = features["masked_lm_positions"] 130 | masked_lm_ids = features["masked_lm_ids"] 131 | masked_lm_weights = features["masked_lm_weights"] 132 | next_sentence_labels = features["next_sentence_labels"] 133 | 134 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 135 | 136 | model = modeling.BertModel( 137 | config=bert_config, 138 | is_training=is_training, 139 | input_ids=input_ids, 140 | input_mask=input_mask, 141 | token_type_ids=segment_ids, 142 | use_one_hot_embeddings=use_one_hot_embeddings) 143 | 144 | (masked_lm_loss, 145 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 146 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 147 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 148 | 149 | (next_sentence_loss, next_sentence_example_loss, 150 | next_sentence_log_probs) = get_next_sentence_output( 151 | bert_config, model.get_pooled_output(), next_sentence_labels) 152 | 153 | total_loss = masked_lm_loss + next_sentence_loss 154 | 155 | tvars = tf.trainable_variables() 156 | 157 | initialized_variable_names = {} 158 | scaffold_fn = None 159 | if init_checkpoint: 160 | (assignment_map, initialized_variable_names 161 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 162 | if use_tpu: 163 | 164 | def tpu_scaffold(): 165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 166 | return tf.train.Scaffold() 167 | 168 | scaffold_fn = tpu_scaffold 169 | else: 170 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 171 | 172 | tf.logging.info("**** Trainable Variables ****") 173 | for var in tvars: 174 | init_string = "" 175 | if var.name in initialized_variable_names: 176 | init_string = ", *INIT_FROM_CKPT*" 177 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 178 | init_string) 179 | 180 | output_spec = None 181 | if mode == tf.estimator.ModeKeys.TRAIN: 182 | train_op = optimization.create_optimizer( 183 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 184 | 185 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 186 | mode=mode, 187 | loss=total_loss, 188 | train_op=train_op, 189 | scaffold_fn=scaffold_fn) 190 | elif mode == tf.estimator.ModeKeys.EVAL: 191 | 192 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 193 | masked_lm_weights, next_sentence_example_loss, 194 | next_sentence_log_probs, next_sentence_labels): 195 | """Computes the loss and accuracy of the model.""" 196 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 197 | [-1, masked_lm_log_probs.shape[-1]]) 198 | masked_lm_predictions = tf.argmax( 199 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 200 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 201 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 202 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 203 | masked_lm_accuracy = tf.metrics.accuracy( 204 | labels=masked_lm_ids, 205 | predictions=masked_lm_predictions, 206 | weights=masked_lm_weights) 207 | masked_lm_mean_loss = tf.metrics.mean( 208 | values=masked_lm_example_loss, weights=masked_lm_weights) 209 | 210 | next_sentence_log_probs = tf.reshape( 211 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 212 | next_sentence_predictions = tf.argmax( 213 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 214 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 215 | next_sentence_accuracy = tf.metrics.accuracy( 216 | labels=next_sentence_labels, predictions=next_sentence_predictions) 217 | next_sentence_mean_loss = tf.metrics.mean( 218 | values=next_sentence_example_loss) 219 | 220 | return { 221 | "masked_lm_accuracy": masked_lm_accuracy, 222 | "masked_lm_loss": masked_lm_mean_loss, 223 | "next_sentence_accuracy": next_sentence_accuracy, 224 | "next_sentence_loss": next_sentence_mean_loss, 225 | } 226 | 227 | eval_metrics = (metric_fn, [ 228 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 229 | masked_lm_weights, next_sentence_example_loss, 230 | next_sentence_log_probs, next_sentence_labels 231 | ]) 232 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 233 | mode=mode, 234 | loss=total_loss, 235 | eval_metrics=eval_metrics, 236 | scaffold_fn=scaffold_fn) 237 | else: 238 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 239 | 240 | return output_spec 241 | 242 | return model_fn 243 | 244 | 245 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 246 | label_ids, label_weights): 247 | """Get loss and log probs for the masked LM.""" 248 | input_tensor = gather_indexes(input_tensor, positions) 249 | 250 | with tf.variable_scope("cls/predictions"): 251 | # We apply one more non-linear transformation before the output layer. 252 | # This matrix is not used after pre-training. 253 | with tf.variable_scope("transform"): 254 | input_tensor = tf.layers.dense( 255 | input_tensor, 256 | units=bert_config.hidden_size, 257 | activation=modeling.get_activation(bert_config.hidden_act), 258 | kernel_initializer=modeling.create_initializer( 259 | bert_config.initializer_range)) 260 | input_tensor = modeling.layer_norm(input_tensor) 261 | 262 | # The output weights are the same as the input embeddings, but there is 263 | # an output-only bias for each token. 264 | output_bias = tf.get_variable( 265 | "output_bias", 266 | shape=[bert_config.vocab_size], 267 | initializer=tf.zeros_initializer()) 268 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 269 | logits = tf.nn.bias_add(logits, output_bias) 270 | log_probs = tf.nn.log_softmax(logits, axis=-1) 271 | 272 | label_ids = tf.reshape(label_ids, [-1]) 273 | label_weights = tf.reshape(label_weights, [-1]) 274 | 275 | one_hot_labels = tf.one_hot( 276 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 277 | 278 | # The `positions` tensor might be zero-padded (if the sequence is too 279 | # short to have the maximum number of predictions). The `label_weights` 280 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 281 | # padding predictions. 282 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 283 | numerator = tf.reduce_sum(label_weights * per_example_loss) 284 | denominator = tf.reduce_sum(label_weights) + 1e-5 285 | loss = numerator / denominator 286 | 287 | return (loss, per_example_loss, log_probs) 288 | 289 | 290 | def get_next_sentence_output(bert_config, input_tensor, labels): 291 | """Get loss and log probs for the next sentence prediction.""" 292 | 293 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 294 | # "random sentence". This weight matrix is not used after pre-training. 295 | with tf.variable_scope("cls/seq_relationship"): 296 | output_weights = tf.get_variable( 297 | "output_weights", 298 | shape=[2, bert_config.hidden_size], 299 | initializer=modeling.create_initializer(bert_config.initializer_range)) 300 | output_bias = tf.get_variable( 301 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 302 | 303 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 304 | logits = tf.nn.bias_add(logits, output_bias) 305 | log_probs = tf.nn.log_softmax(logits, axis=-1) 306 | labels = tf.reshape(labels, [-1]) 307 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 308 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 309 | loss = tf.reduce_mean(per_example_loss) 310 | return (loss, per_example_loss, log_probs) 311 | 312 | 313 | def gather_indexes(sequence_tensor, positions): 314 | """Gathers the vectors at the specific positions over a minibatch.""" 315 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 316 | batch_size = sequence_shape[0] 317 | seq_length = sequence_shape[1] 318 | width = sequence_shape[2] 319 | 320 | flat_offsets = tf.reshape( 321 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 322 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 323 | flat_sequence_tensor = tf.reshape(sequence_tensor, 324 | [batch_size * seq_length, width]) 325 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 326 | return output_tensor 327 | 328 | 329 | def input_fn_builder(input_files, 330 | max_seq_length, 331 | max_predictions_per_seq, 332 | is_training, 333 | num_cpu_threads=4): 334 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 335 | 336 | def input_fn(params): 337 | """The actual input function.""" 338 | batch_size = params["batch_size"] 339 | 340 | name_to_features = { 341 | "input_ids": 342 | tf.FixedLenFeature([max_seq_length], tf.int64), 343 | "input_mask": 344 | tf.FixedLenFeature([max_seq_length], tf.int64), 345 | "segment_ids": 346 | tf.FixedLenFeature([max_seq_length], tf.int64), 347 | "masked_lm_positions": 348 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 349 | "masked_lm_ids": 350 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 351 | "masked_lm_weights": 352 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 353 | "next_sentence_labels": 354 | tf.FixedLenFeature([1], tf.int64), 355 | } 356 | 357 | # For training, we want a lot of parallel reading and shuffling. 358 | # For eval, we want no shuffling and parallel reading doesn't matter. 359 | if is_training: 360 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 361 | d = d.repeat() 362 | d = d.shuffle(buffer_size=len(input_files)) 363 | 364 | # `cycle_length` is the number of parallel files that get read. 365 | cycle_length = min(num_cpu_threads, len(input_files)) 366 | 367 | # `sloppy` mode means that the interleaving is not exact. This adds 368 | # even more randomness to the training pipeline. 369 | d = d.apply( 370 | tf.contrib.data.parallel_interleave( 371 | tf.data.TFRecordDataset, 372 | sloppy=is_training, 373 | cycle_length=cycle_length)) 374 | d = d.shuffle(buffer_size=100) 375 | else: 376 | d = tf.data.TFRecordDataset(input_files) 377 | # Since we evaluate for a fixed number of steps we don't want to encounter 378 | # out-of-range exceptions. 379 | d = d.repeat() 380 | 381 | # We must `drop_remainder` on training because the TPU requires fixed 382 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 383 | # and we *don't* want to drop the remainder, otherwise we wont cover 384 | # every sample. 385 | d = d.apply( 386 | tf.contrib.data.map_and_batch( 387 | lambda record: _decode_record(record, name_to_features), 388 | batch_size=batch_size, 389 | num_parallel_batches=num_cpu_threads, 390 | drop_remainder=True)) 391 | return d 392 | 393 | return input_fn 394 | 395 | 396 | def _decode_record(record, name_to_features): 397 | """Decodes a record to a TensorFlow example.""" 398 | example = tf.parse_single_example(record, name_to_features) 399 | 400 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 401 | # So cast all int64 to int32. 402 | for name in list(example.keys()): 403 | t = example[name] 404 | if t.dtype == tf.int64: 405 | t = tf.to_int32(t) 406 | example[name] = t 407 | 408 | return example 409 | 410 | 411 | def main(_): 412 | tf.logging.set_verbosity(tf.logging.INFO) 413 | 414 | if not FLAGS.do_train and not FLAGS.do_eval: 415 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 416 | 417 | bert_config = modeling.BertConfig.from_json_file(bert_config_file.name) 418 | 419 | tf.gfile.MakeDirs(FLAGS.output_dir) 420 | 421 | input_files = [] 422 | for input_pattern in FLAGS.input_file.split(","): 423 | input_files.extend(tf.gfile.Glob(input_pattern)) 424 | 425 | tf.logging.info("*** Input Files ***") 426 | for input_file in input_files: 427 | tf.logging.info(" %s" % input_file) 428 | 429 | tpu_cluster_resolver = None 430 | if FLAGS.use_tpu and FLAGS.tpu_name: 431 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 432 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 433 | 434 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 435 | run_config = tf.contrib.tpu.RunConfig( 436 | cluster=tpu_cluster_resolver, 437 | master=FLAGS.master, 438 | model_dir=FLAGS.output_dir, 439 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 440 | tpu_config=tf.contrib.tpu.TPUConfig( 441 | iterations_per_loop=FLAGS.iterations_per_loop, 442 | num_shards=FLAGS.num_tpu_cores, 443 | per_host_input_for_training=is_per_host)) 444 | 445 | model_fn = model_fn_builder( 446 | bert_config=bert_config, 447 | init_checkpoint=FLAGS.init_checkpoint, 448 | learning_rate=FLAGS.learning_rate, 449 | num_train_steps=FLAGS.num_train_steps, 450 | num_warmup_steps=FLAGS.num_warmup_steps, 451 | use_tpu=FLAGS.use_tpu, 452 | use_one_hot_embeddings=FLAGS.use_tpu) 453 | 454 | # If TPU is not available, this will fall back to normal Estimator on CPU 455 | # or GPU. 456 | estimator = tf.contrib.tpu.TPUEstimator( 457 | use_tpu=FLAGS.use_tpu, 458 | model_fn=model_fn, 459 | config=run_config, 460 | train_batch_size=FLAGS.train_batch_size, 461 | eval_batch_size=FLAGS.eval_batch_size) 462 | 463 | if FLAGS.do_train: 464 | tf.logging.info("***** Running training *****") 465 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 466 | train_input_fn = input_fn_builder( 467 | input_files=input_files, 468 | max_seq_length=FLAGS.max_seq_length, 469 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 470 | is_training=True) 471 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 472 | 473 | if FLAGS.do_eval: 474 | tf.logging.info("***** Running evaluation *****") 475 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 476 | 477 | eval_input_fn = input_fn_builder( 478 | input_files=input_files, 479 | max_seq_length=FLAGS.max_seq_length, 480 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 481 | is_training=False) 482 | 483 | result = estimator.evaluate( 484 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 485 | 486 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 487 | with tf.gfile.GFile(output_eval_file, "w") as writer: 488 | tf.logging.info("***** Eval results *****") 489 | for key in sorted(result.keys()): 490 | tf.logging.info(" %s = %s", key, str(result[key])) 491 | writer.write("%s = %s\n" % (key, str(result[key]))) 492 | 493 | 494 | if __name__ == "__main__": 495 | flags.mark_flag_as_required("input_file") 496 | flags.mark_flag_as_required("output_dir") 497 | tf.app.run() 498 | -------------------------------------------------------------------------------- /src/tokenization_sentencepiece.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file is based on https://github.com/google-research/bert/blob/master/tokenization.py. 3 | # It is changed to use SentencePiece tokenizer for tokenizations. 4 | """Tokenization classes.""" 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import collections 11 | import re 12 | import unicodedata 13 | import sentencepiece as sp 14 | import six 15 | import tensorflow as tf 16 | 17 | 18 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 19 | """Checks whether the casing config is consistent with the checkpoint name.""" 20 | 21 | # The casing has to be passed in by the user and there is no explicit check 22 | # as to whether it matches the checkpoint. The casing information probably 23 | # should have been stored in the bert_config.json file, but it's not, so 24 | # we have to heuristically detect it to validate. 25 | 26 | if not init_checkpoint: 27 | return 28 | 29 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 30 | if m is None: 31 | return 32 | 33 | model_name = m.group(1) 34 | 35 | lower_models = [ 36 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 37 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 38 | ] 39 | 40 | cased_models = [ 41 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 42 | "multi_cased_L-12_H-768_A-12" 43 | ] 44 | 45 | is_bad_config = False 46 | if model_name in lower_models and not do_lower_case: 47 | is_bad_config = True 48 | actual_flag = "False" 49 | case_name = "lowercased" 50 | opposite_flag = "True" 51 | 52 | if model_name in cased_models and do_lower_case: 53 | is_bad_config = True 54 | actual_flag = "True" 55 | case_name = "cased" 56 | opposite_flag = "False" 57 | 58 | if is_bad_config: 59 | raise ValueError( 60 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 61 | "However, `%s` seems to be a %s model, so you " 62 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 63 | "how the model was pre-training. If this error is wrong, please " 64 | "just comment out this check." % (actual_flag, init_checkpoint, 65 | model_name, case_name, opposite_flag)) 66 | 67 | 68 | def convert_to_unicode(text): 69 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 70 | if six.PY3: 71 | if isinstance(text, str): 72 | return text 73 | elif isinstance(text, bytes): 74 | return text.decode("utf-8", "ignore") 75 | else: 76 | raise ValueError("Unsupported string type: %s" % (type(text))) 77 | elif six.PY2: 78 | if isinstance(text, str): 79 | return text.decode("utf-8", "ignore") 80 | elif isinstance(text, unicode): 81 | return text 82 | else: 83 | raise ValueError("Unsupported string type: %s" % (type(text))) 84 | else: 85 | raise ValueError("Not running on Python2 or Python 3?") 86 | 87 | 88 | def printable_text(text): 89 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 90 | 91 | # These functions want `str` for both Python2 and Python3, but in one case 92 | # it's a Unicode string and in the other it's a byte string. 93 | if six.PY3: 94 | if isinstance(text, str): 95 | return text 96 | elif isinstance(text, bytes): 97 | return text.decode("utf-8", "ignore") 98 | else: 99 | raise ValueError("Unsupported string type: %s" % (type(text))) 100 | elif six.PY2: 101 | if isinstance(text, str): 102 | return text 103 | elif isinstance(text, unicode): 104 | return text.encode("utf-8") 105 | else: 106 | raise ValueError("Unsupported string type: %s" % (type(text))) 107 | else: 108 | raise ValueError("Not running on Python2 or Python 3?") 109 | 110 | 111 | def load_vocab(vocab_file): 112 | """Loads a vocabulary file into a dictionary.""" 113 | vocab = collections.OrderedDict() 114 | index = 0 115 | with tf.gfile.GFile(vocab_file, "r") as reader: 116 | while True: 117 | token = convert_to_unicode(reader.readline()) 118 | if not token: 119 | break 120 | token, _ = token.split("\t") 121 | token = token.strip() 122 | vocab[token] = index 123 | index += 1 124 | return vocab 125 | 126 | 127 | def convert_by_vocab(vocab, items, unk_info): 128 | """Converts a sequence of [tokens|ids] using the vocab.""" 129 | output = [] 130 | for item in items: 131 | if item in vocab: 132 | output.append(vocab[item]) 133 | else: 134 | output.append(unk_info) 135 | return output 136 | 137 | 138 | def convert_tokens_to_ids(vocab, tokens): 139 | """Id of is assumed as 0 accroding to sentencepiece""" 140 | return convert_by_vocab(vocab, tokens, unk_info=0) 141 | 142 | 143 | def convert_ids_to_tokens(inv_vocab, ids): 144 | """Token of unknown word is assumed as according to sentencepiece""" 145 | return convert_by_vocab(inv_vocab, ids, unk_info="") 146 | 147 | 148 | class FullTokenizer(object): 149 | """Runs end-to-end tokenziation.""" 150 | 151 | def __init__(self, model_file, vocab_file, do_lower_case=True): 152 | self.tokenizer = SentencePieceTokenizer(model_file, do_lower_case=do_lower_case) 153 | self.vocab = load_vocab(vocab_file) 154 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 155 | 156 | def tokenize(self, text): 157 | split_tokens = self.tokenizer.tokenize(text) 158 | return split_tokens 159 | 160 | def convert_tokens_to_ids(self, tokens): 161 | """Id of is assumed as 0 accroding to sentencepiece""" 162 | return convert_by_vocab(self.vocab, tokens, unk_info=0) 163 | 164 | def convert_ids_to_tokens(self, ids): 165 | """Token of unknown word is assumed as according to sentencepiece""" 166 | return convert_by_vocab(self.inv_vocab, ids, unk_info="") 167 | 168 | 169 | class SentencePieceTokenizer(object): 170 | """Runs SentencePiece tokenization (from raw text to tokens list)""" 171 | 172 | def __init__(self, model_file=None, do_lower_case=True): 173 | """Constructs a SentencePieceTokenizer.""" 174 | self.tokenizer = sp.SentencePieceProcessor() 175 | if self.tokenizer.Load(model_file): 176 | print("Loaded a trained SentencePiece model.") 177 | else: 178 | print("You have to give a path of trained SentencePiece model.") 179 | sys.exit(1) 180 | self.do_lower_case = do_lower_case 181 | 182 | def tokenize(self, text): 183 | """Tokenizes a piece of text.""" 184 | text = convert_to_unicode(text) 185 | if self.do_lower_case: 186 | text = text.lower() 187 | output_tokens = self.tokenizer.EncodeAsPieces(text) 188 | return output_tokens 189 | -------------------------------------------------------------------------------- /src/train-sentencepiece.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import configparser 4 | import glob 5 | import os 6 | import sentencepiece as sp 7 | 8 | CURDIR = os.path.dirname(os.path.abspath(__file__)) 9 | CONFIGPATH = os.path.join(CURDIR, os.pardir, 'config.ini') 10 | config = configparser.ConfigParser() 11 | config.read(CONFIGPATH) 12 | 13 | TEXTDIR = config['DATA']['TEXTDIR'] 14 | PREFIX = config['SENTENCEPIECE']['PREFIX'] 15 | VOCABSIZE = config['SENTENCEPIECE']['VOCABSIZE'] 16 | CTLSYMBOLS = config['SENTENCEPIECE']['CTLSYMBOLS'] 17 | 18 | 19 | def _get_text_file(text_dir=TEXTDIR): 20 | file_list = glob.glob(f'{text_dir}/**/*.txt') 21 | files = ",".join(file_list) 22 | return files 23 | 24 | 25 | def train(prefix=PREFIX, vocab_size=VOCABSIZE, ctl_symbols=CTLSYMBOLS): 26 | files = _get_text_file() 27 | command = f'--input={files} --model_prefix={prefix} --vocab_size={vocab_size} --control_symbols={ctl_symbols}' 28 | sp.SentencePieceTrainer.Train(command) 29 | 30 | 31 | def main(): 32 | train() 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | def str_to_value(input_str): 2 | """ 3 | Convert data type of value of dict to appropriate one. 4 | Assume there are only three types: str, int, float. 5 | """ 6 | if input_str.isalpha(): 7 | return input_str 8 | elif input_str.isdigit(): 9 | return int(input_str) 10 | else: 11 | return float(input_str) 12 | --------------------------------------------------------------------------------