├── img └── twinbert.png ├── README.md └── TwinBert.ipynb /img/twinbert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashishsalunkhe/TwinBert/master/img/twinbert.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TwinBert-Pytorch 2 | pytorch implementation of the TwinBert paper (https://arxiv.org/pdf/2002.06275v1.pdf) 3 | 4 | This notebook was created to train a Siamese Bert architecture to find similar pair of text documents. The authors of the paper have used this architecture to create a backend model for a sponsored search system. The goal was to display a list of ads that best match user’s intent. 5 | 6 | Due to lack of data, i have trained this model on the Quora Questions Pairs Dataset (https://www.kaggle.com/c/quora-question-pairs) 7 | 8 | -------------------------------------------------------------------------------- /TwinBert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "##### TwinBert https://arxiv.org/pdf/2002.06275v1.pdf" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", 15 | "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a" 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import pandas as pd\n", 21 | "from sklearn import metrics\n", 22 | "import transformers\n", 23 | "import torch\n", 24 | "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", 25 | "from transformers import BertTokenizer, BertModel, BertConfig\n", 26 | "import torch.nn as nn\n", 27 | "from torch import optim\n", 28 | "import torch.nn.functional as F" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "#### The dataset is Quora questions pairs dataset. Ideally in the paper the authors have trained the model to use it as a backend for a sponsored search engine, to delivers ads alongside the organic search results." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "df = pd.read_csv(\"train.csv\") # Dataset : https://www.kaggle.com/c/quora-question-pairs" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Dataset Loader for the Siamese Network " 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "class SiameseNetworkDataset(Dataset):\n", 61 | " def __init__(self, dataframe, tokenizer, max_len):\n", 62 | " self.tokenizer = tokenizer\n", 63 | " self.data = dataframe\n", 64 | " self.question1 = dataframe.question1\n", 65 | " self.question2 = dataframe.question2\n", 66 | " self.targets = dataframe.is_duplicate\n", 67 | " self.max_len = max_len\n", 68 | " \n", 69 | " def __len__(self):\n", 70 | " return len(self.data)\n", 71 | " \n", 72 | " \n", 73 | " def tokenize(self,input_text):\n", 74 | " input_text = \" \".join(input_text.split())\n", 75 | "\n", 76 | " inputs = self.tokenizer.encode_plus(\n", 77 | " input_text,\n", 78 | " None,\n", 79 | " add_special_tokens=True,\n", 80 | " max_length=self.max_len,\n", 81 | " pad_to_max_length=True,\n", 82 | " return_token_type_ids=True\n", 83 | " )\n", 84 | " ids = inputs['input_ids']\n", 85 | " mask = inputs['attention_mask']\n", 86 | " token_type_ids = inputs[\"token_type_ids\"]\n", 87 | " return ids,mask,token_type_ids\n", 88 | "\n", 89 | " def __getitem__(self, index):\n", 90 | " ids1,mask1,token_type_ids1 = self.tokenize(str(self.question1[index]))\n", 91 | " ids2,mask2,token_type_ids2 = self.tokenize(str(self.question1[index]))\n", 92 | " \n", 93 | "\n", 94 | "\n", 95 | " return {\n", 96 | " 'ids': [torch.tensor(ids1, dtype=torch.long),torch.tensor(ids2, dtype=torch.long)],\n", 97 | " 'mask': [torch.tensor(mask1, dtype=torch.long),torch.tensor(mask2, dtype=torch.long)],\n", 98 | " 'token_type_ids': [torch.tensor(token_type_ids1, dtype=torch.long),torch.tensor(token_type_ids2, dtype=torch.long)],\n", 99 | " 'targets': torch.tensor(self.targets[index], dtype=torch.float)\n", 100 | " }\n", 101 | " " 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## TwinBert architecture" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "class TwinBert(nn.Module):\n", 118 | " def __init__(self):\n", 119 | " super(TwinBert, self).__init__()\n", 120 | " self.model = transformers.BertModel.from_pretrained('bert-base-uncased')\n", 121 | " def forward_once(self, ids, mask, token_type_ids):\n", 122 | " _, output= self.model(ids, attention_mask = mask, token_type_ids = token_type_ids)\n", 123 | " return output\n", 124 | " def forward(self, ids, mask, token_type_ids):\n", 125 | " output1 = self.forward_once(ids[0],mask[0], token_type_ids[0])\n", 126 | " output2 = self.forward_once(ids[1],mask[1], token_type_ids[1])\n", 127 | " return output1,output2\n", 128 | " \n", 129 | "\n", 130 | " \n", 131 | "model = TwinBert()\n", 132 | "model.to(device)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "MAX_LEN = 200\n", 142 | "TRAIN_BATCH_SIZE = 8\n", 143 | "VALID_BATCH_SIZE = 4\n", 144 | "EPOCHS = 1\n", 145 | "LEARNING_RATE = 1e-05\n", 146 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "train_size = 0.8\n", 156 | "train_dataset=df.sample(frac=train_size,random_state=200).reset_index(drop=True)\n", 157 | "test_dataset=df.drop(train_dataset.index).reset_index(drop=True)\n", 158 | "\n", 159 | "\n", 160 | "print(\"FULL Dataset: {}\".format(df.shape))\n", 161 | "print(\"TRAIN Dataset: {}\".format(train_dataset.shape))\n", 162 | "print(\"TEST Dataset: {}\".format(test_dataset.shape))\n", 163 | "\n", 164 | "training_set = SiameseNetworkDataset(train_dataset, tokenizer, MAX_LEN)\n", 165 | "testing_set = SiameseNetworkDataset(test_dataset, tokenizer, MAX_LEN)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n", 175 | " 'shuffle': True,\n", 176 | " 'num_workers': 0\n", 177 | " }\n", 178 | "\n", 179 | "test_params = {'batch_size': VALID_BATCH_SIZE,\n", 180 | " 'shuffle': True,\n", 181 | " 'num_workers': 0\n", 182 | " }\n", 183 | "\n", 184 | "training_loader = DataLoader(training_set, **train_params)\n", 185 | "testing_loader = DataLoader(testing_set, **test_params)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "from torch import cuda\n", 195 | "device = 'cuda' if cuda.is_available() else 'cpu'" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "# Loss Function\n", 203 | "### A contrastive loss function that takes cosine similarity as a metric to measure the distance." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "class CosineContrastiveLoss(nn.Module):\n", 213 | " def __init__(self, margin=0.4):\n", 214 | " super(CosineContrastiveLoss, self).__init__()\n", 215 | " self.margin = margin\n", 216 | "\n", 217 | " def forward(self, output1, output2, label):\n", 218 | " cos_sim = F.cosine_similarity(output1, output2)\n", 219 | " loss_cos_con = torch.mean((1-label) * torch.div(torch.pow((1.0-cos_sim), 2), 4) +\n", 220 | " (label) * torch.pow(cos_sim * torch.lt(cos_sim, self.margin), 2))\n", 221 | " return loss_cos_con" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "criterion = CosineContrastiveLoss()\n", 231 | "optimizer = optim.Adam(model.parameters(),lr = 0.0005 )" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "# Training" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "def train(epoch):\n", 248 | " model.train()\n", 249 | " for _,data in enumerate(training_loader, 0):\n", 250 | " ids,mask,token_type_ids = data['ids'],data['mask'],data['token_type_ids'] \n", 251 | " targets = data['targets'].to(device, dtype = torch.float)\n", 252 | " ids = [ids[0].to(device, dtype = torch.long),ids[1].to(device, dtype = torch.long)]\n", 253 | " mask = [mask[0].to(device, dtype = torch.long),mask[1].to(device, dtype = torch.long)]\n", 254 | " token_type_ids = [token_type_ids[0].to(device, dtype = torch.long),token_type_ids[1].to(device, dtype = torch.long)]\n", 255 | " output1,output2 = model(ids, mask, token_type_ids)\n", 256 | " optimizer.zero_grad()\n", 257 | " loss = criterion(output1,output2,targets)\n", 258 | " if _%5000==0:\n", 259 | " print(f'Step: {_}, Epoch: {epoch}, Loss: {loss.item()}')\n", 260 | " \n", 261 | " optimizer.zero_grad()\n", 262 | " loss.backward()\n", 263 | " optimizer.step()" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "for epoch in range(EPOCHS):\n", 273 | " train(epoch)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "# Validation" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "def validation():\n", 290 | " model.eval()\n", 291 | " fin_targets=[]\n", 292 | " fin_outputs=[]\n", 293 | " with torch.no_grad():\n", 294 | " for _, data in enumerate(testing_loader, 0):\n", 295 | " ids,mask,token_type_ids = data['ids'],data['mask'],data['token_type_ids'] \n", 296 | " targets = data['targets'].to(device, dtype = torch.float)\n", 297 | " ids = [ids[0].to(device, dtype = torch.long),ids[1].to(device, dtype = torch.long)]\n", 298 | " mask = [mask[0].to(device, dtype = torch.long),mask[1].to(device, dtype = torch.long)]\n", 299 | " token_type_ids = [token_type_ids[0].to(device, dtype = torch.long),token_type_ids[1].to(device, dtype = torch.long)]\n", 300 | " targets = data['targets'].to(device, dtype = torch.float)\n", 301 | " output1,output2 = model(ids, mask, token_type_ids)\n", 302 | " cos_sim = F.cosine_similarity(output1, output2)\n", 303 | " in_targets.extend(targets.cpu().detach().numpy().tolist())\n", 304 | " fin_outputs.extend(torch.sigmoid(cos_sim).cpu().detach().numpy().tolist())\n", 305 | " return fin_outputs, fin_targets" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "outputs, targets = validation()\n", 315 | "outputs = np.array(outputs) >= 0.5\n", 316 | "accuracy = metrics.accuracy_score(targets, outputs)\n", 317 | "f1_score_micro = metrics.f1_score(targets, outputs, average='micro')\n", 318 | "f1_score_macro = metrics.f1_score(targets, outputs, average='macro')\n", 319 | "print(f\"Accuracy Score = {accuracy}\")\n", 320 | "print(f\"F1 Score (Micro) = {f1_score_micro}\")\n", 321 | "print(f\"F1 Score (Macro) = {f1_score_macro}\")" 322 | ] 323 | } 324 | ], 325 | "metadata": { 326 | "kernelspec": { 327 | "display_name": "Python 3", 328 | "language": "python", 329 | "name": "python3" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.7.4" 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 4 346 | } 347 | --------------------------------------------------------------------------------