├── .gitignore
├── README.md
├── utils.py
├── text_preprocess.ipynb
└── text_bert.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | /__pycache__
2 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ja_text_bert
2 |
3 | BERT with Japanese text
4 |
5 | # Reproduing Results
6 |
7 | 1. Run `jupyter notebook` or `jupyter lab`
8 |
9 | 2. Run on jupyter `text_preprocess.ipynb`
10 |
11 | 3. Run on jupyter `text_bert.ipynb`
12 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 |
5 | # 活性化関数
6 | class GELU(nn.Module):
7 | def forward(self, x):
8 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
9 |
10 | # 位置情報を考慮したFFN
11 | class PositionwiseFeedForward(nn.Module):
12 | def __init__(self, d_model, d_ff, dropout=0.1):
13 | super(PositionwiseFeedForward, self).__init__()
14 | self.w_1 = nn.Linear(d_model, d_ff)
15 | self.w_2 = nn.Linear(d_ff, d_model)
16 | self.dropout = nn.Dropout(dropout)
17 | self.activation = GELU()
18 |
19 | def forward(self, x):
20 | return self.w_2(self.dropout(self.activation(self.w_1(x))))
21 |
22 | # 正規化層
23 | class LayerNorm(nn.Module):
24 | def __init__(self, features, eps=1e-6):
25 | super(LayerNorm, self).__init__()
26 | self.a_2 = nn.Parameter(torch.ones(features))
27 | self.b_2 = nn.Parameter(torch.zeros(features))
28 | self.eps = eps
29 |
30 | def forward(self, x):
31 | mean = x.mean(-1, keepdim=True)
32 | std = x.std(-1, keepdim=True)
33 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
34 |
35 |
36 | class SublayerConnection(nn.Module):
37 | def __init__(self, size, dropout):
38 | super(SublayerConnection, self).__init__()
39 | self.norm = LayerNorm(size)
40 | self.dropout = nn.Dropout(dropout)
41 |
42 | def forward(self, x, sublayer):
43 | return x + self.dropout(sublayer(self.norm(x)))
--------------------------------------------------------------------------------
/text_preprocess.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "日本語wikiからコーパスを作成するスクリプトです.
\n",
8 | "https://dumps.wikimedia.org/jawiki/latest/
\n",
9 | "こちらのサイトから最新版の\"pages-articles\"のアドレスを手に入れてください.
"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "! wget https://dumps.wikimedia.org/jawiki/latest/jawiki-latest-pages-articles.xml.bz2"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "ダンプデータには不要なマークアップなどが含まれているので、取り除くためのテキストクリーニング用のスクリプトをgitから持ってきます"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "! git clone https://github.com/attardi/wikiextractor.git"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "日本語wikiに対してテキストクリーニングを実行します"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "! python wikiextractor/WikiExtractor.py -o extracted jawiki-latest-pages-articles.xml.bz2"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {},
56 | "source": [
57 | "テキストに前処理を加えた上で,複数のtxtファイルをひとつに結合します"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "import glob\n",
67 | "from bs4 import BeautifulSoup\n",
68 | "\n",
69 | "with open('./tmp.txt','w') as f:\n",
70 | " for directory in glob.glob('./extracted/*'):\n",
71 | " for name in glob.glob(directory+'/*'):\n",
72 | " with open(name, 'r') as r:\n",
73 | " for line in r:\n",
74 | " # titleを削除する\n",
75 | " if '' in line:\n",
79 | " f.write('\\n')\n",
80 | " continue\n",
81 | " else:\n",
82 | " # 空白・改行削除、大文字を小文字に変換\n",
83 | " text = BeautifulSoup(line.strip()).text.lower()\n",
84 | " f.write(text)"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "ここからはBERTのトレーニング用にテキストファイルを整形していきます.
\n",
92 | "文章を単語ごとに分割し, ひとつの単元の中に偶数個の文章が含まれるように調整します."
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "import linecache\n",
102 | "import random\n",
103 | "import MeCab"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "random.seed(42)\n",
113 | "filename = 'tmp.txt'\n",
114 | "save_file = 'even_rows100M.txt'\n",
115 | "LIMIT_BYTE = 100000000 # 100Mbyte\n",
116 | "# t = MeCab.Tagger('-Owakati') # Neologdを辞書に使っている人場合はそちらを使用するのがベターです\n",
117 | "t = MeCab.Tagger('-d /usr/local/lib/mecab/dic/mecab-ipadic-neologd/ -Owakati')\n",
118 | "\n",
119 | "def get_byte_num(s):\n",
120 | " return len(s.encode('utf-8'))"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "with open(save_file, 'w') as f:\n",
130 | " count_byte = 0\n",
131 | " with open(filename) as r:\n",
132 | " for text in r:\n",
133 | " print('{} bytes'.format(count_byte))\n",
134 | " text = t.parse(text).strip()\n",
135 | " # 一文ごとに分割する\n",
136 | " text = text.split('。')\n",
137 | " # 空白要素は捨てる\n",
138 | " text = [t.strip() for t in text if t]\n",
139 | " # 一単元の文書が偶数個の文章から成るようにする(BERTのデータセットの都合上)\n",
140 | " max_text_len = len(text) // 2\n",
141 | " text = text[:max_text_len * 2]\n",
142 | " text = '\\n'.join(text)\n",
143 | " f.write(text)\n",
144 | " count_byte += get_byte_num(text)\n",
145 | " if count_byte >= LIMIT_BYTE:\n",
146 | " break"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "これでBERTの学習に使うデータセットができました.
\n",
154 | "今度はTraining用とValidation用のデータに分割します."
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "num_lines = sum(1 for line in open(save_file))\n",
164 | "print('Base file lines : ', num_lines)\n",
165 | "# 全体の80%をTraining dataに当てます\n",
166 | "train_lines = int(num_lines * 0.8)\n",
167 | "print('Train file lines : ', train_lines)"
168 | ]
169 | },
170 | {
171 | "cell_type": "markdown",
172 | "metadata": {},
173 | "source": [
174 | "dataは前処理済みテキスト保存場所
\n",
175 | "outputは訓練モデル保存場所として作成"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 1,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "! mkdir -p data output"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "out_file_name_temp = './data/splitted_%d.txt'\n",
194 | "\n",
195 | "split_index = 1\n",
196 | "line_index = 1\n",
197 | "out_file = open(out_file_name_temp % (split_index,), 'w')\n",
198 | "in_file = open(save_file)\n",
199 | "line = in_file.readline()\n",
200 | "while line:\n",
201 | " if line_index > train_lines:\n",
202 | " print('Starting file: %d' % split_index)\n",
203 | " out_file.close()\n",
204 | " split_index = split_index + 1\n",
205 | " line_index = 1\n",
206 | " out_file = open(out_file_name_temp % (split_index,), 'w')\n",
207 | " out_file.write(line)\n",
208 | " line_index = line_index + 1\n",
209 | " line = in_file.readline()\n",
210 | " \n",
211 | "out_file.close()\n",
212 | "in_file.close()"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": null,
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "print('Train file lines : ', sum(1 for line in open('./data/splitted_1.txt')))\n",
222 | "print('Valid file lines : ', sum(1 for line in open('./data/splitted_2.txt')))"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "metadata": {},
228 | "source": [
229 | "これにてテキストの前処理は完了です!"
230 | ]
231 | }
232 | ],
233 | "metadata": {
234 | "kernelspec": {
235 | "display_name": "Python 3",
236 | "language": "python",
237 | "name": "python3"
238 | },
239 | "language_info": {
240 | "codemirror_mode": {
241 | "name": "ipython",
242 | "version": 3
243 | },
244 | "file_extension": ".py",
245 | "mimetype": "text/x-python",
246 | "name": "python",
247 | "nbconvert_exporter": "python",
248 | "pygments_lexer": "ipython3",
249 | "version": "3.6.1"
250 | }
251 | },
252 | "nbformat": 4,
253 | "nbformat_minor": 2
254 | }
255 |
--------------------------------------------------------------------------------
/text_bert.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "必要ModuleをImport"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import torch\n",
17 | "import torch.nn as nn\n",
18 | "import torch.nn.functional as F\n",
19 | "from torch.optim import Adam\n",
20 | "from torch.utils.data import DataLoader\n",
21 | "import math\n",
22 | "\n",
23 | "import pickle\n",
24 | "import tqdm\n",
25 | "from collections import Counter\n",
26 | "\n",
27 | "from torch.utils.data import Dataset\n",
28 | "import random\n",
29 | "import numpy as np\n",
30 | "\n",
31 | "from utils import GELU, PositionwiseFeedForward, LayerNorm, SublayerConnection, LayerNorm\n",
32 | "\n",
33 | "import matplotlib\n",
34 | "%matplotlib inline\n",
35 | "import matplotlib.pyplot as plt\n",
36 | "from ipywidgets import FloatProgress\n",
37 | "from IPython.display import display, clear_output"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "input_train_txt = './data/splitted_1.txt'\n",
47 | "input_valid_txt = './data/splitted_2.txt'\n",
48 | "processed_train_txt = './data/train_X.txt'\n",
49 | "processed_valid_txt = './data/valid_X.txt'"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "Next Sentence Predictionのために, 意味的に連続する文章をtab区切りで並べる前処理をデータセットに対して行います."
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "# 偶数行の文章を奇数行の文章と接続するメソッド\n",
66 | "def load_data(path):\n",
67 | " with open(path, encoding='utf-8') as f:\n",
68 | " even_rows = []\n",
69 | " odd_rows = []\n",
70 | " all_f = f.readlines()\n",
71 | " for row in all_f[2::2]:\n",
72 | " even_rows.append(row.strip().replace('\\n', ''))\n",
73 | " for row in all_f[1::2]:\n",
74 | " odd_rows.append(row.strip().replace('\\n', ''))\n",
75 | " min_rows_len = int(min(len(even_rows), len(odd_rows)))\n",
76 | " even_rows = even_rows[:min_rows_len]\n",
77 | " odd_rows = odd_rows[:min_rows_len]\n",
78 | "\n",
79 | " concat_rows = []\n",
80 | " for even_r, odd_r in zip(even_rows, odd_rows):\n",
81 | " concat_r = '\\t'.join([even_r, odd_r])\n",
82 | " concat_rows.append(concat_r)\n",
83 | " return concat_rows"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "train_data = load_data(input_train_txt)\n",
93 | "valid_data = load_data(input_valid_txt)\n",
94 | "\n",
95 | "# ランダムに並び替える\n",
96 | "random.shuffle(train_data)\n",
97 | "random.shuffle(valid_data)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "with open(processed_train_txt, 'w') as f:\n",
107 | " f.write('\\n'.join(train_data))"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "with open(processed_valid_txt, 'w') as f:\n",
117 | " f.write('\\n'.join(valid_data))"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {},
123 | "source": [
124 | "Attentionセルを定義する"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "class Attention(nn.Module):\n",
134 | " \"\"\"\n",
135 | " Scaled Dot Product Attention\n",
136 | " \"\"\"\n",
137 | "\n",
138 | " def forward(self, query, key, value, mask=None, dropout=None):\n",
139 | " scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))\n",
140 | "\n",
141 | " if mask is not None:\n",
142 | " scores = scores.masked_fill(mask == 0, -1e9)\n",
143 | "\n",
144 | " p_attn = F.softmax(scores, dim=-1)\n",
145 | "\n",
146 | " if dropout is not None:\n",
147 | " p_attn = dropout(p_attn)\n",
148 | "\n",
149 | " return torch.matmul(p_attn, value), p_attn\n"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {},
155 | "source": [
156 | "Multi Head Attentionを定義する"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "class MultiHeadedAttention(nn.Module):\n",
166 | "\n",
167 | " def __init__(self, h, d_model, dropout=0.1):\n",
168 | " super().__init__()\n",
169 | " assert d_model % h == 0\n",
170 | "\n",
171 | " # We assume d_v always equals d_k\n",
172 | " self.d_k = d_model // h\n",
173 | " self.h = h\n",
174 | "\n",
175 | " self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])\n",
176 | " self.output_linear = nn.Linear(d_model, d_model)\n",
177 | " self.attention = Attention()\n",
178 | "\n",
179 | " self.dropout = nn.Dropout(p=dropout)\n",
180 | "\n",
181 | " def forward(self, query, key, value, mask=None):\n",
182 | " batch_size = query.size(0)\n",
183 | "\n",
184 | " query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))]\n",
185 | "\n",
186 | " x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)\n",
187 | "\n",
188 | " x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)\n",
189 | "\n",
190 | " return self.output_linear(x)\n"
191 | ]
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "metadata": {},
196 | "source": [
197 | "Transformerを定義する"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "class TransformerBlock(nn.Module):\n",
207 | " \"\"\"\n",
208 | " Bidirectional Encoder = Transformer (self-attention)\n",
209 | " Transformer = MultiHead_Attention + Feed_Forward with sublayer connection\n",
210 | " \"\"\"\n",
211 | "\n",
212 | " def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):\n",
213 | " \"\"\"\n",
214 | " :param hidden: hidden size of transformer\n",
215 | " :param attn_heads: head sizes of multi-head attention\n",
216 | " :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size\n",
217 | " :param dropout: dropout rate\n",
218 | " \"\"\"\n",
219 | "\n",
220 | " super().__init__()\n",
221 | " self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)\n",
222 | " self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)\n",
223 | " self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n",
224 | " self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n",
225 | " self.dropout = nn.Dropout(p=dropout)\n",
226 | "\n",
227 | " def forward(self, x, mask):\n",
228 | " x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))\n",
229 | " x = self.output_sublayer(x, self.feed_forward)\n",
230 | " return self.dropout(x)\n"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "metadata": {},
236 | "source": [
237 | "BERTクラスを定義する"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "class BERT(nn.Module):\n",
247 | "\n",
248 | " def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):\n",
249 | " \"\"\"\n",
250 | " :param vocab_size: vocab_size of total words\n",
251 | " :param hidden: BERT model hidden size\n",
252 | " :param n_layers: numbers of Transformer blocks(layers)\n",
253 | " :param attn_heads: number of attention heads\n",
254 | " :param dropout: dropout rate\n",
255 | " \"\"\"\n",
256 | "\n",
257 | " super().__init__()\n",
258 | " self.hidden = hidden\n",
259 | " self.n_layers = n_layers\n",
260 | " self.attn_heads = attn_heads\n",
261 | "\n",
262 | " self.feed_forward_hidden = hidden * 4\n",
263 | "\n",
264 | " # embedding for BERT\n",
265 | " self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden, dropout=dropout)\n",
266 | "\n",
267 | " self.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])\n",
268 | "\n",
269 | " def forward(self, x, segment_info):\n",
270 | " # xの中で0以上は1, 0未満は0として, maskテンソルを作る\n",
271 | " mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)\n",
272 | "\n",
273 | " x = self.embedding(x, segment_info)\n",
274 | "\n",
275 | " for transformer in self.transformer_blocks:\n",
276 | " x = transformer.forward(x, mask)\n",
277 | " return x\n"
278 | ]
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "metadata": {},
283 | "source": [
284 | "BERTのEmbedding層を定義する"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": null,
290 | "metadata": {},
291 | "outputs": [],
292 | "source": [
293 | "class TokenEmbedding(nn.Embedding):\n",
294 | " def __init__(self, vocab_size, embed_size=512):\n",
295 | " super().__init__(vocab_size, embed_size, padding_idx=0)\n",
296 | "\n",
297 | "class PositionalEmbedding(nn.Module):\n",
298 | "\n",
299 | " def __init__(self, d_model, max_len=512):\n",
300 | " super().__init__()\n",
301 | "\n",
302 | " pe = torch.zeros(max_len, d_model).float()\n",
303 | " pe.require_grad = False\n",
304 | "\n",
305 | " position = torch.arange(0, max_len).float().unsqueeze(1)\n",
306 | " div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp()\n",
307 | "\n",
308 | " pe[:, 0::2] = torch.sin(position * div_term)\n",
309 | " pe[:, 1::2] = torch.cos(position * div_term)\n",
310 | "\n",
311 | " pe = pe.unsqueeze(0)\n",
312 | " self.register_buffer('pe', pe)\n",
313 | "\n",
314 | " def forward(self, x):\n",
315 | " return self.pe[:, :x.size(1)]\n",
316 | "\n",
317 | "class SegmentEmbedding(nn.Embedding):\n",
318 | " def __init__(self, embed_size=512):\n",
319 | " super().__init__(3, embed_size, padding_idx=0)\n",
320 | "\n",
321 | "class BERTEmbedding(nn.Module):\n",
322 | " \"\"\"\n",
323 | " BERT Embedding which is consisted with under features\n",
324 | " 1. TokenEmbedding : 通常のEMbedding\n",
325 | " 2. PositionalEmbedding : sin, cosを用いた位置情報付きEmbedding\n",
326 | " 2. SegmentEmbedding : Sentenceのセグメント情報 (sent_A:1, sent_B:2)\n",
327 | " \"\"\"\n",
328 | " def __init__(self, vocab_size, embed_size, dropout=0.1):\n",
329 | " super().__init__()\n",
330 | " self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)\n",
331 | " self.position = PositionalEmbedding(d_model=self.token.embedding_dim)\n",
332 | " self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)\n",
333 | " self.dropout = nn.Dropout(p=dropout)\n",
334 | " self.embed_size = embed_size\n",
335 | "\n",
336 | " def forward(self, sequence, segment_label):\n",
337 | " x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)\n",
338 | " return self.dropout(x)\n"
339 | ]
340 | },
341 | {
342 | "cell_type": "markdown",
343 | "metadata": {},
344 | "source": [
345 | "学習用にマスク予測・隣接文予測の層を追加する"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "metadata": {},
352 | "outputs": [],
353 | "source": [
354 | "class BERTLM(nn.Module):\n",
355 | " \"\"\"\n",
356 | " BERT Language Model\n",
357 | " Next Sentence Prediction Model + Masked Language Model\n",
358 | " \"\"\"\n",
359 | "\n",
360 | " def __init__(self, bert: BERT, vocab_size):\n",
361 | " \"\"\"\n",
362 | " :param bert: BERT model which should be trained\n",
363 | " :param vocab_size: total vocab size for masked_lm\n",
364 | " \"\"\"\n",
365 | "\n",
366 | " super().__init__()\n",
367 | " self.bert = bert\n",
368 | " self.next_sentence = NextSentencePrediction(self.bert.hidden)\n",
369 | " self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)\n",
370 | "\n",
371 | " def forward(self, x, segment_label):\n",
372 | " x = self.bert(x, segment_label)\n",
373 | " return self.next_sentence(x), self.mask_lm(x)\n",
374 | "\n",
375 | "\n",
376 | "class NextSentencePrediction(nn.Module):\n",
377 | " \"\"\"\n",
378 | " 2クラス分類問題 : is_next, is_not_next\n",
379 | " \"\"\"\n",
380 | "\n",
381 | " def __init__(self, hidden):\n",
382 | " \"\"\"\n",
383 | " :param hidden: BERT model output size\n",
384 | " \"\"\"\n",
385 | " super().__init__()\n",
386 | " self.linear = nn.Linear(hidden, 2)\n",
387 | " self.softmax = nn.LogSoftmax(dim=-1)\n",
388 | "\n",
389 | " def forward(self, x):\n",
390 | " return self.softmax(self.linear(x[:, 0]))\n",
391 | "\n",
392 | "\n",
393 | "class MaskedLanguageModel(nn.Module):\n",
394 | " \"\"\"\n",
395 | " 入力系列のMASKトークンから元の単語を予測する\n",
396 | " nクラス分類問題, nクラス : vocab_size\n",
397 | " \"\"\"\n",
398 | "\n",
399 | " def __init__(self, hidden, vocab_size):\n",
400 | " \"\"\"\n",
401 | " :param hidden: output size of BERT model\n",
402 | " :param vocab_size: total vocab size\n",
403 | " \"\"\"\n",
404 | " super().__init__()\n",
405 | " self.linear = nn.Linear(hidden, vocab_size)\n",
406 | " self.softmax = nn.LogSoftmax(dim=-1)\n",
407 | "\n",
408 | " def forward(self, x):\n",
409 | " return self.softmax(self.linear(x))"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "metadata": {},
415 | "source": [
416 | "BERT用のVocabを生成するクラスを定義する"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": null,
422 | "metadata": {},
423 | "outputs": [],
424 | "source": [
425 | "import pickle\n",
426 | "import tqdm\n",
427 | "from collections import Counter\n",
428 | "\n",
429 | "\n",
430 | "class TorchVocab(object):\n",
431 | " \"\"\"\n",
432 | " :property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト\n",
433 | " :property stoi: collections.defaultdict, string → id の対応を示す辞書\n",
434 | " :property itos: collections.defaultdict, id → string の対応を示す辞書\n",
435 | " \"\"\"\n",
436 | " def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''],\n",
437 | " vectors=None, unk_init=None, vectors_cache=None):\n",
438 | " \"\"\"\n",
439 | " :param coutenr: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter\n",
440 | " :param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone\n",
441 | " :param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない.\n",
442 | " :param specials: list of str, vocabularyにあらかじめ登録するtoken\n",
443 | " :param vecors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors\n",
444 | " \"\"\"\n",
445 | " self.freqs = counter\n",
446 | " counter = counter.copy()\n",
447 | " min_freq = max(min_freq, 1)\n",
448 | "\n",
449 | " self.itos = list(specials)\n",
450 | " # special tokensの出現頻度はvocabulary作成の際にカウントされない\n",
451 | " for tok in specials:\n",
452 | " del counter[tok]\n",
453 | "\n",
454 | " max_size = None if max_size is None else max_size + len(self.itos)\n",
455 | "\n",
456 | " # まず頻度でソートし、次に文字順で並び替える\n",
457 | " words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])\n",
458 | " words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)\n",
459 | " \n",
460 | " # 出現頻度がmin_freq未満のものはvocabに加えない\n",
461 | " for word, freq in words_and_frequencies:\n",
462 | " if freq < min_freq or len(self.itos) == max_size:\n",
463 | " break\n",
464 | " self.itos.append(word)\n",
465 | "\n",
466 | " # dictのk,vをいれかえてstoiを作成する\n",
467 | " self.stoi = {tok: i for i, tok in enumerate(self.itos)}\n",
468 | "\n",
469 | " self.vectors = None\n",
470 | " if vectors is not None:\n",
471 | " self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)\n",
472 | " else:\n",
473 | " assert unk_init is None and vectors_cache is None\n",
474 | "\n",
475 | " def __eq__(self, other):\n",
476 | " if self.freqs != other.freqs:\n",
477 | " return False\n",
478 | " if self.stoi != other.stoi:\n",
479 | " return False\n",
480 | " if self.itos != other.itos:\n",
481 | " return False\n",
482 | " if self.vectors != other.vectors:\n",
483 | " return False\n",
484 | " return True\n",
485 | "\n",
486 | " def __len__(self):\n",
487 | " return len(self.itos)\n",
488 | "\n",
489 | " def vocab_rerank(self):\n",
490 | " self.stoi = {word: i for i, word in enumerate(self.itos)}\n",
491 | "\n",
492 | " def extend(self, v, sort=False):\n",
493 | " words = sorted(v.itos) if sort else v.itos\n",
494 | " for w in words:\n",
495 | " if w not in self.stoi:\n",
496 | " self.itos.append(w)\n",
497 | " self.stoi[w] = len(self.itos) - 1\n",
498 | "\n",
499 | "\n",
500 | "class Vocab(TorchVocab):\n",
501 | " def __init__(self, counter, max_size=None, min_freq=1):\n",
502 | " self.pad_index = 0\n",
503 | " self.unk_index = 1\n",
504 | " self.eos_index = 2\n",
505 | " self.sos_index = 3\n",
506 | " self.mask_index = 4\n",
507 | " super().__init__(counter, specials=[\"\", \"\", \"\", \"\", \"\"], max_size=max_size, min_freq=min_freq)\n",
508 | "\n",
509 | " # override用\n",
510 | " def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:\n",
511 | " pass\n",
512 | "\n",
513 | " # override用\n",
514 | " def from_seq(self, seq, join=False, with_pad=False):\n",
515 | " pass\n",
516 | "\n",
517 | " @staticmethod\n",
518 | " def load_vocab(vocab_path: str) -> 'Vocab':\n",
519 | " with open(vocab_path, \"rb\") as f:\n",
520 | " return pickle.load(f)\n",
521 | "\n",
522 | " def save_vocab(self, vocab_path):\n",
523 | " with open(vocab_path, \"wb\") as f:\n",
524 | " pickle.dump(self, f)\n",
525 | "\n",
526 | "\n",
527 | "# テキストファイルからvocabを作成する\n",
528 | "class WordVocab(Vocab):\n",
529 | " def __init__(self, texts, max_size=None, min_freq=1):\n",
530 | " print(\"Building Vocab\")\n",
531 | " counter = Counter()\n",
532 | " for line in texts:\n",
533 | " if isinstance(line, list):\n",
534 | " words = line\n",
535 | " else:\n",
536 | " words = line.replace(\"\\n\", \"\").replace(\"\\t\", \"\").split()\n",
537 | "\n",
538 | " for word in words:\n",
539 | " counter[word] += 1\n",
540 | " super().__init__(counter, max_size=max_size, min_freq=min_freq)\n",
541 | "\n",
542 | " def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):\n",
543 | " if isinstance(sentence, str):\n",
544 | " sentence = sentence.split()\n",
545 | "\n",
546 | " seq = [self.stoi.get(word, self.unk_index) for word in sentence]\n",
547 | "\n",
548 | " if with_eos:\n",
549 | " seq += [self.eos_index] # this would be index 1\n",
550 | " if with_sos:\n",
551 | " seq = [self.sos_index] + seq\n",
552 | "\n",
553 | " origin_seq_len = len(seq)\n",
554 | "\n",
555 | " if seq_len is None:\n",
556 | " pass\n",
557 | " elif len(seq) <= seq_len:\n",
558 | " seq += [self.pad_index for _ in range(seq_len - len(seq))]\n",
559 | " else:\n",
560 | " seq = seq[:seq_len]\n",
561 | "\n",
562 | " return (seq, origin_seq_len) if with_len else seq\n",
563 | "\n",
564 | " def from_seq(self, seq, join=False, with_pad=False):\n",
565 | " words = [self.itos[idx]\n",
566 | " if idx < len(self.itos)\n",
567 | " else \"<%d>\" % idx\n",
568 | " for idx in seq\n",
569 | " if not with_pad or idx != self.pad_index]\n",
570 | "\n",
571 | " return \" \".join(words) if join else words\n",
572 | "\n",
573 | " @staticmethod\n",
574 | " def load_vocab(vocab_path: str) -> 'WordVocab':\n",
575 | " with open(vocab_path, \"rb\") as f:\n",
576 | " return pickle.load(f)\n",
577 | "\n",
578 | "\n",
579 | "def build(corpus_path, output_path, vocab_size=None, encoding='utf-8', min_freq=1):\n",
580 | " with open(corpus_path, \"r\", encoding=encoding) as f:\n",
581 | " vocab = WordVocab(f, max_size=vocab_size, min_freq=min_freq)\n",
582 | "\n",
583 | " print(\"VOCAB SIZE:\", len(vocab))\n",
584 | " vocab.save_vocab(output_path)"
585 | ]
586 | },
587 | {
588 | "cell_type": "markdown",
589 | "metadata": {},
590 | "source": [
591 | "Dataloaderを定義する.\n",
592 | "ここで文章中の単語をMASKする処理と,隣り合う文章を一定確率でシャッフルする処理を同時に行う"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": null,
598 | "metadata": {},
599 | "outputs": [],
600 | "source": [
601 | "class BERTDataset(Dataset):\n",
602 | " def __init__(self, corpus_path, vocab, seq_len, label_path='None', encoding=\"utf-8\", corpus_lines=None, is_train=True):\n",
603 | " self.vocab = vocab\n",
604 | " self.seq_len = seq_len\n",
605 | " self.is_train = is_train\n",
606 | "\n",
607 | " with open(corpus_path, \"r\", encoding=encoding) as f:\n",
608 | " self.datas = [line[:-1].split(\"\\t\") for line in f]\n",
609 | " if label_path:\n",
610 | " self.labels_data = torch.LongTensor(np.loadtxt(label_path))\n",
611 | " else:\n",
612 | " # ラベル不要の時はダミーデータを埋め込む\n",
613 | " self.labels_data = [0 for _ in range(len(self.datas))]\n",
614 | "\n",
615 | " def __len__(self):\n",
616 | " return len(self.datas)\n",
617 | "\n",
618 | " def __getitem__(self, item):\n",
619 | " t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)\n",
620 | " t1_random, t1_label = self.random_word(t1)\n",
621 | " t2_random, t2_label = self.random_word(t2)\n",
622 | " labels = self.labels_data[item]\n",
623 | "\n",
624 | " # [CLS] tag = SOS tag, [SEP] tag = EOS tag\n",
625 | " t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]\n",
626 | " t2 = t2_random + [self.vocab.eos_index]\n",
627 | "\n",
628 | " t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]\n",
629 | " t2_label = t2_label + [self.vocab.pad_index]\n",
630 | "\n",
631 | " segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]\n",
632 | " bert_input = (t1 + t2)[:self.seq_len]\n",
633 | " bert_label = (t1_label + t2_label)[:self.seq_len]\n",
634 | "\n",
635 | " padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]\n",
636 | " bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)\n",
637 | "\n",
638 | " output = {\"bert_input\": bert_input,\n",
639 | " \"bert_label\": bert_label,\n",
640 | " \"segment_label\": segment_label,\n",
641 | " \"is_next\": is_next_label,\n",
642 | " \"labels\": labels}\n",
643 | "\n",
644 | " return {key: torch.tensor(value) for key, value in output.items()}\n",
645 | "\n",
646 | " def random_word(self, sentence):\n",
647 | " tokens = sentence.split()\n",
648 | " output_label = []\n",
649 | "\n",
650 | " for i, token in enumerate(tokens):\n",
651 | " if self.is_train: # Trainingの時は確率的にMASKする\n",
652 | " prob = random.random()\n",
653 | " else: # Predictionの時はMASKをしない\n",
654 | " prob = 1.0\n",
655 | " if prob < 0.15:\n",
656 | " prob /= 0.15\n",
657 | "\n",
658 | " # 80% randomly change token to mask token\n",
659 | " if prob < 0.8:\n",
660 | " tokens[i] = self.vocab.mask_index\n",
661 | "\n",
662 | " # 10% randomly change token to random token\n",
663 | " elif prob < 0.9:\n",
664 | " tokens[i] = random.randrange(len(self.vocab))\n",
665 | "\n",
666 | " # 10% randomly change token to current token\n",
667 | " else:\n",
668 | " tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)\n",
669 | "\n",
670 | " output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))\n",
671 | "\n",
672 | " else:\n",
673 | " tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)\n",
674 | " output_label.append(0)\n",
675 | "\n",
676 | " return tokens, output_label\n",
677 | "\n",
678 | " def random_sent(self, index):\n",
679 | " # output_text, label(isNotNext:0, isNext:1)\n",
680 | " if random.random() > 0.5:\n",
681 | " return self.datas[index][1], 1\n",
682 | " else:\n",
683 | " return self.datas[random.randrange(len(self.datas))][1], 0\n"
684 | ]
685 | },
686 | {
687 | "cell_type": "markdown",
688 | "metadata": {},
689 | "source": [
690 | "Trainerクラスを定義する.\n",
691 | "BERTの事前学習ではふたつの言語モデル学習を行う.\n",
692 | "1. Masked Language Model : 文章中の一部の単語をマスクして,予測を行うタスク.\n",
693 | "2. Next Sentence prediction : ある文章の次に来る文章を予測するタスク."
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": null,
699 | "metadata": {},
700 | "outputs": [],
701 | "source": [
702 | "class BERTTrainer:\n",
703 | " def __init__(self, bert: BERT, vocab_size: int,\n",
704 | " train_dataloader: DataLoader, test_dataloader: DataLoader = None,\n",
705 | " lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,\n",
706 | " with_cuda: bool = True, log_freq: int = 10):\n",
707 | " \"\"\"\n",
708 | " :param bert: BERT model\n",
709 | " :param vocab_size: vocabに含まれるトータルの単語数\n",
710 | " :param train_dataloader: train dataset data loader\n",
711 | " :param test_dataloader: test dataset data loader [can be None]\n",
712 | " :param lr: 学習率\n",
713 | " :param betas: Adam optimizer betas\n",
714 | " :param weight_decay: Adam optimizer weight decay param\n",
715 | " :param with_cuda: traning with cuda\n",
716 | " :param log_freq: logを表示するiterationの頻度\n",
717 | " \"\"\"\n",
718 | "\n",
719 | " # GPU環境において、GPUを指定しているかのフラグ\n",
720 | " cuda_condition = torch.cuda.is_available() and with_cuda\n",
721 | " self.device = torch.device(\"cuda:0\" if cuda_condition else \"cpu\")\n",
722 | "\n",
723 | " self.bert = bert\n",
724 | " self.model = BERTLM(bert, vocab_size).to(self.device)\n",
725 | "\n",
726 | " if torch.cuda.device_count() > 1:\n",
727 | " print(\"Using %d GPUS for BERT\" % torch.cuda.device_count())\n",
728 | " self.model = nn.DataParallel(self.model)\n",
729 | "\n",
730 | " self.train_data = train_dataloader\n",
731 | " self.test_data = test_dataloader\n",
732 | "\n",
733 | " self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)\n",
734 | "\n",
735 | " # masked_token予測のためのLoss関数を設定\n",
736 | " self.criterion = nn.NLLLoss()\n",
737 | " self.log_freq = log_freq\n",
738 | " print(\"Total Parameters:\", sum([p.nelement() for p in self.model.parameters()]))\n",
739 | " \n",
740 | " self.train_lossses = []\n",
741 | " self.train_accs = []\n",
742 | "\n",
743 | " def train(self, epoch):\n",
744 | " self.iteration(epoch, self.train_data)\n",
745 | "\n",
746 | " def test(self, epoch):\n",
747 | " self.iteration(epoch, self.test_data, train=False)\n",
748 | "\n",
749 | " def iteration(self, epoch, data_loader, train=True):\n",
750 | " \"\"\"\n",
751 | " :param epoch: 現在のepoch\n",
752 | " :param data_loader: torch.utils.data.DataLoader\n",
753 | " :param train: trainかtestかのbool値\n",
754 | " \"\"\"\n",
755 | " str_code = \"train\" if train else \"test\"\n",
756 | "\n",
757 | " data_iter = tqdm.tqdm(enumerate(data_loader), desc=\"EP_%s:%d\" % (str_code, epoch), total=len(data_loader), bar_format=\"{l_bar}{r_bar}\")\n",
758 | "\n",
759 | "\n",
760 | " avg_loss = 0.0\n",
761 | " total_correct = 0\n",
762 | " total_element = 0\n",
763 | "\n",
764 | " for i, data in data_iter:\n",
765 | " # 0. batch_dataはGPU or CPUに載せる\n",
766 | " data = {key: value.to(self.device) for key, value in data.items()}\n",
767 | "\n",
768 | " # 1. forward the next_sentence_prediction and masked_lm model\n",
769 | " next_sent_output, mask_lm_output = self.model.forward(data[\"bert_input\"], data[\"segment_label\"])\n",
770 | "\n",
771 | " # 2-1. NLLLoss(negative log likelihood) : next_sentence_predictionのLoss\n",
772 | " next_loss = self.criterion(next_sent_output, data[\"is_next\"])\n",
773 | "\n",
774 | " # 2-2. NLLLoss(negative log likelihood) : predicting masked token word\n",
775 | " mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data[\"bert_label\"])\n",
776 | "\n",
777 | " # 2-3. next_lossとmask_lossの合計をlossとする\n",
778 | " loss = next_loss + mask_loss\n",
779 | "\n",
780 | " # 3. training時のみ,backwardとoptimizer更新を行う\n",
781 | " if train:\n",
782 | " self.optim.zero_grad()\n",
783 | " loss.backward()\n",
784 | " self.optim.step()\n",
785 | "\n",
786 | " # next sentence prediction accuracy\n",
787 | " correct = next_sent_output.argmax(dim=-1).eq(data[\"is_next\"]).sum().item()\n",
788 | " avg_loss += loss.item()\n",
789 | " total_correct += correct\n",
790 | " total_element += data[\"is_next\"].nelement()\n",
791 | "\n",
792 | " post_fix = {\n",
793 | " \"epoch\": epoch,\n",
794 | " \"iter\": i,\n",
795 | " \"avg_loss\": avg_loss / (i + 1),\n",
796 | " \"avg_acc\": total_correct / total_element * 100,\n",
797 | " \"loss\": loss.item()\n",
798 | " }\n",
799 | "\n",
800 | " if i % self.log_freq == 0:\n",
801 | " data_iter.write(str(post_fix))\n",
802 | "\n",
803 | " print(\"EP%d_%s, avg_loss=\" % (epoch, str_code), avg_loss / len(data_iter), \"total_acc=\", total_correct * 100.0 / total_element)\n",
804 | " self.train_lossses.append(avg_loss / len(data_iter))\n",
805 | " self.train_accs.append(total_correct * 100.0 / total_element)\n",
806 | " \n",
807 | " def save(self, epoch, file_path=\"output/bert_trained.model\"):\n",
808 | " \"\"\"\n",
809 | " Saving the current BERT model on file_path\n",
810 | "\n",
811 | " :param epoch: current epoch number\n",
812 | " :param file_path: model output path which gonna be file_path+\"ep%d\" % epoch\n",
813 | " :return: final_output_path\n",
814 | " \"\"\"\n",
815 | " output_path = file_path + \".ep%d\" % epoch\n",
816 | " torch.save(self.bert.cpu(), output_path)\n",
817 | " self.bert.to(self.device)\n",
818 | " print(\"EP:%d Model Saved on:\" % epoch, output_path)\n",
819 | " return output_path"
820 | ]
821 | },
822 | {
823 | "cell_type": "code",
824 | "execution_count": null,
825 | "metadata": {},
826 | "outputs": [],
827 | "source": [
828 | "import datetime\n",
829 | "dt_now = str(datetime.datetime.now()).replace(' ', '')"
830 | ]
831 | },
832 | {
833 | "cell_type": "code",
834 | "execution_count": null,
835 | "metadata": {},
836 | "outputs": [],
837 | "source": [
838 | "# 訓練用パラメタを定義する\n",
839 | "train_dataset=processed_train_txt\n",
840 | "test_dataset=processed_valid_txt\n",
841 | "vocab_path='./data/vocab'+ dt_now +'.txt'\n",
842 | "output_model_path='./output/bertmodel'+ dt_now\n",
843 | "\n",
844 | "hidden=256 #768\n",
845 | "layers=8 #12\n",
846 | "attn_heads=8 #12\n",
847 | "seq_len=60\n",
848 | "\n",
849 | "batch_size=64\n",
850 | "epochs=10\n",
851 | "num_workers=5\n",
852 | "with_cuda=True\n",
853 | "log_freq=20\n",
854 | "corpus_lines=None\n",
855 | "\n",
856 | "lr=1e-3\n",
857 | "adam_weight_decay=0.00\n",
858 | "adam_beta1=0.9\n",
859 | "adam_beta2=0.999\n",
860 | "\n",
861 | "dropout=0.0\n",
862 | "\n",
863 | "min_freq=7\n",
864 | "\n",
865 | "corpus_path=processed_train_txt\n",
866 | "label_path=None"
867 | ]
868 | },
869 | {
870 | "cell_type": "code",
871 | "execution_count": null,
872 | "metadata": {},
873 | "outputs": [],
874 | "source": [
875 | "build(corpus_path, vocab_path, min_freq=min_freq)\n",
876 | "\n",
877 | "print(\"Loading Vocab\", vocab_path)\n",
878 | "vocab = WordVocab.load_vocab(vocab_path)\n",
879 | "\n",
880 | "print(\"Loading Train Dataset\", train_dataset)\n",
881 | "train_dataset = BERTDataset(train_dataset, vocab, seq_len=seq_len, label_path=label_path, corpus_lines=corpus_lines)\n",
882 | "\n",
883 | "print(\"Loading Test Dataset\", test_dataset)\n",
884 | "test_dataset = BERTDataset(test_dataset, vocab, seq_len=seq_len, label_path=label_path) if test_dataset is not None else None\n",
885 | "\n",
886 | "print(\"Creating Dataloader\")\n",
887 | "train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)\n",
888 | "test_data_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers) if test_dataset is not None else None\n",
889 | "\n",
890 | "print(\"Building BERT model\")\n",
891 | "bert = BERT(len(vocab), hidden=hidden, n_layers=layers, attn_heads=attn_heads, dropout=dropout)"
892 | ]
893 | },
894 | {
895 | "cell_type": "code",
896 | "execution_count": null,
897 | "metadata": {},
898 | "outputs": [],
899 | "source": [
900 | "print(\"Creating BERT Trainer\")\n",
901 | "trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,\n",
902 | " lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,\n",
903 | " with_cuda=with_cuda, log_freq=log_freq)"
904 | ]
905 | },
906 | {
907 | "cell_type": "code",
908 | "execution_count": null,
909 | "metadata": {},
910 | "outputs": [],
911 | "source": [
912 | "print(\"Training Start\")\n",
913 | "for epoch in range(epochs):\n",
914 | " trainer.train(epoch)\n",
915 | " # Model Save\n",
916 | " trainer.save(epoch, output_model_path)\n",
917 | " trainer.test(epoch)"
918 | ]
919 | }
920 | ],
921 | "metadata": {
922 | "kernelspec": {
923 | "display_name": "Python 3",
924 | "language": "python",
925 | "name": "python3"
926 | },
927 | "language_info": {
928 | "codemirror_mode": {
929 | "name": "ipython",
930 | "version": 3
931 | },
932 | "file_extension": ".py",
933 | "mimetype": "text/x-python",
934 | "name": "python",
935 | "nbconvert_exporter": "python",
936 | "pygments_lexer": "ipython3",
937 | "version": "3.6.1"
938 | }
939 | },
940 | "nbformat": 4,
941 | "nbformat_minor": 2
942 | }
943 |
--------------------------------------------------------------------------------