├── .gitignore ├── README.md ├── utils.py ├── text_preprocess.ipynb └── text_bert.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | /__pycache__ 2 | .ipynb_checkpoints -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ja_text_bert 2 | 3 | BERT with Japanese text 4 | 5 | # Reproduing Results 6 | 7 | 1. Run `jupyter notebook` or `jupyter lab` 8 | 9 | 2. Run on jupyter `text_preprocess.ipynb` 10 | 11 | 3. Run on jupyter `text_bert.ipynb` 12 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | 5 | # 活性化関数 6 | class GELU(nn.Module): 7 | def forward(self, x): 8 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 9 | 10 | # 位置情報を考慮したFFN 11 | class PositionwiseFeedForward(nn.Module): 12 | def __init__(self, d_model, d_ff, dropout=0.1): 13 | super(PositionwiseFeedForward, self).__init__() 14 | self.w_1 = nn.Linear(d_model, d_ff) 15 | self.w_2 = nn.Linear(d_ff, d_model) 16 | self.dropout = nn.Dropout(dropout) 17 | self.activation = GELU() 18 | 19 | def forward(self, x): 20 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 21 | 22 | # 正規化層 23 | class LayerNorm(nn.Module): 24 | def __init__(self, features, eps=1e-6): 25 | super(LayerNorm, self).__init__() 26 | self.a_2 = nn.Parameter(torch.ones(features)) 27 | self.b_2 = nn.Parameter(torch.zeros(features)) 28 | self.eps = eps 29 | 30 | def forward(self, x): 31 | mean = x.mean(-1, keepdim=True) 32 | std = x.std(-1, keepdim=True) 33 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 34 | 35 | 36 | class SublayerConnection(nn.Module): 37 | def __init__(self, size, dropout): 38 | super(SublayerConnection, self).__init__() 39 | self.norm = LayerNorm(size) 40 | self.dropout = nn.Dropout(dropout) 41 | 42 | def forward(self, x, sublayer): 43 | return x + self.dropout(sublayer(self.norm(x))) -------------------------------------------------------------------------------- /text_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "日本語wikiからコーパスを作成するスクリプトです.
\n", 8 | "https://dumps.wikimedia.org/jawiki/latest/
\n", 9 | "こちらのサイトから最新版の\"pages-articles\"のアドレスを手に入れてください.
" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "! wget https://dumps.wikimedia.org/jawiki/latest/jawiki-latest-pages-articles.xml.bz2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "ダンプデータには不要なマークアップなどが含まれているので、取り除くためのテキストクリーニング用のスクリプトをgitから持ってきます" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "! git clone https://github.com/attardi/wikiextractor.git" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "日本語wikiに対してテキストクリーニングを実行します" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "! python wikiextractor/WikiExtractor.py -o extracted jawiki-latest-pages-articles.xml.bz2" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "テキストに前処理を加えた上で,複数のtxtファイルをひとつに結合します" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "import glob\n", 67 | "from bs4 import BeautifulSoup\n", 68 | "\n", 69 | "with open('./tmp.txt','w') as f:\n", 70 | " for directory in glob.glob('./extracted/*'):\n", 71 | " for name in glob.glob(directory+'/*'):\n", 72 | " with open(name, 'r') as r:\n", 73 | " for line in r:\n", 74 | " # titleを削除する\n", 75 | " if '' in line:\n", 79 | " f.write('\\n')\n", 80 | " continue\n", 81 | " else:\n", 82 | " # 空白・改行削除、大文字を小文字に変換\n", 83 | " text = BeautifulSoup(line.strip()).text.lower()\n", 84 | " f.write(text)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "ここからはBERTのトレーニング用にテキストファイルを整形していきます.
\n", 92 | "文章を単語ごとに分割し, ひとつの単元の中に偶数個の文章が含まれるように調整します." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "import linecache\n", 102 | "import random\n", 103 | "import MeCab" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "random.seed(42)\n", 113 | "filename = 'tmp.txt'\n", 114 | "save_file = 'even_rows100M.txt'\n", 115 | "LIMIT_BYTE = 100000000 # 100Mbyte\n", 116 | "# t = MeCab.Tagger('-Owakati') # Neologdを辞書に使っている人場合はそちらを使用するのがベターです\n", 117 | "t = MeCab.Tagger('-d /usr/local/lib/mecab/dic/mecab-ipadic-neologd/ -Owakati')\n", 118 | "\n", 119 | "def get_byte_num(s):\n", 120 | " return len(s.encode('utf-8'))" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "with open(save_file, 'w') as f:\n", 130 | " count_byte = 0\n", 131 | " with open(filename) as r:\n", 132 | " for text in r:\n", 133 | " print('{} bytes'.format(count_byte))\n", 134 | " text = t.parse(text).strip()\n", 135 | " # 一文ごとに分割する\n", 136 | " text = text.split('。')\n", 137 | " # 空白要素は捨てる\n", 138 | " text = [t.strip() for t in text if t]\n", 139 | " # 一単元の文書が偶数個の文章から成るようにする(BERTのデータセットの都合上)\n", 140 | " max_text_len = len(text) // 2\n", 141 | " text = text[:max_text_len * 2]\n", 142 | " text = '\\n'.join(text)\n", 143 | " f.write(text)\n", 144 | " count_byte += get_byte_num(text)\n", 145 | " if count_byte >= LIMIT_BYTE:\n", 146 | " break" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "これでBERTの学習に使うデータセットができました.
\n", 154 | "今度はTraining用とValidation用のデータに分割します." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "num_lines = sum(1 for line in open(save_file))\n", 164 | "print('Base file lines : ', num_lines)\n", 165 | "# 全体の80%をTraining dataに当てます\n", 166 | "train_lines = int(num_lines * 0.8)\n", 167 | "print('Train file lines : ', train_lines)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "dataは前処理済みテキスト保存場所
\n", 175 | "outputは訓練モデル保存場所として作成" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 1, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "! mkdir -p data output" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "out_file_name_temp = './data/splitted_%d.txt'\n", 194 | "\n", 195 | "split_index = 1\n", 196 | "line_index = 1\n", 197 | "out_file = open(out_file_name_temp % (split_index,), 'w')\n", 198 | "in_file = open(save_file)\n", 199 | "line = in_file.readline()\n", 200 | "while line:\n", 201 | " if line_index > train_lines:\n", 202 | " print('Starting file: %d' % split_index)\n", 203 | " out_file.close()\n", 204 | " split_index = split_index + 1\n", 205 | " line_index = 1\n", 206 | " out_file = open(out_file_name_temp % (split_index,), 'w')\n", 207 | " out_file.write(line)\n", 208 | " line_index = line_index + 1\n", 209 | " line = in_file.readline()\n", 210 | " \n", 211 | "out_file.close()\n", 212 | "in_file.close()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "print('Train file lines : ', sum(1 for line in open('./data/splitted_1.txt')))\n", 222 | "print('Valid file lines : ', sum(1 for line in open('./data/splitted_2.txt')))" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "これにてテキストの前処理は完了です!" 230 | ] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.6.1" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 2 254 | } 255 | -------------------------------------------------------------------------------- /text_bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "必要ModuleをImport" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "from torch.optim import Adam\n", 20 | "from torch.utils.data import DataLoader\n", 21 | "import math\n", 22 | "\n", 23 | "import pickle\n", 24 | "import tqdm\n", 25 | "from collections import Counter\n", 26 | "\n", 27 | "from torch.utils.data import Dataset\n", 28 | "import random\n", 29 | "import numpy as np\n", 30 | "\n", 31 | "from utils import GELU, PositionwiseFeedForward, LayerNorm, SublayerConnection, LayerNorm\n", 32 | "\n", 33 | "import matplotlib\n", 34 | "%matplotlib inline\n", 35 | "import matplotlib.pyplot as plt\n", 36 | "from ipywidgets import FloatProgress\n", 37 | "from IPython.display import display, clear_output" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "input_train_txt = './data/splitted_1.txt'\n", 47 | "input_valid_txt = './data/splitted_2.txt'\n", 48 | "processed_train_txt = './data/train_X.txt'\n", 49 | "processed_valid_txt = './data/valid_X.txt'" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "Next Sentence Predictionのために, 意味的に連続する文章をtab区切りで並べる前処理をデータセットに対して行います." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# 偶数行の文章を奇数行の文章と接続するメソッド\n", 66 | "def load_data(path):\n", 67 | " with open(path, encoding='utf-8') as f:\n", 68 | " even_rows = []\n", 69 | " odd_rows = []\n", 70 | " all_f = f.readlines()\n", 71 | " for row in all_f[2::2]:\n", 72 | " even_rows.append(row.strip().replace('\\n', ''))\n", 73 | " for row in all_f[1::2]:\n", 74 | " odd_rows.append(row.strip().replace('\\n', ''))\n", 75 | " min_rows_len = int(min(len(even_rows), len(odd_rows)))\n", 76 | " even_rows = even_rows[:min_rows_len]\n", 77 | " odd_rows = odd_rows[:min_rows_len]\n", 78 | "\n", 79 | " concat_rows = []\n", 80 | " for even_r, odd_r in zip(even_rows, odd_rows):\n", 81 | " concat_r = '\\t'.join([even_r, odd_r])\n", 82 | " concat_rows.append(concat_r)\n", 83 | " return concat_rows" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "train_data = load_data(input_train_txt)\n", 93 | "valid_data = load_data(input_valid_txt)\n", 94 | "\n", 95 | "# ランダムに並び替える\n", 96 | "random.shuffle(train_data)\n", 97 | "random.shuffle(valid_data)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "with open(processed_train_txt, 'w') as f:\n", 107 | " f.write('\\n'.join(train_data))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "with open(processed_valid_txt, 'w') as f:\n", 117 | " f.write('\\n'.join(valid_data))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "Attentionセルを定義する" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "class Attention(nn.Module):\n", 134 | " \"\"\"\n", 135 | " Scaled Dot Product Attention\n", 136 | " \"\"\"\n", 137 | "\n", 138 | " def forward(self, query, key, value, mask=None, dropout=None):\n", 139 | " scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))\n", 140 | "\n", 141 | " if mask is not None:\n", 142 | " scores = scores.masked_fill(mask == 0, -1e9)\n", 143 | "\n", 144 | " p_attn = F.softmax(scores, dim=-1)\n", 145 | "\n", 146 | " if dropout is not None:\n", 147 | " p_attn = dropout(p_attn)\n", 148 | "\n", 149 | " return torch.matmul(p_attn, value), p_attn\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "Multi Head Attentionを定義する" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "class MultiHeadedAttention(nn.Module):\n", 166 | "\n", 167 | " def __init__(self, h, d_model, dropout=0.1):\n", 168 | " super().__init__()\n", 169 | " assert d_model % h == 0\n", 170 | "\n", 171 | " # We assume d_v always equals d_k\n", 172 | " self.d_k = d_model // h\n", 173 | " self.h = h\n", 174 | "\n", 175 | " self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])\n", 176 | " self.output_linear = nn.Linear(d_model, d_model)\n", 177 | " self.attention = Attention()\n", 178 | "\n", 179 | " self.dropout = nn.Dropout(p=dropout)\n", 180 | "\n", 181 | " def forward(self, query, key, value, mask=None):\n", 182 | " batch_size = query.size(0)\n", 183 | "\n", 184 | " query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))]\n", 185 | "\n", 186 | " x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)\n", 187 | "\n", 188 | " x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)\n", 189 | "\n", 190 | " return self.output_linear(x)\n" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "Transformerを定義する" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "class TransformerBlock(nn.Module):\n", 207 | " \"\"\"\n", 208 | " Bidirectional Encoder = Transformer (self-attention)\n", 209 | " Transformer = MultiHead_Attention + Feed_Forward with sublayer connection\n", 210 | " \"\"\"\n", 211 | "\n", 212 | " def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):\n", 213 | " \"\"\"\n", 214 | " :param hidden: hidden size of transformer\n", 215 | " :param attn_heads: head sizes of multi-head attention\n", 216 | " :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size\n", 217 | " :param dropout: dropout rate\n", 218 | " \"\"\"\n", 219 | "\n", 220 | " super().__init__()\n", 221 | " self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)\n", 222 | " self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)\n", 223 | " self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n", 224 | " self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n", 225 | " self.dropout = nn.Dropout(p=dropout)\n", 226 | "\n", 227 | " def forward(self, x, mask):\n", 228 | " x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))\n", 229 | " x = self.output_sublayer(x, self.feed_forward)\n", 230 | " return self.dropout(x)\n" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "BERTクラスを定義する" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "class BERT(nn.Module):\n", 247 | "\n", 248 | " def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):\n", 249 | " \"\"\"\n", 250 | " :param vocab_size: vocab_size of total words\n", 251 | " :param hidden: BERT model hidden size\n", 252 | " :param n_layers: numbers of Transformer blocks(layers)\n", 253 | " :param attn_heads: number of attention heads\n", 254 | " :param dropout: dropout rate\n", 255 | " \"\"\"\n", 256 | "\n", 257 | " super().__init__()\n", 258 | " self.hidden = hidden\n", 259 | " self.n_layers = n_layers\n", 260 | " self.attn_heads = attn_heads\n", 261 | "\n", 262 | " self.feed_forward_hidden = hidden * 4\n", 263 | "\n", 264 | " # embedding for BERT\n", 265 | " self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden, dropout=dropout)\n", 266 | "\n", 267 | " self.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])\n", 268 | "\n", 269 | " def forward(self, x, segment_info):\n", 270 | " # xの中で0以上は1, 0未満は0として, maskテンソルを作る\n", 271 | " mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)\n", 272 | "\n", 273 | " x = self.embedding(x, segment_info)\n", 274 | "\n", 275 | " for transformer in self.transformer_blocks:\n", 276 | " x = transformer.forward(x, mask)\n", 277 | " return x\n" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "BERTのEmbedding層を定義する" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "class TokenEmbedding(nn.Embedding):\n", 294 | " def __init__(self, vocab_size, embed_size=512):\n", 295 | " super().__init__(vocab_size, embed_size, padding_idx=0)\n", 296 | "\n", 297 | "class PositionalEmbedding(nn.Module):\n", 298 | "\n", 299 | " def __init__(self, d_model, max_len=512):\n", 300 | " super().__init__()\n", 301 | "\n", 302 | " pe = torch.zeros(max_len, d_model).float()\n", 303 | " pe.require_grad = False\n", 304 | "\n", 305 | " position = torch.arange(0, max_len).float().unsqueeze(1)\n", 306 | " div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp()\n", 307 | "\n", 308 | " pe[:, 0::2] = torch.sin(position * div_term)\n", 309 | " pe[:, 1::2] = torch.cos(position * div_term)\n", 310 | "\n", 311 | " pe = pe.unsqueeze(0)\n", 312 | " self.register_buffer('pe', pe)\n", 313 | "\n", 314 | " def forward(self, x):\n", 315 | " return self.pe[:, :x.size(1)]\n", 316 | "\n", 317 | "class SegmentEmbedding(nn.Embedding):\n", 318 | " def __init__(self, embed_size=512):\n", 319 | " super().__init__(3, embed_size, padding_idx=0)\n", 320 | "\n", 321 | "class BERTEmbedding(nn.Module):\n", 322 | " \"\"\"\n", 323 | " BERT Embedding which is consisted with under features\n", 324 | " 1. TokenEmbedding : 通常のEMbedding\n", 325 | " 2. PositionalEmbedding : sin, cosを用いた位置情報付きEmbedding\n", 326 | " 2. SegmentEmbedding : Sentenceのセグメント情報 (sent_A:1, sent_B:2)\n", 327 | " \"\"\"\n", 328 | " def __init__(self, vocab_size, embed_size, dropout=0.1):\n", 329 | " super().__init__()\n", 330 | " self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)\n", 331 | " self.position = PositionalEmbedding(d_model=self.token.embedding_dim)\n", 332 | " self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)\n", 333 | " self.dropout = nn.Dropout(p=dropout)\n", 334 | " self.embed_size = embed_size\n", 335 | "\n", 336 | " def forward(self, sequence, segment_label):\n", 337 | " x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)\n", 338 | " return self.dropout(x)\n" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "学習用にマスク予測・隣接文予測の層を追加する" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [ 354 | "class BERTLM(nn.Module):\n", 355 | " \"\"\"\n", 356 | " BERT Language Model\n", 357 | " Next Sentence Prediction Model + Masked Language Model\n", 358 | " \"\"\"\n", 359 | "\n", 360 | " def __init__(self, bert: BERT, vocab_size):\n", 361 | " \"\"\"\n", 362 | " :param bert: BERT model which should be trained\n", 363 | " :param vocab_size: total vocab size for masked_lm\n", 364 | " \"\"\"\n", 365 | "\n", 366 | " super().__init__()\n", 367 | " self.bert = bert\n", 368 | " self.next_sentence = NextSentencePrediction(self.bert.hidden)\n", 369 | " self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)\n", 370 | "\n", 371 | " def forward(self, x, segment_label):\n", 372 | " x = self.bert(x, segment_label)\n", 373 | " return self.next_sentence(x), self.mask_lm(x)\n", 374 | "\n", 375 | "\n", 376 | "class NextSentencePrediction(nn.Module):\n", 377 | " \"\"\"\n", 378 | " 2クラス分類問題 : is_next, is_not_next\n", 379 | " \"\"\"\n", 380 | "\n", 381 | " def __init__(self, hidden):\n", 382 | " \"\"\"\n", 383 | " :param hidden: BERT model output size\n", 384 | " \"\"\"\n", 385 | " super().__init__()\n", 386 | " self.linear = nn.Linear(hidden, 2)\n", 387 | " self.softmax = nn.LogSoftmax(dim=-1)\n", 388 | "\n", 389 | " def forward(self, x):\n", 390 | " return self.softmax(self.linear(x[:, 0]))\n", 391 | "\n", 392 | "\n", 393 | "class MaskedLanguageModel(nn.Module):\n", 394 | " \"\"\"\n", 395 | " 入力系列のMASKトークンから元の単語を予測する\n", 396 | " nクラス分類問題, nクラス : vocab_size\n", 397 | " \"\"\"\n", 398 | "\n", 399 | " def __init__(self, hidden, vocab_size):\n", 400 | " \"\"\"\n", 401 | " :param hidden: output size of BERT model\n", 402 | " :param vocab_size: total vocab size\n", 403 | " \"\"\"\n", 404 | " super().__init__()\n", 405 | " self.linear = nn.Linear(hidden, vocab_size)\n", 406 | " self.softmax = nn.LogSoftmax(dim=-1)\n", 407 | "\n", 408 | " def forward(self, x):\n", 409 | " return self.softmax(self.linear(x))" 410 | ] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "metadata": {}, 415 | "source": [ 416 | "BERT用のVocabを生成するクラスを定義する" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "import pickle\n", 426 | "import tqdm\n", 427 | "from collections import Counter\n", 428 | "\n", 429 | "\n", 430 | "class TorchVocab(object):\n", 431 | " \"\"\"\n", 432 | " :property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト\n", 433 | " :property stoi: collections.defaultdict, string → id の対応を示す辞書\n", 434 | " :property itos: collections.defaultdict, id → string の対応を示す辞書\n", 435 | " \"\"\"\n", 436 | " def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''],\n", 437 | " vectors=None, unk_init=None, vectors_cache=None):\n", 438 | " \"\"\"\n", 439 | " :param coutenr: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter\n", 440 | " :param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone\n", 441 | " :param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない.\n", 442 | " :param specials: list of str, vocabularyにあらかじめ登録するtoken\n", 443 | " :param vecors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors\n", 444 | " \"\"\"\n", 445 | " self.freqs = counter\n", 446 | " counter = counter.copy()\n", 447 | " min_freq = max(min_freq, 1)\n", 448 | "\n", 449 | " self.itos = list(specials)\n", 450 | " # special tokensの出現頻度はvocabulary作成の際にカウントされない\n", 451 | " for tok in specials:\n", 452 | " del counter[tok]\n", 453 | "\n", 454 | " max_size = None if max_size is None else max_size + len(self.itos)\n", 455 | "\n", 456 | " # まず頻度でソートし、次に文字順で並び替える\n", 457 | " words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])\n", 458 | " words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)\n", 459 | " \n", 460 | " # 出現頻度がmin_freq未満のものはvocabに加えない\n", 461 | " for word, freq in words_and_frequencies:\n", 462 | " if freq < min_freq or len(self.itos) == max_size:\n", 463 | " break\n", 464 | " self.itos.append(word)\n", 465 | "\n", 466 | " # dictのk,vをいれかえてstoiを作成する\n", 467 | " self.stoi = {tok: i for i, tok in enumerate(self.itos)}\n", 468 | "\n", 469 | " self.vectors = None\n", 470 | " if vectors is not None:\n", 471 | " self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)\n", 472 | " else:\n", 473 | " assert unk_init is None and vectors_cache is None\n", 474 | "\n", 475 | " def __eq__(self, other):\n", 476 | " if self.freqs != other.freqs:\n", 477 | " return False\n", 478 | " if self.stoi != other.stoi:\n", 479 | " return False\n", 480 | " if self.itos != other.itos:\n", 481 | " return False\n", 482 | " if self.vectors != other.vectors:\n", 483 | " return False\n", 484 | " return True\n", 485 | "\n", 486 | " def __len__(self):\n", 487 | " return len(self.itos)\n", 488 | "\n", 489 | " def vocab_rerank(self):\n", 490 | " self.stoi = {word: i for i, word in enumerate(self.itos)}\n", 491 | "\n", 492 | " def extend(self, v, sort=False):\n", 493 | " words = sorted(v.itos) if sort else v.itos\n", 494 | " for w in words:\n", 495 | " if w not in self.stoi:\n", 496 | " self.itos.append(w)\n", 497 | " self.stoi[w] = len(self.itos) - 1\n", 498 | "\n", 499 | "\n", 500 | "class Vocab(TorchVocab):\n", 501 | " def __init__(self, counter, max_size=None, min_freq=1):\n", 502 | " self.pad_index = 0\n", 503 | " self.unk_index = 1\n", 504 | " self.eos_index = 2\n", 505 | " self.sos_index = 3\n", 506 | " self.mask_index = 4\n", 507 | " super().__init__(counter, specials=[\"\", \"\", \"\", \"\", \"\"], max_size=max_size, min_freq=min_freq)\n", 508 | "\n", 509 | " # override用\n", 510 | " def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:\n", 511 | " pass\n", 512 | "\n", 513 | " # override用\n", 514 | " def from_seq(self, seq, join=False, with_pad=False):\n", 515 | " pass\n", 516 | "\n", 517 | " @staticmethod\n", 518 | " def load_vocab(vocab_path: str) -> 'Vocab':\n", 519 | " with open(vocab_path, \"rb\") as f:\n", 520 | " return pickle.load(f)\n", 521 | "\n", 522 | " def save_vocab(self, vocab_path):\n", 523 | " with open(vocab_path, \"wb\") as f:\n", 524 | " pickle.dump(self, f)\n", 525 | "\n", 526 | "\n", 527 | "# テキストファイルからvocabを作成する\n", 528 | "class WordVocab(Vocab):\n", 529 | " def __init__(self, texts, max_size=None, min_freq=1):\n", 530 | " print(\"Building Vocab\")\n", 531 | " counter = Counter()\n", 532 | " for line in texts:\n", 533 | " if isinstance(line, list):\n", 534 | " words = line\n", 535 | " else:\n", 536 | " words = line.replace(\"\\n\", \"\").replace(\"\\t\", \"\").split()\n", 537 | "\n", 538 | " for word in words:\n", 539 | " counter[word] += 1\n", 540 | " super().__init__(counter, max_size=max_size, min_freq=min_freq)\n", 541 | "\n", 542 | " def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):\n", 543 | " if isinstance(sentence, str):\n", 544 | " sentence = sentence.split()\n", 545 | "\n", 546 | " seq = [self.stoi.get(word, self.unk_index) for word in sentence]\n", 547 | "\n", 548 | " if with_eos:\n", 549 | " seq += [self.eos_index] # this would be index 1\n", 550 | " if with_sos:\n", 551 | " seq = [self.sos_index] + seq\n", 552 | "\n", 553 | " origin_seq_len = len(seq)\n", 554 | "\n", 555 | " if seq_len is None:\n", 556 | " pass\n", 557 | " elif len(seq) <= seq_len:\n", 558 | " seq += [self.pad_index for _ in range(seq_len - len(seq))]\n", 559 | " else:\n", 560 | " seq = seq[:seq_len]\n", 561 | "\n", 562 | " return (seq, origin_seq_len) if with_len else seq\n", 563 | "\n", 564 | " def from_seq(self, seq, join=False, with_pad=False):\n", 565 | " words = [self.itos[idx]\n", 566 | " if idx < len(self.itos)\n", 567 | " else \"<%d>\" % idx\n", 568 | " for idx in seq\n", 569 | " if not with_pad or idx != self.pad_index]\n", 570 | "\n", 571 | " return \" \".join(words) if join else words\n", 572 | "\n", 573 | " @staticmethod\n", 574 | " def load_vocab(vocab_path: str) -> 'WordVocab':\n", 575 | " with open(vocab_path, \"rb\") as f:\n", 576 | " return pickle.load(f)\n", 577 | "\n", 578 | "\n", 579 | "def build(corpus_path, output_path, vocab_size=None, encoding='utf-8', min_freq=1):\n", 580 | " with open(corpus_path, \"r\", encoding=encoding) as f:\n", 581 | " vocab = WordVocab(f, max_size=vocab_size, min_freq=min_freq)\n", 582 | "\n", 583 | " print(\"VOCAB SIZE:\", len(vocab))\n", 584 | " vocab.save_vocab(output_path)" 585 | ] 586 | }, 587 | { 588 | "cell_type": "markdown", 589 | "metadata": {}, 590 | "source": [ 591 | "Dataloaderを定義する.\n", 592 | "ここで文章中の単語をMASKする処理と,隣り合う文章を一定確率でシャッフルする処理を同時に行う" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": null, 598 | "metadata": {}, 599 | "outputs": [], 600 | "source": [ 601 | "class BERTDataset(Dataset):\n", 602 | " def __init__(self, corpus_path, vocab, seq_len, label_path='None', encoding=\"utf-8\", corpus_lines=None, is_train=True):\n", 603 | " self.vocab = vocab\n", 604 | " self.seq_len = seq_len\n", 605 | " self.is_train = is_train\n", 606 | "\n", 607 | " with open(corpus_path, \"r\", encoding=encoding) as f:\n", 608 | " self.datas = [line[:-1].split(\"\\t\") for line in f]\n", 609 | " if label_path:\n", 610 | " self.labels_data = torch.LongTensor(np.loadtxt(label_path))\n", 611 | " else:\n", 612 | " # ラベル不要の時はダミーデータを埋め込む\n", 613 | " self.labels_data = [0 for _ in range(len(self.datas))]\n", 614 | "\n", 615 | " def __len__(self):\n", 616 | " return len(self.datas)\n", 617 | "\n", 618 | " def __getitem__(self, item):\n", 619 | " t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)\n", 620 | " t1_random, t1_label = self.random_word(t1)\n", 621 | " t2_random, t2_label = self.random_word(t2)\n", 622 | " labels = self.labels_data[item]\n", 623 | "\n", 624 | " # [CLS] tag = SOS tag, [SEP] tag = EOS tag\n", 625 | " t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]\n", 626 | " t2 = t2_random + [self.vocab.eos_index]\n", 627 | "\n", 628 | " t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]\n", 629 | " t2_label = t2_label + [self.vocab.pad_index]\n", 630 | "\n", 631 | " segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]\n", 632 | " bert_input = (t1 + t2)[:self.seq_len]\n", 633 | " bert_label = (t1_label + t2_label)[:self.seq_len]\n", 634 | "\n", 635 | " padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]\n", 636 | " bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)\n", 637 | "\n", 638 | " output = {\"bert_input\": bert_input,\n", 639 | " \"bert_label\": bert_label,\n", 640 | " \"segment_label\": segment_label,\n", 641 | " \"is_next\": is_next_label,\n", 642 | " \"labels\": labels}\n", 643 | "\n", 644 | " return {key: torch.tensor(value) for key, value in output.items()}\n", 645 | "\n", 646 | " def random_word(self, sentence):\n", 647 | " tokens = sentence.split()\n", 648 | " output_label = []\n", 649 | "\n", 650 | " for i, token in enumerate(tokens):\n", 651 | " if self.is_train: # Trainingの時は確率的にMASKする\n", 652 | " prob = random.random()\n", 653 | " else: # Predictionの時はMASKをしない\n", 654 | " prob = 1.0\n", 655 | " if prob < 0.15:\n", 656 | " prob /= 0.15\n", 657 | "\n", 658 | " # 80% randomly change token to mask token\n", 659 | " if prob < 0.8:\n", 660 | " tokens[i] = self.vocab.mask_index\n", 661 | "\n", 662 | " # 10% randomly change token to random token\n", 663 | " elif prob < 0.9:\n", 664 | " tokens[i] = random.randrange(len(self.vocab))\n", 665 | "\n", 666 | " # 10% randomly change token to current token\n", 667 | " else:\n", 668 | " tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)\n", 669 | "\n", 670 | " output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))\n", 671 | "\n", 672 | " else:\n", 673 | " tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)\n", 674 | " output_label.append(0)\n", 675 | "\n", 676 | " return tokens, output_label\n", 677 | "\n", 678 | " def random_sent(self, index):\n", 679 | " # output_text, label(isNotNext:0, isNext:1)\n", 680 | " if random.random() > 0.5:\n", 681 | " return self.datas[index][1], 1\n", 682 | " else:\n", 683 | " return self.datas[random.randrange(len(self.datas))][1], 0\n" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": {}, 689 | "source": [ 690 | "Trainerクラスを定義する.\n", 691 | "BERTの事前学習ではふたつの言語モデル学習を行う.\n", 692 | "1. Masked Language Model : 文章中の一部の単語をマスクして,予測を行うタスク.\n", 693 | "2. Next Sentence prediction : ある文章の次に来る文章を予測するタスク." 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": null, 699 | "metadata": {}, 700 | "outputs": [], 701 | "source": [ 702 | "class BERTTrainer:\n", 703 | " def __init__(self, bert: BERT, vocab_size: int,\n", 704 | " train_dataloader: DataLoader, test_dataloader: DataLoader = None,\n", 705 | " lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,\n", 706 | " with_cuda: bool = True, log_freq: int = 10):\n", 707 | " \"\"\"\n", 708 | " :param bert: BERT model\n", 709 | " :param vocab_size: vocabに含まれるトータルの単語数\n", 710 | " :param train_dataloader: train dataset data loader\n", 711 | " :param test_dataloader: test dataset data loader [can be None]\n", 712 | " :param lr: 学習率\n", 713 | " :param betas: Adam optimizer betas\n", 714 | " :param weight_decay: Adam optimizer weight decay param\n", 715 | " :param with_cuda: traning with cuda\n", 716 | " :param log_freq: logを表示するiterationの頻度\n", 717 | " \"\"\"\n", 718 | "\n", 719 | " # GPU環境において、GPUを指定しているかのフラグ\n", 720 | " cuda_condition = torch.cuda.is_available() and with_cuda\n", 721 | " self.device = torch.device(\"cuda:0\" if cuda_condition else \"cpu\")\n", 722 | "\n", 723 | " self.bert = bert\n", 724 | " self.model = BERTLM(bert, vocab_size).to(self.device)\n", 725 | "\n", 726 | " if torch.cuda.device_count() > 1:\n", 727 | " print(\"Using %d GPUS for BERT\" % torch.cuda.device_count())\n", 728 | " self.model = nn.DataParallel(self.model)\n", 729 | "\n", 730 | " self.train_data = train_dataloader\n", 731 | " self.test_data = test_dataloader\n", 732 | "\n", 733 | " self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)\n", 734 | "\n", 735 | " # masked_token予測のためのLoss関数を設定\n", 736 | " self.criterion = nn.NLLLoss()\n", 737 | " self.log_freq = log_freq\n", 738 | " print(\"Total Parameters:\", sum([p.nelement() for p in self.model.parameters()]))\n", 739 | " \n", 740 | " self.train_lossses = []\n", 741 | " self.train_accs = []\n", 742 | "\n", 743 | " def train(self, epoch):\n", 744 | " self.iteration(epoch, self.train_data)\n", 745 | "\n", 746 | " def test(self, epoch):\n", 747 | " self.iteration(epoch, self.test_data, train=False)\n", 748 | "\n", 749 | " def iteration(self, epoch, data_loader, train=True):\n", 750 | " \"\"\"\n", 751 | " :param epoch: 現在のepoch\n", 752 | " :param data_loader: torch.utils.data.DataLoader\n", 753 | " :param train: trainかtestかのbool値\n", 754 | " \"\"\"\n", 755 | " str_code = \"train\" if train else \"test\"\n", 756 | "\n", 757 | " data_iter = tqdm.tqdm(enumerate(data_loader), desc=\"EP_%s:%d\" % (str_code, epoch), total=len(data_loader), bar_format=\"{l_bar}{r_bar}\")\n", 758 | "\n", 759 | "\n", 760 | " avg_loss = 0.0\n", 761 | " total_correct = 0\n", 762 | " total_element = 0\n", 763 | "\n", 764 | " for i, data in data_iter:\n", 765 | " # 0. batch_dataはGPU or CPUに載せる\n", 766 | " data = {key: value.to(self.device) for key, value in data.items()}\n", 767 | "\n", 768 | " # 1. forward the next_sentence_prediction and masked_lm model\n", 769 | " next_sent_output, mask_lm_output = self.model.forward(data[\"bert_input\"], data[\"segment_label\"])\n", 770 | "\n", 771 | " # 2-1. NLLLoss(negative log likelihood) : next_sentence_predictionのLoss\n", 772 | " next_loss = self.criterion(next_sent_output, data[\"is_next\"])\n", 773 | "\n", 774 | " # 2-2. NLLLoss(negative log likelihood) : predicting masked token word\n", 775 | " mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data[\"bert_label\"])\n", 776 | "\n", 777 | " # 2-3. next_lossとmask_lossの合計をlossとする\n", 778 | " loss = next_loss + mask_loss\n", 779 | "\n", 780 | " # 3. training時のみ,backwardとoptimizer更新を行う\n", 781 | " if train:\n", 782 | " self.optim.zero_grad()\n", 783 | " loss.backward()\n", 784 | " self.optim.step()\n", 785 | "\n", 786 | " # next sentence prediction accuracy\n", 787 | " correct = next_sent_output.argmax(dim=-1).eq(data[\"is_next\"]).sum().item()\n", 788 | " avg_loss += loss.item()\n", 789 | " total_correct += correct\n", 790 | " total_element += data[\"is_next\"].nelement()\n", 791 | "\n", 792 | " post_fix = {\n", 793 | " \"epoch\": epoch,\n", 794 | " \"iter\": i,\n", 795 | " \"avg_loss\": avg_loss / (i + 1),\n", 796 | " \"avg_acc\": total_correct / total_element * 100,\n", 797 | " \"loss\": loss.item()\n", 798 | " }\n", 799 | "\n", 800 | " if i % self.log_freq == 0:\n", 801 | " data_iter.write(str(post_fix))\n", 802 | "\n", 803 | " print(\"EP%d_%s, avg_loss=\" % (epoch, str_code), avg_loss / len(data_iter), \"total_acc=\", total_correct * 100.0 / total_element)\n", 804 | " self.train_lossses.append(avg_loss / len(data_iter))\n", 805 | " self.train_accs.append(total_correct * 100.0 / total_element)\n", 806 | " \n", 807 | " def save(self, epoch, file_path=\"output/bert_trained.model\"):\n", 808 | " \"\"\"\n", 809 | " Saving the current BERT model on file_path\n", 810 | "\n", 811 | " :param epoch: current epoch number\n", 812 | " :param file_path: model output path which gonna be file_path+\"ep%d\" % epoch\n", 813 | " :return: final_output_path\n", 814 | " \"\"\"\n", 815 | " output_path = file_path + \".ep%d\" % epoch\n", 816 | " torch.save(self.bert.cpu(), output_path)\n", 817 | " self.bert.to(self.device)\n", 818 | " print(\"EP:%d Model Saved on:\" % epoch, output_path)\n", 819 | " return output_path" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": null, 825 | "metadata": {}, 826 | "outputs": [], 827 | "source": [ 828 | "import datetime\n", 829 | "dt_now = str(datetime.datetime.now()).replace(' ', '')" 830 | ] 831 | }, 832 | { 833 | "cell_type": "code", 834 | "execution_count": null, 835 | "metadata": {}, 836 | "outputs": [], 837 | "source": [ 838 | "# 訓練用パラメタを定義する\n", 839 | "train_dataset=processed_train_txt\n", 840 | "test_dataset=processed_valid_txt\n", 841 | "vocab_path='./data/vocab'+ dt_now +'.txt'\n", 842 | "output_model_path='./output/bertmodel'+ dt_now\n", 843 | "\n", 844 | "hidden=256 #768\n", 845 | "layers=8 #12\n", 846 | "attn_heads=8 #12\n", 847 | "seq_len=60\n", 848 | "\n", 849 | "batch_size=64\n", 850 | "epochs=10\n", 851 | "num_workers=5\n", 852 | "with_cuda=True\n", 853 | "log_freq=20\n", 854 | "corpus_lines=None\n", 855 | "\n", 856 | "lr=1e-3\n", 857 | "adam_weight_decay=0.00\n", 858 | "adam_beta1=0.9\n", 859 | "adam_beta2=0.999\n", 860 | "\n", 861 | "dropout=0.0\n", 862 | "\n", 863 | "min_freq=7\n", 864 | "\n", 865 | "corpus_path=processed_train_txt\n", 866 | "label_path=None" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": null, 872 | "metadata": {}, 873 | "outputs": [], 874 | "source": [ 875 | "build(corpus_path, vocab_path, min_freq=min_freq)\n", 876 | "\n", 877 | "print(\"Loading Vocab\", vocab_path)\n", 878 | "vocab = WordVocab.load_vocab(vocab_path)\n", 879 | "\n", 880 | "print(\"Loading Train Dataset\", train_dataset)\n", 881 | "train_dataset = BERTDataset(train_dataset, vocab, seq_len=seq_len, label_path=label_path, corpus_lines=corpus_lines)\n", 882 | "\n", 883 | "print(\"Loading Test Dataset\", test_dataset)\n", 884 | "test_dataset = BERTDataset(test_dataset, vocab, seq_len=seq_len, label_path=label_path) if test_dataset is not None else None\n", 885 | "\n", 886 | "print(\"Creating Dataloader\")\n", 887 | "train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)\n", 888 | "test_data_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers) if test_dataset is not None else None\n", 889 | "\n", 890 | "print(\"Building BERT model\")\n", 891 | "bert = BERT(len(vocab), hidden=hidden, n_layers=layers, attn_heads=attn_heads, dropout=dropout)" 892 | ] 893 | }, 894 | { 895 | "cell_type": "code", 896 | "execution_count": null, 897 | "metadata": {}, 898 | "outputs": [], 899 | "source": [ 900 | "print(\"Creating BERT Trainer\")\n", 901 | "trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,\n", 902 | " lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,\n", 903 | " with_cuda=with_cuda, log_freq=log_freq)" 904 | ] 905 | }, 906 | { 907 | "cell_type": "code", 908 | "execution_count": null, 909 | "metadata": {}, 910 | "outputs": [], 911 | "source": [ 912 | "print(\"Training Start\")\n", 913 | "for epoch in range(epochs):\n", 914 | " trainer.train(epoch)\n", 915 | " # Model Save\n", 916 | " trainer.save(epoch, output_model_path)\n", 917 | " trainer.test(epoch)" 918 | ] 919 | } 920 | ], 921 | "metadata": { 922 | "kernelspec": { 923 | "display_name": "Python 3", 924 | "language": "python", 925 | "name": "python3" 926 | }, 927 | "language_info": { 928 | "codemirror_mode": { 929 | "name": "ipython", 930 | "version": 3 931 | }, 932 | "file_extension": ".py", 933 | "mimetype": "text/x-python", 934 | "name": "python", 935 | "nbconvert_exporter": "python", 936 | "pygments_lexer": "ipython3", 937 | "version": "3.6.1" 938 | } 939 | }, 940 | "nbformat": 4, 941 | "nbformat_minor": 2 942 | } 943 | --------------------------------------------------------------------------------