├── .gitignore ├── LongBART.ipynb ├── longbart ├── __init__.py ├── configuration_bart.py ├── convert_bart_to_longbart.py ├── modeling_bart.py └── modeling_longbart.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /LongBART.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "LongBART", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm", 10 | "authorship_tag": "ABX9TyMBu/tl3uAemtoSjaCYca2U", 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "accelerator": "TPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "PkmDekoURprl", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "# LongBART" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "FNfNMSjTK61d", 44 | "colab_type": "code", 45 | "outputId": "f2278a45-26ee-48dd-ab62-4a34f8c56662", 46 | "colab": { 47 | "base_uri": "https://localhost:8080/", 48 | "height": 34 49 | } 50 | }, 51 | "source": [ 52 | "!git clone https://github.com/patil-suraj/longbart.git\n", 53 | "%cd longbart" 54 | ], 55 | "execution_count": 1, 56 | "outputs": [ 57 | { 58 | "output_type": "stream", 59 | "text": [ 60 | "/content/longbart\n" 61 | ], 62 | "name": "stdout" 63 | } 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "metadata": { 69 | "id": "yOSSddShCaz1", 70 | "colab_type": "code", 71 | "outputId": "8fc06a89-e421-4f67-df2d-004400b60fbc", 72 | "colab": { 73 | "base_uri": "https://localhost:8080/", 74 | "height": 658 75 | } 76 | }, 77 | "source": [ 78 | "!pip install git+https://github.com/huggingface/transformers.git" 79 | ], 80 | "execution_count": 2, 81 | "outputs": [ 82 | { 83 | "output_type": "stream", 84 | "text": [ 85 | "Collecting git+https://github.com/huggingface/transformers.git\n", 86 | " Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-69qg24g6\n", 87 | " Running command git clone -q https://github.com/huggingface/transformers.git /tmp/pip-req-build-69qg24g6\n", 88 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (1.18.4)\n", 89 | "Collecting tokenizers==0.7.0\n", 90 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)\n", 91 | "\u001b[K |████████████████████████████████| 3.8MB 3.4MB/s \n", 92 | "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (20.4)\n", 93 | "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (3.0.12)\n", 94 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (2.23.0)\n", 95 | "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (4.41.1)\n", 96 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (2019.12.20)\n", 97 | "Collecting sentencepiece\n", 98 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n", 99 | "\u001b[K |████████████████████████████████| 1.1MB 59.3MB/s \n", 100 | "\u001b[?25hCollecting sacremoses\n", 101 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n", 102 | "\u001b[K |████████████████████████████████| 890kB 60.2MB/s \n", 103 | "\u001b[?25hRequirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (0.7)\n", 104 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==2.10.0) (1.12.0)\n", 105 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==2.10.0) (2.4.7)\n", 106 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (3.0.4)\n", 107 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (1.24.3)\n", 108 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (2020.4.5.1)\n", 109 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (2.9)\n", 110 | "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.10.0) (7.1.2)\n", 111 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.10.0) (0.15.1)\n", 112 | "Building wheels for collected packages: transformers, sacremoses\n", 113 | " Building wheel for transformers (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 114 | " Created wheel for transformers: filename=transformers-2.10.0-cp36-none-any.whl size=667026 sha256=1fbfcea1f14b529238dfb962701daa0b4df4df60b6927f0793ca24f52b161af8\n", 115 | " Stored in directory: /tmp/pip-ephem-wheel-cache-gv3xom6x/wheels/33/eb/3b/4bf5dd835e865e472d4fc0754f35ac0edb08fe852e8f21655f\n", 116 | " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 117 | " Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=87f114a7eda7007f2e152396969d392ab29f68a5812c6a57e3538111bdf7c32e\n", 118 | " Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n", 119 | "Successfully built transformers sacremoses\n", 120 | "Installing collected packages: tokenizers, sentencepiece, sacremoses, transformers\n", 121 | "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.7.0 transformers-2.10.0\n" 122 | ], 123 | "name": "stdout" 124 | } 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "id": "3-EQ0F6qCm_d", 131 | "colab_type": "code", 132 | "colab": {} 133 | }, 134 | "source": [ 135 | "import logging\n", 136 | "import os\n", 137 | "import math\n", 138 | "from dataclasses import dataclass, field\n", 139 | "from transformers import RobertaForMaskedLM, RobertaTokenizerFast, TextDataset, DataCollatorForLanguageModeling, Trainer\n", 140 | "from transformers import BartTokenizer\n", 141 | "from transformers import TrainingArguments, HfArgumentParser\n", 142 | "from transformers.modeling_longformer import LongformerSelfAttention\n", 143 | "\n", 144 | "from modeling_bart import BartForConditionalGeneration\n", 145 | "\n", 146 | "logger = logging.getLogger(__name__)\n", 147 | "logging.basicConfig(level=logging.INFO)" 148 | ], 149 | "execution_count": 0, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "metadata": { 155 | "id": "oO4RNqIODK9z", 156 | "colab_type": "code", 157 | "outputId": "bc3a13ca-03ed-45e9-f593-490fe0b36d61", 158 | "colab": { 159 | "base_uri": "https://localhost:8080/", 160 | "height": 1000 161 | } 162 | }, 163 | "source": [ 164 | "# lets use a tiny version of bart for initial experiment \n", 165 | "tokenizer = BartTokenizer.from_pretrained('sshleifer/bart-tiny-random')\n", 166 | "bart = BartForConditionalGeneration.from_pretrained('sshleifer/bart-tiny-random')\n", 167 | "\n", 168 | "# load ROBERta model to see the difference between bart encoder layer and roberta encoder layer \n", 169 | "roberta = RobertaForMaskedLM.from_pretrained('roberta-base')" 170 | ], 171 | "execution_count": 3, 172 | "outputs": [ 173 | { 174 | "output_type": "stream", 175 | "text": [ 176 | "INFO:transformers.tokenization_utils:Model name 'sshleifer/bart-tiny-random' not found in model shortcut name list (bart-large, bart-large-mnli, bart-large-cnn, bart-large-xsum). Assuming 'sshleifer/bart-tiny-random' is a path, a model identifier, or url to a directory containing tokenizer files.\n", 177 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/vocab.json from cache at /root/.cache/torch/transformers/70b9426bcc7c2cd96de53c16f7e13eabbc8373cecf5c38d68ced2fcc25e3382a.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n", 178 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/merges.txt from cache at /root/.cache/torch/transformers/dc37af6307b1a17037d2d066cb55af9cc1cf55d38d3b1f862221fc8d87b9a672.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n", 179 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/added_tokens.json from cache at None\n", 180 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/special_tokens_map.json from cache at None\n", 181 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/tokenizer_config.json from cache at None\n", 182 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/config.json from cache at /root/.cache/torch/transformers/ce13c5b4dd7e5d8a0d2417a7842224d1535d0cd14dd928809bdb6029e1fa7af3.0a5a7d7a4a1c79b5dce5d054a64dd329deefdcbe16b8cf8a4e825bbed4186047\n", 183 | "INFO:transformers.configuration_utils:Model config BartConfig {\n", 184 | " \"_num_labels\": 3,\n", 185 | " \"activation_dropout\": 0.0,\n", 186 | " \"activation_function\": \"gelu\",\n", 187 | " \"add_bias_logits\": false,\n", 188 | " \"add_final_layer_norm\": false,\n", 189 | " \"architectures\": [\n", 190 | " \"BartForConditionalGeneration\"\n", 191 | " ],\n", 192 | " \"attention_dropout\": 0.0,\n", 193 | " \"bos_token_id\": 0,\n", 194 | " \"classif_dropout\": 0.0,\n", 195 | " \"d_model\": 24,\n", 196 | " \"decoder_attention_heads\": 2,\n", 197 | " \"decoder_ffn_dim\": 16,\n", 198 | " \"decoder_layerdrop\": 0.0,\n", 199 | " \"decoder_layers\": 2,\n", 200 | " \"decoder_max_position_embeddings\": 1024,\n", 201 | " \"decoder_start_token_id\": 2,\n", 202 | " \"dropout\": 0.1,\n", 203 | " \"encoder_attention_heads\": 2,\n", 204 | " \"encoder_ffn_dim\": 16,\n", 205 | " \"encoder_layerdrop\": 0.0,\n", 206 | " \"encoder_layers\": 2,\n", 207 | " \"encoder_max_position_embeddings\": 1024,\n", 208 | " \"eos_token_id\": 2,\n", 209 | " \"id2label\": {\n", 210 | " \"0\": \"LABEL_0\",\n", 211 | " \"1\": \"LABEL_1\",\n", 212 | " \"2\": \"LABEL_2\"\n", 213 | " },\n", 214 | " \"init_std\": 0.02,\n", 215 | " \"is_encoder_decoder\": true,\n", 216 | " \"label2id\": {\n", 217 | " \"LABEL_0\": 0,\n", 218 | " \"LABEL_1\": 1,\n", 219 | " \"LABEL_2\": 2\n", 220 | " },\n", 221 | " \"max_position_embeddings\": 1024,\n", 222 | " \"model_type\": \"bart\",\n", 223 | " \"normalize_before\": false,\n", 224 | " \"normalize_embedding\": true,\n", 225 | " \"num_hidden_layers\": 2,\n", 226 | " \"output_past\": true,\n", 227 | " \"pad_token_id\": 1,\n", 228 | " \"prefix\": \" \",\n", 229 | " \"scale_embedding\": false,\n", 230 | " \"static_position_embeddings\": false,\n", 231 | " \"task_specific_params\": {\n", 232 | " \"summarization\": {\n", 233 | " \"early_stopping\": true,\n", 234 | " \"length_penalty\": 2.0,\n", 235 | " \"max_length\": 142,\n", 236 | " \"min_length\": 56,\n", 237 | " \"no_repeat_ngram_size\": 3,\n", 238 | " \"num_beams\": 4\n", 239 | " }\n", 240 | " },\n", 241 | " \"vocab_size\": 50265\n", 242 | "}\n", 243 | "\n", 244 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/sshleifer/bart-tiny-random/pytorch_model.bin from cache at /root/.cache/torch/transformers/002911b8e4cea0a107864f5b17f20c10f613d256e92e3c1247d6d174fbf56fe5.bf6ebaf6162cfbfbad2ce1909278a9ea1fbfe9284d318bff8bccddfdaa104205\n", 245 | "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n", 246 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /root/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.117c81977c5979de8c088352e74ec6e70f5c66096c28b61d3c50101609b39690\n", 247 | "INFO:transformers.configuration_utils:Model config RobertaConfig {\n", 248 | " \"architectures\": [\n", 249 | " \"RobertaForMaskedLM\"\n", 250 | " ],\n", 251 | " \"attention_probs_dropout_prob\": 0.1,\n", 252 | " \"bos_token_id\": 0,\n", 253 | " \"eos_token_id\": 2,\n", 254 | " \"hidden_act\": \"gelu\",\n", 255 | " \"hidden_dropout_prob\": 0.1,\n", 256 | " \"hidden_size\": 768,\n", 257 | " \"initializer_range\": 0.02,\n", 258 | " \"intermediate_size\": 3072,\n", 259 | " \"layer_norm_eps\": 1e-05,\n", 260 | " \"max_position_embeddings\": 514,\n", 261 | " \"model_type\": \"roberta\",\n", 262 | " \"num_attention_heads\": 12,\n", 263 | " \"num_hidden_layers\": 12,\n", 264 | " \"pad_token_id\": 1,\n", 265 | " \"type_vocab_size\": 1,\n", 266 | " \"vocab_size\": 50265\n", 267 | "}\n", 268 | "\n", 269 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/roberta-base-pytorch_model.bin from cache at /root/.cache/torch/transformers/80b4a484eddeb259bec2f06a6f2f05d90934111628e0e1c09a33bd4a121358e1.49b88ba7ec2c26a7558dda98ca3884c3b80fa31cf43a1b1f23aef3ff81ba344e\n", 270 | "INFO:transformers.modeling_utils:Weights of RobertaForMaskedLM not initialized from pretrained model: ['lm_head.decoder.bias']\n" 271 | ], 272 | "name": "stderr" 273 | } 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "metadata": { 279 | "id": "kZcws3kaMFrp", 280 | "colab_type": "code", 281 | "outputId": "1ef71a3a-5f3c-4ed6-d033-a4b55a259325", 282 | "colab": { 283 | "base_uri": "https://localhost:8080/", 284 | "height": 370 285 | } 286 | }, 287 | "source": [ 288 | "roberta.config" 289 | ], 290 | "execution_count": 5, 291 | "outputs": [ 292 | { 293 | "output_type": "execute_result", 294 | "data": { 295 | "text/plain": [ 296 | "RobertaConfig {\n", 297 | " \"architectures\": [\n", 298 | " \"RobertaForMaskedLM\"\n", 299 | " ],\n", 300 | " \"attention_probs_dropout_prob\": 0.1,\n", 301 | " \"bos_token_id\": 0,\n", 302 | " \"eos_token_id\": 2,\n", 303 | " \"hidden_act\": \"gelu\",\n", 304 | " \"hidden_dropout_prob\": 0.1,\n", 305 | " \"hidden_size\": 768,\n", 306 | " \"initializer_range\": 0.02,\n", 307 | " \"intermediate_size\": 3072,\n", 308 | " \"layer_norm_eps\": 1e-05,\n", 309 | " \"max_position_embeddings\": 514,\n", 310 | " \"model_type\": \"roberta\",\n", 311 | " \"num_attention_heads\": 12,\n", 312 | " \"num_hidden_layers\": 12,\n", 313 | " \"pad_token_id\": 1,\n", 314 | " \"type_vocab_size\": 1,\n", 315 | " \"vocab_size\": 50265\n", 316 | "}" 317 | ] 318 | }, 319 | "metadata": { 320 | "tags": [] 321 | }, 322 | "execution_count": 5 323 | } 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "metadata": { 329 | "id": "2ucGw5qrEtp8", 330 | "colab_type": "code", 331 | "outputId": "d844579b-a54a-4799-e783-4945930d6eb7", 332 | "colab": { 333 | "base_uri": "https://localhost:8080/", 334 | "height": 1000 335 | } 336 | }, 337 | "source": [ 338 | "bart.config" 339 | ], 340 | "execution_count": 6, 341 | "outputs": [ 342 | { 343 | "output_type": "execute_result", 344 | "data": { 345 | "text/plain": [ 346 | "BartConfig {\n", 347 | " \"_num_labels\": 3,\n", 348 | " \"activation_dropout\": 0.0,\n", 349 | " \"activation_function\": \"gelu\",\n", 350 | " \"add_bias_logits\": false,\n", 351 | " \"add_final_layer_norm\": false,\n", 352 | " \"architectures\": [\n", 353 | " \"BartForConditionalGeneration\"\n", 354 | " ],\n", 355 | " \"attention_dropout\": 0.0,\n", 356 | " \"bos_token_id\": 0,\n", 357 | " \"classif_dropout\": 0.0,\n", 358 | " \"d_model\": 24,\n", 359 | " \"decoder_attention_heads\": 2,\n", 360 | " \"decoder_ffn_dim\": 16,\n", 361 | " \"decoder_layerdrop\": 0.0,\n", 362 | " \"decoder_layers\": 2,\n", 363 | " \"decoder_max_position_embeddings\": 1024,\n", 364 | " \"decoder_start_token_id\": 2,\n", 365 | " \"dropout\": 0.1,\n", 366 | " \"encoder_attention_heads\": 2,\n", 367 | " \"encoder_ffn_dim\": 16,\n", 368 | " \"encoder_layerdrop\": 0.0,\n", 369 | " \"encoder_layers\": 2,\n", 370 | " \"encoder_max_position_embeddings\": 1024,\n", 371 | " \"eos_token_id\": 2,\n", 372 | " \"id2label\": {\n", 373 | " \"0\": \"LABEL_0\",\n", 374 | " \"1\": \"LABEL_1\",\n", 375 | " \"2\": \"LABEL_2\"\n", 376 | " },\n", 377 | " \"init_std\": 0.02,\n", 378 | " \"is_encoder_decoder\": true,\n", 379 | " \"label2id\": {\n", 380 | " \"LABEL_0\": 0,\n", 381 | " \"LABEL_1\": 1,\n", 382 | " \"LABEL_2\": 2\n", 383 | " },\n", 384 | " \"max_position_embeddings\": 1024,\n", 385 | " \"model_type\": \"bart\",\n", 386 | " \"normalize_before\": false,\n", 387 | " \"normalize_embedding\": true,\n", 388 | " \"num_hidden_layers\": 2,\n", 389 | " \"output_past\": true,\n", 390 | " \"pad_token_id\": 1,\n", 391 | " \"prefix\": \" \",\n", 392 | " \"scale_embedding\": false,\n", 393 | " \"static_position_embeddings\": false,\n", 394 | " \"task_specific_params\": {\n", 395 | " \"summarization\": {\n", 396 | " \"early_stopping\": true,\n", 397 | " \"length_penalty\": 2.0,\n", 398 | " \"max_length\": 142,\n", 399 | " \"min_length\": 56,\n", 400 | " \"no_repeat_ngram_size\": 3,\n", 401 | " \"num_beams\": 4\n", 402 | " }\n", 403 | " },\n", 404 | " \"vocab_size\": 50265\n", 405 | "}" 406 | ] 407 | }, 408 | "metadata": { 409 | "tags": [] 410 | }, 411 | "execution_count": 6 412 | } 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "metadata": { 418 | "id": "RNA4Z21mEvGt", 419 | "colab_type": "code", 420 | "colab": {} 421 | }, 422 | "source": [ 423 | "bart_layer = bart.model.encoder.layers[0]\n", 424 | "roberta_layer = roberta.roberta.encoder.layer[0]" 425 | ], 426 | "execution_count": 0, 427 | "outputs": [] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "metadata": { 432 | "id": "2KAOBHPdFGn_", 433 | "colab_type": "code", 434 | "outputId": "9121f646-3762-4016-b62d-62523686f8af", 435 | "colab": { 436 | "base_uri": "https://localhost:8080/", 437 | "height": 403 438 | } 439 | }, 440 | "source": [ 441 | "roberta_layer" 442 | ], 443 | "execution_count": 8, 444 | "outputs": [ 445 | { 446 | "output_type": "execute_result", 447 | "data": { 448 | "text/plain": [ 449 | "BertLayer(\n", 450 | " (attention): BertAttention(\n", 451 | " (self): BertSelfAttention(\n", 452 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 453 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 454 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 455 | " (dropout): Dropout(p=0.1, inplace=False)\n", 456 | " )\n", 457 | " (output): BertSelfOutput(\n", 458 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 459 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 460 | " (dropout): Dropout(p=0.1, inplace=False)\n", 461 | " )\n", 462 | " )\n", 463 | " (intermediate): BertIntermediate(\n", 464 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 465 | " )\n", 466 | " (output): BertOutput(\n", 467 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 468 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 469 | " (dropout): Dropout(p=0.1, inplace=False)\n", 470 | " )\n", 471 | ")" 472 | ] 473 | }, 474 | "metadata": { 475 | "tags": [] 476 | }, 477 | "execution_count": 8 478 | } 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "metadata": { 484 | "id": "okn8MAVcFDS9", 485 | "colab_type": "code", 486 | "outputId": "2f6a72a3-d7e3-46d6-c81f-a1fc1af85258", 487 | "colab": { 488 | "base_uri": "https://localhost:8080/", 489 | "height": 218 490 | } 491 | }, 492 | "source": [ 493 | "bart_layer" 494 | ], 495 | "execution_count": 9, 496 | "outputs": [ 497 | { 498 | "output_type": "execute_result", 499 | "data": { 500 | "text/plain": [ 501 | "EncoderLayer(\n", 502 | " (self_attn): SelfAttention(\n", 503 | " (k_proj): Linear(in_features=24, out_features=24, bias=True)\n", 504 | " (v_proj): Linear(in_features=24, out_features=24, bias=True)\n", 505 | " (q_proj): Linear(in_features=24, out_features=24, bias=True)\n", 506 | " (out_proj): Linear(in_features=24, out_features=24, bias=True)\n", 507 | " )\n", 508 | " (self_attn_layer_norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)\n", 509 | " (fc1): Linear(in_features=24, out_features=16, bias=True)\n", 510 | " (fc2): Linear(in_features=16, out_features=24, bias=True)\n", 511 | " (final_layer_norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)\n", 512 | ")" 513 | ] 514 | }, 515 | "metadata": { 516 | "tags": [] 517 | }, 518 | "execution_count": 9 519 | } 520 | ] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "metadata": { 525 | "id": "pURj--xZL7Ef", 526 | "colab_type": "text" 527 | }, 528 | "source": [ 529 | "BART calculates the output projection in the attention layer itself, also the `forward` paramter names of `SelfAttention` layer used in BART are different than that of `BertSelfAttention`. So we'll need to wrap `LongformerSelfAttention` to use it for BART" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "metadata": { 535 | "id": "jfPgsJ8YQR4A", 536 | "colab_type": "code", 537 | "colab": {} 538 | }, 539 | "source": [ 540 | "import math\n", 541 | "from typing import Dict, List, Optional, Tuple\n", 542 | "\n", 543 | "import torch\n", 544 | "from torch import Tensor, nn\n", 545 | "\n", 546 | "class LongformerSelfAttentionForBart(nn.Module):\n", 547 | " def __init__(self, config, layer_id):\n", 548 | " super().__init__()\n", 549 | " self.embed_dim = config.d_model\n", 550 | " self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)\n", 551 | " self.output = nn.Linear(self.embed_dim, self.embed_dim)\n", 552 | " \n", 553 | " def forward(\n", 554 | " self,\n", 555 | " query,\n", 556 | " key: Optional[Tensor],\n", 557 | " key_padding_mask: Optional[Tensor] = None,\n", 558 | " layer_state: Optional[Dict[str, Optional[Tensor]]] = None,\n", 559 | " attn_mask: Optional[Tensor] = None,\n", 560 | " need_weights=False,\n", 561 | " ) -> Tuple[Tensor, Optional[Tensor]]:\n", 562 | " \n", 563 | " tgt_len, bsz, embed_dim = query.size()\n", 564 | " assert embed_dim == self.embed_dim\n", 565 | " assert list(query.size()) == [tgt_len, bsz, embed_dim]\n", 566 | "\n", 567 | " # LongformerSelfAttention expects this shape\n", 568 | " query = query.view(bsz, tgt_len, embed_dim)\n", 569 | "\n", 570 | " outputs = self.longformer_self_attn(\n", 571 | " query,\n", 572 | " attention_mask=attn_mask,\n", 573 | " head_mask=None,\n", 574 | " encoder_hidden_states=None,\n", 575 | " encoder_attention_mask=None,\n", 576 | " )\n", 577 | "\n", 578 | " attn_output = outputs[0] \n", 579 | " attn_output = attn_output.contiguous().view(tgt_len, bsz, embed_dim)\n", 580 | " attn_output = self.output(attn_output)\n", 581 | "\n", 582 | " return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)" 583 | ], 584 | "execution_count": 0, 585 | "outputs": [] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "metadata": { 590 | "id": "6VIx_TmOELqF", 591 | "colab_type": "code", 592 | "colab": {} 593 | }, 594 | "source": [ 595 | "class LongBartForConditionalGeneration(BartForConditionalGeneration):\n", 596 | " def __init__(self, config):\n", 597 | " super().__init__(config)\n", 598 | " for i, layer in enumerate(self.model.encoder.layers):\n", 599 | " # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`\n", 600 | " layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)" 601 | ], 602 | "execution_count": 0, 603 | "outputs": [] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "metadata": { 608 | "id": "Lx28-eLNEou5", 609 | "colab_type": "code", 610 | "colab": {} 611 | }, 612 | "source": [ 613 | "def create_long_model(save_model_to, base_model='bart-large', attention_window=512, max_pos=4096):\n", 614 | " model = BartForConditionalGeneration.from_pretrained(base_model)\n", 615 | " tokenizer = BartTokenizer.from_pretrained('bart-large', model_max_length=max_pos)\n", 616 | " config = model.config\n", 617 | "\n", 618 | " # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention\n", 619 | " # expects attention_probs_dropout_prob, so set it here \n", 620 | " config.attention_probs_dropout_prob = config.attention_dropout\n", 621 | "\n", 622 | " # extend position embeddings\n", 623 | " tokenizer.model_max_length = max_pos\n", 624 | " tokenizer.init_kwargs['model_max_length'] = max_pos\n", 625 | " current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape\n", 626 | " # config.max_position_embeddings = max_pos\n", 627 | " config.encoder_max_position_embeddings = max_pos\n", 628 | " max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2\n", 629 | " assert max_pos > current_max_pos\n", 630 | " # allocate a larger position embedding matrix\n", 631 | " new_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)\n", 632 | " # copy position embeddings over and over to initialize the new position embeddings\n", 633 | " k = 2\n", 634 | " step = current_max_pos - 2\n", 635 | " while k < max_pos - 1:\n", 636 | " new_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]\n", 637 | " k += step\n", 638 | " model.model.encoder.embed_positions.weight.data = new_pos_embed\n", 639 | "\n", 640 | " # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`\n", 641 | " config.attention_window = [attention_window] * config.num_hidden_layers\n", 642 | " for i, layer in enumerate(model.model.encoder.layers):\n", 643 | " longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i)\n", 644 | " \n", 645 | " longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj\n", 646 | " longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj\n", 647 | " longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj\n", 648 | "\n", 649 | " longformer_self_attn_for_bart.longformer_self_attn.query_global = layer.self_attn.q_proj\n", 650 | " longformer_self_attn_for_bart.longformer_self_attn.key_global = layer.self_attn.k_proj\n", 651 | " longformer_self_attn_for_bart.longformer_self_attn.value_global = layer.self_attn.v_proj\n", 652 | "\n", 653 | " longformer_self_attn_for_bart.output = layer.self_attn.out_proj\n", 654 | "\n", 655 | " layer.self_attn = longformer_self_attn_for_bart\n", 656 | "\n", 657 | " logger.info(f'saving model to {save_model_to}')\n", 658 | " model.save_pretrained(save_model_to)\n", 659 | " tokenizer.save_pretrained(save_model_to)\n", 660 | " return model, tokenizer" 661 | ], 662 | "execution_count": 0, 663 | "outputs": [] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "metadata": { 668 | "id": "XuYkE_kGO-U-", 669 | "colab_type": "code", 670 | "outputId": "ee5176ee-677c-406b-feb2-1d1aa5ea2b23", 671 | "colab": { 672 | "base_uri": "https://localhost:8080/", 673 | "height": 1000 674 | } 675 | }, 676 | "source": [ 677 | "# model_path = f'{training_args.output_dir}/roberta-base-{model_args.max_pos}'\n", 678 | "base_model = \"sshleifer/bart-tiny-random\"\n", 679 | "model_path = \"bart-tiny-random-4096\"\n", 680 | "attention_window = 512\n", 681 | "max_pos = 4096\n", 682 | "\n", 683 | "if not os.path.exists(model_path):\n", 684 | " os.makedirs(model_path)\n", 685 | "\n", 686 | "# logger.info(f'Converting roberta-base into roberta-base-{model_args.max_pos}')\n", 687 | "model, tokenizer = create_long_model(\n", 688 | " save_model_to=model_path,\n", 689 | " base_model=base_model,\n", 690 | " attention_window=attention_window,\n", 691 | " max_pos=max_pos\n", 692 | ")" 693 | ], 694 | "execution_count": 13, 695 | "outputs": [ 696 | { 697 | "output_type": "stream", 698 | "text": [ 699 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/config.json from cache at /root/.cache/torch/transformers/ce13c5b4dd7e5d8a0d2417a7842224d1535d0cd14dd928809bdb6029e1fa7af3.0a5a7d7a4a1c79b5dce5d054a64dd329deefdcbe16b8cf8a4e825bbed4186047\n", 700 | "INFO:transformers.configuration_utils:Model config BartConfig {\n", 701 | " \"_num_labels\": 3,\n", 702 | " \"activation_dropout\": 0.0,\n", 703 | " \"activation_function\": \"gelu\",\n", 704 | " \"add_bias_logits\": false,\n", 705 | " \"add_final_layer_norm\": false,\n", 706 | " \"architectures\": [\n", 707 | " \"BartForConditionalGeneration\"\n", 708 | " ],\n", 709 | " \"attention_dropout\": 0.0,\n", 710 | " \"bos_token_id\": 0,\n", 711 | " \"classif_dropout\": 0.0,\n", 712 | " \"d_model\": 24,\n", 713 | " \"decoder_attention_heads\": 2,\n", 714 | " \"decoder_ffn_dim\": 16,\n", 715 | " \"decoder_layerdrop\": 0.0,\n", 716 | " \"decoder_layers\": 2,\n", 717 | " \"decoder_max_position_embeddings\": 1024,\n", 718 | " \"decoder_start_token_id\": 2,\n", 719 | " \"dropout\": 0.1,\n", 720 | " \"encoder_attention_heads\": 2,\n", 721 | " \"encoder_ffn_dim\": 16,\n", 722 | " \"encoder_layerdrop\": 0.0,\n", 723 | " \"encoder_layers\": 2,\n", 724 | " \"encoder_max_position_embeddings\": 1024,\n", 725 | " \"eos_token_id\": 2,\n", 726 | " \"id2label\": {\n", 727 | " \"0\": \"LABEL_0\",\n", 728 | " \"1\": \"LABEL_1\",\n", 729 | " \"2\": \"LABEL_2\"\n", 730 | " },\n", 731 | " \"init_std\": 0.02,\n", 732 | " \"is_encoder_decoder\": true,\n", 733 | " \"label2id\": {\n", 734 | " \"LABEL_0\": 0,\n", 735 | " \"LABEL_1\": 1,\n", 736 | " \"LABEL_2\": 2\n", 737 | " },\n", 738 | " \"max_position_embeddings\": 1024,\n", 739 | " \"model_type\": \"bart\",\n", 740 | " \"normalize_before\": false,\n", 741 | " \"normalize_embedding\": true,\n", 742 | " \"num_hidden_layers\": 2,\n", 743 | " \"output_past\": true,\n", 744 | " \"pad_token_id\": 1,\n", 745 | " \"prefix\": \" \",\n", 746 | " \"scale_embedding\": false,\n", 747 | " \"static_position_embeddings\": false,\n", 748 | " \"task_specific_params\": {\n", 749 | " \"summarization\": {\n", 750 | " \"early_stopping\": true,\n", 751 | " \"length_penalty\": 2.0,\n", 752 | " \"max_length\": 142,\n", 753 | " \"min_length\": 56,\n", 754 | " \"no_repeat_ngram_size\": 3,\n", 755 | " \"num_beams\": 4\n", 756 | " }\n", 757 | " },\n", 758 | " \"vocab_size\": 50265\n", 759 | "}\n", 760 | "\n", 761 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/sshleifer/bart-tiny-random/pytorch_model.bin from cache at /root/.cache/torch/transformers/002911b8e4cea0a107864f5b17f20c10f613d256e92e3c1247d6d174fbf56fe5.bf6ebaf6162cfbfbad2ce1909278a9ea1fbfe9284d318bff8bccddfdaa104205\n", 762 | "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n", 763 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n", 764 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n", 765 | "INFO:__main__:saving model to bart-tiny-random-4096\n", 766 | "INFO:transformers.configuration_utils:Configuration saved in bart-tiny-random-4096/config.json\n", 767 | "INFO:transformers.modeling_utils:Model weights saved in bart-tiny-random-4096/pytorch_model.bin\n" 768 | ], 769 | "name": "stderr" 770 | } 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "metadata": { 776 | "id": "XsnW-waPU3Ua", 777 | "colab_type": "code", 778 | "outputId": "5f6d97f5-2863-4535-969a-1fa884b5635e", 779 | "colab": { 780 | "base_uri": "https://localhost:8080/", 781 | "height": 1000 782 | } 783 | }, 784 | "source": [ 785 | "long_model_tiny = LongBartForConditionalGeneration.from_pretrained('bart-tiny-random-4096')" 786 | ], 787 | "execution_count": 14, 788 | "outputs": [ 789 | { 790 | "output_type": "stream", 791 | "text": [ 792 | "INFO:transformers.configuration_utils:loading configuration file bart-tiny-random-4096/config.json\n", 793 | "INFO:transformers.configuration_utils:Model config BartConfig {\n", 794 | " \"_num_labels\": 3,\n", 795 | " \"activation_dropout\": 0.0,\n", 796 | " \"activation_function\": \"gelu\",\n", 797 | " \"add_bias_logits\": false,\n", 798 | " \"add_final_layer_norm\": false,\n", 799 | " \"architectures\": [\n", 800 | " \"BartForConditionalGeneration\"\n", 801 | " ],\n", 802 | " \"attention_dropout\": 0.0,\n", 803 | " \"attention_probs_dropout_prob\": 0.0,\n", 804 | " \"attention_window\": [\n", 805 | " 512,\n", 806 | " 512\n", 807 | " ],\n", 808 | " \"bos_token_id\": 0,\n", 809 | " \"classif_dropout\": 0.0,\n", 810 | " \"d_model\": 24,\n", 811 | " \"decoder_attention_heads\": 2,\n", 812 | " \"decoder_ffn_dim\": 16,\n", 813 | " \"decoder_layerdrop\": 0.0,\n", 814 | " \"decoder_layers\": 2,\n", 815 | " \"decoder_max_position_embeddings\": 1024,\n", 816 | " \"decoder_start_token_id\": 2,\n", 817 | " \"dropout\": 0.1,\n", 818 | " \"encoder_attention_heads\": 2,\n", 819 | " \"encoder_ffn_dim\": 16,\n", 820 | " \"encoder_layerdrop\": 0.0,\n", 821 | " \"encoder_layers\": 2,\n", 822 | " \"encoder_max_position_embeddings\": 4096,\n", 823 | " \"eos_token_id\": 2,\n", 824 | " \"id2label\": {\n", 825 | " \"0\": \"LABEL_0\",\n", 826 | " \"1\": \"LABEL_1\",\n", 827 | " \"2\": \"LABEL_2\"\n", 828 | " },\n", 829 | " \"init_std\": 0.02,\n", 830 | " \"is_encoder_decoder\": true,\n", 831 | " \"label2id\": {\n", 832 | " \"LABEL_0\": 0,\n", 833 | " \"LABEL_1\": 1,\n", 834 | " \"LABEL_2\": 2\n", 835 | " },\n", 836 | " \"max_position_embeddings\": 1024,\n", 837 | " \"model_type\": \"bart\",\n", 838 | " \"normalize_before\": false,\n", 839 | " \"normalize_embedding\": true,\n", 840 | " \"num_hidden_layers\": 2,\n", 841 | " \"output_past\": true,\n", 842 | " \"pad_token_id\": 1,\n", 843 | " \"prefix\": \" \",\n", 844 | " \"scale_embedding\": false,\n", 845 | " \"static_position_embeddings\": false,\n", 846 | " \"task_specific_params\": {\n", 847 | " \"summarization\": {\n", 848 | " \"early_stopping\": true,\n", 849 | " \"length_penalty\": 2.0,\n", 850 | " \"max_length\": 142,\n", 851 | " \"min_length\": 56,\n", 852 | " \"no_repeat_ngram_size\": 3,\n", 853 | " \"num_beams\": 4\n", 854 | " }\n", 855 | " },\n", 856 | " \"vocab_size\": 50265\n", 857 | "}\n", 858 | "\n", 859 | "INFO:transformers.modeling_utils:loading weights file bart-tiny-random-4096/pytorch_model.bin\n" 860 | ], 861 | "name": "stderr" 862 | } 863 | ] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "metadata": { 868 | "id": "Z5QKIIdeYRDL", 869 | "colab_type": "code", 870 | "outputId": "dd326b65-6bc7-4a91-e670-e41b6f64784f", 871 | "colab": { 872 | "base_uri": "https://localhost:8080/", 873 | "height": 34 874 | } 875 | }, 876 | "source": [ 877 | "TXT = \"My friends are but they eat too many carbs.\"\n", 878 | "\n", 879 | "input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt', max_length=4096, pad_to_max_length=True)['input_ids']\n", 880 | "\n", 881 | "logits = long_model_tiny(input_ids)[0]\n", 882 | "\n", 883 | "masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()\n", 884 | "probs = logits[0, masked_index].softmax(dim=0)\n", 885 | "values, predictions = probs.topk(5)\n", 886 | "tokenizer.decode(predictions).split()" 887 | ], 888 | "execution_count": 15, 889 | "outputs": [ 890 | { 891 | "output_type": "execute_result", 892 | "data": { 893 | "text/plain": [ 894 | "['.']" 895 | ] 896 | }, 897 | "metadata": { 898 | "tags": [] 899 | }, 900 | "execution_count": 15 901 | } 902 | ] 903 | }, 904 | { 905 | "cell_type": "markdown", 906 | "metadata": { 907 | "id": "FNmzwNHAN1AI", 908 | "colab_type": "text" 909 | }, 910 | "source": [ 911 | "Now lets try with bart-large" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "metadata": { 917 | "id": "vAzZdj-1N3Is", 918 | "colab_type": "code", 919 | "outputId": "85cf0a32-fc89-4137-f61e-d2710ea85bc4", 920 | "colab": { 921 | "base_uri": "https://localhost:8080/", 922 | "height": 1000 923 | } 924 | }, 925 | "source": [ 926 | "# model_path = f'{training_args.output_dir}/roberta-base-{model_args.max_pos}'\n", 927 | "base_model = \"bart-large\"\n", 928 | "model_path = \"bart-large-4096\"\n", 929 | "attention_window = 512\n", 930 | "max_pos = 4096\n", 931 | "\n", 932 | "if not os.path.exists(model_path):\n", 933 | " os.makedirs(model_path)\n", 934 | "\n", 935 | "# logger.info(f'Converting roberta-base into roberta-base-{model_args.max_pos}')\n", 936 | "model, tokenizer = create_long_model(\n", 937 | " save_model_to=model_path,\n", 938 | " base_model=base_model,\n", 939 | " attention_window=attention_window,\n", 940 | " max_pos=max_pos\n", 941 | ")" 942 | ], 943 | "execution_count": 16, 944 | "outputs": [ 945 | { 946 | "output_type": "stream", 947 | "text": [ 948 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json from cache at /root/.cache/torch/transformers/7f6632e580b7d9fd4f611dd96dab877cccfc319867b53b8b72fddca7fd64de5c.40bd49bcec9d93d8b0bfbd020088e2e1b6e6bb03e8e80aa5144638f90ca6bd61\n", 949 | "INFO:transformers.configuration_utils:Model config BartConfig {\n", 950 | " \"_num_labels\": 3,\n", 951 | " \"activation_dropout\": 0.0,\n", 952 | " \"activation_function\": \"gelu\",\n", 953 | " \"add_bias_logits\": false,\n", 954 | " \"add_final_layer_norm\": false,\n", 955 | " \"architectures\": [\n", 956 | " \"BartModel\",\n", 957 | " \"BartForMaskedLM\",\n", 958 | " \"BartForSequenceClassification\"\n", 959 | " ],\n", 960 | " \"attention_dropout\": 0.0,\n", 961 | " \"bos_token_id\": 0,\n", 962 | " \"classif_dropout\": 0.0,\n", 963 | " \"d_model\": 1024,\n", 964 | " \"decoder_attention_heads\": 16,\n", 965 | " \"decoder_ffn_dim\": 4096,\n", 966 | " \"decoder_layerdrop\": 0.0,\n", 967 | " \"decoder_layers\": 12,\n", 968 | " \"decoder_max_position_embeddings\": 1024,\n", 969 | " \"decoder_start_token_id\": 2,\n", 970 | " \"dropout\": 0.1,\n", 971 | " \"encoder_attention_heads\": 16,\n", 972 | " \"encoder_ffn_dim\": 4096,\n", 973 | " \"encoder_layerdrop\": 0.0,\n", 974 | " \"encoder_layers\": 12,\n", 975 | " \"encoder_max_position_embeddings\": 1024,\n", 976 | " \"eos_token_id\": 2,\n", 977 | " \"id2label\": {\n", 978 | " \"0\": \"LABEL_0\",\n", 979 | " \"1\": \"LABEL_1\",\n", 980 | " \"2\": \"LABEL_2\"\n", 981 | " },\n", 982 | " \"init_std\": 0.02,\n", 983 | " \"is_encoder_decoder\": true,\n", 984 | " \"label2id\": {\n", 985 | " \"LABEL_0\": 0,\n", 986 | " \"LABEL_1\": 1,\n", 987 | " \"LABEL_2\": 2\n", 988 | " },\n", 989 | " \"max_position_embeddings\": 1024,\n", 990 | " \"model_type\": \"bart\",\n", 991 | " \"normalize_before\": false,\n", 992 | " \"normalize_embedding\": true,\n", 993 | " \"num_hidden_layers\": 12,\n", 994 | " \"output_past\": false,\n", 995 | " \"pad_token_id\": 1,\n", 996 | " \"prefix\": \" \",\n", 997 | " \"scale_embedding\": false,\n", 998 | " \"static_position_embeddings\": false,\n", 999 | " \"task_specific_params\": {\n", 1000 | " \"summarization\": {\n", 1001 | " \"early_stopping\": true,\n", 1002 | " \"length_penalty\": 2.0,\n", 1003 | " \"max_length\": 142,\n", 1004 | " \"min_length\": 56,\n", 1005 | " \"no_repeat_ngram_size\": 3,\n", 1006 | " \"num_beams\": 4\n", 1007 | " }\n", 1008 | " },\n", 1009 | " \"vocab_size\": 50265\n", 1010 | "}\n", 1011 | "\n", 1012 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/facebook/bart-large/pytorch_model.bin from cache at /root/.cache/torch/transformers/2e7cae41bb1dd1f18e498ff4ff0ea85f7e9bc2b637439e2d95c485c5d5bdd579.5be2a88ec29f5969270f98902db392beab8be8a6a7ecc588d410ada3e32c4263\n", 1013 | "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n", 1014 | "INFO:transformers.modeling_utils:Weights from pretrained model not used in BartForConditionalGeneration: ['encoder.version', 'decoder.version']\n", 1015 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n", 1016 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n", 1017 | "INFO:__main__:saving model to bart-large-4096\n", 1018 | "INFO:transformers.configuration_utils:Configuration saved in bart-large-4096/config.json\n", 1019 | "INFO:transformers.modeling_utils:Model weights saved in bart-large-4096/pytorch_model.bin\n" 1020 | ], 1021 | "name": "stderr" 1022 | } 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "code", 1027 | "metadata": { 1028 | "id": "9-lgWprfN8QW", 1029 | "colab_type": "code", 1030 | "outputId": "c1a817ea-3a05-4676-acea-81b27b3f6591", 1031 | "colab": { 1032 | "base_uri": "https://localhost:8080/", 1033 | "height": 1000 1034 | } 1035 | }, 1036 | "source": [ 1037 | "long_model = LongBartForConditionalGeneration.from_pretrained('bart-large-4096')\n", 1038 | "tokenizer = BartTokenizer.from_pretrained('bart-large-4096')" 1039 | ], 1040 | "execution_count": 7, 1041 | "outputs": [ 1042 | { 1043 | "output_type": "stream", 1044 | "text": [ 1045 | "INFO:transformers.configuration_utils:loading configuration file bart-large-4096/config.json\n", 1046 | "INFO:transformers.configuration_utils:Model config BartConfig {\n", 1047 | " \"_num_labels\": 3,\n", 1048 | " \"activation_dropout\": 0.0,\n", 1049 | " \"activation_function\": \"gelu\",\n", 1050 | " \"add_bias_logits\": false,\n", 1051 | " \"add_final_layer_norm\": false,\n", 1052 | " \"architectures\": [\n", 1053 | " \"BartForConditionalGeneration\"\n", 1054 | " ],\n", 1055 | " \"attention_dropout\": 0.0,\n", 1056 | " \"attention_probs_dropout_prob\": 0.0,\n", 1057 | " \"attention_window\": [\n", 1058 | " 512,\n", 1059 | " 512,\n", 1060 | " 512,\n", 1061 | " 512,\n", 1062 | " 512,\n", 1063 | " 512,\n", 1064 | " 512,\n", 1065 | " 512,\n", 1066 | " 512,\n", 1067 | " 512,\n", 1068 | " 512,\n", 1069 | " 512\n", 1070 | " ],\n", 1071 | " \"bos_token_id\": 0,\n", 1072 | " \"classif_dropout\": 0.0,\n", 1073 | " \"d_model\": 1024,\n", 1074 | " \"decoder_attention_heads\": 16,\n", 1075 | " \"decoder_ffn_dim\": 4096,\n", 1076 | " \"decoder_layerdrop\": 0.0,\n", 1077 | " \"decoder_layers\": 12,\n", 1078 | " \"decoder_max_position_embeddings\": 1024,\n", 1079 | " \"decoder_start_token_id\": 2,\n", 1080 | " \"dropout\": 0.1,\n", 1081 | " \"encoder_attention_heads\": 16,\n", 1082 | " \"encoder_ffn_dim\": 4096,\n", 1083 | " \"encoder_layerdrop\": 0.0,\n", 1084 | " \"encoder_layers\": 12,\n", 1085 | " \"encoder_max_position_embeddings\": 4096,\n", 1086 | " \"eos_token_id\": 2,\n", 1087 | " \"id2label\": {\n", 1088 | " \"0\": \"LABEL_0\",\n", 1089 | " \"1\": \"LABEL_1\",\n", 1090 | " \"2\": \"LABEL_2\"\n", 1091 | " },\n", 1092 | " \"init_std\": 0.02,\n", 1093 | " \"is_encoder_decoder\": true,\n", 1094 | " \"label2id\": {\n", 1095 | " \"LABEL_0\": 0,\n", 1096 | " \"LABEL_1\": 1,\n", 1097 | " \"LABEL_2\": 2\n", 1098 | " },\n", 1099 | " \"max_position_embeddings\": 1024,\n", 1100 | " \"model_type\": \"bart\",\n", 1101 | " \"normalize_before\": false,\n", 1102 | " \"normalize_embedding\": true,\n", 1103 | " \"num_hidden_layers\": 12,\n", 1104 | " \"output_past\": false,\n", 1105 | " \"pad_token_id\": 1,\n", 1106 | " \"prefix\": \" \",\n", 1107 | " \"scale_embedding\": false,\n", 1108 | " \"static_position_embeddings\": false,\n", 1109 | " \"task_specific_params\": {\n", 1110 | " \"summarization\": {\n", 1111 | " \"early_stopping\": true,\n", 1112 | " \"length_penalty\": 2.0,\n", 1113 | " \"max_length\": 142,\n", 1114 | " \"min_length\": 56,\n", 1115 | " \"no_repeat_ngram_size\": 3,\n", 1116 | " \"num_beams\": 4\n", 1117 | " }\n", 1118 | " },\n", 1119 | " \"vocab_size\": 50265\n", 1120 | "}\n", 1121 | "\n", 1122 | "INFO:transformers.modeling_utils:loading weights file bart-large-4096/pytorch_model.bin\n", 1123 | "INFO:transformers.tokenization_utils:Model name 'bart-large-4096' not found in model shortcut name list (bart-large, bart-large-mnli, bart-large-cnn, bart-large-xsum). Assuming 'bart-large-4096' is a path, a model identifier, or url to a directory containing tokenizer files.\n", 1124 | "INFO:transformers.tokenization_utils:Didn't find file bart-large-4096/added_tokens.json. We won't load it.\n", 1125 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/vocab.json\n", 1126 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/merges.txt\n", 1127 | "INFO:transformers.tokenization_utils:loading file None\n", 1128 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/special_tokens_map.json\n", 1129 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/tokenizer_config.json\n" 1130 | ], 1131 | "name": "stderr" 1132 | } 1133 | ] 1134 | }, 1135 | { 1136 | "cell_type": "code", 1137 | "metadata": { 1138 | "id": "cLhZFQMYONPb", 1139 | "colab_type": "code", 1140 | "colab": { 1141 | "base_uri": "https://localhost:8080/", 1142 | "height": 34 1143 | }, 1144 | "outputId": "f4eda1b3-5333-4144-bdd4-9804046dd30a" 1145 | }, 1146 | "source": [ 1147 | "TXT = \"My friends are but they eat too many carbs.\"\n", 1148 | "\n", 1149 | "# 4096 seq len crashes even with 35 GB memory\n", 1150 | "# so we also probably need sliding-window attention in decoder as well\n", 1151 | "input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt', max_length=2560, pad_to_max_length=True)['input_ids']\n", 1152 | "\n", 1153 | "logits = long_model(input_ids)[0]\n", 1154 | "\n", 1155 | "masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()\n", 1156 | "probs = logits[0, masked_index].softmax(dim=0)\n", 1157 | "values, predictions = probs.topk(5)\n", 1158 | "tokenizer.decode(predictions).split()" 1159 | ], 1160 | "execution_count": 8, 1161 | "outputs": [ 1162 | { 1163 | "output_type": "execute_result", 1164 | "data": { 1165 | "text/plain": [ 1166 | "['having', 'still', 'going', 'getting', 'not']" 1167 | ] 1168 | }, 1169 | "metadata": { 1170 | "tags": [] 1171 | }, 1172 | "execution_count": 8 1173 | } 1174 | ] 1175 | } 1176 | ] 1177 | } -------------------------------------------------------------------------------- /longbart/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_longbart import LongformerSelfAttentionForBart, LongBartForConditionalGeneration 2 | from .modeling_bart import BartForConditionalGeneration -------------------------------------------------------------------------------- /longbart/configuration_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ BART configuration """ 16 | 17 | 18 | import logging 19 | 20 | from transformers.configuration_utils import PretrainedConfig 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { 26 | "facebook/bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json", 27 | "facebook/bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json", 28 | "facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", 29 | "facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json", 30 | "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", 31 | } 32 | 33 | 34 | class BartConfig(PretrainedConfig): 35 | r""" 36 | Configuration class for Bart. Parameters are renamed from the fairseq implementation 37 | """ 38 | model_type = "bart" 39 | 40 | def __init__( 41 | self, 42 | activation_dropout=0.0, 43 | activation_function="gelu", 44 | vocab_size=50265, 45 | d_model=1024, 46 | encoder_ffn_dim=4096, 47 | encoder_layers=12, 48 | encoder_attention_heads=16, 49 | decoder_ffn_dim=4096, 50 | decoder_layers=12, 51 | decoder_attention_heads=16, 52 | encoder_layerdrop=0.0, 53 | decoder_layerdrop=0.0, 54 | attention_dropout=0.0, 55 | dropout=0.1, 56 | max_position_embeddings=1024, 57 | encoder_max_position_embeddings=None, 58 | decoder_max_position_embeddings=None, 59 | init_std=0.02, 60 | classifier_dropout=0.0, 61 | num_labels=3, 62 | is_encoder_decoder=True, 63 | pad_token_id=1, 64 | bos_token_id=0, 65 | eos_token_id=2, 66 | normalize_before=False, 67 | add_final_layer_norm=False, 68 | scale_embedding=False, 69 | normalize_embedding=True, 70 | static_position_embeddings=False, 71 | add_bias_logits=False, 72 | gradient_checkpointing=False, 73 | **common_kwargs 74 | ): 75 | r""" 76 | :class:`~transformers.BartConfig` is the configuration class for `BartModel`. 77 | Examples: 78 | config = BartConfig.from_pretrained('bart-large') 79 | model = BartModel(config) 80 | """ 81 | if "hidden_size" in common_kwargs: 82 | raise ValueError("hidden size is called d_model") 83 | super().__init__( 84 | num_labels=num_labels, 85 | pad_token_id=pad_token_id, 86 | bos_token_id=bos_token_id, 87 | eos_token_id=eos_token_id, 88 | is_encoder_decoder=is_encoder_decoder, 89 | **common_kwargs, 90 | ) 91 | self.vocab_size = vocab_size 92 | self.d_model = d_model # encoder_embed_dim and decoder_embed_dim 93 | self.encoder_ffn_dim = encoder_ffn_dim 94 | self.encoder_layers = self.num_hidden_layers = encoder_layers 95 | self.encoder_attention_heads = encoder_attention_heads 96 | self.encoder_layerdrop = encoder_layerdrop 97 | self.decoder_layerdrop = decoder_layerdrop 98 | self.decoder_ffn_dim = decoder_ffn_dim 99 | self.decoder_layers = decoder_layers 100 | self.decoder_attention_heads = decoder_attention_heads 101 | self.init_std = init_std # Normal(0, this parameter) 102 | self.activation_function = activation_function 103 | 104 | self.max_position_embeddings = max_position_embeddings 105 | self.encoder_max_position_embeddings = encoder_max_position_embeddings if encoder_max_position_embeddings else max_position_embeddings 106 | self.decoder_max_position_embeddings = decoder_max_position_embeddings if decoder_max_position_embeddings else max_position_embeddings 107 | 108 | # Params introduced for Mbart 109 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True 110 | self.normalize_embedding = normalize_embedding # True for mbart, False otherwise 111 | self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before 112 | self.add_final_layer_norm = add_final_layer_norm 113 | 114 | # Params introduced for Marian 115 | self.add_bias_logits = add_bias_logits 116 | self.static_position_embeddings = static_position_embeddings 117 | 118 | # 3 Types of Dropout 119 | self.attention_dropout = attention_dropout 120 | self.activation_dropout = activation_dropout 121 | self.dropout = dropout 122 | 123 | # Classifier stuff 124 | self.classif_dropout = classifier_dropout 125 | 126 | # gradient_checkpointing 127 | self.gradient_checkpointing = gradient_checkpointing 128 | self.output_attentions = True 129 | 130 | @property 131 | def num_attention_heads(self) -> int: 132 | return self.encoder_attention_heads 133 | 134 | @property 135 | def hidden_size(self) -> int: 136 | return self.d_model 137 | 138 | def is_valid_mbart(self) -> bool: 139 | """Is the configuration aligned with the MBART paper.""" 140 | if self.normalize_before and self.add_final_layer_norm and self.scale_embedding: 141 | return True 142 | if self.normalize_before or self.add_final_layer_norm or self.scale_embedding: 143 | logger.info("This configuration is a mixture of MBART and BART settings") 144 | return False -------------------------------------------------------------------------------- /longbart/convert_bart_to_longbart.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | from transformers import BartTokenizer 6 | 7 | from .modeling_bart import BartForConditionalGeneration 8 | from .modeling_longbart import LongformerSelfAttentionForBart 9 | 10 | logger = logging.getLogger(__name__) 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | def create_long_model( 14 | save_model_to, 15 | base_model='facebook/bart-large', 16 | tokenizer_name_or_path='facebook/bart-large', 17 | attention_window=1024, 18 | max_pos=4096 19 | ): 20 | model = BartForConditionalGeneration.from_pretrained(base_model) 21 | tokenizer = BartTokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) 22 | config = model.config 23 | 24 | # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention 25 | # expects attention_probs_dropout_prob, so set it here 26 | config.attention_probs_dropout_prob = config.attention_dropout 27 | 28 | # extend position embeddings 29 | tokenizer.model_max_length = max_pos 30 | tokenizer.init_kwargs['model_max_length'] = max_pos 31 | current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape 32 | # config.max_position_embeddings = max_pos 33 | config.encoder_max_position_embeddings = max_pos 34 | max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 35 | assert max_pos > current_max_pos 36 | # allocate a larger position embedding matrix 37 | new_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) 38 | # copy position embeddings over and over to initialize the new position embeddings 39 | k = 2 40 | step = current_max_pos - 2 41 | while k < max_pos - 1: 42 | new_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] 43 | k += step 44 | model.model.encoder.embed_positions.weight.data = new_pos_embed 45 | 46 | # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention` 47 | config.attention_window = [attention_window] * config.num_hidden_layers 48 | for i, layer in enumerate(model.model.encoder.layers): 49 | longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i) 50 | 51 | longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj 52 | longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj 53 | longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj 54 | 55 | longformer_self_attn_for_bart.longformer_self_attn.query_global = layer.self_attn.q_proj 56 | longformer_self_attn_for_bart.longformer_self_attn.key_global = layer.self_attn.k_proj 57 | longformer_self_attn_for_bart.longformer_self_attn.value_global = layer.self_attn.v_proj 58 | 59 | longformer_self_attn_for_bart.output = layer.self_attn.out_proj 60 | 61 | layer.self_attn = longformer_self_attn_for_bart 62 | 63 | logger.info(f'saving model to {save_model_to}') 64 | model.save_pretrained(save_model_to) 65 | tokenizer.save_pretrained(save_model_to) 66 | return model, tokenizer 67 | 68 | 69 | def main(): 70 | parser = argparse.ArgumentParser(description="Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention") 71 | parser.add_argument( 72 | 'base_model', 73 | type=str, 74 | default='facebook/bart-large', 75 | help='The name or path of the base model you want to convert' 76 | ) 77 | parser.add_argument( 78 | 'tokenizer_name_or_path', 79 | type=str, 80 | default='facebook/bart-large', 81 | help='The name or path of the tokenizer' 82 | ) 83 | parser.add_argument( 84 | 'save_model_to', 85 | type=str, 86 | required=True, 87 | help='The path to save the converted model' 88 | ) 89 | parser.add_argument( 90 | 'attention_window', 91 | type=int, 92 | default=1024, 93 | help='attention window size for longformer self attention' 94 | ) 95 | parser.add_argument( 96 | 'max_pos', 97 | type=int, 98 | default=4096, 99 | help='maximum encoder positions' 100 | ) 101 | 102 | args = parser.parse_args() 103 | 104 | if not os.path.exists(args.save_model_to): 105 | os.mkdir(args.save_model_to) 106 | 107 | create_long_model( 108 | save_model_to=args.save_model_to, 109 | base_model=args.base_model, 110 | tokenizer_name_or_path=args.tokenizer_name_or_path, 111 | attention_window=args.attention_window, 112 | max_pos=args.max_pos 113 | ) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() -------------------------------------------------------------------------------- /longbart/modeling_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BART model, ported from the fairseq repo.""" 16 | import logging 17 | import math 18 | import random 19 | from typing import Dict, List, Optional, Tuple 20 | 21 | import numpy as np 22 | import torch 23 | import torch.utils.checkpoint 24 | import torch.nn.functional as F 25 | from torch import Tensor, nn 26 | 27 | from transformers.activations import ACT2FN 28 | from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable 29 | from transformers.modeling_utils import PreTrainedModel, create_position_ids_from_input_ids 30 | 31 | from .configuration_bart import BartConfig 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ 37 | "facebook/bart-large", 38 | "facebook/bart-large-mnli", 39 | "facebook/bart-large-cnn", 40 | "facebook/bart-large-xsum", 41 | "facebook/mbart-large-en-ro", 42 | # See all BART models at https://huggingface.co/models?filter=bart 43 | ] 44 | 45 | 46 | BART_START_DOCSTRING = r""" 47 | 48 | This model is a PyTorch `torch.nn.Module `_ sub-class. Use it as a regular PyTorch Module and 49 | refer to the PyTorch documentation for all matters related to general usage and behavior. 50 | 51 | Parameters: 52 | config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model. 53 | Initializing with a config file does not load the weights associated with the model, only the configuration. 54 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 55 | 56 | """ 57 | BART_GENERATION_EXAMPLE = r""" 58 | Examples:: 59 | 60 | from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig 61 | # see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example 62 | model = BartForConditionalGeneration.from_pretrained('bart-large-cnn') 63 | tokenizer = BartTokenizer.from_pretrained('bart-large-cnn') 64 | ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." 65 | inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') 66 | # Generate Summary 67 | summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) 68 | print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) 69 | 70 | """ 71 | 72 | BART_INPUTS_DOCSTRING = r""" 73 | Args: 74 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 75 | Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them. 76 | Padding will be ignored by default should you provide it. 77 | Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`. 78 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 79 | Mask to avoid performing attention on padding token indices in input_ids. 80 | Mask values selected in ``[0, 1]``: 81 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. 82 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): 83 | Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) 84 | `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. 85 | Used in the cross-attention of the decoder. 86 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): 87 | Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper. 88 | decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): 89 | Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. 90 | If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. 91 | See diagram 1 in the paper for more info on the default strategy 92 | """ 93 | 94 | 95 | def invert_mask(attention_mask): 96 | assert attention_mask.dim() == 2 97 | return attention_mask.eq(0) 98 | 99 | 100 | def _prepare_bart_decoder_inputs( 101 | config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 102 | ): 103 | """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if 104 | none are provided. This mimics the default behavior in fairseq. To override it pass in masks. 105 | Note: this is not called during generation 106 | """ 107 | pad_token_id = config.pad_token_id 108 | if decoder_input_ids is None: 109 | decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) 110 | bsz, tgt_len = decoder_input_ids.size() 111 | if decoder_padding_mask is None: 112 | decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) 113 | else: 114 | decoder_padding_mask = invert_mask(decoder_padding_mask) 115 | causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( 116 | dtype=causal_mask_dtype, device=decoder_input_ids.device 117 | ) 118 | return decoder_input_ids, decoder_padding_mask, causal_mask 119 | 120 | 121 | class PretrainedBartModel(PreTrainedModel): 122 | config_class = BartConfig 123 | base_model_prefix = "model" 124 | 125 | def _init_weights(self, module): 126 | std = self.config.init_std 127 | if isinstance(module, nn.Linear): 128 | module.weight.data.normal_(mean=0.0, std=std) 129 | if module.bias is not None: 130 | module.bias.data.zero_() 131 | elif isinstance(module, SinusoidalPositionalEmbedding): 132 | pass 133 | elif isinstance(module, nn.Embedding): 134 | module.weight.data.normal_(mean=0.0, std=std) 135 | if module.padding_idx is not None: 136 | module.weight.data[module.padding_idx].zero_() 137 | 138 | @property 139 | def dummy_inputs(self): 140 | pad_token = self.config.pad_token_id 141 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 142 | dummy_inputs = { 143 | "attention_mask": input_ids.ne(pad_token), 144 | "input_ids": input_ids, 145 | } 146 | return dummy_inputs 147 | 148 | 149 | def _make_linear_from_emb(emb): 150 | vocab_size, emb_size = emb.weight.shape 151 | lin_layer = nn.Linear(vocab_size, emb_size, bias=False) 152 | lin_layer.weight.data = emb.weight.data 153 | return lin_layer 154 | 155 | 156 | # Helper Functions, mostly for making masks 157 | def _check_shapes(shape_1, shape2): 158 | if shape_1 != shape2: 159 | raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) 160 | 161 | 162 | def shift_tokens_right(input_ids, pad_token_id): 163 | """Shift input ids one token to the right, and wrap the last non pad token (usually ).""" 164 | prev_output_tokens = input_ids.clone() 165 | index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) 166 | prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() 167 | prev_output_tokens[:, 1:] = input_ids[:, :-1] 168 | return prev_output_tokens 169 | 170 | 171 | def make_padding_mask(input_ids, padding_idx=1): 172 | """True for pad tokens""" 173 | padding_mask = input_ids.eq(padding_idx) 174 | if not padding_mask.any(): 175 | padding_mask = None 176 | return padding_mask 177 | 178 | 179 | # Helper Modules 180 | 181 | 182 | class EncoderLayer(nn.Module): 183 | def __init__(self, config: BartConfig): 184 | super().__init__() 185 | self.embed_dim = config.d_model 186 | self.output_attentions = config.output_attentions 187 | self.self_attn = SelfAttention( 188 | self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, 189 | ) 190 | self.normalize_before = config.normalize_before 191 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) 192 | self.dropout = config.dropout 193 | self.activation_fn = ACT2FN[config.activation_function] 194 | self.activation_dropout = config.activation_dropout 195 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) 196 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) 197 | self.final_layer_norm = LayerNorm(self.embed_dim) 198 | 199 | def forward(self, x, encoder_padding_mask): 200 | """ 201 | Args: 202 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` 203 | encoder_padding_mask (ByteTensor): binary ByteTensor of shape 204 | `(batch, src_len)` where padding elements are indicated by ``1``. 205 | for t_tgt, t_src is excluded (or masked out), =0 means it is 206 | included in attention 207 | 208 | Returns: 209 | encoded output of shape `(seq_len, batch, embed_dim)` 210 | """ 211 | residual = x 212 | if self.normalize_before: 213 | x = self.self_attn_layer_norm(x) 214 | x, attn_weights = self.self_attn( 215 | query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions 216 | ) 217 | x = F.dropout(x, p=self.dropout, training=self.training) 218 | x = residual + x 219 | if not self.normalize_before: 220 | x = self.self_attn_layer_norm(x) 221 | 222 | residual = x 223 | if self.normalize_before: 224 | x = self.final_layer_norm(x) 225 | x = self.activation_fn(self.fc1(x)) 226 | x = F.dropout(x, p=self.activation_dropout, training=self.training) 227 | x = self.fc2(x) 228 | x = F.dropout(x, p=self.dropout, training=self.training) 229 | x = residual + x 230 | if not self.normalize_before: 231 | x = self.final_layer_norm(x) 232 | return (x, attn_weights) 233 | 234 | 235 | class BartEncoder(nn.Module): 236 | """ 237 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer 238 | is a :class:`EncoderLayer`. 239 | 240 | Args: 241 | config: BartConfig 242 | """ 243 | 244 | def __init__(self, config: BartConfig, embed_tokens): 245 | super().__init__() 246 | self.config = config 247 | self.dropout = config.dropout 248 | self.layerdrop = config.encoder_layerdrop 249 | self.output_attentions = config.output_attentions 250 | self.output_hidden_states = config.output_hidden_states 251 | 252 | embed_dim = embed_tokens.embedding_dim 253 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 254 | self.padding_idx = embed_tokens.padding_idx 255 | self.max_source_positions = config.encoder_max_position_embeddings 256 | 257 | self.embed_tokens = embed_tokens 258 | if config.static_position_embeddings: 259 | self.embed_positions = SinusoidalPositionalEmbedding( 260 | config.encoder_max_position_embeddings, embed_dim, self.padding_idx 261 | ) 262 | else: 263 | self.embed_positions = LearnedPositionalEmbedding( 264 | config.encoder_max_position_embeddings, embed_dim, self.padding_idx, 265 | ) 266 | self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) 267 | self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() 268 | # mbart has one extra layer_norm 269 | self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None 270 | 271 | def forward( 272 | self, input_ids, attention_mask=None, 273 | ): 274 | """ 275 | Args: 276 | input_ids (LongTensor): tokens in the source language of shape 277 | `(batch, src_len)` 278 | attention_mask (torch.LongTensor): indicating which indices are padding tokens. 279 | Returns: 280 | Tuple comprised of: 281 | - **x** (Tensor): the last encoder layer's output of 282 | shape `(src_len, batch, embed_dim)` 283 | - **encoder_states** (List[Tensor]): all intermediate 284 | hidden states of shape `(src_len, batch, embed_dim)`. 285 | Only populated if *self.output_hidden_states:* is True. 286 | - **all_attentions** (List[Tensor]): Attention weights for each layer. 287 | During training might not be of length n_layers because of layer dropout. 288 | """ 289 | # check attention mask and invert 290 | if attention_mask is not None: 291 | attention_mask = invert_mask(attention_mask) 292 | 293 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 294 | embed_pos = self.embed_positions(input_ids) 295 | x = inputs_embeds + embed_pos 296 | x = self.layernorm_embedding(x) 297 | x = F.dropout(x, p=self.dropout, training=self.training) 298 | 299 | # B x T x C -> T x B x C 300 | x = x.transpose(0, 1) 301 | 302 | encoder_states, all_attentions = [], [] 303 | for encoder_layer in self.layers: 304 | if self.output_hidden_states: 305 | encoder_states.append(x) 306 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 307 | dropout_probability = random.uniform(0, 1) 308 | if self.training and (dropout_probability < self.layerdrop): # skip the layer 309 | attn = None 310 | else: 311 | if getattr(self.config, "gradient_checkpointing", False): 312 | x, attn = torch.utils.checkpoint.checkpoint( 313 | encoder_layer, 314 | x, 315 | attention_mask 316 | ) 317 | else: 318 | x, attn = encoder_layer(x, attention_mask) 319 | 320 | 321 | if self.output_attentions: 322 | all_attentions.append(attn) 323 | 324 | if self.layer_norm: 325 | x = self.layer_norm(x) 326 | if self.output_hidden_states: 327 | encoder_states.append(x) 328 | 329 | # T x B x C -> B x T x C 330 | encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states] 331 | x = x.transpose(0, 1) 332 | 333 | return x, encoder_states, all_attentions 334 | 335 | 336 | class DecoderLayer(nn.Module): 337 | def __init__(self, config: BartConfig): 338 | super().__init__() 339 | self.embed_dim = config.d_model 340 | self.self_attn = SelfAttention( 341 | embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, 342 | ) 343 | self.dropout = config.dropout 344 | self.activation_fn = ACT2FN[config.activation_function] 345 | self.activation_dropout = config.activation_dropout 346 | self.normalize_before = config.normalize_before 347 | 348 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) 349 | self.encoder_attn = SelfAttention( 350 | self.embed_dim, 351 | config.decoder_attention_heads, 352 | dropout=config.attention_dropout, 353 | encoder_decoder_attention=True, 354 | ) 355 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) 356 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 357 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 358 | self.final_layer_norm = LayerNorm(self.embed_dim) 359 | 360 | def forward( 361 | self, 362 | x, 363 | encoder_hidden_states, 364 | encoder_attn_mask=None, 365 | layer_state=None, 366 | causal_mask=None, 367 | decoder_padding_mask=None, 368 | output_attentions=False, 369 | ): 370 | residual = x 371 | 372 | if layer_state is None: 373 | layer_state = {} 374 | if self.normalize_before: 375 | x = self.self_attn_layer_norm(x) 376 | # Self Attention 377 | 378 | x, self_attn_weights = self.self_attn( 379 | query=x, 380 | key=x, 381 | layer_state=layer_state, # adds keys to layer state 382 | key_padding_mask=decoder_padding_mask, 383 | attn_mask=causal_mask, 384 | ) 385 | x = F.dropout(x, p=self.dropout, training=self.training) 386 | x = residual + x 387 | if not self.normalize_before: 388 | x = self.self_attn_layer_norm(x) 389 | 390 | # Cross attention 391 | residual = x 392 | assert self.encoder_attn.cache_key != self.self_attn.cache_key 393 | if self.normalize_before: 394 | x = self.encoder_attn_layer_norm(x) 395 | x, _ = self.encoder_attn( 396 | query=x, 397 | key=encoder_hidden_states, 398 | key_padding_mask=encoder_attn_mask, 399 | layer_state=layer_state, # mutates layer state 400 | ) 401 | x = F.dropout(x, p=self.dropout, training=self.training) 402 | x = residual + x 403 | if not self.normalize_before: 404 | x = self.encoder_attn_layer_norm(x) 405 | 406 | # Fully Connected 407 | residual = x 408 | if self.normalize_before: 409 | x = self.final_layer_norm(x) 410 | x = self.activation_fn(self.fc1(x)) 411 | x = F.dropout(x, p=self.activation_dropout, training=self.training) 412 | x = self.fc2(x) 413 | x = F.dropout(x, p=self.dropout, training=self.training) 414 | x = residual + x 415 | if not self.normalize_before: 416 | x = self.final_layer_norm(x) 417 | return ( 418 | x, 419 | self_attn_weights, 420 | layer_state, 421 | ) # just self_attn weights for now, following t5, layer_state = cache for decoding 422 | 423 | 424 | class BartDecoder(nn.Module): 425 | """ 426 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer 427 | is a :class:`DecoderLayer`. 428 | Args: 429 | config: BartConfig 430 | embed_tokens (torch.nn.Embedding): output embedding 431 | """ 432 | 433 | def __init__(self, config: BartConfig, embed_tokens: nn.Embedding): 434 | super().__init__() 435 | self.output_hidden_states = config.output_hidden_states 436 | self.dropout = config.dropout 437 | self.layerdrop = config.decoder_layerdrop 438 | self.padding_idx = embed_tokens.padding_idx 439 | self.max_target_positions = config.max_position_embeddings 440 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 441 | self.embed_tokens = embed_tokens 442 | if config.static_position_embeddings: 443 | self.embed_positions = SinusoidalPositionalEmbedding( 444 | config.max_position_embeddings, config.d_model, config.pad_token_id 445 | ) 446 | else: 447 | self.embed_positions = LearnedPositionalEmbedding( 448 | config.max_position_embeddings, config.d_model, self.padding_idx, 449 | ) 450 | self.layers = nn.ModuleList( 451 | [DecoderLayer(config) for _ in range(config.decoder_layers)] 452 | ) # type: List[DecoderLayer] 453 | self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() 454 | self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None 455 | 456 | def forward( 457 | self, 458 | input_ids, 459 | encoder_hidden_states, 460 | encoder_padding_mask, 461 | decoder_padding_mask, 462 | decoder_causal_mask, 463 | decoder_cached_states=None, 464 | use_cache=False, 465 | output_attentions=False, 466 | **unused, 467 | ): 468 | """ 469 | Includes several features from "Jointly Learning to Align and 470 | Translate with Transformer Models" (Garg et al., EMNLP 2019). 471 | 472 | Args: 473 | input_ids (LongTensor): previous decoder outputs of shape 474 | `(batch, tgt_len)`, for teacher forcing 475 | encoder_hidden_states: output from the encoder, used for 476 | encoder-side attention 477 | encoder_padding_mask: for ignoring pad tokens 478 | decoder_cached_states (dict or None): dictionary used for storing state during generation 479 | 480 | Returns: 481 | tuple: 482 | - the decoder's features of shape `(batch, tgt_len, embed_dim)` 483 | - hidden states 484 | - attentions 485 | """ 486 | # check attention mask and invert 487 | if encoder_padding_mask is not None: 488 | encoder_padding_mask = invert_mask(encoder_padding_mask) 489 | 490 | # embed positions 491 | positions = self.embed_positions(input_ids, use_cache=use_cache) 492 | 493 | if use_cache: 494 | input_ids = input_ids[:, -1:] 495 | positions = positions[:, -1:] # happens after we embed them 496 | # assert input_ids.ne(self.padding_idx).any() 497 | 498 | x = self.embed_tokens(input_ids) * self.embed_scale 499 | x += positions 500 | x = self.layernorm_embedding(x) 501 | x = F.dropout(x, p=self.dropout, training=self.training) 502 | 503 | # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) 504 | x = x.transpose(0, 1) 505 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1) 506 | 507 | # decoder layers 508 | all_hidden_states = () 509 | all_self_attns = () 510 | next_decoder_cache = [] 511 | for idx, decoder_layer in enumerate(self.layers): 512 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 513 | if self.output_hidden_states: 514 | all_hidden_states += (x,) 515 | dropout_probability = random.uniform(0, 1) 516 | if self.training and (dropout_probability < self.layerdrop): 517 | continue 518 | 519 | layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None 520 | 521 | x, layer_self_attn, layer_past = decoder_layer( 522 | x, 523 | encoder_hidden_states, 524 | encoder_attn_mask=encoder_padding_mask, 525 | decoder_padding_mask=decoder_padding_mask, 526 | layer_state=layer_state, 527 | causal_mask=decoder_causal_mask, 528 | output_attentions=output_attentions, 529 | ) 530 | 531 | if use_cache: 532 | next_decoder_cache.append(layer_past.copy()) 533 | 534 | if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart 535 | x = self.layer_norm(x) 536 | if output_attentions: 537 | all_self_attns += (layer_self_attn,) 538 | 539 | # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) 540 | all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states] 541 | x = x.transpose(0, 1) 542 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1) 543 | 544 | if use_cache: 545 | next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache) 546 | else: 547 | next_cache = None 548 | return x, next_cache, all_hidden_states, list(all_self_attns) 549 | 550 | 551 | def _reorder_buffer(attn_cache, new_order): 552 | for k, input_buffer_k in attn_cache.items(): 553 | if input_buffer_k is not None: 554 | attn_cache[k] = input_buffer_k.index_select(0, new_order) 555 | return attn_cache 556 | 557 | 558 | class SelfAttention(nn.Module): 559 | """Multi-headed attention from 'Attention Is All You Need' paper""" 560 | 561 | def __init__( 562 | self, 563 | embed_dim, 564 | num_heads, 565 | dropout=0.0, 566 | bias=True, 567 | encoder_decoder_attention=False, # otherwise self_attention 568 | ): 569 | super().__init__() 570 | self.embed_dim = embed_dim 571 | self.num_heads = num_heads 572 | self.dropout = dropout 573 | self.head_dim = embed_dim // num_heads 574 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 575 | self.scaling = self.head_dim ** -0.5 576 | 577 | self.encoder_decoder_attention = encoder_decoder_attention 578 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 579 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 580 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 581 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 582 | self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" 583 | 584 | def _shape(self, tensor, dim_0, bsz): 585 | return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1) 586 | 587 | def forward( 588 | self, 589 | query, 590 | key: Optional[Tensor], 591 | key_padding_mask: Optional[Tensor] = None, 592 | layer_state: Optional[Dict[str, Optional[Tensor]]] = None, 593 | attn_mask: Optional[Tensor] = None, 594 | need_weights=False, 595 | ) -> Tuple[Tensor, Optional[Tensor]]: 596 | """Input shape: Time(SeqLen) x Batch x Channel""" 597 | static_kv: bool = self.encoder_decoder_attention 598 | tgt_len, bsz, embed_dim = query.size() 599 | assert embed_dim == self.embed_dim 600 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 601 | # get here for encoder decoder cause of static_kv 602 | if layer_state is not None: # reuse k,v and encoder_padding_mask 603 | saved_state = layer_state.get(self.cache_key, {}) 604 | if "prev_key" in saved_state: 605 | # previous time steps are cached - no need to recompute key and value if they are static 606 | if static_kv: 607 | key = None 608 | else: 609 | saved_state = None 610 | layer_state = {} 611 | 612 | q = self.q_proj(query) * self.scaling 613 | if static_kv: 614 | if key is None: 615 | k = v = None 616 | else: 617 | k = self.k_proj(key) 618 | v = self.v_proj(key) 619 | else: 620 | k = self.k_proj(query) 621 | v = self.v_proj(query) 622 | 623 | q = self._shape(q, tgt_len, bsz) 624 | if k is not None: 625 | k = self._shape(k, -1, bsz) 626 | if v is not None: 627 | v = self._shape(v, -1, bsz) 628 | 629 | if saved_state is not None: 630 | k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) 631 | 632 | # Update cache 633 | layer_state[self.cache_key] = { 634 | "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), 635 | "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), 636 | "prev_key_padding_mask": key_padding_mask if not static_kv else None, 637 | } 638 | 639 | assert k is not None 640 | src_len = k.size(1) 641 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 642 | assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) 643 | 644 | if attn_mask is not None: 645 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask 646 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 647 | 648 | # This is part of a workaround to get around fork/join parallelism not supporting Optional types. 649 | if key_padding_mask is not None and key_padding_mask.dim() == 0: 650 | key_padding_mask = None 651 | assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,) 652 | 653 | if key_padding_mask is not None: # don't attend to padding symbols 654 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 655 | reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2) 656 | attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) 657 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 658 | attn_weights = F.softmax(attn_weights, dim=-1) 659 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,) 660 | 661 | assert v is not None 662 | attn_output = torch.bmm(attn_probs, v) 663 | assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) 664 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 665 | attn_output = self.out_proj(attn_output) 666 | if need_weights: 667 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 668 | else: 669 | attn_weights = None 670 | return attn_output, attn_weights 671 | 672 | def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): 673 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 674 | if "prev_key" in saved_state: 675 | _prev_key = saved_state["prev_key"] 676 | assert _prev_key is not None 677 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) 678 | if static_kv: 679 | k = prev_key 680 | else: 681 | assert k is not None 682 | k = torch.cat([prev_key, k], dim=1) 683 | if "prev_value" in saved_state: 684 | _prev_value = saved_state["prev_value"] 685 | assert _prev_value is not None 686 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) 687 | if static_kv: 688 | v = prev_value 689 | else: 690 | assert v is not None 691 | v = torch.cat([prev_value, v], dim=1) 692 | assert k is not None and v is not None 693 | prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None) 694 | key_padding_mask = self._cat_prev_key_padding_mask( 695 | key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv 696 | ) 697 | return k, v, key_padding_mask 698 | 699 | @staticmethod 700 | def _cat_prev_key_padding_mask( 701 | key_padding_mask: Optional[Tensor], 702 | prev_key_padding_mask: Optional[Tensor], 703 | batch_size: int, 704 | src_len: int, 705 | static_kv: bool, 706 | ) -> Optional[Tensor]: 707 | # saved key padding masks have shape (bsz, seq_len) 708 | if prev_key_padding_mask is not None: 709 | if static_kv: 710 | new_key_padding_mask = prev_key_padding_mask 711 | else: 712 | new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1) 713 | 714 | elif key_padding_mask is not None: 715 | filler = torch.zeros( 716 | batch_size, 717 | src_len - key_padding_mask.size(1), 718 | dtype=key_padding_mask.dtype, 719 | device=key_padding_mask.device, 720 | ) 721 | new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1) 722 | else: 723 | new_key_padding_mask = prev_key_padding_mask 724 | return new_key_padding_mask 725 | 726 | 727 | class BartClassificationHead(nn.Module): 728 | """Head for sentence-level classification tasks.""" 729 | 730 | # This can trivially be shared with RobertaClassificationHead 731 | 732 | def __init__( 733 | self, input_dim, inner_dim, num_classes, pooler_dropout, 734 | ): 735 | super().__init__() 736 | self.dense = nn.Linear(input_dim, inner_dim) 737 | self.dropout = nn.Dropout(p=pooler_dropout) 738 | self.out_proj = nn.Linear(inner_dim, num_classes) 739 | 740 | def forward(self, x): 741 | x = self.dropout(x) 742 | x = self.dense(x) 743 | x = torch.tanh(x) 744 | x = self.dropout(x) 745 | x = self.out_proj(x) 746 | return x 747 | 748 | 749 | class LearnedPositionalEmbedding(nn.Embedding): 750 | """ 751 | This module learns positional embeddings up to a fixed maximum size. 752 | Padding ids are ignored by either offsetting based on padding_idx 753 | or by setting padding_idx to None and ensuring that the appropriate 754 | position ids are passed to the forward function. 755 | """ 756 | 757 | def __init__( 758 | self, num_embeddings: int, embedding_dim: int, padding_idx: int, 759 | ): 760 | # if padding_idx is specified then offset the embedding ids by 761 | # this index and adjust num_embeddings appropriately 762 | assert padding_idx is not None 763 | num_embeddings += padding_idx + 1 # WHY? 764 | super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) 765 | 766 | def forward(self, input, use_cache=False): 767 | """Input is expected to be of size [bsz x seqlen].""" 768 | if use_cache: # the position is our current step in the decoded sequence 769 | pos = int(self.padding_idx + input.size(1)) 770 | positions = input.data.new(1, 1).fill_(pos) 771 | else: 772 | positions = create_position_ids_from_input_ids(input, self.padding_idx) 773 | return super().forward(positions) 774 | 775 | 776 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): 777 | if torch.cuda.is_available(): 778 | try: 779 | from apex.normalization import FusedLayerNorm 780 | 781 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 782 | except ImportError: 783 | pass 784 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 785 | 786 | 787 | def fill_with_neg_inf(t): 788 | """FP16-compatible function that fills a input_ids with -inf.""" 789 | return t.float().fill_(float("-inf")).type_as(t) 790 | 791 | 792 | def _filter_out_falsey_values(tup) -> Tuple: 793 | """Remove entries that are None or [] from an iterable.""" 794 | return tuple(x for x in tup if isinstance(x, torch.Tensor) or x) 795 | 796 | 797 | # Public API 798 | def _get_shape(t): 799 | return getattr(t, "shape", None) 800 | 801 | 802 | @add_start_docstrings( 803 | "The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, 804 | ) 805 | class BartModel(PretrainedBartModel): 806 | def __init__(self, config: BartConfig): 807 | super().__init__(config) 808 | self.output_attentions = config.output_attentions 809 | self.output_hidden_states = config.output_hidden_states 810 | 811 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 812 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 813 | 814 | self.encoder = BartEncoder(config, self.shared) 815 | self.decoder = BartDecoder(config, self.shared) 816 | 817 | self.init_weights() 818 | 819 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) 820 | def forward( 821 | self, 822 | input_ids, 823 | attention_mask=None, 824 | decoder_input_ids=None, 825 | encoder_outputs: Optional[Tuple] = None, 826 | decoder_attention_mask=None, 827 | decoder_cached_states=None, 828 | use_cache=False, 829 | ): 830 | 831 | # make masks if user doesn't supply 832 | if not use_cache: 833 | decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( 834 | self.config, 835 | input_ids, 836 | decoder_input_ids=decoder_input_ids, 837 | decoder_padding_mask=decoder_attention_mask, 838 | causal_mask_dtype=self.shared.weight.dtype, 839 | ) 840 | else: 841 | decoder_padding_mask, causal_mask = None, None 842 | 843 | assert decoder_input_ids is not None 844 | if encoder_outputs is None: 845 | encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) 846 | assert isinstance(encoder_outputs, tuple) 847 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 848 | decoder_outputs = self.decoder( 849 | decoder_input_ids, 850 | encoder_outputs[0], 851 | attention_mask, 852 | decoder_padding_mask, 853 | decoder_causal_mask=causal_mask, 854 | decoder_cached_states=decoder_cached_states, 855 | use_cache=use_cache, 856 | ) 857 | # Attention and hidden_states will be [] or None if they aren't needed 858 | decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs) 859 | assert isinstance(decoder_outputs[0], torch.Tensor) 860 | encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs) 861 | return decoder_outputs + encoder_outputs 862 | 863 | def get_input_embeddings(self): 864 | return self.shared 865 | 866 | def set_input_embeddings(self, value): 867 | self.shared = value 868 | self.encoder.embed_tokens = self.shared 869 | self.decoder.embed_tokens = self.shared 870 | 871 | def get_output_embeddings(self): 872 | return _make_linear_from_emb(self.shared) # make it on the fly 873 | 874 | 875 | @add_start_docstrings( 876 | "The BART Model with a language modeling head. Can be used for summarization.", 877 | BART_START_DOCSTRING + BART_GENERATION_EXAMPLE, 878 | ) 879 | class BartForConditionalGeneration(PretrainedBartModel): 880 | base_model_prefix = "model" 881 | 882 | def __init__(self, config: BartConfig): 883 | super().__init__(config) 884 | base_model = BartModel(config) 885 | self.model = base_model 886 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 887 | 888 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 889 | old_num_tokens = self.model.shared.num_embeddings 890 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 891 | self.model.shared = new_embeddings 892 | self._resize_final_logits_bias(new_num_tokens, old_num_tokens) 893 | return new_embeddings 894 | 895 | def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: 896 | if new_num_tokens <= old_num_tokens: 897 | new_bias = self.final_logits_bias[:, :new_num_tokens] 898 | else: 899 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 900 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 901 | self.register_buffer("final_logits_bias", new_bias) 902 | 903 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) 904 | def forward( 905 | self, 906 | input_ids, 907 | attention_mask=None, 908 | encoder_outputs=None, 909 | decoder_input_ids=None, 910 | decoder_attention_mask=None, 911 | decoder_cached_states=None, 912 | lm_labels=None, 913 | use_cache=False, 914 | **unused 915 | ): 916 | r""" 917 | lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 918 | Labels for computing the masked language modeling loss. 919 | Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). 920 | Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens 921 | with labels 922 | in ``[0, ..., config.vocab_size]``. 923 | 924 | Returns: 925 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: 926 | masked_lm_loss (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 927 | Masked language modeling loss. 928 | prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) 929 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 930 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 931 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 932 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 933 | 934 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 935 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 936 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 937 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 938 | 939 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 940 | heads. 941 | 942 | Examples:: 943 | 944 | # Mask filling only works for bart-large 945 | from transformers import BartTokenizer, BartForConditionalGeneration 946 | tokenizer = BartTokenizer.from_pretrained('bart-large') 947 | TXT = "My friends are but they eat too many carbs." 948 | model = BartForConditionalGeneration.from_pretrained('bart-large') 949 | input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids'] 950 | logits = model(input_ids)[0] 951 | masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() 952 | probs = logits[0, masked_index].softmax(dim=0) 953 | values, predictions = probs.topk(5) 954 | tokenizer.decode(predictions).split() 955 | # ['good', 'great', 'all', 'really', 'very'] 956 | """ 957 | outputs = self.model( 958 | input_ids, 959 | attention_mask=attention_mask, 960 | decoder_input_ids=decoder_input_ids, 961 | encoder_outputs=encoder_outputs, 962 | decoder_attention_mask=decoder_attention_mask, 963 | decoder_cached_states=decoder_cached_states, 964 | use_cache=use_cache, 965 | ) 966 | lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) 967 | outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here 968 | if lm_labels is not None: 969 | loss_fct = nn.CrossEntropyLoss() 970 | # TODO(SS): do we need to ignore pad tokens in lm_labels? 971 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), lm_labels.view(-1)) 972 | outputs = (masked_lm_loss,) + outputs 973 | 974 | return outputs 975 | 976 | def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): 977 | assert past is not None, "past has to be defined for encoder_outputs" 978 | 979 | # first step, decoder_cached_states are empty 980 | if not past[1]: 981 | encoder_outputs, decoder_cached_states = past, None 982 | else: 983 | encoder_outputs, decoder_cached_states = past 984 | return { 985 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 986 | "encoder_outputs": encoder_outputs, 987 | "decoder_cached_states": decoder_cached_states, 988 | "decoder_input_ids": decoder_input_ids, 989 | "attention_mask": attention_mask, 990 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 991 | } 992 | 993 | def prepare_logits_for_generation(self, logits, cur_len, max_length): 994 | if cur_len == 1: 995 | self._force_token_ids_generation(logits, self.config.bos_token_id) 996 | if cur_len == max_length - 1 and self.config.eos_token_id is not None: 997 | self._force_token_ids_generation(logits, self.config.eos_token_id) 998 | return logits 999 | 1000 | def _force_token_ids_generation(self, scores, token_ids) -> None: 1001 | """force one of token_ids to be generated by setting prob of all other tokens to 0""" 1002 | if isinstance(token_ids, int): 1003 | token_ids = [token_ids] 1004 | all_but_token_ids_mask = torch.tensor( 1005 | [x for x in range(self.config.vocab_size) if x not in token_ids], 1006 | dtype=torch.long, 1007 | device=next(self.parameters()).device, 1008 | ) 1009 | assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" 1010 | scores[:, all_but_token_ids_mask] = -float("inf") 1011 | 1012 | @staticmethod 1013 | def _reorder_cache(past, beam_idx): 1014 | ((enc_out, enc_mask), decoder_cached_states) = past 1015 | reordered_past = [] 1016 | for layer_past in decoder_cached_states: 1017 | # get the correct batch idx from decoder layer's batch dim for cross and self-attn 1018 | layer_past_new = { 1019 | attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() 1020 | } 1021 | reordered_past.append(layer_past_new) 1022 | 1023 | new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx) 1024 | new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx) 1025 | 1026 | past = ((new_enc_out, new_enc_mask), reordered_past) 1027 | return past 1028 | 1029 | def get_encoder(self): 1030 | return self.model.encoder 1031 | 1032 | def get_output_embeddings(self): 1033 | return _make_linear_from_emb(self.model.shared) # make it on the fly 1034 | 1035 | 1036 | @add_start_docstrings( 1037 | """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, 1038 | BART_START_DOCSTRING, 1039 | ) 1040 | class BartForSequenceClassification(PretrainedBartModel): 1041 | def __init__(self, config: BartConfig, **kwargs): 1042 | super().__init__(config, **kwargs) 1043 | self.model = BartModel(config) 1044 | self.classification_head = BartClassificationHead( 1045 | config.d_model, config.d_model, config.num_labels, config.classif_dropout, 1046 | ) 1047 | self.model._init_weights(self.classification_head.dense) 1048 | self.model._init_weights(self.classification_head.out_proj) 1049 | 1050 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) 1051 | def forward( 1052 | self, 1053 | input_ids, 1054 | attention_mask=None, 1055 | encoder_outputs=None, 1056 | decoder_input_ids=None, 1057 | decoder_attention_mask=None, 1058 | labels=None, 1059 | ): 1060 | r""" 1061 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): 1062 | Labels for computing the sequence classification/regression loss. 1063 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`. 1064 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1065 | 1066 | Returns: 1067 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs: 1068 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): 1069 | Classification loss (cross entropy) 1070 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): 1071 | Classification (or regression if config.num_labels==1) scores (before SoftMax). 1072 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 1073 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 1074 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 1075 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 1076 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 1077 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 1078 | Attentions weights after the attention softmax, used to compute the weighted average in the 1079 | self-attention 1080 | heads. 1081 | 1082 | Examples:: 1083 | 1084 | from transformers import BartTokenizer, BartForSequenceClassification 1085 | import torch 1086 | 1087 | tokenizer = BartTokenizer.from_pretrained('bart-large') 1088 | model = BartForSequenceClassification.from_pretrained('bart-large') 1089 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", 1090 | add_special_tokens=True)).unsqueeze(0) # Batch size 1 1091 | labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 1092 | outputs = model(input_ids, labels=labels) 1093 | loss, logits = outputs[:2] 1094 | 1095 | """ 1096 | outputs = self.model( 1097 | input_ids, 1098 | attention_mask=attention_mask, 1099 | decoder_input_ids=decoder_input_ids, 1100 | decoder_attention_mask=decoder_attention_mask, 1101 | encoder_outputs=encoder_outputs, 1102 | ) 1103 | x = outputs[0] # last hidden state 1104 | eos_mask = input_ids.eq(self.config.eos_token_id) 1105 | if len(torch.unique(eos_mask.sum(1))) > 1: 1106 | raise ValueError("All examples must have the same number of tokens.") 1107 | sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] 1108 | logits = self.classification_head(sentence_representation) 1109 | # Prepend logits 1110 | outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here 1111 | if labels is not None: # prepend loss to output, 1112 | loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1)) 1113 | outputs = (loss,) + outputs 1114 | 1115 | return outputs 1116 | 1117 | 1118 | class SinusoidalPositionalEmbedding(nn.Embedding): 1119 | """This module produces sinusoidal positional embeddings of any length.""" 1120 | 1121 | def __init__(self, num_positions, embedding_dim, padding_idx=None): 1122 | super().__init__(num_positions, embedding_dim) 1123 | if embedding_dim % 2 != 0: 1124 | raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") 1125 | self.weight = self._init_weight(self.weight) 1126 | 1127 | @staticmethod 1128 | def _init_weight(out: nn.Parameter): 1129 | """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. 1130 | The cos features are in the 2nd half of the vector. [dim // 2:] 1131 | """ 1132 | n_pos, dim = out.shape 1133 | position_enc = np.array( 1134 | [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] 1135 | ) 1136 | out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos 1137 | out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 1138 | out.detach_() 1139 | out.requires_grad = False 1140 | return out 1141 | 1142 | @torch.no_grad() 1143 | def forward(self, input_ids, use_cache=False): 1144 | """Input is expected to be of size [bsz x seqlen].""" 1145 | bsz, seq_len = input_ids.shape[:2] 1146 | if use_cache: 1147 | positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing 1148 | else: 1149 | # starts at 0, ends at 1-seq_len 1150 | positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) 1151 | return super().forward(positions) -------------------------------------------------------------------------------- /longbart/modeling_longbart.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | from torch import Tensor, nn 4 | 5 | from transformers.modeling_longformer import LongformerSelfAttention 6 | 7 | from .modeling_bart import BartForConditionalGeneration 8 | 9 | class LongBartForConditionalGeneration(BartForConditionalGeneration): 10 | def __init__(self, config): 11 | super().__init__(config) 12 | for i, layer in enumerate(self.model.encoder.layers): 13 | # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention` 14 | layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i) 15 | 16 | 17 | class LongformerSelfAttentionForBart(nn.Module): 18 | def __init__(self, config, layer_id): 19 | super().__init__() 20 | self.embed_dim = config.d_model 21 | self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) 22 | self.output = nn.Linear(self.embed_dim, self.embed_dim) 23 | 24 | def forward( 25 | self, 26 | query, 27 | key: Optional[Tensor], 28 | key_padding_mask: Optional[Tensor] = None, 29 | layer_state: Optional[Dict[str, Optional[Tensor]]] = None, 30 | attn_mask: Optional[Tensor] = None, 31 | need_weights=False, 32 | ) -> Tuple[Tensor, Optional[Tensor]]: 33 | 34 | tgt_len, bsz, embed_dim = query.size() 35 | assert embed_dim == self.embed_dim 36 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 37 | 38 | # LongformerSelfAttention expects this shape 39 | query = query.view(bsz, tgt_len, embed_dim) 40 | 41 | outputs = self.longformer_self_attn( 42 | query, 43 | attention_mask=attn_mask, 44 | head_mask=None, 45 | encoder_hidden_states=None, 46 | encoder_attention_mask=None, 47 | ) 48 | 49 | attn_output = outputs[0] 50 | attn_output = attn_output.contiguous().view(tgt_len, bsz, embed_dim) 51 | attn_output = self.output(attn_output) 52 | 53 | return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) 54 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='longbart', 5 | version='0.1', 6 | description='Long version of the BART model', 7 | url='https://github.com/patil-suraj/longbart', 8 | author='Suraj Patil', 9 | author_email='surajp815@gmail.com', 10 | packages=['longbart'], 11 | keywords="NLP deep learning transformer pytorch bart", 12 | install_requires=[ 13 | 'transformers == 2.11.0' 14 | ], 15 | python_requires=">=3.6.0", 16 | classifiers=[ 17 | "Intended Audience :: Developers", 18 | "Intended Audience :: Education", 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.6", 24 | "Programming Language :: Python :: 3.7", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | ], 27 | zip_safe=False 28 | ) --------------------------------------------------------------------------------