├── README.md ├── LICENSE ├── src ├── inference.py ├── eval.py └── train.py ├── bert-japanese-ner-finetuning-kyoto.ipynb └── bert-japanese-ner-finetuning-tohoku.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # bert-japanese-ner-finetuning 2 | Code to perform finetuning of the BERT model. 3 | BERTモデルのファインチューニングで固有表現抽出用タスクのモデルを作成・使用するサンプルです 4 | 5 | ## Based on Kyoto 6 | `bert-japanese-ner-finetuning-kyoto.ipynb` は京都大学のBERTモデルをベースにファインチューニングを実行するものです 7 | 詳細は[こちら](https://zenn.dev/ken_11/articles/ca61812791c4d9) 8 | 9 | ## Based on Tohoku 10 | `bert-japanese-ner-finetuning-tohoku.ipynb` は東北大学のBERTモデルをベースにファインチューニングを実行するものです 11 | 12 | ## src directory 13 | 後者の東北大BERTベースのモデルについて、ノートブックではなくコードとして整理したものです 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ken 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | from transformers import BertJapaneseTokenizer, BertForTokenClassification 4 | 5 | 6 | class Inference: 7 | def __init__(self, args): 8 | self.tokenizer = BertJapaneseTokenizer.from_pretrained(args.model_dir) 9 | self.model = BertForTokenClassification.from_pretrained(args.model_dir) 10 | 11 | def run(self, text): 12 | tokenized_text = self.tokenizer.tokenize(text) 13 | inputs = self.tokenizer(tokenized_text, return_tensors="pt", padding='max_length', truncation=True, max_length=64, is_split_into_words=True) 14 | pred = self.model(**inputs).logits[0] 15 | pred = np.argmax(pred.detach().numpy(), axis=-1) 16 | labels = [] 17 | for i, label in enumerate(pred): 18 | if i + 1 > len(tokenized_text): 19 | continue 20 | labels.append(self.model.config.id2label[label]) 21 | print(tokenized_text) 22 | print(labels) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument("--model-dir", type=str, default="./dest") 29 | parser.add_argument("--input-text", type=str, default="田中さんの会社の社長は鈴木さんです") 30 | 31 | args, _ = parser.parse_known_args() 32 | evaluate = Inference(args) 33 | evaluate.run(args.input_text) 34 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | import noyaki 5 | from seqeval.metrics import classification_report 6 | from seqeval.scheme import BILOU 7 | from transformers import BertJapaneseTokenizer, BertForTokenClassification 8 | 9 | 10 | class Eval: 11 | def __init__(self, args): 12 | self.tokenizer = BertJapaneseTokenizer.from_pretrained(args.model_dir) 13 | self.model = BertForTokenClassification.from_pretrained(args.model_dir) 14 | self.args = args 15 | 16 | def run(self): 17 | features = self._load_from_json(self.args.training_data_path) 18 | y_true = [] 19 | for unit in features: 20 | y_true.append(unit["y"][:64]) 21 | y_pred = [] 22 | for unit in features: 23 | y_pred.append(self._inference(unit["x"])) 24 | print(classification_report(y_true, y_pred, mode='strict', scheme=BILOU)) 25 | 26 | def _inference(self, tokenized_text: list) -> list: 27 | inputs = self.tokenizer(tokenized_text, return_tensors="pt", padding='max_length', truncation=True, max_length=64, is_split_into_words=True) 28 | pred = self.model(**inputs).logits[0] 29 | pred = np.argmax(pred.detach().numpy(), axis=-1) 30 | labels = [] 31 | for i, label in enumerate(pred): 32 | if i + 1 > len(tokenized_text): 33 | continue 34 | labels.append(self.model.config.id2label[label]) 35 | return labels 36 | 37 | def _load_from_json(self, path: str) -> list: 38 | json_dict = json.load(open(path, "r")) 39 | features = [] 40 | for unit in json_dict: 41 | tokenized_text = self.tokenizer.tokenize(unit["text"]) 42 | spans = [] 43 | for entity in unit["entities"]: 44 | span_list = [] 45 | span_list.extend(entity["span"]) 46 | span_list.append(entity["type"]) 47 | spans.append(span_list) 48 | label = noyaki.convert(tokenized_text, spans, subword="##") 49 | features.append({"x": tokenized_text, "y": label}) 50 | return features 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | 56 | parser.add_argument("--model-dir", type=str, default="./dest") 57 | parser.add_argument("--training-data-path", type=str, default="./ner.json") 58 | 59 | args, _ = parser.parse_known_args() 60 | evaluate = Eval(args) 61 | evaluate.run() 62 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import ( 3 | BertForTokenClassification, BertJapaneseTokenizer, BertConfig, 4 | TrainingArguments, Trainer, 5 | EarlyStoppingCallback 6 | ) 7 | from sklearn.model_selection import train_test_split 8 | 9 | import torch 10 | import noyaki 11 | import json 12 | 13 | 14 | class Train: 15 | def __init__(self, args): 16 | self.tokenizer = BertJapaneseTokenizer.from_pretrained(args.model_name) 17 | self.args = args 18 | 19 | def run(self): 20 | features = self._load_from_json(self.args.training_data_path) 21 | train_data, val_data = train_test_split(features, test_size=0.2, random_state=123) 22 | train_data, test_data = train_test_split(train_data, test_size=0.1, random_state=123) 23 | self.label2id, id2label = self._create_label_vocab(features) 24 | 25 | config = BertConfig.from_pretrained(self.args.model_name, label2id=self.label2id, id2label=id2label) 26 | model = BertForTokenClassification.from_pretrained(self.args.model_name, config=config) 27 | 28 | args = TrainingArguments(output_dir=self.args.checkpoint_dir, 29 | do_train=True, 30 | do_eval=True, 31 | do_predict=True, 32 | per_device_train_batch_size=self.args.batch_size, 33 | per_device_eval_batch_size=self.args.batch_size, 34 | learning_rate=self.args.lr, 35 | num_train_epochs=self.args.num_epochs, 36 | evaluation_strategy="steps", 37 | eval_steps=self.args.save_freq, 38 | save_strategy="steps", 39 | save_steps=self.args.save_freq, 40 | load_best_model_at_end=True, 41 | ) 42 | trainer = Trainer(model=model, 43 | args=args, 44 | data_collator=self._data_collator, 45 | train_dataset=train_data, 46 | eval_dataset=val_data, 47 | callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] 48 | ) 49 | trainer.train() 50 | 51 | _, _, metrics = trainer.predict(test_data, metric_key_prefix="test") 52 | print(metrics) 53 | 54 | trainer.save_model(self.args.model_output_dir) 55 | 56 | def _load_from_json(self, path: str) -> list: 57 | json_dict = json.load(open(path, "r")) 58 | features = [] 59 | for unit in json_dict: 60 | tokenized_text = self.tokenizer.tokenize(unit["text"]) 61 | spans = [] 62 | for entity in unit["entities"]: 63 | span_list = [] 64 | span_list.extend(entity["span"]) 65 | span_list.append(entity["type"]) 66 | spans.append(span_list) 67 | label = noyaki.convert(tokenized_text, spans, subword="##") 68 | features.append({"x": tokenized_text, "y": label}) 69 | return features 70 | 71 | def _data_collator(self, features: list) -> dict: 72 | x = [f["x"] for f in features] 73 | y = [f["y"] for f in features] 74 | inputs = self.tokenizer(x, return_tensors=None, padding='max_length', truncation=True, max_length=64, is_split_into_words=True) 75 | input_labels = [] 76 | for labels in y: 77 | pad_list = [-100] * 64 78 | for i, label in enumerate(labels): 79 | pad_list.insert(i, self.label2id[label]) 80 | input_labels.append(pad_list[:64]) 81 | inputs['labels'] = input_labels 82 | batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in inputs.items()} 83 | return batch 84 | 85 | @staticmethod 86 | def _create_label_vocab(features: list) -> tuple: 87 | labels = [f["y"] for f in features] 88 | unique_labels = list(set(sum(labels, []))) 89 | label2id = {} 90 | for i, label in enumerate(unique_labels): 91 | label2id[label] = i 92 | id2label = {v: k for k, v in label2id.items()} 93 | return label2id, id2label 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser() 98 | 99 | parser.add_argument("--lr", type=float, default=3e-5) 100 | parser.add_argument("--model-name", type=str, default="cl-tohoku/bert-base-japanese-whole-word-masking") 101 | parser.add_argument("--model-output-dir", type=str, default="./dest") 102 | parser.add_argument("--training-data-path", type=str, default="./ner.json") 103 | parser.add_argument("--checkpoint-dir", type=str, default="./ckpt") 104 | parser.add_argument("--batch-size", type=int, default=8) 105 | parser.add_argument("--save-freq", type=int, default=100) 106 | parser.add_argument("--num-epochs", type=int, default=10) 107 | 108 | args, _ = parser.parse_known_args() 109 | 110 | train = Train(args) 111 | train.run() 112 | -------------------------------------------------------------------------------- /bert-japanese-ner-finetuning-kyoto.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 京大BERTファインチューニング\n", 8 | "[京大BERT](https://nlp.ist.i.kyoto-u.ac.jp/?ku_bert_japanese)をベースにして、[ストックマーク株式会社が公開しているner-wikipedia-dataset](https://github.com/stockmarkteam/ner-wikipedia-dataset)を使って固有表現抽出タスク向けにファインチューニングを行う例です \n", 9 | "PyTorch+transformersです(not Tensorflow)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## 準備\n", 17 | "学習に必要なものを用意します\n", 18 | "主に必要になるもの\n", 19 | "- [京大BERTモデル](https://nlp.ist.i.kyoto-u.ac.jp/?ku_bert_japanese)\n", 20 | "- [Juman++](https://nlp.ist.i.kyoto-u.ac.jp/?JUMAN%2B%2B)\n", 21 | "- [pyknp](https://github.com/ku-nlp/pyknp)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "!wget \"http://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=http://lotus.kuee.kyoto-u.ac.jp/nl-resource/JapaneseBertPretrainedModel/Japanese_L-24_H-1024_A-16_E-30_BPE_WWM_transformers.zip&name=Japanese_L-24_H-1024_A-16_E-30_BPE_WWM_transformers.zip\" -O bert.zip\n", 31 | "!unzip bert.zip" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "!mkdir kyoto\n", 41 | "!mv Japanese_L-24_H-1024_A-16_E-30_BPE_WWM_transformers kyoto/bert" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "※Juman++のインストールは大きめのインスタンスでないと時間がかかるorフリーズするかもしれません" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": { 55 | "collapsed": true, 56 | "jupyter": { 57 | "outputs_hidden": true 58 | } 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "!wget \"https://github.com/ku-nlp/jumanpp/releases/download/v2.0.0-rc2/jumanpp-2.0.0-rc2.tar.xz\"\n", 63 | "!tar xvf jumanpp-2.0.0-rc2.tar.xz\n", 64 | "!apt-get update -y\n", 65 | "!apt-get install -y cmake gcc build-essential\n", 66 | "%cd jumanpp-2.0.0-rc2\n", 67 | "!mkdir bld\n", 68 | "%cd bld\n", 69 | "!cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local\n", 70 | "!make install -j\n", 71 | "%cd ../.." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "!pip install --upgrade pip\n", 81 | "!pip install transformers[\"ja\"] numpy noyaki sklearn pyknp\n", 82 | "!pip install -U jupyter ipywidgets" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "!pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 22, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "!mkdir outputs\n", 101 | "!mkdir ckpt" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## 動作確認\n", 109 | "Juman++が動いていることを確認します" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "!echo \"こんにちは\" | jumanpp" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## 学習データのダウンロード\n", 126 | "今回は冒頭でも述べたとおり[ストックマーク株式会社が公開しているner-wikipedia-dataset](https://github.com/stockmarkteam/ner-wikipedia-dataset)を利用させていただきます" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "!wget \"https://github.com/stockmarkteam/ner-wikipedia-dataset/raw/main/ner.json\"" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## 学習データの確認\n", 143 | "ダウンロードしてきた`ner.json`がどのようになっているか軽く確認してみましょう" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 16, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "!head -15 ner.json" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "## 学習\n", 160 | "実際に学習をしてみます" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "from transformers import (\n", 170 | " BertForTokenClassification, BertTokenizer, BertConfig,\n", 171 | " TrainingArguments, Trainer,\n", 172 | " EarlyStoppingCallback\n", 173 | ")\n", 174 | "from pyknp import Juman\n", 175 | "from sklearn.model_selection import train_test_split\n", 176 | "\n", 177 | "import torch\n", 178 | "import noyaki\n", 179 | "import os\n", 180 | "import numpy as np\n", 181 | "import argparse\n", 182 | "import re\n", 183 | "import json" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "関数を定義しておきます" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 24, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "def load_from_json(path: str) -> list:\n", 200 | " jumanpp = Juman()\n", 201 | " json_dict = json.load(open(path, \"r\"))\n", 202 | " features = []\n", 203 | " for unit in json_dict:\n", 204 | " result = jumanpp.analysis(unit[\"text\"])\n", 205 | " tokenized_text = [mrph.midasi for mrph in result.mrph_list()]\n", 206 | " spans = []\n", 207 | " for entity in unit[\"entities\"]:\n", 208 | " span_list = []\n", 209 | " span_list.extend(entity[\"span\"])\n", 210 | " span_list.append(entity[\"type\"])\n", 211 | " spans.append(span_list)\n", 212 | " label = noyaki.convert(tokenized_text, spans)\n", 213 | " features.append({\"x\": tokenized_text, \"y\": label})\n", 214 | " return features" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 25, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "def create_label_vocab(features: list) -> tuple:\n", 224 | " labels = [f[\"y\"] for f in features]\n", 225 | " unique_labels = list(set(sum(labels, [])))\n", 226 | " label2id = {}\n", 227 | " for i, label in enumerate(unique_labels):\n", 228 | " label2id[label] = i\n", 229 | " id2label = {v: k for k, v in label2id.items()}\n", 230 | " return label2id, id2label" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 26, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "def data_collator(features: list) -> dict:\n", 240 | " x = [f[\"x\"] for f in features]\n", 241 | " y = [f[\"y\"] for f in features]\n", 242 | " inputs = tokenizer(x, return_tensors=None, padding='max_length', truncation=True, max_length=64, is_split_into_words=True)\n", 243 | " input_labels = []\n", 244 | " for labels in y:\n", 245 | " pad_list = [-100] * 64\n", 246 | " for i, label in enumerate(labels):\n", 247 | " pad_list.insert(i, label2id[label])\n", 248 | " input_labels.append(pad_list[:64])\n", 249 | " inputs['labels'] = input_labels\n", 250 | " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in inputs.items()}\n", 251 | " return batch" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "変数を定義しておきます" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 27, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "model_output_dir = \"./outputs\"\n", 268 | "ckpt_dir = \"./ckpt\"\n", 269 | "training_data_directory = \"./\"\n", 270 | "base_model_directory = \"./kyoto/bert\"\n", 271 | "batch_size = 32\n", 272 | "epochs = 3\n", 273 | "learning_rate = 3e-5\n", 274 | "save_freq = 200" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 28, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "tokenizer = BertTokenizer.from_pretrained(base_model_directory, tokenize_chinese_chars=False, do_lower_case=False)\n", 284 | "features = load_from_json(os.path.join(training_data_directory, \"ner.json\"))\n", 285 | "label2id, id2label = create_label_vocab(features)\n", 286 | "\n", 287 | "train_data, val_data = train_test_split(features, test_size=0.2, random_state=123)\n", 288 | "train_data, test_data = train_test_split(train_data, test_size=0.1, random_state=123)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "`features`の中身を確認してみます" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 29, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "print(features[:10])" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "`label2id`と`id2label`の中身を確認してみます" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 30, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "print(label2id)\n", 321 | "print(id2label)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "モデルの用意をします" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 31, 334 | "metadata": { 335 | "collapsed": true, 336 | "jupyter": { 337 | "outputs_hidden": true 338 | } 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "config = BertConfig.from_pretrained(base_model_directory, label2id=label2id, id2label=id2label)\n", 343 | "model = BertForTokenClassification.from_pretrained(base_model_directory, config=config)\n", 344 | "print(model)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "学習の設定をつくります" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 32, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "args = TrainingArguments(output_dir=ckpt_dir,\n", 361 | " do_train=True,\n", 362 | " do_eval=True,\n", 363 | " do_predict=True,\n", 364 | " per_device_train_batch_size=batch_size,\n", 365 | " per_device_eval_batch_size=batch_size,\n", 366 | " learning_rate=learning_rate,\n", 367 | " num_train_epochs=epochs,\n", 368 | " evaluation_strategy=\"steps\",\n", 369 | " eval_steps=save_freq,\n", 370 | " save_strategy=\"steps\",\n", 371 | " save_steps=save_freq,\n", 372 | " load_best_model_at_end=True,\n", 373 | " remove_unused_columns=False,\n", 374 | " )" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": {}, 380 | "source": [ 381 | "Trainerをつくります" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 33, 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "trainer = Trainer(model=model,\n", 391 | " args=args,\n", 392 | " data_collator=data_collator,\n", 393 | " train_dataset=train_data,\n", 394 | " eval_dataset=val_data,\n", 395 | " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n", 396 | " )" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": {}, 402 | "source": [ 403 | "学習を実行します" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 34, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "trainer.train()" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": {}, 418 | "source": [ 419 | "テストしてみます" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 18, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "_, _, metrics = trainer.predict(test_data, metric_key_prefix=\"test\")\n", 429 | "print(metrics)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 35, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "trainer.save_model(model_output_dir)" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 20, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "!ls outputs" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": {}, 453 | "source": [ 454 | "## 推論\n", 455 | "できあがったモデルを使って推論を行ってみます" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 36, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "text = \"田中さんはhogehoge株式会社の社員です\"\n", 465 | "model = BertForTokenClassification.from_pretrained(\"outputs\")\n", 466 | "\n", 467 | "jumanpp = Juman()\n", 468 | "result = jumanpp.analysis(text)\n", 469 | "tokenized_text = [mrph.midasi for mrph in result.mrph_list()]\n", 470 | "inputs = tokenizer(tokenized_text, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=64, is_split_into_words=True)\n", 471 | "pred = model(**inputs).logits[0]\n", 472 | "pred = np.argmax(pred.detach().numpy(), axis=-1)\n", 473 | "labels = []\n", 474 | "for i, label in enumerate(pred):\n", 475 | " if i + 1 > len(tokenized_text):\n", 476 | " continue\n", 477 | " labels.append(model.config.id2label[label])\n", 478 | " print(f\"{tokenized_text[i]}: {model.config.id2label[label]}\")" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 37, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "print(tokenized_text)\n", 488 | "print(labels)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [] 497 | } 498 | ], 499 | "metadata": { 500 | "kernelspec": { 501 | "display_name": "Python 3 (ipykernel)", 502 | "language": "python", 503 | "name": "python3" 504 | }, 505 | "language_info": { 506 | "codemirror_mode": { 507 | "name": "ipython", 508 | "version": 3 509 | }, 510 | "file_extension": ".py", 511 | "mimetype": "text/x-python", 512 | "name": "python", 513 | "nbconvert_exporter": "python", 514 | "pygments_lexer": "ipython3", 515 | "version": "3.9.7" 516 | } 517 | }, 518 | "nbformat": 4, 519 | "nbformat_minor": 4 520 | } 521 | -------------------------------------------------------------------------------- /bert-japanese-ner-finetuning-tohoku.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 東北大BERTをベースにファインチューニングで固有表現抽出用モデルを作成する\n", 8 | "huggingfaceで公開されている東北大BERTこと `cl-tohoku/bert-base-japanese-whole-word-masking` をベースに、ファインチューニングをして固有表現抽出タスク用のモデルを作成します" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## 準備" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "### ライブラリのインストール\n", 23 | "必要なライブラリをインストールします" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "!pip3 install --upgrade pip\n", 33 | "!pip3 install transformers[\"ja\"] numpy noyaki sklearn seqeval\n", 34 | "!pip3 install -U jupyter ipywidgets\n", 35 | "!pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "### 学習データのダウンロード\n", 43 | "今回は[ストックマーク株式会社が公開しているner-wikipedia-dataset](https://github.com/stockmarkteam/ner-wikipedia-dataset)を利用させていただきます" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "!wget \"https://github.com/stockmarkteam/ner-wikipedia-dataset/raw/main/ner.json\"" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "ダウンロードした学習データを確認してみましょう" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "!head -15 ner.json" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## 学習の実行\n", 76 | "実際に学習を行っていきます" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "model_output_dir = \"./dest\"\n", 86 | "model_name = \"cl-tohoku/bert-base-japanese-whole-word-masking\"" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "### Tokenizerの準備\n", 94 | "Tokenizerを用意します" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "from transformers import BertJapaneseTokenizer\n", 104 | "\n", 105 | "tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "なお、NEologdを使いたい場合など、TokenizerのMeCabにオプションを渡したい場合は[こちら](https://qiita.com/ken11_/items/fd20e69103bb0ce698af)を参考にしてください" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "### 学習データの前処理\n", 120 | "先ほどダウンロードしてきた学習データを、学習に使えるように前処理していきます" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "import noyaki\n", 130 | "import json\n", 131 | "\n", 132 | "def load_from_json(path: str) -> list:\n", 133 | " json_dict = json.load(open(path, \"r\"))\n", 134 | " features = []\n", 135 | " for unit in json_dict:\n", 136 | " tokenized_text = tokenizer.tokenize(unit[\"text\"])\n", 137 | " spans = []\n", 138 | " for entity in unit[\"entities\"]:\n", 139 | " span_list = []\n", 140 | " span_list.extend(entity[\"span\"])\n", 141 | " span_list.append(entity[\"type\"])\n", 142 | " spans.append(span_list)\n", 143 | " label = noyaki.convert(tokenized_text, spans, subword=\"##\")\n", 144 | " features.append({\"x\": tokenized_text, \"y\": label})\n", 145 | " return features" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "features = load_from_json(\"./ner.json\")" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "featuresの中身を確認します" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "print(features[:10])" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "学習データを `train`, `valid`, `test` に分割します" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "from sklearn.model_selection import train_test_split\n", 187 | "\n", 188 | "train_data, val_data = train_test_split(features, test_size=0.2, random_state=123)\n", 189 | "train_data, test_data = train_test_split(train_data, test_size=0.1, random_state=123)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "### ラベル辞書の作成\n", 197 | "ラベルの辞書を作成します \n", 198 | "これはあとでmodelのconfigに渡す情報となります" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "def create_label_vocab(features: list) -> tuple:\n", 208 | " labels = [f[\"y\"] for f in features]\n", 209 | " unique_labels = list(set(sum(labels, [])))\n", 210 | " label2id = {}\n", 211 | " for i, label in enumerate(unique_labels):\n", 212 | " label2id[label] = i\n", 213 | " id2label = {v: k for k, v in label2id.items()}\n", 214 | " return label2id, id2label" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "label2id, id2label = create_label_vocab(features)\n", 224 | "print(label2id)\n", 225 | "print(id2label)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "### モデルの準備\n", 233 | "ベースモデルを用意します \n", 234 | "ここで先ほどの `label2id`, `id2label` を渡してあげることで、推論時のラベル復元が楽になります" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "from transformers import BertForTokenClassification, BertConfig\n", 244 | "\n", 245 | "config = BertConfig.from_pretrained(model_name, label2id=label2id, id2label=id2label)\n", 246 | "model = BertForTokenClassification.from_pretrained(model_name, config=config)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "print(model)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "### Trainerの準備\n", 263 | "TrainingArgumentsを設定し、Trainerを作成していきます \n", 264 | "Trainerにはdata_collatorを渡してあげる必要があるので、data_collatorも作成します" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": {}, 270 | "source": [ 271 | "data_collatorは[transformersにすでにあるもの](https://huggingface.co/docs/transformers/main_classes/data_collator)を利用することもできますが、ここでは自前で定義していきます" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "import torch\n", 281 | "\n", 282 | "def data_collator(features: list) -> dict:\n", 283 | " x = [f[\"x\"] for f in features]\n", 284 | " y = [f[\"y\"] for f in features]\n", 285 | " inputs = tokenizer(x, return_tensors=None, padding='max_length', truncation=True, max_length=64, is_split_into_words=True)\n", 286 | " input_labels = []\n", 287 | " for labels in y:\n", 288 | " pad_list = [-100] * 64\n", 289 | " for i, label in enumerate(labels):\n", 290 | " pad_list.insert(i, label2id[label])\n", 291 | " input_labels.append(pad_list[:64])\n", 292 | " inputs['labels'] = input_labels\n", 293 | " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in inputs.items()}\n", 294 | " return batch" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "ハイパーパラメータなどを定義しておきます" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "ckpt_dir = \"./ckpt\"\n", 311 | "batch_size = 8\n", 312 | "epochs = 3\n", 313 | "learning_rate = 3e-5\n", 314 | "save_freq = 100" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "from transformers import TrainingArguments\n", 324 | "\n", 325 | "args = TrainingArguments(output_dir=ckpt_dir,\n", 326 | " do_train=True,\n", 327 | " do_eval=True,\n", 328 | " do_predict=True,\n", 329 | " per_device_train_batch_size=batch_size,\n", 330 | " per_device_eval_batch_size=batch_size,\n", 331 | " learning_rate=learning_rate,\n", 332 | " num_train_epochs=epochs,\n", 333 | " evaluation_strategy=\"steps\",\n", 334 | " eval_steps=save_freq,\n", 335 | " save_strategy=\"steps\",\n", 336 | " save_steps=save_freq,\n", 337 | " load_best_model_at_end=True,\n", 338 | " remove_unused_columns=False,\n", 339 | " )" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "from transformers import Trainer, EarlyStoppingCallback\n", 349 | "\n", 350 | "trainer = Trainer(model=model,\n", 351 | " args=args,\n", 352 | " data_collator=data_collator,\n", 353 | " train_dataset=train_data,\n", 354 | " eval_dataset=val_data,\n", 355 | " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n", 356 | " )" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "学習を実行します" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "trainer.train()" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": {}, 378 | "source": [ 379 | "できあがったモデルをテストします" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "_, _, metrics = trainer.predict(test_data, metric_key_prefix=\"test\")\n", 389 | "print(metrics)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "metadata": {}, 395 | "source": [ 396 | "モデルをsaveします" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "trainer.save_model(model_output_dir)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "## モデルの検証\n", 413 | "[seqeval](https://github.com/chakki-works/seqeval)を使って実際のモデル精度を検証していきます" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "### 推論用の関数を定義\n", 421 | "学習したモデルを使って推論をするための関数を定義します" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "import numpy as np\n", 431 | "\n", 432 | "inference_model = BertForTokenClassification.from_pretrained(model_output_dir)\n", 433 | "def inference(tokenized_text: list) -> list:\n", 434 | " inputs = tokenizer(tokenized_text, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=64, is_split_into_words=True)\n", 435 | " pred = inference_model(**inputs).logits[0]\n", 436 | " pred = np.argmax(pred.detach().numpy(), axis=-1)\n", 437 | " labels = []\n", 438 | " for i, label in enumerate(pred):\n", 439 | " if i + 1 > len(tokenized_text):\n", 440 | " continue\n", 441 | " labels.append(inference_model.config.id2label[label])\n", 442 | " return labels" 443 | ] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "metadata": {}, 448 | "source": [ 449 | "正解データを用意します" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "y_true = []\n", 459 | "for unit in test_data:\n", 460 | " # 今回はmax_lengthを64にしているので正解データも切り詰めておく\n", 461 | " y_true.append(unit[\"y\"][:64])\n", 462 | "print(y_true)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": {}, 468 | "source": [ 469 | "同様に推論結果も用意します" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "y_pred = []\n", 479 | "for unit in test_data:\n", 480 | " y_pred.append(inference(unit[\"x\"]))\n", 481 | "print(y_pred)" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "### seqevalのclassification_reportを実行\n", 489 | "seqevalのclassification_reportを使って検証します" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "metadata": {}, 496 | "outputs": [], 497 | "source": [ 498 | "from seqeval.metrics import classification_report\n", 499 | "from seqeval.scheme import BILOU\n", 500 | "\n", 501 | "print(classification_report(y_true, y_pred, mode='strict', scheme=BILOU))" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "seqevalのstrictモードは厳密なので精度は低くなりがちです \n", 509 | "BILUOではstrictモードしかサポートされていないため、適宜BILUOをBIOに変換して使用するなど、タスクに合った精度検証を行ってください" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "## 推論\n", 517 | "最後に、通常の推論用コードを紹介します" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "def inference(text: str):\n", 527 | " model = BertForTokenClassification.from_pretrained(model_output_dir)\n", 528 | " tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)\n", 529 | " \n", 530 | " \n", 531 | " tokenized_text = tokenizer.tokenize(text)\n", 532 | " inputs = tokenizer(tokenized_text, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=64, is_split_into_words=True)\n", 533 | " pred = model(**inputs).logits[0]\n", 534 | " pred = np.argmax(pred.detach().numpy(), axis=-1)\n", 535 | " labels = []\n", 536 | " for i, label in enumerate(pred):\n", 537 | " if i + 1 > len(tokenized_text):\n", 538 | " continue\n", 539 | " labels.append(inference_model.config.id2label[label])\n", 540 | " print(tokenized_text)\n", 541 | " print(labels)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": { 548 | "scrolled": true 549 | }, 550 | "outputs": [], 551 | "source": [ 552 | "print(inference(\"田中さんの会社の社長は鈴木さんです\"))" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [] 561 | } 562 | ], 563 | "metadata": { 564 | "instance_type": "ml.g4dn.xlarge", 565 | "kernelspec": { 566 | "display_name": "Python 3 (ipykernel)", 567 | "language": "python", 568 | "name": "python3" 569 | }, 570 | "language_info": { 571 | "codemirror_mode": { 572 | "name": "ipython", 573 | "version": 3 574 | }, 575 | "file_extension": ".py", 576 | "mimetype": "text/x-python", 577 | "name": "python", 578 | "nbconvert_exporter": "python", 579 | "pygments_lexer": "ipython3", 580 | "version": "3.9.7" 581 | } 582 | }, 583 | "nbformat": 4, 584 | "nbformat_minor": 4 585 | } 586 | --------------------------------------------------------------------------------