├── .ipynb_checkpoints └── attention-checkpoint.ipynb ├── README.md ├── attention.ipynb ├── data ├── data.csv ├── human_vocab.json └── machine_vocab.json ├── generate.py └── img ├── attn_ex_2.png └── attn_seq2seq.png /.ipynb_checkpoints/attention-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Tìm Hiểu và Áp Dụng Cơ Chế Attention\n", 8 | "### Giới thiệu\n", 9 | "Trong phần này mình hướng dẫn các bạn cách cài đặt cơ chế attention trong bài toán seq2seq đơn giản hóa. \n", 10 | "\n", 11 | "Cơ chế attention chỉ đơn giản là trung bình có trọng số của những “thứ” mà chúng ta nghĩ nó cần thiết cho bài toán, điều đặc biệt là trọng số này do mô hình tự học được. Cụ thể, trong bài toán dịch máy ở ví dụ dưới, khi sử dụng cơ chế attention để phát sinh từ little, mình sẽ cần tính một vector context C là trung bình có trọng số của vector biểu diễn các từ mặt, trời, bé, nhỏ tương ứng với vector h1,h2,h3,h4, rồi sử dụng thêm vector context c này tại lúc dự đoán từ little, và nhớ rằng, trọng số này là các số scalar, được mô hình tự học \n", 12 | "\n", 13 | "![attention](https://github.com/pbcquoc/attention_tutorial/raw/master/img/attn_seq2seq.png)\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "### Import thư viện\n", 21 | "Mình sử dụng thêm cái lib của keras để hỗ trợ tiền xử lý cho nhanh, và tập trung chủ yếu vào phần cài đặt cơ chế attention\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": { 28 | "colab": { 29 | "base_uri": "https://localhost:8080/", 30 | "height": 35 31 | }, 32 | "colab_type": "code", 33 | "id": "o43HsYMod3xO", 34 | "outputId": "867dbbd8-5404-4f06-9e07-d5a2fc4ed5ab" 35 | }, 36 | "outputs": [ 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "Using TensorFlow backend.\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "import pandas as pd\n", 47 | "import numpy as np\n", 48 | "from keras.preprocessing.text import Tokenizer\n", 49 | "from keras.preprocessing.sequence import pad_sequences\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "import torch\n", 52 | "import torch.nn as nn\n", 53 | "import torch.nn.functional as F\n", 54 | "from torch import optim\n", 55 | "from torch.autograd import Variable\n", 56 | "from sklearn.model_selection import train_test_split\n", 57 | "import torch.utils.data\n", 58 | "import matplotlib.pyplot as plt\n", 59 | "import matplotlib.ticker as ticker\n", 60 | "from random import randint" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### Download tập huấn luyện\n", 68 | "Để minh họa cơ chế attention, mình sử dụng tập dataset tự phát sinh, với đầu vào là các câu biểu diễn ngày tháng năm của con người đọc, và nhãn là ngày tháng năm tương ứng do máy tính hiểu.\n", 69 | "\n", 70 | "Mình đã phát sinh tổng cộng 20k mẫu, trong đó 5k dùng để validation." 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 2, 76 | "metadata": { 77 | "colab": { 78 | "base_uri": "https://localhost:8080/", 79 | "height": 35 80 | }, 81 | "colab_type": "code", 82 | "id": "iDdQn4P6d3xl", 83 | "outputId": "957055c0-912b-469b-a78c-0c7e215b8c5c" 84 | }, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "data.csv human_vocab.json machine_vocab.json\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "! curl --silent -L -o data.zip \"https://drive.google.com/uc?export=download&id=1d6eUqRstk7NIpyASzbuIsDvBdHEwfU0g\"\n", 96 | "! unzip -q data.zip\n", 97 | "! ls data" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "### Tiền xử lý\n", 105 | "Tập vocab mình xử dụng là các kí tự alphabet và số. Các bạn không cần phải filter, các kí tự đặt biệt trong tập dữ liệu" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 3, 111 | "metadata": { 112 | "colab": { 113 | "base_uri": "https://localhost:8080/", 114 | "height": 35 115 | }, 116 | "colab_type": "code", 117 | "id": "17yzXiMWd3xp", 118 | "outputId": "39370600-b519-4d4d-bae7-79dd8fd38a27" 119 | }, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "train size: 18750 - test size: 6250\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "def load_data(path):\n", 131 | " df = pd.read_csv(path, header=None)\n", 132 | " X = df[0].values\n", 133 | " y = df[1].values\n", 134 | " x_tok = Tokenizer(char_level=True, filters='')\n", 135 | " x_tok.fit_on_texts(X)\n", 136 | " y_tok = Tokenizer(char_level=True, filters='')\n", 137 | " y_tok.fit_on_texts(y)\n", 138 | " \n", 139 | " X = x_tok.texts_to_sequences(X)\n", 140 | " y = y_tok.texts_to_sequences(y)\n", 141 | " \n", 142 | " X = pad_sequences(X)\n", 143 | " y = np.asarray(y)\n", 144 | " \n", 145 | " return X, y, x_tok.word_index, y_tok.word_index\n", 146 | "\n", 147 | "X, y, x_wid, y_wid= load_data('data/data.csv')\n", 148 | "x_id2w = dict(zip(x_wid.values(), x_wid.keys()))\n", 149 | "y_id2w = dict(zip(y_wid.values(), y_wid.keys()))\n", 150 | "X_train, X_test, y_train, y_test = train_test_split(X, y)\n", 151 | "print('train size: {} - test size: {}'.format(len(X_train), len(X_test)))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "### Định nghĩa các tham số" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 4, 164 | "metadata": { 165 | "colab": { 166 | "base_uri": "https://localhost:8080/", 167 | "height": 35 168 | }, 169 | "colab_type": "code", 170 | "id": "Fy7StSyzd3xz", 171 | "outputId": "b77edae4-f419-412a-dd33-fccf98786432" 172 | }, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "input vocab: 35 - output vocab: 13 - length of target: 10\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "# hidden size cho môt hình LSTM\n", 184 | "hidden_size = 128\n", 185 | "learning_rate = 0.001\n", 186 | "decoder_learning_ratio = 0.1\n", 187 | "\n", 188 | "# tập tự vựng của các câu đầu vào \n", 189 | "# +1 vì các bạn cần kí tự padding nhé!\n", 190 | "input_size = len(x_wid) + 1\n", 191 | "\n", 192 | "# +2 vì các bạn cần kí tự bắt đầu và kết thức\n", 193 | "output_size = len(y_wid) + 2\n", 194 | "# 2 kí tự này nằm ở cuối\n", 195 | "sos_idx = len(y_wid) \n", 196 | "eos_idx = len(y_wid) + 1\n", 197 | "\n", 198 | "max_length = y.shape[1]\n", 199 | "print(\"input vocab: {} - output vocab: {} - length of target: {}\".format(input_size, output_size, max_length))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "Chuyển sang dạng chuỗi kí tự đọc được từ chuỗi số" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 0, 212 | "metadata": { 213 | "colab": {}, 214 | "colab_type": "code", 215 | "id": "2LksDULqd3x3" 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "def decoder_sentence(idxs, vocab):\n", 220 | " text = ''.join([vocab[w] for w in idxs if (w > 0) and (w in vocab)])\n", 221 | " return text" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "## Định nghĩa mô hình\n", 229 | "Ở phần này, các bạn cần định nghĩa 3 mô hình nhỏ\n", 230 | "* Encoder: là một mô hình LSTM, dùng để học biểu diễn của câu\n", 231 | "* Attention: dùng để học cách kết hợp để tạo ra context vector\n", 232 | "* Decoder: là một mô hình LSTM, chúng ta sẽ kết hợp context vector vào mô hình này để dự đoán các từ tại mỗi thời điểm\n", 233 | "\n", 234 | "![model](https://github.com/pbcquoc/pbcquoc.github.io/raw/master/images/attn_seq2seq_attn.png)\n", 235 | "\n", 236 | "### Encoder\n", 237 | "Mô hình này nhận đầu vào là các câu, các bạn có thể xem các hidden state h1,h2,h3,h4 như các biểu diễn của mỗi từ, và muốn tổng hợp context vector trên những thông tin này. \n", 238 | "### Attention\n", 239 | "Mô hình này học các trọng số alpha trên các h1,h2,h3,h4 rồi sau đó tổng hợp context vector theo trung bình có trọng số alpha này\n", 240 | "### Decoder\n", 241 | "Ở thời điểm dự đoán, các bạn sử dụng thêm context vector để bổ sung thông tin cho mô hình. " 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 0, 247 | "metadata": { 248 | "colab": {}, 249 | "colab_type": "code", 250 | "id": "yVx1T2LJd3x6" 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "class Encoder(nn.Module):\n", 255 | " def __init__(self, input_size, hidden_size):\n", 256 | " super(Encoder, self).__init__()\n", 257 | " self.hidden_size = hidden_size\n", 258 | " # embedding vector của từ\n", 259 | " self.embedding = nn.Embedding(input_size, hidden_size)\n", 260 | " # mô hình GRU biến thể RNN để học vector biểu diễn của câu\n", 261 | " self.gru = nn.GRU(hidden_size, hidden_size)\n", 262 | " \n", 263 | " def forward(self, input):\n", 264 | " # input: SxB \n", 265 | " embedded = self.embedding(input)\n", 266 | " output, hidden = self.gru(embedded)\n", 267 | " return output, hidden # SxBxH, 1xBxH \n", 268 | "\n", 269 | "class Attn(nn.Module):\n", 270 | " def __init__(self, hidden_size):\n", 271 | " super(Attn ,self).__init__()\n", 272 | " \n", 273 | " def forward(self, hidden, encoder_outputs):\n", 274 | " ### Mô hình nhận trạng thái hidden hiện tại của mô hình decoder, \n", 275 | " ### và các hidden states của mô hình encoder\n", 276 | " # encoder_outputs: TxBxH\n", 277 | " # hidden: SxBxH\n", 278 | " \n", 279 | " # tranpose về đúng shape để nhận ma trận\n", 280 | " encoder_outputs = torch.transpose(encoder_outputs, 0, 1) #BxTxH\n", 281 | " hidden = torch.transpose(torch.transpose(hidden, 0, 1), 1, 2) # BxHxS\n", 282 | " # tính e, chính là tương tác giữ hidden và các trạng thái ẩn của mô hình encoder \n", 283 | " energies = torch.bmm(encoder_outputs, hidden) # BxTxS\n", 284 | " energies = torch.transpose(energies, 1, 2) # BxSxT\n", 285 | " # tính alpha, chính là trọng số của trung bình có trọng số cần tính bằng hàm softmax\n", 286 | " attn_weights = F.softmax(energies, dim=-1) #BxSxT\n", 287 | " \n", 288 | " # tính context vector bằng trung binh có trọng số\n", 289 | " output = torch.bmm(attn_weights, encoder_outputs) # BxSxH\n", 290 | " \n", 291 | " # trả về chiều cần thiết\n", 292 | " output = torch.transpose(output, 0, 1) # SxBxH\n", 293 | " attn_weights = torch.transpose(attn_weights, 0, 1) #SxBxT\n", 294 | " \n", 295 | " # return context vector và các trọng số alpha cho mục đích biểu diễn cơ chế attention\n", 296 | " return output, attn_weights\n", 297 | " \n", 298 | "class Decoder(nn.Module):\n", 299 | " def __init__(self, output_size, hidden_size, dropout):\n", 300 | " super(Decoder, self).__init__()\n", 301 | " self.hidden_size = hidden_size\n", 302 | " self.output_size = output_size\n", 303 | " \n", 304 | " # vector biểu diễn cho các từ của output\n", 305 | " self.embedding = nn.Embedding(output_size, hidden_size)\n", 306 | " # định nghĩa mô hình attention ở trên\n", 307 | " self.attn = Attn(hidden_size)\n", 308 | " self.dropout = nn.Dropout(dropout)\n", 309 | " # mô hình decoder là GRU\n", 310 | " self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n", 311 | " \n", 312 | " # dự đoán các từ tại mội thời điểm, chúng ta nối 2 vector hidden và context lại với nhau \n", 313 | " self.concat = nn.Linear(self.hidden_size*2, hidden_size) \n", 314 | " self.out = nn.Linear(self.hidden_size, self.output_size)\n", 315 | " \n", 316 | " def forward(self, input, hidden, encoder_outputs):\n", 317 | " # input: SxB\n", 318 | " # encoder_outputs: BxSxH\n", 319 | " # hidden: 1xBxH\n", 320 | " embedded = self.embedding(input) # 1xBxH\n", 321 | " embedded = self.dropout(embedded)\n", 322 | " \n", 323 | " # biểu diễn của câu\n", 324 | " rnn_output, hidden = self.gru(embedded, hidden) #SxBxH, 1xBxH\n", 325 | " # tính context vector dựa trên các hidden states\n", 326 | " context, attn_weights = self.attn(rnn_output, encoder_outputs) # SxBxH\n", 327 | " \n", 328 | " # nối hidden state của mô hình decoder hiện tại và context vector để dự đoán \n", 329 | " concat_input = torch.cat((rnn_output, context), -1)\n", 330 | " concat_output = torch.tanh(self.concat(concat_input)) #SxBxH\n", 331 | " \n", 332 | " # dự đoán kết quả tại mỗi thời điểm\n", 333 | " output = self.out(concat_output) # SxBxoutput_size\n", 334 | " return output, hidden, attn_weights\n", 335 | "\n" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "### Kiểm tra\n", 343 | "Chúng ta khởi tạo mô hình để kiểm tra xem mô hình có chạy được không, ít nhất là không bị lỗi về tính toán" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 0, 349 | "metadata": { 350 | "colab": {}, 351 | "colab_type": "code", 352 | "id": "ArLtO8rsd3yA" 353 | }, 354 | "outputs": [], 355 | "source": [ 356 | "encoder = Encoder(input_size, hidden_size)\n", 357 | "decoder = Decoder(output_size, hidden_size, 0.1)\n", 358 | "\n", 359 | "# Initialize optimizers and criterion\n", 360 | "encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n", 361 | "decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)\n", 362 | "criterion = nn.CrossEntropyLoss()\n", 363 | "\n", 364 | "\n", 365 | "input_encoder = torch.randint(1, input_size, (34, 6), dtype=torch.long)\n", 366 | "encoder_outputs, hidden = encoder(input_encoder)\n", 367 | "input_decoder = torch.randint(1, output_size, (10, 6), dtype=torch.long)\n", 368 | "output, hidden, attn_weights = decoder(input_decoder, hidden, encoder_outputs)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "## Train/test\n", 376 | "Phần này chúng ta định nghĩa một số hàm để huấn luyện, dự đoán mô hình " 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 6, 382 | "metadata": { 383 | "colab": {}, 384 | "colab_type": "code", 385 | "id": "TnARQv5td3yG" 386 | }, 387 | "outputs": [ 388 | { 389 | "ename": "NameError", 390 | "evalue": "name 'max_length' is not defined", 391 | "output_type": "error", 392 | "traceback": [ 393 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 394 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 395 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0meval_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_idx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecoder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;31m### Lúc dự đoán chúng ta cần tính kết quả ngay lập tức tại mỗi thời điểm, rồi sau đó dừng từ được dự đoán để tính từ tiếp theo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 396 | "\u001b[0;31mNameError\u001b[0m: name 'max_length' is not defined" 397 | ] 398 | } 399 | ], 400 | "source": [ 401 | "def forward_and_compute_loss(inputs, targets, encoder, decoder, criterion):\n", 402 | " batch_size = inputs.size()[1]\n", 403 | " \n", 404 | " # định nghĩa 2 kí tự bắt đầu và kết thúc\n", 405 | " sos = Variable(torch.ones((1, batch_size), dtype=torch.long)*sos_idx)\n", 406 | " eos = Variable(torch.ones((1, batch_size), dtype=torch.long)*eos_idx)\n", 407 | " \n", 408 | " # input của mô hình decoder phải thêm kí tự bắt đầu\n", 409 | " decoder_inputs = torch.cat((sos, targets), dim=0)\n", 410 | " # output cần dự đoán của mô hình decoder phải thêm kí tự kết thúc\n", 411 | " decoder_targets = torch.cat((targets, eos), dim=0)\n", 412 | " \n", 413 | " # forward tính hidden states của câu\n", 414 | " encoder_outputs, encoder_hidden = encoder(inputs)\n", 415 | " # tính output của mô hình decoder\n", 416 | " output, hidden, attn_weights = decoder(decoder_inputs, encoder_hidden, encoder_outputs)\n", 417 | " \n", 418 | " output = torch.transpose(torch.transpose(output, 0, 1), 1, 2) # BxCxS\n", 419 | " decoder_targets = torch.transpose(decoder_targets, 0, 1)\n", 420 | " # tính loss \n", 421 | " loss = criterion(output, decoder_targets)\n", 422 | " \n", 423 | " return loss, output\n", 424 | "\n", 425 | "def train(inputs, targets, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):\n", 426 | " # khai báo train để mô hình biết là đang train hay test\n", 427 | " encoder.train()\n", 428 | " decoder.train()\n", 429 | " \n", 430 | " # zero gradient, phải làm mỗi khi cập nhất gradient\n", 431 | " encoder_optimizer.zero_grad()\n", 432 | " decoder_optimizer.zero_grad()\n", 433 | " \n", 434 | " # tính loss dựa vào hàm đã định nghĩa ở trên\n", 435 | " train_loss, output = forward_and_compute_loss(inputs, targets,encoder, decoder,criterion) \n", 436 | " \n", 437 | " train_loss.backward()\n", 438 | " # cập nhật một step\n", 439 | " encoder_optimizer.step()\n", 440 | " decoder_optimizer.step()\n", 441 | " \n", 442 | " # return loss để print :D\n", 443 | " return train_loss.item()\n", 444 | "\n", 445 | "def evaluate(inputs, targets, encoder, decoder, criterion):\n", 446 | " # báo cho mô hình biết đang test/eval\n", 447 | " encoder.eval()\n", 448 | " decoder.eval()\n", 449 | " # tính loss\n", 450 | " eval_loss, output = forward_and_compute_loss(inputs, targets, encoder, decoder,criterion)\n", 451 | " output = torch.transpose(output, 1, 2)\n", 452 | " # dự đoán của mỗi thời điểm các vị trí có prob lớn nhất\n", 453 | " pred_idx = torch.argmax(output, dim=-1).squeeze(-1)\n", 454 | " pred_idx = pred_idx.data.cpu().numpy()\n", 455 | " \n", 456 | " # return loss và kết quả dự đoán\n", 457 | " return eval_loss.item(), pred_idx\n", 458 | "\n", 459 | "def predict(inputs, encoder, decoder, target_length=max_length):\n", 460 | " ### Lúc dự đoán chúng ta cần tính kết quả ngay lập tức tại mỗi thời điểm, \n", 461 | " ### rồi sau đó dừng từ được dự đoán để tính từ tiếp theo \n", 462 | " batch_size = inputs.size()[1]\n", 463 | " \n", 464 | " # input đầu tiên của mô hình decoder là kí tự bắt đầu, chúng ta dự đoán kí tự tiếp theo, sau đó lại dùng kí tự này để dự đoán từ kế tiếp\n", 465 | " decoder_inputs = Variable(torch.ones((1, batch_size), dtype=torch.long)*sos_idx)\n", 466 | " \n", 467 | " # tính hidden state của mô hình encoder, cũng là vector biểu diễn của các từ, chúng ta cần tính context vector dựa trên những hidden states này\n", 468 | " encoder_outputs, encoder_hidden = encoder(inputs)\n", 469 | " hidden = encoder_hidden\n", 470 | " \n", 471 | " preds = []\n", 472 | " attn_weights = []\n", 473 | " # chúng ta tính từng từ tại mỗi thời điểm\n", 474 | " for i in range(target_length):\n", 475 | " # dự đoán từ đầu tiên\n", 476 | " output, hidden, attn_weight = decoder(decoder_inputs, hidden, encoder_outputs)\n", 477 | " output = output.squeeze(dim=0)\n", 478 | " pred_idx = torch.argmax(output, dim=-1)\n", 479 | " \n", 480 | " # thay đổi input tiếp theo bằng từ vừa được dự đoán\n", 481 | " decoder_inputs = Variable(torch.ones((1, batch_size), dtype=torch.long)*pred_idx)\n", 482 | " preds.append(decoder_inputs)\n", 483 | " attn_weights.append(attn_weight.detach())\n", 484 | " \n", 485 | " preds = torch.cat(preds, dim=0)\n", 486 | " preds = torch.transpose(preds, 0, 1)\n", 487 | " attn_weights = torch.cat(attn_weights, dim=0)\n", 488 | " attn_weights = torch.transpose(attn_weights, 0, 1)\n", 489 | " return preds, attn_weights" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": { 495 | "colab": { 496 | "base_uri": "https://localhost:8080/", 497 | "height": 35 498 | }, 499 | "colab_type": "code", 500 | "id": "l-h8o5Rtd3yW", 501 | "outputId": "244ec330-69d5-49f9-8256-c2ab12b6bdfe" 502 | }, 503 | "source": [ 504 | "### Train và eval\n", 505 | "Trong phần này, chúng ta train mô hình, cũng như theo dõi độ lỗi, kết quả dự đoán tại mỗi epoch. " 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 10, 511 | "metadata": { 512 | "colab": { 513 | "base_uri": "https://localhost:8080/", 514 | "height": 725 515 | }, 516 | "colab_type": "code", 517 | "id": "Fyfzas04d3yZ", 518 | "outputId": "7da23839-0161-4916-e023-8bed89f11b8d" 519 | }, 520 | "outputs": [ 521 | { 522 | "name": "stdout", 523 | "output_type": "stream", 524 | "text": [ 525 | "Epoch 0 - train loss: 0.381 - eval loss: 0.341\n", 526 | " 10 03 17 \t1017-03-10\n", 527 | " 11 thg 6 2004 \t2004-06-11\n", 528 | " 01 thg 1 1982 \t1982-01-01\n", 529 | "Epoch 1 - train loss: 0.044 - eval loss: 0.041\n", 530 | " 2 tháng 9 1994 \t1994-09-02\n", 531 | " thứ năm, ngày 15 tháng 11 năm 2012 \t 2012-1115\n", 532 | " 9 thg 3, 1987 \t1987-03-09\n", 533 | "Epoch 2 - train loss: 0.020 - eval loss: 0.017\n", 534 | " 15.04.03 \t2003-04-15\n", 535 | " 24 06 96 \t1996-06-24\n", 536 | " 21 tháng 1 2013 \t2013-01-21\n", 537 | "Epoch 3 - train loss: 0.010 - eval loss: 0.009\n", 538 | " thứ hai, ngày 07 tháng 5 năm 1984 \t1984-05-07\n", 539 | " 4 thg 4, 1984 \t1984-04-04\n", 540 | " ngày 25 tháng 11 năm 1990 \t1990-11-25\n", 541 | "Epoch 4 - train loss: 0.008 - eval loss: 0.008\n", 542 | " 23 thg 2 2003 \t2003-02-23\n", 543 | " 15 thg 1, 2015 \t2015-01-15\n", 544 | " 24/07/1986 \t1986-07-24\n", 545 | "Epoch 5 - train loss: 0.004 - eval loss: 0.004\n", 546 | " 22 tháng 12 2002 \t2002-12-22\n", 547 | " 20 thg 5, 2003 \t2003-05-20\n", 548 | " thứ ba, ngày 06 tháng 8 năm 1991 \t1991-08-06\n", 549 | "Epoch 6 - train loss: 0.005 - eval loss: 0.003\n", 550 | " 6 thg 12 2009 \t2009-12-06\n", 551 | " ngày 12 tháng 10 năm 1972 \t1972-10-12\n", 552 | " 10 thg 5, 1974 \t1974-05-10\n", 553 | "Epoch 7 - train loss: 0.002 - eval loss: 0.002\n", 554 | " 4 tháng 1, 2001 \t2001-01-04\n", 555 | " 22 thg 8 1972 \t1972-08-22\n", 556 | " 16 thg 8 1997 \t1997-08-16\n", 557 | "Epoch 8 - train loss: 0.002 - eval loss: 0.002\n", 558 | " tháng 11 13, 1981 \t1981-11-13\n", 559 | " 15 thg 6, 1994 \t1994-06-15\n", 560 | " 29 thg 10, 1996 \t1996-10-29\n", 561 | "Epoch 9 - train loss: 0.003 - eval loss: 0.001\n", 562 | " 9 thg 9, 2012 \t2012-09-09\n", 563 | " 17.11.01 \t2001-11-17\n", 564 | " 17 tháng 11, 1971 \t1971-11-17\n" 565 | ] 566 | } 567 | ], 568 | "source": [ 569 | "epochs = 10\n", 570 | "batch_size = 64\n", 571 | "\n", 572 | "encoder = Encoder(input_size, hidden_size)\n", 573 | "decoder = Decoder(output_size, hidden_size, 0.1)\n", 574 | "\n", 575 | "# Initialize optimizers and criterion\n", 576 | "encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n", 577 | "decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)\n", 578 | "criterion = nn.CrossEntropyLoss()\n", 579 | "\n", 580 | "X_val = torch.tensor(X_test, dtype=torch.long)\n", 581 | "y_val = torch.tensor(y_test, dtype=torch.long)\n", 582 | "X_val = torch.transpose(X_val, 0, 1)\n", 583 | "y_val = torch.transpose(y_val, 0, 1)\n", 584 | "\n", 585 | "for epoch in range(epochs):\n", 586 | " for idx in range(len(X_train)//batch_size):\n", 587 | " # input đầu vào của chúng ta là timestep first nhé. \n", 588 | " X_train_batch = torch.tensor(X_train[batch_size*idx:batch_size*(idx+1)], dtype=torch.long)\n", 589 | " y_train_batch = torch.tensor(y_train[batch_size*idx:batch_size*(idx+1)], dtype=torch.long)\n", 590 | " \n", 591 | " X_train_batch = torch.transpose(X_train_batch, 0, 1)\n", 592 | " y_train_batch = torch.transpose(y_train_batch, 0, 1)\n", 593 | " train_loss= train(X_train_batch, y_train_batch, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n", 594 | " eval_loss, preds = evaluate(X_val, y_val, encoder, decoder, criterion)\n", 595 | " \n", 596 | " print('Epoch {} - train loss: {:.3f} - eval loss: {:.3f}'.format(epoch, train_loss, eval_loss))\n", 597 | " print_idx = np.random.randint(0, len(preds), 3)\n", 598 | " for i in print_idx:\n", 599 | " x_val = decoder_sentence(X_val[:,i].numpy(), x_id2w)\n", 600 | " y_pred = decoder_sentence(preds[i], y_id2w)\n", 601 | " print(\" {:<35s}\\t{:>10}\".format(x_val, y_pred))" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": {}, 607 | "source": [ 608 | "## Predict\n", 609 | "Chúng ta dự đoán một vài mẫu và phân tích một số kết quả của cơ chế attention" 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": 0, 615 | "metadata": { 616 | "colab": {}, 617 | "colab_type": "code", 618 | "id": "NglXGh77d3yc" 619 | }, 620 | "outputs": [], 621 | "source": [ 622 | "preds, attn_weights = predict(X_val ,encoder, decoder, target_length=10)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 0, 628 | "metadata": { 629 | "colab": {}, 630 | "colab_type": "code", 631 | "id": "vRMILrf6d3yj" 632 | }, 633 | "outputs": [], 634 | "source": [ 635 | "def show_attention(input_sentence, output_words, attentions):\n", 636 | " # Set up figure with colorbar\n", 637 | " fig = plt.figure()\n", 638 | " ax = fig.add_subplot(111)\n", 639 | " cax = ax.matshow(attentions.numpy(), cmap='bone')\n", 640 | " fig.colorbar(cax)\n", 641 | "\n", 642 | " # Set up axes\n", 643 | " ax.set_xticks(np.arange(len(input_sentence)))\n", 644 | " ax.set_xticklabels(list(input_sentence), rotation=90)\n", 645 | " ax.set_yticks(np.arange(len(output_words)))\n", 646 | " ax.set_yticklabels(list(output_words))\n", 647 | " ax.grid()\n", 648 | " ax.set_xlabel('Input Sequence')\n", 649 | " ax.set_ylabel('Output Sequence')\n", 650 | " plt.show()" 651 | ] 652 | }, 653 | { 654 | "cell_type": "markdown", 655 | "metadata": {}, 656 | "source": [ 657 | "Chọn ngẫu nhiên một câu trong tập validation để hiển thị. Khi hiển thị cơ chế attention, chúng ta có một cái nhìn về quá trình dự đoán của mô hình rõ ràng hơn, giúp đánh giá có thể interpretable hơn. " 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": 27, 663 | "metadata": { 664 | "colab": { 665 | "base_uri": "https://localhost:8080/", 666 | "height": 336 667 | }, 668 | "colab_type": "code", 669 | "id": "eFJvKjnOLL9W", 670 | "outputId": "05ac40f8-92be-47ff-95a2-5594eb53cfc1" 671 | }, 672 | "outputs": [ 673 | { 674 | "data": { 675 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcgAAAE/CAYAAADCNlNLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3X1cVGX+//H3MAqkQwnKkBlmstkN\nKyUZm1HSFnZjbvuwtUQqa5fNfHRnudYaptSukLZqm5rZjeV2Ky3LtzXD1a2trTWEbrxJygpK1F8G\nDCmKRoCc3x+us5IDg+OcI2d4PX3M4+GZca7rGpadd5/rXOc6DsMwDAEAgFbCjvUAAADojAhIAAB8\nICABAPCBgAQAwAcCEgAAHwhIAAB8ICABAPCh27EewJHau3evPB6PJCk2NlY9evQwvc/m5mZVVVUp\nLi5O3brZ7keGEBSKv5OGYcjhcLR67ttvv9WJJ55oSn9Wfpcci++tzupoLr3/8e+H2Rx22Sjgk08+\nUW5urnbv3q3o6GgZhqHq6mrFxcVpxowZOv3004PW18yZM/XAAw9Ikt5//31NmzZNffr0UW1trR56\n6CFddNFFQeurq3j88cf14osveo8PfhkWFxcHva9LL730sOecTqfi4+M1efJkJSYmBr1Ps4Xy7+Q/\n//lP5eXl6fvvv1daWpqmT58ul8slSRo/fryef/75oPZn5XeJlX1JUnJyskaPHq3bbrtNvXv3Dmrb\nwWKngJRhExkZGUZ5eflhz2/atMnIzMwMal833HCD9++ZmZnG1q1bDcMwjOrqauO6664Lal9dxahR\no4y9e/da0tfixYuNl156yfj222+Nb7/91nj11VeNxYsXG+vWrTMyMjIsGUOwhfLv5JgxY4ydO3ca\n+/fvN5YtW2Zcd911xu7duw3DaP25g8XK7xIr+zKMAz+v0tJS46abbjKmTp1qlJaWGk1NTUHv52g0\n798f8MNqtjkHaRiGEhISDns+MTFR+/fvD2pfh/5XygknnKD4+HhJB6ZGQmU6a9GiRYc9N2vWLNP6\nO+OMMyz72b377rvKzMxUXFyc4uLidO2112rNmjU655xzLOnfDKH8O+l0OtWrVy+FhYVp7NixuuWW\nW5SVlaXvvvvOlIrByu8SK/uSDvyenHfeeVq6dKkyMzP1+uuva9SoUfrVr36lCRMmBL2/QBiGEfDD\narb5f9bZZ5+tiRMnKj09XTExMZIkj8ejVatWKSUlJah9ffnll5o0aZIMw1BlZaVWrlypK6+8Us8+\n+6yioqKC2pfVVq9erRUrVujDDz/U559/7n2+ublZn332maZOnRrU/u666y45HA7t3btXV1xxhc46\n6yw5nU7v64899lhQ+5OkiIgI5eXlKTk5WWFhYdq0aZOampq0Zs0a2577CeXfyeTkZN1666167LHH\nFBkZqfT0dEVEROjmm2/Wrl27gt6fld8lVvYltZ6+HDx4sAYPHixJqq6uVk1NTdD7C4QhW5zVk2Sj\nc5CS9MEHH6i4uNh7stvtdis1NVVDhgwJaj+lpaWtjk855RTFxcXp9ddf1yWXXKKePXsGtT+rbd++\nXX/84x+VlZXlfS4sLEwDBw70/p84WH78s/wxM74k6uvr9dprr6miokKGYah///4aPXq0vv/+e0VF\nRdkyUEL9d7KkpEQpKSmtKsb6+noVFRXpuuuuC3p/Vn2XWN1XQUGBxowZE/R2g+mH5uaA3xth8WyJ\nrQISAGBvDU1NAb83snv3II7EP9tMsQIA7K/FRjWZbRbpAABgJSpIAIBl7HRWj4AEAFiGgAyA5Tsk\nAIDN2Clc2mKnc5CdJiABAKHPTiFPQAIALGOnjQJYxQoAgA9UkAAAy7TYp4AkIAEA1uEcJAAAPrCK\nFQAAH6ggAQDwgYAEAMAHO02xcpkHAAA+mFpBPvLII/roo4/U3NysW2+9VZdddpmZ3QEAOjmmWCWt\nXbtWX375pfLz87Vz506NHj2agASALs5OO+mYFpDnnXeekpKSJEnHH3+8vv/+e+3fv19Op9OsLgEA\nnRwbBUhyOp3q0aOHJKmgoEDDhw8nHAGgi2OK9RBvvvmmCgoK9Oyzz5rdFQCgkyMg/+u9997T4sWL\n9cwzzygqKsrMrgAANmCnyzxMC8g9e/bokUce0dKlS9WrVy+zugEAwBSmBWRRUZF27typu+++2/vc\n7NmzddJJJ5nVJQCgk7PTFKvD6CSjdTgcx3oIANCpdZKv66NSXlUV8Ht/EhcXxJH4x1ZzAADL2Cnk\nCUgAgGXYKAAAAB/YKAAAAB/sNMXK3TwAAPCBChIAYBk7VZAEJADAMuykAwDHiNt9iqX9rdlYYllf\nVl8vbka1RwUJAIAPBCQAAD4wxQoAgA922iiAyzwAAPCBChIAYBl20gEAwAcW6QAA4AMBCQCAD6xi\nBQDABypIAAB8ICD/Ky8vTxs2bJDD4VB2draSkpLM7A4AgKAxLSBLS0tVWVmp/Px8VVRUKDs7W/n5\n+WZ1BwCwATudgzRto4Di4mKlp6dLkhISElRXV6f6+nqzugMA2IBxFH+sZlpAejweRUdHe49jYmJU\nU1NjVncAABtoMQJ/WM2yRTp2OjELADCHnbLAtIB0u93yeDze4+rqasXGxprVHQDABuwUkKZNsaam\npmrVqlWSpLKyMrndbrlcLrO6AwDYQIthBPzoiLy8PI0dO1YZGRnauHFjq9deeukljR07VuPGjVNu\nbq7ftkyrIJOTk5WYmKiMjAw5HA7l5OSY1RUAAO1ePVFfX68lS5Zo9erV6tatm37zm99o/fr1Ouec\nc9psz9RzkFOmTDGzeQCAzZg5xdrW1RMul0vdu3dX9+7dtW/fPvXo0UPff/+9TjjhhHbbYycdAIBl\nzAxIj8ejxMRE7/HBqydcLpciIiJ0++23Kz09XREREbrqqqt06qmnttseN0wGAFjG7HOQhzo0jOvr\n6/Xkk0/qH//4h9566y1t2LBBmzdvbvf9BCQAwDJmbhTQ3tUTFRUVio+PV0xMjMLDwzV06FBt2rSp\n3fYISACAZQwj8Ic/7V090a9fP1VUVKihoUGStGnTJg0YMKDd9jgHCQCwjJl7sfq6eqKwsFBRUVEa\nMWKEsrKyNH78eDmdTg0ZMkRDhw5ttz2H0Umu2nQ4HMd6CABCgNt9iqX9rdlYYllfp514omV9SeYs\nqCnasCHg9448++wgjsQ/KkgAgGU6SU3WIQQkAMAydrrdFQEJIKRUV1da2t9P4uIs7c/uqCABAPCB\ngAQAwAemWAEA8KEjF/x3FmwUAACAD1SQAADL2GiGlYAEAFiHc5AAAPjAKlYAAHygggQAwAcqSAAA\nfCAg/ysvL08bNmyQw+FQdna2kpKSzOwOAICgMS0gS0tLVVlZqfz8fFVUVCg7O1v5+flmdQcAsAMq\nSKm4uFjp6emSpISEBNXV1am+vt57d2cAQNdjtNgnIE3bScfj8Sg6Otp7HBMTo5qaGrO6AwDYgGEE\n/rCaZYt07HRiFgBgDjtlgWkB6Xa75fF4vMfV1dWKjY01qzsAgA3YKSBNm2JNTU3VqlWrJEllZWVy\nu92cfwSALs4wjIAfVjOtgkxOTlZiYqIyMjLkcDiUk5NjVlcAAASdqecgp0yZYmbzAACbsdMqVnbS\nAQBYxk7nIAlIAIBlCEgAAHwhIAEAOJyN8pGABABYx06LdEy7DhIAADujggQAWIZFOgDQRTid1n2N\nNjY3W9aXWQhIAAB8ICABAPCBgAQAwBcbrWIlIAEAlrFTBcllHgAA+EAFCQCwjI0KSAISAGAdO02x\nEpAAAMsQkAAA+GCnvVgJSACAZaggJf31r3/V8uXLvcebNm3SunXrzOoOAGADBKSka6+9Vtdee60k\nqbS0VCtXrjSrKwAAgs6SKdbHH39cc+bMsaIrAEAnZqcKskMbBXzxxRd68803JUm7d+8+og42btyo\nvn37KjY29shHBwAILYYR+MNifivIpUuXasWKFWpsbFR6eroWLVqk448/XrfddluHOigoKNDo0aOP\neqAAAPszWo71CDrObwW5YsUKvfrqqzrhhBMkSffdd5/eeeedDndQUlKiIUOGBDxAAEDoMAwj4IfV\n/FaQPXv2VFjY/3I0LCys1XF7qqqq1LNnT4WHhwc+QgBAyLDTOUi/Adm/f38tXLhQu3fv1urVq1VU\nVKSEhIQONV5TU6OYmJijHiQAIDTYKSAdhp/RNjU16fnnn1dJSYnCw8M1dOhQZWZmBr0qdDgcQW0P\nAKwQFua0rK+Gxh8s60uSujuD/9nmPF8Q8HunjB8TxJH457eCdDqdOvvss5WVlSVJ+te//qVu3diA\nBwBw5OxUQfo9mThjxgz9+9//9h6XlpZq2rRppg4KABCajBYj4IfV/Abkli1b9Lvf/c57PHXqVG3f\nvt3UQQEAQpTJ10Hm5eVp7NixysjI0MaNG1u9tmPHDo0bN05jxozRjBkz/LblNyAbGhq0a9cu73FV\nVZV++MHaeXAAQGgw8zKP0tJSVVZWKj8/X7m5ucrNzW31+qxZs/Sb3/xGBQUFcjqd+uabb9ptz+/J\nxNtvv12jRo1S3759tX//flVXVx/WKQAAHWHmKcji4mKlp6dLkhISElRXV6f6+nq5XC61tLToo48+\n0rx58yRJOTk5ftvzG5A///nP9eabb6q8vFwOh0MDBw7Ucccdd5QfAwDQFZm5SMfj8SgxMdF7HBMT\no5qaGrlcLn333Xfq2bOnHn74YZWVlWno0KGtTh/64jcga2pqVFRUpLq6ulYfbNKkSUfxMQAAMNeh\nmWUYhqqqqjR+/Hj169dPEyZM0DvvvKOLL764zff7PQd56623avPmzQoLC5PT6fQ+AAA4UmauYnW7\n3fJ4PN7j6upq740yoqOjddJJJ6l///5yOp0aNmyYvvzyy3bb81tB9ujRQw8//LDfgQFAV9TSst+y\nvsItvgbdjOlQM6dYU1NTtWDBAmVkZKisrExut1sul0uS1K1bN8XHx2vLli0aMGCAysrKdNVVV7Xb\nnt+f9tlnn62KiooOby8HAEBbzAzI5ORkJSYmKiMjQw6HQzk5OSosLFRUVJRGjBih7OxsTZ06VYZh\naNCgQbrkkkvabc/vVnNXX321KioqFB0drW7duskwDDkcjiO6o0dHsNUcAHQuZoTZHxe9EPB7p992\nYxBH4p/fCvKJJ56wYhwAgC4gpLaai42N1TvvvKNXXnlF/fr1k8fjUZ8+fawYGwAg1LQYgT8s5jcg\nH3zwQW3dulUlJSWSpLKyMk2dOtX0gQEAcCz5DcivvvpK999/vyIjIyVJmZmZqq6uNn1gAIDQY/JW\nrEHl9xzkwVtbHVxEs2/fPjU0NJg7KgBASLLTOUi/AXnFFVfopptu0vbt2zVz5ky9++67yszMtGJs\nAIAQE1IBecMNNygpKUmlpaUKDw/XvHnz9NOf/tSKsQEAQsyxuK9joPwGZHFxsSR5N4Dds2ePiouL\nNWzYMHNHBgAIOSFVQS5atMj796amJpWXlys5OZmABAAcsZAKyBdeaL3rQW1trebOnWvagAAA6AyO\neOfb3r1766uvvjJjLACAUBdKFeS9997bap/UHTt2KCzM7+WTAAAcJqSmWC+44ALv3x0Oh1wul1JT\nUzvcwcsvv6yVK1cqOjpa8+fPD2yUAICQYLQc6xF0nN+AHDp06GHPHXpDyvj4+Hbfn5mZyXWTAABJ\nIVZBZmVladu2berVq5ccDod27typk046yXvbq7feesuKcQIAQkBIBeTw4cM1evRo73WQ69ev14oV\nK/TAAw+YPjgAQGixU0D6XW3z+eefe8NRks455xxt3rzZ1EEBAHCs+a0gGxoa9NJLL+m8886TJH34\n4Yfat2+f6QMDAIQeO1WQfgNy7ty5WrBggZYtWyZJGjRokP70pz+ZPjAAQOgJqb1Y+/fvr9mzZ8vj\n8cjtdlsxJgBAiLJTBen3HGRxcbHS09M1fvx4SVJeXp7efvtt0wcGAAhBNrpjst+AfPTRR/Xqq68q\nNjZWkjRx4kQ98cQTpg8MABB6bJSP/qdYe/TooT59+niPY2Ji1L17d1MHBQAITXaaYvUbkJGRkSot\nLZUk1dXV6Y033lBERITpAwMA4FjyG5A5OTl68MEH9cknn2jEiBE699xz9Yc//MGKsQHHlNX/pXvo\nTQGAUBVSq1j79u2rJ5980oqxAABCnJ2mWNtcpLNjxw7NmjXLe/zoo49q6NChuuaaa/T1119bMjgA\nQGgxDCPgh9XaDMgZM2Z479Tx6aefqqCgQH/72990zz33tApOAAA6KiQCcs+ePbr++uslSatXr9bI\nkSN1yimn6KKLLlJDQ4NlAwQAhBAbXefRZkAeulK1tLRU559/vvfYTnPIAIDOw2gxAn5Yrc1FOg6H\nQ5s3b9aePXv0xRdf6IILLpAk1dTUqLGx0bIBAgBwLLQZkJMnT9akSZNUV1en6dOn67jjjlNDQ4PG\njBmjqVOnWjlGAECIsNMEZJsBmZSUpFWrVrV6LjIyUs8995wGDhxo+sAAAKHHTqfo/F4H+WOEIwAg\nUCEdkAAABCrkA7KhoUGRkZF+/11eXp42bNggh8Oh7OxsJSUlBdIdACBE2GmrOb+3u8rKyjrsuYPX\nR7antLRUlZWVys/PV25urnJzcwMbIQAgZNhpo4A2K8jly5fr8ccf1zfffKOLL77Y+3xTU1Or21+1\n5eCNliUpISFBdXV1qq+vl8vlOvpRAwBgsjYD8uqrr9ZVV12ladOm6c477/Q+HxYWJrfb7bdhj8ej\nxMRE73FMTIxqamoISADoykLlHKTT6dQvf/lLbd26tdXzW7Zs0bBhw46oIzudmAUAmMNOWeB3kc6i\nRYu8f29qalJ5ebmSk5P9BqTb7ZbH4/EeV1dXKzY29iiGCgCwOxvlo/+AfOGFF1od19bWau7cuX4b\nTk1N1YIFC5SRkaGysjK53W6mVwGgi7PTKtYjvsyjd+/e+uqrr/z+u+TkZCUmJiojI0MOh0M5OTkB\nDRAAEDpCaor13nvvlcPh8B7v2LFDYWF+rw6RJE2ZMiXwkQEAQk5IBeTBu3hIB+7w4XK5lJqaauqg\nAAA41vyWgqNHj1ZiYqIiIiIUERGhgQMH6rjjjrNibACAEBMSGwUcNHv2bL311lsaPHiwWlpaNHfu\nXI0aNUp33323FeMDAISQkJpiLSkp0RtvvKHu3btLkhobG5WRkUFAAgCOWEitYu3Tp4+6dfvfP+ve\nvbv69etn6qAAACEqlCrI6Oho/epXv9L5558vwzD0wQcfKD4+Xo899pgkadKkSaYPEgAQGmyUj/4D\nMj4+XvHx8d7jQzcuBwAgVPkNSJfLpZtvvrnVc/Pnz9ddd91l1pgAACHK7EU6HbkP8dy5c7V+/frD\ndor7sTYDcu3atVq7dq2WL1+uuro67/PNzc0qLCwkIBHyDt0gwwpWru6z+rMBB5n5e37ofYgrKiqU\nnZ2t/Pz8Vv+mvLxcH3zwgXfhaXvavA5y4MCBSkhIkHTgrh4HH5GRkZo3b95RfgwAQFdktBgBP/xp\n6z7Eh5o1a5buueeeDo21zQrS7XbrF7/4hZKTk1m1CgAICjMrSH/3IS4sLFRKSkqHM83vOcjMzEyf\n0zHvvPNOB4cMAMABVp5KOLSvXbt2qbCwUM8995yqqqo69H6/Afnyyy97/97U1KTi4mL98MMPAQwV\nANDVmRmQ7d2HeO3atfruu+90/fXXq7GxUVu3blVeXp6ys7PbbM/vXqz9+vXzPgYMGKBx48bpvffe\nC8JHAQAgeFJTU7Vq1SpJOuw+xFdccYWKior06quvauHChUpMTGw3HKUOVJDFxcWtjr/99ltt3bo1\n0PEDALoyEytIX/chLiwsVFRUlEaMGHHE7TkMP/XujTfe+L9//N/bXd1www2tboMVDCw7R1fHZR7o\nbMz4nRyb8fuA35u/bHYQR+Kf3wrS34WUAAB0lJ3u5tHuOcji4mJdf/31GjJkiJKTk3XzzTdr/fr1\nVo0NABBiQuJ+kEVFRVq0aJEmT56sc845R5L0ySefKCcnR5MmTdIll1zSbsMlJSWaNGmSTjvtNEnS\noEGDNH369CAOHQBgN3aqINsMyKVLl+rpp59W3759vc+lpaXpzDPP7FBASlJKSormz58fnJECAGzP\nTgHZ5hSrw+FoFY4Hud1uW31AAAAC0WZANjQ0tPmmffv2dajx8vJyTZw4UePGjdOaNWuOfHQAgJBi\n5l6swdZmQJ555pk+V7A+88wzSk5O9tvwgAEDdMcdd+iJJ57Q7NmzNW3aNDU2Nh7daAEA9mYYgT8s\n1uY5yPvuu0+33XabVqxYocGDB8swDK1bt04ul0tPPvmk34bj4uI0cuRISVL//v3Vp08fVVVVtbr5\nMgCgazFkn1N0bQZkTEyMli1bpjVr1ujTTz9Vjx49dOWVV2ro0KEdanj58uWqqalRVlaWampqVFtb\nq7i4uKANHABgP3Zaw+J3J51A1dfXa8qUKdq9e7eampp0xx13KC0tre2BsLMHujh20kFnY8bv5C9/\neWfA7/373xcEcST++d1JJ1Aul0uLFy82q3kAgA3ZqYL0ezcPAAC6ItMqSAAAfsxOFSQBCQCwDAEJ\nAIAPhtFyrIfQYQQkAMA6VJAAABwuJDYKAAAg2Ox0DpLLPAAA8IEKEugk2N0GXYGdKkgCEgBgGVax\nAgDgAxUkAAA+EJAAAPhAQAIA4IuNApLLPAAA8IEKEgBgGUOsYgUA4DCcgwQAwAcCEgAAHwhIAAB8\nYCcdAAB8sFMFyWUeAAD4QAUJALCMnSpIAhIAYB0C8n9efvllrVy5UtHR0Zo/f77Z3QEAOjFD9glI\nh9FJ6l1uFgsAnYsZ8ZCaOjrg965Z839BHIl/TLECACzTSWqyDiEgAQCWsVNAcpkHAAA+UEECACxj\npwqSgAQAWIat5gAA8IEKEgAAXwhIAAAOZ6eNAghIAIBl7DTFymUeAAD40GkqSCv/q4Jt7QDg2GAV\nKwAAPthpipWABABYhoAEAMAHAhIAAB/MDsi8vDxt2LBBDodD2dnZSkpK8r62du1azZs3T2FhYTr1\n1FOVm5ursLC216qyihUAYB2jJfCHH6WlpaqsrFR+fr5yc3OVm5vb6vUZM2Zo/vz5WrZsmfbu3av3\n3nuv3fYISABASCguLlZ6erokKSEhQXV1daqvr/e+XlhYqBNPPFGSFBMTo507d7bbHgEJALCMcRR/\n/PF4PIqOjvYex8TEqKamxnvscrkkSdXV1VqzZo3S0tLabY9zkAAAy1i5SMdXX7W1tZo4caJycnJa\nhakvBCQAwDJmBqTb7ZbH4/EeV1dXKzY21ntcX1+vW265RXfffbcuvPBCv+2ZOsWal5ensWPHKiMj\nQxs3bjSzKwCADRhGS8APf1JTU7Vq1SpJUllZmdxut3daVZJmzZqlm266ScOHD+/QWE2rIA9dTVRR\nUaHs7Gzl5+eb1R0AwAbMrCCTk5OVmJiojIwMORwO5eTkqLCwUFFRUbrwwgv12muvqbKyUgUFBZKk\nUaNGaezYsW22Z1pAtrWa6NA0BwB0LWafg5wyZUqr4zPOOMP7902bNh1RW6ZNsfpbTQQAQGdm2SId\nO20vBAAwh52ywLSA9LeaCADQBdkoIE2bYvW3mggA0PUYagn4YTXTKkhfq4kAAF2bnaZYHYadRhsk\nDofjWA8BADo9M+LhtNPODfi9X375URBH4h876QAALGOnmozNygEA8IEKEgBgmY5sGddZEJAAAMvY\naYqVgAQAWIaABADAFwISAIDDGSIgAQA4DIt0AmDlxfuNzc2W9SVJ4d06zY8ZnVhYmNOyvlpa9lvW\nF2BXfHMDACzDIh0AAHwgIAEA8IGABADABwISAAAfWMUKAIAvNqoguZsHAAA+UEECACzDTjoAAPjA\nIh0AAHxgkY6kkpISTZo0SaeddpokadCgQZo+fbpZ3QEAbIAK8r9SUlI0f/58M7sAANgIAQkAgA92\nCkhTL/MoLy/XxIkTNW7cOK1Zs8bMrgAACCqHYVKcV1VV6aOPPtKVV16pbdu2afz48Vq9erXCw8N9\nD4TbXaGL43ZX6GzMiIfevfsF/N7a2v8XxJH4Z1oFGRcXp5EjR8rhcKh///7q06ePqqqqzOoOAGAH\nRkvgD4uZFpDLly/XkiVLJEk1NTWqra1VXFycWd0BAGzAOIo/VjNtirW+vl5TpkzR7t271dTUpDvu\nuENpaWltD4QpVnRxTLGiszEjHqKjAy+Udu60dhbStIA8UgQkujoCEp2NGfHQq5c74Pfu2lUdxJH4\nxzc3AMAydtpJh7t5AADgAxUkAMAyneSsXocQkAAAyxCQAAD4QEACAOALAQkAwOEM2WcVKwEJALAM\nU6wBsNMP7UiF8mcDgFDVaQISABD67FQwEJAAAMsQkAAA+EBAAgDgg532YiUgYVunn366ysrK1C2I\nd0v5+OOPFRsbq/j4+FbPNzQ0aObMmaqoqFC3bt20d+9e/fa3v9XIkSOD1jfQFVBBAjZVWFiokSNH\nHhaQzz33nCIjI/XKK69Iknbs2KEJEyYoLS1NPXv2PBZDBeyJgASsU1JSoqeeekonnniiysvL1a1b\nNz3zzDOqra3VzTffrOHDh2vz5s2SpEcffVRxcXGtqs/CwkK9//77uvzyy/WPf/xDGzdu1P33369h\nw4Z5+6irq9PevXtlGIYcDof69u2r119/3fv6vHnz9PHHH6uhoUHnnXee7rvvPknSjBkztGnTJrnd\nbkVHRysuLk733HOPz/7nzJmjzZs3a/bs2WpublZTU5NmzJihs846SzfeeKOGDRumdevWacuWLbrz\nzjt19dVXq7a2Vvfff7/27Nkjp9OpGTNmaNCgQSoqKtKLL74owzAUExOjmTNnKjo62tr/YQCb43ZX\nCAnr16/X5MmTlZ+fr7CwMP3nP/+RJG3btk3XXHONXn75ZaWkpOjZZ59ts40RI0bozDPP1NSpU1uF\noySNHz9emzZt0qWXXqpp06Zp5cqVamxslCStXLlSVVVVevHFF1VQUKCtW7fq7bffVnFxsT777DMV\nFBRo4cKF+uKLL/x+jnvvvVcPPfSQXnjhBT344IN64IEHvK/t27dPTz/9tHJzc/XMM89IkubOnau0\ntDS98soruuuuu/T3v/9dO3ZRTlBHAAAEZElEQVTs0OLFi7V06VK98sorSklJ0ZNPPnnEP1PADMZR\n/LEaFSRCQkJCgnr37i1J6tevn3bt2iVJ6tWrl376059KkpKTk/WXv/wloPZPOukkLV++XJ988onW\nrl2rZ599Vn/+85/1t7/9TSUlJVq/fr1uvPFGSdKePXu0fft2NTc369xzz5XT6ZTT6dTPfvazdvuo\nra3V119/rWnTpnmfq6+vV0vLgUUNKSkp3rHU1dVJkjZu3Khf//rX3tdTUlJUVFSkmpoaZWVlSZIa\nGxt18sknB/S5gWAze5FOXl6eNmzYIIfDoezsbCUlJXlfe//99zVv3jw5nU4NHz5ct99+e7ttEZAI\nCU6n0+fzhy4IODg9+mNNTU1+229oaFBERISSkpKUlJSkW265RZmZmXr//fcVHh6u6667zhtIBy1Z\nsqTVsa++D+0/PDxc3bt31wsvvODz3x26GOng53I4HN4APSg8PFxJSUlUjeiUzFykU1paqsrKSuXn\n56uiokLZ2dnKz8/3vj5z5kwtWbJEcXFxuuGGG3T55ZfrJz/5SZvtMcWKkFZXV6dPP/1U0oEVqqef\nfrokyeVyaceOHZIOnMM8yOFw+AzMm266Sa+99pr3eO/evdq5c6fi4+N17rnn6p///Keam5slSQsX\nLtSWLVt02mmnad26dWppaVFjY6N32ret/qOionTyySfr3//+tyTp66+/1sKFC9v9fEOGDNF7770n\nSfrwww/1+9//XoMHD9bGjRtVU1Mj6cAU8JtvvtnRHxlgKsMwAn74U1xcrPT0dEkHZpXq6upUX18v\n6cDplhNOOEF9+/ZVWFiY0tLSVFxc3G57VJAIaXFxcSosLNSsWbNkGIbmzZsnSZowYYKysrJ0yimn\n6IwzzvCGVWpqqnJycpSdna3LLrvM287cuXOVm5ur/Px8hYeH64cfftCECRN05pln6owzztD69euV\nkZEhp9Ops846S/Hx8erfv7/eeOMNXXPNNYqNjdWgQYO87bXV/+zZszVz5kw99dRTam5u1tSpU9v9\nfJMmTdL999+vt99+W5I0ffp0xcXFadq0abr11lt13HHHKTIyUrNnzw7qzxUIlJkVpMfjUWJiovc4\nJiZGNTU1crlcqqmpUUxMTKvXtm3b1m57DsNOF6UAR2D79u3KzMzUu+++e6yHIklasGCBmpubdc89\n9xzroQAhafr06UpLS/NWkePGjVNeXp5OPfVUffzxx1qyZIkef/xxSdJf//pXbdu2TZMnT26zPaZY\nAQAhwe12y+PxeI+rq6sVGxvr87Wqqiq53e522yMgEbJOPvnkTlM9StKdd95J9QiYKDU1VatWrZIk\nlZWVye12y+VySTrwfVBfX+9dYf72228rNTW13faYYgUAhIw5c+boww8/lMPhUE5Ojj799FNFRUVp\nxIgR+uCDDzRnzhxJ0mWXXXbYyvMfIyABAPCBKVYAAHwgIAEA8IGABADABwISAAAfCEgAAHwgIAEA\n8IGABADAh/8PvC7C/aLzHrUAAAAASUVORK5CYII=\n", 676 | "text/plain": [ 677 | "
" 678 | ] 679 | }, 680 | "metadata": { 681 | "tags": [] 682 | }, 683 | "output_type": "display_data" 684 | } 685 | ], 686 | "source": [ 687 | "show_idx = randint(0, len(preds))\n", 688 | "text_x = decoder_sentence(X_val[:,show_idx].numpy(), x_id2w)\n", 689 | "text_y = decoder_sentence(preds[show_idx].numpy(), y_id2w)\n", 690 | "attn_weight = attn_weights[show_idx, :, -len(text_x):]\n", 691 | "show_attention(text_x, text_y, attn_weight)" 692 | ] 693 | } 694 | ], 695 | "metadata": { 696 | "accelerator": "GPU", 697 | "colab": { 698 | "collapsed_sections": [], 699 | "name": "attention.ipynb", 700 | "provenance": [], 701 | "version": "0.3.2" 702 | }, 703 | "kernelspec": { 704 | "display_name": "Python 3", 705 | "language": "python", 706 | "name": "python3" 707 | }, 708 | "language_info": { 709 | "codemirror_mode": { 710 | "name": "ipython", 711 | "version": 3 712 | }, 713 | "file_extension": ".py", 714 | "mimetype": "text/x-python", 715 | "name": "python", 716 | "nbconvert_exporter": "python", 717 | "pygments_lexer": "ipython3", 718 | "version": "3.6.6" 719 | } 720 | }, 721 | "nbformat": 4, 722 | "nbformat_minor": 1 723 | } 724 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tìm hiểu và áp dụng cơ chế attention trong deep learning - Understanding Attention Mechanism 2 | ## Giới Thiệu 3 | Ở git này, mình cung cấp cho các bạn cài đặt chi tiết của cơ chế attention trong bài toàn seq2seq đã được đơn giản hóa. Đồng thời giải nghĩa kết quả tại mỗi thời điểm thông qua cơ chế này. Git này để phục vụ cho blog về attention mà mình đã viết, các bạn có thể đọc thêm tại [đây](https://pbcquoc.github.io/attention/) nhé. 4 | ## Tổng quan cơ chế attention trong deep learning 5 | Cơ chế attention chỉ đơn giản là trung bình có trọng số của những “thứ” mà chúng ta nghĩ nó cần thiết cho bài toán, điều đặc biệt là trọng số này do mô hình tự học được. Cụ thể, trong bài toán dịch máy ở ví dụ dưới, khi sử dụng cơ chế attention để phát sinh từ little, mình sẽ cần tính một vector context C là trung bình có trọng số của vector biểu diễn các từ mặt, trời, bé, nhỏ tương ứng với vector h1,h2,h3,h4, rồi sử dụng thêm vector context c này tại lúc dự đoán từ little, và nhớ rằng, trọng số này là các số scalar, được mô hình tự học 6 | ![attention](./img/attn_seq2seq.png) 7 | 8 | [READ MORE](https://pbcquoc.github.io/attention/) 9 | 10 | ## Dataset 11 | Để minh họa cơ chế attention, mình sử dụng tập dataset tự phát sinh, với đầu vào là các câu biểu diễn ngày tháng năm của con người đọc, và nhãn là ngày tháng năm tương ứng do máy tính hiểu. 12 | 13 | | Input | Label | 14 | | ---------------------------------| ------------- | 15 | | 12, thg 9 2010 | 2010-09-12 | 16 | | Thứ Tư, ngày 21 tháng 3 năm 1973 | 1973-03-21 | 17 | | 31 thg 7, 1988 | 1988-07-31 | 18 | 19 | Mình đã phát sinh tổng cộng 20k mẫu, trong đó 5k dùng để validation. 20 | 21 | ## Kết quả 22 | Vì tập dữ liệu mình dùng để minh họa khá đơn giản, nên chỉ cần sau 3 epochs bạn đã có kết quả tương đối chính xác. Mình huấn luyện đến 10 epochs thì loss là 0.023 trên tập validation. 23 | 24 | Dưới này là một minh họa của câu đầu vào là “05 thg 5 2017”, các bạn có thể thấy rằng các phần ngày tháng năm khi phát sinh đều được mô hình chú ý một cách đúng lúc và chính xác. 25 | 26 | ![result](/img/attn_ex_2.png) 27 | ## Any Problems? 28 | Nếu có bất kì câu hỏi gì, các bạn có thể liên hệ mình thông qua địa chỉ pbcquoc@gmail.com nhé !. 29 | -------------------------------------------------------------------------------- /attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "attention.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "language": "python", 14 | "name": "python3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "metadata": { 21 | "id": "_5S-TnnTbpVR", 22 | "colab_type": "text" 23 | }, 24 | "cell_type": "markdown", 25 | "source": [ 26 | "## Tìm Hiểu và Áp Dụng Cơ Chế Attention\n", 27 | "### Giới thiệu\n", 28 | "Trong phần này mình hướng dẫn các bạn cách cài đặt cơ chế attention trong bài toán seq2seq đơn giản hóa. \n", 29 | "\n", 30 | "Cơ chế attention chỉ đơn giản là trung bình có trọng số của những “thứ” mà chúng ta nghĩ nó cần thiết cho bài toán, điều đặc biệt là trọng số này do mô hình tự học được. Cụ thể, trong bài toán dịch máy ở ví dụ dưới, khi sử dụng cơ chế attention để phát sinh từ little, mình sẽ cần tính một vector context C là trung bình có trọng số của vector biểu diễn các từ mặt, trời, bé, nhỏ tương ứng với vector h1,h2,h3,h4, rồi sử dụng thêm vector context c này tại lúc dự đoán từ little, và nhớ rằng, trọng số này là các số scalar, được mô hình tự học \n", 31 | "\n", 32 | "![attention](https://github.com/pbcquoc/attention_tutorial/raw/master/img/attn_seq2seq.png)\n" 33 | ] 34 | }, 35 | { 36 | "metadata": { 37 | "id": "RkZkpr3qbpVT", 38 | "colab_type": "text" 39 | }, 40 | "cell_type": "markdown", 41 | "source": [ 42 | "### Import thư viện\n", 43 | "Mình sử dụng thêm cái lib của keras để hỗ trợ tiền xử lý cho nhanh, và tập trung chủ yếu vào phần cài đặt cơ chế attention\n" 44 | ] 45 | }, 46 | { 47 | "metadata": { 48 | "colab_type": "code", 49 | "id": "o43HsYMod3xO", 50 | "outputId": "be8b39bf-d2d2-44c9-facf-1daedf385340", 51 | "colab": { 52 | "base_uri": "https://localhost:8080/", 53 | "height": 35 54 | } 55 | }, 56 | "cell_type": "code", 57 | "source": [ 58 | "import pandas as pd\n", 59 | "import numpy as np\n", 60 | "from keras.preprocessing.text import Tokenizer\n", 61 | "from keras.preprocessing.sequence import pad_sequences\n", 62 | "import matplotlib.pyplot as plt\n", 63 | "import torch\n", 64 | "import torch.nn as nn\n", 65 | "import torch.nn.functional as F\n", 66 | "from torch import optim\n", 67 | "from torch.autograd import Variable\n", 68 | "from sklearn.model_selection import train_test_split\n", 69 | "import torch.utils.data\n", 70 | "import matplotlib.pyplot as plt\n", 71 | "import matplotlib.ticker as ticker\n", 72 | "from random import randint" 73 | ], 74 | "execution_count": 1, 75 | "outputs": [ 76 | { 77 | "output_type": "stream", 78 | "text": [ 79 | "Using TensorFlow backend.\n" 80 | ], 81 | "name": "stderr" 82 | } 83 | ] 84 | }, 85 | { 86 | "metadata": { 87 | "id": "MQ9naryYbpVd", 88 | "colab_type": "text" 89 | }, 90 | "cell_type": "markdown", 91 | "source": [ 92 | "### Download tập huấn luyện\n", 93 | "Để minh họa cơ chế attention, mình sử dụng tập dataset tự phát sinh, với đầu vào là các câu biểu diễn ngày tháng năm của con người đọc, và nhãn là ngày tháng năm tương ứng do máy tính hiểu.\n", 94 | "\n", 95 | "Mình đã phát sinh tổng cộng 20k mẫu, trong đó 5k dùng để validation." 96 | ] 97 | }, 98 | { 99 | "metadata": { 100 | "colab_type": "code", 101 | "id": "iDdQn4P6d3xl", 102 | "outputId": "ee6a4b1d-cac9-46bf-8e8f-53dea5d56b3a", 103 | "colab": { 104 | "base_uri": "https://localhost:8080/", 105 | "height": 52 106 | } 107 | }, 108 | "cell_type": "code", 109 | "source": [ 110 | "! curl --silent -L -o data.zip \"https://drive.google.com/uc?export=download&id=1d6eUqRstk7NIpyASzbuIsDvBdHEwfU0g\"\n", 111 | "! unzip -q data.zip\n", 112 | "! ls data" 113 | ], 114 | "execution_count": 2, 115 | "outputs": [ 116 | { 117 | "output_type": "stream", 118 | "text": [ 119 | "replace data/machine_vocab.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: A\n", 120 | "data.csv human_vocab.json machine_vocab.json\n" 121 | ], 122 | "name": "stdout" 123 | } 124 | ] 125 | }, 126 | { 127 | "metadata": { 128 | "id": "d-PgIPRqbpVj", 129 | "colab_type": "text" 130 | }, 131 | "cell_type": "markdown", 132 | "source": [ 133 | "### Tiền xử lý\n", 134 | "Tập vocab mình xử dụng là các kí tự alphabet và số. Các bạn không cần phải filter, các kí tự đặt biệt trong tập dữ liệu" 135 | ] 136 | }, 137 | { 138 | "metadata": { 139 | "colab_type": "code", 140 | "id": "17yzXiMWd3xp", 141 | "outputId": "529006a9-221d-42a2-8fee-dc84d96f55b5", 142 | "colab": { 143 | "base_uri": "https://localhost:8080/", 144 | "height": 35 145 | } 146 | }, 147 | "cell_type": "code", 148 | "source": [ 149 | "def load_data(path):\n", 150 | " df = pd.read_csv(path, header=None)\n", 151 | " X = df[0].values\n", 152 | " y = df[1].values\n", 153 | " x_tok = Tokenizer(char_level=True, filters='')\n", 154 | " x_tok.fit_on_texts(X)\n", 155 | " y_tok = Tokenizer(char_level=True, filters='')\n", 156 | " y_tok.fit_on_texts(y)\n", 157 | " \n", 158 | " X = x_tok.texts_to_sequences(X)\n", 159 | " y = y_tok.texts_to_sequences(y)\n", 160 | " \n", 161 | " X = pad_sequences(X)\n", 162 | " y = np.asarray(y)\n", 163 | " \n", 164 | " return X, y, x_tok.word_index, y_tok.word_index\n", 165 | "\n", 166 | "X, y, x_wid, y_wid= load_data('data/data.csv')\n", 167 | "x_id2w = dict(zip(x_wid.values(), x_wid.keys()))\n", 168 | "y_id2w = dict(zip(y_wid.values(), y_wid.keys()))\n", 169 | "X_train, X_test, y_train, y_test = train_test_split(X, y)\n", 170 | "print('train size: {} - test size: {}'.format(len(X_train), len(X_test)))" 171 | ], 172 | "execution_count": 3, 173 | "outputs": [ 174 | { 175 | "output_type": "stream", 176 | "text": [ 177 | "train size: 18750 - test size: 6250\n" 178 | ], 179 | "name": "stdout" 180 | } 181 | ] 182 | }, 183 | { 184 | "metadata": { 185 | "id": "yvTDTUDIbpVo", 186 | "colab_type": "text" 187 | }, 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### Định nghĩa các tham số" 191 | ] 192 | }, 193 | { 194 | "metadata": { 195 | "colab_type": "code", 196 | "id": "Fy7StSyzd3xz", 197 | "outputId": "bcea276a-aada-4770-cb14-3387a7498dd1", 198 | "colab": { 199 | "base_uri": "https://localhost:8080/", 200 | "height": 35 201 | } 202 | }, 203 | "cell_type": "code", 204 | "source": [ 205 | "# hidden size cho môt hình LSTM\n", 206 | "hidden_size = 128\n", 207 | "learning_rate = 0.001\n", 208 | "decoder_learning_ratio = 0.1\n", 209 | "\n", 210 | "# tập tự vựng của các câu đầu vào \n", 211 | "# +1 vì các bạn cần kí tự padding nhé!\n", 212 | "input_size = len(x_wid) + 1\n", 213 | "\n", 214 | "# +2 vì các bạn cần kí tự bắt đầu và kết thức\n", 215 | "output_size = len(y_wid) + 2\n", 216 | "# 2 kí tự này nằm ở cuối\n", 217 | "sos_idx = len(y_wid) \n", 218 | "eos_idx = len(y_wid) + 1\n", 219 | "\n", 220 | "max_length = y.shape[1]\n", 221 | "print(\"input vocab: {} - output vocab: {} - length of target: {}\".format(input_size, output_size, max_length))" 222 | ], 223 | "execution_count": 4, 224 | "outputs": [ 225 | { 226 | "output_type": "stream", 227 | "text": [ 228 | "input vocab: 35 - output vocab: 13 - length of target: 10\n" 229 | ], 230 | "name": "stdout" 231 | } 232 | ] 233 | }, 234 | { 235 | "metadata": { 236 | "id": "l6TVVm61bpVu", 237 | "colab_type": "text" 238 | }, 239 | "cell_type": "markdown", 240 | "source": [ 241 | "Chuyển sang dạng chuỗi kí tự đọc được từ chuỗi số" 242 | ] 243 | }, 244 | { 245 | "metadata": { 246 | "colab_type": "code", 247 | "id": "2LksDULqd3x3", 248 | "colab": {} 249 | }, 250 | "cell_type": "code", 251 | "source": [ 252 | "def decoder_sentence(idxs, vocab):\n", 253 | " text = ''.join([vocab[w] for w in idxs if (w > 0) and (w in vocab)])\n", 254 | " return text" 255 | ], 256 | "execution_count": 0, 257 | "outputs": [] 258 | }, 259 | { 260 | "metadata": { 261 | "id": "almb2BMwbpVz", 262 | "colab_type": "text" 263 | }, 264 | "cell_type": "markdown", 265 | "source": [ 266 | "## Định nghĩa mô hình\n", 267 | "Ở phần này, các bạn cần định nghĩa 3 mô hình nhỏ\n", 268 | "* Encoder: là một mô hình LSTM, dùng để học biểu diễn của câu\n", 269 | "* Attention: dùng để học cách kết hợp để tạo ra context vector\n", 270 | "* Decoder: là một mô hình LSTM, chúng ta sẽ kết hợp context vector vào mô hình này để dự đoán các từ tại mỗi thời điểm\n", 271 | "\n", 272 | "![model](https://github.com/pbcquoc/pbcquoc.github.io/raw/master/images/attn_seq2seq_attn.png)\n", 273 | "\n", 274 | "### Encoder\n", 275 | "Mô hình này nhận đầu vào là các câu, các bạn có thể xem các hidden state h1,h2,h3,h4 như các biểu diễn của mỗi từ, và muốn tổng hợp context vector trên những thông tin này. \n", 276 | "### Attention\n", 277 | "Mô hình này học các trọng số alpha trên các h1,h2,h3,h4 rồi sau đó tổng hợp context vector theo trung bình có trọng số alpha này\n", 278 | "### Decoder\n", 279 | "Ở thời điểm dự đoán, các bạn sử dụng thêm context vector để bổ sung thông tin cho mô hình. " 280 | ] 281 | }, 282 | { 283 | "metadata": { 284 | "colab_type": "code", 285 | "id": "yVx1T2LJd3x6", 286 | "colab": {} 287 | }, 288 | "cell_type": "code", 289 | "source": [ 290 | "class Encoder(nn.Module):\n", 291 | " def __init__(self, input_size, hidden_size):\n", 292 | " super(Encoder, self).__init__()\n", 293 | " self.hidden_size = hidden_size\n", 294 | " # embedding vector của từ\n", 295 | " self.embedding = nn.Embedding(input_size, hidden_size)\n", 296 | " # mô hình GRU biến thể RNN để học vector biểu diễn của câu\n", 297 | " self.gru = nn.GRU(hidden_size, hidden_size)\n", 298 | " \n", 299 | " def forward(self, input):\n", 300 | " # input: SxB \n", 301 | " embedded = self.embedding(input)\n", 302 | " output, hidden = self.gru(embedded)\n", 303 | " return output, hidden # SxBxH, 1xBxH \n", 304 | "\n", 305 | "class Attn(nn.Module):\n", 306 | " def __init__(self, hidden_size):\n", 307 | " super(Attn ,self).__init__()\n", 308 | " \n", 309 | " def forward(self, hidden, encoder_outputs):\n", 310 | " ### Mô hình nhận trạng thái hidden hiện tại của mô hình decoder, \n", 311 | " ### và các hidden states của mô hình encoder\n", 312 | " # encoder_outputs: TxBxH\n", 313 | " # hidden: SxBxH\n", 314 | " \n", 315 | " # tranpose về đúng shape để nhận ma trận\n", 316 | " encoder_outputs = torch.transpose(encoder_outputs, 0, 1) #BxTxH\n", 317 | " hidden = torch.transpose(torch.transpose(hidden, 0, 1), 1, 2) # BxHxS\n", 318 | " # tính e, chính là tương tác giữ hidden và các trạng thái ẩn của mô hình encoder \n", 319 | " energies = torch.bmm(encoder_outputs, hidden) # BxTxS\n", 320 | " energies = torch.transpose(energies, 1, 2) # BxSxT\n", 321 | " # tính alpha, chính là trọng số của trung bình có trọng số cần tính bằng hàm softmax\n", 322 | " attn_weights = F.softmax(energies, dim=-1) #BxSxT\n", 323 | " \n", 324 | " # tính context vector bằng trung binh có trọng số\n", 325 | " output = torch.bmm(attn_weights, encoder_outputs) # BxSxH\n", 326 | " \n", 327 | " # trả về chiều cần thiết\n", 328 | " output = torch.transpose(output, 0, 1) # SxBxH\n", 329 | " attn_weights = torch.transpose(attn_weights, 0, 1) #SxBxT\n", 330 | " \n", 331 | " # return context vector và các trọng số alpha cho mục đích biểu diễn cơ chế attention\n", 332 | " return output, attn_weights\n", 333 | " \n", 334 | "class Decoder(nn.Module):\n", 335 | " def __init__(self, output_size, hidden_size, dropout):\n", 336 | " super(Decoder, self).__init__()\n", 337 | " self.hidden_size = hidden_size\n", 338 | " self.output_size = output_size\n", 339 | " \n", 340 | " # vector biểu diễn cho các từ của output\n", 341 | " self.embedding = nn.Embedding(output_size, hidden_size)\n", 342 | " # định nghĩa mô hình attention ở trên\n", 343 | " self.attn = Attn(hidden_size)\n", 344 | " self.dropout = nn.Dropout(dropout)\n", 345 | " # mô hình decoder là GRU\n", 346 | " self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n", 347 | " \n", 348 | " # dự đoán các từ tại mội thời điểm, chúng ta nối 2 vector hidden và context lại với nhau \n", 349 | " self.concat = nn.Linear(self.hidden_size*2, hidden_size) \n", 350 | " self.out = nn.Linear(self.hidden_size, self.output_size)\n", 351 | " \n", 352 | " def forward(self, input, hidden, encoder_outputs):\n", 353 | " # input: SxB\n", 354 | " # encoder_outputs: BxSxH\n", 355 | " # hidden: 1xBxH\n", 356 | " embedded = self.embedding(input) # 1xBxH\n", 357 | " embedded = self.dropout(embedded)\n", 358 | " \n", 359 | " # biểu diễn của câu\n", 360 | " rnn_output, hidden = self.gru(embedded, hidden) #SxBxH, 1xBxH\n", 361 | " # tính context vector dựa trên các hidden states\n", 362 | " context, attn_weights = self.attn(rnn_output, encoder_outputs) # SxBxH\n", 363 | " \n", 364 | " # nối hidden state của mô hình decoder hiện tại và context vector để dự đoán \n", 365 | " concat_input = torch.cat((rnn_output, context), -1)\n", 366 | " concat_output = torch.tanh(self.concat(concat_input)) #SxBxH\n", 367 | " \n", 368 | " # dự đoán kết quả tại mỗi thời điểm\n", 369 | " output = self.out(concat_output) # SxBxoutput_size\n", 370 | " return output, hidden, attn_weights\n", 371 | "\n" 372 | ], 373 | "execution_count": 0, 374 | "outputs": [] 375 | }, 376 | { 377 | "metadata": { 378 | "id": "nRDEAU7WbpV3", 379 | "colab_type": "text" 380 | }, 381 | "cell_type": "markdown", 382 | "source": [ 383 | "### Kiểm tra\n", 384 | "Chúng ta khởi tạo mô hình để kiểm tra xem mô hình có chạy được không, ít nhất là không bị lỗi về tính toán" 385 | ] 386 | }, 387 | { 388 | "metadata": { 389 | "colab_type": "code", 390 | "id": "ArLtO8rsd3yA", 391 | "colab": {} 392 | }, 393 | "cell_type": "code", 394 | "source": [ 395 | "encoder = Encoder(input_size, hidden_size)\n", 396 | "decoder = Decoder(output_size, hidden_size, 0.1)\n", 397 | "\n", 398 | "# Initialize optimizers and criterion\n", 399 | "encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n", 400 | "decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)\n", 401 | "criterion = nn.CrossEntropyLoss()\n", 402 | "\n", 403 | "\n", 404 | "input_encoder = torch.randint(1, input_size, (34, 6), dtype=torch.long)\n", 405 | "encoder_outputs, hidden = encoder(input_encoder)\n", 406 | "input_decoder = torch.randint(1, output_size, (10, 6), dtype=torch.long)\n", 407 | "output, hidden, attn_weights = decoder(input_decoder, hidden, encoder_outputs)" 408 | ], 409 | "execution_count": 0, 410 | "outputs": [] 411 | }, 412 | { 413 | "metadata": { 414 | "id": "exEjFZUpbpV7", 415 | "colab_type": "text" 416 | }, 417 | "cell_type": "markdown", 418 | "source": [ 419 | "## Train/test\n", 420 | "Phần này chúng ta định nghĩa một số hàm để huấn luyện, dự đoán mô hình " 421 | ] 422 | }, 423 | { 424 | "metadata": { 425 | "colab_type": "code", 426 | "id": "TnARQv5td3yG", 427 | "colab": {} 428 | }, 429 | "cell_type": "code", 430 | "source": [ 431 | "def forward_and_compute_loss(inputs, targets, encoder, decoder, criterion):\n", 432 | " batch_size = inputs.size()[1]\n", 433 | " \n", 434 | " # định nghĩa 2 kí tự bắt đầu và kết thúc\n", 435 | " sos = Variable(torch.ones((1, batch_size), dtype=torch.long)*sos_idx)\n", 436 | " eos = Variable(torch.ones((1, batch_size), dtype=torch.long)*eos_idx)\n", 437 | " \n", 438 | " # input của mô hình decoder phải thêm kí tự bắt đầu\n", 439 | " decoder_inputs = torch.cat((sos, targets), dim=0)\n", 440 | " # output cần dự đoán của mô hình decoder phải thêm kí tự kết thúc\n", 441 | " decoder_targets = torch.cat((targets, eos), dim=0)\n", 442 | " \n", 443 | " # forward tính hidden states của câu\n", 444 | " encoder_outputs, encoder_hidden = encoder(inputs)\n", 445 | " # tính output của mô hình decoder\n", 446 | " output, hidden, attn_weights = decoder(decoder_inputs, encoder_hidden, encoder_outputs)\n", 447 | " \n", 448 | " output = torch.transpose(torch.transpose(output, 0, 1), 1, 2) # BxCxS\n", 449 | " decoder_targets = torch.transpose(decoder_targets, 0, 1)\n", 450 | " # tính loss \n", 451 | " loss = criterion(output, decoder_targets)\n", 452 | " \n", 453 | " return loss, output\n", 454 | "\n", 455 | "def train(inputs, targets, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):\n", 456 | " # khai báo train để mô hình biết là đang train hay test\n", 457 | " encoder.train()\n", 458 | " decoder.train()\n", 459 | " \n", 460 | " # zero gradient, phải làm mỗi khi cập nhất gradient\n", 461 | " encoder_optimizer.zero_grad()\n", 462 | " decoder_optimizer.zero_grad()\n", 463 | " \n", 464 | " # tính loss dựa vào hàm đã định nghĩa ở trên\n", 465 | " train_loss, output = forward_and_compute_loss(inputs, targets,encoder, decoder,criterion) \n", 466 | " \n", 467 | " train_loss.backward()\n", 468 | " # cập nhật một step\n", 469 | " encoder_optimizer.step()\n", 470 | " decoder_optimizer.step()\n", 471 | " \n", 472 | " # return loss để print :D\n", 473 | " return train_loss.item()\n", 474 | "\n", 475 | "def evaluate(inputs, targets, encoder, decoder, criterion):\n", 476 | " # báo cho mô hình biết đang test/eval\n", 477 | " encoder.eval()\n", 478 | " decoder.eval()\n", 479 | " # tính loss\n", 480 | " eval_loss, output = forward_and_compute_loss(inputs, targets, encoder, decoder,criterion)\n", 481 | " output = torch.transpose(output, 1, 2)\n", 482 | " # dự đoán của mỗi thời điểm các vị trí có prob lớn nhất\n", 483 | " pred_idx = torch.argmax(output, dim=-1).squeeze(-1)\n", 484 | " pred_idx = pred_idx.data.cpu().numpy()\n", 485 | " \n", 486 | " # return loss và kết quả dự đoán\n", 487 | " return eval_loss.item(), pred_idx\n", 488 | "\n", 489 | "def predict(inputs, encoder, decoder, target_length=max_length):\n", 490 | " ### Lúc dự đoán chúng ta cần tính kết quả ngay lập tức tại mỗi thời điểm, \n", 491 | " ### rồi sau đó dừng từ được dự đoán để tính từ tiếp theo \n", 492 | " batch_size = inputs.size()[1]\n", 493 | " \n", 494 | " # input đầu tiên của mô hình decoder là kí tự bắt đầu, chúng ta dự đoán kí tự tiếp theo, sau đó lại dùng kí tự này để dự đoán từ kế tiếp\n", 495 | " decoder_inputs = Variable(torch.ones((1, batch_size), dtype=torch.long)*sos_idx)\n", 496 | " \n", 497 | " # tính hidden state của mô hình encoder, cũng là vector biểu diễn của các từ, chúng ta cần tính context vector dựa trên những hidden states này\n", 498 | " encoder_outputs, encoder_hidden = encoder(inputs)\n", 499 | " hidden = encoder_hidden\n", 500 | " \n", 501 | " preds = []\n", 502 | " attn_weights = []\n", 503 | " # chúng ta tính từng từ tại mỗi thời điểm\n", 504 | " for i in range(target_length):\n", 505 | " # dự đoán từ đầu tiên\n", 506 | " output, hidden, attn_weight = decoder(decoder_inputs, hidden, encoder_outputs)\n", 507 | " output = output.squeeze(dim=0)\n", 508 | " pred_idx = torch.argmax(output, dim=-1)\n", 509 | " \n", 510 | " # thay đổi input tiếp theo bằng từ vừa được dự đoán\n", 511 | " decoder_inputs = Variable(torch.ones((1, batch_size), dtype=torch.long)*pred_idx)\n", 512 | " preds.append(decoder_inputs)\n", 513 | " attn_weights.append(attn_weight.detach())\n", 514 | " \n", 515 | " preds = torch.cat(preds, dim=0)\n", 516 | " preds = torch.transpose(preds, 0, 1)\n", 517 | " attn_weights = torch.cat(attn_weights, dim=0)\n", 518 | " attn_weights = torch.transpose(attn_weights, 0, 1)\n", 519 | " return preds, attn_weights" 520 | ], 521 | "execution_count": 0, 522 | "outputs": [] 523 | }, 524 | { 525 | "metadata": { 526 | "id": "VvrO070PcmWQ", 527 | "colab_type": "text" 528 | }, 529 | "cell_type": "markdown", 530 | "source": [ 531 | "### Train và eval\n", 532 | "Trong phần này, chúng ta train mô hình, cũng như theo dõi độ lỗi, kết quả dự đoán tại mỗi epoch. " 533 | ] 534 | }, 535 | { 536 | "metadata": { 537 | "colab_type": "code", 538 | "id": "Fyfzas04d3yZ", 539 | "outputId": "174b5a9c-a703-46e8-fb17-7f795589d665", 540 | "colab": { 541 | "base_uri": "https://localhost:8080/", 542 | "height": 725 543 | } 544 | }, 545 | "cell_type": "code", 546 | "source": [ 547 | "epochs = 10\n", 548 | "batch_size = 64\n", 549 | "\n", 550 | "encoder = Encoder(input_size, hidden_size)\n", 551 | "decoder = Decoder(output_size, hidden_size, 0.1)\n", 552 | "\n", 553 | "# Initialize optimizers and criterion\n", 554 | "encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n", 555 | "decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)\n", 556 | "criterion = nn.CrossEntropyLoss()\n", 557 | "\n", 558 | "X_val = torch.tensor(X_test, dtype=torch.long)\n", 559 | "y_val = torch.tensor(y_test, dtype=torch.long)\n", 560 | "X_val = torch.transpose(X_val, 0, 1)\n", 561 | "y_val = torch.transpose(y_val, 0, 1)\n", 562 | "\n", 563 | "for epoch in range(epochs):\n", 564 | " for idx in range(len(X_train)//batch_size):\n", 565 | " # input đầu vào của chúng ta là timestep first nhé. \n", 566 | " X_train_batch = torch.tensor(X_train[batch_size*idx:batch_size*(idx+1)], dtype=torch.long)\n", 567 | " y_train_batch = torch.tensor(y_train[batch_size*idx:batch_size*(idx+1)], dtype=torch.long)\n", 568 | " \n", 569 | " X_train_batch = torch.transpose(X_train_batch, 0, 1)\n", 570 | " y_train_batch = torch.transpose(y_train_batch, 0, 1)\n", 571 | " train_loss= train(X_train_batch, y_train_batch, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n", 572 | " eval_loss, preds = evaluate(X_val, y_val, encoder, decoder, criterion)\n", 573 | " \n", 574 | " print('Epoch {} - train loss: {:.3f} - eval loss: {:.3f}'.format(epoch, train_loss, eval_loss))\n", 575 | " print_idx = np.random.randint(0, len(preds), 3)\n", 576 | " for i in print_idx:\n", 577 | " x_val = decoder_sentence(X_val[:,i].numpy(), x_id2w)\n", 578 | " y_pred = decoder_sentence(preds[i], y_id2w)\n", 579 | " print(\" {:<35s}\\t{:>10}\".format(x_val, y_pred))" 580 | ], 581 | "execution_count": 9, 582 | "outputs": [ 583 | { 584 | "output_type": "stream", 585 | "text": [ 586 | "Epoch 0 - train loss: 0.533 - eval loss: 0.505\n", 587 | " 7 11 17 \t7017-01-17\n", 588 | " tháng 3 14 1971 \t1971-03-12\n", 589 | " 26 thg 3, 2013 \t2013-03-16\n", 590 | "Epoch 1 - train loss: 0.197 - eval loss: 0.170\n", 591 | " ngày 03 tháng 04 năm 2008 \t2008-04-03\n", 592 | " 27 thg 11 1985 \t1985-01-27\n", 593 | " 30 tháng 7 1975 \t1975-07-30\n", 594 | "Epoch 2 - train loss: 0.112 - eval loss: 0.097\n", 595 | " 14, thg 7 2013 \t2013-07-14\n", 596 | " 05/12/1990 \t1990-02-05\n", 597 | " 27 thg 10, 1990 \t1990-00-27\n", 598 | "Epoch 3 - train loss: 0.050 - eval loss: 0.044\n", 599 | " 10 thg 7, 1975 \t1975-07-10\n", 600 | " 5 12 89 \t1989-12-05\n", 601 | " thứ ba, ngày 31 tháng 5 năm 2016 \t2016-05-31\n", 602 | "Epoch 4 - train loss: 0.028 - eval loss: 0.025\n", 603 | " ngày 27 tháng 04 năm 2000 \t2000-04-27\n", 604 | " 27 thg 1, 1974 \t1974-11-27\n", 605 | " 27 11 00 \t2000-11-27\n", 606 | "Epoch 5 - train loss: 0.027 - eval loss: 0.021\n", 607 | " thứ tư, ngày 12 tháng 2 năm 1997 \t1997-02-12\n", 608 | " 11 tháng 1, 1981 \t1981-11-11\n", 609 | " 4 tháng 12, 2012 \t2012-12-04\n", 610 | "Epoch 6 - train loss: 0.018 - eval loss: 0.020\n", 611 | " 1 04 02 \t2002-04-01\n", 612 | " 19 thg 11, 1984 \t1984-11-19\n", 613 | " 21.10.12 \t2012-10-21\n", 614 | "Epoch 7 - train loss: 0.022 - eval loss: 0.019\n", 615 | " 07.04.95 \t1995-04-07\n", 616 | " 1 tháng 3 1970 \t1970-03-01\n", 617 | " 09/03/1993 \t1993-03-09\n", 618 | "Epoch 8 - train loss: 0.014 - eval loss: 0.018\n", 619 | " 05.07.80 \t1980-07-05\n", 620 | " 9 thg 9, 1978 \t1978-09-09\n", 621 | " ngày 29 tháng 12 năm 2004 \t2004-12-29\n", 622 | "Epoch 9 - train loss: 0.013 - eval loss: 0.017\n", 623 | " 06/08/1996 \t1996-08-06\n", 624 | " 04, thg 3 1980 \t1980-03-04\n", 625 | " 25 thg 3 1972 \t1972-03-25\n" 626 | ], 627 | "name": "stdout" 628 | } 629 | ] 630 | }, 631 | { 632 | "metadata": { 633 | "id": "XL0477_kbpWK", 634 | "colab_type": "text" 635 | }, 636 | "cell_type": "markdown", 637 | "source": [ 638 | "## Predict\n", 639 | "Chúng ta dự đoán một vài mẫu và phân tích một số kết quả của cơ chế attention" 640 | ] 641 | }, 642 | { 643 | "metadata": { 644 | "colab_type": "code", 645 | "id": "NglXGh77d3yc", 646 | "colab": {} 647 | }, 648 | "cell_type": "code", 649 | "source": [ 650 | "preds, attn_weights = predict(X_val ,encoder, decoder, target_length=10)" 651 | ], 652 | "execution_count": 0, 653 | "outputs": [] 654 | }, 655 | { 656 | "metadata": { 657 | "colab_type": "code", 658 | "id": "vRMILrf6d3yj", 659 | "colab": {} 660 | }, 661 | "cell_type": "code", 662 | "source": [ 663 | "def show_attention(input_sentence, output_words, attentions):\n", 664 | " # Set up figure with colorbar\n", 665 | " fig = plt.figure()\n", 666 | " ax = fig.add_subplot(111)\n", 667 | " cax = ax.matshow(attentions.numpy(), cmap='bone')\n", 668 | " fig.colorbar(cax)\n", 669 | "\n", 670 | " # Set up axes\n", 671 | " ax.set_xticks(np.arange(len(input_sentence)))\n", 672 | " ax.set_xticklabels(list(input_sentence), rotation=90)\n", 673 | " ax.set_yticks(np.arange(len(output_words)))\n", 674 | " ax.set_yticklabels(list(output_words))\n", 675 | " ax.grid()\n", 676 | " ax.set_xlabel('Input Sequence')\n", 677 | " ax.set_ylabel('Output Sequence')\n", 678 | " plt.show()" 679 | ], 680 | "execution_count": 0, 681 | "outputs": [] 682 | }, 683 | { 684 | "metadata": { 685 | "id": "FvmyeXyBbpWW", 686 | "colab_type": "text" 687 | }, 688 | "cell_type": "markdown", 689 | "source": [ 690 | "Chọn ngẫu nhiên một câu trong tập validation để hiển thị. Khi hiển thị cơ chế attention, chúng ta có một cái nhìn về quá trình dự đoán của mô hình rõ ràng hơn, giúp đánh giá có thể interpretable hơn. " 691 | ] 692 | }, 693 | { 694 | "metadata": { 695 | "colab_type": "code", 696 | "id": "eFJvKjnOLL9W", 697 | "outputId": "d0f8a27b-85ba-46cb-d2e6-5289aad0ec58", 698 | "colab": { 699 | "base_uri": "https://localhost:8080/", 700 | "height": 336 701 | } 702 | }, 703 | "cell_type": "code", 704 | "source": [ 705 | "show_idx = randint(0, len(preds))\n", 706 | "text_x = decoder_sentence(X_val[:,show_idx].numpy(), x_id2w)\n", 707 | "text_y = decoder_sentence(preds[show_idx].numpy(), y_id2w)\n", 708 | "attn_weight = attn_weights[show_idx, :, -len(text_x):]\n", 709 | "show_attention(text_x, text_y, attn_weight)" 710 | ], 711 | "execution_count": 12, 712 | "outputs": [ 713 | { 714 | "output_type": "display_data", 715 | "data": { 716 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcgAAAE/CAYAAADCNlNLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XFcVHW+//H3YRRQoWSUITXMlV9W\nopSk3oxd3dtiW+a2S5vKUmn3ctMedTfLW61hQu0Kq121m5prm2vetVKKuHvNcDXd2ryKkq2i4lpK\niblrwJChqATo+f1hTrIODMKcgTO+nj3O48GZYT7nMwTz8fM93/M9hmmapgAAQCMh7Z0AAAAdEQUS\nAAAvKJAAAHhBgQQAwAsKJAAAXlAgAQDwggIJAIAXndo7AQDApaMtl94bhuHHTHyjgwQAwAs6SABA\nwJxpQwfpCHAHSYEEAASMnVY3pUACAALGFAUSAIALnLFPfaRAAgAChyFWAAC8aMsknUDjMg8AALyg\ngwQABAxDrAAAeEGBBADACzudg6RAAgAChg4SAAAv7LRQALNYAQDwgg4SABAwrKQDAIAXnIMEAMAL\nZrECAOAFHSQAAF5QIAEA8MJOQ6xc5gEAgBd0kACAgGGIFQAAL+y0kg4FEgAQMHZaKMA25yC9teVf\nfPFFO2QCf/if//mf9k4BQcrKz4oTJ06orKxMZWVlOnnypF9iWhn3nIaGBv3tb39TQ0OD32NfLNM0\nW70FmmF28AHhd999Vzk5OTp16pRGjRqlmTNnKiIiQpI0ceJE/f73v29V3MTERKWkpOihhx5Sjx49\n/Jmybb344ot69dVXPfumacowDBUWFrYp7u7du/Xyyy/rq6++kiTV19fL7Xbr3XffbVNcSXrrrbe0\nYsUK1dTUeP6IDMPQxo0b2xT3Bz/4wQWPORwOxcbGatq0aYqPj29TfH/h9/hbVn1WSGd/h7Ozs3Xs\n2DFFRUXJNE1VVFQoJiZGmZmZuuaaazpU3FmzZunpp5+WJG3ZskUzZsxQz549VVVVpWeffVbf+973\nWhXXH8rc7la/9qqePf2YSQuYHdzdd99tHj161Dx9+rS5atUqc/z48eaxY8dM0zTNe++9t9Vx7733\nXrOoqMicNGmSOX36dLOoqMisr6/3V9q2NHbsWPPEiRN+i/fss8+apmma48ePNwsLC81JkyaZu3fv\nNufNm2f+6U9/8ssxbr/9drO0tNQ8ceJEo62tlixZYr722mvmF198YX7xxRfmG2+8YS5ZssTcsWOH\nmZqa6ofM/YPf429Z9VlhmqaZmppqHjhw4ILH9+zZY6alpXW4uOe/37S0NPPQoUOmaZpmRUWFOX78\n+FbH9YdPKypavQVahx9idTgc6t69u0JCQjRhwgQ98MADSk9P15dffinDMFod1zAMDRs2TMuXL1da\nWprefvttjR07Vj/96U81efJkP74D/1q8ePEFj82ePdsvsa+99lp16uS/09JDhgzRxo0bFRYWpptu\nuklhYWEaNGiQpk2b1qhTbYt+/fqpf//+6tq1a6OtrT744AOlpaUpJiZGMTExGjdunDZv3qwbbrjB\nD1n7j11/j61g1WeFdHY0JS4u7oLH4+Pjdfr06Q4X9/z3e/nllys2NlaSFB0d7de/8WDX4X9SiYmJ\nmjJlil544QWFh4crOTlZYWFhuv/++z1Ddq1hnjeyPHjwYA0ePFiSVFFRocrKyjbn7W/r16/XmjVr\ntH37dn388ceexxsaGvTXv/5V06dPb3XsRx55RIZh6MSJE7rttts0cOBAORwOz/MvvPBCq+L+6Ec/\nkiS9+eabWr9+vZxOpxYuXKgrrrhCR44caXW+53M6nZowYYJuuOGGRjk/+eSTbYobFhamnJwcJSYm\nKiQkRHv27FF9fb02b97slwLsL+31e5yZmanIyEglJSXp5ptvtuw4F8OqzwpJuv766/Xggw8qOTlZ\nTqdTkuR2u7Vu3ToNHz68w8Xdv3+/pk6dKtM0VVZWprVr1+r222/XsmXLFBkZ2eq4/mB27LN6jXT4\nc5CStG3bNg0fPrzRv4pqampUUFCg8ePHtypmXl6e7r77bn+lGBCHDx/Wr371K6Wnp3seCwkJUf/+\n/T1/XK1RVFTU7PNt+UOVzk5AqKiokMvl0ssvv6zjx4/rJz/5iefDvC2amuyTkpLSprg1NTX6wx/+\noNLSUpmmqb59+yolJUWnTp1SZGRku3/InNNev8dut1s9e/ZUfX29OnfuHPDjN8WKz4pzPvzwQxUW\nFsr9zTk0l8ulpKQkDRkypMPF/ce/6auuukoxMTF6++23dcstt6hbt25tyrktDpSXt/q1/y8mxo+Z\n+GaLAgkACA772zCj+OorrvBjJr51+CFWAEDwYKEAAAC8sNNCARRIAEDA2OmsXoe/zAMAgPZABwkA\nCBg7dZAdpkC29UJeAPCH0NBwS+KWH62yJK4k9bzsckviNjTU+z2mnW6Y3GEKJAAg+NFBAgDgBQUS\nAAAvGGIFAMALOy0UwGUeAAB4QQcJAAgYVtL5xnPPPaePPvpIDQ0NmjJlim699VYrDwcA6OCYpCNp\n69at2r9/v3Jzc3X06FGlpKRQIAHgEkeBlDRs2DAlJCRIki677DKdOnVKp0+fbnRTWwDApYVZrJIc\nDofnzut5eXkaOXIkxREALnF0kOfZsGGD8vLytGzZMqsPBQDo4CiQ39i0aZOWLFmipUuXKjIy0spD\nAQDgV5YVyOPHj+u5557T8uXL1b17d6sOAwCwEc5BSiooKNDRo0f16KOPeh6bM2eOevfubdUhAQAd\nnJ1W0jHMDjIgzO2uAHQE3O7qW1bc7uqDjz9u9WtHXnONHzPxjZV0AAAB00F6shahQAIAAoYCCQCA\nF1ZP0snJyVFxcbEMw1BGRoZnwRpJeu2117R69WqFhIRo0KBBmjFjRrOxuJsHACAoFBUVqaysTLm5\nucrOzlZ2drbnuZqaGv3ud7/Ta6+9ppUrV6q0tFQ7d+5sNh4FEgAQMKZptnrzpbCwUMnJyZKkuLg4\nVVdXq6amRpLUuXNnde7cWSdPnlRDQ4NOnTqlyy9vfnITQ6wdklUzeu0z9g+0l7q6WkviRnXrZklc\nu7HyHKTb7VZ8fLxn3+l0qrKyUhEREQoLC9PDDz+s5ORkhYWF6Y477tB3vvOdZuPRQQIAAuaMabZ6\nu1jnF+Oamhq99NJL+uMf/6iNGzequLhY+/bta/b1FEgAQMCYbfjPF5fLJbfb7dmvqKhQdHS0JKm0\ntFSxsbFyOp0KDQ3V0KFDtWfPnmbjUSABAAFjmq3ffElKStK6deskSSUlJXK5XIqIiJAk9enTR6Wl\npaqtPTuEvmfPHvXr16/ZeJyDBAAEjJWXeSQmJio+Pl6pqakyDENZWVnKz89XZGSkRo8erfT0dE2c\nOFEOh0NDhgzR0KFDm43HUnMdEpN0ALQ/K8pDQXFxq1875vrr/ZiJb3SQAICA6SA9WYtQIAEAAcPt\nrgAA8IIO8hvNrYkHALj0UCDVeE280tJSZWRkKDc316rDAQBswE5DrJZdB9ncmngAgEuTlQsF+Jtl\nBdLtdisqKsqzf25NPAAA7CBgk3TsNO4MALCGnUqBZQWyuTXxAACXJs5Bqvk18QAAlyYr7wfpb5Z1\nkN7WxAMAXNrs1EFaeg7y8ccftzI8AMBm7DQfhZV0AAABY6cCyf0gAQDwgg4SABA4NuogKZAAgIAx\nz1AgAQC4gI0aSApkR9SpU2dL4jY01FkSFwBayk6TdCiQAICAoUACAOCFnQokl3kAAOAFHSQAIGCY\nxQoAgBd2GmKlQAIAAoYCCQCANxTIs3JyclRcXCzDMJSRkaGEhAQrDwcA6OBsVB+tK5BFRUUqKytT\nbm6uSktLlZGRodzcXKsOBwCwATtN0rHsMo/CwkIlJydLkuLi4lRdXa2amhqrDgcAgF9ZViDdbrei\noqI8+06nU5WVlVYdDgBgA6ZptnoLtIBN0rHTzCUAgDXsVAssK5Aul0tut9uzX1FRoejoaKsOBwCw\nATsVSMuGWJOSkrRu3TpJUklJiVwulyIiIqw6HADABhhilZSYmKj4+HilpqbKMAxlZWVZdSgAgF3Y\naBarpecgH3/8cSvDAwBshiFWAABsjqXmAAABY6MGkgIJAAgcOw2xUiABAAFDgQQAwAs7rcVKgeyA\nGhrq2jsFADZip67MTrlSIAEAAWOnAsllHgAAeEEHCQAImKDrID/55BNt2LBBknTs2DFLEwIABDHT\nbP0WYD47yOXLl2vNmjWqq6tTcnKyFi9erMsuu0wPPfRQIPIDAAQR80x7Z9ByPjvINWvW6I033tDl\nl18uSXryySf1/vvvW50XACAIBdXdPLp166aQkG/raEhISKN9AABayk7nIH0WyL59+2rRokU6duyY\n1q9fr4KCAsXFxfkMvG3bNk2dOlVXX321JGnAgAGaOXNm2zMGANhWUBXIzMxM/f73v1dMTIxWr16t\noUOHKi0trUXBhw8frgULFrQ5SQAAAs1ngXQ4HLr++uuVnp4uSfrTn/6kTp24OgQAcPHs1EH6PJmY\nmZmpP//5z579oqIizZgxo0XBDxw4oAcffFA/+9nPtHnz5tZnCQAICuYZs9VboPlsBQ8ePKhZs2Z5\n9qdPn6777rvPZ+B+/frp3//933X77bfr888/18SJE7V+/XqFhoa2LWMAgH1Z3EHm5OSouLhYhmEo\nIyNDCQkJnueOHDmiadOmqb6+XgMHDtQvf/nLZmP57CBra2v11VdfefbLy8v19ddf+0wyJiZGY8aM\nkWEY6tu3r3r27Kny8nKfrwMABC8rL/MoKipSWVmZcnNzlZ2drezs7EbPz549W//6r/+qvLw8ORwO\n/f3vf282ns8O8uGHH9bYsWPVq1cvnT59WhUVFRcc1JvVq1ersrJS6enpqqysVFVVlWJiYny+DgAQ\nvKxsIAsLC5WcnCxJiouLU3V1tWpqahQREaEzZ87oo48+0vz58yVJWVlZPuP5LJD//M//rA0bNujA\ngQMyDEP9+/dXly5dfAa+5ZZb9Pjjj2vjxo2qr6/XM888w/AqAFzirJyk43a7FR8f79l3Op2qrKxU\nRESEvvzyS3Xr1k2//vWvVVJSoqFDh+o//uM/mo3ns0BWVlaqoKBA1dXVjd7Y1KlTm31dRESElixZ\n4is8AACWOL9mmaap8vJyTZw4UX369NHkyZP1/vvv6/vf/36Tr/d5DnLKlCnat2+fQkJC5HA4PBsA\nABfLylmsLpdLbrfbs19RUaHo6GhJUlRUlHr37q2+ffvK4XBoxIgR2r9/f7PxfHaQXbt21a9//Wuf\niQEA4IuVQ6xJSUlauHChUlNTVVJSIpfLpYiICElSp06dFBsbq4MHD6pfv34qKSnRHXfc0Ww8nwXy\n+uuvV2lpaYuWlwMAoDlWFsjExETFx8crNTVVhmEoKytL+fn5ioyM1OjRo5WRkaHp06fLNE0NGDBA\nt9xyS7PxDNNHtnfeeadKS0sVFRWlTp06yTRNGYbh9zt6GIbh13gAcKmw0+o0v1q8otWvnfmQ72vw\n/clnB/mb3/wmEHkAAC4BdirmPifpREdH6/3339fKlSvVp08fud1u9ezZMxC5AQCCzRmz9VuA+SyQ\nzzzzjA4dOqRt27ZJkkpKSjR9+nTLE4P/tWUFi452I1MA3zIMw5LtUuezQH766ad66qmnFB4eLklK\nS0tTRUWF5YkBAIKPabZ+CzSf5yDP3drq3L8mTp48qdraWmuzAgAEJTuNOPkskLfddpsmTZqkw4cP\na9asWfrggw9afMNkAADOF1QF8t5771VCQoKKiooUGhqq+fPna9CgQYHIDQAQZNrjvo6t5bNAFhYW\nSpJnAdjjx4+rsLBQI0aMsDYzAEDQCaoOcvHixZ6v6+vrdeDAASUmJlIgAQAXLagK5IoVjVc9qKqq\n0rx58yxLCACAjsBngfxHPXr00KeffmpFLgCAYBdMHeQTTzzR6ILRI0eOKCTE5+WTAABcIKiGWG++\n+WbP14ZhKCIiQklJSZYmBQAITuaZ9s6g5XwWyKFDh17w2Pk3pIyNjfVvRgCAoBVUHWR6ero+//xz\nde/eXYZh6OjRo+rdu7fntlcbN25s9vWvv/661q5dq6ioKC1YsMBviQMA7CeoCuTIkSOVkpLiuQ5y\n586dWrNmjZ5++ukWHSAtLY2VdwAAkuxVIH3Otvn44489xVGSbrjhBu3bt8/SpAAAaG8+O8ja2lq9\n9tprGjZsmCRp+/btOnnypOWJAQCCj506SJ8Fct68eVq4cKFWrVolSRowYID+8z//0/LEAADBJ6jW\nYu3bt6/mzJkjt9stl8sViJwAAEHKTh2kz3OQhYWFSk5O1sSJEyVJOTk5eu+99yxPDAAQhGx0x2Sf\nBfL555/XG2+8oejoaEnSgw8+qN/85jeWJwYACD42qo++h1i7du2qnj17evadTqc6d+5saVIAgOBk\npyFWnwUyPDxcRUVFkqTq6mq98847CgsLszwxAADak88CmZWVpWeeeUa7d+/W6NGjdeONN+qXv/xl\nIHKDn52/6DwAtIegmsXaq1cvvfTSS4HIBQAQ5Ow0xNrkJJ0jR45o9uzZnv3nn39eQ4cO1V133aXP\nPvssIMkBAIKLaZqt3gKtyQKZmZnpuVPH3r17lZeXp7feekuPPfZYo8IJAEBLBUWBPH78uO655x5J\n0vr16zVmzBhdddVV+t73vqfa2tqAJQgACCI2us6jyQJ5/kzVoqIi3XTTTZ59O40hAwA6DvOM2eot\n0JqcpGMYhvbt26fjx4/rk08+0c033yxJqqysVF1dXcASBACgPTRZIKdNm6apU6equrpaM2fOVJcu\nXVRbW6u7775b06dPD2SOAIAgYacByCYLZEJCgtatW9fosfDwcL3yyivq379/i4Ln5OSouLhYhmEo\nIyNDCQkJbcsWAGBrdjpF5/M6yH/U0uJYVFSksrIy5ebmqrS0VBkZGcrNzb3oBAEAwSOoC2RLnbsL\niCTFxcWpurpaNTU1ioiIsOqQAIAOzk4F0ufdPLxpyWUebrdbUVFRnn2n06nKysrWHA4AECTsNIvV\nZ4FMT0+/4LFz10deDDv9qwEAYA07LRTQ5BDr6tWr9eKLL+rvf/+7vv/973ser6+vb3T7q6a4XC65\n3W7PfkVFheeekgAAdHRNFsg777xTd9xxh2bMmKGf//znnsdDQkLkcrl8Bk5KStLChQuVmpqqkpIS\nuVwuzj8CwKXORqOJzU7ScTgc+vGPf6xDhw41evzgwYMaMWJEs4ETExMVHx+v1NRUGYahrKystmcL\nALA1O51u8zmLdfHixZ6v6+vrdeDAASUmJvoskJL0+OOPty07AEBQsVF99F0gV6xY0Wi/qqpK8+bN\nsywhAEDwCqobJv+jHj166NNPP7UiFwBAkAuqIdYnnnhChmF49o8cOaKQkFZdPgkAuMQFVYE8dxcP\n6ewdPiIiIpSUlGRpUgAAtDefrWBKSori4+MVFhamsLAw9e/fX126dAlEbgCAIBMUCwWcM2fOHG3c\nuFGDBw/WmTNnNG/ePI0dO1aPPvpoIPID8A0rPyDOP40CWCmohli3bdumd955R507d5Yk1dXVKTU1\nlQIJALhoQTWLtWfPnurU6dtv69y5s/r06WNpUgCAIBVMHWRUVJR++tOf6qabbpJpmvrwww8VGxur\nF154QZI0depUy5MEAAQHG9VH3wUyNjZWsbGxnv3zFy4HACBY+SyQERERuv/++xs9tmDBAj3yyCNW\n5QQACFJWT9LJyclRcXGxDMNQRkaGEhISLvieefPmaefOnResFPePmiyQW7du1datW7V69WpVV1d7\nHm9oaFB+fj4FEgBw0awskEVFRSorK1Nubq5KS0uVkZGh3NzcRt9z4MABffjhh56Jp81p8jrI/v37\nKy4uTtLZu3qc28LDwzV//vw2vg0AwKXIPGO2evOlsLBQycnJkqS4uDhVV1erpqam0ffMnj1bjz32\nWItybbKDdLlc+tGPfqTExMRWzVp98803tXr1as/+nj17tGPHjouOAwAIHlZ2kG63W/Hx8Z59p9Op\nyspKz72I8/PzNXz48BbXNJ/nINPS0rxeRPz+++83+7px48Zp3Lhxks62vWvXrm1RQgCA4BXIhQLO\nP9ZXX32l/Px8vfLKKyovL2/R630WyNdff93zdX19vQoLC/X1119fVJIvvvii5s6de1GvAQAEHysL\npMvlktvt9uxXVFQoOjpa0tl5NV9++aXuuece1dXV6dChQ8rJyVFGRkaT8XyuxdqnTx/P1q9fP/3s\nZz/Tpk2bWpzwrl271KtXL0+SAABYISkpSevWrZMklZSUyOVyeYZXb7vtNhUUFOiNN97QokWLFB8f\n32xxlFrQQRYWFjba/+KLL3To0KEWJ5yXl6eUlJQWfz8AIIhZ2EEmJiYqPj5eqampMgxDWVlZys/P\nV2RkpEaPHn3R8QzTR7973333ffvN39zu6t577210G6zm/PCHP9Tbb7+t0NDQ5hNhsWSgWSxWjkCz\n4nduQuovWv3a3FVz/JiJbz47SF8XUjanvLxc3bp181kcAQCXBjvdzaPZc5CFhYW65557NGTIECUm\nJur+++/Xzp07Wxy8srJSTqezzUkCAIJDUNwPsqCgQIsXL9a0adN0ww03SJJ2796trKwsTZ06Vbfc\ncovP4IMGDdLSpUv9ly0AwNbs1EE2WSCXL1+ul19+Wb169fI8NmrUKF133XUtLpAAAJzPTgWyySFW\nwzAaFcdzXC6Xrd4gAACt0WQHWVtb2+SLTp48aUkyAIDg1pI1VTuKJjvI6667zusM1qVLlyoxMdHS\npAAAQco0W78FWJMd5JNPPqmHHnpIa9as0eDBg2Wapnbs2KGIiAi99NJLgcwRABAkTNmng2yyQDqd\nTq1atUqbN2/W3r171bVrV91+++0aOnRoIPMDAAQRO81h8blQQFJSkpKSkixPJCy0iyVxv647ZUlc\nO6qpte5nERFuzf8/fIvVbhAMTPNMe6fQYj4LJAAA/mKnDtLn3TwAALgU0UECAALGTh0kBRIAEDAU\nSAAAvGCSDgAA3tBBAgBwoaBYKAAAAH+z0zlILvMAAMALOkgAQMDYqYOkQAIAAoZZrOd5/fXXtXbt\nWkVFRWnBggVWHw4A0IHRQZ4nLS1NaWlpVh8GAGADFEgAALygQAIA4I2NCiSXeQAA4AUdJAAgYEwx\nixUAgAtwDhIAAC8okAAAeEGBBADAC1bSaYWw8G6WxP34b4csiStJ/aKjLYtthYjwLu2dAoBLnJ06\nSC7zAADAiw7TQQIAgp+dOkgKJAAgcCiQAABcyBQFEgCACzCLFQAALzgHCQCAFxTIbzz33HP66KOP\n1NDQoClTpujWW2+18nAAAPiNZQVy69at2r9/v3Jzc3X06FGlpKRQIAHgEkcHKWnYsGFKSEiQJF12\n2WU6deqUTp8+LYfDYdUhAQAdHJN0JDkcDnXt2lWSlJeXp5EjR1IcAeASRwd5ng0bNigvL0/Lli2z\n+lAAgI6OAnnWpk2btGTJEi1dulSRkZFWHgoAYAMsFCDp+PHjeu6557R8+XJ1797dqsMAAGyEIVZJ\nBQUFOnr0qB599FHPY3PmzFHv3r2tOiQAAH5jWYGcMGGCJkyYYFV4AIANMYsVAAAvGGIFAMALCiQA\nAF5YXSBzcnJUXFwswzCUkZHhWbBGOrvC2/z58xUSEqLvfOc7ys7OVkhISJOxmn4GAAA/M02z1Zsv\nRUVFKisrU25urrKzs5Wdnd3o+czMTC1YsECrVq3SiRMntGnTpmbj0UECAALHwkk6hYWFSk5OliTF\nxcWpurpaNTU1ioiIkCTl5+d7vnY6nTp69Giz8eggAQBBwe12KyoqyrPvdDpVWVnp2T9XHCsqKrR5\n82aNGjWq2XgdpoM8dsxtSdyreva0JC4A4OIFciUdb8OyVVVVevDBB5WVldWomHrTYQokACD4WTlJ\nx+Vyye3+ttmqqKhQdHS0Z7+mpkYPPPCAHn30UX33u9/1GY8hVgBAwFg5SScpKUnr1q2TJJWUlMjl\ncnmGVSVp9uzZmjRpkkaOHNmiXA2zg1yUYhiGJXGtfHtW5QwAHYEVn5/x8Umtfm1JyWaf3zN37lxt\n375dhmEoKytLe/fuVWRkpL773e9q2LBhGjJkiOd7x44d2+yKbxTINqBAAghmVnx+Dhx4c6tfu3fv\nFj9m4hvnIAEAAdNBerIW4RwkAABeWNZBvvnmm1q9erVnf8+ePdqxY4dVhwMA2ICdOkjLCuS4ceM0\nbtw4SWeX/1m7dq1VhwIA2AUFsrEXX3xRc+fODcShAAAdmCnuB+mxa9cu9erVq9HFmgCASxNDrOfJ\ny8tTSkqK1YcBANiAnQqk5bNYt23b1ujCTADApcvKlXT8zdICWV5erm7duik0NNTKwwAA4HeWDrFW\nVlbK6XRaeQgAgI2YFt4P0t9Yaq4NWGoOQDCz4vOzf//rW/3aTz8t9mMmvrHUHAAgYDpIT9YiFEgA\nQOBQIAEAuJApCiQAABew0yQd7uYBAIAXdJAAgIBhkk4r2OmHdo4dcwaA9mSnz80OUyABAMGPAgkA\ngBcUSAAAvLDTLFYKJGzrmmuuUUlJiTp18t+v8V/+8hdFR0crNja20eO1tbWaNWuWSktL1alTJ504\ncUL/9m//pjFjxvjt2MAlgQ4SsKf8/HyNGTPmggL5yiuvKDw8XCtXrpQkHTlyRJMnT9aoUaPUrVu3\n9kgVgMUokLC9bdu26be//a2uuOIKHThwQJ06ddLSpUtVVVWl+++/XyNHjtS+ffskSc8//7xiYmIa\ndZ/5+fnasmWLfvjDH+qPf/yjdu3apaeeekojRozwHKO6ulonTpyQaZoyDEO9evXS22+/7Xl+/vz5\n+stf/qLa2loNGzZMTz75pCQpMzNTe/bskcvlUlRUlGJiYvTYY495Pf7cuXO1b98+zZkzRw0NDaqv\nr1dmZqYGDhyo++67TyNGjNCOHTt08OBB/fznP9edd96pqqoqPfXUUzp+/LgcDocyMzM1YMAAFRQU\n6NVXX5VpmnI6nZo1a5aioqIC+z8G8MJOK+mwUACCws6dOzVt2jTl5uYqJCRE//d//ydJ+vzzz3XX\nXXfp9ddf1/Dhw7Vs2bImY4yu5zaYAAAFjUlEQVQePVrXXXedpk+f3qg4StLEiRO1Z88e/eAHP9CM\nGTO0du1a1dXVSZLWrl2r8vJyvfrqq8rLy9OhQ4f03nvvqbCwUH/961+Vl5enRYsW6ZNPPvH5Pp54\n4gk9++yzWrFihZ555hk9/fTTnudOnjypl19+WdnZ2Vq6dKkkad68eRo1apRWrlypRx55RP/7v/+r\nI0eOaMmSJVq+fLlWrlyp4cOH66WXXrronylgBTvdMJkOEkEhLi5OPXr0kCT16dNHX331lSSpe/fu\nGjRokCQpMTFR//3f/92q+L1799bq1au1e/dubd26VcuWLdN//dd/6a233tK2bdu0c+dO3XfffZKk\n48eP6/Dhw2poaNCNN94oh8Mhh8Ohf/qnf2r2GFVVVfrss880Y8YMz2M1NTU6c+bspIbhw4d7cqmu\nrpYk7dq1S//yL//ieX748OEqKChQZWWl0tPTJUl1dXW68sorW/W+AX9jkg4QYA6Hw+vj5/+r89zw\n6D+qr6/3Gb+2tlZhYWFKSEhQQkKCHnjgAaWlpWnLli0KDQ3V+PHjPQXpnN/97neN9pu6f+i544eG\nhqpz585asWKF1+87fzLSufdlGIangJ4TGhqqhIQEukZ0SHa6zIMhVgS16upq7d27V9LZGarXXHON\nJCkiIkJHjhyRdPYc5jmGYXgtmJMmTdIf/vAHz/6JEyd09OhRxcbG6sYbb9S7776rhoYGSdKiRYt0\n8OBBXX311dqxY4fOnDmjuro6z7BvU8ePjIzUlVdeqT//+c+SpM8++0yLFi1q9v0NGTJEmzZtkiRt\n375dv/jFLzR48GDt2rVLlZWVks4OAW/YsKGlPzLAUgyxAh1ETEyM8vPzNXv2bJmmqfnz50uSJk+e\nrPT0dF111VW69tprPcUqKSlJWVlZysjI0K233uqJM2/ePGVnZys3N1ehoaH6+uuvNXnyZF133XW6\n9tprtXPnTqWmpsrhcGjgwIGKjY1V37599c477+iuu+5SdHS0BgwY4InX1PHnzJmjWbNm6be//a0a\nGho0ffr0Zt/f1KlT9dRTT+m9996TJM2cOVMxMTGaMWOGpkyZoi5duig8PFxz5szx688VaC07dZCG\naadsgYtw+PBhpaWl6YMPPmjvVCRJCxcuVENDgx577LH2TgVoNz17tv58uNt92I+Z+EYHCQAIGDv1\nZHSQAICA6eHs1erXVn15xI+Z+EYHCQAIGDstFECBBAAEjJ0GLSmQAICAoUACAOCFnVbSYaEAAAC8\noIMEAAQMQ6wAAHhBgQQAwAsKJAAA3lAgAQC4kCn7zGKlQAIAAsZOQ6xc5gEAgBd0kACAgLFTB0mB\nBAAEDAUSAAAvKJAAAHhhp7VYKZAAgIChgwQAwBsbFUgu8wAAwAs6SABAwJiytoPMyclRcXGxDMNQ\nRkaGEhISPM9t2bJF8+fPl8Ph0MiRI/Xwww83G4sOEgAQMKZ5ptWbL0VFRSorK1Nubq6ys7OVnZ3d\n6PlZs2Zp4cKFWrlypTZv3qwDBw40G48CCQAIGNM0W735UlhYqOTkZElSXFycqqurVVNTI0n6/PPP\ndfnll6tXr14KCQnRqFGjVFhY2Gw8CiQAIGCsLJBut1tRUVGefafTqcrKSklSZWWlnE6n1+eawjlI\nAEDABPIyj7Yeiw4SABAUXC6X3G63Z7+iokLR0dFenysvL5fL5Wo2HgUSABAUkpKStG7dOklSSUmJ\nXC6XIiIiJElXXnmlampqdPjwYTU0NOi9995TUlJSs/EM007LGgAA0Iy5c+dq+/btMgxDWVlZ2rt3\nryIjIzV69Gh9+OGHmjt3riTp1ltvVXp6erOxKJAAAHjBECsAAF5QIAEA8IICCQCAFxRIAAC8oEAC\nAOAFBRIAAC8okAAAeEGBBADAi/8P5AZI3mz+VgYAAAAASUVORK5CYII=\n", 717 | "text/plain": [ 718 | "
" 719 | ] 720 | }, 721 | "metadata": { 722 | "tags": [] 723 | } 724 | } 725 | ] 726 | } 727 | ] 728 | } -------------------------------------------------------------------------------- /data/human_vocab.json: -------------------------------------------------------------------------------- 1 | {"N": 0, "1": 1, "\u00e0": 2, "I": 3, "A": 4, "3": 5, "\u1eac": 6, "T": 7, "\u0103": 8, "\u0102": 9, "\u1ea2": 10, "n": 11, "t": 12, "\u1ea3": 13, "9": 14, "\u1ee9": 15, "S": 16, "\u1ee7": 17, "G": 18, " ": 19, "\u00e1": 20, "\u1ee8": 21, ",": 22, "h": 23, "/": 24, "i": 25, "u": 26, "8": 27, "\u00c0": 28, "M": 29, "\u1ee6": 30, "m": 31, "c": 32, "B": 33, "U": 34, "0": 35, "2": 36, "b": 37, "\u1ead": 38, "s": 39, "4": 40, "\u01af": 41, ".": 42, "H": 43, "Y": 44, "y": 45, "\u01b0": 46, "g": 47, "a": 48, "5": 49, "C": 50, "\u00c1": 51, "6": 52, "7": 53, "": 54, "": 55} -------------------------------------------------------------------------------- /data/machine_vocab.json: -------------------------------------------------------------------------------- 1 | {"4": 0, "1": 1, "7": 2, "0": 3, "2": 4, "3": 5, "-": 6, "8": 7, "5": 8, "6": 9, "9": 10, "": 11, "": 12} -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Date Generator 3 | This code creates data for our date translation model 4 | 5 | References: 6 | https://github.com/rasmusbergpalm/normalization/blob/master/babel_data.py 7 | https://github.com/joke2k/faker 8 | https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior 9 | 10 | Contact: 11 | zaf@datalogue.io (@zafarali) 12 | """ 13 | import random 14 | import json 15 | import os 16 | 17 | DATA_FOLDER = os.path.realpath(os.path.join(os.path.realpath(__file__), '..')) 18 | 19 | from faker import Faker 20 | import babel 21 | from babel.dates import format_date 22 | 23 | fake = Faker() 24 | fake.seed(230517) 25 | random.seed(230517) 26 | 27 | FORMATS = ['short', 28 | 'medium', 29 | 'long', 30 | 'full', 31 | 'd MMM YYY', 32 | 'd MMMM YYY', 33 | 'dd MMM YYY', 34 | 'd MMM, YYY', 35 | 'd MMMM, YYY', 36 | 'dd, MMM YYY', 37 | 'd MM YY', 38 | 'd MMMM YYY', 39 | 'MMMM d YYY', 40 | 'MMMM d, YYY', 41 | 'dd.MM.YY', 42 | ] 43 | 44 | # change this if you want it to work with only a single language 45 | LOCALES = ['vi_VN'] 46 | #LOCALES = babel.localedata.locale_identifiers() 47 | 48 | 49 | def create_date(): 50 | """ 51 | Creates some fake dates 52 | :returns: tuple containing 53 | 1. human formatted string 54 | 2. machine formatted string 55 | 3. date object. 56 | """ 57 | dt = fake.date_object() 58 | 59 | # wrapping this in a try catch because 60 | # the locale 'vo' and format 'full' will fail 61 | try: 62 | human = format_date(dt, 63 | format=random.choice(FORMATS), 64 | locale=random.choice(LOCALES)) 65 | 66 | case_change = random.randint(0,3) # 1/2 chance of case change 67 | if case_change == 1: 68 | human = human.upper() 69 | elif case_change == 2: 70 | human = human.lower() 71 | 72 | machine = dt.isoformat() 73 | except AttributeError as e: 74 | # print(e) 75 | return None, None, None 76 | 77 | return human, machine, dt 78 | 79 | 80 | def create_dataset(dataset_name, n_examples, vocabulary=False): 81 | """ 82 | Creates a csv dataset with n_examples and optional vocabulary 83 | :param dataset_name: name of the file to save as 84 | :n_examples: the number of examples to generate 85 | :vocabulary: if true, will also save the vocabulary 86 | """ 87 | human_vocab = set() 88 | machine_vocab = set() 89 | 90 | with open(dataset_name, 'w') as f: 91 | for i in range(n_examples): 92 | h, m, _ = create_date() 93 | if h is not None: 94 | f.write('"'+h + '","' + m + '"\n') 95 | human_vocab.update(tuple(h)) 96 | machine_vocab.update(tuple(m)) 97 | 98 | if vocabulary: 99 | int2human = dict(enumerate(human_vocab)) 100 | int2human.update({len(int2human): '', 101 | len(int2human)+1: ''}) 102 | int2machine = dict(enumerate(machine_vocab)) 103 | int2machine.update({len(int2machine):'', 104 | len(int2machine)+1:''}) 105 | 106 | human2int = {v: k for k, v in int2human.items()} 107 | machine2int = {v: k for k, v in int2machine.items()} 108 | 109 | with open(os.path.join(DATA_FOLDER, 'human_vocab.json'), 'w') as f: 110 | json.dump(human2int, f) 111 | with open(os.path.join(DATA_FOLDER, 'machine_vocab.json'), 'w') as f: 112 | json.dump(machine2int, f) 113 | 114 | if __name__ == '__main__': 115 | print('creating dataset') 116 | create_dataset(os.path.join(DATA_FOLDER, 'data.csv'), 25000, 117 | vocabulary=True) 118 | -------------------------------------------------------------------------------- /img/attn_ex_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/attention_tutorial/085bbee03433bf7a7bd3a341daeaa228ac15cb50/img/attn_ex_2.png -------------------------------------------------------------------------------- /img/attn_seq2seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/attention_tutorial/085bbee03433bf7a7bd3a341daeaa228ac15cb50/img/attn_seq2seq.png --------------------------------------------------------------------------------