├── .docs ├── attention_heads.png ├── bert_embeddings.png ├── bert_fine_tune_pre_train.png ├── encoder.png ├── gelu.png ├── graph.png ├── model-pyTorch_BERT_model.svg ├── model.puml ├── res │ ├── mlm_loss.png │ ├── mlm_train_accuracy.png │ ├── nsp_loss.png │ ├── nsp_loss_smoothed.png │ └── nsp_train_accuracy.png └── the_process.png ├── .gitignore ├── .wheels └── torchtext-0.12.0a0+2683aa8-cp39-cp39-macosx_12_0_universal2.whl ├── README.md ├── bert ├── __init__.py ├── dataset.py ├── model.py └── trainer.py ├── graph.py ├── main.py └── requirements.txt /.docs/attention_heads.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/attention_heads.png -------------------------------------------------------------------------------- /.docs/bert_embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/bert_embeddings.png -------------------------------------------------------------------------------- /.docs/bert_fine_tune_pre_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/bert_fine_tune_pre_train.png -------------------------------------------------------------------------------- /.docs/encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/encoder.png -------------------------------------------------------------------------------- /.docs/gelu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/gelu.png -------------------------------------------------------------------------------- /.docs/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/graph.png -------------------------------------------------------------------------------- /.docs/model-pyTorch_BERT_model.svg: -------------------------------------------------------------------------------- 1 | pyTorch BERT modelBERTThe main class of our model.It unites all the sub-modelsJointEmbeddingContainer for embeddings.The first layer in the modelAttentionHeadAttention modelMultiHeadAttentionContainer for attention headsEncoderPass embeddings through attentionand feed-forward neural network -------------------------------------------------------------------------------- /.docs/model.puml: -------------------------------------------------------------------------------- 1 | @startuml 2 | 3 | !theme cerulean-outline 4 | 5 | skinparam backgroundColor white 6 | 7 | title "pyTorch BERT model" 8 | 9 | class BERT { 10 | The main class of our model. 11 | It unites all the sub-models 12 | .. 13 | } 14 | 15 | class JointEmbedding { 16 | Container for embeddings. 17 | The first layer in the model 18 | .. 19 | } 20 | 21 | class AttentionHead { 22 | Attention model 23 | .. 24 | } 25 | 26 | class MultiHeadAttention { 27 | Container for attention heads 28 | .. 29 | } 30 | 31 | class Encoder { 32 | Pass embeddings through attention 33 | and feed-forward neural network 34 | .. 35 | } 36 | 37 | JointEmbedding -* BERT 38 | Encoder -left-* BERT 39 | AttentionHead -left-* MultiHeadAttention 40 | MultiHeadAttention --* Encoder 41 | 42 | @enduml -------------------------------------------------------------------------------- /.docs/res/mlm_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/res/mlm_loss.png -------------------------------------------------------------------------------- /.docs/res/mlm_train_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/res/mlm_train_accuracy.png -------------------------------------------------------------------------------- /.docs/res/nsp_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/res/nsp_loss.png -------------------------------------------------------------------------------- /.docs/res/nsp_loss_smoothed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/res/nsp_loss_smoothed.png -------------------------------------------------------------------------------- /.docs/res/nsp_train_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/res/nsp_train_accuracy.png -------------------------------------------------------------------------------- /.docs/the_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.docs/the_process.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm+all 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm+all 4 | 5 | ### PyCharm+all ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # AWS User-specific 17 | .idea/**/aws.xml 18 | 19 | # Generated files 20 | .idea/**/contentModel.xml 21 | 22 | # Sensitive or high-churn files 23 | .idea/**/dataSources/ 24 | .idea/**/dataSources.ids 25 | .idea/**/dataSources.local.xml 26 | .idea/**/sqlDataSources.xml 27 | .idea/**/dynamic.xml 28 | .idea/**/uiDesigner.xml 29 | .idea/**/dbnavigator.xml 30 | 31 | # Gradle 32 | .idea/**/gradle.xml 33 | .idea/**/libraries 34 | 35 | # Gradle and Maven with auto-import 36 | # When using Gradle or Maven with auto-import, you should exclude module files, 37 | # since they will be recreated, and may cause churn. Uncomment if using 38 | # auto-import. 39 | # .idea/artifacts 40 | # .idea/compiler.xml 41 | # .idea/jarRepositories.xml 42 | # .idea/modules.xml 43 | # .idea/*.iml 44 | # .idea/modules 45 | # *.iml 46 | # *.ipr 47 | 48 | # CMake 49 | cmake-build-*/ 50 | 51 | # Mongo Explorer plugin 52 | .idea/**/mongoSettings.xml 53 | 54 | # File-based project format 55 | *.iws 56 | 57 | # IntelliJ 58 | out/ 59 | 60 | # mpeltonen/sbt-idea plugin 61 | .idea_modules/ 62 | 63 | # JIRA plugin 64 | atlassian-ide-plugin.xml 65 | 66 | # Cursive Clojure plugin 67 | .idea/replstate.xml 68 | 69 | # SonarLint plugin 70 | .idea/sonarlint/ 71 | 72 | # Crashlytics plugin (for Android Studio and IntelliJ) 73 | com_crashlytics_export_strings.xml 74 | crashlytics.properties 75 | crashlytics-build.properties 76 | fabric.properties 77 | 78 | # Editor-based Rest Client 79 | .idea/httpRequests 80 | 81 | # Android studio 3.1+ serialized cache file 82 | .idea/caches/build_file_checksums.ser 83 | 84 | ### PyCharm+all Patch ### 85 | # Ignores the whole .idea folder and all .iml files 86 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 87 | 88 | .idea/* 89 | 90 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 91 | 92 | *.iml 93 | modules.xml 94 | .idea/misc.xml 95 | *.ipr 96 | 97 | # Sonarlint plugin 98 | .idea/sonarlint 99 | 100 | ### Python ### 101 | # Byte-compiled / optimized / DLL files 102 | __pycache__/ 103 | *.py[cod] 104 | *$py.class 105 | 106 | # C extensions 107 | *.so 108 | 109 | # Distribution / packaging 110 | .Python 111 | build/ 112 | develop-eggs/ 113 | dist/ 114 | downloads/ 115 | eggs/ 116 | .eggs/ 117 | lib/ 118 | lib64/ 119 | parts/ 120 | sdist/ 121 | var/ 122 | wheels/ 123 | share/python-wheels/ 124 | *.egg-info/ 125 | .installed.cfg 126 | *.egg 127 | MANIFEST 128 | 129 | # PyInstaller 130 | # Usually these files are written by a python script from a template 131 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 132 | *.manifest 133 | *.spec 134 | 135 | # Installer logs 136 | pip-log.txt 137 | pip-delete-this-directory.txt 138 | 139 | # Unit test / coverage reports 140 | htmlcov/ 141 | .tox/ 142 | .nox/ 143 | .coverage 144 | .coverage.* 145 | .cache 146 | nosetests.xml 147 | coverage.xml 148 | *.cover 149 | *.py,cover 150 | .hypothesis/ 151 | .pytest_cache/ 152 | cover/ 153 | 154 | # Translations 155 | *.mo 156 | *.pot 157 | 158 | # Django stuff: 159 | *.log 160 | local_settings.py 161 | db.sqlite3 162 | db.sqlite3-journal 163 | 164 | # Flask stuff: 165 | instance/ 166 | .webassets-cache 167 | 168 | # Scrapy stuff: 169 | .scrapy 170 | 171 | # Sphinx documentation 172 | docs/_build/ 173 | 174 | # PyBuilder 175 | .pybuilder/ 176 | target/ 177 | 178 | # Jupyter Notebook 179 | .ipynb_checkpoints 180 | 181 | # IPython 182 | profile_default/ 183 | ipython_config.py 184 | 185 | # pyenv 186 | # For a library or package, you might want to ignore these files since the code is 187 | # intended to run in multiple environments; otherwise, check them in: 188 | # .python-version 189 | 190 | # pipenv 191 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 192 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 193 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 194 | # install all needed dependencies. 195 | #Pipfile.lock 196 | 197 | # poetry 198 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 199 | # This is especially recommended for binary packages to ensure reproducibility, and is more 200 | # commonly ignored for libraries. 201 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 202 | #poetry.lock 203 | 204 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 205 | __pypackages__/ 206 | 207 | # Celery stuff 208 | celerybeat-schedule 209 | celerybeat.pid 210 | 211 | # SageMath parsed files 212 | *.sage.py 213 | 214 | # Environments 215 | .env 216 | .venv 217 | env/ 218 | venv/ 219 | ENV/ 220 | env.bak/ 221 | venv.bak/ 222 | 223 | # Spyder project settings 224 | .spyderproject 225 | .spyproject 226 | 227 | # Rope project settings 228 | .ropeproject 229 | 230 | # mkdocs documentation 231 | /site 232 | 233 | # mypy 234 | .mypy_cache/ 235 | .dmypy.json 236 | dmypy.json 237 | 238 | # Pyre type checker 239 | .pyre/ 240 | 241 | # pytype static type analyzer 242 | .pytype/ 243 | 244 | # Cython debug symbols 245 | cython_debug/ 246 | 247 | # PyCharm 248 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 249 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 250 | # and can be added to the global gitignore or merged into this file. For a more nuclear 251 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 252 | #.idea/ 253 | 254 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all -------------------------------------------------------------------------------- /.wheels/torchtext-0.12.0a0+2683aa8-cp39-cp39-macosx_12_0_universal2.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coaxsoft/pytorch_bert/f4b8fba2c7e0fe77986c65d1ef6fe8f8b08cc313/.wheels/torchtext-0.12.0a0+2683aa8-cp39-cp39-macosx_12_0_universal2.whl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Building BERT with PyTorch from scratch 2 | 3 | ![img](https://uploads-ssl.webflow.com/60100d26d33c7cce48258afd/6244769a9ec65d641e367414_BERT%20with%20PyTorch.png) 4 | 5 | This is the repository containing the code for a tutorial 6 | 7 | [Building BERT with PyTorch from scratch](https://coaxsoft.com/blog/building-bert-with-pytorch-from-scratch) 8 | 9 | ## Installation 10 | 11 | After you clone the repository and setup virtual environment, 12 | install dependencies 13 | 14 | ```shell 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Installation on Mac M1 19 | 20 | You may experience difficulties installing `tensorboard`. 21 | Tensorboard requires `grpcio` that should be installed with extra environment 22 | variables. Read more in [StackOverflow](https://stackoverflow.com/questions/66640705/how-can-i-install-grpcio-on-an-apple-m1-silicon-laptop). 23 | 24 | So, your installation line for Mac M1 should look like 25 | 26 | ```shell 27 | export GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 28 | export GRPC_PYTHON_BUILD_SYSTEM_ZLIB=1 29 | 30 | pip install -r requirements.txt 31 | ``` 32 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | 4 | from random import choice, sample 5 | 6 | import pandas as pd 7 | import torch 8 | import torch.nn.functional as f 9 | 10 | from torch import nn 11 | from torch.utils.data import Dataset 12 | 13 | text = ( 14 | 'Hello, how are you? I am Romeo.\n' 15 | 'Hello, Romeo My name is Juliet. Nice to meet you.\n' 16 | 'Nice meet you too. How are you today?\n' 17 | 'Great. My baseball team won the competition.\n' 18 | 'Oh Congratulations, Juliet\n' 19 | 'Thanks you Romeo' 20 | ) 21 | SENTENCES = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filter '.', ',', '?', '!' 22 | word_list = list(set(" ".join(SENTENCES).split())) 23 | 24 | word_index = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3} 25 | index_word = {v: i for i, v in word_index.items()} 26 | 27 | 28 | def max_len(sents): 29 | m = 0 30 | for v in sents: 31 | v = v.split() 32 | m = len(v) if len(v) > m else m 33 | return m 34 | 35 | 36 | def pad_sentence(sentence, length): 37 | l = len(sentence) 38 | if l < length: 39 | sentence += ["[PAD]"] * (length - l) 40 | return sentence 41 | 42 | 43 | def mask_sentence(size, mask_indices): 44 | m = [False for _ in range(size)] 45 | for v in mask_indices: 46 | m[v] = True 47 | return m 48 | 49 | 50 | def preprocess_sentences(sentences): 51 | sent = [] 52 | mask = [] 53 | max_l = max_len(sentences) + 1 54 | 55 | for sentence in sentences: 56 | sentence = sentence.split() 57 | p = round(len(sentence) * 0.15) 58 | mask_indices = sample(range(len(sentence)), p) 59 | 60 | p = round(len(mask_indices) * 0.15) 61 | mask_wrong_indices = sample(mask_indices, p) 62 | 63 | for v in mask_indices: 64 | if v in mask_wrong_indices: 65 | symb = choice(word_list) 66 | else: 67 | symb = "[MASK]" 68 | sentence[v] = symb 69 | 70 | mask.append(mask_sentence(max_l, mask_indices)) 71 | 72 | s = ["[CLS]"] + sentence 73 | s = pad_sentence(s, max_l) 74 | 75 | sent.append(s) 76 | return sent, mask 77 | 78 | 79 | def form_ds(sentences): 80 | x, y = [], [] 81 | for i in range(len(sentences) - 1): 82 | x.append(sentences[i] + sentences[i + 1]) 83 | y.append(1) 84 | 85 | for i in range(len(sentences) - 1): 86 | new_sentences = [v for j, v in enumerate(sentences) if j != i + 1] 87 | s = choice(new_sentences) 88 | x.append(sentences[i] + s) 89 | y.append(0) 90 | return x, y 91 | 92 | 93 | def tokenize(sentences): 94 | res = [] 95 | for s in sentences: 96 | tokens = [word_index[v] for v in s] 97 | res.append(tokens) 98 | return res 99 | 100 | 101 | def get_attn_pad_mask(seq_q): 102 | return seq_q.data.eq(0) 103 | 104 | 105 | class JointEmbedding(nn.Module): 106 | SEGMENTS = 2 # 0 - first sentence, 1 - second sentence. 2 is amount of segments 107 | 108 | def __init__(self, vocab_size, size): 109 | super(JointEmbedding, self).__init__() 110 | 111 | self.token_emb = nn.Embedding(vocab_size, size) 112 | self.segment_emb = nn.Embedding(vocab_size, size) 113 | self.position_emb = nn.Embedding(vocab_size, size) 114 | 115 | self.norm = nn.LayerNorm(size) 116 | 117 | def forward(self, input_tensor, segment_tensor): 118 | # TODO: apply sin - cos functions as stated in `Attention is all you need` 119 | pos_tensor = torch.arange(input_tensor.size(1), dtype=torch.long) 120 | pos_tensor = pos_tensor.expand_as(input_tensor) 121 | 122 | output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + self.position_emb(pos_tensor) 123 | return self.norm(output) 124 | 125 | 126 | class EncoderLayer(nn.Module): 127 | 128 | def __init__(self, dim, dim_ff, dropout=0.1, num_heads=4): 129 | super(EncoderLayer, self).__init__() 130 | 131 | dim_q = dim_k = max(dim // num_heads, 1) 132 | 133 | self.multi_head = MultiHeadAttention(num_heads, dim, dim_q, dim_k) 134 | self.position_feed_forward = nn.Sequential( 135 | nn.Linear(dim, dim_ff), 136 | nn.ReLU(), 137 | nn.Linear(dim_ff, dim), 138 | nn.Dropout(dropout) 139 | ) 140 | 141 | def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor): 142 | output = self.multi_head(input_tensor, attention_mask) 143 | return self.position_feed_forward(output) 144 | 145 | 146 | class MultiHeadAttention(nn.Module): 147 | 148 | def __init__(self, num_heads, dim_inp, dim_q, dim_k): 149 | super(MultiHeadAttention, self).__init__() 150 | 151 | self.heads = nn.ModuleList([ 152 | AttentionHead(dim_inp, dim_q, dim_k) for _ in range(num_heads) 153 | ]) 154 | self.linear = nn.Linear(dim_q * num_heads, dim_inp) 155 | 156 | def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor): 157 | scores = torch.cat([h(input_tensor, attention_mask) for h in self.heads], dim=-1) 158 | return self.linear(scores) 159 | 160 | 161 | class AttentionHead(nn.Module): 162 | 163 | def __init__(self, dim_inp, dim_q, dim_k): 164 | super(AttentionHead, self).__init__() 165 | 166 | self.dim_inp = dim_inp 167 | 168 | self.q = nn.Linear(dim_inp, dim_q) 169 | self.k = nn.Linear(dim_inp, dim_k) 170 | self.v = nn.Linear(dim_inp, dim_k) 171 | 172 | def forward(self, input_tensor: torch.Tensor, attention_mask): 173 | # input_tensor = input_tensor.squeeze(0) 174 | query, key, value = self.q(input_tensor), self.k(input_tensor), self.v(input_tensor) 175 | 176 | scale = query.size(1) ** 0.5 177 | scores = torch.bmm(query, key.transpose(1, 2)) / scale 178 | scores = scores.masked_fill_(attention_mask, -1e9) 179 | attn = f.softmax(scores, dim=-1) 180 | context = torch.bmm(attn, value) 181 | 182 | return context 183 | 184 | 185 | class BERT(nn.Module): 186 | 187 | def __init__(self, vocab_size, size, encoder_size, num_heads=4): 188 | super(BERT, self).__init__() 189 | 190 | self.embedding = JointEmbedding(vocab_size, size) 191 | self.encoder = EncoderLayer(size, encoder_size, num_heads=num_heads) 192 | 193 | self.token_prediction_layer = nn.Linear(size, 1) 194 | self.classification_layer = nn.Linear(size, 2) 195 | 196 | def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor): 197 | embedded = self.embedding(input_tensor, input_tensor) 198 | encoded_sources = self.encoder(embedded, attention_mask) 199 | 200 | classification_embedding = encoded_sources[:, 0, :] 201 | classification_output = self.classification_layer(classification_embedding) 202 | 203 | token_output = self.token_prediction_layer(encoded_sources) 204 | 205 | return classification_output, token_output 206 | 207 | 208 | if __name__ == '__main__': 209 | for i, w in enumerate(word_list): 210 | word_index[w] = i + 4 211 | index_word[word_index[w]] = w 212 | vocab_size = len(word_index) 213 | 214 | ds, mask = preprocess_sentences(SENTENCES) 215 | x_s, y = form_ds(ds) 216 | x = tokenize(x_s) 217 | 218 | inp = torch.tensor(x, dtype=torch.long) 219 | 220 | inp_mask = get_attn_pad_mask(inp) 221 | 222 | emb_size = 64 223 | encoder_size = 12 224 | 225 | bert = BERT(len(word_index), emb_size, encoder_size, num_heads=4).to('cpu') 226 | criterion = nn.CrossEntropyLoss() 227 | optimizer = torch.optim.Adam(bert.parameters(), lr=0.0001) 228 | optimizer.zero_grad() 229 | 230 | ds_size = inp.size(0) 231 | 232 | for i in range(20000): 233 | j = random.randint(0, ds_size - 1) 234 | inp_x, inp_y = inp[j, :].unsqueeze(0), torch.Tensor([y[j]]).long() 235 | inp_mask = get_attn_pad_mask(inp_x) 236 | class_out, token_out = bert(inp_x, inp_mask) 237 | 238 | loss = criterion(class_out, inp_y) 239 | loss.backward() 240 | optimizer.step() 241 | 242 | if i % 100 == 0: 243 | print(f"Epoch {i}. Loss {loss}") 244 | 245 | for j in range(len(x)): 246 | inp_x, inp_y = inp[j, :].unsqueeze(0), torch.Tensor([y[j]]).long() 247 | class_out, token_out = bert(inp_x, inp_mask) 248 | print(inp_y, class_out.argmax()) 249 | -------------------------------------------------------------------------------- /bert/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing 3 | from collections import Counter 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset 12 | from torchtext.vocab import vocab 13 | from torchtext.data.utils import get_tokenizer 14 | 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | class IMDBBertDataset(Dataset): 20 | CLS = '[CLS]' 21 | PAD = '[PAD]' 22 | SEP = '[SEP]' 23 | MASK = '[MASK]' 24 | UNK = '[UNK]' 25 | 26 | MASK_PERCENTAGE = 0.15 27 | 28 | MASKED_INDICES_COLUMN = 'masked_indices' 29 | TARGET_COLUMN = 'indices' 30 | NSP_TARGET_COLUMN = 'is_next' 31 | TOKEN_MASK_COLUMN = 'token_mask' 32 | 33 | OPTIMAL_LENGTH_PERCENTILE = 70 34 | 35 | def __init__(self, path, ds_from=None, ds_to=None, should_include_text=False): 36 | self.ds: pd.Series = pd.read_csv(path)['review'] 37 | 38 | if ds_from is not None or ds_to is not None: 39 | self.ds = self.ds[ds_from:ds_to] 40 | 41 | self.tokenizer = get_tokenizer('basic_english') 42 | self.counter = Counter() 43 | self.vocab = None 44 | 45 | self.optimal_sentence_length = None 46 | self.should_include_text = should_include_text 47 | 48 | if should_include_text: 49 | self.columns = ['masked_sentence', self.MASKED_INDICES_COLUMN, 'sentence', self.TARGET_COLUMN, 50 | self.TOKEN_MASK_COLUMN, 51 | self.NSP_TARGET_COLUMN] 52 | else: 53 | self.columns = [self.MASKED_INDICES_COLUMN, self.TARGET_COLUMN, self.TOKEN_MASK_COLUMN, 54 | self.NSP_TARGET_COLUMN] 55 | 56 | self.df = self.prepare_dataset() 57 | 58 | def __len__(self): 59 | return len(self.df) 60 | 61 | def __getitem__(self, idx): 62 | item = self.df.iloc[idx] 63 | 64 | inp = torch.Tensor(item[self.MASKED_INDICES_COLUMN]).long() 65 | token_mask = torch.Tensor(item[self.TOKEN_MASK_COLUMN]).bool() 66 | 67 | mask_target = torch.Tensor(item[self.TARGET_COLUMN]).long() 68 | mask_target = mask_target.masked_fill_(token_mask, 0) 69 | 70 | attention_mask = (inp == self.vocab[self.PAD]).unsqueeze(0) 71 | 72 | if item[self.NSP_TARGET_COLUMN] == 0: 73 | t = [1, 0] 74 | else: 75 | t = [0, 1] 76 | 77 | nsp_target = torch.Tensor(t) 78 | 79 | return ( 80 | inp.to(device), 81 | attention_mask.to(device), 82 | token_mask.to(device), 83 | mask_target.to(device), 84 | nsp_target.to(device) 85 | ) 86 | 87 | def prepare_dataset(self) -> pd.DataFrame: 88 | sentences = [] 89 | nsp = [] 90 | sentence_lens = [] 91 | 92 | # Split dataset on sentences 93 | for review in self.ds: 94 | review_sentences = review.split('. ') 95 | sentences += review_sentences 96 | self._update_length(review_sentences, sentence_lens) 97 | self.optimal_sentence_length = self._find_optimal_sentence_length(sentence_lens) 98 | 99 | print("Create vocabulary") 100 | for sentence in tqdm(sentences): 101 | s = self.tokenizer(sentence) 102 | self.counter.update(s) 103 | 104 | self._fill_vocab() 105 | 106 | print("Preprocessing dataset") 107 | for review in tqdm(self.ds): 108 | review_sentences = review.split('. ') 109 | if len(review_sentences) > 1: 110 | for i in range(len(review_sentences) - 1): 111 | # True NSP item 112 | first, second = self.tokenizer(review_sentences[i]), self.tokenizer(review_sentences[i + 1]) 113 | nsp.append(self._create_item(first, second, 1)) 114 | 115 | # False NSP item 116 | first, second = self._select_false_nsp_sentences(sentences) 117 | first, second = self.tokenizer(first), self.tokenizer(second) 118 | nsp.append(self._create_item(first, second, 0)) 119 | df = pd.DataFrame(nsp, columns=self.columns) 120 | return df 121 | 122 | def _update_length(self, sentences: typing.List[str], lengths: typing.List[int]): 123 | for v in sentences: 124 | l = len(v.split()) 125 | lengths.append(l) 126 | return lengths 127 | 128 | def _find_optimal_sentence_length(self, lengths: typing.List[int]): 129 | arr = np.array(lengths) 130 | return int(np.percentile(arr, self.OPTIMAL_LENGTH_PERCENTILE)) 131 | 132 | def _fill_vocab(self): 133 | # specials= argument is only in 0.12.0 version 134 | # specials=[self.CLS, self.PAD, self.MASK, self.SEP, self.UNK] 135 | self.vocab = vocab(self.counter, min_freq=2) 136 | 137 | # 0.11.0 uses this approach to insert specials 138 | self.vocab.insert_token(self.CLS, 0) 139 | self.vocab.insert_token(self.PAD, 1) 140 | self.vocab.insert_token(self.MASK, 2) 141 | self.vocab.insert_token(self.SEP, 3) 142 | self.vocab.insert_token(self.UNK, 4) 143 | self.vocab.set_default_index(4) 144 | 145 | def _create_item(self, first: typing.List[str], second: typing.List[str], target: int = 1): 146 | # Create masked sentence item 147 | updated_first, first_mask = self._preprocess_sentence(first.copy()) 148 | updated_second, second_mask = self._preprocess_sentence(second.copy()) 149 | 150 | nsp_sentence = updated_first + [self.SEP] + updated_second 151 | nsp_indices = self.vocab.lookup_indices(nsp_sentence) 152 | inverse_token_mask = first_mask + [True] + second_mask 153 | 154 | # Create sentence item without masking random words 155 | first, _ = self._preprocess_sentence(first.copy(), should_mask=False) 156 | second, _ = self._preprocess_sentence(second.copy(), should_mask=False) 157 | original_nsp_sentence = first + [self.SEP] + second 158 | original_nsp_indices = self.vocab.lookup_indices(original_nsp_sentence) 159 | 160 | if self.should_include_text: 161 | return ( 162 | nsp_sentence, 163 | nsp_indices, 164 | original_nsp_sentence, 165 | original_nsp_indices, 166 | inverse_token_mask, 167 | target 168 | ) 169 | else: 170 | return ( 171 | nsp_indices, 172 | original_nsp_indices, 173 | inverse_token_mask, 174 | target 175 | ) 176 | 177 | def _select_false_nsp_sentences(self, sentences: typing.List[str]): 178 | """Select sentences to create false NSP item 179 | 180 | Args: 181 | sentences: list of all sentences 182 | 183 | Returns: 184 | tuple of two sentences. The second one NOT the next sentence 185 | """ 186 | sentences_len = len(sentences) 187 | sentence_index = random.randint(0, sentences_len - 1) 188 | next_sentence_index = random.randint(0, sentences_len - 1) 189 | 190 | # To be sure that it's not real next sentence 191 | while next_sentence_index == sentence_index + 1: 192 | next_sentence_index = random.randint(0, sentences_len - 1) 193 | 194 | return sentences[sentence_index], sentences[next_sentence_index] 195 | 196 | def _preprocess_sentence(self, sentence: typing.List[str], should_mask: bool = True): 197 | inverse_token_mask = None 198 | if should_mask: 199 | sentence, inverse_token_mask = self._mask_sentence(sentence) 200 | sentence, inverse_token_mask = self._pad_sentence([self.CLS] + sentence, [True] + inverse_token_mask) 201 | 202 | return sentence, inverse_token_mask 203 | 204 | def _mask_sentence(self, sentence: typing.List[str]): 205 | """Replace MASK_PERCENTAGE (15%) of words with special [MASK] symbol 206 | or with random word from vocabulary 207 | 208 | Args: 209 | sentence: sentence to process 210 | 211 | Returns: 212 | tuple of processed sentence and inverse token mask 213 | """ 214 | len_s = len(sentence) 215 | inverse_token_mask = [True for _ in range(max(len_s, self.optimal_sentence_length))] 216 | 217 | mask_amount = round(len_s * self.MASK_PERCENTAGE) 218 | for _ in range(mask_amount): 219 | i = random.randint(0, len_s - 1) 220 | 221 | if random.random() < 0.8: 222 | sentence[i] = self.MASK 223 | else: 224 | # All is below 5 is special token 225 | # see self._insert_specials method 226 | j = random.randint(5, len(self.vocab) - 1) 227 | sentence[i] = self.vocab.lookup_token(j) 228 | inverse_token_mask[i] = False 229 | return sentence, inverse_token_mask 230 | 231 | def _pad_sentence(self, sentence: typing.List[str], inverse_token_mask: typing.List[bool] = None): 232 | len_s = len(sentence) 233 | 234 | if len_s >= self.optimal_sentence_length: 235 | s = sentence[:self.optimal_sentence_length] 236 | else: 237 | s = sentence + [self.PAD] * (self.optimal_sentence_length - len_s) 238 | 239 | # inverse token mask should be padded as well 240 | if inverse_token_mask: 241 | len_m = len(inverse_token_mask) 242 | if len_m >= self.optimal_sentence_length: 243 | inverse_token_mask = inverse_token_mask[:self.optimal_sentence_length] 244 | else: 245 | inverse_token_mask = inverse_token_mask + [True] * (self.optimal_sentence_length - len_m) 246 | return s, inverse_token_mask 247 | 248 | 249 | if __name__ == '__main__': 250 | BASE_DIR = Path(__file__).resolve().parent.parent 251 | 252 | ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=50000, 253 | should_include_text=True) 254 | print(ds.df) 255 | -------------------------------------------------------------------------------- /bert/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | import torch.nn.functional as f 5 | 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | class JointEmbedding(nn.Module): 11 | 12 | def __init__(self, vocab_size, size): 13 | super(JointEmbedding, self).__init__() 14 | 15 | self.size = size 16 | 17 | self.token_emb = nn.Embedding(vocab_size, size) 18 | self.segment_emb = nn.Embedding(vocab_size, size) 19 | 20 | self.norm = nn.LayerNorm(size) 21 | 22 | def forward(self, input_tensor): 23 | sentence_size = input_tensor.size(-1) 24 | pos_tensor = self.attention_position(self.size, input_tensor) 25 | 26 | segment_tensor = torch.zeros_like(input_tensor).to(device) 27 | segment_tensor[:, sentence_size // 2 + 1:] = 1 28 | 29 | output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + pos_tensor 30 | return self.norm(output) 31 | 32 | def attention_position(self, dim, input_tensor): 33 | batch_size = input_tensor.size(0) 34 | sentence_size = input_tensor.size(-1) 35 | 36 | pos = torch.arange(sentence_size, dtype=torch.long).to(device) 37 | d = torch.arange(dim, dtype=torch.long).to(device) 38 | d = (2 * d / dim) 39 | 40 | pos = pos.unsqueeze(1) 41 | pos = pos / (1e4 ** d) 42 | 43 | pos[:, ::2] = torch.sin(pos[:, ::2]) 44 | pos[:, 1::2] = torch.cos(pos[:, 1::2]) 45 | 46 | return pos.expand(batch_size, *pos.size()) 47 | 48 | def numeric_position(self, dim, input_tensor): 49 | pos_tensor = torch.arange(dim, dtype=torch.long).to(device) 50 | return pos_tensor.expand_as(input_tensor) 51 | 52 | 53 | class AttentionHead(nn.Module): 54 | 55 | def __init__(self, dim_inp, dim_out): 56 | super(AttentionHead, self).__init__() 57 | 58 | self.dim_inp = dim_inp 59 | 60 | self.q = nn.Linear(dim_inp, dim_out) 61 | self.k = nn.Linear(dim_inp, dim_out) 62 | self.v = nn.Linear(dim_inp, dim_out) 63 | 64 | def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor = None): 65 | query, key, value = self.q(input_tensor), self.k(input_tensor), self.v(input_tensor) 66 | 67 | scale = query.size(1) ** 0.5 68 | scores = torch.bmm(query, key.transpose(1, 2)) / scale 69 | 70 | scores = scores.masked_fill_(attention_mask, -1e9) 71 | attn = f.softmax(scores, dim=-1) 72 | context = torch.bmm(attn, value) 73 | 74 | return context 75 | 76 | 77 | class MultiHeadAttention(nn.Module): 78 | 79 | def __init__(self, num_heads, dim_inp, dim_out): 80 | super(MultiHeadAttention, self).__init__() 81 | 82 | self.heads = nn.ModuleList([ 83 | AttentionHead(dim_inp, dim_out) for _ in range(num_heads) 84 | ]) 85 | self.linear = nn.Linear(dim_out * num_heads, dim_inp) 86 | self.norm = nn.LayerNorm(dim_inp) 87 | 88 | def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor): 89 | s = [head(input_tensor, attention_mask) for head in self.heads] 90 | scores = torch.cat(s, dim=-1) 91 | scores = self.linear(scores) 92 | return self.norm(scores) 93 | 94 | 95 | class Encoder(nn.Module): 96 | 97 | def __init__(self, dim_inp, dim_out, attention_heads=4, dropout=0.1): 98 | super(Encoder, self).__init__() 99 | 100 | self.attention = MultiHeadAttention(attention_heads, dim_inp, dim_out) # batch_size x sentence size x dim_inp 101 | self.feed_forward = nn.Sequential( 102 | nn.Linear(dim_inp, dim_out), 103 | nn.Dropout(dropout), 104 | nn.GELU(), 105 | nn.Linear(dim_out, dim_inp), 106 | nn.Dropout(dropout) 107 | ) 108 | self.norm = nn.LayerNorm(dim_inp) 109 | 110 | def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor): 111 | context = self.attention(input_tensor, attention_mask) 112 | res = self.feed_forward(context) 113 | return self.norm(res) 114 | 115 | 116 | class BERT(nn.Module): 117 | 118 | def __init__(self, vocab_size, dim_inp, dim_out, attention_heads=4): 119 | super(BERT, self).__init__() 120 | 121 | self.embedding = JointEmbedding(vocab_size, dim_inp) 122 | self.encoder = Encoder(dim_inp, dim_out, attention_heads) 123 | 124 | self.token_prediction_layer = nn.Linear(dim_inp, vocab_size) 125 | self.softmax = nn.LogSoftmax(dim=-1) 126 | self.classification_layer = nn.Linear(dim_inp, 2) 127 | 128 | def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor): 129 | embedded = self.embedding(input_tensor) 130 | encoded = self.encoder(embedded, attention_mask) 131 | 132 | token_predictions = self.token_prediction_layer(encoded) 133 | 134 | first_word = encoded[:, 0, :] 135 | return self.softmax(token_predictions), self.classification_layer(first_word) 136 | -------------------------------------------------------------------------------- /bert/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from bert.dataset import IMDBBertDataset 12 | from bert.model import BERT 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | def percentage(batch_size: int, max_index: int, current_index: int): 18 | """Calculate epoch progress percentage 19 | 20 | Args: 21 | batch_size: batch size 22 | max_index: max index in epoch 23 | current_index: current index 24 | 25 | Returns: 26 | Passed percentage of dataset 27 | """ 28 | batched_max = max_index // batch_size 29 | return round(current_index / batched_max * 100, 2) 30 | 31 | 32 | def nsp_accuracy(result: torch.Tensor, target: torch.Tensor): 33 | """Calculate NSP accuracy between two tensors 34 | 35 | Args: 36 | result: result calculated by model 37 | target: real target 38 | 39 | Returns: 40 | NSP accuracy 41 | """ 42 | s = (result.argmax(1) == target.argmax(1)).sum() 43 | return round(float(s / result.size(0)), 2) 44 | 45 | 46 | def token_accuracy(result: torch.Tensor, target: torch.Tensor, inverse_token_mask: torch.Tensor): 47 | """Calculate MLM accuracy between ONLY masked words 48 | 49 | Args: 50 | result: result calculated by model 51 | target: real target 52 | inverse_token_mask: well-known inverse token mask 53 | 54 | Returns: 55 | MLM accuracy 56 | """ 57 | r = result.argmax(-1).masked_select(~inverse_token_mask) 58 | t = target.masked_select(~inverse_token_mask) 59 | s = (r == t).sum() 60 | return round(float(s / (result.size(0) * result.size(1))), 2) 61 | 62 | 63 | class BertTrainer: 64 | 65 | def __init__(self, 66 | model: BERT, 67 | dataset: IMDBBertDataset, 68 | log_dir: Path, 69 | checkpoint_dir: Path = None, 70 | print_progress_every: int = 10, 71 | print_accuracy_every: int = 50, 72 | batch_size: int = 24, 73 | learning_rate: float = 0.005, 74 | epochs: int = 5, 75 | ): 76 | self.model = model 77 | self.dataset = dataset 78 | 79 | self.batch_size = batch_size 80 | self.epochs = epochs 81 | self.current_epoch = 0 82 | 83 | self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True) 84 | 85 | self.writer = SummaryWriter(str(log_dir)) 86 | self.checkpoint_dir = checkpoint_dir 87 | 88 | self.criterion = nn.BCEWithLogitsLoss().to(device) 89 | self.ml_criterion = nn.NLLLoss(ignore_index=0).to(device) 90 | self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.015) 91 | 92 | self._splitter_size = 35 93 | 94 | self._ds_len = len(self.dataset) 95 | self._batched_len = self._ds_len // self.batch_size 96 | 97 | self._print_every = print_progress_every 98 | self._accuracy_every = print_accuracy_every 99 | 100 | def print_summary(self): 101 | ds_len = len(self.dataset) 102 | 103 | print("Model Summary\n") 104 | print('=' * self._splitter_size) 105 | print(f"Device: {device}") 106 | print(f"Training dataset len: {ds_len}") 107 | print(f"Max / Optimal sentence len: {self.dataset.optimal_sentence_length}") 108 | print(f"Vocab size: {len(self.dataset.vocab)}") 109 | print(f"Batch size: {self.batch_size}") 110 | print(f"Batched dataset len: {self._batched_len}") 111 | print('=' * self._splitter_size) 112 | print() 113 | 114 | def __call__(self): 115 | for self.current_epoch in range(self.current_epoch, self.epochs): 116 | loss = self.train(self.current_epoch) 117 | self.save_checkpoint(self.current_epoch, step=-1, loss=loss) 118 | 119 | def train(self, epoch: int): 120 | print(f"Begin epoch {epoch}") 121 | 122 | prev = time.time() 123 | average_nsp_loss = 0 124 | average_mlm_loss = 0 125 | for i, value in enumerate(self.loader): 126 | index = i + 1 127 | inp, mask, inverse_token_mask, token_target, nsp_target = value 128 | self.optimizer.zero_grad() 129 | 130 | token, nsp = self.model(inp, mask) 131 | 132 | tm = inverse_token_mask.unsqueeze(-1).expand_as(token) 133 | token = token.masked_fill(tm, 0) 134 | 135 | loss_token = self.ml_criterion(token.transpose(1, 2), token_target) # 1D tensor as target is required 136 | loss_nsp = self.criterion(nsp, nsp_target) 137 | 138 | loss = loss_token + loss_nsp 139 | average_nsp_loss += loss_nsp 140 | average_mlm_loss += loss_token 141 | 142 | loss.backward() 143 | self.optimizer.step() 144 | 145 | if index % self._print_every == 0: 146 | elapsed = time.gmtime(time.time() - prev) 147 | s = self.training_summary(elapsed, index, average_nsp_loss, average_mlm_loss) 148 | 149 | if index % self._accuracy_every == 0: 150 | s += self.accuracy_summary(index, token, nsp, token_target, nsp_target, inverse_token_mask) 151 | 152 | print(s) 153 | 154 | average_nsp_loss = 0 155 | average_mlm_loss = 0 156 | return loss 157 | 158 | def training_summary(self, elapsed, index, average_nsp_loss, average_mlm_loss): 159 | passed = percentage(self.batch_size, self._ds_len, index) 160 | global_step = self.current_epoch * len(self.loader) + index 161 | 162 | print_nsp_loss = average_nsp_loss / self._print_every 163 | print_mlm_loss = average_mlm_loss / self._print_every 164 | 165 | s = f"{time.strftime('%H:%M:%S', elapsed)}" 166 | s += f" | Epoch {self.current_epoch + 1} | {index} / {self._batched_len} ({passed}%) | " \ 167 | f"NSP loss {print_nsp_loss:6.2f} | MLM loss {print_mlm_loss:6.2f}" 168 | 169 | self.writer.add_scalar("NSP loss", print_nsp_loss, global_step=global_step) 170 | self.writer.add_scalar("MLM loss", print_mlm_loss, global_step=global_step) 171 | return s 172 | 173 | def accuracy_summary(self, index, token, nsp, token_target, nsp_target, inverse_token_mask): 174 | global_step = self.current_epoch * len(self.loader) + index 175 | nsp_acc = nsp_accuracy(nsp, nsp_target) 176 | token_acc = token_accuracy(token, token_target, inverse_token_mask) 177 | 178 | self.writer.add_scalar("NSP train accuracy", nsp_acc, global_step=global_step) 179 | self.writer.add_scalar("Token train accuracy", token_acc, global_step=global_step) 180 | 181 | return f" | NSP accuracy {nsp_acc} | Token accuracy {token_acc}" 182 | 183 | def save_checkpoint(self, epoch, step, loss): 184 | if not self.checkpoint_dir: 185 | return 186 | 187 | prev = time.time() 188 | name = f"bert_epoch{epoch}_step{step}_{datetime.utcnow().timestamp():.0f}.pt" 189 | 190 | torch.save({ 191 | 'epoch': epoch, 192 | 'model_state_dict': self.model.state_dict(), 193 | 'optimizer_state_dict': self.optimizer.state_dict(), 194 | 'loss': loss, 195 | }, self.checkpoint_dir.joinpath(name)) 196 | 197 | print() 198 | print('=' * self._splitter_size) 199 | print(f"Model saved as '{name}' for {time.time() - prev:.2f}s") 200 | print('=' * self._splitter_size) 201 | print() 202 | 203 | def load_checkpoint(self, path: Path): 204 | print('=' * self._splitter_size) 205 | print(f"Restoring model {path}") 206 | checkpoint = torch.load(path) 207 | self.current_epoch = checkpoint['epoch'] 208 | self.model.load_state_dict(checkpoint['model_state_dict']) 209 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 210 | print("Model is restored.") 211 | print('=' * self._splitter_size) 212 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pathlib import Path 4 | from torch.utils.data import DataLoader 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | from bert.dataset import IMDBBertDataset 8 | from bert.model import BERT 9 | 10 | BASE_DIR = Path(__file__).resolve().parent 11 | 12 | EMB_SIZE = 64 13 | HIDDEN_SIZE = 36 14 | EPOCHS = 4 15 | BATCH_SIZE = 12 16 | NUM_HEADS = 4 17 | LOG_DIR = BASE_DIR.joinpath(f'data/logs/graph') 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | if torch.cuda.is_available(): 22 | torch.cuda.empty_cache() 23 | 24 | if __name__ == '__main__': 25 | print("Prepare dataset") 26 | ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=5) 27 | loader = DataLoader(ds, batch_size=1, shuffle=False) 28 | 29 | bert = BERT(len(ds.vocab), EMB_SIZE, HIDDEN_SIZE, NUM_HEADS).to(device) 30 | writer = SummaryWriter(str(LOG_DIR)) 31 | 32 | inp, mask, inverse_token_mask, token_target, nsp_target = next(iter(loader)) 33 | 34 | writer.add_graph(bert, input_to_model=[inp, mask]) 35 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import torch 4 | 5 | from pathlib import Path 6 | 7 | from bert.dataset import IMDBBertDataset 8 | from bert.model import BERT 9 | from bert.trainer import BertTrainer 10 | 11 | BASE_DIR = Path(__file__).resolve().parent 12 | 13 | EMB_SIZE = 64 14 | HIDDEN_SIZE = 36 15 | EPOCHS = 4 16 | BATCH_SIZE = 12 17 | NUM_HEADS = 4 18 | 19 | CHECKPOINT_DIR = BASE_DIR.joinpath('data/bert_checkpoints') 20 | 21 | timestamp = datetime.datetime.utcnow().timestamp() 22 | LOG_DIR = BASE_DIR.joinpath(f'data/logs/bert_experiment_{timestamp}') 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | if torch.cuda.is_available(): 27 | torch.cuda.empty_cache() 28 | 29 | if __name__ == '__main__': 30 | print("Prepare dataset") 31 | ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=1000) 32 | 33 | bert = BERT(len(ds.vocab), EMB_SIZE, HIDDEN_SIZE, NUM_HEADS).to(device) 34 | trainer = BertTrainer( 35 | model=bert, 36 | dataset=ds, 37 | log_dir=LOG_DIR, 38 | checkpoint_dir=CHECKPOINT_DIR, 39 | print_progress_every=20, 40 | print_accuracy_every=200, 41 | batch_size=BATCH_SIZE, 42 | learning_rate=0.00007, 43 | epochs=15 44 | ) 45 | 46 | trainer.print_summary() 47 | trainer() 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | torch 4 | torchtext 5 | tensorboard --------------------------------------------------------------------------------