├── requirements.txt ├── .github └── workflows │ └── auto_mirror.yaml ├── LICENSE ├── 5-2.BERT ├── layers.py ├── BERT.py ├── BERT_pytorch.ipynb └── BERT.ipynb ├── 5-1.Transformer ├── layers.py └── Transformer.py ├── .gitignore ├── README.md ├── 3-2.TextLSTM ├── TextLSTM_pytorch.ipynb └── TextLSTM.ipynb ├── 1-1.NNLM ├── NNLM_pytorch.ipynb └── NNLM.ipynb ├── 3-3.Bi-LSTM ├── Bi-LSTM_pytorch.ipynb └── Bi-LSTM.ipynb ├── 1-2.Word2Vec ├── Word2Vec_pytorch.ipynb └── Word2Vec.ipynb ├── 3-1.TextRNN ├── TextRNN_pytorch.ipynb └── TextRNN.ipynb ├── 2-1.TextCNN ├── TextCNN_pytorch.ipynb └── TextCNN.ipynb ├── 4-3.Bi-LSTM(Attention) ├── Bi-LSTM-Attention_pytorch.ipynb └── Bi-LSTM-Attention.ipynb ├── 4-1.Seq2Seq ├── Seq2Seq_pytorch.ipynb └── Seq2Seq.ipynb └── 4-2.Seq2Seq(Attention) ├── Seq2Seq-Attention_pytorch.ipynb └── Seq2Seq-Attention.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib -------------------------------------------------------------------------------- /.github/workflows/auto_mirror.yaml: -------------------------------------------------------------------------------- 1 | name: Mirroring 2 | 3 | on: [push, delete] 4 | 5 | jobs: 6 | sync_to_openi: 7 | runs-on: ubuntu-latest 8 | steps: # <-- must use actions/checkout before mirroring! 9 | - uses: actions/checkout@v2 10 | with: 11 | fetch-depth: 0 12 | - uses: pixta-dev/repository-mirroring-action@v1 13 | with: 14 | target_repo_url: 15 | git@git.openi.org.cn:lvyufeng/mindspore_nlp_tutorial.git 16 | ssh_private_key: # <-- use 'secrets' to pass credential information. 17 | ${{ secrets.OPENI_SSH_PRIVATE_KEY }} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 nate.river 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /5-2.BERT/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mindspore 3 | import mindspore.nn as nn 4 | from mindspore import Parameter, Tensor 5 | from mindspore.common.initializer import initializer, HeUniform, Uniform, Normal, _calculate_fan_in_and_fan_out 6 | 7 | class Conv1d(nn.Conv1d): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1, group=1, has_bias=True): 9 | super().__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group, has_bias, weight_init='normal', bias_init='zeros') 10 | self.reset_parameters() 11 | 12 | def reset_parameters(self): 13 | self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape)) 14 | #self.weight = Parameter(initializer(HeUniform(math.sqrt(5)), self.weight.shape), name='weight') 15 | if self.has_bias: 16 | fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape) 17 | bound = 1 / math.sqrt(fan_in) 18 | self.bias.set_data(initializer(Uniform(bound), [self.out_channels])) 19 | 20 | class Dense(nn.Dense): 21 | def __init__(self, in_channels, out_channels, has_bias=True, activation=None): 22 | super().__init__(in_channels, out_channels, weight_init='normal', bias_init='zeros', has_bias=has_bias, activation=activation) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape)) 27 | if self.has_bias: 28 | fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape) 29 | bound = 1 / math.sqrt(fan_in) 30 | self.bias.set_data(initializer(Uniform(bound), [self.out_channels])) 31 | 32 | class Embedding(nn.Embedding): 33 | def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mindspore.float32, padding_idx=None): 34 | if embedding_table == 'normal': 35 | embedding_table = Normal(1.0) 36 | super().__init__(vocab_size, embedding_size, use_one_hot, embedding_table, dtype, padding_idx) 37 | @classmethod 38 | def from_pretrained_embedding(cls, embeddings:Tensor, freeze=True, padding_idx=None): 39 | rows, cols = embeddings.shape 40 | embedding = cls(rows, cols, embedding_table=embeddings, padding_idx=padding_idx) 41 | embedding.embedding_table.requires_grad = not freeze 42 | return embedding -------------------------------------------------------------------------------- /5-1.Transformer/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mindspore 3 | import mindspore.nn as nn 4 | from mindspore import Parameter, Tensor 5 | from mindspore.common.initializer import initializer, HeUniform, Uniform, Normal, _calculate_fan_in_and_fan_out 6 | 7 | class Conv1d(nn.Conv1d): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1, group=1, has_bias=True): 9 | super().__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group, has_bias, weight_init='normal', bias_init='zeros') 10 | self.reset_parameters() 11 | 12 | def reset_parameters(self): 13 | self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape)) 14 | #self.weight = Parameter(initializer(HeUniform(math.sqrt(5)), self.weight.shape), name='weight') 15 | if self.has_bias: 16 | fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape) 17 | bound = 1 / math.sqrt(fan_in) 18 | self.bias.set_data(initializer(Uniform(bound), [self.out_channels])) 19 | 20 | class Dense(nn.Dense): 21 | def __init__(self, in_channels, out_channels, has_bias=True, activation=None): 22 | super().__init__(in_channels, out_channels, weight_init='normal', bias_init='zeros', has_bias=has_bias, activation=activation) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape)) 27 | if self.has_bias: 28 | fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape) 29 | bound = 1 / math.sqrt(fan_in) 30 | self.bias.set_data(initializer(Uniform(bound), [self.out_channels])) 31 | 32 | class Embedding(nn.Embedding): 33 | def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mindspore.float32, padding_idx=None): 34 | if embedding_table == 'normal': 35 | embedding_table = Normal(1.0) 36 | super().__init__(vocab_size, embedding_size, use_one_hot, embedding_table, dtype, padding_idx) 37 | @classmethod 38 | def from_pretrained_embedding(cls, embeddings:Tensor, freeze=True, padding_idx=None): 39 | rows, cols = embeddings.shape 40 | embedding = cls(rows, cols, embedding_table=embeddings, padding_idx=padding_idx) 41 | embedding.embedding_table.requires_grad = not freeze 42 | return embedding -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | rank_0/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mindspore-nlp-tutorial 2 | 3 |

4 | 5 | `mindspore-nlp-tutorial` is a tutorial for who is studying NLP(Natural Language Processing) using **MindSpore**. This repository is migrated from [nlp-tutorial](https://github.com/graykode/nlp-tutorial). Most of the models in NLP were migrated from Pytorch version with less than **100 lines** of code.(except comments or blank lines) 6 | 7 | - **Notice**: All models are tested on CPU(Linux and macOS), GPU and Ascend. 8 | 9 | ## Curriculum - (Example Purpose) 10 | 11 | #### 1. Basic Embedding Model 12 | 13 | - 1-1. [NNLM(Neural Network Language Model)](1-1.NNLM) - **Predict Next Word** 14 | - Paper - [A Neural Probabilistic Language Model(2003)](http://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf) 15 | - 1-2. [Word2Vec(Skip-gram)](1-2.Word2Vec) - **Embedding Words and Show Graph** 16 | - Paper - [Distributed Representations of Words and Phrases 17 | and their Compositionality(2013)](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf) 18 | 21 | 22 | 23 | 24 | #### 2. CNN(Convolutional Neural Network) 25 | 26 | - 2-1. [TextCNN](2-1.TextCNN) - **Binary Sentiment Classification** 27 | - Paper - [Convolutional Neural Networks for Sentence Classification(2014)](http://www.aclweb.org/anthology/D14-1181) 28 | 29 | 30 | 31 | #### 3. RNN(Recurrent Neural Network) 32 | 33 | - 3-1. [TextRNN](3-1.TextRNN) - **Predict Next Step** 34 | - Paper - [Finding Structure in Time(1990)](http://psych.colorado.edu/~kimlab/Elman1990.pdf) 35 | - 3-2. [TextLSTM](3-2.TextLSTM) - **Autocomplete** 36 | - Paper - [LONG SHORT-TERM MEMORY(1997)](https://www.bioinf.jku.at/publications/older/2604.pdf) 37 | - 3-3. [Bi-LSTM](3-3.Bi-LSTM) - **Predict Next Word in Long Sentence** 38 | 39 | 40 | #### 4. Attention Mechanism 41 | 42 | - 4-1. [Seq2Seq](4-1.Seq2Seq) - **Change Word** 43 | - Paper - [Learning Phrase Representations using RNN Encoder–Decoder 44 | for Statistical Machine Translation(2014)](https://arxiv.org/pdf/1406.1078.pdf) 45 | - 4-2. [Seq2Seq with Attention](4-2.Seq2Seq(Attention)) - **Translate** 46 | - Paper - [Neural Machine Translation by Jointly Learning to Align and Translate(2014)](https://arxiv.org/abs/1409.0473) 47 | - 4-3. [Bi-LSTM with Attention](4-3.Bi-LSTM(Attention)) - **Binary Sentiment Classification** 48 | 49 | #### 5. Model based on Transformer 50 | 51 | - 5-1. [The Transformer](5-1.Transformer) - **Translate** 52 | - Paper - [Attention Is All You Need(2017)](https://arxiv.org/abs/1706.03762) 53 | 54 | - 5-2. [BERT](5-2.BERT) - **Classification Next Sentence & Predict Masked Tokens** 55 | - Paper - [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding(2018)](https://arxiv.org/abs/1810.04805) 56 | 57 | ## Dependencies 58 | 59 | - Python >= 3.7.5 60 | - MindSpore 1.9.0 61 | - Pytorch 1.7.1(for comparation) 62 | 63 | ## Author 64 | 65 | - Yufeng Lyu 66 | - Author Email : lvyufeng2007@hotmail.com 67 | - Acknowledgements to [graykode](https://github.com/graykode) who opensource the Pytorch and Tensorflow version. 68 | -------------------------------------------------------------------------------- /3-2.TextLSTM/TextLSTM_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "collect-government", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# %%\n", 11 | "# code by Tae Hwan Jung @graykode\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.optim as optim\n", 16 | "\n", 17 | "def make_batch():\n", 18 | " input_batch, target_batch = [], []\n", 19 | "\n", 20 | " for seq in seq_data:\n", 21 | " input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input\n", 22 | " target = word_dict[seq[-1]] # 'e' is target\n", 23 | " input_batch.append(np.eye(n_class)[input])\n", 24 | " target_batch.append(target)\n", 25 | "\n", 26 | " return input_batch, target_batch\n", 27 | "\n", 28 | "class TextLSTM(nn.Module):\n", 29 | " def __init__(self):\n", 30 | " super(TextLSTM, self).__init__()\n", 31 | "\n", 32 | " self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)\n", 33 | " self.W = nn.Linear(n_hidden, n_class, bias=False)\n", 34 | " self.b = nn.Parameter(torch.ones([n_class]))\n", 35 | "\n", 36 | " def forward(self, X):\n", 37 | " input = X.transpose(0, 1) # X : [n_step, batch_size, n_class]\n", 38 | "\n", 39 | " hidden_state = torch.zeros(1, len(X), n_hidden) # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 40 | " cell_state = torch.zeros(1, len(X), n_hidden) # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 41 | "\n", 42 | " outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))\n", 43 | " outputs = outputs[-1] # [batch_size, n_hidden]\n", 44 | " model = self.W(outputs) + self.b # model : [batch_size, n_class]\n", 45 | " return model\n", 46 | "\n", 47 | "if __name__ == '__main__':\n", 48 | " n_step = 3 # number of cells(= number of Step)\n", 49 | " n_hidden = 128 # number of hidden units in one cell\n", 50 | "\n", 51 | " char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']\n", 52 | " word_dict = {n: i for i, n in enumerate(char_arr)}\n", 53 | " number_dict = {i: w for i, w in enumerate(char_arr)}\n", 54 | " n_class = len(word_dict) # number of class(=number of vocab)\n", 55 | "\n", 56 | " seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']\n", 57 | "\n", 58 | " model = TextLSTM()\n", 59 | "\n", 60 | " criterion = nn.CrossEntropyLoss()\n", 61 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 62 | "\n", 63 | " input_batch, target_batch = make_batch()\n", 64 | " input_batch = torch.FloatTensor(input_batch)\n", 65 | " target_batch = torch.LongTensor(target_batch)\n", 66 | "\n", 67 | " # Training\n", 68 | " for epoch in range(1000):\n", 69 | " optimizer.zero_grad()\n", 70 | "\n", 71 | " output = model(input_batch)\n", 72 | " loss = criterion(output, target_batch)\n", 73 | " if (epoch + 1) % 100 == 0:\n", 74 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 75 | "\n", 76 | " loss.backward()\n", 77 | " optimizer.step()\n", 78 | "\n", 79 | " inputs = [sen[:3] for sen in seq_data]\n", 80 | "\n", 81 | " predict = model(input_batch).data.max(1, keepdim=True)[1]\n", 82 | " print(inputs, '->', [number_dict[n.item()] for n in predict.squeeze()])\n" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "id": "imposed-facility", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "Python 3", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.7.5" 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 5 115 | } 116 | -------------------------------------------------------------------------------- /1-1.NNLM/NNLM_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "hired-pride", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# code by Tae Hwan Jung @graykode\n", 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "import torch.optim as optim\n", 14 | "\n", 15 | "def make_batch():\n", 16 | " input_batch = []\n", 17 | " target_batch = []\n", 18 | "\n", 19 | " for sen in sentences:\n", 20 | " word = sen.split() # space tokenizer\n", 21 | " input = [word_dict[n] for n in word[:-1]] # create (1~n-1) as input\n", 22 | " target = word_dict[word[-1]] # create (n) as target, We usually call this 'casual language model'\n", 23 | "\n", 24 | " input_batch.append(input)\n", 25 | " target_batch.append(target)\n", 26 | "\n", 27 | " return input_batch, target_batch\n", 28 | "\n", 29 | "# Model\n", 30 | "class NNLM(nn.Module):\n", 31 | " def __init__(self):\n", 32 | " super(NNLM, self).__init__()\n", 33 | " self.C = nn.Embedding(n_class, m)\n", 34 | " self.H = nn.Linear(n_step * m, n_hidden, bias=False)\n", 35 | " self.d = nn.Parameter(torch.ones(n_hidden))\n", 36 | " self.U = nn.Linear(n_hidden, n_class, bias=False)\n", 37 | " self.W = nn.Linear(n_step * m, n_class, bias=False)\n", 38 | " self.b = nn.Parameter(torch.ones(n_class))\n", 39 | "\n", 40 | " def forward(self, X):\n", 41 | " X = self.C(X) # X : [batch_size, n_step, m]\n", 42 | " X = X.view(-1, n_step * m) # [batch_size, n_step * m]\n", 43 | " tanh = torch.tanh(self.d + self.H(X)) # [batch_size, n_hidden]\n", 44 | " output = self.b + self.W(X) + self.U(tanh) # [batch_size, n_class]\n", 45 | " return output\n", 46 | "\n", 47 | "if __name__ == '__main__':\n", 48 | " n_step = 2 # number of steps, n-1 in paper\n", 49 | " n_hidden = 2 # number of hidden size, h in paper\n", 50 | " m = 2 # embedding size, m in paper\n", 51 | "\n", 52 | " sentences = [\"i like dog\", \"i love coffee\", \"i hate milk\"]\n", 53 | "\n", 54 | " word_list = \" \".join(sentences).split()\n", 55 | " word_list = list(set(word_list))\n", 56 | " word_dict = {w: i for i, w in enumerate(word_list)}\n", 57 | " number_dict = {i: w for i, w in enumerate(word_list)}\n", 58 | " n_class = len(word_dict) # number of Vocabulary\n", 59 | "\n", 60 | " model = NNLM()\n", 61 | "\n", 62 | " input_batch, target_batch = make_batch()\n", 63 | " input_batch = torch.LongTensor(input_batch)\n", 64 | " target_batch = torch.LongTensor(target_batch)\n", 65 | "\n", 66 | " criterion = nn.CrossEntropyLoss()\n", 67 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 68 | "\n", 69 | " # Training\n", 70 | " for epoch in range(5000):\n", 71 | " optimizer.zero_grad()\n", 72 | " output = model(input_batch)\n", 73 | "\n", 74 | " # output : [batch_size, n_class], target_batch : [batch_size]\n", 75 | " loss = criterion(output, target_batch)\n", 76 | " if (epoch + 1) % 1000 == 0:\n", 77 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 78 | "\n", 79 | " loss.backward()\n", 80 | " optimizer.step()\n", 81 | "\n", 82 | " # Predict\n", 83 | " predict = model(input_batch)\n", 84 | " predict = predict.data.max(1, keepdim=True)[1]\n", 85 | " # Test\n", 86 | " print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "collected-tracy", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "Python 3", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.7.5" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 5 119 | } 120 | -------------------------------------------------------------------------------- /3-3.Bi-LSTM/Bi-LSTM_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "qualified-shaft", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# %%\n", 11 | "# code by Tae Hwan Jung @graykode\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.optim as optim\n", 16 | "\n", 17 | "def make_batch():\n", 18 | " input_batch = []\n", 19 | " target_batch = []\n", 20 | "\n", 21 | " words = sentence.split()\n", 22 | " for i, word in enumerate(words[:-1]):\n", 23 | " input = [word_dict[n] for n in words[:(i + 1)]]\n", 24 | " input = input + [0] * (max_len - len(input))\n", 25 | " target = word_dict[words[i + 1]]\n", 26 | " input_batch.append(np.eye(n_class)[input])\n", 27 | " target_batch.append(target)\n", 28 | "\n", 29 | " return input_batch, target_batch\n", 30 | "\n", 31 | "class BiLSTM(nn.Module):\n", 32 | " def __init__(self):\n", 33 | " super(BiLSTM, self).__init__()\n", 34 | "\n", 35 | " self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)\n", 36 | " self.W = nn.Linear(n_hidden * 2, n_class, bias=False)\n", 37 | " self.b = nn.Parameter(torch.ones([n_class]))\n", 38 | "\n", 39 | " def forward(self, X):\n", 40 | " input = X.transpose(0, 1) # input : [n_step, batch_size, n_class]\n", 41 | "\n", 42 | " hidden_state = torch.zeros(1*2, len(X), n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n", 43 | " cell_state = torch.zeros(1*2, len(X), n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n", 44 | "\n", 45 | " outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))\n", 46 | " outputs = outputs[-1] # [batch_size, n_hidden]\n", 47 | " model = self.W(outputs) + self.b # model : [batch_size, n_class]\n", 48 | " return model\n", 49 | "\n", 50 | "if __name__ == '__main__':\n", 51 | " n_hidden = 5 # number of hidden units in one cell\n", 52 | "\n", 53 | " sentence = (\n", 54 | " 'Lorem ipsum dolor sit amet consectetur adipisicing elit '\n", 55 | " 'sed do eiusmod tempor incididunt ut labore et dolore magna '\n", 56 | " 'aliqua Ut enim ad minim veniam quis nostrud exercitation'\n", 57 | " )\n", 58 | "\n", 59 | " word_dict = {w: i for i, w in enumerate(list(set(sentence.split())))}\n", 60 | " number_dict = {i: w for i, w in enumerate(list(set(sentence.split())))}\n", 61 | " n_class = len(word_dict)\n", 62 | " max_len = len(sentence.split())\n", 63 | "\n", 64 | " model = BiLSTM()\n", 65 | "\n", 66 | " criterion = nn.CrossEntropyLoss()\n", 67 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 68 | "\n", 69 | " input_batch, target_batch = make_batch()\n", 70 | " input_batch = torch.FloatTensor(input_batch)\n", 71 | " target_batch = torch.LongTensor(target_batch)\n", 72 | " \n", 73 | " print(input_batch.shape, target_batch.shape)\n", 74 | " # Training\n", 75 | " for epoch in range(10000):\n", 76 | " optimizer.zero_grad()\n", 77 | " output = model(input_batch)\n", 78 | " loss = criterion(output, target_batch)\n", 79 | " if (epoch + 1) % 1000 == 0:\n", 80 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 81 | "\n", 82 | " loss.backward()\n", 83 | " optimizer.step()\n", 84 | "\n", 85 | " predict = model(input_batch).data.max(1, keepdim=True)[1]\n", 86 | " print(sentence)\n", 87 | " print([number_dict[n.item()] for n in predict.squeeze()])\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "advance-dressing", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "Python 3", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.7.5" 116 | } 117 | }, 118 | "nbformat": 4, 119 | "nbformat_minor": 5 120 | } 121 | -------------------------------------------------------------------------------- /1-2.Word2Vec/Word2Vec_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# %%\n", 10 | "# code by Tae Hwan Jung @graykode\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import torch.optim as optim\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "\n", 17 | "def random_batch():\n", 18 | " random_inputs = []\n", 19 | " random_labels = []\n", 20 | " random_index = np.random.choice(range(len(skip_grams)), batch_size, replace=False)\n", 21 | "\n", 22 | " for i in random_index:\n", 23 | " random_inputs.append(np.eye(voc_size)[skip_grams[i][0]]) # target\n", 24 | " random_labels.append(skip_grams[i][1]) # context word\n", 25 | "\n", 26 | " return random_inputs, random_labels\n", 27 | "\n", 28 | "# Model\n", 29 | "class Word2Vec(nn.Module):\n", 30 | " def __init__(self):\n", 31 | " super(Word2Vec, self).__init__()\n", 32 | " # W and WT is not Traspose relationship\n", 33 | " self.W = nn.Linear(voc_size, embedding_size, bias=False) # voc_size > embedding_size Weight\n", 34 | " self.WT = nn.Linear(embedding_size, voc_size, bias=False) # embedding_size > voc_size Weight\n", 35 | "\n", 36 | " def forward(self, X):\n", 37 | " # X : [batch_size, voc_size]\n", 38 | " hidden_layer = self.W(X) # hidden_layer : [batch_size, embedding_size]\n", 39 | " output_layer = self.WT(hidden_layer) # output_layer : [batch_size, voc_size]\n", 40 | " return output_layer\n", 41 | "\n", 42 | "if __name__ == '__main__':\n", 43 | " batch_size = 2 # mini-batch size\n", 44 | " embedding_size = 2 # embedding size\n", 45 | "\n", 46 | " sentences = [\"apple banana fruit\", \"banana orange fruit\", \"orange banana fruit\",\n", 47 | " \"dog cat animal\", \"cat monkey animal\", \"monkey dog animal\"]\n", 48 | "\n", 49 | " word_sequence = \" \".join(sentences).split()\n", 50 | " word_list = \" \".join(sentences).split()\n", 51 | " word_list = list(set(word_list))\n", 52 | " word_dict = {w: i for i, w in enumerate(word_list)}\n", 53 | " voc_size = len(word_list)\n", 54 | "\n", 55 | " # Make skip gram of one size window\n", 56 | " skip_grams = []\n", 57 | " for i in range(1, len(word_sequence) - 1):\n", 58 | " target = word_dict[word_sequence[i]]\n", 59 | " context = [word_dict[word_sequence[i - 1]], word_dict[word_sequence[i + 1]]]\n", 60 | " for w in context:\n", 61 | " skip_grams.append([target, w])\n", 62 | "\n", 63 | " model = Word2Vec()\n", 64 | "\n", 65 | " criterion = nn.CrossEntropyLoss()\n", 66 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 67 | "\n", 68 | " # Training\n", 69 | " for epoch in range(5000):\n", 70 | " input_batch, target_batch = random_batch()\n", 71 | " input_batch = torch.Tensor(input_batch)\n", 72 | " target_batch = torch.LongTensor(target_batch)\n", 73 | "\n", 74 | " optimizer.zero_grad()\n", 75 | " output = model(input_batch)\n", 76 | "\n", 77 | " # output : [batch_size, voc_size], target_batch : [batch_size] (LongTensor, not one-hot)\n", 78 | " loss = criterion(output, target_batch)\n", 79 | " if (epoch + 1) % 1000 == 0:\n", 80 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 81 | "\n", 82 | " loss.backward()\n", 83 | " optimizer.step()\n", 84 | "\n", 85 | " for i, label in enumerate(word_list):\n", 86 | " W, WT = model.parameters()\n", 87 | " x, y = W[0][i].item(), W[1][i].item()\n", 88 | " plt.scatter(x, y)\n", 89 | " plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')\n", 90 | " plt.show()" 91 | ] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "Python 3", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.7.5" 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 4 115 | } 116 | -------------------------------------------------------------------------------- /3-1.TextRNN/TextRNN_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "heated-fighter", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# %%\n", 11 | "# code by Tae Hwan Jung @graykode\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.optim as optim\n", 16 | "\n", 17 | "def make_batch():\n", 18 | " input_batch = []\n", 19 | " target_batch = []\n", 20 | "\n", 21 | " for sen in sentences:\n", 22 | " word = sen.split() # space tokenizer\n", 23 | " input = [word_dict[n] for n in word[:-1]] # create (1~n-1) as input\n", 24 | " target = word_dict[word[-1]] # create (n) as target, We usually call this 'casual language model'\n", 25 | "\n", 26 | " input_batch.append(np.eye(n_class)[input])\n", 27 | " target_batch.append(target)\n", 28 | "\n", 29 | " return input_batch, target_batch\n", 30 | "\n", 31 | "class TextRNN(nn.Module):\n", 32 | " def __init__(self):\n", 33 | " super(TextRNN, self).__init__()\n", 34 | " self.rnn = nn.RNN(input_size=n_class, hidden_size=n_hidden)\n", 35 | " self.W = nn.Linear(n_hidden, n_class, bias=False)\n", 36 | " self.b = nn.Parameter(torch.ones([n_class]))\n", 37 | "\n", 38 | " def forward(self, hidden, X):\n", 39 | " X = X.transpose(0, 1) # X : [n_step, batch_size, n_class]\n", 40 | " outputs, hidden = self.rnn(X, hidden)\n", 41 | " # outputs : [n_step, batch_size, num_directions(=1) * n_hidden]\n", 42 | " # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 43 | " outputs = outputs[-1] # [batch_size, num_directions(=1) * n_hidden]\n", 44 | " model = self.W(outputs) + self.b # model : [batch_size, n_class]\n", 45 | " return model\n", 46 | "\n", 47 | "if __name__ == '__main__':\n", 48 | " n_step = 2 # number of cells(= number of Step)\n", 49 | " n_hidden = 5 # number of hidden units in one cell\n", 50 | "\n", 51 | " sentences = [\"i like dog\", \"i love coffee\", \"i hate milk\"]\n", 52 | "\n", 53 | " word_list = \" \".join(sentences).split()\n", 54 | " word_list = list(set(word_list))\n", 55 | " word_dict = {w: i for i, w in enumerate(word_list)}\n", 56 | " number_dict = {i: w for i, w in enumerate(word_list)}\n", 57 | " n_class = len(word_dict)\n", 58 | " batch_size = len(sentences)\n", 59 | "\n", 60 | " model = TextRNN()\n", 61 | "\n", 62 | " criterion = nn.CrossEntropyLoss()\n", 63 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 64 | "\n", 65 | " input_batch, target_batch = make_batch()\n", 66 | " input_batch = torch.FloatTensor(input_batch)\n", 67 | " target_batch = torch.LongTensor(target_batch)\n", 68 | "\n", 69 | " # Training\n", 70 | " for epoch in range(5000):\n", 71 | " optimizer.zero_grad()\n", 72 | "\n", 73 | " # hidden : [num_layers * num_directions, batch, hidden_size]\n", 74 | " hidden = torch.zeros(1, batch_size, n_hidden)\n", 75 | " # input_batch : [batch_size, n_step, n_class]\n", 76 | " output = model(hidden, input_batch)\n", 77 | "\n", 78 | " # output : [batch_size, n_class], target_batch : [batch_size] (LongTensor, not one-hot)\n", 79 | " loss = criterion(output, target_batch)\n", 80 | " if (epoch + 1) % 1000 == 0:\n", 81 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 82 | "\n", 83 | " loss.backward()\n", 84 | " optimizer.step()\n", 85 | "\n", 86 | " input = [sen.split()[:2] for sen in sentences]\n", 87 | "\n", 88 | " # Predict\n", 89 | " hidden = torch.zeros(1, batch_size, n_hidden)\n", 90 | " predict = model(hidden, input_batch).data.max(1, keepdim=True)[1]\n", 91 | " print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "informational-channel", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "Python 3", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.7.5" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 5 124 | } 125 | -------------------------------------------------------------------------------- /2-1.TextCNN/TextCNN_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "healthy-stationery", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# %%\n", 11 | "# code by Tae Hwan Jung @graykode\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.optim as optim\n", 16 | "import torch.nn.functional as F\n", 17 | "\n", 18 | "class TextCNN(nn.Module):\n", 19 | " def __init__(self):\n", 20 | " super(TextCNN, self).__init__()\n", 21 | " self.num_filters_total = num_filters * len(filter_sizes)\n", 22 | " self.W = nn.Embedding(vocab_size, embedding_size)\n", 23 | " self.Weight = nn.Linear(self.num_filters_total, num_classes, bias=False)\n", 24 | " self.Bias = nn.Parameter(torch.ones([num_classes]))\n", 25 | " self.filter_list = nn.ModuleList([nn.Conv2d(1, num_filters, (size, embedding_size)) for size in filter_sizes])\n", 26 | "\n", 27 | " def forward(self, X):\n", 28 | " embedded_chars = self.W(X) # [batch_size, sequence_length, sequence_length]\n", 29 | " embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]\n", 30 | "\n", 31 | " pooled_outputs = []\n", 32 | " for i, conv in enumerate(self.filter_list):\n", 33 | " # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]\n", 34 | " h = F.relu(conv(embedded_chars))\n", 35 | " # mp : ((filter_height, filter_width))\n", 36 | " mp = nn.MaxPool2d((sequence_length - filter_sizes[i] + 1, 1))\n", 37 | " # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]\n", 38 | " pooled = mp(h).permute(0, 3, 2, 1)\n", 39 | " pooled_outputs.append(pooled)\n", 40 | "\n", 41 | " h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]\n", 42 | " h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]\n", 43 | " model = self.Weight(h_pool_flat) + self.Bias # [batch_size, num_classes]\n", 44 | " return model\n", 45 | "\n", 46 | "if __name__ == '__main__':\n", 47 | " embedding_size = 2 # embedding size\n", 48 | " sequence_length = 3 # sequence length\n", 49 | " num_classes = 2 # number of classes\n", 50 | " filter_sizes = [2, 2, 2] # n-gram windows\n", 51 | " num_filters = 3 # number of filters\n", 52 | "\n", 53 | " # 3 words sentences (=sequence_length is 3)\n", 54 | " sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \"i hate you\", \"sorry for that\", \"this is awful\"]\n", 55 | " labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.\n", 56 | "\n", 57 | " word_list = \" \".join(sentences).split()\n", 58 | " word_list = list(set(word_list))\n", 59 | " word_dict = {w: i for i, w in enumerate(word_list)}\n", 60 | " vocab_size = len(word_dict)\n", 61 | "\n", 62 | " model = TextCNN()\n", 63 | "\n", 64 | " criterion = nn.CrossEntropyLoss()\n", 65 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 66 | "\n", 67 | " inputs = torch.LongTensor([np.asarray([word_dict[n] for n in sen.split()]) for sen in sentences])\n", 68 | " targets = torch.LongTensor([out for out in labels]) # To using Torch Softmax Loss function\n", 69 | "\n", 70 | " # Training\n", 71 | " for epoch in range(5000):\n", 72 | " optimizer.zero_grad()\n", 73 | " output = model(inputs)\n", 74 | "\n", 75 | " # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)\n", 76 | " loss = criterion(output, targets)\n", 77 | " if (epoch + 1) % 1000 == 0:\n", 78 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 79 | "\n", 80 | " loss.backward()\n", 81 | " optimizer.step()\n", 82 | "\n", 83 | " # Test\n", 84 | " test_text = 'sorry hate you'\n", 85 | " tests = [np.asarray([word_dict[n] for n in test_text.split()])]\n", 86 | " test_batch = torch.LongTensor(tests)\n", 87 | "\n", 88 | " # Predict\n", 89 | " predict = model(test_batch).data.max(1, keepdim=True)[1]\n", 90 | " if predict[0][0] == 0:\n", 91 | " print(test_text,\"is Bad Mean...\")\n", 92 | " else:\n", 93 | " print(test_text,\"is Good Mean!!\")" 94 | ] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python 3", 100 | "language": "python", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.7.5" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 5 118 | } 119 | -------------------------------------------------------------------------------- /4-3.Bi-LSTM(Attention)/Bi-LSTM-Attention_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "metadata": {}, 6 | "source": [ 7 | "# code by Tae Hwan Jung(Jeff Jung) @graykode\n", 8 | "# Reference : https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM_Attn.py\n", 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.optim as optim\n", 13 | "import torch.nn.functional as F\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "\n", 16 | "class BiLSTM_Attention(nn.Module):\n", 17 | " def __init__(self):\n", 18 | " super(BiLSTM_Attention, self).__init__()\n", 19 | "\n", 20 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", 21 | " self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)\n", 22 | " self.out = nn.Linear(n_hidden * 2, num_classes)\n", 23 | "\n", 24 | " # lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix\n", 25 | " def attention_net(self, lstm_output, final_state):\n", 26 | " hidden = final_state.view(-1, n_hidden * 2, 1) # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]\n", 27 | " attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]\n", 28 | " soft_attn_weights = F.softmax(attn_weights, 1)\n", 29 | " # [batch_size, n_hidden * num_directions(=2), n_step] * [batch_size, n_step, 1] = [batch_size, n_hidden * num_directions(=2), 1]\n", 30 | " context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)\n", 31 | " return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]\n", 32 | "\n", 33 | " def forward(self, X):\n", 34 | " input = self.embedding(X) # input : [batch_size, len_seq, embedding_dim]\n", 35 | " input = input.permute(1, 0, 2) # input : [len_seq, batch_size, embedding_dim]\n", 36 | "\n", 37 | " hidden_state = torch.zeros(1*2, len(X), n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n", 38 | " cell_state = torch.zeros(1*2, len(X), n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n", 39 | "\n", 40 | " # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n", 41 | " output, (final_hidden_state, final_cell_state) = self.lstm(input, (hidden_state, cell_state))\n", 42 | " output = output.permute(1, 0, 2) # output : [batch_size, len_seq, n_hidden]\n", 43 | " attn_output, attention = self.attention_net(output, final_hidden_state)\n", 44 | " return self.out(attn_output), attention # model : [batch_size, num_classes], attention : [batch_size, n_step]\n", 45 | "\n", 46 | "if __name__ == '__main__':\n", 47 | " embedding_dim = 2 # embedding size\n", 48 | " n_hidden = 5 # number of hidden units in one cell\n", 49 | " num_classes = 2 # 0 or 1\n", 50 | "\n", 51 | " # 3 words sentences (=sequence_length is 3)\n", 52 | " sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \"i hate you\", \"sorry for that\", \"this is awful\"]\n", 53 | " labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.\n", 54 | "\n", 55 | " word_list = \" \".join(sentences).split()\n", 56 | " word_list = list(set(word_list))\n", 57 | " word_dict = {w: i for i, w in enumerate(word_list)}\n", 58 | " vocab_size = len(word_dict)\n", 59 | "\n", 60 | " model = BiLSTM_Attention()\n", 61 | "\n", 62 | " criterion = nn.CrossEntropyLoss()\n", 63 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 64 | "\n", 65 | " inputs = torch.LongTensor([np.asarray([word_dict[n] for n in sen.split()]) for sen in sentences])\n", 66 | " targets = torch.LongTensor([out for out in labels]) # To using Torch Softmax Loss function\n", 67 | "\n", 68 | " # Training\n", 69 | " for epoch in range(5000):\n", 70 | " optimizer.zero_grad()\n", 71 | " output, attention = model(inputs)\n", 72 | " loss = criterion(output, targets)\n", 73 | " if (epoch + 1) % 1000 == 0:\n", 74 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 75 | "\n", 76 | " loss.backward()\n", 77 | " optimizer.step()\n", 78 | "\n", 79 | " # Test\n", 80 | " test_text = 'sorry hate you'\n", 81 | " tests = [np.asarray([word_dict[n] for n in test_text.split()])]\n", 82 | " test_batch = torch.LongTensor(tests)\n", 83 | "\n", 84 | " # Predict\n", 85 | " predict, _ = model(test_batch)\n", 86 | " predict = predict.data.max(1, keepdim=True)[1]\n", 87 | " if predict[0][0] == 0:\n", 88 | " print(test_text,\"is Bad Mean...\")\n", 89 | " else:\n", 90 | " print(test_text,\"is Good Mean!!\")\n", 91 | "\n", 92 | " fig = plt.figure(figsize=(6, 3)) # [batch_size, n_step]\n", 93 | " ax = fig.add_subplot(1, 1, 1)\n", 94 | " ax.matshow(attention, cmap='viridis')\n", 95 | " ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)\n", 96 | " ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], fontdict={'fontsize': 14})\n", 97 | " plt.show()" 98 | ], 99 | "outputs": [], 100 | "execution_count": null 101 | } 102 | ], 103 | "metadata": { 104 | "anaconda-cloud": {}, 105 | "kernelspec": { 106 | "display_name": "Python 3", 107 | "language": "python", 108 | "name": "python3" 109 | }, 110 | "language_info": { 111 | "codemirror_mode": { 112 | "name": "ipython", 113 | "version": 3 114 | }, 115 | "file_extension": ".py", 116 | "mimetype": "text/x-python", 117 | "name": "python", 118 | "nbconvert_exporter": "python", 119 | "pygments_lexer": "ipython3", 120 | "version": "3.6.1" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 4 125 | } -------------------------------------------------------------------------------- /4-1.Seq2Seq/Seq2Seq_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "tested-performance", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# code by Tae Hwan Jung @graykode\n", 11 | "import argparse\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "\n", 16 | "# S: Symbol that shows starting of decoding input\n", 17 | "# E: Symbol that shows starting of decoding output\n", 18 | "# P: Symbol that will fill in blank sequence if current batch data size is short than time steps\n", 19 | "\n", 20 | "def make_batch(seq_data, num_dic, n_step):\n", 21 | " input_batch, output_batch, target_batch = [], [], []\n", 22 | "\n", 23 | " for seq in seq_data:\n", 24 | " for i in range(2):\n", 25 | " seq[i] = seq[i] + 'P' * (n_step - len(seq[i]))\n", 26 | "\n", 27 | " input = [num_dic[n] for n in seq[0]]\n", 28 | " output = [num_dic[n] for n in ('S' + seq[1])]\n", 29 | " target = [num_dic[n] for n in (seq[1] + 'E')]\n", 30 | "\n", 31 | " input_batch.append(np.eye(n_class)[input])\n", 32 | " output_batch.append(np.eye(n_class)[output])\n", 33 | " target_batch.append(target) # not one-hot\n", 34 | "\n", 35 | " # make tensor\n", 36 | " return torch.FloatTensor(input_batch), torch.FloatTensor(output_batch), torch.LongTensor(target_batch)\n", 37 | "\n", 38 | "# Model\n", 39 | "class Seq2Seq(nn.Module):\n", 40 | " def __init__(self):\n", 41 | " super(Seq2Seq, self).__init__()\n", 42 | "\n", 43 | " self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n", 44 | " self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n", 45 | " self.fc = nn.Linear(n_hidden, n_class)\n", 46 | "\n", 47 | " def forward(self, enc_input, enc_hidden, dec_input):\n", 48 | " enc_input = enc_input.transpose(0, 1) # enc_input: [max_len(=n_step, time step), batch_size, n_class]\n", 49 | " dec_input = dec_input.transpose(0, 1) # dec_input: [max_len(=n_step, time step), batch_size, n_class]\n", 50 | "\n", 51 | " # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 52 | " _, enc_states = self.enc_cell(enc_input, enc_hidden)\n", 53 | " # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)]\n", 54 | " outputs, _ = self.dec_cell(dec_input, enc_states)\n", 55 | "\n", 56 | " model = self.fc(outputs) # model : [max_len+1(=6), batch_size, n_class]\n", 57 | " return model\n", 58 | "\n", 59 | "if __name__ == '__main__':\n", 60 | " n_step = 5\n", 61 | " n_hidden = 128\n", 62 | "\n", 63 | " char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']\n", 64 | " num_dic = {n: i for i, n in enumerate(char_arr)}\n", 65 | " seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]\n", 66 | "\n", 67 | " n_class = len(num_dic)\n", 68 | " batch_size = len(seq_data)\n", 69 | "\n", 70 | " model = Seq2Seq()\n", 71 | "\n", 72 | " criterion = nn.CrossEntropyLoss()\n", 73 | " optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", 74 | "\n", 75 | " input_batch, output_batch, target_batch = make_batch(seq_data, num_dic, n_step)\n", 76 | "\n", 77 | " for epoch in range(5000):\n", 78 | " # make hidden shape [num_layers * num_directions, batch_size, n_hidden]\n", 79 | " hidden = torch.zeros(1, batch_size, n_hidden)\n", 80 | "\n", 81 | " optimizer.zero_grad()\n", 82 | " # input_batch : [batch_size, max_len(=n_step, time step), n_class]\n", 83 | " # output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]\n", 84 | " # target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot\n", 85 | " output = model(input_batch, hidden, output_batch)\n", 86 | " # output : [max_len+1, batch_size, n_class]\n", 87 | " output = output.transpose(0, 1) # [batch_size, max_len+1(=6), n_class]\n", 88 | " loss = 0\n", 89 | " for i in range(0, len(target_batch)):\n", 90 | " # output[i] : [max_len+1, n_class, target_batch[i] : max_len+1]\n", 91 | " loss += criterion(output[i], target_batch[i])\n", 92 | " if (epoch + 1) % 1000 == 0:\n", 93 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 94 | " loss.backward()\n", 95 | " optimizer.step()\n", 96 | "\n", 97 | " # Test\n", 98 | " def translate(word):\n", 99 | " input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]], num_dic, n_step)\n", 100 | " # make hidden shape [num_layers * num_directions, batch_size, n_hidden]\n", 101 | " hidden = torch.zeros(1, 1, n_hidden)\n", 102 | " output = model(input_batch, hidden, output_batch)\n", 103 | " # output : [max_len+1(=6), batch_size(=1), n_class]\n", 104 | "\n", 105 | " predict = output.data.max(2, keepdim=True)[1] # select n_class dimension\n", 106 | " decoded = [char_arr[i] for i in predict]\n", 107 | " end = decoded.index('E')\n", 108 | " translated = ''.join(decoded[:end])\n", 109 | "\n", 110 | " return translated.replace('P', '')\n", 111 | "\n", 112 | " print('test')\n", 113 | " print('man ->', translate('man'))\n", 114 | " print('mans ->', translate('mans'))\n", 115 | " print('king ->', translate('king'))\n", 116 | " print('black ->', translate('black'))\n", 117 | " print('upp ->', translate('upp'))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "equivalent-preview", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "Python 3", 132 | "language": "python", 133 | "name": "python3" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 3 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython3", 145 | "version": "3.7.5" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 5 150 | } 151 | -------------------------------------------------------------------------------- /3-1.TextRNN/TextRNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "conditional-growing", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import math\n", 11 | "import mindspore\n", 12 | "import numpy as np\n", 13 | "import mindspore.nn as nn\n", 14 | "import mindspore.ops as ops\n", 15 | "from mindspore import Tensor, Parameter, ms_function" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "wanted-black", 21 | "metadata": {}, 22 | "source": [ 23 | "TextRNN Model:" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 4, 29 | "id": "superb-decrease", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "def make_batch(sentences, word_dict, n_class):\n", 34 | " input_batch = []\n", 35 | " target_batch = []\n", 36 | "\n", 37 | " for sen in sentences:\n", 38 | " word = sen.split() # space tokenizer\n", 39 | " input = [word_dict[n] for n in word[:-1]] # create (1~n-1) as input\n", 40 | " target = word_dict[word[-1]] # create (n) as target, We usually call this 'casual language model'\n", 41 | "\n", 42 | " input_batch.append(np.eye(n_class)[input])\n", 43 | " target_batch.append(target)\n", 44 | "\n", 45 | " return input_batch, target_batch" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 5, 51 | "id": "suitable-receiver", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "class TextRNN(nn.Cell):\n", 56 | " def __init__(self, n_class, n_hidden, batch_size):\n", 57 | " super(TextRNN, self).__init__()\n", 58 | " self.rnn = nn.RNN(input_size=n_class, hidden_size=n_hidden, batch_first=True)\n", 59 | " self.W = nn.Dense(n_hidden, n_class, has_bias=False)\n", 60 | " self.b = Parameter(Tensor(np.ones([n_class]), mindspore.float32), 'b')\n", 61 | "\n", 62 | " def construct(self, X):\n", 63 | " X = X.swapaxes(0, 1) # X : [n_step, batch_size, n_class]\n", 64 | " outputs, _ = self.rnn(X)\n", 65 | " # outputs : [n_step, batch_size, num_directions(=1) * n_hidden]\n", 66 | " outputs = outputs[-1] # [batch_size, num_directions(=1) * n_hidden]\n", 67 | " model = self.W(outputs)# model : [batch_size, n_class]\n", 68 | " \n", 69 | " return model" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 6, 75 | "id": "greenhouse-state", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "n_step = 2 # number of cells(= number of Step)\n", 80 | "n_hidden = 5 # number of hidden units in one cell\n", 81 | "\n", 82 | "sentences = [\"i like dog\", \"i love coffee\", \"i hate milk\"]\n", 83 | "\n", 84 | "word_list = \" \".join(sentences).split()\n", 85 | "word_list = list(set(word_list))\n", 86 | "word_dict = {w: i for i, w in enumerate(word_list)}\n", 87 | "number_dict = {i: w for i, w in enumerate(word_list)}\n", 88 | "n_class = len(word_dict)\n", 89 | "batch_size = len(sentences)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 7, 95 | "id": "quantitative-superintendent", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "model = TextRNN(n_class, n_hidden, batch_size)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 8, 105 | "id": "enabling-shore", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "criterion = nn.CrossEntropyLoss()\n", 110 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 9, 116 | "id": "afraid-pharmacology", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "input_batch, target_batch = make_batch(sentences, word_dict, n_class)\n", 121 | "input_batch = Tensor(input_batch, mindspore.float32)\n", 122 | "target_batch = Tensor(target_batch, mindspore.int32)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 10, 128 | "id": "e443bf76", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "def forward(inputs, targets):\n", 133 | " logits = model(inputs)\n", 134 | " loss = criterion(logits, targets)\n", 135 | " return loss" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 11, 141 | "id": "8d489907", 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 12, 151 | "id": "9dd3d835", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "@ms_function\n", 156 | "def train_step(inputs, targets):\n", 157 | " loss, grads = grad_fn(inputs, targets)\n", 158 | " optimizer(grads)\n", 159 | " return loss" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 13, 165 | "id": "banner-backup", 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "Epoch: 1000 cost = 0.141270\n", 173 | "Epoch: 2000 cost = 0.025611\n", 174 | "Epoch: 3000 cost = 0.010544\n", 175 | "Epoch: 4000 cost = 0.005329\n", 176 | "Epoch: 5000 cost = 0.002940\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "model.set_train()\n", 182 | "\n", 183 | "# Training\n", 184 | "for epoch in range(5000):\n", 185 | " # hidden : [num_layers * num_directions, batch, hidden_size]\n", 186 | " loss = train_step(input_batch, target_batch)\n", 187 | " if (epoch + 1) % 1000 == 0:\n", 188 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 14, 194 | "id": "established-solid", 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "[['i', 'like'], ['i', 'love'], ['i', 'hate']] -> ['dog', 'coffee', 'milk']\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "# Predict\n", 207 | "predict = model(input_batch).asnumpy().argmax(1)\n", 208 | "print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])" 209 | ] 210 | } 211 | ], 212 | "metadata": { 213 | "kernelspec": { 214 | "display_name": "Python 3.7.13 ('ms1.8')", 215 | "language": "python", 216 | "name": "python3" 217 | }, 218 | "language_info": { 219 | "codemirror_mode": { 220 | "name": "ipython", 221 | "version": 3 222 | }, 223 | "file_extension": ".py", 224 | "mimetype": "text/x-python", 225 | "name": "python", 226 | "nbconvert_exporter": "python", 227 | "pygments_lexer": "ipython3", 228 | "version": "3.7.13" 229 | }, 230 | "vscode": { 231 | "interpreter": { 232 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 233 | } 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 5 238 | } 239 | -------------------------------------------------------------------------------- /2-1.TextCNN/TextCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 11, 6 | "id": "confidential-attendance", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import mindspore\n", 12 | "import mindspore.nn as nn\n", 13 | "import mindspore.ops as ops\n", 14 | "from mindspore import Parameter, Tensor, ms_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 12, 20 | "id": "promotional-smart", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "class TextCNN(nn.Cell):\n", 25 | " def __init__(self, embedding_size, sequence_length, num_classes, filter_sizes, num_filters, vocab_size):\n", 26 | " super(TextCNN, self).__init__()\n", 27 | " self.num_filters_total = num_filters * len(filter_sizes)\n", 28 | " self.filter_sizes = filter_sizes\n", 29 | " self.sequence_length = sequence_length\n", 30 | " self.W = nn.Embedding(vocab_size, embedding_size)\n", 31 | " self.Weight = nn.Dense(self.num_filters_total, num_classes, has_bias=False)\n", 32 | " self.Bias = Parameter(Tensor(np.ones(num_classes), mindspore.float32), name='bias')\n", 33 | " self.filter_list = nn.CellList()\n", 34 | " for size in filter_sizes:\n", 35 | " seq_cell = nn.SequentialCell([\n", 36 | " nn.Conv2d(1, num_filters, (size, embedding_size), pad_mode='valid'),\n", 37 | " nn.ReLU(),\n", 38 | " nn.MaxPool2d(kernel_size=(sequence_length - size + 1, 1))\n", 39 | " ])\n", 40 | " self.filter_list.append(seq_cell)\n", 41 | "\n", 42 | " def construct(self, X):\n", 43 | " embedded_chars = self.W(X)\n", 44 | " embedded_chars = embedded_chars.expand_dims(1)\n", 45 | " pooled_outputs = []\n", 46 | " for conv in self.filter_list:\n", 47 | " pooled = conv(embedded_chars)\n", 48 | " pooled = pooled.transpose((0, 3, 2, 1))\n", 49 | " pooled_outputs.append(pooled)\n", 50 | " \n", 51 | " h_pool = ops.concat(pooled_outputs, len(self.filter_sizes))\n", 52 | " h_pool_flat = h_pool.view(-1, self.num_filters_total)\n", 53 | " model = self.Weight(h_pool_flat) + self.Bias\n", 54 | " return model" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 13, 60 | "id": "prime-lindsay", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "\n", 65 | "embedding_size = 2\n", 66 | "sequence_length = 3\n", 67 | "num_classes = 2\n", 68 | "filter_sizes = [2, 2, 2]\n", 69 | "num_filters = 3\n", 70 | "\n", 71 | "sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \" i hate you\", \"sorry for that\", \"this is awful\"]\n", 72 | "labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.\n", 73 | "\n", 74 | "word_list = \" \".join(sentences).split()\n", 75 | "word_list = list(set(word_list))\n", 76 | "word_dict = {w: i for i, w in enumerate(word_list)}\n", 77 | "vocab_size = len(word_dict)\n", 78 | "\n", 79 | "model = TextCNN(embedding_size, sequence_length, num_classes, filter_sizes, num_filters, vocab_size)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 14, 85 | "id": "gorgeous-weekly", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "criterion = nn.CrossEntropyLoss()\n", 90 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 15, 96 | "id": "instructional-scheduling", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "inputs = Tensor([np.asarray([word_dict[n] for n in sen.split()]) for sen in sentences], mindspore.int32)\n", 101 | "targets = Tensor([out for out in labels], mindspore.int32) " 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 16, 107 | "id": "hundred-conclusion", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "def forward(inputs, targets):\n", 112 | " logits = model(inputs)\n", 113 | " loss = criterion(logits, targets)\n", 114 | " return loss" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 17, 120 | "id": "8bf68174", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 18, 130 | "id": "c4ee2913", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "@ms_function\n", 135 | "def train_step(inputs, targets):\n", 136 | " loss, grads = grad_fn(inputs, targets)\n", 137 | " optimizer(grads)\n", 138 | " return loss" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 19, 144 | "id": "interesting-worthy", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Epoch: 1000 cost = 0.002617\n", 152 | "Epoch: 2000 cost = 0.000449\n", 153 | "Epoch: 3000 cost = 0.000152\n", 154 | "Epoch: 4000 cost = 0.000066\n", 155 | "Epoch: 5000 cost = 0.000031\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "model.set_train()\n", 161 | "\n", 162 | "epoch = 5000\n", 163 | "for step in range(epoch):\n", 164 | " loss = train_step(inputs, targets)\n", 165 | " \n", 166 | " if (step + 1) % 1000 == 0:\n", 167 | " print('Epoch:', '%04d' % (step + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 20, 173 | "id": "decent-breakdown", 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "sorry hate you is Bad Mean...\n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "test_text = 'sorry hate you'\n", 186 | "tests = [np.asarray([word_dict[n] for n in test_text.split()])]\n", 187 | "test_batch = Tensor(tests, mindspore.int32)\n", 188 | "\n", 189 | "# Predict\n", 190 | "predict = model(test_batch).asnumpy().argmax(1)\n", 191 | "if predict[0] == 0:\n", 192 | " print(test_text,\"is Bad Mean...\")\n", 193 | "else:\n", 194 | " print(test_text,\"is Good Mean!!\")" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python 3.7.13 ('ms1.8')", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.7.13" 215 | }, 216 | "vscode": { 217 | "interpreter": { 218 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 219 | } 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 5 224 | } 225 | -------------------------------------------------------------------------------- /3-2.TextLSTM/TextLSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "dying-communications", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import mindspore\n", 12 | "import mindspore.nn as nn\n", 13 | "import mindspore.ops as ops\n", 14 | "from mindspore import Parameter, Tensor, ms_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "suited-southeast", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def make_batch(seq_data, word_dict,vocab_size):\n", 25 | " input_batch, target_batch = [], []\n", 26 | "\n", 27 | " for seq in seq_data:\n", 28 | " input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input\n", 29 | " target = word_dict[seq[-1]] # 'e' is target\n", 30 | " input_batch.append(np.eye(vocab_size)[input])\n", 31 | " target_batch.append(target)\n", 32 | "\n", 33 | " return input_batch, target_batch" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "id": "saving-print", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "class TextLSTM(nn.Cell):\n", 44 | " def __init__(self, batch_size, vocab_size, hidden_size):\n", 45 | " super(TextLSTM,self).__init__()\n", 46 | " self.lstm = nn.LSTM(input_size=vocab_size, hidden_size=hidden_size)\n", 47 | " self.W = nn.Dense(hidden_size, vocab_size, has_bias=False)\n", 48 | " self.b = Parameter(Tensor(np.ones(vocab_size), mindspore.float32), 'b')\n", 49 | " \n", 50 | " self.n_steps = n_steps\n", 51 | "\n", 52 | " def construct(self, X):\n", 53 | " input = X.transpose((1, 0, 2)) \n", 54 | " outputs, (_, _) = self.lstm(input)\n", 55 | " outputs = outputs[-1] \n", 56 | " model = self.W(outputs) + self.b \n", 57 | " return model" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "id": "changed-facial", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "n_steps = 3 \n", 68 | "hidden_size = 128 \n", 69 | "\n", 70 | "char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']\n", 71 | "word_dict = {n: i for i, n in enumerate(char_arr)}\n", 72 | "number_dict = {i: w for i, w in enumerate(char_arr)}\n", 73 | "vocab_size = len(word_dict) \n", 74 | "\n", 75 | "seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "id": "growing-stroke", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "input_batch, target_batch = make_batch(seq_data, word_dict, vocab_size)\n", 86 | "input_batch = Tensor(input_batch, mindspore.float32)\n", 87 | "target_batch = Tensor(target_batch, mindspore.int32)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 6, 93 | "id": "happy-medline", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "batch_size = len(input_batch)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 7, 103 | "id": "fancy-detroit", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "model = TextLSTM(batch_size, vocab_size, hidden_size)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 8, 113 | "id": "electronic-mirror", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "criterion = nn.CrossEntropyLoss()\n", 118 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 9, 124 | "id": "suspended-shuttle", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "def forward(inputs, targets):\n", 129 | " logits = model(inputs)\n", 130 | " loss = criterion(logits, targets)\n", 131 | " return loss" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 10, 137 | "id": "6d65e68d", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 11, 147 | "id": "da270df2", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "@ms_function\n", 152 | "def train_step(inputs, targets):\n", 153 | " loss, grads = grad_fn(inputs, targets)\n", 154 | " optimizer(grads)\n", 155 | " return loss" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 12, 161 | "id": "sublime-exercise", 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "Epoch: 0100 cost = 1.123684\n", 169 | "Epoch: 0200 cost = 0.125290\n", 170 | "Epoch: 0300 cost = 0.027129\n", 171 | "Epoch: 0400 cost = 0.010317\n", 172 | "Epoch: 0500 cost = 0.005415\n", 173 | "Epoch: 0600 cost = 0.003387\n", 174 | "Epoch: 0700 cost = 0.002342\n", 175 | "Epoch: 0800 cost = 0.001727\n", 176 | "Epoch: 0900 cost = 0.001330\n", 177 | "Epoch: 1000 cost = 0.001058\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "model.set_train()\n", 183 | "# Training\n", 184 | "epoch = 1000\n", 185 | "for step in range(epoch):\n", 186 | " loss = train_step(input_batch, target_batch)\n", 187 | " if (step + 1) % 100 == 0:\n", 188 | " print('Epoch:', '%04d' % (step + 1), 'cost = ', '{:.6f}'.format(loss.asnumpy()))" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 13, 194 | "id": "seven-tunisia", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "inputs = [sen[:3] for sen in seq_data]" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 14, 204 | "id": "norwegian-mounting", 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "['mak', 'nee', 'coa', 'wor', 'lov', 'hat', 'liv', 'hom', 'has', 'sta'] -> ['e', 'd', 'l', 'd', 'e', 'e', 'e', 'e', 'h', 'r']\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "predict = model(input_batch).asnumpy().argmax(axis=1)\n", 217 | "print(inputs, '->', [number_dict[n.item()] for n in predict.squeeze()])" 218 | ] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python 3.7.13 ('ms1.8')", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.7.13" 238 | }, 239 | "vscode": { 240 | "interpreter": { 241 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 242 | } 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 5 247 | } 248 | -------------------------------------------------------------------------------- /1-1.NNLM/NNLM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 50, 6 | "id": "celtic-passenger", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import mindspore\n", 12 | "import mindspore.nn as nn\n", 13 | "import mindspore.ops as ops\n", 14 | "from mindspore import Parameter, Tensor, ms_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 51, 20 | "id": "lined-travel", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def make_batch(sentences, word_dict):\n", 25 | " input_batch = []\n", 26 | " target_batch = []\n", 27 | " \n", 28 | " for sent in sentences:\n", 29 | " word = sent.split()\n", 30 | " inp = [word_dict[n] for n in word[:-1]]\n", 31 | " tgt = word_dict[word[-1]]\n", 32 | " \n", 33 | " input_batch.append(inp)\n", 34 | " target_batch.append(tgt)\n", 35 | " return input_batch, target_batch" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 52, 41 | "id": "opened-guinea", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "class NNLM(nn.Cell):\n", 46 | " def __init__(self, n_steps, vocab_size, embed_size, hidden_size):\n", 47 | " super().__init__()\n", 48 | " self.C = nn.Embedding(vocab_size, embed_size)\n", 49 | " self.H = nn.Dense(n_steps * embed_size, hidden_size, has_bias=False)\n", 50 | " self.d = Parameter(Tensor(np.ones(hidden_size), mindspore.float32), name='d')\n", 51 | " self.U = nn.Dense(hidden_size, vocab_size, has_bias=False)\n", 52 | " self.W = nn.Dense(n_steps * embed_size, vocab_size, has_bias=False)\n", 53 | " self.b = Parameter(Tensor(np.ones(vocab_size), mindspore.float32), name='b')\n", 54 | " self.n_steps = n_steps\n", 55 | " self.embed_size = embed_size\n", 56 | " self.tanh = nn.Tanh()\n", 57 | "\n", 58 | " def construct(self, X):\n", 59 | " X = self.C(X)\n", 60 | " X = X.view(-1, self.n_steps * self.embed_size)\n", 61 | " tanh = self.tanh(self.d + self.H(X))\n", 62 | " output = self.b + self.W(X) + self.U(tanh)\n", 63 | " return output\n", 64 | " " 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 53, 70 | "id": "boolean-outline", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "n_steps = 2\n", 75 | "hidden_size = 2\n", 76 | "embed_size = 2\n", 77 | "\n", 78 | "sentences = [\"i like dog\", \"i love coffee\", \"i hate milk\"]\n", 79 | "\n", 80 | "word_list = \" \".join(sentences).split()\n", 81 | "word_list = list(set(word_list))\n", 82 | "word_dict = {w: i for i, w in enumerate(word_list)}\n", 83 | "number_dict = {i: w for i, w in enumerate(word_list)}\n", 84 | "vocab_size = len(word_dict)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 54, 90 | "id": "vocational-adaptation", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "Tensor(shape=[3], dtype=Int32, value= [1, 6, 3])" 97 | ] 98 | }, 99 | "execution_count": 54, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "input_batch, target_batch = make_batch(sentences, word_dict)\n", 106 | "input_batch = Tensor(input_batch, mindspore.int32)\n", 107 | "target_batch = Tensor(target_batch, mindspore.int32)\n", 108 | "target_batch" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 55, 114 | "id": "certain-spouse", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "model = NNLM(n_steps, vocab_size, embed_size, hidden_size)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 56, 124 | "id": "municipal-hypothetical", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "criterion = nn.CrossEntropyLoss()\n", 129 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 57, 135 | "id": "f1a65d23", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "def forward(inputs, targets):\n", 140 | " logits = model(inputs)\n", 141 | " loss = criterion(logits, targets)\n", 142 | " return loss" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 58, 148 | "id": "6121b71e", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 59, 158 | "id": "ff2c4e89", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "@ms_function\n", 163 | "def train_step(inputs, targets):\n", 164 | " loss, grads = grad_fn(inputs, targets)\n", 165 | " optimizer(grads)\n", 166 | " return loss" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 60, 172 | "id": "efficient-slope", 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "Epoch: 1000 cost = 0.159208\n", 180 | "Epoch: 2000 cost = 0.016804\n", 181 | "Epoch: 3000 cost = 0.005246\n", 182 | "Epoch: 4000 cost = 0.002221\n", 183 | "Epoch: 5000 cost = 0.001076\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "model.set_train()\n", 189 | "\n", 190 | "epoch = 5000\n", 191 | "for step in range(epoch):\n", 192 | " loss = train_step(input_batch, target_batch)\n", 193 | " if (step + 1) % 1000 == 0:\n", 194 | " print('Epoch:', '%04d' % (step + 1), 'cost = ', '{:.6f}'.format(loss.asnumpy()))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 61, 200 | "id": "hourly-senegal", 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "[1 6 3]\n", 208 | "[['i', 'like'], ['i', 'love'], ['i', 'hate']] -> ['dog', 'coffee', 'milk']\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "model.set_train(False)\n", 214 | "predict = model(input_batch).asnumpy().argmax(axis=1)\n", 215 | "print(predict)\n", 216 | "print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict])" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3.7.13 ('ms1.8')", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.7.13" 237 | }, 238 | "vscode": { 239 | "interpreter": { 240 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 241 | } 242 | } 243 | }, 244 | "nbformat": 4, 245 | "nbformat_minor": 5 246 | } 247 | -------------------------------------------------------------------------------- /3-3.Bi-LSTM/Bi-LSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "handled-script", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import mindspore\n", 12 | "import mindspore.nn as nn\n", 13 | "import mindspore.ops as ops\n", 14 | "from mindspore import Parameter, Tensor, ms_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "worthy-samba", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def make_batch(sentence, word_dict, n_class, max_len):\n", 25 | " input_batch = []\n", 26 | " target_batch = []\n", 27 | "\n", 28 | " words = sentence.split()\n", 29 | " for i, word in enumerate(words[:-1]):\n", 30 | " input = [word_dict[n] for n in words[:(i + 1)]]\n", 31 | " input = input + [0] * (max_len - len(input))\n", 32 | " target = word_dict[words[i + 1]]\n", 33 | " input_batch.append(np.eye(n_class)[input])\n", 34 | " target_batch.append(target)\n", 35 | "\n", 36 | " return input_batch, target_batch" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "id": "religious-portland", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "class BiLSTM(nn.Cell):\n", 47 | " def __init__(self, n_class, n_hidden, batch_size):\n", 48 | " super().__init__()\n", 49 | " self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)\n", 50 | " self.W = nn.Dense(n_hidden * 2, n_class, has_bias=False)\n", 51 | " self.b = Parameter(Tensor(np.ones([n_class], dtype=np.float32), mindspore.float32), 'b')\n", 52 | "\n", 53 | " def construct(self, X):\n", 54 | " input = X.transpose((1, 0, 2))\n", 55 | " output, (_, _) = self.lstm(input)\n", 56 | " outputs = output[-1]\n", 57 | " model = self.W(outputs) + self.b\n", 58 | " \n", 59 | " return model" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "id": "imposed-waters", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "n_hidden = 5 # number of hidden units in one cell\n", 70 | "\n", 71 | "sentence = (\n", 72 | " 'Lorem ipsum dolor sit amet consectetur adipisicing elit '\n", 73 | " 'sed do eiusmod tempor incididunt ut labore et dolore magna '\n", 74 | " 'aliqua Ut enim ad minim veniam quis nostrud exercitation'\n", 75 | ")\n", 76 | "\n", 77 | "word_dict = {w: i for i, w in enumerate(list(set(sentence.split())))}\n", 78 | "number_dict = {i: w for i, w in enumerate(list(set(sentence.split())))}\n", 79 | "n_class = len(word_dict)\n", 80 | "max_len = len(sentence.split())\n", 81 | "vocab_size = len(word_dict)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "id": "generic-vessel", 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "(26, 27, 27) (26,)\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "input_batch, target_batch = make_batch(sentence, word_dict, n_class, max_len)\n", 100 | "# print(input_batch, target_batch)\n", 101 | "input_batch = Tensor(input_batch, mindspore.float32)\n", 102 | "target_batch = Tensor(target_batch, mindspore.int32)\n", 103 | "print(input_batch.shape, target_batch.shape)\n", 104 | "\n", 105 | "batch_size = len(input_batch)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "id": "random-entertainment", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "model = BiLSTM(n_class, n_hidden, batch_size)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "id": "southwest-baltimore", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "criterion = nn.CrossEntropyLoss()\n", 126 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 8, 132 | "id": "conventional-munich", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "def forward(inputs, targets):\n", 137 | " logits = model(inputs)\n", 138 | " loss = criterion(logits, targets)\n", 139 | " return loss" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 9, 145 | "id": "b9633c3c", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 10, 155 | "id": "9b5a2242", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "@ms_function\n", 160 | "def train_step(inputs, targets):\n", 161 | " loss, grads = grad_fn(inputs, targets)\n", 162 | " optimizer(grads)\n", 163 | " return loss" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 11, 169 | "id": "accredited-manual", 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "Epoch: 1000 cost = 2.585795\n", 177 | "Epoch: 2000 cost = 2.581421\n", 178 | "Epoch: 3000 cost = 2.569982\n", 179 | "Epoch: 4000 cost = 2.311544\n", 180 | "Epoch: 5000 cost = 1.974983\n", 181 | "Epoch: 6000 cost = 1.053331\n", 182 | "Epoch: 7000 cost = 0.681154\n", 183 | "Epoch: 8000 cost = 0.568491\n", 184 | "Epoch: 9000 cost = 0.448840\n", 185 | "Epoch: 10000 cost = 0.375638\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "model.set_train()\n", 191 | "\n", 192 | "epoch = 10000\n", 193 | "for step in range(epoch):\n", 194 | " loss = train_step(input_batch, target_batch)\n", 195 | " if (step + 1) % 1000 == 0:\n", 196 | " print('Epoch:', '%04d' % (step + 1), 'cost = ', '{:.6f}'.format(loss.asnumpy()))" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 12, 202 | "id": "revised-description", 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud exercitation\n", 210 | "['dolor', 'dolor', 'sit', 'amet', 'consectetur', 'adipisicing', 'elit', 'sed', 'sed', 'eiusmod', 'tempor', 'incididunt', 'ut', 'labore', 'et', 'dolore', 'magna', 'aliqua', 'ad', 'ad', 'ad', 'minim', 'veniam', 'quis', 'nostrud', 'exercitation']\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "model.set_train(False)\n", 216 | "predict = model(input_batch).asnumpy().argmax(axis=1)\n", 217 | "print(sentence)\n", 218 | "print([number_dict[n.item()] for n in predict.squeeze()])" 219 | ] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python 3.7.13 ('ms1.8')", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.7.13" 239 | }, 240 | "vscode": { 241 | "interpreter": { 242 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 243 | } 244 | } 245 | }, 246 | "nbformat": 4, 247 | "nbformat_minor": 5 248 | } 249 | -------------------------------------------------------------------------------- /4-2.Seq2Seq(Attention)/Seq2Seq-Attention_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "metadata": {}, 6 | "source": [ 7 | "# code by Tae Hwan Jung @graykode\n", 8 | "# Reference : https://github.com/hunkim/PyTorchZeroToAll/blob/master/14_2_seq2seq_att.py\n", 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.nn.functional as F\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "\n", 15 | "# S: Symbol that shows starting of decoding input\n", 16 | "# E: Symbol that shows starting of decoding output\n", 17 | "# P: Symbol that will fill in blank sequence if current batch data size is short than time steps\n", 18 | "\n", 19 | "def make_batch():\n", 20 | " input_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[0].split()]]]\n", 21 | " output_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[1].split()]]]\n", 22 | " target_batch = [[word_dict[n] for n in sentences[2].split()]]\n", 23 | "\n", 24 | " # make tensor\n", 25 | " return torch.FloatTensor(input_batch), torch.FloatTensor(output_batch), torch.LongTensor(target_batch)\n", 26 | "\n", 27 | "class Attention(nn.Module):\n", 28 | " def __init__(self):\n", 29 | " super(Attention, self).__init__()\n", 30 | " self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n", 31 | " self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n", 32 | "\n", 33 | " # Linear for attention\n", 34 | " self.attn = nn.Linear(n_hidden, n_hidden)\n", 35 | " self.out = nn.Linear(n_hidden * 2, n_class)\n", 36 | "\n", 37 | " def forward(self, enc_inputs, hidden, dec_inputs):\n", 38 | " enc_inputs = enc_inputs.transpose(0, 1) # enc_inputs: [n_step(=n_step, time step), batch_size, n_class]\n", 39 | " dec_inputs = dec_inputs.transpose(0, 1) # dec_inputs: [n_step(=n_step, time step), batch_size, n_class]\n", 40 | "\n", 41 | " # enc_outputs : [n_step, batch_size, num_directions(=1) * n_hidden], matrix F\n", 42 | " # enc_hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 43 | " enc_outputs, enc_hidden = self.enc_cell(enc_inputs, hidden)\n", 44 | "\n", 45 | " trained_attn = []\n", 46 | " hidden = enc_hidden\n", 47 | " n_step = len(dec_inputs)\n", 48 | " model = torch.empty([n_step, 1, n_class])\n", 49 | "\n", 50 | " for i in range(n_step): # each time step\n", 51 | " # dec_output : [n_step(=1), batch_size(=1), num_directions(=1) * n_hidden]\n", 52 | " # hidden : [num_layers(=1) * num_directions(=1), batch_size(=1), n_hidden]\n", 53 | " dec_output, hidden = self.dec_cell(dec_inputs[i].unsqueeze(0), hidden)\n", 54 | " attn_weights = self.get_att_weight(dec_output, enc_outputs) # attn_weights : [1, 1, n_step]\n", 55 | " trained_attn.append(attn_weights.squeeze().data.numpy())\n", 56 | "\n", 57 | " # matrix-matrix product of matrices [1,1,n_step] x [1,n_step,n_hidden] = [1,1,n_hidden]\n", 58 | " context = attn_weights.bmm(enc_outputs.transpose(0, 1))\n", 59 | " dec_output = dec_output.squeeze(0) # dec_output : [batch_size(=1), num_directions(=1) * n_hidden]\n", 60 | " context = context.squeeze(1) # [1, num_directions(=1) * n_hidden]\n", 61 | " model[i] = self.out(torch.cat((dec_output, context), 1))\n", 62 | "\n", 63 | " # make model shape [n_step, n_class]\n", 64 | " return model.transpose(0, 1).squeeze(0), trained_attn\n", 65 | "\n", 66 | " def get_att_weight(self, dec_output, enc_outputs): # get attention weight one 'dec_output' with 'enc_outputs'\n", 67 | " n_step = len(enc_outputs)\n", 68 | " attn_scores = torch.zeros(n_step) # attn_scores : [n_step]\n", 69 | "\n", 70 | " for i in range(n_step):\n", 71 | " attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i])\n", 72 | "\n", 73 | " # Normalize scores to weights in range 0 to 1\n", 74 | " return F.softmax(attn_scores).view(1, 1, -1)\n", 75 | "\n", 76 | " def get_att_score(self, dec_output, enc_output): # enc_outputs [batch_size, num_directions(=1) * n_hidden]\n", 77 | " score = self.attn(enc_output) # score : [batch_size, n_hidden]\n", 78 | " return torch.dot(dec_output.view(-1), score.view(-1)) # inner product make scalar value\n", 79 | "\n", 80 | "if __name__ == '__main__':\n", 81 | " n_step = 5 # number of cells(= number of Step)\n", 82 | " n_hidden = 128 # number of hidden units in one cell\n", 83 | "\n", 84 | " sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']\n", 85 | "\n", 86 | " word_list = \" \".join(sentences).split()\n", 87 | " word_list = list(set(word_list))\n", 88 | " word_dict = {w: i for i, w in enumerate(word_list)}\n", 89 | " number_dict = {i: w for i, w in enumerate(word_list)}\n", 90 | " n_class = len(word_dict) # vocab list\n", 91 | "\n", 92 | " # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 93 | " hidden = torch.zeros(1, 1, n_hidden)\n", 94 | "\n", 95 | " model = Attention()\n", 96 | " criterion = nn.CrossEntropyLoss()\n", 97 | " optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", 98 | "\n", 99 | " input_batch, output_batch, target_batch = make_batch()\n", 100 | "\n", 101 | " # Train\n", 102 | " for epoch in range(2000):\n", 103 | " optimizer.zero_grad()\n", 104 | " output, _ = model(input_batch, hidden, output_batch)\n", 105 | "\n", 106 | " loss = criterion(output, target_batch.squeeze(0))\n", 107 | " if (epoch + 1) % 400 == 0:\n", 108 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 109 | "\n", 110 | " loss.backward()\n", 111 | " optimizer.step()\n", 112 | "\n", 113 | " # Test\n", 114 | " test_batch = [np.eye(n_class)[[word_dict[n] for n in 'SPPPP']]]\n", 115 | " test_batch = torch.FloatTensor(test_batch)\n", 116 | " predict, trained_attn = model(input_batch, hidden, test_batch)\n", 117 | " predict = predict.data.max(1, keepdim=True)[1]\n", 118 | " print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])\n", 119 | "\n", 120 | " # Show Attention\n", 121 | " fig = plt.figure(figsize=(5, 5))\n", 122 | " ax = fig.add_subplot(1, 1, 1)\n", 123 | " ax.matshow(trained_attn, cmap='viridis')\n", 124 | " ax.set_xticklabels([''] + sentences[0].split(), fontdict={'fontsize': 14})\n", 125 | " ax.set_yticklabels([''] + sentences[2].split(), fontdict={'fontsize': 14})\n", 126 | " plt.show()" 127 | ], 128 | "outputs": [], 129 | "execution_count": null 130 | } 131 | ], 132 | "metadata": { 133 | "anaconda-cloud": {}, 134 | "kernelspec": { 135 | "display_name": "Python 3", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "codemirror_mode": { 141 | "name": "ipython", 142 | "version": 3 143 | }, 144 | "file_extension": ".py", 145 | "mimetype": "text/x-python", 146 | "name": "python", 147 | "nbconvert_exporter": "python", 148 | "pygments_lexer": "ipython3", 149 | "version": "3.6.1" 150 | } 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 4 154 | } -------------------------------------------------------------------------------- /4-1.Seq2Seq/Seq2Seq.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "id": "metropolitan-married", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import mindspore\n", 11 | "import numpy as np\n", 12 | "import mindspore.nn as nn\n", 13 | "import mindspore.ops as ops\n", 14 | "from mindspore import Tensor, ms_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 13, 20 | "id": "internal-covering", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def make_batch(seq_data, num_dic, n_step):\n", 25 | " input_batch, output_batch, target_batch = [], [], []\n", 26 | "\n", 27 | " for seq in seq_data:\n", 28 | " for i in range(2):\n", 29 | " seq[i] = seq[i] + 'P' * (n_step - len(seq[i]))\n", 30 | "\n", 31 | " input = [num_dic[n] for n in seq[0]]\n", 32 | " output = [num_dic[n] for n in ('S' + seq[1])]\n", 33 | " target = [num_dic[n] for n in (seq[1] + 'E')]\n", 34 | "\n", 35 | " input_batch.append(np.eye(n_class)[input])\n", 36 | " output_batch.append(np.eye(n_class)[output])\n", 37 | " target_batch.append(target) # not one-hot\n", 38 | "\n", 39 | " # make tensor\n", 40 | " return Tensor(input_batch, mindspore.float32), Tensor(output_batch, mindspore.float32), Tensor(target_batch, mindspore.int32)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 14, 46 | "id": "compressed-resort", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Model\n", 51 | "class Seq2Seq(nn.Cell):\n", 52 | " def __init__(self, n_class, n_hidden, dropout):\n", 53 | " super(Seq2Seq, self).__init__()\n", 54 | "\n", 55 | " self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=dropout)\n", 56 | " self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=dropout)\n", 57 | " self.fc = nn.Dense(n_hidden, n_class)\n", 58 | " \n", 59 | " \n", 60 | " def construct(self, enc_input, dec_input):\n", 61 | " enc_input = enc_input.transpose((1, 0, 2)) # enc_input: [max_len(=n_step, time step), batch_size, n_class]\n", 62 | " dec_input = dec_input.transpose((1, 0, 2)) # dec_input: [max_len(=n_step, time step), batch_size, n_class]\n", 63 | "\n", 64 | " # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 65 | " _, enc_states = self.enc_cell(enc_input)\n", 66 | " # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)]\n", 67 | " outputs, _ = self.dec_cell(dec_input, enc_states)\n", 68 | "\n", 69 | " model = self.fc(outputs) # model : [max_len+1(=6), batch_size, n_class]\n", 70 | " return model" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 15, 76 | "id": "impaired-treat", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "n_step = 5\n", 81 | "n_hidden = 128\n", 82 | "dropout = 0.5\n", 83 | "char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']\n", 84 | "num_dic = {n: i for i, n in enumerate(char_arr)}\n", 85 | "seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]\n", 86 | "\n", 87 | "n_class = len(num_dic)\n", 88 | "batch_size = len(seq_data)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 16, 94 | "id": "structured-external", 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stderr", 99 | "output_type": "stream", 100 | "text": [ 101 | "[WARNING] ME(257088:139675582087424,MainProcess):2022-08-12-21:11:28.654.667 [mindspore/nn/layer/rnns.py:392] dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n", 102 | "[WARNING] ME(257088:139675582087424,MainProcess):2022-08-12-21:11:28.663.657 [mindspore/nn/layer/rnns.py:392] dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "model = Seq2Seq(n_class, n_hidden, dropout)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 17, 113 | "id": "japanese-platform", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "criterion = nn.CrossEntropyLoss()\n", 118 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 18, 124 | "id": "governmental-narrative", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "input_batch, output_batch, target_batch = make_batch(seq_data, num_dic, n_step)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 19, 134 | "id": "20affee4", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "def forward(enc_input, dec_input, target):\n", 139 | " output = model(enc_input, dec_input)\n", 140 | " output = output.transpose((1, 0, 2))\n", 141 | " return criterion(output.view(-1, output.shape[-1]), target.view(-1))" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 20, 147 | "id": "0727166e", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 21, 157 | "id": "bd24c88f", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "@ms_function\n", 162 | "def train_step(enc_input, dec_input, target):\n", 163 | " loss, grads = grad_fn(enc_input, dec_input, target)\n", 164 | " optimizer(grads)\n", 165 | " return loss" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 22, 171 | "id": "celtic-variety", 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "Epoch: 1000 cost = 0.000974\n", 179 | "Epoch: 2000 cost = 0.000262\n", 180 | "Epoch: 3000 cost = 0.000112\n", 181 | "Epoch: 4000 cost = 0.000056\n", 182 | "Epoch: 5000 cost = 0.000030\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "model.set_train()\n", 188 | "\n", 189 | "for epoch in range(5000):\n", 190 | " # input_batch : [batch_size, max_len(=n_step, time step), n_class]\n", 191 | " # output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]\n", 192 | " # target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot\n", 193 | " loss = train_step(input_batch, output_batch, target_batch)\n", 194 | " if (epoch + 1) % 1000 == 0:\n", 195 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 23, 201 | "id": "resident-debate", 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "test\n", 209 | "man -> women\n", 210 | "mans -> women\n", 211 | "king -> queen\n", 212 | "black -> white\n", 213 | "upp -> down\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "model.set_train(False)\n", 219 | "# Test\n", 220 | "def translate(word):\n", 221 | " input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]], num_dic, n_step)\n", 222 | " output = model(input_batch, output_batch)\n", 223 | " # output : [max_len+1(=6), batch_size(=1), n_class]\n", 224 | "\n", 225 | " predict = output.asnumpy().argmax(2) # select n_class dimension\n", 226 | " decoded = [char_arr[i[0]] for i in predict]\n", 227 | " end = decoded.index('E')\n", 228 | " translated = ''.join(decoded[:end])\n", 229 | "\n", 230 | " return translated.replace('P', '')\n", 231 | "\n", 232 | "print('test')\n", 233 | "print('man ->', translate('man'))\n", 234 | "print('mans ->', translate('mans'))\n", 235 | "print('king ->', translate('king'))\n", 236 | "print('black ->', translate('black'))\n", 237 | "print('upp ->', translate('upp'))" 238 | ] 239 | } 240 | ], 241 | "metadata": { 242 | "kernelspec": { 243 | "display_name": "Python 3.7.13 ('ms1.8')", 244 | "language": "python", 245 | "name": "python3" 246 | }, 247 | "language_info": { 248 | "codemirror_mode": { 249 | "name": "ipython", 250 | "version": 3 251 | }, 252 | "file_extension": ".py", 253 | "mimetype": "text/x-python", 254 | "name": "python", 255 | "nbconvert_exporter": "python", 256 | "pygments_lexer": "ipython3", 257 | "version": "3.7.13" 258 | }, 259 | "vscode": { 260 | "interpreter": { 261 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 262 | } 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 5 267 | } 268 | -------------------------------------------------------------------------------- /5-2.BERT/BERT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import re 8 | from random import * 9 | import mindspore 10 | import mindspore.nn as nn 11 | import mindspore.ops as ops 12 | import mindspore.numpy as mnp 13 | from layers import Dense, Embedding 14 | 15 | 16 | # In[2]: 17 | 18 | 19 | # sample IsNext and NotNext to be same in small batch size 20 | def make_batch(): 21 | batch = [] 22 | positive = negative = 0 23 | while positive != batch_size/2 or negative != batch_size/2: 24 | tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences 25 | tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index] 26 | input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']] 27 | segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1) 28 | 29 | # MASK LM 30 | n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence 31 | cand_maked_pos = [i for i, token in enumerate(input_ids) 32 | if token != word_dict['[CLS]'] and token != word_dict['[SEP]']] 33 | shuffle(cand_maked_pos) 34 | masked_tokens, masked_pos = [], [] 35 | for pos in cand_maked_pos[:n_pred]: 36 | masked_pos.append(pos) 37 | masked_tokens.append(input_ids[pos]) 38 | if random() < 0.8: # 80% 39 | input_ids[pos] = word_dict['[MASK]'] # make mask 40 | elif random() < 0.5: # 10% 41 | index = randint(0, vocab_size - 1) # random index in vocabulary 42 | input_ids[pos] = word_dict[number_dict[index]] # replace 43 | 44 | # Zero Paddings 45 | n_pad = maxlen - len(input_ids) 46 | input_ids.extend([0] * n_pad) 47 | segment_ids.extend([0] * n_pad) 48 | 49 | # Zero Padding (100% - 15%) tokens 50 | if max_pred > n_pred: 51 | n_pad = max_pred - n_pred 52 | masked_tokens.extend([0] * n_pad) 53 | masked_pos.extend([0] * n_pad) 54 | 55 | if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2: 56 | batch.append([input_ids, segment_ids, masked_tokens, masked_pos, 1]) # IsNext 57 | positive += 1 58 | elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2: 59 | batch.append([input_ids, segment_ids, masked_tokens, masked_pos, 0]) # NotNext 60 | negative += 1 61 | return batch 62 | # Proprecessing Finished 63 | 64 | 65 | # In[3]: 66 | 67 | 68 | def get_attn_pad_mask(seq_q, seq_k): 69 | batch_size, len_q = seq_q.shape 70 | batch_size, len_k = seq_k.shape 71 | 72 | pad_attn_mask = ops.equal(seq_k, 0) 73 | pad_attn_mask = pad_attn_mask.expand_dims(1) # batch_size x 1 x len_k(=len_q), one is masking 74 | 75 | return ops.broadcast_to(pad_attn_mask, (batch_size, len_q, len_k)) # batch_size x len_q x len_k 76 | 77 | 78 | # In[4]: 79 | 80 | 81 | class BertEmbedding(nn.Cell): 82 | def __init__(self): 83 | super(BertEmbedding, self).__init__() 84 | self.tok_embed = Embedding(vocab_size, d_model) # token embedding 85 | self.pos_embed = Embedding(maxlen, d_model) # position embedding 86 | self.seg_embed = Embedding(n_segments, d_model) # segment(token type) embedding 87 | self.norm = nn.LayerNorm([d_model,]) 88 | 89 | def construct(self, x, seg): 90 | seq_len = x.shape[1] 91 | pos = ops.arange(seq_len, dtype=mindspore.int64) 92 | pos = pos.expand_dims(0).expand_as(x) # (seq_len,) -> (batch_size, seq_len) 93 | embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg) 94 | return self.norm(embedding) 95 | 96 | 97 | # In[5]: 98 | 99 | 100 | class ScaledDotProductAttention(nn.Cell): 101 | def __init__(self): 102 | super(ScaledDotProductAttention, self).__init__() 103 | self.softmax = nn.Softmax(axis=-1) 104 | 105 | def construct(self, Q, K, V, attn_mask): 106 | scores = ops.matmul(Q, K.swapaxes(-1, -2)) / ops.sqrt(ops.scalar_to_tensor(d_k)) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] 107 | scores = scores.masked_fill(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one. 108 | attn = self.softmax(scores) 109 | context = ops.matmul(attn, V) 110 | return context, attn 111 | 112 | 113 | # In[6]: 114 | 115 | 116 | class MultiHeadAttention(nn.Cell): 117 | def __init__(self): 118 | super(MultiHeadAttention, self).__init__() 119 | self.W_Q = Dense(d_model, d_k * n_heads) 120 | self.W_K = Dense(d_model, d_k * n_heads) 121 | self.W_V = Dense(d_model, d_v * n_heads) 122 | self.attn = ScaledDotProductAttention() 123 | self.out_fc = Dense(n_heads * d_v, d_model) 124 | self.norm = nn.LayerNorm([d_model,]) 125 | 126 | def construct(self, Q, K, V, attn_mask): 127 | # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model] 128 | residual, batch_size = Q, Q.shape[0] 129 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 130 | q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).swapaxes(1,2) # q_s: [batch_size x n_heads x len_q x d_k] 131 | k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).swapaxes(1,2) # k_s: [batch_size x n_heads x len_k x d_k] 132 | v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).swapaxes(1,2) # v_s: [batch_size x n_heads x len_k x d_v] 133 | 134 | attn_mask = attn_mask.expand_dims(1) 135 | attn_mask = ops.tile(attn_mask, (1, n_heads, 1, 1)) 136 | 137 | # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] 138 | context, attn = self.attn(q_s, k_s, v_s, attn_mask) 139 | context = context.swapaxes(1, 2).view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v] 140 | output = self.out_fc(context) 141 | return self.norm(output + residual), attn # output: [batch_size x len_q x d_model] 142 | 143 | 144 | # In[7]: 145 | 146 | 147 | class PoswiseFeedForwardNet(nn.Cell): 148 | def __init__(self): 149 | super(PoswiseFeedForwardNet, self).__init__() 150 | self.fc1 = Dense(d_model, d_ff) 151 | self.fc2 = Dense(d_ff, d_model) 152 | self.activation = nn.GELU(False) 153 | 154 | def construct(self, x): 155 | # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model) 156 | return self.fc2(self.activation(self.fc1(x))) 157 | 158 | 159 | # In[8]: 160 | 161 | 162 | class EncoderLayer(nn.Cell): 163 | def __init__(self): 164 | super(EncoderLayer, self).__init__() 165 | self.enc_self_attn = MultiHeadAttention() 166 | self.pos_ffn = PoswiseFeedForwardNet() 167 | 168 | def construct(self, enc_inputs, enc_self_attn_mask): 169 | enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V 170 | enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model] 171 | return enc_outputs, attn 172 | 173 | 174 | # In[9]: 175 | 176 | 177 | class BERT(nn.Cell): 178 | def __init__(self): 179 | super(BERT, self).__init__() 180 | self.embedding = BertEmbedding() 181 | self.layers = nn.CellList([EncoderLayer() for _ in range(n_layers)]) 182 | self.fc = Dense(d_model, d_model) 183 | self.activ1 = nn.Tanh() 184 | self.linear = Dense(d_model, d_model) 185 | self.activ2 = nn.GELU(False) 186 | self.norm = nn.LayerNorm([d_model,]) 187 | self.classifier = Dense(d_model, 2) 188 | # decoder is shared with embedding layer 189 | embed_weight = self.embedding.tok_embed.embedding_table 190 | n_vocab, n_dim = embed_weight.shape 191 | self.decoder = Dense(n_dim, n_vocab, has_bias=False) 192 | self.decoder.weight = embed_weight 193 | self.decoder_bias = mindspore.Parameter(ops.zeros(n_vocab), 'decoder_bias') 194 | 195 | def construct(self, input_ids, segment_ids, masked_pos): 196 | output = self.embedding(input_ids, segment_ids) 197 | enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) 198 | for layer in self.layers: 199 | output, enc_self_attn = layer(output, enc_self_attn_mask) 200 | # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model] 201 | # it will be decided by first token(CLS) 202 | h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model] 203 | logits_clsf = self.classifier(h_pooled) # [batch_size, 2] 204 | 205 | masked_pos = ops.tile(masked_pos[:, :, None], (1, 1, output.shape[-1])) # [batch_size, max_pred, d_model] 206 | # get masked position from final output of transformer. 207 | h_masked = ops.gather_d(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model] 208 | h_masked = self.norm(self.activ2(self.linear(h_masked))) 209 | logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab] 210 | 211 | return logits_lm, logits_clsf 212 | 213 | 214 | # In[10]: 215 | 216 | 217 | # BERT Parameters 218 | maxlen = 30 # maximum of length 219 | batch_size = 6 220 | max_pred = 5 # max tokens of prediction 221 | n_layers = 6 # number of Encoder of Encoder Layer 222 | n_heads = 12 # number of heads in Multi-Head Attention 223 | d_model = 768 # Embedding Size 224 | d_ff = 768 * 4 # 4*d_model, FeedForward dimension 225 | d_k = d_v = 64 # dimension of K(=Q), V 226 | n_segments = 2 227 | 228 | 229 | # In[11]: 230 | 231 | 232 | text = ( 233 | 'Hello, how are you? I am Romeo.\n' 234 | 'Hello, Romeo My name is Juliet. Nice to meet you.\n' 235 | 'Nice meet you too. How are you today?\n' 236 | 'Great. My baseball team won the competition.\n' 237 | 'Oh Congratulations, Juliet\n' 238 | 'Thanks you Romeo' 239 | ) 240 | sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filter '.', ',', '?', '!' 241 | word_list = list(set(" ".join(sentences).split())) 242 | word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3} 243 | for i, w in enumerate(word_list): 244 | word_dict[w] = i + 4 245 | number_dict = {i: w for i, w in enumerate(word_dict)} 246 | vocab_size = len(word_dict) 247 | 248 | token_list = list() 249 | for sentence in sentences: 250 | arr = [word_dict[s] for s in sentence.split()] 251 | token_list.append(arr) 252 | 253 | 254 | # In[12]: 255 | 256 | 257 | model = BERT() 258 | criterion = nn.CrossEntropyLoss() 259 | optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001) 260 | 261 | 262 | # In[13]: 263 | 264 | 265 | def forward(input_ids, segment_ids, masked_pos, masked_tokens, isNext): 266 | logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos) 267 | loss_lm = criterion(logits_lm.swapaxes(1, 2), masked_tokens.astype(mindspore.int32)) 268 | loss_lm = loss_lm.mean() 269 | loss_clsf = criterion(logits_clsf, isNext.astype(mindspore.int32)) 270 | 271 | return loss_lm + loss_clsf 272 | 273 | 274 | # In[14]: 275 | 276 | 277 | grad_fn = ops.value_and_grad(forward, None, optimizer.parameters) 278 | 279 | 280 | # In[15]: 281 | 282 | 283 | @mindspore.jit 284 | def train_step(input_ids, segment_ids, masked_pos, masked_tokens, isNext): 285 | loss, grads = grad_fn(input_ids, segment_ids, masked_pos, masked_tokens, isNext) 286 | optimizer(grads) 287 | return loss 288 | 289 | 290 | # In[16]: 291 | 292 | 293 | batch = make_batch() 294 | input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(mindspore.Tensor, zip(*batch)) 295 | 296 | model.set_train() 297 | for epoch in range(100): 298 | loss = train_step(input_ids, segment_ids, masked_pos, masked_tokens, isNext) # for sentence classification 299 | if (epoch + 1) % 10 == 0: 300 | print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy())) 301 | 302 | 303 | # In[ ]: 304 | 305 | 306 | # Predict mask tokens ans isNext 307 | input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(mindspore.Tensor, zip(batch[0])) 308 | print(text) 309 | print([number_dict[int(w.asnumpy())] for w in input_ids[0] if number_dict[int(w.asnumpy())] != '[PAD]']) 310 | 311 | logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos) 312 | logits_lm = logits_lm.argmax(2)[0].asnumpy() 313 | print('masked tokens list : ',[pos for pos in masked_tokens[0] if pos != 0]) 314 | print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0]) 315 | 316 | logits_clsf = logits_clsf.argmax(1).asnumpy()[0] 317 | print('isNext : ', True if isNext else False) 318 | print('predict isNext : ',True if logits_clsf else False) 319 | 320 | -------------------------------------------------------------------------------- /5-1.Transformer/Transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import mindspore 8 | import numpy as np 9 | import mindspore.nn as nn 10 | import mindspore.ops as ops 11 | from mindspore import Tensor 12 | import matplotlib.pyplot as plt 13 | from layers import Dense, Embedding, Conv1d 14 | # S: Symbol that shows starting of decoding input 15 | # E: Symbol that shows starting of decoding output 16 | # P: Symbol that will fill in blank sequence if current batch data size is short than time steps 17 | 18 | 19 | # In[2]: 20 | 21 | 22 | def make_batch(sentences, src_vocab, tgt_vocab): 23 | input_batch = [[src_vocab[n] for n in sentences[0].split()]] 24 | output_batch = [[tgt_vocab[n] for n in sentences[1].split()]] 25 | target_batch = [[tgt_vocab[n] for n in sentences[2].split()]] 26 | return Tensor(input_batch, mindspore.int32), Tensor(output_batch, mindspore.int32), Tensor(target_batch, mindspore.int32) 27 | 28 | 29 | # In[3]: 30 | 31 | 32 | def get_sinusoid_encoding_table(n_position, d_model): 33 | def cal_angle(position, hid_idx): 34 | return position / np.power(10000, 2 * (hid_idx // 2) / d_model) 35 | def get_posi_angle_vec(position): 36 | return [cal_angle(position, hid_j) for hid_j in range(d_model)] 37 | 38 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 39 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 40 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 41 | return Tensor(sinusoid_table, mindspore.float32) 42 | 43 | 44 | # In[4]: 45 | 46 | 47 | def get_attn_pad_mask(seq_q, seq_k): 48 | batch_size, len_q = seq_q.shape 49 | batch_size, len_k = seq_k.shape 50 | 51 | pad_attn_mask = seq_k.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking 52 | return pad_attn_mask.broadcast_to((batch_size, len_q, len_k)) # batch_size x len_q x len_k 53 | 54 | 55 | # In[5]: 56 | 57 | 58 | def get_attn_subsequent_mask(seq): 59 | attn_shape = [seq.shape[0], seq.shape[1], seq.shape[1]] 60 | subsequent_mask = np.triu(np.ones(attn_shape), k=1) 61 | subsequent_mask = Tensor.from_numpy(subsequent_mask).to(mindspore.uint8) 62 | return subsequent_mask 63 | 64 | 65 | # In[6]: 66 | 67 | 68 | class ScaledDotProductAttention(nn.Cell): 69 | def __init__(self, d_k): 70 | super().__init__() 71 | self.softmax = nn.Softmax(axis=-1) 72 | self.d_k = Tensor(d_k, mindspore.float32) 73 | 74 | def construct(self, Q, K, V, attn_mask): 75 | scores = ops.matmul(Q, K.swapaxes(-1, -2)) / ops.sqrt(self.d_k)# scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] 76 | scores = scores.masked_fill(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one. 77 | attn = ops.softmax(scores) 78 | context = ops.matmul(attn, V) 79 | return context, attn 80 | 81 | 82 | # In[7]: 83 | 84 | 85 | class MultiHeadAttention(nn.Cell): 86 | def __init__(self, d_model, d_k, d_v, n_heads): 87 | super().__init__() 88 | self.d_k = d_k 89 | self.d_v = d_v 90 | self.n_heads = n_heads 91 | self.W_Q = Dense(d_model, d_k * n_heads) 92 | self.W_K = Dense(d_model, d_k * n_heads) 93 | self.W_V = Dense(d_model, d_v * n_heads) 94 | self.linear = Dense(n_heads * d_v, d_model) 95 | self.layer_norm = nn.LayerNorm((d_model, ), epsilon=1e-5) 96 | self.attention = ScaledDotProductAttention(d_k) 97 | 98 | def construct(self, Q, K, V, attn_mask): 99 | # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model] 100 | residual, batch_size = Q, Q.shape[0] 101 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 102 | q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).swapaxes(1,2) # q_s: [batch_size x n_heads x len_q x d_k] 103 | k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).swapaxes(1,2) # k_s: [batch_size x n_heads x len_k x d_k] 104 | v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).swapaxes(1,2) # v_s: [batch_size x n_heads x len_k x d_v] 105 | 106 | attn_mask = attn_mask.unsqueeze(1).tile((1, n_heads, 1, 1)) # attn_mask : [batch_size x n_heads x len_q x len_k] 107 | 108 | # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] 109 | context, attn = self.attention(q_s, k_s, v_s, attn_mask) 110 | context = context.swapaxes(1, 2).view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v] 111 | output = self.linear(context) 112 | return self.layer_norm(output + residual), attn # output: [batch_size x len_q x d_model] 113 | 114 | 115 | # In[8]: 116 | 117 | 118 | class PoswiseFeedForward(nn.Cell): 119 | def __init__(self, d_ff, d_model): 120 | super().__init__() 121 | self.conv1 = Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 122 | self.conv2 = Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 123 | self.layer_norm = nn.LayerNorm((d_model, ), epsilon=1e-5) 124 | self.relu = nn.ReLU() 125 | 126 | def construct(self, inputs): 127 | residual = inputs # inputs : [batch_size, len_q, d_model] 128 | output = self.relu(self.conv1(inputs.swapaxes(1, 2))) 129 | output = self.conv2(output).swapaxes(1, 2) 130 | return self.layer_norm(output + residual) 131 | 132 | 133 | # In[9]: 134 | 135 | 136 | class EncoderLayer(nn.Cell): 137 | def __init__(self, d_model, d_k, d_v, n_heads, d_ff): 138 | super().__init__() 139 | self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads) 140 | self.pos_ffn = PoswiseFeedForward(d_ff, d_model) 141 | 142 | def construct(self, enc_inputs, enc_self_attn_mask): 143 | enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V 144 | enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model] 145 | return enc_outputs, attn 146 | 147 | 148 | # In[10]: 149 | 150 | 151 | class DecoderLayer(nn.Cell): 152 | def __init__(self, d_model, d_k, d_v, n_heads, d_ff): 153 | super().__init__() 154 | self.dec_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads) 155 | self.dec_enc_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads) 156 | self.pos_ffn = PoswiseFeedForward(d_ff, d_model) 157 | 158 | def construct(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): 159 | dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) 160 | dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) 161 | dec_outputs = self.pos_ffn(dec_outputs) 162 | return dec_outputs, dec_self_attn, dec_enc_attn 163 | 164 | 165 | # In[11]: 166 | 167 | 168 | class Encoder(nn.Cell): 169 | def __init__(self, src_vocab_size, d_model, d_k, d_v, n_heads, d_ff, n_layers, src_len): 170 | super().__init__() 171 | self.src_emb = Embedding(src_vocab_size, d_model) 172 | self.pos_emb = Embedding.from_pretrained_embedding(get_sinusoid_encoding_table(src_len+1, d_model), freeze=True) 173 | self.layers = nn.CellList([EncoderLayer(d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)]) 174 | # temp positional indexes 175 | self.pos = Tensor([[1, 2, 3, 4, 0]]) 176 | 177 | def construct(self, enc_inputs): 178 | # enc_inputs : [batch_size x source_len] 179 | enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(self.pos) 180 | enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) 181 | enc_self_attns = [] 182 | for layer in self.layers: 183 | enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask) 184 | enc_self_attns.append(enc_self_attn) 185 | return enc_outputs, enc_self_attns 186 | 187 | 188 | # In[12]: 189 | 190 | 191 | class Decoder(nn.Cell): 192 | def __init__(self, tgt_vocab_size, d_model, d_k, d_v, n_heads, d_ff, n_layers, tgt_len): 193 | super().__init__() 194 | self.tgt_emb = Embedding(tgt_vocab_size, d_model) 195 | self.pos_emb = Embedding.from_pretrained_embedding(get_sinusoid_encoding_table(tgt_len+1, d_model), freeze=True) 196 | self.layers = nn.CellList([DecoderLayer(d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)]) 197 | 198 | def construct(self, dec_inputs, enc_inputs, enc_outputs): 199 | # dec_inputs : [batch_size x target_len] 200 | dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(Tensor([[5,1,2,3,4]])) 201 | dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) 202 | dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) 203 | dec_self_attn_mask = ops.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) 204 | 205 | dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) 206 | 207 | dec_self_attns, dec_enc_attns = [], [] 208 | for layer in self.layers: 209 | dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask) 210 | dec_self_attns.append(dec_self_attn) 211 | dec_enc_attns.append(dec_enc_attn) 212 | return dec_outputs, dec_self_attns, dec_enc_attns 213 | 214 | 215 | # In[13]: 216 | 217 | 218 | class Transformer(nn.Cell): 219 | def __init__(self, d_model, d_k, d_v, n_heads, d_ff, n_layers, src_vocab_size, tgt_vocab_size, src_len, tgt_len): 220 | super(Transformer, self).__init__() 221 | self.encoder = Encoder(src_vocab_size, d_model, d_k, d_v, n_heads, d_ff, n_layers, src_len) 222 | self.decoder = Decoder(tgt_vocab_size, d_model, d_k, d_v, n_heads, d_ff, n_layers, tgt_len) 223 | self.projection = Dense(d_model, tgt_vocab_size, has_bias=False) 224 | 225 | def construct(self, enc_inputs, dec_inputs): 226 | enc_outputs, enc_self_attns = self.encoder(enc_inputs) 227 | dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) 228 | dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size] 229 | return dec_logits.view((-1, dec_logits.shape[-1])), enc_self_attns, dec_self_attns, dec_enc_attns 230 | 231 | 232 | # In[14]: 233 | 234 | 235 | sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E'] 236 | 237 | # Transformer Parameters 238 | # Padding Should be Zero 239 | src_vocab = {'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4} 240 | src_vocab_size = len(src_vocab) 241 | 242 | tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'S': 5, 'E': 6} 243 | number_dict = {i: w for i, w in enumerate(tgt_vocab)} 244 | tgt_vocab_size = len(tgt_vocab) 245 | 246 | src_len = 6 # length of source 247 | tgt_len = 5 # length of target 248 | 249 | d_model = 512 # Embedding Size 250 | d_ff = 2048 # FeedForward dimension 251 | d_k = d_v = 64 # dimension of K(=Q), V 252 | n_layers = 6 # number of Encoder of Decoder Layer 253 | n_heads = 8 # number of heads in Multi-Head Attention 254 | 255 | 256 | # In[15]: 257 | 258 | 259 | model = Transformer(d_model, d_k, d_v, n_heads, d_ff, n_layers, src_vocab_size, tgt_vocab_size, src_len, tgt_len) 260 | 261 | 262 | # In[16]: 263 | 264 | 265 | criterion = nn.CrossEntropyLoss() 266 | optimizer = nn.Adam(model.trainable_params(), learning_rate=0.0001) 267 | # print(model.trainable_params()) 268 | enc_inputs, dec_inputs, target_batch = make_batch(sentences, src_vocab, tgt_vocab) 269 | 270 | 271 | # In[17]: 272 | 273 | 274 | def forward(enc_inputs, dec_inputs, target_batch): 275 | outputs, _, _, _, = model(enc_inputs, dec_inputs) 276 | loss = criterion(outputs, target_batch) 277 | 278 | return loss 279 | 280 | 281 | # In[18]: 282 | 283 | 284 | grad_fn = ops.value_and_grad(forward, None, optimizer.parameters) 285 | 286 | 287 | # In[19]: 288 | 289 | 290 | @mindspore.jit 291 | def train_step(enc_inputs, dec_inputs, target_batch): 292 | loss, grads = grad_fn(enc_inputs, dec_inputs, target_batch) 293 | optimizer(grads) 294 | return loss 295 | 296 | 297 | # In[20]: 298 | 299 | 300 | model.set_train() 301 | 302 | # Training 303 | for epoch in range(20): 304 | # hidden : [num_layers * num_directions, batch, hidden_size] 305 | loss = train_step(enc_inputs, dec_inputs, target_batch.view(-1)) 306 | print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy())) 307 | 308 | 309 | # In[21]: 310 | 311 | 312 | # Test 313 | predict, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs) 314 | predict = predict.asnumpy().argmax(1) 315 | print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()]) 316 | 317 | 318 | # In[22]: 319 | 320 | 321 | def showgraph(attn): 322 | attn = attn[-1].squeeze(0)[0] 323 | attn = attn.asnumpy() 324 | fig = plt.figure(figsize=(n_heads, n_heads)) # [n_heads, n_heads] 325 | ax = fig.add_subplot(1, 1, 1) 326 | ax.matshow(attn, cmap='viridis') 327 | ax.set_xticklabels(['']+sentences[0].split(), fontdict={'fontsize': 14}, rotation=90) 328 | ax.set_yticklabels(['']+sentences[2].split(), fontdict={'fontsize': 14}) 329 | plt.show() 330 | 331 | 332 | # In[ ]: 333 | 334 | 335 | print('first head of last state enc_self_attns') 336 | showgraph(enc_self_attns) 337 | 338 | print('first head of last state dec_self_attns') 339 | showgraph(dec_self_attns) 340 | 341 | print('first head of last state dec_enc_attns') 342 | showgraph(dec_enc_attns) 343 | 344 | -------------------------------------------------------------------------------- /5-2.BERT/BERT_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# code by Tae Hwan Jung(Jeff Jung) @graykode\n", 10 | "# Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch\n", 11 | "# https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert\n", 12 | "import math\n", 13 | "import re\n", 14 | "from random import *\n", 15 | "import numpy as np\n", 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.optim as optim\n", 19 | "\n", 20 | "# sample IsNext and NotNext to be same in small batch size\n", 21 | "def make_batch():\n", 22 | " batch = []\n", 23 | " positive = negative = 0\n", 24 | " while positive != batch_size/2 or negative != batch_size/2:\n", 25 | " tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences\n", 26 | " tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]\n", 27 | " input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]\n", 28 | " segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)\n", 29 | "\n", 30 | " # MASK LM\n", 31 | " n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence\n", 32 | " cand_maked_pos = [i for i, token in enumerate(input_ids)\n", 33 | " if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]\n", 34 | " shuffle(cand_maked_pos)\n", 35 | " masked_tokens, masked_pos = [], []\n", 36 | " for pos in cand_maked_pos[:n_pred]:\n", 37 | " masked_pos.append(pos)\n", 38 | " masked_tokens.append(input_ids[pos])\n", 39 | " if random() < 0.8: # 80%\n", 40 | " input_ids[pos] = word_dict['[MASK]'] # make mask\n", 41 | " elif random() < 0.5: # 10%\n", 42 | " index = randint(0, vocab_size - 1) # random index in vocabulary\n", 43 | " input_ids[pos] = word_dict[number_dict[index]] # replace\n", 44 | "\n", 45 | " # Zero Paddings\n", 46 | " n_pad = maxlen - len(input_ids)\n", 47 | " input_ids.extend([0] * n_pad)\n", 48 | " segment_ids.extend([0] * n_pad)\n", 49 | "\n", 50 | " # Zero Padding (100% - 15%) tokens\n", 51 | " if max_pred > n_pred:\n", 52 | " n_pad = max_pred - n_pred\n", 53 | " masked_tokens.extend([0] * n_pad)\n", 54 | " masked_pos.extend([0] * n_pad)\n", 55 | "\n", 56 | " if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:\n", 57 | " batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext\n", 58 | " positive += 1\n", 59 | " elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:\n", 60 | " batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext\n", 61 | " negative += 1\n", 62 | " return batch\n", 63 | "# Proprecessing Finished\n", 64 | "\n", 65 | "def get_attn_pad_mask(seq_q, seq_k):\n", 66 | " batch_size, len_q = seq_q.size()\n", 67 | " batch_size, len_k = seq_k.size()\n", 68 | " # eq(zero) is PAD token\n", 69 | " pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking\n", 70 | " return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k\n", 71 | "\n", 72 | "def gelu(x):\n", 73 | " \"Implementation of the gelu activation function by Hugging Face\"\n", 74 | " return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))\n", 75 | "\n", 76 | "class Embedding(nn.Module):\n", 77 | " def __init__(self):\n", 78 | " super(Embedding, self).__init__()\n", 79 | " self.tok_embed = nn.Embedding(vocab_size, d_model) # token embedding\n", 80 | " self.pos_embed = nn.Embedding(maxlen, d_model) # position embedding\n", 81 | " self.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embedding\n", 82 | " self.norm = nn.LayerNorm(d_model)\n", 83 | "\n", 84 | " def forward(self, x, seg):\n", 85 | " seq_len = x.size(1)\n", 86 | " pos = torch.arange(seq_len, dtype=torch.long)\n", 87 | " pos = pos.unsqueeze(0).expand_as(x) # (seq_len,) -> (batch_size, seq_len)\n", 88 | " embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n", 89 | " return self.norm(embedding)\n", 90 | "\n", 91 | "class ScaledDotProductAttention(nn.Module):\n", 92 | " def __init__(self):\n", 93 | " super(ScaledDotProductAttention, self).__init__()\n", 94 | "\n", 95 | " def forward(self, Q, K, V, attn_mask):\n", 96 | " scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n", 97 | " scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.\n", 98 | " attn = nn.Softmax(dim=-1)(scores)\n", 99 | " context = torch.matmul(attn, V)\n", 100 | " return context, attn\n", 101 | "\n", 102 | "class MultiHeadAttention(nn.Module):\n", 103 | " def __init__(self):\n", 104 | " super(MultiHeadAttention, self).__init__()\n", 105 | " self.W_Q = nn.Linear(d_model, d_k * n_heads)\n", 106 | " self.W_K = nn.Linear(d_model, d_k * n_heads)\n", 107 | " self.W_V = nn.Linear(d_model, d_v * n_heads)\n", 108 | " def forward(self, Q, K, V, attn_mask):\n", 109 | " # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]\n", 110 | " residual, batch_size = Q, Q.size(0)\n", 111 | " # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)\n", 112 | " q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size x n_heads x len_q x d_k]\n", 113 | " k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size x n_heads x len_k x d_k]\n", 114 | " v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size x n_heads x len_k x d_v]\n", 115 | "\n", 116 | " attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]\n", 117 | "\n", 118 | " # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n", 119 | " context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)\n", 120 | " context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]\n", 121 | " output = nn.Linear(n_heads * d_v, d_model)(context)\n", 122 | " return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]\n", 123 | "\n", 124 | "class PoswiseFeedForwardNet(nn.Module):\n", 125 | " def __init__(self):\n", 126 | " super(PoswiseFeedForwardNet, self).__init__()\n", 127 | " self.fc1 = nn.Linear(d_model, d_ff)\n", 128 | " self.fc2 = nn.Linear(d_ff, d_model)\n", 129 | "\n", 130 | " def forward(self, x):\n", 131 | " # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)\n", 132 | " return self.fc2(gelu(self.fc1(x)))\n", 133 | "\n", 134 | "class EncoderLayer(nn.Module):\n", 135 | " def __init__(self):\n", 136 | " super(EncoderLayer, self).__init__()\n", 137 | " self.enc_self_attn = MultiHeadAttention()\n", 138 | " self.pos_ffn = PoswiseFeedForwardNet()\n", 139 | "\n", 140 | " def forward(self, enc_inputs, enc_self_attn_mask):\n", 141 | " enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V\n", 142 | " enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]\n", 143 | " return enc_outputs, attn\n", 144 | "\n", 145 | "class BERT(nn.Module):\n", 146 | " def __init__(self):\n", 147 | " super(BERT, self).__init__()\n", 148 | " self.embedding = Embedding()\n", 149 | " self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n", 150 | " self.fc = nn.Linear(d_model, d_model)\n", 151 | " self.activ1 = nn.Tanh()\n", 152 | " self.linear = nn.Linear(d_model, d_model)\n", 153 | " self.activ2 = gelu\n", 154 | " self.norm = nn.LayerNorm(d_model)\n", 155 | " self.classifier = nn.Linear(d_model, 2)\n", 156 | " # decoder is shared with embedding layer\n", 157 | " embed_weight = self.embedding.tok_embed.weight\n", 158 | " n_vocab, n_dim = embed_weight.size()\n", 159 | " self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n", 160 | " self.decoder.weight = embed_weight\n", 161 | " self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n", 162 | "\n", 163 | " def forward(self, input_ids, segment_ids, masked_pos):\n", 164 | " output = self.embedding(input_ids, segment_ids)\n", 165 | " enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n", 166 | " for layer in self.layers:\n", 167 | " output, enc_self_attn = layer(output, enc_self_attn_mask)\n", 168 | " # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]\n", 169 | " # it will be decided by first token(CLS)\n", 170 | " h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model]\n", 171 | " logits_clsf = self.classifier(h_pooled) # [batch_size, 2]\n", 172 | "\n", 173 | " masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]\n", 174 | " # get masked position from final output of transformer.\n", 175 | " h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]\n", 176 | " h_masked = self.norm(self.activ2(self.linear(h_masked)))\n", 177 | " logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]\n", 178 | "\n", 179 | " return logits_lm, logits_clsf\n", 180 | "\n", 181 | "if __name__ == '__main__':\n", 182 | " # BERT Parameters\n", 183 | " maxlen = 30 # maximum of length\n", 184 | " batch_size = 6\n", 185 | " max_pred = 5 # max tokens of prediction\n", 186 | " n_layers = 6 # number of Encoder of Encoder Layer\n", 187 | " n_heads = 12 # number of heads in Multi-Head Attention\n", 188 | " d_model = 768 # Embedding Size\n", 189 | " d_ff = 768 * 4 # 4*d_model, FeedForward dimension\n", 190 | " d_k = d_v = 64 # dimension of K(=Q), V\n", 191 | " n_segments = 2\n", 192 | "\n", 193 | " text = (\n", 194 | " 'Hello, how are you? I am Romeo.\\n'\n", 195 | " 'Hello, Romeo My name is Juliet. Nice to meet you.\\n'\n", 196 | " 'Nice meet you too. How are you today?\\n'\n", 197 | " 'Great. My baseball team won the competition.\\n'\n", 198 | " 'Oh Congratulations, Juliet\\n'\n", 199 | " 'Thanks you Romeo'\n", 200 | " )\n", 201 | " sentences = re.sub(\"[.,!?\\\\-]\", '', text.lower()).split('\\n') # filter '.', ',', '?', '!'\n", 202 | " word_list = list(set(\" \".join(sentences).split()))\n", 203 | " word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}\n", 204 | " for i, w in enumerate(word_list):\n", 205 | " word_dict[w] = i + 4\n", 206 | " number_dict = {i: w for i, w in enumerate(word_dict)}\n", 207 | " vocab_size = len(word_dict)\n", 208 | "\n", 209 | " token_list = list()\n", 210 | " for sentence in sentences:\n", 211 | " arr = [word_dict[s] for s in sentence.split()]\n", 212 | " token_list.append(arr)\n", 213 | "\n", 214 | " model = BERT()\n", 215 | " criterion = nn.CrossEntropyLoss()\n", 216 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 217 | "\n", 218 | " batch = make_batch()\n", 219 | " input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))\n", 220 | "\n", 221 | " for epoch in range(100):\n", 222 | " optimizer.zero_grad()\n", 223 | " logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)\n", 224 | " loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM\n", 225 | " loss_lm = (loss_lm.float()).mean()\n", 226 | " loss_clsf = criterion(logits_clsf, isNext) # for sentence classification\n", 227 | " loss = loss_lm + loss_clsf\n", 228 | " if (epoch + 1) % 10 == 0:\n", 229 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n", 230 | " loss.backward()\n", 231 | " optimizer.step()\n", 232 | "\n", 233 | " # Predict mask tokens ans isNext\n", 234 | " input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0]))\n", 235 | " print(text)\n", 236 | " print([number_dict[w.item()] for w in input_ids[0] if number_dict[w.item()] != '[PAD]'])\n", 237 | "\n", 238 | " logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)\n", 239 | " logits_lm = logits_lm.data.max(2)[1][0].data.numpy()\n", 240 | " print('masked tokens list : ',[pos.item() for pos in masked_tokens[0] if pos.item() != 0])\n", 241 | " print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])\n", 242 | "\n", 243 | " logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]\n", 244 | " print('isNext : ', True if isNext else False)\n", 245 | " print('predict isNext : ',True if logits_clsf else False)\n" 246 | ] 247 | } 248 | ], 249 | "metadata": { 250 | "anaconda-cloud": {}, 251 | "kernelspec": { 252 | "display_name": "Python 3 (ipykernel)", 253 | "language": "python", 254 | "name": "python3" 255 | }, 256 | "language_info": { 257 | "codemirror_mode": { 258 | "name": "ipython", 259 | "version": 3 260 | }, 261 | "file_extension": ".py", 262 | "mimetype": "text/x-python", 263 | "name": "python", 264 | "nbconvert_exporter": "python", 265 | "pygments_lexer": "ipython3", 266 | "version": "3.9.18" 267 | } 268 | }, 269 | "nbformat": 4, 270 | "nbformat_minor": 4 271 | } 272 | -------------------------------------------------------------------------------- /1-2.Word2Vec/Word2Vec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import mindspore\n", 11 | "import mindspore.nn as nn\n", 12 | "import mindspore.ops as ops\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "from mindspore import Tensor, ms_function" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "def random_batch():\n", 24 | " random_inputs = []\n", 25 | " random_labels = []\n", 26 | " random_index = np.random.choice(range(len(skip_grams)), batch_size, replace=False)\n", 27 | "\n", 28 | " for i in random_index:\n", 29 | " random_inputs.append(np.eye(voc_size)[skip_grams[i][0]]) # target\n", 30 | " random_labels.append(skip_grams[i][1]) # context word\n", 31 | "\n", 32 | " return random_inputs, random_labels" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "class Word2Vec(nn.Cell):\n", 42 | " def __init__(self, voc_size, embed_size):\n", 43 | " super(Word2Vec, self).__init__()\n", 44 | " # W and WT is not Traspose relationship\n", 45 | " self.W = nn.Dense(voc_size, embed_size, has_bias=False) # voc_size > embedding_size Weight\n", 46 | " self.WT = nn.Dense(embed_size, voc_size, has_bias=False) # embedding_size > voc_size Weight\n", 47 | " \n", 48 | " def construct(self, X):\n", 49 | " # X : [batch_size, voc_size]\n", 50 | " hidden_layer = self.W(X) # hidden_layer : [batch_size, embedding_size]\n", 51 | " output_layer = self.WT(hidden_layer) # output_layer : [batch_size, voc_size]\n", 52 | " return output_layer" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "batch_size = 2 # mini-batch size\n", 62 | "embed_size = 2 # embedding size\n", 63 | "\n", 64 | "sentences = [\"apple banana fruit\", \"banana orange fruit\", \"orange banana fruit\",\n", 65 | " \"dog cat animal\", \"cat monkey animal\", \"monkey dog animal\"]\n", 66 | "\n", 67 | "word_sequence = \" \".join(sentences).split()\n", 68 | "word_list = \" \".join(sentences).split()\n", 69 | "word_list = list(set(word_list))\n", 70 | "word_dict = {w: i for i, w in enumerate(word_list)}\n", 71 | "voc_size = len(word_list)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 5, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Make skip gram of one size window\n", 81 | "skip_grams = []\n", 82 | "for i in range(1, len(word_sequence) - 1):\n", 83 | " target = word_dict[word_sequence[i]]\n", 84 | " context = [word_dict[word_sequence[i - 1]], word_dict[word_sequence[i + 1]]]\n", 85 | " for w in context:\n", 86 | " skip_grams.append([target, w])" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "model = Word2Vec(voc_size, embed_size)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 7, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "criterion = nn.CrossEntropyLoss()\n", 105 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 8, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "def forward(inputs, targets):\n", 115 | " logits = model(inputs)\n", 116 | " loss = criterion(logits, targets)\n", 117 | " return loss" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 9, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 10, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "@ms_function\n", 136 | "def train_step(inputs, targets):\n", 137 | " loss, grads = grad_fn(inputs, targets)\n", 138 | " optimizer(grads)\n", 139 | " return loss" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 11, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Epoch: 1000 cost = 1.538420\n", 152 | "Epoch: 2000 cost = 1.268426\n", 153 | "Epoch: 3000 cost = 1.192565\n", 154 | "Epoch: 4000 cost = 1.204078\n", 155 | "Epoch: 5000 cost = 1.086928\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "model.set_train()\n", 161 | "\n", 162 | "epoch = 5000\n", 163 | "for step in range(epoch):\n", 164 | " input_batch, target_batch = random_batch()\n", 165 | " input_batch = Tensor(input_batch, mindspore.float32)\n", 166 | " target_batch = Tensor(target_batch, mindspore.int32)\n", 167 | " loss = train_step(input_batch, target_batch)\n", 168 | " if (step + 1) % 1000 == 0:\n", 169 | " print('Epoch:', '%04d' % (step + 1), 'cost = ', '{:.6f}'.format(loss.asnumpy()))" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 12, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAe7klEQVR4nO3de3RU5b3/8fc3IQQFDVZAqFy1QCEXIASkcjGCVRArQvGuWD2IglS0rb/SxakGz7LLVls5VD0WDipaLFIUhUqLyqWCghI0plwr0silCOF+EZCQ7++PDDlccnUmMxP257XWrOz97Gee5zuzyIedfZkxd0dERM58CbEuQEREokOBLyISEAp8EZGAUOCLiASEAl9EJCAU+CIiARF24JtZCzNbaGarzWyVmY0po4+Z2UQzW29m+WaWGe68IiJSPXUiMEYR8FN3/9jMzgFWmNk77r76hD4DgLahxyXA/4R+iohIlIS9h+/uW93949DyfmANcOEp3QYBL3mJZUBDM2sW7twi8cbMWpvZyir0yzazS6NRk8hxkdjDL2VmrYEuwIenbLoQ2HTC+uZQ29aKxmvUqJG3bt06ghWK1Ky0tDTWr19PVlZWhbewN2vWjISEBLKysqJVmgTEihUrdrh747K2RSzwzawB8BrwgLvvC2OcEcAIgJYtW5KbmxuhCkVOVlBQQP/+/enRowcffPAB3bp148477+SRRx5h+/btTJs2je985zvcddddbNiwgbPPPptJkyaRkZFBTk4OGzduZMOGDWzcuJEHHniA+++/n9/+9reMGzeOo0ePctFFF/Hpp59y9tlnU1xczLZt27jwwgtJTk7m6NGjJCUlkZ+fz7x587j88svZt28fnTp14p///CdJSUmxfnukljKzL8rbFpHAN7MkSsJ+mru/XkaXLUCLE9abh9pO4+6TgElApXtJIuFav349f/7zn3n++efp1q0br7zyCkuWLGH27Nn86le/okWLFnTp0oU33niDBQsWMGzYMPLy8gBYu3YtCxcuZP/+/bRv354+ffrw9NNP07p1a2bMmMHQoUOZOnUqvXr14oorruDWW29l165d3HDDDdx0003cd999rFq1ir179wIwffp0hgwZorCXGhN24JuZAVOANe7+u3K6zQZGm9l0Sk7W7nX3Cg/niERDmzZtSE9PByA1NZV+/fphZqSnp1NQUMAXX3zBa6+9BkDfvn3ZuXMn+/aV/AE7cOBAkpOTSU5OpkmTJsyZM4err76amTNnMmjQIF5//XWOHTvGFVdcwcKFC3n//fepW7cu8+fPZ//+/QAMHz6c3/zmN1x33XW88MILTJ48OTZvhARCJK7D7wncDvQ1s7zQ42ozu9fM7g31mQtsANYDk4FREZhXJGzJycmlywkJCaXrCQkJFBUVVfm5iYmJHDt2DICUlBRatmzJkiVL+PGPf8zdd9/NBRdcwN/+9jcyMzPJy8tj9OjRAPTs2ZOCggIWLVrEsWPHSEtLi/RLFCkV9h6+uy8BrJI+DtwX7lwi0da7d2+mTZvGL3/5SxYtWkSjRo0499xzy+x76aWXcs8995CcnMysWbPo27cv27Zto23btrRp04bx48cD4O7s37+f4x9NPmzYMG655RZ++ctfRu11STDpTluRCuTk5LBixQoyMjIYO3YsU6dOLbdvu3btGD16NAUFBVx66aVcfPHF1KlTh4EDB7J//34KCgrIzc0lNTWVo0ePMmvWLDp37szFF1/M7t27ufnmm6P4yiSILJ6/ACUrK8t1lY6cKf754ZcsffNzDuw6QoNvJfO9QRfT7pKmzJw5kzfffJOXX3451iXKGcDMVrh7mdf7ag9fJAr++eGXLJy2lgO7jgBwYNcR5k/JY/C3v8tPhw3jgZ49Y1yhBIECXyQKlr75OUVfF5/UVpxQl4FX/hd/a9mK+s/9gb1z5sSoOgkKBb5IFBzfsz/VkeRvAeCHD7P9qQlRrEiCSIEvEgUNvpVcZnvykV2ly0VbdWuK1CwFvkgUfG/QxdSpe/KvW8KxI1y8YXbpep1m+jxBqVkR/fA0ESlbu0uaAoSu0jlMvSO7uejzN2m6veQqNKtXjyYPPhDDCiUIFPgiUdLukqalwb93zhy2P/Vvisyo06wZTR58gJQf/CDGFcqZToEvEgMpP/iBAl6iTsfwRUQCQoEvIhIQCnwRkYBQ4IuIBIQCX0QkIBT4IiIBocAXEQkIBb6ISEAo8EVEAkKBLyISEBEJfDN73sy2m9nKcrZnm9leM8sLPR6OxLwiIlJ1kfosnReBp4GXKuiz2N2vidB8IiJSTRHZw3f394BdlXYUEZGYieYx/O+Z2adm9lczSy2vk5mNMLNcM8stLCyMYnkiIme2aAX+x0Ard+8E/B54o7yO7j7J3bPcPatx48ZRKk9E5MwXlcB3933ufiC0PBdIMrNG0ZhbRERKRCXwzaypmVlouXto3p3RmFtEREpE5CodM/sTkA00MrPNwCNAEoC7PwcMBUaaWRFwCLjJ3T0Sc4uISNVEJPDd/eZKtj9NyWWbIiISI7rTVkQkIBT4IiIBocAXEQkIBb6ISEAo8EVEAkKBLyISEAp8EZGAUOCLiASEAl9EJCAU+CIiAaHAFxEJCAW+iEhAKPBFRAJCgS8iEhAKfBGRgFDgi4gEhAJfRCQgFPgiIgGhwBcRCYiIBL6ZPW9m281sZTnbzcwmmtl6M8s3s8xIzCsiIlUXqT38F4H+FWwfALQNPUYA/xOheUVEpIoiEvju/h6wq4Iug4CXvMQyoKGZNYvE3CIiUjXROoZ/IbDphPXNobbTmNkIM8s1s9zCwsKoFCciEgRxd9LW3Se5e5a7ZzVu3DjW5YiInDGiFfhbgBYnrDcPtYmISJREK/BnA8NCV+v0APa6+9YozS0iIkCdSAxiZn8CsoFGZrYZeARIAnD354C5wNXAeuAr4M5IzCsiIlUXkcB395sr2e7AfZGYS0REvpm4O2krIiI1Q4EvIhIQCnwRkYBQ4IuIBIQCX0QkIBT4IiIBocAXEQkIBb6ISEAo8EVEAkKBLyISEAp8EZGAUOCLiASEAl9EJCAU+CIiAaHAFxEJCAW+iEhAKPBFRAJCgS8iEhAKfBGRgIhI4JtZfzNbZ2brzWxsGdt/ZGaFZpYXegyPxLwiIlJ1YX+JuZklAs8A3wc2A8vNbLa7rz6l66vuPjrc+URE5JuJxB5+d2C9u29w96+B6cCgCIwrIiIRFInAvxDYdML65lDbqX5oZvlmNtPMWkRgXhERqYZonbSdA7R29wzgHWBqeR3NbISZ5ZpZbmFhYZTKExE580Ui8LcAJ+6xNw+1lXL3ne5+JLT6v0DX8gZz90nunuXuWY0bN45AeSIiApEJ/OVAWzNrY2Z1gZuA2Sd2MLNmJ6xeC6yJwLwiIlINYV+l4+5FZjYamAckAs+7+yozexTIdffZwP1mdi1QBOwCfhTuvCIiUj3m7rGuoVxZWVmem5sb6zJERGoNM1vh7lllbdOdtiIiEfbcc8/x0ksvRWSs1q1bs2PHjoiMFfYhHREROdm9994b6xLKpD18EZEquO666+jatSupqalMmjQJgAYNGjBu3Dg6depEjx492LZtGwA5OTk8+eSTAGRnZ/Pggw+SlZVFhw4dWL58OUOGDKFt27b853/+Z4XjR5oCX0SkCp5//nlWrFhBbm4uEydOZOfOnRw8eJAePXrw6aef0qdPHyZPnlzmc+vWrUtubi733nsvgwYN4plnnmHlypW8+OKL7Ny5s9zxI02BLyJSBRMnTizdk9+0aROfffYZdevW5ZprrgGga9euFBQUlPnca6+9FoD09HRSU1Np1qwZycnJXHTRRWzatKnc8SNNx/BFRCqxaNEi3n33XZYuXcrZZ59NdnY2hw8fJikpCTMDIDExkaKiojKfn5ycDEBCQkLp8vH1oqKicsePNO3hi4hUYu/evZx33nmcffbZrF27lmXLltWq8Y9T4IuIVKJ///4UFRXRoUMHxo4dS48ePWrV+MfpxisRkThx8JPt7JtXwLE9R0hsmMy5V7Wmfpcm1Rqj1t54tX37djp06MCtt95a5edcffXV7Nmzhz179vDss8/WYHUiIpFz8JPt7Hn9M47tKfmcyWN7jrDn9c84+Mn2iM0R13v49erV8/Xr19O8efPStqKiIurUqfxcc0FBAddccw0rV66syRJFRCJi6+MflYb9iRIbJtNsbPcqj1Mr9/DN7Lmvv/6aAQMGkJKSwu23307Pnj25/fbbefHFFxk9+v++LfGaa65h0aJFwP/dhjx27Fg+//xzOnfuzEMPPRSjVyGxcuq/EZF4V1bYV9T+TcTtZZnufm9ycvI9Cxcu5Omnn2bOnDksWbKEs846ixdffLHS5z/++OOsXLmSvLy8Gq9VRCRciQ2Ty93Dj5S43cM/1bXXXstZZ50V6zIkCsq7hf3BBx8kNTWVfv36cfzb0LKzsxkzZgydO3cmLS2Njz766LTxCgsL+eEPf0i3bt3o1q0b77//flRfj0hVnHtVayzp5Ei2pATOvap1xOaoNYFfv3790uU6depQXFxcul4TNyhI7JR3C3tWVharVq3isssuY/z48aX9v/rqK/Ly8nj22We56667ThtvzJgxPPjggyxfvpzXXnuN4cOHR/PliFRJ/S5NaDikbekefWLDZBoOaVvtq3QqEreHdCrSunVrnn32WYqLi9myZUuZe3XnnHMO+/fvj0F1Eq6JEycya9YsgNJbzBMSErjxxhsBuO222xgyZEhp/5tvvhmAPn36sG/fPvbs2XPSeO+++y6rV68uXd+3bx8HDhygQYMGNfxKak5OTg4NGjTgZz/7WaxLkQiq36VJRAP+VLUy8Hv27EmbNm3o2LEjHTp0IDMz87Q+559/Pj179iQtLY0BAwbwxBNPxKBSqa6q3mJ+/Hb2U5fLWi8uLmbZsmXUq1evZooWqSXi+pBOeno6jRo1Iicn56Q9GTNj2rRprF27llmzZrFo0SKys7MBeGbBM9yy6BYypmawY8gOfj371wr7WqS8W8yLi4uZOXMmAK+88gq9evUqfc6rr74KwJIlS0hJSSElJeWkMa+88kp+//vfl67X1hP5jz32GO3ataNXr16sW7cOKHktPXr0ICMjg8GDB7N7924Ali9fTkZGRulVamlpabEsXeJEXAd+db214S1yPshh68GtOM7Wg1vJ+SCHtza8FevSpIrKu8W8fv36fPTRR6SlpbFgwQIefvjh0ufUq1ePLl26cO+99zJlypTTxpw4cSK5ublkZGTQsWNHnnvuuai9nkhZsWIF06dPJy8vj7lz57J8+XIAhg0bxq9//Wvy8/NJT08vPbdx55138oc//IG8vDwSExNjWbrEkbi+8aq6H61w5cwr2Xpw62ntzeo34+2hb0eyNImyBg0acODAgdPas7OzefLJJ8nKKvM+E8ifAfMfhb2bIaU59HsYMm6o4Wojb8KECezatYtHH30UgJ/85CekpKQwZcoUNm7cCMDnn3/O9ddfz4IFC+jUqRNffPEFAPn5+dxyyy26CTEgavzGKzPrb2brzGy9mY0tY3uymb0a2v6hmbWOxLyn+vLgl9VqlzNc/gyYcz/s3QR4yc8595e0iwRQ2IFvZonAM8AAoCNws5l1PKXbfwC73f07wFPAr8OdtyxN6zetVrvUHmXt3UPJSd5y9+7nPwpHD53cdvRQSXst06dPH9544w0OHTrE/v37mTNnDvXr1+e8885j8eLFALz88stcdtllNGzYkHPOOYcPP/wQgOnTp8eydIkjkdjD7w6sd/cN7v41MB0YdEqfQcDU0PJMoJ+deilFBIzJHEO9xJOvxKiXWI8xmWMiPZXUBns3V689jmVmZnLjjTfSqVMnBgwYQLdu3QCYOnUqDz30EBkZGeTl5ZWe25gyZQp33303nTt35uDBg6edyJZgisRlmRcCm05Y3wxcUl4fdy8ys73A+cCOUwczsxHACICWLVtWq5CBFw0E4L8//m++PPglTes3ZUzmmNJ2CZiU5qHDOWW010Ljxo1j3Lhxp7WX9WUZqamp5OfnAyUfM1LuX0ESKHF3Hb67TwImQclJ2+o+f+BFAxXwUqLfwyXH7E88rJN0Vkn7GWzN4oX87r/GM3f5J5CQwHfatWfm7DmxLkviQCQCfwvQ4oT15qG2svpsNrM6QAoQ+a9kFznR8atxzoCrdKpqzeKFvD3pab57XgO+e2VvAOrUTWbH2pU0bnx5jKuTWItE4C8H2ppZG0qC/SbgllP6zAbuAJYCQ4EFHs/Xg8qZI+OGMzrgT7V4+ksUfX3yJy4WfX2ExdNfokNvBX7QhR34oWPyo4F5QCLwvLuvMrNHgVx3nw1MAV42s/XALkr+UxCRCNu/87TTYhW2S7BE5Bi+u88F5p7S9vAJy4eB6yMxl4iU75zzG7F/R2GZ7SJn1EcriARd75uGUafuyV+YUaduMr1vGhajiiSexN1VOiLyzR0/Tr94+kvs37mDc85vRO+bhun4vQAKfJEzTofelyvgpUw6pCMiEhAKfBGRgFDgi4gEhAJfRCQgFPgiIgGhwBcRCQgFvohIQCjwRUQCQoEvIhIQCnwRkYBQ4IuIBIQCX0QkIBT4IiIBocAXEQkIBb6ISEAo8EVEAiKswDezb5nZO2b2WejneeX0O2ZmeaHH7HDmFBGRbybcPfyxwHx3bwvMD62X5ZC7dw49rg1zThER+QbCDfxBwNTQ8lTgujDHExGRGhJu4F/g7ltDy18CF5TTr56Z5ZrZMjO7rqIBzWxEqG9uYWFhmOWJiMhxlX6JuZm9CzQtY9O4E1fc3c3MyxmmlbtvMbOLgAVm9g93/7ysju4+CZgEkJWVVd54IiJSTZUGvrtfUd42M9tmZs3cfauZNQO2lzPGltDPDWa2COgClBn4IiJSM8I9pDMbuCO0fAfw5qkdzOw8M0sOLTcCegKrw5xXRESqKdzAfxz4vpl9BlwRWsfMsszsf0N9OgC5ZvYpsBB43N0V+CIiUVbpIZ2KuPtOoF8Z7bnA8NDyB0B6OPOIiEj4dKetiEhAKPBFRAJCgS8iEhAKfBGRgFDgi4gEhAJfRCQgFPgiIgGhwBcRCQgFvohIQCjwRUQCQoEvIhIQCnwRkYBQ4IuIBIQCX0QkIBT4IiIBocAXEQkIBb6ISEAo8EVEAkKBLyISEGEFvpldb2arzKzYzLIq6NffzNaZ2XozGxvOnCIi8s2Eu4e/EhgCvFdeBzNLBJ4BBgAdgZvNrGOY84qISDXVCefJ7r4GwMwq6tYdWO/uG0J9pwODgNXhzC0iItUTjWP4FwKbTljfHGoTEZEoqnQP38zeBZqWsWmcu78Z6YLMbAQwAqBly5aRHl5EJLAqDXx3vyLMObYALU5Ybx5qK2++ScAkgKysLA9zbhERCYnGIZ3lQFsza2NmdYGbgNlRmFdERE4Q7mWZg81sM/A94C0zmxdq/7aZzQVw9yJgNDAPWAPMcPdV4ZUtIiLVFe5VOrOAWWW0/xu4+oT1ucDccOYSEZHw6E5bEZGAUOCLiASEAl9EJCAU+CIiAaHAFxEJCAW+iEhAKPBFRAJCgS8iEhAKfBGRgFDgi4gEhAJfRCQgFPgiIgGhwBcRCQgFvohIQCjwRUQCQoEvIhIQCnwRkYBQ4IuIBIQCX0QkIBT4IiIBEVbgm9n1ZrbKzIrNLKuCfgVm9g8zyzOz3HDmlG+uoKCAtLS0WJchIjFSJ8znrwSGAH+oQt/L3X1HmPOJiMg3FNYevruvcfd1kSommtyd4uLiWJcRdUVFRdx666106NCBoUOH8tVXX/Hoo4/SrVs30tLSGDFiBO4OQHZ2Nj//+c/p3r077dq1Y/HixUDJXwq9e/cmMzOTzMxMPvjgAwAWLVpEdnY2Q4cO5bvf/S633npr6VjlzSEi0ROtY/gOvG1mK8xsREUdzWyEmeWaWW5hYWFYk/7ud78jLS2NtLQ0JkyYQEFBAe3bt2fYsGGkpaWxadMmRo4cSVZWFqmpqTzyyCOlz23dujWPPPIImZmZpKens3btWgAKCwv5/ve/T2pqKsOHD6dVq1bs2FHyh8sf//hHunfvTufOnbnnnns4duxYWPXXhHXr1jFq1CjWrFnDueeey7PPPsvo0aNZvnw5K1eu5NChQ/zlL38p7V9UVMRHH33EhAkTGD9+PABNmjThnXfe4eOPP+bVV1/l/vvvL+3/ySefMGHCBFavXs2GDRt4//33ASqcQ0Sio9LAN7N3zWxlGY9B1Zinl7tnAgOA+8ysT3kd3X2Su2e5e1bjxo2rMcXJVqxYwQsvvMCHH37IsmXLmDx5Mrt37+azzz5j1KhRrFq1ilatWvHYY4+Rm5tLfn4+f//738nPzy8do1GjRnz88ceMHDmSJ598EoDx48fTt29fVq1axdChQ9m4cSMAa9as4dVXX+X9998nLy+PxMREpk2b9o3rryktWrSgZ8+eANx2220sWbKEhQsXcskll5Cens6CBQtYtWpVaf8hQ4YA0LVrVwoKCgA4evQod999N+np6Vx//fWsXr26tH/37t1p3rw5CQkJdO7cufQ5Fc0hItFR6TF8d78i3EncfUvo53YzmwV0B94Ld9yKLFmyhMGDB1O/fn2gJLgWL15Mq1at6NGjR2m/GTNmMGnSJIqKiti6dSurV68mIyOj9DlQEnavv/566bizZs0CoH///px33nkAzJ8/nxUrVtCtWzcADh06RJMmTWryJX4jZnba+qhRo8jNzaVFixbk5ORw+PDh0u3JyckAJCYmUlRUBMBTTz3FBRdcwKeffkpxcTH16tU7rf+Jzzl8+HCFc4hIdNT4IR0zq29m5xxfBq6k5GRvTBz/DwDgX//6F08++STz588nPz+fgQMHVhp25XF37rjjDvLy8sjLy2PdunXk5OTUyGsIx8aNG1m6dCkAr7zyCr169QJK/po5cOAAM2fOrHSMvXv30qxZMxISEnj55ZcrPXR1/D2tzhwiEnnhXpY52Mw2A98D3jKzeaH2b5vZ3FC3C4AlZvYp8BHwlrv/LZx5q6J379688cYbfPXVVxw8eJBZs2bRu3fvk/rs27eP+vXrk5KSwrZt2/jrX/9a6bg9e/ZkxowZALz99tvs3r0bgH79+jFz5ky2b98OwK5du/jiiy8i/KrC1759e5555hk6dOjA7t27GTlyJHfffTdpaWlcddVVpX+hVGTUqFFMnTqVTp06sXbt2pP+Ey1Lw4YNqz2HiNQAd4/bR9euXT0cv/3tbz01NdVTU1P9qaee8n/961+empp6Up877rjD27Zt63379vXBgwf7Cy+84O7urVq18sLCQnd3X758uV922WXu7r5t2zbv27evp6am+vDhw71p06Z++PBhd3efPn26d+rUydPT0z0zM9OXLl0aVv0iItUF5Ho5mWoex5fHZWVleW5ufN2ndeTIERITE6lTpw5Lly5l5MiR5OXlkZ+fz/z589m7dy8pKSn069ev9FxA0L3xyRaemLeOf+85xLcbnsVDV7Xnui4XxroskTOSma1w9zJvhA33xqvA2bhxIzfccAPFxcXUrVuXyZMnk5+fz5w5czh69ChQcox7zpw5AIEP/Tc+2cIvXv8Hh46WHOffsucQv3j9HwAKfZEoU+BXU9u2bfnkk09OanvqqadKw/64o0ePMn/+/MAH/hPz1pWG/XGHjh7jiXnrFPgiUaYPT4uAvXv3Vqs9SP6951C12kWk5ijwIyAlJaVa7UHy7YZnVatdRGqOAj8C+vXrR1JS0kltSUlJ9OvXL0YVxY+HrmrPWUmJJ7WdlZTIQ1e1j1FFIsGlY/gRcPw4va7SOd3x4/S6Skck9nRZpojIGaSiyzJ1SEdEJCAU+CIiAaHAFxEJCAW+iEhAKPBFRAIirq/SMbP9QK38ztwoagToy+Erp/epavQ+VS7e36NW7l7m1wXG+3X468q7vEhKmFmu3qPK6X2qGr1PlavN75EO6YiIBIQCX0QkIOI98CfFuoBaQO9R1eh9qhq9T5Wrte9RXJ+0FRGRyIn3PXwREYmQuA58M3vCzNaaWb6ZzTKzhrGuKR6Z2fVmtsrMis2sVl49UJPMrL+ZrTOz9WY2Ntb1xCMze97MtpvZyljXEq/MrIWZLTSz1aHftzGxrqm64jrwgXeANHfPAP4J/CLG9cSrlcAQ4L1YFxJvzCwReAYYAHQEbjazjrGtKi69CPSPdRFxrgj4qbt3BHoA99W2f0txHfju/ra7F4VWlwHNY1lPvHL3Ne6uG9TK1h1Y7+4b3P1rYDowKMY1xR13fw/YFes64pm7b3X3j0PL+4E1QK36Yoe4DvxT3AX8NdZFSK1zIbDphPXN1LJfUok/ZtYa6AJ8GONSqiXmd9qa2btA0zI2jXP3N0N9xlHy59S0aNYWT6ryPolIzTOzBsBrwAPuvi/W9VRHzAPf3a+oaLuZ/Qi4BujnAb6GtLL3Scq1BWhxwnrzUJtItZlZEiVhP83dX491PdUV14d0zKw/8P+Aa939q1jXI7XScqCtmbUxs7rATcDsGNcktZCZGTAFWOPuv4t1Pd9EXAc+8DRwDvCOmeWZ2XOxLigemdlgM9sMfA94y8zmxbqmeBE66T8amEfJSbYZ7r4qtlXFHzP7E7AUaG9mm83sP2JdUxzqCdwO9A3lUZ6ZXR3roqpDd9qKiAREvO/hi4hIhCjwRUQCQoEvIhIQCnwRkYBQ4IuIBIQCX0QkIBT4IiIBocAXEQmI/w81nm1VM/4ougAAAABJRU5ErkJggg==", 180 | "text/plain": [ 181 | "
" 182 | ] 183 | }, 184 | "metadata": { 185 | "needs_background": "light" 186 | }, 187 | "output_type": "display_data" 188 | } 189 | ], 190 | "source": [ 191 | "for i, label in enumerate(word_list):\n", 192 | " W, WT = model.get_parameters()\n", 193 | " x, y = W[0][i].asnumpy(), W[1][i].asnumpy()\n", 194 | " plt.scatter(x, y)\n", 195 | " plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')\n", 196 | "plt.show()" 197 | ] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "Python 3.7.13 ('ms1.8')", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.7.13" 217 | }, 218 | "vscode": { 219 | "interpreter": { 220 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 221 | } 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 4 226 | } 227 | -------------------------------------------------------------------------------- /4-2.Seq2Seq(Attention)/Seq2Seq-Attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 103, 6 | "id": "8a0766e3-3816-4303-af74-2dfc50ba5c62", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import mindspore\n", 12 | "import mindspore.nn as nn\n", 13 | "import mindspore.ops as ops\n", 14 | "import mindspore.numpy as mnp\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from mindspore import ms_function" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 104, 22 | "id": "f4888e33-7ea9-437a-be7d-3151891ed97d", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# S: Symbol that shows starting of decoding input\n", 27 | "# E: Symbol that shows starting of decoding output\n", 28 | "# P: Symbol that will fill in blank sequence if current batch data size is short than time steps\n", 29 | "\n", 30 | "def make_batch():\n", 31 | " input_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[0].split()]]]\n", 32 | " output_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[1].split()]]]\n", 33 | " target_batch = [[word_dict[n] for n in sentences[2].split()]]\n", 34 | "\n", 35 | " # make tensor\n", 36 | " return mindspore.Tensor(input_batch), mindspore.Tensor(output_batch), \\\n", 37 | " mindspore.Tensor(target_batch, mindspore.int32)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 105, 43 | "id": "0242cf52-c02a-4670-ab34-e376cb140744", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "class Attention(nn.Cell):\n", 48 | " def __init__(self):\n", 49 | " super(Attention, self).__init__()\n", 50 | " self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n", 51 | " self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)\n", 52 | "\n", 53 | " # Linear for attention\n", 54 | " self.attn = nn.Dense(n_hidden, n_hidden)\n", 55 | " self.out = nn.Dense(n_hidden * 2, n_class)\n", 56 | "\n", 57 | " def construct(self, enc_inputs, dec_inputs):\n", 58 | " enc_inputs = enc_inputs.swapaxes(0, 1) # enc_inputs: [n_step(=n_step, time step), batch_size, n_class]\n", 59 | " dec_inputs = dec_inputs.swapaxes(0, 1) # dec_inputs: [n_step(=n_step, time step), batch_size, n_class]\n", 60 | "\n", 61 | " # enc_outputs : [n_step, batch_size, num_directions(=1) * n_hidden], matrix F\n", 62 | " # enc_hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n", 63 | " enc_outputs, enc_hidden = self.enc_cell(enc_inputs)\n", 64 | "\n", 65 | " trained_attn = []\n", 66 | " hidden = enc_hidden\n", 67 | " n_step = len(dec_inputs)\n", 68 | " model = []\n", 69 | "\n", 70 | " for i in range(n_step): # each time step\n", 71 | " # dec_output : [n_step(=1), batch_size(=1), num_directions(=1) * n_hidden]\n", 72 | " # hidden : [num_layers(=1) * num_directions(=1), batch_size(=1), n_hidden]\n", 73 | " dec_output, hidden = self.dec_cell(dec_inputs[i].expand_dims(0), hidden)\n", 74 | " attn_weights = self.get_att_weight(dec_output, enc_outputs) # attn_weights : [1, 1, n_step]\n", 75 | " trained_attn.append(attn_weights.squeeze())\n", 76 | "\n", 77 | " # matrix-matrix product of matrices [1,1,n_step] x [1,n_step,n_hidden] = [1,1,n_hidden]\n", 78 | " context = ops.matmul(attn_weights, enc_outputs.swapaxes(0, 1))\n", 79 | " dec_output = dec_output.squeeze(0) # dec_output : [batch_size(=1), num_directions(=1) * n_hidden]\n", 80 | " context = context.squeeze(1) # [1, num_directions(=1) * n_hidden]\n", 81 | " out = self.out(ops.concat((dec_output, context), 1))\n", 82 | " model.append(out)\n", 83 | " \n", 84 | " model = ops.stack(model)\n", 85 | "\n", 86 | " # make model shape [n_step, n_class]\n", 87 | " return model.swapaxes(0, 1).squeeze(0), trained_attn\n", 88 | "\n", 89 | " def get_att_weight(self, dec_output, enc_outputs): # get attention weight one 'dec_output' with 'enc_outputs'\n", 90 | " n_step = len(enc_outputs)\n", 91 | " attn_scores = ops.zeros(n_step, mindspore.float32) # attn_scores : [n_step]\n", 92 | "\n", 93 | " for i in range(n_step):\n", 94 | " attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i])\n", 95 | "\n", 96 | " # Normalize scores to weights in range 0 to 1\n", 97 | " return ops.Softmax()(attn_scores).view(1, 1, -1)\n", 98 | "\n", 99 | " def get_att_score(self, dec_output, enc_output): # enc_outputs [batch_size, num_directions(=1) * n_hidden]\n", 100 | " score = self.attn(enc_output) # score : [batch_size, n_hidden]\n", 101 | " return mnp.dot(dec_output.view(-1), score.view(-1)) # inner product make scalar valuek" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 106, 107 | "id": "ad18877e-322a-42af-97bf-f13b561760d3", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "n_step = 5 # number of cells(= number of Step)\n", 112 | "n_hidden = 128 # number of hidden units in one cell\n", 113 | "\n", 114 | "sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']\n", 115 | "\n", 116 | "word_list = \" \".join(sentences).split()\n", 117 | "word_list = list(set(word_list))\n", 118 | "word_dict = {w: i for i, w in enumerate(word_list)}\n", 119 | "number_dict = {i: w for i, w in enumerate(word_list)}\n", 120 | "n_class = len(word_dict) # vocab list" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 107, 126 | "id": "2ef50a8e-f4c0-44c8-82cd-750cc7ba3d5a", 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stderr", 131 | "output_type": "stream", 132 | "text": [ 133 | "[WARNING] ME(258161:139710830719232,MainProcess):2022-08-12-21:41:32.952.304 [mindspore/nn/layer/rnns.py:392] dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n", 134 | "[WARNING] ME(258161:139710830719232,MainProcess):2022-08-12-21:41:32.961.000 [mindspore/nn/layer/rnns.py:392] dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "model = Attention()\n", 140 | "criterion = nn.CrossEntropyLoss()\n", 141 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 108, 147 | "id": "612ce97a-4fc3-4953-b641-bcd9f800f700", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "input_batch, output_batch, target_batch = make_batch()" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 109, 157 | "id": "002ae643", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "def forward(enc_input, dec_input, target):\n", 162 | " output, attn = model(enc_input, dec_input)\n", 163 | " loss = criterion(output, target.squeeze(0))\n", 164 | " return loss, attn" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 110, 170 | "id": "47904591", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters, has_aux=True)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 111, 180 | "id": "843167f2", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "@ms_function\n", 185 | "def train_step(enc_input, dec_input, target):\n", 186 | " (loss, _), grads = grad_fn(enc_input, dec_input, target)\n", 187 | " optimizer(grads)\n", 188 | " return loss" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 112, 194 | "id": "5f192fcb-4bad-4b97-9611-e24c8a0d966d", 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "Epoch: 0400 cost = 0.000808\n", 202 | "Epoch: 0800 cost = 0.000266\n", 203 | "Epoch: 1200 cost = 0.000135\n", 204 | "Epoch: 1600 cost = 0.000080\n", 205 | "Epoch: 2000 cost = 0.000053\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "# Train\n", 211 | "for epoch in range(2000):\n", 212 | " loss = train_step(input_batch, output_batch, target_batch)\n", 213 | " if (epoch + 1) % 400 == 0:\n", 214 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 113, 220 | "id": "fa47c242-103d-471d-983d-2480830b7d6c", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "ich mochte ein bier P -> ['i', 'want', 'a', 'beer', 'E']\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "# Test\n", 233 | "test_batch = [np.eye(n_class)[[word_dict[n] for n in 'SPPPP']]]\n", 234 | "test_batch = mindspore.Tensor(test_batch)\n", 235 | "predict, trained_attn = model(input_batch, test_batch)\n", 236 | "predict = predict.argmax(1)\n", 237 | "print(sentences[0], '->', [number_dict[int(n.asnumpy())] for n in predict])" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 114, 243 | "id": "fc60f4c8-461b-4640-97f2-b50fb6da5f19", 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stderr", 248 | "output_type": "stream", 249 | "text": [ 250 | "/home/lvyufeng/miniconda3/envs/ms1.8/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: FixedFormatter should only be used together with FixedLocator\n", 251 | " \"\"\"\n", 252 | "/home/lvyufeng/miniconda3/envs/ms1.8/lib/python3.7/site-packages/ipykernel_launcher.py:6: UserWarning: FixedFormatter should only be used together with FixedLocator\n", 253 | " \n" 254 | ] 255 | }, 256 | { 257 | "data": { 258 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUcAAAE2CAYAAADyN1APAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAARXUlEQVR4nO3de9BcdX3H8fcHEkK5RBvFQUFAwRtV6UC4eCUdaKF17HSU0dFCFWcMaK2oeGlrUTvqpN4GrCiayhidYkcdqVS8U8mgU67aVjFYUEHuN7kTSAJ++8ee6LL+EvI8ye5ZnrxfMzvJc/bsnu959nneOedskidVhSTpobbpewBJmkbGUZIajKMkNRhHSWowjpLUYBwlqcE4DkmyIsnZm7DeXkkqyeJJzNWHbv+O6nuOzfVI2o8kK5OcOtv7tWXN63uAKXMCkL6HeCRIshdwJXBgVV3S8zgb83jg9r6H2EJeAqzre4hxSLICeFX34QPANcCZwLur6t4+ZjKOQ6rqzr5n0JZVVTf2PcOWUlW3be5zJJlfVdMa2HOAY4D5wAuATwM7Aq/rYxhPq4cMn1Zn4MQkVyRZk+TaJMtGHrJnku8kWZ1kVZI/HtNcK5OcluQjSW5LckuSE5IsSPLxJHckuTrJMUOPeVaSc5Lc1z1mRZJHjTzvq5L8uNu/m5J8dmTTi5J8Kcm9SX6R5Oih+67sfr24O3VdOfS8x3afj/uTXJ7kzUnG8rXWvU5vT/Lzbl9/PDzn8Gn10OWQl07idZuleUk+muT27vah9Z+70dPqJNsl+UD3tbk6ycVJjhi6f0m3v3+W5KIka4EjGtucFmuq6saquqaqPg+cAfxFb9NUlbfuBqwAzu5+vwy4A3gNsA/wHOD13X17AQX8FHgx8BTgs8CvgJ3GMNdK4C7gPd22Tuy2/w0GlwL2Ad4LrGFwGrkjcD3wFeBZwKHA5cCXh57zOOB+4C3A04ADgLcN3V/AtcDR3fMvA9YCe3T3H9itcwSwK7CoW/5a4AbgKOBJ3efnRuANY3rN3g/8H3Bkt71XAvcCLxraj6P6eN1m+TrfDXwMeDrwMuBO4C1D9586tP4ZwAXAC4EnA2/oXqP9uvuXdPv7Y+BPunV26Xs/H+57b2jZPwO39jZT35+Uabqtf4GAnbpwHL+B9dZ/kx03tGy3btnzxzDXSuD8oY8D3AL8x9Cy+d03xlFdoO4Edh66f/03yj7dx9cC/7SRbRawbOjjecBq4OiRz8HikcddDRwzsuxNwKoxfF52BO4DXjCy/BTg60P7MRrHibxus3ydLwcytOwfgGuH7j+1+/3ewK/p/rAaWv8rwCdGXvOX9r1vm7DvD4kjcBBwK/CFvmbymmPbvsAC4D8fZr0fDf3++u7Xx41loqFtVVUluZnBEcH6ZeuS3N5tfx/gR1V199Dj/4vBN9O+Se5iEIVN3r+qeiDJLWxk/5LsAjwR+FSS04bumsd43ujaF9ge+GaS4f9BZT5w1UYeN8nXbaYuqK4OnfOB9yZZOLLe/gw+p6uSh3xqFwDfHVl3mt8wG3ZkknsYfL3MB84C/qavYYzj5vnNhe0uWDC+67ijF9FrA8sebvsz+W+YZvr86+87nkGMx2399l7M4Ih12MbedJjk6zYu2zB4PQ7kd/f1vpGPe3m3dxbOA5Yy2J/rq+c3joxj22UMrt8dBlzR8yyzcRnwmiQ7Dx09PpfBN9RlVXVzkusY7N93ZrmNtd2v265fUFU3Jbke2LuqPjfL552JVQxepz2ravRo6ZHq4CQZOno8hEEo7ho5QvxvBkeOu1bVuZMeckxWV9XP+h5iPePYUFV3J/kosCzJGgZ/oj0GOKCqTtv4o6fCGcA/Ap9L8i7g94FPAWcOffG9Hzg5yU3A14AdgMOq6iObuI2bGRyhHJHkKuD+GvxVqHcDH0tyB/B1BqdH+wO7VdXou/2bpXudPgx8OINynMfgevEhwK+ravmW3N6EPAE4JcknGLyZ9jbgfaMrVdXlSc4AViQ5EfghsIjBdcZfVNWZkxt5bjKOG/Z3DP7y8EnA7sBNwCSOhjZbVa3u/krHKcBFDN5cOovBO9vr1zmt+6sdJwIfAG5jELNN3cYDSd4IvItBEL8HLKmqTye5l8E39TIGAf0JMK5/2XESg9fmrcBpDN7V/x/gg2Pa3ridweBo/EIGp82nAydvYN1jgXcy2NfdGbyGFwFz5UiyV3notV9JEjzyLkJL0kQYR0lqMI6S1GAcJanBOEpSg3GUpAbjOENJlvY9wzjM1f2Cubtv7td4GceZm4oXbgzm6n7B3N0392uMjKMkNcyJfyGzXRbU9uw4kW2tYw3zWTCRbU3SXN0vmOy+5WnzJ7IdgLV33Md2j/69iW1v53n3T2Q79962lh0XbTeRbQFc95O7bq2qXUaXz4l/W709O3JwDut7DM3ENts+/DqPQPOXT8t/C7nlHfrYy/seYSz+9g++9cvWck+rJanBOEpSg3GUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNUx3HJCuSnN33HJK2PtP+0wdPANL3EJK2PlMdx6q6s+8ZJG2dPK2WpIapjqMk9WWqT6s3JslSYCnA9uzQ8zSS5ppH7JFjVS2vqsVVtXg+C/oeR9Ic84iNoySNk3GUpAbjKEkNxlGSGqb63eqqenXfM0jaOnnkKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNxlGSGoyjJDVM9c+Q0dyV+XPzS2+7bR7oe4SxufOBHfoeYaI8cpSkBuMoSQ3GUZIajKMkNRhHSWowjpLUYBwlqcE4SlKDcZSkBuMoSQ3GUZIajKMkNRhHSWowjpLUYBwlqcE4SlKDcZSkBuMoSQ3GUZIajKMkNRhHSWowjpLUYBwlqWEq45hkZZJT+55D0tZrKuMoSX172DgmOTLJ3UnmdR/vk6SSfHJonfclOSfJtklOT3JlkvuSXJHk7Um2GVp3RZKzk5yQ5Loktyf5TJId1t8PHAr8dbedSrLXlt5xSdqYeZuwzveB7YHFwAXAEuDW7tf1lgDfZBDb64CXAbcABwHLgV8Bpw+t/wLgBuBw4InAF4HLgWXACcBTgZ8Cf9+tf8vMdkuSNs/DHjlW1T3AD4A/6hYtAU4F9kzy+O6I70BgZVWtq6p3VdXFVXVVVX0R+CTwipGnvQs4vqouq6pvA18CDuu2dyewFlhdVTd2twdH50qyNMklSS5Zx5rZ7LskbdCmXnNcyW+PFA8FvgFc2C17LvAAcBFAkuO7aN2S5B7gzcAeI8+3aiR41wOPm8ngVbW8qhZX1eL5LJjJQyXpYc0kjs9L8gxgIYMjyZUMjiaXAOdX1dokLwdOAVYARwB/CHwC2G7k+daNfFwzmEWSxm5TrjnC4LrjAuDtwPer6sEkK4F/AW5icL0R4PnAhVX1m7+Gk2TvWcy1Fth2Fo+TpC1ik47Whq47Hg2c2y2+ANgdOITBUSQM3lTZP8mfJnlKkpMYnIbP1FXAQUn2SvLY4Xe7JWkSZhKdlQyONFcCVNX9DK47rqG73gh8isE7z58HLgb2Aj4yi7k+zODocRWDd6pHr1lK0lilqvqeYbMtzKI6OIf1PYZmIAvm5ptoO3xnYd8jjM2+C2/se4SxWLbfv/+gqhaPLvd0VZIajKMkNRhHSWowjpLUYBwlqcE4SlKDcZSkBuMoSQ3GUZIajKMkNRhHSWowjpLUYBwlqcE4SlKDcZSkBuMoSQ3GUZIajKMkNRhHSWowjpLUsKk/t1raorZZODd/ENW96+bmDw4DuO/B+X2PMFEeOUpSg3GUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSw1TFMcmRSb6X5PYktyX5VpJn9D2XpK3PVMUR2BE4BTgIWALcCXw1yXajKyZZmuSSJJesY81Eh5Q0983re4BhVfXl4Y+THAvcxSCW3x9ZdzmwHGBhFtWkZpS0dZiqI8ckeyf5fJKfJ7kLuInBjHv0PJqkrcxUHTkCZwPXAscB1wEPAKuA3zmtlqRxmpo4JnkM8HTg9VV1brdsf6ZoRklbj2kKz+3ArcBrk1wD7AZ8iMHRoyRN1NRcc6yqXwMvB54NXAp8HDgJfCta0uRN05EjVfVd4Jkji3fqYxZJW7epOXKUpGliHCWpwThKUoNxlKQG4yhJDcZRkhqMoyQ1GEdJajCOktRgHCWpwThKUoNxlKQG4yhJDcZRkhqMoyQ1GEdJajCOktRgHCWpwThKUsNU/QyZ2Vqzxw5c/s6D+h5ji7vyz5f3PcLYHLH7AX2PMB6HPdj3BGNzad8DTJhHjpLUYBwlqcE4SlKDcZSkBuMoSQ3GUZIajKMkNRhHSWowjpLUYBwlqcE4SlKDcZSkBuMoSQ3GUZIajKMkNRhHSWowjpLUYBwlqcE4SlKDcZSkBuMoSQ3GUZIaZhTHJCuTnDquYSRpWnjkKEkNUx/HJNv1PYOkrc9s4jgvyUeT3N7dPpRkGxiELMkHklybZHWSi5McMfzgJPsm+VqSu5PcnOTfkuw6dP+KJGcneUeSa4FrN28XJWnmZhPHv+we9xzgOGAp8Kbuvs8AhwKvBJ4JfBb4apL9AJI8HjgPuBQ4CDgc2Ak4a31gO4cCzwaOBA6bxYyStFnmzeIxNwBvrKoCfprkqcBbkpwFvALYq6qu7tY9NcnhDCL6euB1wP9W1TvWP1mSvwJuAxYDF3WL7wdeU1VrNjREkqUMwsy2ix49i92QpA2bzZHjBV0Y1zsf2A14PhBgVZJ71t+AFwF7d+seALxw5P5ruvv2HnrOSzcWRoCqWl5Vi6tq8bY77TiL3ZCkDZvNkePGFHAgsG5k+X3dr9sAXwPe2njsTUO/v3cLzyVJMzKbOB6cJENHj4cA1zM4ggywa1Wdu4HH/hB4GfDLqhoNqCRNjdmcVj8BOCXJ05IcBbwNOLmqLgfOAFYkOSrJk5MsTvLWJC/pHvtx4FHAF5Ic3K1zeJLlSXbeInskSVvAbI4czwC2BS5kcBp9OnByd9+xwDuBDwK7M3ij5SLgXICquj7J84BlwDeB7YGrgW8DG73GKEmTNKM4VtWSoQ/f0Lh/HfCe7rah57gCOGoj9796JjNJ0jhM/b+QkaQ+GEdJajCOktRgHCWpwThKUoNxlKQG4yhJDcZRkhqMoyQ1GEdJajCOktRgHCWpwThKUoNxlKQG4yhJDcZRkhqMoyQ1GEdJajCOktSQ3/6E1UeuhVlUB+ewvsfQTCR9TzAW15/5jL5HGJtP7vevfY8wFi980i9+UFWLR5d75ChJDcZRkhqMoyQ1GEdJajCOktRgHCWpwThKUoNxlKQG4yhJDcZRkhqMoyQ1GEdJajCOktRgHCWpwThKUoNxlKQG4yhJDcZRkhqMoyQ1GEdJajCOktRgHCWpwThKUsPUxDHJiiTVuF3Q92yStj7z+h5gxDnAMSPL1vYxiKSt27TFcU1V3dj3EJI0NafVkjRNpi2ORya5Z+T2gdaKSZYmuSTJJetYM+k5Jc1x03ZafR6wdGTZHa0Vq2o5sBxgYRbVeMeStLWZtjiurqqf9T2EJE3babUkTYVpO3JckGTXkWUPVtUtvUwjaas1bXE8HLhhZNl1wO49zCJpKzY1p9VV9eqqSuNmGCVN3NTEUZKmiXGUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNxlGSGoyjJDUYR0lqMI6S1GAcJanBOEpSg3GUpAbjKEkNxlGSGlJVfc+w2ZLcAvxyQpt7LHDrhLY1SXN1v2Du7pv7tWXsWVW7jC6cE3GcpCSXVNXivufY0ubqfsHc3Tf3a7w8rZakBuMoSQ3GceaW9z3AmMzV/YK5u2/u1xh5zVGSGjxylKQG4yhJDcZRkhqMoyQ1GEdJavh/Q8WGqu6Ak48AAAAASUVORK5CYII=", 259 | "text/plain": [ 260 | "
" 261 | ] 262 | }, 263 | "metadata": { 264 | "needs_background": "light" 265 | }, 266 | "output_type": "display_data" 267 | } 268 | ], 269 | "source": [ 270 | "# Show Attention\n", 271 | "fig = plt.figure(figsize=(5, 5))\n", 272 | "ax = fig.add_subplot(1, 1, 1)\n", 273 | "ax.matshow([attn.asnumpy() for attn in trained_attn], cmap='viridis')\n", 274 | "ax.set_xticklabels([''] + sentences[0].split(), fontdict={'fontsize': 14})\n", 275 | "ax.set_yticklabels([''] + sentences[2].split(), fontdict={'fontsize': 14})\n", 276 | "plt.show()" 277 | ] 278 | } 279 | ], 280 | "metadata": { 281 | "kernelspec": { 282 | "display_name": "Python 3.7.13 ('ms1.8')", 283 | "language": "python", 284 | "name": "python3" 285 | }, 286 | "language_info": { 287 | "codemirror_mode": { 288 | "name": "ipython", 289 | "version": 3 290 | }, 291 | "file_extension": ".py", 292 | "mimetype": "text/x-python", 293 | "name": "python", 294 | "nbconvert_exporter": "python", 295 | "pygments_lexer": "ipython3", 296 | "version": "3.7.13" 297 | }, 298 | "vscode": { 299 | "interpreter": { 300 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 301 | } 302 | } 303 | }, 304 | "nbformat": 4, 305 | "nbformat_minor": 5 306 | } 307 | -------------------------------------------------------------------------------- /5-2.BERT/BERT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "0d87f0c5-ea59-4411-8e02-1b338a2dee30", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import re\n", 11 | "from random import *\n", 12 | "import mindspore\n", 13 | "import mindspore.nn as nn\n", 14 | "import mindspore.ops as ops\n", 15 | "import mindspore.numpy as mnp\n", 16 | "from layers import Dense, Embedding" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "id": "cd4a9097-c20d-45d3-910c-cd1bb7f32a6a", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# sample IsNext and NotNext to be same in small batch size\n", 27 | "def make_batch():\n", 28 | " batch = []\n", 29 | " positive = negative = 0\n", 30 | " while positive != batch_size/2 or negative != batch_size/2:\n", 31 | " tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences\n", 32 | " tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]\n", 33 | " input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]\n", 34 | " segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)\n", 35 | "\n", 36 | " # MASK LM\n", 37 | " n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence\n", 38 | " cand_maked_pos = [i for i, token in enumerate(input_ids)\n", 39 | " if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]\n", 40 | " shuffle(cand_maked_pos)\n", 41 | " masked_tokens, masked_pos = [], []\n", 42 | " for pos in cand_maked_pos[:n_pred]:\n", 43 | " masked_pos.append(pos)\n", 44 | " masked_tokens.append(input_ids[pos])\n", 45 | " if random() < 0.8: # 80%\n", 46 | " input_ids[pos] = word_dict['[MASK]'] # make mask\n", 47 | " elif random() < 0.5: # 10%\n", 48 | " index = randint(0, vocab_size - 1) # random index in vocabulary\n", 49 | " input_ids[pos] = word_dict[number_dict[index]] # replace\n", 50 | "\n", 51 | " # Zero Paddings\n", 52 | " n_pad = maxlen - len(input_ids)\n", 53 | " input_ids.extend([0] * n_pad)\n", 54 | " segment_ids.extend([0] * n_pad)\n", 55 | "\n", 56 | " # Zero Padding (100% - 15%) tokens\n", 57 | " if max_pred > n_pred:\n", 58 | " n_pad = max_pred - n_pred\n", 59 | " masked_tokens.extend([0] * n_pad)\n", 60 | " masked_pos.extend([0] * n_pad)\n", 61 | "\n", 62 | " if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:\n", 63 | " batch.append([input_ids, segment_ids, masked_tokens, masked_pos, 1]) # IsNext\n", 64 | " positive += 1\n", 65 | " elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:\n", 66 | " batch.append([input_ids, segment_ids, masked_tokens, masked_pos, 0]) # NotNext\n", 67 | " negative += 1\n", 68 | " return batch\n", 69 | "# Proprecessing Finished" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "id": "f024fd22-1226-40c1-a99c-cf1eb89471be", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "def get_attn_pad_mask(seq_q, seq_k):\n", 80 | " batch_size, len_q = seq_q.shape\n", 81 | " batch_size, len_k = seq_k.shape\n", 82 | " \n", 83 | " pad_attn_mask = ops.equal(seq_k, 0)\n", 84 | " pad_attn_mask = pad_attn_mask.expand_dims(1) # batch_size x 1 x len_k(=len_q), one is masking\n", 85 | "\n", 86 | " return ops.broadcast_to(pad_attn_mask, (batch_size, len_q, len_k)) # batch_size x len_q x len_k" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "id": "55ddeca3-e465-4873-9677-f27a21e335e9", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "class BertEmbedding(nn.Cell):\n", 97 | " def __init__(self):\n", 98 | " super(BertEmbedding, self).__init__()\n", 99 | " self.tok_embed = Embedding(vocab_size, d_model) # token embedding\n", 100 | " self.pos_embed = Embedding(maxlen, d_model) # position embedding\n", 101 | " self.seg_embed = Embedding(n_segments, d_model) # segment(token type) embedding\n", 102 | " self.norm = nn.LayerNorm([d_model,])\n", 103 | "\n", 104 | " def construct(self, x, seg):\n", 105 | " seq_len = x.shape[1]\n", 106 | " pos = ops.arange(seq_len, dtype=mindspore.int64)\n", 107 | " pos = pos.expand_dims(0).expand_as(x) # (seq_len,) -> (batch_size, seq_len)\n", 108 | " embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n", 109 | " return self.norm(embedding)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "id": "f824473a-a320-42f1-827d-cb340b92a7c0", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "class ScaledDotProductAttention(nn.Cell):\n", 120 | " def __init__(self):\n", 121 | " super(ScaledDotProductAttention, self).__init__()\n", 122 | " self.softmax = nn.Softmax(axis=-1)\n", 123 | "\n", 124 | " def construct(self, Q, K, V, attn_mask):\n", 125 | " scores = ops.matmul(Q, K.swapaxes(-1, -2)) / ops.sqrt(ops.scalar_to_tensor(d_k)) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n", 126 | " scores = scores.masked_fill(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.\n", 127 | " attn = self.softmax(scores)\n", 128 | " context = ops.matmul(attn, V)\n", 129 | " return context, attn" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "id": "ef49a217-d48b-4a89-babd-ee2722745316", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "class MultiHeadAttention(nn.Cell):\n", 140 | " def __init__(self):\n", 141 | " super(MultiHeadAttention, self).__init__()\n", 142 | " self.W_Q = Dense(d_model, d_k * n_heads)\n", 143 | " self.W_K = Dense(d_model, d_k * n_heads)\n", 144 | " self.W_V = Dense(d_model, d_v * n_heads)\n", 145 | " self.attn = ScaledDotProductAttention()\n", 146 | " self.out_fc = Dense(n_heads * d_v, d_model)\n", 147 | " self.norm = nn.LayerNorm([d_model,])\n", 148 | "\n", 149 | " def construct(self, Q, K, V, attn_mask):\n", 150 | " # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]\n", 151 | " residual, batch_size = Q, Q.shape[0]\n", 152 | " # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)\n", 153 | " q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).swapaxes(1,2) # q_s: [batch_size x n_heads x len_q x d_k]\n", 154 | " k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).swapaxes(1,2) # k_s: [batch_size x n_heads x len_k x d_k]\n", 155 | " v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).swapaxes(1,2) # v_s: [batch_size x n_heads x len_k x d_v]\n", 156 | "\n", 157 | " attn_mask = attn_mask.expand_dims(1)\n", 158 | " attn_mask = ops.tile(attn_mask, (1, n_heads, 1, 1))\n", 159 | " \n", 160 | " # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n", 161 | " context, attn = self.attn(q_s, k_s, v_s, attn_mask)\n", 162 | " context = context.swapaxes(1, 2).view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]\n", 163 | " output = self.out_fc(context)\n", 164 | " return self.norm(output + residual), attn # output: [batch_size x len_q x d_model]" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 7, 170 | "id": "f3ac0741-2515-413c-8aec-67a6ab654f44", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "class PoswiseFeedForwardNet(nn.Cell):\n", 175 | " def __init__(self):\n", 176 | " super(PoswiseFeedForwardNet, self).__init__()\n", 177 | " self.fc1 = Dense(d_model, d_ff)\n", 178 | " self.fc2 = Dense(d_ff, d_model)\n", 179 | " self.activation = nn.GELU(False)\n", 180 | "\n", 181 | " def construct(self, x):\n", 182 | " # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)\n", 183 | " return self.fc2(self.activation(self.fc1(x)))" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 8, 189 | "id": "cac523ae-53a0-4678-a205-6f51d2e4f4a6", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "class EncoderLayer(nn.Cell):\n", 194 | " def __init__(self):\n", 195 | " super(EncoderLayer, self).__init__()\n", 196 | " self.enc_self_attn = MultiHeadAttention()\n", 197 | " self.pos_ffn = PoswiseFeedForwardNet()\n", 198 | "\n", 199 | " def construct(self, enc_inputs, enc_self_attn_mask):\n", 200 | " enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V\n", 201 | " enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]\n", 202 | " return enc_outputs, attn" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 9, 208 | "id": "2891fc39-ccf0-4f8c-875a-821ad85ec029", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "class BERT(nn.Cell):\n", 213 | " def __init__(self):\n", 214 | " super(BERT, self).__init__()\n", 215 | " self.embedding = BertEmbedding()\n", 216 | " self.layers = nn.CellList([EncoderLayer() for _ in range(n_layers)])\n", 217 | " self.fc = Dense(d_model, d_model)\n", 218 | " self.activ1 = nn.Tanh()\n", 219 | " self.linear = Dense(d_model, d_model)\n", 220 | " self.activ2 = nn.GELU(False)\n", 221 | " self.norm = nn.LayerNorm([d_model,])\n", 222 | " self.classifier = Dense(d_model, 2)\n", 223 | " # decoder is shared with embedding layer\n", 224 | " embed_weight = self.embedding.tok_embed.embedding_table\n", 225 | " n_vocab, n_dim = embed_weight.shape\n", 226 | " self.decoder = Dense(n_dim, n_vocab, has_bias=False)\n", 227 | " self.decoder.weight = embed_weight\n", 228 | " self.decoder_bias = mindspore.Parameter(ops.zeros(n_vocab), 'decoder_bias')\n", 229 | "\n", 230 | " def construct(self, input_ids, segment_ids, masked_pos):\n", 231 | " output = self.embedding(input_ids, segment_ids)\n", 232 | " enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n", 233 | " for layer in self.layers:\n", 234 | " output, enc_self_attn = layer(output, enc_self_attn_mask)\n", 235 | " # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]\n", 236 | " # it will be decided by first token(CLS)\n", 237 | " h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model]\n", 238 | " logits_clsf = self.classifier(h_pooled) # [batch_size, 2]\n", 239 | "\n", 240 | " masked_pos = ops.tile(masked_pos[:, :, None], (1, 1, output.shape[-1])) # [batch_size, max_pred, d_model]\n", 241 | " # get masked position from final output of transformer.\n", 242 | " h_masked = ops.gather_d(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]\n", 243 | " h_masked = self.norm(self.activ2(self.linear(h_masked)))\n", 244 | " logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]\n", 245 | "\n", 246 | " return logits_lm, logits_clsf" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 10, 252 | "id": "9391e0d9-f019-4d3e-9c6c-fb57a3b6a8e6", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "# BERT Parameters\n", 257 | "maxlen = 30 # maximum of length\n", 258 | "batch_size = 6\n", 259 | "max_pred = 5 # max tokens of prediction\n", 260 | "n_layers = 6 # number of Encoder of Encoder Layer\n", 261 | "n_heads = 12 # number of heads in Multi-Head Attention\n", 262 | "d_model = 768 # Embedding Size\n", 263 | "d_ff = 768 * 4 # 4*d_model, FeedForward dimension\n", 264 | "d_k = d_v = 64 # dimension of K(=Q), V\n", 265 | "n_segments = 2" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 11, 271 | "id": "24608b11-440c-4fb6-b070-45ff3d82c014", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "text = (\n", 276 | " 'Hello, how are you? I am Romeo.\\n'\n", 277 | " 'Hello, Romeo My name is Juliet. Nice to meet you.\\n'\n", 278 | " 'Nice meet you too. How are you today?\\n'\n", 279 | " 'Great. My baseball team won the competition.\\n'\n", 280 | " 'Oh Congratulations, Juliet\\n'\n", 281 | " 'Thanks you Romeo'\n", 282 | ")\n", 283 | "sentences = re.sub(\"[.,!?\\\\-]\", '', text.lower()).split('\\n') # filter '.', ',', '?', '!'\n", 284 | "word_list = list(set(\" \".join(sentences).split()))\n", 285 | "word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}\n", 286 | "for i, w in enumerate(word_list):\n", 287 | " word_dict[w] = i + 4\n", 288 | "number_dict = {i: w for i, w in enumerate(word_dict)}\n", 289 | "vocab_size = len(word_dict)\n", 290 | "\n", 291 | "token_list = list()\n", 292 | "for sentence in sentences:\n", 293 | " arr = [word_dict[s] for s in sentence.split()]\n", 294 | " token_list.append(arr)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 12, 300 | "id": "fe4e30ab-9e7d-4868-893f-b160cf090959", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "model = BERT()\n", 305 | "criterion = nn.CrossEntropyLoss()\n", 306 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 13, 312 | "id": "eaef3603-8154-495b-9e00-5916c65c9f3c", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "def forward(input_ids, segment_ids, masked_pos, masked_tokens, isNext):\n", 317 | " logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)\n", 318 | " loss_lm = criterion(logits_lm.swapaxes(1, 2), masked_tokens.astype(mindspore.int32))\n", 319 | " loss_lm = loss_lm.mean()\n", 320 | " loss_clsf = criterion(logits_clsf, isNext.astype(mindspore.int32))\n", 321 | "\n", 322 | " return loss_lm + loss_clsf" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 14, 328 | "id": "0c1c152d-d4f0-4a66-b3e2-5cbf76d15d48", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 15, 338 | "id": "e2cd05a5-b034-46cd-980e-dfb15e7b6155", 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "@mindspore.jit\n", 343 | "def train_step(input_ids, segment_ids, masked_pos, masked_tokens, isNext):\n", 344 | " loss, grads = grad_fn(input_ids, segment_ids, masked_pos, masked_tokens, isNext)\n", 345 | " optimizer(grads)\n", 346 | " return loss" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 16, 352 | "id": "81bf550b-8239-440d-9fda-c556dee4552c", 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "name": "stderr", 357 | "output_type": "stream", 358 | "text": [ 359 | "[ERROR] CORE(1267049,7f74549fd4c0,python):2024-04-16-15:56:16.580.126 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1267049/3083615623.py]\n", 360 | "[ERROR] CORE(1267049,7f74549fd4c0,python):2024-04-16-15:56:16.580.172 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1267049/3083615623.py]\n" 361 | ] 362 | }, 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "Epoch: 0010 cost = 46.552399\n", 368 | "Epoch: 0020 cost = 19.055964\n", 369 | "Epoch: 0030 cost = 15.114850\n", 370 | "Epoch: 0040 cost = 9.543916\n", 371 | "Epoch: 0050 cost = 6.100155\n", 372 | "Epoch: 0060 cost = 2.962293\n", 373 | "Epoch: 0070 cost = 3.004694\n", 374 | "Epoch: 0080 cost = 2.631464\n", 375 | "Epoch: 0090 cost = 2.321460\n", 376 | "Epoch: 0100 cost = 2.230808\n" 377 | ] 378 | } 379 | ], 380 | "source": [ 381 | "batch = make_batch()\n", 382 | "input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(mindspore.Tensor, zip(*batch))\n", 383 | "\n", 384 | "model.set_train()\n", 385 | "for epoch in range(100):\n", 386 | " loss = train_step(input_ids, segment_ids, masked_pos, masked_tokens, isNext) # for sentence classification\n", 387 | " if (epoch + 1) % 10 == 0:\n", 388 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "id": "f7da833d-9efb-475f-9aa1-93be15e3ea73", 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "# Predict mask tokens ans isNext\n", 399 | "input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(mindspore.Tensor, zip(batch[0]))\n", 400 | "print(text)\n", 401 | "print([number_dict[int(w.asnumpy())] for w in input_ids[0] if number_dict[int(w.asnumpy())] != '[PAD]'])\n", 402 | "\n", 403 | "logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)\n", 404 | "logits_lm = logits_lm.argmax(2)[0].asnumpy()\n", 405 | "print('masked tokens list : ',[pos for pos in masked_tokens[0] if pos != 0])\n", 406 | "print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])\n", 407 | "\n", 408 | "logits_clsf = logits_clsf.argmax(1).asnumpy()[0]\n", 409 | "print('isNext : ', True if isNext else False)\n", 410 | "print('predict isNext : ',True if logits_clsf else False)" 411 | ] 412 | } 413 | ], 414 | "metadata": { 415 | "kernelspec": { 416 | "display_name": "Python 3 (ipykernel)", 417 | "language": "python", 418 | "name": "python3" 419 | }, 420 | "language_info": { 421 | "codemirror_mode": { 422 | "name": "ipython", 423 | "version": 3 424 | }, 425 | "file_extension": ".py", 426 | "mimetype": "text/x-python", 427 | "name": "python", 428 | "nbconvert_exporter": "python", 429 | "pygments_lexer": "ipython3", 430 | "version": "3.9.18" 431 | }, 432 | "vscode": { 433 | "interpreter": { 434 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 435 | } 436 | } 437 | }, 438 | "nbformat": 4, 439 | "nbformat_minor": 5 440 | } 441 | -------------------------------------------------------------------------------- /4-3.Bi-LSTM(Attention)/Bi-LSTM-Attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "444b3f0a-f9a8-48bc-8aa9-eb6970040e2a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import mindspore\n", 12 | "import mindspore.nn as nn\n", 13 | "import mindspore.ops as ops\n", 14 | "import mindspore.numpy as mnp\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from mindspore import ms_function" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "id": "4956c4aa-bf84-44c0-b271-58a26a26ba87", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "class BiLSTM_Attention(nn.Cell):\n", 27 | " def __init__(self):\n", 28 | " super(BiLSTM_Attention, self).__init__()\n", 29 | "\n", 30 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", 31 | " self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)\n", 32 | " self.out = nn.Dense(n_hidden * 2, num_classes)\n", 33 | "\n", 34 | " # lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix\n", 35 | " def attention_net(self, lstm_output, final_state):\n", 36 | " hidden = final_state.view(-1, n_hidden * 2, 1) # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]\n", 37 | " attn_weights = ops.matmul(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]\n", 38 | " soft_attn_weights = ops.Softmax(1)(attn_weights)\n", 39 | " # [batch_size, n_hidden * num_directions(=2), n_step] * [batch_size, n_step, 1] = [batch_size, n_hidden * num_directions(=2), 1]\n", 40 | " context = ops.matmul(lstm_output.swapaxes(1, 2), soft_attn_weights.expand_dims(2)).squeeze(2)\n", 41 | " return context, soft_attn_weights # context : [batch_size, n_hidden * num_directions(=2)]\n", 42 | "\n", 43 | " def construct(self, X):\n", 44 | " input = self.embedding(X) # input : [batch_size, len_seq, embedding_dim]\n", 45 | " input = input.transpose(1, 0, 2) # input : [len_seq, batch_size, embedding_dim]\n", 46 | "\n", 47 | " # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n", 48 | " output, (final_hidden_state, final_cell_state) = self.lstm(input)\n", 49 | " output = output.transpose(1, 0, 2) # output : [batch_size, len_seq, n_hidden]\n", 50 | " attn_output, attention = self.attention_net(output, final_hidden_state)\n", 51 | " return self.out(attn_output), attention # model : [batch_size, num_classes], attention : [batch_size, n_step]" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "id": "ab7e4fe2-0dd5-4473-bbd4-67f2115bd181", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "embedding_dim = 2 # embedding size\n", 62 | "n_hidden = 5 # number of hidden units in one cell\n", 63 | "num_classes = 2 # 0 or 1\n", 64 | "\n", 65 | "# 3 words sentences (=sequence_length is 3)\n", 66 | "sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \"i hate you\", \"sorry for that\", \"this is awful\"]\n", 67 | "labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.\n", 68 | "\n", 69 | "word_list = \" \".join(sentences).split()\n", 70 | "word_list = list(set(word_list))\n", 71 | "word_dict = {w: i for i, w in enumerate(word_list)}\n", 72 | "vocab_size = len(word_dict)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 4, 78 | "id": "f0d10505-d45e-481f-a406-e66c42b15775", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "model = BiLSTM_Attention()\n", 83 | "criterion = nn.CrossEntropyLoss()\n", 84 | "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "id": "82c877f4-d0ab-48d5-a724-bff9acc7527a", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "inputs = mindspore.Tensor([np.asarray([word_dict[n] for n in sen.split()]) for sen in sentences])\n", 95 | "targets = mindspore.Tensor([out for out in labels], mindspore.int32)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 6, 101 | "id": "dcad76d3", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "def forward(input, target):\n", 106 | " output, attn = model(input)\n", 107 | " loss = criterion(output, target)\n", 108 | " return loss, attn" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 7, 114 | "id": "758a3ab5", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "grad_fn = ops.value_and_grad(forward, None, optimizer.parameters, has_aux=True)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 8, 124 | "id": "f0bffb1c", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "@ms_function\n", 129 | "def train_step(input, target):\n", 130 | " (loss, _), grads = grad_fn(input, target)\n", 131 | " optimizer(grads)\n", 132 | " return loss" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 9, 138 | "id": "79044130-6d93-41cc-b4a4-faf67858be4c", 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "Epoch: 1000 cost = 0.004177\n", 146 | "Epoch: 2000 cost = 0.000893\n", 147 | "Epoch: 3000 cost = 0.000332\n", 148 | "Epoch: 4000 cost = 0.000157\n", 149 | "Epoch: 5000 cost = 0.000082\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "model.set_train()\n", 155 | "# Training\n", 156 | "for epoch in range(5000):\n", 157 | " loss = train_step(inputs, targets)\n", 158 | " if (epoch + 1) % 1000 == 0:\n", 159 | " print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 10, 165 | "id": "c29d2248-796b-44f5-890e-2ec9267a4de5", 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "sorry hate you is Bad Mean...\n" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "# Test\n", 178 | "test_text = 'sorry hate you'\n", 179 | "tests = [np.asarray([word_dict[n] for n in test_text.split()])]\n", 180 | "test_batch = mindspore.Tensor(tests)\n", 181 | "\n", 182 | "# Predict\n", 183 | "predict, attention = model(test_batch)\n", 184 | "predict = predict.argmax(1)\n", 185 | "\n", 186 | "if predict[0] == 0:\n", 187 | " print(test_text,\"is Bad Mean...\")\n", 188 | "else:\n", 189 | " print(test_text,\"is Good Mean!!\")" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 11, 195 | "id": "da2e17c5-2f00-4673-a2f7-b91cfa7e4c39", 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stderr", 200 | "output_type": "stream", 201 | "text": [ 202 | "/home/lvyufeng/miniconda3/envs/ms1.8/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: FixedFormatter should only be used together with FixedLocator\n", 203 | " after removing the cwd from sys.path.\n", 204 | "/home/lvyufeng/miniconda3/envs/ms1.8/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: FixedFormatter should only be used together with FixedLocator\n", 205 | " \"\"\"\n" 206 | ] 207 | }, 208 | { 209 | "data": { 210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZsAAADjCAYAAABTnrngAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfO0lEQVR4nO3de7hcVZnn8e+PQACNmR4ICARCREQJYmIIAXRAkEGxoWkEcZ6RANqtAVvkKjpkuqFHhIhcQmjQGORmNwPTPIiI3ZrGNAEZQ0hQGhkSrpIQAiQhaBIDIcA7f6x1TJ3KOTn7XKp21dm/z/PUU1V77ctbTyXnrXXZaykiMDMza6Qtyg7AzMwGPycbMzNrOCcbMzNrOCcbMzNrOCcbMzNrOCcbMzNrOCcbMzNrOCcbMzNrOCcbMzNruC3LDsDMNiXpbaDQ9B4RMaTB4Zj1m5ONWWv6LBuTzbuBbwJ3AnPztoOAY4ELmx6ZWR/Ic6OZtTZJPwHujojr6rZ/CTg2Io4qJzKz4pxszFqcpLXAuIh4um77nsB/RMQ7y4nMrDgPEDBrfSuBz3Sx/TPAiibHYtYn7rMxa30XADdKOoyNfTYHAv8V+OvSojLrBTejmbUBSROBM4G986aFwNURMa+8qMyKc7Ixa2GStgL+CZgSEc+UHY9ZX7nPxqyFRcQG4BMUvOfGrFU52Zi1vh8Bx5UdhFl/eICAWetbAvytpIOBBcAfawsj4spSojLrBffZmLU4Sb/bTHFExB5NC8asj5xszMys4dxnY9ZGJA2T5BkDrO042Zi1AUlfkbQE+AOwWtJiSX9TdlxmRXmAgFmLkzQFOB+4HHggbz4Y+Lak4RHx7dKCMyvIfTZmLS7XaL4REbfWbT8RuCQidi8nMrPiXLMxa307AvO72P4Qaa0bayF59GDRhe8qM5LQycas9T0JfI60gFqtzwFPND8c68E1Na+HAeeQfhjULnw3EbiiyXGVys1oZi1O0nHAPwNzgP+bN38U+BhwQkT8uJzIrCeSbgKejIhL6rafD+wTEZNKCawETjZmbUDSfsDZdJ71+YqI+E15UVlPJK0Gxnez8N2vI2J4OZE1n5vRzNpARDwMVOZX8CDyR+BQ4Om67YcC65odTJmcbAYpSfdSvJPy4w0Ox/pB0kzgXmBORLxYdjzWK9OAayVNAB7M2w4ETgH+vqygyuBkM3g9VvN6CHAi8BLQsdjWRGBn0lop1treAVwKjJT0DKnvZg4p+SwrMS7rQUR8R9JzpIXvPps3LwROiYh/Li2wErjPpgIkTSMlnDOj5guXdBXp38CZZcVmxeV2/o+RmmAOAXYFno6I95cZl3UtL3x3MXBtRCwuO56yOdlUgKRXgIMi4sm67XsBD0bEduVEZr0haQtgf+DjwGGkpPNCRLynzLise5LWAh+MiOfKjqVsnhutGgTs28X2rrZZi5H0dUn/CvweuBXYC7gFeJ8TTcubRfpxUHnus6mGG4AfSHofnTspvw7cWFpUVtS3gRXARcBNEbGi5HisuNnAJZI+BDzMpgvf/aiUqErgZrQKyM0vXyN1Uu6cN78ITCfdq/FWWbFZzyQdTmoyOxSYQBpGO4c0Qu2+iHilrNhs8yS9vZniiIghTQumZE42g5ykLYHJwI8jYpmk4QARsbrcyKwvJG0LfIQ0uvBEYIuI2KrcqMx65mRTAZL+CIzxiJj2JWlHNg4KOIzUb/MSqWbz30sMzawQDxCohgeB/coOwvpG0kJSs+c04M/y894RsYsTTeuTdJSk+yWtlLRC0n2S/rzsuJrNAwSq4Trgckmj6LqT8telRGVFXUW6gdMzPLcZSV8EvksaPXhz3nwwcKekL0fEDaUF12RuRqsAd1JWQ570cVxEPFt2LJZIegqYHhHX1G3/KvDViNirnMiazzWbavC9GNWgsgOwTYwCft7F9p+RlvmuDCebCvDAALPSLAGOYNNZnz8BVOr/pZNNReSbyr4GjCHNBv04cFlEPLbZA82sPy4H/kHSeOBXedtHgZOAr5YWVQmcbCpA0jHAj4BfkqrvAP8F+I2k4yLi7tKCMxvEIuL7kpYD5wLH5c0Lgc9GxF3lRdZ8HiBQAZIeBe6MiAvrtn8T+MuIGFtOZDaQPEDAWpnvs6mGvYB/7GL7PwKenn7w8ACBFiNpiqSD8kweleZkUw3L6fqmzv2Al5scizXOp4AXyg7COvkUaQ67VyX9W04+H6li8qncB66o64Dv58W3ajspvwZcVlpU1i1JhW/2i4i/ys8PNC4i64uIODjPZ/dR0sJ3nwL+DnhT0q8i4pOlBthE7rOpAEkCziJ1Uu6SNy8jJZqrw/8IWo6k+kEbhwBvA7/N7z9Iapm4PyKOaWZs1jeS3k1a2+Yo0hLRb0bEO8qNqnmcbCpG0rsAImJN2bFYMZLOBz4MfCEi/pi3vRO4HvhtRFxcZnzWPUmfZePkqaOAecB9pCUiHoyI9aUF12RONhUg6XPAvRHxYtmxWO9JehE4PCIer9u+DzA7InYqJzLrSZ4qagXpfptrI2JdySGVxgMEquHbwFJJT0qaKelzknbp8ShrFcPY2PxZa2egMs0wbWoy8G+kGziXSbpb0rmSxufm7cpwzaYi8uCAQ0mdlB8DRgLPkGo8p5YYmvVA0k3A4cB5dF7W+1LS9/f5ciKz3pD0XtL/wSOATwNrI2L7UoNqIiebipE0BJgIfAmYBAzxrM+tLY9mugL4K6BjVc43SX02X6ty00w7yMuy709KNB8njUwbCjwcEQeVGFpTOdlUgKSJbOyk/Ciwko2dlHM8UWd7yIMC3pvfPtMxWMBal6SfkZbx3pa0ltSc/Higat+fk00F1HVS/p+IWFJySGaVIGkqFU0u9ZxsKkDSt0j9NPuTpjq/l421mldKDM0KkLQNcCap32ZH6gb2RMSHyojLBo6k3wJ/HhHPlx1LozjZVEhu+/8IGwcKTASe8EScrS3PJvBp4HbSzbid/tNGxP8qIy4bOJLWAGMH8ySqnq6mWoYDI0i/jncidVKOKDUiK+JY4ISI+EXZgZj1le+zqQBJ35P0OOlX8TRS0rkC2DsiRpYanBWxDhi0zStWDa7ZVMOfAdNJfTRPlByL9d53gHMkneZ57Kxduc/G/kTSvwBf9LQ2rSVPynkw8AfSct4bass9EWf7c5+NVc0hpPsBrLWsBO4sOwiz/nCyMWtxEfGFsmOwhjuVQb6QoZONWZuQtAcwhjT0eeFgbnJpZ5IuKLpvRHwzP//vxkXUGtxnY39ShXbjdiRpOGketONJC6gBCLgD+GuvTdRa8g2atXYnzc69LL/fhTTC8Lkq3ZDroc9mrW868CHS3Hbb5sfhedtV5YVlXYmIfTsewJWkOdH2iIhRETEK2AOYT8W+O9ds7E9cs2lNkl4Bjo2IX9ZtPwS4s0rT1LcbSb8jfXf/Ubd9HHBXROxeSmAlcM2mAiQdImmT/jlJW+Y/WB0uAVY1LzIraFugqznsVgHbNDkW65130/UIz22o2OwdrtlUgKS3gJ0jYnnd9u2B5V7PprVJugdYDZzUsXZNXm7gh8DwiDiizPise5LuIjWbfYnUdBakOQm/D/wuIo4tL7rm8mi0ahB1kzdm2wOVnva8TZwNzAJekPRo3rYvqZP5k6VFZUV8EbgZ+BXwVt62Ben7/FJZQZXBNZtBTNJP8sujgF8A62uKhwAfJA2hPbLZsVnvSHoHcCLwgbxpIXBLRLxWXlS2OXmFzg8AS4Cdgb1z0aKIeLK0wErims3g1tHOL+BVoPYP0xvAA8B1zQ7Kei83n/m7ai8BPAKMiYingKfKDadcTjaDWMed55KeAy6v+kqB7UrSxcDzETGjbvtpwMiI+LtyIrPNiYiQ9ASwA2nRwkrzaLRquIiaWo2knSR9UdJHSozJijsJ+E0X2x8GTm5yLNY7XwculzROksoOpkzus6kAST8Dfh4R0yUNAxYB7wSGke5A/2GpAdpmSXqd1BTzbN32PYDHI8LDn1tUvndtG9IP+zfp3G9KRAwvI64yuBmtGiaQfmEBHEcaRvseUofz10hDaK11LSEtMVB/s+0hwNLmh2O9cHrZAbQKJ5tqGAb8Pr/+BOmu8w2S/h24trSorKjvA9MkDQX+PW87HJgKXFpaVNajiLi57BhahZNNNSwBPpoX4fokcELevh3pXg1rYRFxhaQRwNXA0Lz5DWB6RHynvMisK5K2i4hVHa83t2/HflXgPpsKkHQqcA2wFlgMjI+ItyWdQZq36eOlBmiF5FkDxuS3CyNibZnxWNdqZ+yQ9DZd31At0oC1ysze4ZpNBUTE9yUtAEYB90RExzT1zwAeNts+tiV1ND8SEet72tlK83E2zjF4WJmBtBLXbAY5SVuRbt48OSKeKDse6z1J7wJuIK1nE8D7IuJZSTOAlyLi78uMz6wI32czyEXEBtLIM/+qaF+XkhbcGk/nWSB+Cny6lIisVyTtku+1GV/7KDuuZnIzWjXcTJr077yyA7E+OQb4dEQ8Iqn2R8NC0ozC1qIkfRj4J9IcafU3dQZpjsJKcLKphncCJ0o6gnTXeadpayLijFKisqL+M12vZ/MuNs4kbK1pJvA86cfeMircwuBkUw17A7/Or+t/CVf2H38bmU+q3VyV33d8Z6eSpq631jUG+HAVZ3mu52RTARHhETHtbQowS9I+pP+z5+TXB5BmFrDW9VtgJ6Dyycaj0czagKQPkvrc9iMN7HkY+E5E/LbUwGwTdTdyjiMtt/63pMSzoXZf39RpbS8vnDYpIlbnmQO6/aIj4pjmRWa9JWkM8FbH0HVJnyDN9vz/SAnH/TYtpIsbOTsGBtRv802dNih8kI3/uFeWGYj12w2k/ponJO0G3AncB3wFGA6cX15o1oXaZuvRpAEC9T8ItiDdZF0ZrtkMUvnX1U55yoxngf0joqsRTdbiJP0emBgRT0o6GzgmIg6TdBhwY0SMLjVA61bt1DV127cHllepZuObOgevVaSbOSH9uvJ33b6GkCbehDTb87/m188A7y4lIitKdN2EPQx4vcmxlMrNaIPXHcB9kl4k/WNfkH9lbSIifGNga3sM+LKkn5KSTUez2UjcRNqSJF2dXwYwVVLt7OpDgInAI82Oq0xONoPXacBPgPcBVwI3AmtKjcj66hvAj0kL3d1cMwLtGOChsoKyzdo3P4t0n9sbNWVvkO57u7zZQZXJfTYVIOlG4IyIcLJpU5KGAMMj4tWabaOBdfX9AdY68v+9MyNiddmxlM3JxszMGs6dxmZm1nBONmZm1nBONhUkaXLZMVjf+ftrb1X9/pxsqqmS/9gHEX9/7a2S35+TjZmZNZxHo3VjxHZDYvRuW5UdRkOseOUtdth+cM+S8eTT2/W8U5va8OY6ttryHWWH0TjrBveN9RtYz1ZsXXYYDfE6f+SNWF+/Iingmzq7NXq3rXho1m5lh2F9dOQxk8oOwfooFjxWdgjWR/NidrdlbkYzM7OGc7IxM7OGc7IxM7OGc7IxM7OG6zHZSJoj6ZpmBFNzzdGSQtKEZl7XzMwao+E1G0mH5sQxotHXqrvudEkLJL0u6blmXtvMzDobzM1oWwA3Az8sOxAzs6ormmy2zDWFV/PjMklbAEiaJGm+pDWSlku6XdLIXDYauDefY0Wu4dyUyyTpXElPSVovaamkqXXX3V3SPZLWSXpc0hFFP1hEfDUi/gF4sugxZmbWGEWTzYl534OAU0lz+5yVy4YCFwJjgaOBEcCtuex54Pj8eh9gZ+DM/P4S4O+AqbnshLx/rYuBq/O55wO3SRpWMGYzM2sRRWcQeJG00mMAiyTtBZwDXBkRN9Ts96ykLwMLJe0aEUslrcplyyNiJUBOGGcDZ9Uc/zQwt+660yLi7nzMFOBkYBzwQK8+ZUF5NtbJAKNGenIFM7OBUrRm82B0nkRtLjBS0nBJ4yXdJWmxpDXAgrzPqM2cbwywNdD93AbJozWvl+XnHQvG3GsRMTMiJkTEhME+d5iZWTP1d4CAgFnAOuAkYH/gyFw2tJ/nBtjQ8aIm2Q3mQQ1mZoNS0T/cB0iqncnzQFJNY09SH82UiLg/Ihaxac3jjfxcW1VYCKwHDu99yGZm1m6KJptdgKskvV/SZ4DzgGnAElLSOF3SHpKOAi6qO3YxEMBRknaQNCwi1gDTgamSviDpvZIm5v6eASFpT0njcuxDJY3Lj4GocZmZWS8U7QW/hVQzmUdKHNeTOu/fknQKaWTZV0h9LOcAP+84MCJekHQhaWTZD0j3vXweOB94lTQibVfgZQb2npgfAB+ref+b/Pwe4LkBvI6ZmfXAi6d1Y8LYbcLr2bQvr2fTvryeTfuaF7NZHau6XDzNne1mZtZwbZlsJM2QtLabx4yy4zMzs87a9c7FC4DLuylb3cxAzMysZ22ZbCJiObC87DjMzKyYtmxGMzOz9uJkY2ZmDedkY2ZmDedkY2ZmDedkY2ZmDedkY2ZmDedkY2ZmDddjspE0R9I1zQim5pqjJYWkCc28rpmZNUbDazaSDs2JY0Sjr1VzzbGSbpX0vKTXJD0h6euSXJMzMytBW84gUMB+wArS6qFLgInAdaTPe0mJcZmZVVLRX/pbSpou6dX8uKyjliBpkqT5ktZIWi7pdkkjc9lo4N58jhW5hnNTLpOkcyU9JWm9pKWSptZdd3dJ90haJ+lxSUcUCTYiboiIMyJiTkQ8GxG3Ad8Dji/4ec3MbAAVTTYn5n0PAk4FJgNn5bKhwIXAWOBo0jLRt+ay59n4B34fYGfgzPz+EtLCaVNz2Ql5/1oXA1fnc88HbpM0rGDM9YaTFmvrlqTJkhZIWrDilbf6eBkzM6tXtBntReCMSCutLZK0F2lFzisj4oaa/Z7NSzsvlLRrRCyVtCqXLY+IlQA5YZwNnFVz/NPA3LrrTouIu/MxU4CTgXHAA735kJLGk1YHPXFz+0XETGAmpMXTenMNMzPrXtGazYPReUnPucBIScMljZd0l6TFktYAC/I+ozZzvjHA1sDsHq77aM3rZfl5x4IxAyDp/cC/AFdFxB29OdbMzAZGf0dnCZgFrCN1xu8PHJnLhvbz3AAbOl7UJLvCMUv6ADAHuC0i/scAxGNmZn1Q9A/3AZJq15U+kFTT2JPURzMlIu6PiEVsWvN4Iz8Pqdm2EFgPHN77kIuRNIaUaG6PiLMbdR0zM+tZ0WSzC3CVpPdL+gxwHjCNNKx4PXC6pD0kHQVcVHfsYiCAoyTtIGlYRKwBpgNTJX1B0nslTcz9Pf0maR/SKLg5wCWSdup4DMT5zcysd4omm1tINZN5pPtVrid13q8ATgGOBR4njUo7p/bAiHghb78YeBnomI3gfOBS0oi0hcAdwK59/yidnECqYf030uCG2oeZmTWZOvf7W4cJY7eJh2btVnYY1kdHHjOp7BCsj2LBY2WHYH00L2azOlapqzJP32JmZg3XlslG0gxJa7t5zCg7PjMz66xd50a7ALi8m7LVzQzEzMx61pbJJiKWA8vLjsPMzIppy2Y0MzNrL042ZmbWcE42ZmbWcE42ZmbWcE42ZmbWcE42ZmbWcD0mG0lzJF3T034DSdLovIT0hGZe18zMGqPhNRtJh+bEMaLR16q55g6SZklaJmm9pOclXSvpPzUrBjMz22iwNqO9DdwJ/AWwF2lJ6MNJM1abmVmTFU02W0qaLunV/LhM0hYAkiZJmi9pjaTlkm6XNDKXjSatKwOwItdwbsplknSupKdy7WOppKl1191d0j2S1kl6XNIRRYKNiFciYkZEPBwRiyNiNvBd4OCCn9fMzAZQ0WRzYt73IOBUYDJwVi4bSlqvZixwNGnlzltz2fPA8fn1PsDOwJn5/SWktWym5rIT8v61LgauzueeD9wmaVjBmP9E0i7AccB9vT3WzMz6r+jcaC8CZ0Ra/GaRpL1Ii6RdGRE31Oz3bF5tc6GkXSNiqaRVuWx5RKwEyAnjbOCsmuOfBubWXXdaRNydj5kCnAyMAx4oErSkW4G/BLYFfgp8oYf9J5MSKaNGtuW0cWZmLalozebB6LzK2lxgpKThksZLukvSYklrgAV5n1GbOd8YYGtgdg/XfbTm9bL8vGPBmCEltPGkhLMHcNXmdo6ImRExISIm7LD9kF5cxszMNqe/P98FzAJ+AZxEmol5BPBLUvNaf23oeBERIQl6MaghIl4CXiLVxlYBv5T0rYiob64zM7MGKvqH+wDlv/TZgaSaxp6k5DIlIu6PiEVsWvN4Iz/XVhUWAutJI8SapeOzbt3Ea5qZGcVrNrsAV0n6LrAvcB7wLWAJKWmcLulaYG/gorpjFwMBHCXpbuC1iFgjaTowVdJ64H5ge2C/iPhefz+UpKPz+R4G1pIGIFxGag58ur/nNzOz3imabG4h1UzmkRLH9aTO+7cknUIaWfYVUh/LOcDPOw6MiBckXUgaWfYD4Iek+17OB14ljUjbFXg5lw2E14HTSMlva9IotzuBbw/Q+c3MrBfUud/fOkwYu008NGu3ssOwPjrymEllh2B9FAseKzsE66N5MZvVsUpdlQ3WGQTMzKyFtGWykTRD0tpuHjPKjs/MzDpr1zsXLwAu76ZsdTMDMTOznrVlsomI5aR7eszMrA20ZTOamZm1FycbMzNrOCcbMzNrOCcbMzNrOCcbMzNrOCcbMzNruB6TjaQ5kq5pRjA11xydl5Ce0MzrmplZYzS8ZiPp0Jw4RjT6Wt1cf4SkF8qMwcys6qrQjHYj8EjZQZiZVVnRZLOlpOmSXs2PyyRtASBpkqT5ktZIWi7pdkkjc9lo4N58jhW5dnFTLpOkcyU9JWm9pKWSptZdd3dJ90haJ+lxSUf05sNJOhN4B3BFb44zM7OBVTTZnJj3PQg4FZgMnJXLhgIXAmOBo0krd96ay54Hjs+v9wF2Bs7M7y8hrWUzNZedkPevdTFwdT73fOA2ScOKBCzpw8A3gJOBtwt9SjMza4iic6O9CJwRafGbRZL2Ii2SdmVE3FCz37OSvgwslLRrRCyVtCqXLY+IlQA5YZwNnFVz/NPA3LrrTouIu/MxU0iJYxzwwOaClfRO4Dbgq3nxtvcV+ZCSJpMSKaNGtuW0cWZmLalozebB6LzK2lxgpKThksZLukvSYklrgAV5n1GbOd8Y0gqas3u47qM1r5fl5x0LxHs18EBE3FFg3z+JiJkRMSEiJuyw/ZDeHGpmZpvR3wECAmYB64CTgP2BI3PZ0H6eG2BDx4uaZFck5sOBz0t6U9KbbExqL0m6eADiMjOzXijaVnSAJNX8wT+QVNPYk9RHMyUifgcg6bi6Y9/Iz7VVhYXAelJSeKovgffgE3ROdvsDNwCHNuh6Zma2GUWTzS7AVZK+C+wLnAd8C1hCShqnS7oW2Bu4qO7YxUAAR0m6G3gtItZImg5MlbQeuB/YHtgvIr7X3w8VEU/Wvq+5v2ZRR7+RmZk1T9FmtFtINZN5wHXA9aTO+xXAKcCxwOOkUWnn1B4YES/k7RcDLwMdsxGcD1xKGpG2ELgD2LXvH8XMzFqVOvf7W4cJY7eJh2btVnYY1kdHHjOp7BCsj2LBY2WHYH00L2azOlapq7IqzCBgZmYla8tkI2mGpLXdPGaUHZ+ZmXXWrncuXgBc3k3Z6mYGYmZmPWvLZBMRy4HlZcdhZmbFtGUzmpmZtRcnGzMzazgnGzMzazgnGzMzazgnGzMzazgnGzMzazgnGzMza7gek42kOZKu6Wm/gSRptKSQNKGZ1zUzs8ZoeM1G0qE5cYzoee8BvW508TitmTGYmVnSljMI9MKXgJ/WvP9DWYGYmVVZ0ZrNlpKmS3o1Py6TtAWApEmS5ktaI2m5pNsljcxlo4F78zlW5NrFTblMks6V9JSk9ZKWSppad93dJd0jaZ2kxyUd0cvP9/uIeKnm8VovjzczswFQNNmcmPc9CDgVmAyclcuGkhZHGwscTVom+tZc9jxwfH69D7AzcGZ+fwlp4bSpueyEvH+ti4Gr87nnA7dJGlYwZoDpklbmZHhaR4LsjqTJkhZIWrDilbd6cRkzM9ucos1oLwJnRFppbZGkvUgrcl4ZETfU7PespC8DCyXtGhFLJa3KZcs7lmTOCeNs4Kya458G5tZdd1pE3J2PmQKcDIwDHigQ8wWkWtVa4HDgClIi/FZ3B0TETGAmpMXTClzDzMwKKJpsHozOS3rOBS6SNBzYk1SzGQdsB3Ss0jYKWNrN+cYAWwOze7juozWvl+XnHYsEHBEX1bx9RNIQ4H+ymWRjZmaN0d/RaAJmAeuAk4D9gSNz2dB+nhtgQ8eLmmTX15jnAcMlvbvfUZmZWa8U/cN9gKTadaUPJNU09iQ1TU2JiPsjYhGb1jzeyM9DarYtBNaTmreaZRzwOvD7Jl7TzMwo3oy2C3CVpO8C+wLnkZqjlpCSxumSrgX2Bi6qO3YxEMBRku4GXouINZKmA1MlrQfuB7YH9ouI7/X3Q0n6C2AnUnPfa8BhwDeBmRGxvr/nNzOz3imabG4h1UzmkRLH9aTO+7cknUIaWfYVUh/LOcDPOw6MiBckXUgaWfYD4IfA54HzgVdJI9J2BV7OZQNhA/A3wJWk2tuzpAED1w7Q+c3MrBfUud/fOkwYu008NGu3ssOwPjrymEllh2B9FAseKzsE66N5MZvVsUpdlXkiTjMza7i2TDaSZkha281jRtnxmZlZZ+06N9oFwOXdlK1uZiBmZtYz99l0Q9IK0ki6wWgEsLLsIKzP/P21t8H8/e0eETt0VeBkU0GSFkSE1wpqU/7+2ltVv7+27LMxM7P24mRjZmYN52RTTTPLDsD6xd9fe6vk9+c+GzMzazjXbMzMrOGcbMzMrOGcbMzMrOGcbMzMrOGcbMzMrOH+P9ckWoaiGI1RAAAAAElFTkSuQmCC", 211 | "text/plain": [ 212 | "
" 213 | ] 214 | }, 215 | "metadata": { 216 | "needs_background": "light" 217 | }, 218 | "output_type": "display_data" 219 | } 220 | ], 221 | "source": [ 222 | "fig = plt.figure(figsize=(6, 3)) # [batch_size, n_step]\n", 223 | "ax = fig.add_subplot(1, 1, 1)\n", 224 | "ax.matshow(attention.asnumpy(), cmap='viridis')\n", 225 | "ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)\n", 226 | "ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], fontdict={'fontsize': 14})\n", 227 | "plt.show()" 228 | ] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3.7.13 ('ms1.8')", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.7.13" 248 | }, 249 | "vscode": { 250 | "interpreter": { 251 | "hash": "bd0943702584cdb580f8947884f31a9fb49482f77f8c89ed6532de3aa180e7ba" 252 | } 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 5 257 | } 258 | --------------------------------------------------------------------------------