├── .gitignore
├── LongBART.ipynb
├── longbart
├── __init__.py
├── configuration_bart.py
├── convert_bart_to_longbart.py
├── modeling_bart.py
└── modeling_longbart.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/LongBART.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "LongBART",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "machine_shape": "hm",
10 | "authorship_tag": "ABX9TyMBu/tl3uAemtoSjaCYca2U",
11 | "include_colab_link": true
12 | },
13 | "kernelspec": {
14 | "name": "python3",
15 | "display_name": "Python 3"
16 | },
17 | "accelerator": "TPU"
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "view-in-github",
24 | "colab_type": "text"
25 | },
26 | "source": [
27 | "
"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "PkmDekoURprl",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "# LongBART"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "FNfNMSjTK61d",
44 | "colab_type": "code",
45 | "outputId": "f2278a45-26ee-48dd-ab62-4a34f8c56662",
46 | "colab": {
47 | "base_uri": "https://localhost:8080/",
48 | "height": 34
49 | }
50 | },
51 | "source": [
52 | "!git clone https://github.com/patil-suraj/longbart.git\n",
53 | "%cd longbart"
54 | ],
55 | "execution_count": 1,
56 | "outputs": [
57 | {
58 | "output_type": "stream",
59 | "text": [
60 | "/content/longbart\n"
61 | ],
62 | "name": "stdout"
63 | }
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "metadata": {
69 | "id": "yOSSddShCaz1",
70 | "colab_type": "code",
71 | "outputId": "8fc06a89-e421-4f67-df2d-004400b60fbc",
72 | "colab": {
73 | "base_uri": "https://localhost:8080/",
74 | "height": 658
75 | }
76 | },
77 | "source": [
78 | "!pip install git+https://github.com/huggingface/transformers.git"
79 | ],
80 | "execution_count": 2,
81 | "outputs": [
82 | {
83 | "output_type": "stream",
84 | "text": [
85 | "Collecting git+https://github.com/huggingface/transformers.git\n",
86 | " Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-69qg24g6\n",
87 | " Running command git clone -q https://github.com/huggingface/transformers.git /tmp/pip-req-build-69qg24g6\n",
88 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (1.18.4)\n",
89 | "Collecting tokenizers==0.7.0\n",
90 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)\n",
91 | "\u001b[K |████████████████████████████████| 3.8MB 3.4MB/s \n",
92 | "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (20.4)\n",
93 | "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (3.0.12)\n",
94 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (2.23.0)\n",
95 | "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (4.41.1)\n",
96 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (2019.12.20)\n",
97 | "Collecting sentencepiece\n",
98 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n",
99 | "\u001b[K |████████████████████████████████| 1.1MB 59.3MB/s \n",
100 | "\u001b[?25hCollecting sacremoses\n",
101 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
102 | "\u001b[K |████████████████████████████████| 890kB 60.2MB/s \n",
103 | "\u001b[?25hRequirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (0.7)\n",
104 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==2.10.0) (1.12.0)\n",
105 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==2.10.0) (2.4.7)\n",
106 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (3.0.4)\n",
107 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (1.24.3)\n",
108 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (2020.4.5.1)\n",
109 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (2.9)\n",
110 | "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.10.0) (7.1.2)\n",
111 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.10.0) (0.15.1)\n",
112 | "Building wheels for collected packages: transformers, sacremoses\n",
113 | " Building wheel for transformers (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
114 | " Created wheel for transformers: filename=transformers-2.10.0-cp36-none-any.whl size=667026 sha256=1fbfcea1f14b529238dfb962701daa0b4df4df60b6927f0793ca24f52b161af8\n",
115 | " Stored in directory: /tmp/pip-ephem-wheel-cache-gv3xom6x/wheels/33/eb/3b/4bf5dd835e865e472d4fc0754f35ac0edb08fe852e8f21655f\n",
116 | " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
117 | " Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=87f114a7eda7007f2e152396969d392ab29f68a5812c6a57e3538111bdf7c32e\n",
118 | " Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
119 | "Successfully built transformers sacremoses\n",
120 | "Installing collected packages: tokenizers, sentencepiece, sacremoses, transformers\n",
121 | "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.7.0 transformers-2.10.0\n"
122 | ],
123 | "name": "stdout"
124 | }
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "metadata": {
130 | "id": "3-EQ0F6qCm_d",
131 | "colab_type": "code",
132 | "colab": {}
133 | },
134 | "source": [
135 | "import logging\n",
136 | "import os\n",
137 | "import math\n",
138 | "from dataclasses import dataclass, field\n",
139 | "from transformers import RobertaForMaskedLM, RobertaTokenizerFast, TextDataset, DataCollatorForLanguageModeling, Trainer\n",
140 | "from transformers import BartTokenizer\n",
141 | "from transformers import TrainingArguments, HfArgumentParser\n",
142 | "from transformers.modeling_longformer import LongformerSelfAttention\n",
143 | "\n",
144 | "from modeling_bart import BartForConditionalGeneration\n",
145 | "\n",
146 | "logger = logging.getLogger(__name__)\n",
147 | "logging.basicConfig(level=logging.INFO)"
148 | ],
149 | "execution_count": 0,
150 | "outputs": []
151 | },
152 | {
153 | "cell_type": "code",
154 | "metadata": {
155 | "id": "oO4RNqIODK9z",
156 | "colab_type": "code",
157 | "outputId": "bc3a13ca-03ed-45e9-f593-490fe0b36d61",
158 | "colab": {
159 | "base_uri": "https://localhost:8080/",
160 | "height": 1000
161 | }
162 | },
163 | "source": [
164 | "# lets use a tiny version of bart for initial experiment \n",
165 | "tokenizer = BartTokenizer.from_pretrained('sshleifer/bart-tiny-random')\n",
166 | "bart = BartForConditionalGeneration.from_pretrained('sshleifer/bart-tiny-random')\n",
167 | "\n",
168 | "# load ROBERta model to see the difference between bart encoder layer and roberta encoder layer \n",
169 | "roberta = RobertaForMaskedLM.from_pretrained('roberta-base')"
170 | ],
171 | "execution_count": 3,
172 | "outputs": [
173 | {
174 | "output_type": "stream",
175 | "text": [
176 | "INFO:transformers.tokenization_utils:Model name 'sshleifer/bart-tiny-random' not found in model shortcut name list (bart-large, bart-large-mnli, bart-large-cnn, bart-large-xsum). Assuming 'sshleifer/bart-tiny-random' is a path, a model identifier, or url to a directory containing tokenizer files.\n",
177 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/vocab.json from cache at /root/.cache/torch/transformers/70b9426bcc7c2cd96de53c16f7e13eabbc8373cecf5c38d68ced2fcc25e3382a.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n",
178 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/merges.txt from cache at /root/.cache/torch/transformers/dc37af6307b1a17037d2d066cb55af9cc1cf55d38d3b1f862221fc8d87b9a672.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n",
179 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/added_tokens.json from cache at None\n",
180 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/special_tokens_map.json from cache at None\n",
181 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/tokenizer_config.json from cache at None\n",
182 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/config.json from cache at /root/.cache/torch/transformers/ce13c5b4dd7e5d8a0d2417a7842224d1535d0cd14dd928809bdb6029e1fa7af3.0a5a7d7a4a1c79b5dce5d054a64dd329deefdcbe16b8cf8a4e825bbed4186047\n",
183 | "INFO:transformers.configuration_utils:Model config BartConfig {\n",
184 | " \"_num_labels\": 3,\n",
185 | " \"activation_dropout\": 0.0,\n",
186 | " \"activation_function\": \"gelu\",\n",
187 | " \"add_bias_logits\": false,\n",
188 | " \"add_final_layer_norm\": false,\n",
189 | " \"architectures\": [\n",
190 | " \"BartForConditionalGeneration\"\n",
191 | " ],\n",
192 | " \"attention_dropout\": 0.0,\n",
193 | " \"bos_token_id\": 0,\n",
194 | " \"classif_dropout\": 0.0,\n",
195 | " \"d_model\": 24,\n",
196 | " \"decoder_attention_heads\": 2,\n",
197 | " \"decoder_ffn_dim\": 16,\n",
198 | " \"decoder_layerdrop\": 0.0,\n",
199 | " \"decoder_layers\": 2,\n",
200 | " \"decoder_max_position_embeddings\": 1024,\n",
201 | " \"decoder_start_token_id\": 2,\n",
202 | " \"dropout\": 0.1,\n",
203 | " \"encoder_attention_heads\": 2,\n",
204 | " \"encoder_ffn_dim\": 16,\n",
205 | " \"encoder_layerdrop\": 0.0,\n",
206 | " \"encoder_layers\": 2,\n",
207 | " \"encoder_max_position_embeddings\": 1024,\n",
208 | " \"eos_token_id\": 2,\n",
209 | " \"id2label\": {\n",
210 | " \"0\": \"LABEL_0\",\n",
211 | " \"1\": \"LABEL_1\",\n",
212 | " \"2\": \"LABEL_2\"\n",
213 | " },\n",
214 | " \"init_std\": 0.02,\n",
215 | " \"is_encoder_decoder\": true,\n",
216 | " \"label2id\": {\n",
217 | " \"LABEL_0\": 0,\n",
218 | " \"LABEL_1\": 1,\n",
219 | " \"LABEL_2\": 2\n",
220 | " },\n",
221 | " \"max_position_embeddings\": 1024,\n",
222 | " \"model_type\": \"bart\",\n",
223 | " \"normalize_before\": false,\n",
224 | " \"normalize_embedding\": true,\n",
225 | " \"num_hidden_layers\": 2,\n",
226 | " \"output_past\": true,\n",
227 | " \"pad_token_id\": 1,\n",
228 | " \"prefix\": \" \",\n",
229 | " \"scale_embedding\": false,\n",
230 | " \"static_position_embeddings\": false,\n",
231 | " \"task_specific_params\": {\n",
232 | " \"summarization\": {\n",
233 | " \"early_stopping\": true,\n",
234 | " \"length_penalty\": 2.0,\n",
235 | " \"max_length\": 142,\n",
236 | " \"min_length\": 56,\n",
237 | " \"no_repeat_ngram_size\": 3,\n",
238 | " \"num_beams\": 4\n",
239 | " }\n",
240 | " },\n",
241 | " \"vocab_size\": 50265\n",
242 | "}\n",
243 | "\n",
244 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/sshleifer/bart-tiny-random/pytorch_model.bin from cache at /root/.cache/torch/transformers/002911b8e4cea0a107864f5b17f20c10f613d256e92e3c1247d6d174fbf56fe5.bf6ebaf6162cfbfbad2ce1909278a9ea1fbfe9284d318bff8bccddfdaa104205\n",
245 | "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n",
246 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /root/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.117c81977c5979de8c088352e74ec6e70f5c66096c28b61d3c50101609b39690\n",
247 | "INFO:transformers.configuration_utils:Model config RobertaConfig {\n",
248 | " \"architectures\": [\n",
249 | " \"RobertaForMaskedLM\"\n",
250 | " ],\n",
251 | " \"attention_probs_dropout_prob\": 0.1,\n",
252 | " \"bos_token_id\": 0,\n",
253 | " \"eos_token_id\": 2,\n",
254 | " \"hidden_act\": \"gelu\",\n",
255 | " \"hidden_dropout_prob\": 0.1,\n",
256 | " \"hidden_size\": 768,\n",
257 | " \"initializer_range\": 0.02,\n",
258 | " \"intermediate_size\": 3072,\n",
259 | " \"layer_norm_eps\": 1e-05,\n",
260 | " \"max_position_embeddings\": 514,\n",
261 | " \"model_type\": \"roberta\",\n",
262 | " \"num_attention_heads\": 12,\n",
263 | " \"num_hidden_layers\": 12,\n",
264 | " \"pad_token_id\": 1,\n",
265 | " \"type_vocab_size\": 1,\n",
266 | " \"vocab_size\": 50265\n",
267 | "}\n",
268 | "\n",
269 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/roberta-base-pytorch_model.bin from cache at /root/.cache/torch/transformers/80b4a484eddeb259bec2f06a6f2f05d90934111628e0e1c09a33bd4a121358e1.49b88ba7ec2c26a7558dda98ca3884c3b80fa31cf43a1b1f23aef3ff81ba344e\n",
270 | "INFO:transformers.modeling_utils:Weights of RobertaForMaskedLM not initialized from pretrained model: ['lm_head.decoder.bias']\n"
271 | ],
272 | "name": "stderr"
273 | }
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "metadata": {
279 | "id": "kZcws3kaMFrp",
280 | "colab_type": "code",
281 | "outputId": "1ef71a3a-5f3c-4ed6-d033-a4b55a259325",
282 | "colab": {
283 | "base_uri": "https://localhost:8080/",
284 | "height": 370
285 | }
286 | },
287 | "source": [
288 | "roberta.config"
289 | ],
290 | "execution_count": 5,
291 | "outputs": [
292 | {
293 | "output_type": "execute_result",
294 | "data": {
295 | "text/plain": [
296 | "RobertaConfig {\n",
297 | " \"architectures\": [\n",
298 | " \"RobertaForMaskedLM\"\n",
299 | " ],\n",
300 | " \"attention_probs_dropout_prob\": 0.1,\n",
301 | " \"bos_token_id\": 0,\n",
302 | " \"eos_token_id\": 2,\n",
303 | " \"hidden_act\": \"gelu\",\n",
304 | " \"hidden_dropout_prob\": 0.1,\n",
305 | " \"hidden_size\": 768,\n",
306 | " \"initializer_range\": 0.02,\n",
307 | " \"intermediate_size\": 3072,\n",
308 | " \"layer_norm_eps\": 1e-05,\n",
309 | " \"max_position_embeddings\": 514,\n",
310 | " \"model_type\": \"roberta\",\n",
311 | " \"num_attention_heads\": 12,\n",
312 | " \"num_hidden_layers\": 12,\n",
313 | " \"pad_token_id\": 1,\n",
314 | " \"type_vocab_size\": 1,\n",
315 | " \"vocab_size\": 50265\n",
316 | "}"
317 | ]
318 | },
319 | "metadata": {
320 | "tags": []
321 | },
322 | "execution_count": 5
323 | }
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "metadata": {
329 | "id": "2ucGw5qrEtp8",
330 | "colab_type": "code",
331 | "outputId": "d844579b-a54a-4799-e783-4945930d6eb7",
332 | "colab": {
333 | "base_uri": "https://localhost:8080/",
334 | "height": 1000
335 | }
336 | },
337 | "source": [
338 | "bart.config"
339 | ],
340 | "execution_count": 6,
341 | "outputs": [
342 | {
343 | "output_type": "execute_result",
344 | "data": {
345 | "text/plain": [
346 | "BartConfig {\n",
347 | " \"_num_labels\": 3,\n",
348 | " \"activation_dropout\": 0.0,\n",
349 | " \"activation_function\": \"gelu\",\n",
350 | " \"add_bias_logits\": false,\n",
351 | " \"add_final_layer_norm\": false,\n",
352 | " \"architectures\": [\n",
353 | " \"BartForConditionalGeneration\"\n",
354 | " ],\n",
355 | " \"attention_dropout\": 0.0,\n",
356 | " \"bos_token_id\": 0,\n",
357 | " \"classif_dropout\": 0.0,\n",
358 | " \"d_model\": 24,\n",
359 | " \"decoder_attention_heads\": 2,\n",
360 | " \"decoder_ffn_dim\": 16,\n",
361 | " \"decoder_layerdrop\": 0.0,\n",
362 | " \"decoder_layers\": 2,\n",
363 | " \"decoder_max_position_embeddings\": 1024,\n",
364 | " \"decoder_start_token_id\": 2,\n",
365 | " \"dropout\": 0.1,\n",
366 | " \"encoder_attention_heads\": 2,\n",
367 | " \"encoder_ffn_dim\": 16,\n",
368 | " \"encoder_layerdrop\": 0.0,\n",
369 | " \"encoder_layers\": 2,\n",
370 | " \"encoder_max_position_embeddings\": 1024,\n",
371 | " \"eos_token_id\": 2,\n",
372 | " \"id2label\": {\n",
373 | " \"0\": \"LABEL_0\",\n",
374 | " \"1\": \"LABEL_1\",\n",
375 | " \"2\": \"LABEL_2\"\n",
376 | " },\n",
377 | " \"init_std\": 0.02,\n",
378 | " \"is_encoder_decoder\": true,\n",
379 | " \"label2id\": {\n",
380 | " \"LABEL_0\": 0,\n",
381 | " \"LABEL_1\": 1,\n",
382 | " \"LABEL_2\": 2\n",
383 | " },\n",
384 | " \"max_position_embeddings\": 1024,\n",
385 | " \"model_type\": \"bart\",\n",
386 | " \"normalize_before\": false,\n",
387 | " \"normalize_embedding\": true,\n",
388 | " \"num_hidden_layers\": 2,\n",
389 | " \"output_past\": true,\n",
390 | " \"pad_token_id\": 1,\n",
391 | " \"prefix\": \" \",\n",
392 | " \"scale_embedding\": false,\n",
393 | " \"static_position_embeddings\": false,\n",
394 | " \"task_specific_params\": {\n",
395 | " \"summarization\": {\n",
396 | " \"early_stopping\": true,\n",
397 | " \"length_penalty\": 2.0,\n",
398 | " \"max_length\": 142,\n",
399 | " \"min_length\": 56,\n",
400 | " \"no_repeat_ngram_size\": 3,\n",
401 | " \"num_beams\": 4\n",
402 | " }\n",
403 | " },\n",
404 | " \"vocab_size\": 50265\n",
405 | "}"
406 | ]
407 | },
408 | "metadata": {
409 | "tags": []
410 | },
411 | "execution_count": 6
412 | }
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "metadata": {
418 | "id": "RNA4Z21mEvGt",
419 | "colab_type": "code",
420 | "colab": {}
421 | },
422 | "source": [
423 | "bart_layer = bart.model.encoder.layers[0]\n",
424 | "roberta_layer = roberta.roberta.encoder.layer[0]"
425 | ],
426 | "execution_count": 0,
427 | "outputs": []
428 | },
429 | {
430 | "cell_type": "code",
431 | "metadata": {
432 | "id": "2KAOBHPdFGn_",
433 | "colab_type": "code",
434 | "outputId": "9121f646-3762-4016-b62d-62523686f8af",
435 | "colab": {
436 | "base_uri": "https://localhost:8080/",
437 | "height": 403
438 | }
439 | },
440 | "source": [
441 | "roberta_layer"
442 | ],
443 | "execution_count": 8,
444 | "outputs": [
445 | {
446 | "output_type": "execute_result",
447 | "data": {
448 | "text/plain": [
449 | "BertLayer(\n",
450 | " (attention): BertAttention(\n",
451 | " (self): BertSelfAttention(\n",
452 | " (query): Linear(in_features=768, out_features=768, bias=True)\n",
453 | " (key): Linear(in_features=768, out_features=768, bias=True)\n",
454 | " (value): Linear(in_features=768, out_features=768, bias=True)\n",
455 | " (dropout): Dropout(p=0.1, inplace=False)\n",
456 | " )\n",
457 | " (output): BertSelfOutput(\n",
458 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
459 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
460 | " (dropout): Dropout(p=0.1, inplace=False)\n",
461 | " )\n",
462 | " )\n",
463 | " (intermediate): BertIntermediate(\n",
464 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
465 | " )\n",
466 | " (output): BertOutput(\n",
467 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
468 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
469 | " (dropout): Dropout(p=0.1, inplace=False)\n",
470 | " )\n",
471 | ")"
472 | ]
473 | },
474 | "metadata": {
475 | "tags": []
476 | },
477 | "execution_count": 8
478 | }
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "metadata": {
484 | "id": "okn8MAVcFDS9",
485 | "colab_type": "code",
486 | "outputId": "2f6a72a3-d7e3-46d6-c81f-a1fc1af85258",
487 | "colab": {
488 | "base_uri": "https://localhost:8080/",
489 | "height": 218
490 | }
491 | },
492 | "source": [
493 | "bart_layer"
494 | ],
495 | "execution_count": 9,
496 | "outputs": [
497 | {
498 | "output_type": "execute_result",
499 | "data": {
500 | "text/plain": [
501 | "EncoderLayer(\n",
502 | " (self_attn): SelfAttention(\n",
503 | " (k_proj): Linear(in_features=24, out_features=24, bias=True)\n",
504 | " (v_proj): Linear(in_features=24, out_features=24, bias=True)\n",
505 | " (q_proj): Linear(in_features=24, out_features=24, bias=True)\n",
506 | " (out_proj): Linear(in_features=24, out_features=24, bias=True)\n",
507 | " )\n",
508 | " (self_attn_layer_norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)\n",
509 | " (fc1): Linear(in_features=24, out_features=16, bias=True)\n",
510 | " (fc2): Linear(in_features=16, out_features=24, bias=True)\n",
511 | " (final_layer_norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)\n",
512 | ")"
513 | ]
514 | },
515 | "metadata": {
516 | "tags": []
517 | },
518 | "execution_count": 9
519 | }
520 | ]
521 | },
522 | {
523 | "cell_type": "markdown",
524 | "metadata": {
525 | "id": "pURj--xZL7Ef",
526 | "colab_type": "text"
527 | },
528 | "source": [
529 | "BART calculates the output projection in the attention layer itself, also the `forward` paramter names of `SelfAttention` layer used in BART are different than that of `BertSelfAttention`. So we'll need to wrap `LongformerSelfAttention` to use it for BART"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "metadata": {
535 | "id": "jfPgsJ8YQR4A",
536 | "colab_type": "code",
537 | "colab": {}
538 | },
539 | "source": [
540 | "import math\n",
541 | "from typing import Dict, List, Optional, Tuple\n",
542 | "\n",
543 | "import torch\n",
544 | "from torch import Tensor, nn\n",
545 | "\n",
546 | "class LongformerSelfAttentionForBart(nn.Module):\n",
547 | " def __init__(self, config, layer_id):\n",
548 | " super().__init__()\n",
549 | " self.embed_dim = config.d_model\n",
550 | " self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)\n",
551 | " self.output = nn.Linear(self.embed_dim, self.embed_dim)\n",
552 | " \n",
553 | " def forward(\n",
554 | " self,\n",
555 | " query,\n",
556 | " key: Optional[Tensor],\n",
557 | " key_padding_mask: Optional[Tensor] = None,\n",
558 | " layer_state: Optional[Dict[str, Optional[Tensor]]] = None,\n",
559 | " attn_mask: Optional[Tensor] = None,\n",
560 | " need_weights=False,\n",
561 | " ) -> Tuple[Tensor, Optional[Tensor]]:\n",
562 | " \n",
563 | " tgt_len, bsz, embed_dim = query.size()\n",
564 | " assert embed_dim == self.embed_dim\n",
565 | " assert list(query.size()) == [tgt_len, bsz, embed_dim]\n",
566 | "\n",
567 | " # LongformerSelfAttention expects this shape\n",
568 | " query = query.view(bsz, tgt_len, embed_dim)\n",
569 | "\n",
570 | " outputs = self.longformer_self_attn(\n",
571 | " query,\n",
572 | " attention_mask=attn_mask,\n",
573 | " head_mask=None,\n",
574 | " encoder_hidden_states=None,\n",
575 | " encoder_attention_mask=None,\n",
576 | " )\n",
577 | "\n",
578 | " attn_output = outputs[0] \n",
579 | " attn_output = attn_output.contiguous().view(tgt_len, bsz, embed_dim)\n",
580 | " attn_output = self.output(attn_output)\n",
581 | "\n",
582 | " return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)"
583 | ],
584 | "execution_count": 0,
585 | "outputs": []
586 | },
587 | {
588 | "cell_type": "code",
589 | "metadata": {
590 | "id": "6VIx_TmOELqF",
591 | "colab_type": "code",
592 | "colab": {}
593 | },
594 | "source": [
595 | "class LongBartForConditionalGeneration(BartForConditionalGeneration):\n",
596 | " def __init__(self, config):\n",
597 | " super().__init__(config)\n",
598 | " for i, layer in enumerate(self.model.encoder.layers):\n",
599 | " # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`\n",
600 | " layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)"
601 | ],
602 | "execution_count": 0,
603 | "outputs": []
604 | },
605 | {
606 | "cell_type": "code",
607 | "metadata": {
608 | "id": "Lx28-eLNEou5",
609 | "colab_type": "code",
610 | "colab": {}
611 | },
612 | "source": [
613 | "def create_long_model(save_model_to, base_model='bart-large', attention_window=512, max_pos=4096):\n",
614 | " model = BartForConditionalGeneration.from_pretrained(base_model)\n",
615 | " tokenizer = BartTokenizer.from_pretrained('bart-large', model_max_length=max_pos)\n",
616 | " config = model.config\n",
617 | "\n",
618 | " # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention\n",
619 | " # expects attention_probs_dropout_prob, so set it here \n",
620 | " config.attention_probs_dropout_prob = config.attention_dropout\n",
621 | "\n",
622 | " # extend position embeddings\n",
623 | " tokenizer.model_max_length = max_pos\n",
624 | " tokenizer.init_kwargs['model_max_length'] = max_pos\n",
625 | " current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape\n",
626 | " # config.max_position_embeddings = max_pos\n",
627 | " config.encoder_max_position_embeddings = max_pos\n",
628 | " max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2\n",
629 | " assert max_pos > current_max_pos\n",
630 | " # allocate a larger position embedding matrix\n",
631 | " new_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)\n",
632 | " # copy position embeddings over and over to initialize the new position embeddings\n",
633 | " k = 2\n",
634 | " step = current_max_pos - 2\n",
635 | " while k < max_pos - 1:\n",
636 | " new_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]\n",
637 | " k += step\n",
638 | " model.model.encoder.embed_positions.weight.data = new_pos_embed\n",
639 | "\n",
640 | " # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`\n",
641 | " config.attention_window = [attention_window] * config.num_hidden_layers\n",
642 | " for i, layer in enumerate(model.model.encoder.layers):\n",
643 | " longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i)\n",
644 | " \n",
645 | " longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj\n",
646 | " longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj\n",
647 | " longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj\n",
648 | "\n",
649 | " longformer_self_attn_for_bart.longformer_self_attn.query_global = layer.self_attn.q_proj\n",
650 | " longformer_self_attn_for_bart.longformer_self_attn.key_global = layer.self_attn.k_proj\n",
651 | " longformer_self_attn_for_bart.longformer_self_attn.value_global = layer.self_attn.v_proj\n",
652 | "\n",
653 | " longformer_self_attn_for_bart.output = layer.self_attn.out_proj\n",
654 | "\n",
655 | " layer.self_attn = longformer_self_attn_for_bart\n",
656 | "\n",
657 | " logger.info(f'saving model to {save_model_to}')\n",
658 | " model.save_pretrained(save_model_to)\n",
659 | " tokenizer.save_pretrained(save_model_to)\n",
660 | " return model, tokenizer"
661 | ],
662 | "execution_count": 0,
663 | "outputs": []
664 | },
665 | {
666 | "cell_type": "code",
667 | "metadata": {
668 | "id": "XuYkE_kGO-U-",
669 | "colab_type": "code",
670 | "outputId": "ee5176ee-677c-406b-feb2-1d1aa5ea2b23",
671 | "colab": {
672 | "base_uri": "https://localhost:8080/",
673 | "height": 1000
674 | }
675 | },
676 | "source": [
677 | "# model_path = f'{training_args.output_dir}/roberta-base-{model_args.max_pos}'\n",
678 | "base_model = \"sshleifer/bart-tiny-random\"\n",
679 | "model_path = \"bart-tiny-random-4096\"\n",
680 | "attention_window = 512\n",
681 | "max_pos = 4096\n",
682 | "\n",
683 | "if not os.path.exists(model_path):\n",
684 | " os.makedirs(model_path)\n",
685 | "\n",
686 | "# logger.info(f'Converting roberta-base into roberta-base-{model_args.max_pos}')\n",
687 | "model, tokenizer = create_long_model(\n",
688 | " save_model_to=model_path,\n",
689 | " base_model=base_model,\n",
690 | " attention_window=attention_window,\n",
691 | " max_pos=max_pos\n",
692 | ")"
693 | ],
694 | "execution_count": 13,
695 | "outputs": [
696 | {
697 | "output_type": "stream",
698 | "text": [
699 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/config.json from cache at /root/.cache/torch/transformers/ce13c5b4dd7e5d8a0d2417a7842224d1535d0cd14dd928809bdb6029e1fa7af3.0a5a7d7a4a1c79b5dce5d054a64dd329deefdcbe16b8cf8a4e825bbed4186047\n",
700 | "INFO:transformers.configuration_utils:Model config BartConfig {\n",
701 | " \"_num_labels\": 3,\n",
702 | " \"activation_dropout\": 0.0,\n",
703 | " \"activation_function\": \"gelu\",\n",
704 | " \"add_bias_logits\": false,\n",
705 | " \"add_final_layer_norm\": false,\n",
706 | " \"architectures\": [\n",
707 | " \"BartForConditionalGeneration\"\n",
708 | " ],\n",
709 | " \"attention_dropout\": 0.0,\n",
710 | " \"bos_token_id\": 0,\n",
711 | " \"classif_dropout\": 0.0,\n",
712 | " \"d_model\": 24,\n",
713 | " \"decoder_attention_heads\": 2,\n",
714 | " \"decoder_ffn_dim\": 16,\n",
715 | " \"decoder_layerdrop\": 0.0,\n",
716 | " \"decoder_layers\": 2,\n",
717 | " \"decoder_max_position_embeddings\": 1024,\n",
718 | " \"decoder_start_token_id\": 2,\n",
719 | " \"dropout\": 0.1,\n",
720 | " \"encoder_attention_heads\": 2,\n",
721 | " \"encoder_ffn_dim\": 16,\n",
722 | " \"encoder_layerdrop\": 0.0,\n",
723 | " \"encoder_layers\": 2,\n",
724 | " \"encoder_max_position_embeddings\": 1024,\n",
725 | " \"eos_token_id\": 2,\n",
726 | " \"id2label\": {\n",
727 | " \"0\": \"LABEL_0\",\n",
728 | " \"1\": \"LABEL_1\",\n",
729 | " \"2\": \"LABEL_2\"\n",
730 | " },\n",
731 | " \"init_std\": 0.02,\n",
732 | " \"is_encoder_decoder\": true,\n",
733 | " \"label2id\": {\n",
734 | " \"LABEL_0\": 0,\n",
735 | " \"LABEL_1\": 1,\n",
736 | " \"LABEL_2\": 2\n",
737 | " },\n",
738 | " \"max_position_embeddings\": 1024,\n",
739 | " \"model_type\": \"bart\",\n",
740 | " \"normalize_before\": false,\n",
741 | " \"normalize_embedding\": true,\n",
742 | " \"num_hidden_layers\": 2,\n",
743 | " \"output_past\": true,\n",
744 | " \"pad_token_id\": 1,\n",
745 | " \"prefix\": \" \",\n",
746 | " \"scale_embedding\": false,\n",
747 | " \"static_position_embeddings\": false,\n",
748 | " \"task_specific_params\": {\n",
749 | " \"summarization\": {\n",
750 | " \"early_stopping\": true,\n",
751 | " \"length_penalty\": 2.0,\n",
752 | " \"max_length\": 142,\n",
753 | " \"min_length\": 56,\n",
754 | " \"no_repeat_ngram_size\": 3,\n",
755 | " \"num_beams\": 4\n",
756 | " }\n",
757 | " },\n",
758 | " \"vocab_size\": 50265\n",
759 | "}\n",
760 | "\n",
761 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/sshleifer/bart-tiny-random/pytorch_model.bin from cache at /root/.cache/torch/transformers/002911b8e4cea0a107864f5b17f20c10f613d256e92e3c1247d6d174fbf56fe5.bf6ebaf6162cfbfbad2ce1909278a9ea1fbfe9284d318bff8bccddfdaa104205\n",
762 | "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n",
763 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n",
764 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n",
765 | "INFO:__main__:saving model to bart-tiny-random-4096\n",
766 | "INFO:transformers.configuration_utils:Configuration saved in bart-tiny-random-4096/config.json\n",
767 | "INFO:transformers.modeling_utils:Model weights saved in bart-tiny-random-4096/pytorch_model.bin\n"
768 | ],
769 | "name": "stderr"
770 | }
771 | ]
772 | },
773 | {
774 | "cell_type": "code",
775 | "metadata": {
776 | "id": "XsnW-waPU3Ua",
777 | "colab_type": "code",
778 | "outputId": "5f6d97f5-2863-4535-969a-1fa884b5635e",
779 | "colab": {
780 | "base_uri": "https://localhost:8080/",
781 | "height": 1000
782 | }
783 | },
784 | "source": [
785 | "long_model_tiny = LongBartForConditionalGeneration.from_pretrained('bart-tiny-random-4096')"
786 | ],
787 | "execution_count": 14,
788 | "outputs": [
789 | {
790 | "output_type": "stream",
791 | "text": [
792 | "INFO:transformers.configuration_utils:loading configuration file bart-tiny-random-4096/config.json\n",
793 | "INFO:transformers.configuration_utils:Model config BartConfig {\n",
794 | " \"_num_labels\": 3,\n",
795 | " \"activation_dropout\": 0.0,\n",
796 | " \"activation_function\": \"gelu\",\n",
797 | " \"add_bias_logits\": false,\n",
798 | " \"add_final_layer_norm\": false,\n",
799 | " \"architectures\": [\n",
800 | " \"BartForConditionalGeneration\"\n",
801 | " ],\n",
802 | " \"attention_dropout\": 0.0,\n",
803 | " \"attention_probs_dropout_prob\": 0.0,\n",
804 | " \"attention_window\": [\n",
805 | " 512,\n",
806 | " 512\n",
807 | " ],\n",
808 | " \"bos_token_id\": 0,\n",
809 | " \"classif_dropout\": 0.0,\n",
810 | " \"d_model\": 24,\n",
811 | " \"decoder_attention_heads\": 2,\n",
812 | " \"decoder_ffn_dim\": 16,\n",
813 | " \"decoder_layerdrop\": 0.0,\n",
814 | " \"decoder_layers\": 2,\n",
815 | " \"decoder_max_position_embeddings\": 1024,\n",
816 | " \"decoder_start_token_id\": 2,\n",
817 | " \"dropout\": 0.1,\n",
818 | " \"encoder_attention_heads\": 2,\n",
819 | " \"encoder_ffn_dim\": 16,\n",
820 | " \"encoder_layerdrop\": 0.0,\n",
821 | " \"encoder_layers\": 2,\n",
822 | " \"encoder_max_position_embeddings\": 4096,\n",
823 | " \"eos_token_id\": 2,\n",
824 | " \"id2label\": {\n",
825 | " \"0\": \"LABEL_0\",\n",
826 | " \"1\": \"LABEL_1\",\n",
827 | " \"2\": \"LABEL_2\"\n",
828 | " },\n",
829 | " \"init_std\": 0.02,\n",
830 | " \"is_encoder_decoder\": true,\n",
831 | " \"label2id\": {\n",
832 | " \"LABEL_0\": 0,\n",
833 | " \"LABEL_1\": 1,\n",
834 | " \"LABEL_2\": 2\n",
835 | " },\n",
836 | " \"max_position_embeddings\": 1024,\n",
837 | " \"model_type\": \"bart\",\n",
838 | " \"normalize_before\": false,\n",
839 | " \"normalize_embedding\": true,\n",
840 | " \"num_hidden_layers\": 2,\n",
841 | " \"output_past\": true,\n",
842 | " \"pad_token_id\": 1,\n",
843 | " \"prefix\": \" \",\n",
844 | " \"scale_embedding\": false,\n",
845 | " \"static_position_embeddings\": false,\n",
846 | " \"task_specific_params\": {\n",
847 | " \"summarization\": {\n",
848 | " \"early_stopping\": true,\n",
849 | " \"length_penalty\": 2.0,\n",
850 | " \"max_length\": 142,\n",
851 | " \"min_length\": 56,\n",
852 | " \"no_repeat_ngram_size\": 3,\n",
853 | " \"num_beams\": 4\n",
854 | " }\n",
855 | " },\n",
856 | " \"vocab_size\": 50265\n",
857 | "}\n",
858 | "\n",
859 | "INFO:transformers.modeling_utils:loading weights file bart-tiny-random-4096/pytorch_model.bin\n"
860 | ],
861 | "name": "stderr"
862 | }
863 | ]
864 | },
865 | {
866 | "cell_type": "code",
867 | "metadata": {
868 | "id": "Z5QKIIdeYRDL",
869 | "colab_type": "code",
870 | "outputId": "dd326b65-6bc7-4a91-e670-e41b6f64784f",
871 | "colab": {
872 | "base_uri": "https://localhost:8080/",
873 | "height": 34
874 | }
875 | },
876 | "source": [
877 | "TXT = \"My friends are but they eat too many carbs.\"\n",
878 | "\n",
879 | "input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt', max_length=4096, pad_to_max_length=True)['input_ids']\n",
880 | "\n",
881 | "logits = long_model_tiny(input_ids)[0]\n",
882 | "\n",
883 | "masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()\n",
884 | "probs = logits[0, masked_index].softmax(dim=0)\n",
885 | "values, predictions = probs.topk(5)\n",
886 | "tokenizer.decode(predictions).split()"
887 | ],
888 | "execution_count": 15,
889 | "outputs": [
890 | {
891 | "output_type": "execute_result",
892 | "data": {
893 | "text/plain": [
894 | "['.']"
895 | ]
896 | },
897 | "metadata": {
898 | "tags": []
899 | },
900 | "execution_count": 15
901 | }
902 | ]
903 | },
904 | {
905 | "cell_type": "markdown",
906 | "metadata": {
907 | "id": "FNmzwNHAN1AI",
908 | "colab_type": "text"
909 | },
910 | "source": [
911 | "Now lets try with bart-large"
912 | ]
913 | },
914 | {
915 | "cell_type": "code",
916 | "metadata": {
917 | "id": "vAzZdj-1N3Is",
918 | "colab_type": "code",
919 | "outputId": "85cf0a32-fc89-4137-f61e-d2710ea85bc4",
920 | "colab": {
921 | "base_uri": "https://localhost:8080/",
922 | "height": 1000
923 | }
924 | },
925 | "source": [
926 | "# model_path = f'{training_args.output_dir}/roberta-base-{model_args.max_pos}'\n",
927 | "base_model = \"bart-large\"\n",
928 | "model_path = \"bart-large-4096\"\n",
929 | "attention_window = 512\n",
930 | "max_pos = 4096\n",
931 | "\n",
932 | "if not os.path.exists(model_path):\n",
933 | " os.makedirs(model_path)\n",
934 | "\n",
935 | "# logger.info(f'Converting roberta-base into roberta-base-{model_args.max_pos}')\n",
936 | "model, tokenizer = create_long_model(\n",
937 | " save_model_to=model_path,\n",
938 | " base_model=base_model,\n",
939 | " attention_window=attention_window,\n",
940 | " max_pos=max_pos\n",
941 | ")"
942 | ],
943 | "execution_count": 16,
944 | "outputs": [
945 | {
946 | "output_type": "stream",
947 | "text": [
948 | "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json from cache at /root/.cache/torch/transformers/7f6632e580b7d9fd4f611dd96dab877cccfc319867b53b8b72fddca7fd64de5c.40bd49bcec9d93d8b0bfbd020088e2e1b6e6bb03e8e80aa5144638f90ca6bd61\n",
949 | "INFO:transformers.configuration_utils:Model config BartConfig {\n",
950 | " \"_num_labels\": 3,\n",
951 | " \"activation_dropout\": 0.0,\n",
952 | " \"activation_function\": \"gelu\",\n",
953 | " \"add_bias_logits\": false,\n",
954 | " \"add_final_layer_norm\": false,\n",
955 | " \"architectures\": [\n",
956 | " \"BartModel\",\n",
957 | " \"BartForMaskedLM\",\n",
958 | " \"BartForSequenceClassification\"\n",
959 | " ],\n",
960 | " \"attention_dropout\": 0.0,\n",
961 | " \"bos_token_id\": 0,\n",
962 | " \"classif_dropout\": 0.0,\n",
963 | " \"d_model\": 1024,\n",
964 | " \"decoder_attention_heads\": 16,\n",
965 | " \"decoder_ffn_dim\": 4096,\n",
966 | " \"decoder_layerdrop\": 0.0,\n",
967 | " \"decoder_layers\": 12,\n",
968 | " \"decoder_max_position_embeddings\": 1024,\n",
969 | " \"decoder_start_token_id\": 2,\n",
970 | " \"dropout\": 0.1,\n",
971 | " \"encoder_attention_heads\": 16,\n",
972 | " \"encoder_ffn_dim\": 4096,\n",
973 | " \"encoder_layerdrop\": 0.0,\n",
974 | " \"encoder_layers\": 12,\n",
975 | " \"encoder_max_position_embeddings\": 1024,\n",
976 | " \"eos_token_id\": 2,\n",
977 | " \"id2label\": {\n",
978 | " \"0\": \"LABEL_0\",\n",
979 | " \"1\": \"LABEL_1\",\n",
980 | " \"2\": \"LABEL_2\"\n",
981 | " },\n",
982 | " \"init_std\": 0.02,\n",
983 | " \"is_encoder_decoder\": true,\n",
984 | " \"label2id\": {\n",
985 | " \"LABEL_0\": 0,\n",
986 | " \"LABEL_1\": 1,\n",
987 | " \"LABEL_2\": 2\n",
988 | " },\n",
989 | " \"max_position_embeddings\": 1024,\n",
990 | " \"model_type\": \"bart\",\n",
991 | " \"normalize_before\": false,\n",
992 | " \"normalize_embedding\": true,\n",
993 | " \"num_hidden_layers\": 12,\n",
994 | " \"output_past\": false,\n",
995 | " \"pad_token_id\": 1,\n",
996 | " \"prefix\": \" \",\n",
997 | " \"scale_embedding\": false,\n",
998 | " \"static_position_embeddings\": false,\n",
999 | " \"task_specific_params\": {\n",
1000 | " \"summarization\": {\n",
1001 | " \"early_stopping\": true,\n",
1002 | " \"length_penalty\": 2.0,\n",
1003 | " \"max_length\": 142,\n",
1004 | " \"min_length\": 56,\n",
1005 | " \"no_repeat_ngram_size\": 3,\n",
1006 | " \"num_beams\": 4\n",
1007 | " }\n",
1008 | " },\n",
1009 | " \"vocab_size\": 50265\n",
1010 | "}\n",
1011 | "\n",
1012 | "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/facebook/bart-large/pytorch_model.bin from cache at /root/.cache/torch/transformers/2e7cae41bb1dd1f18e498ff4ff0ea85f7e9bc2b637439e2d95c485c5d5bdd579.5be2a88ec29f5969270f98902db392beab8be8a6a7ecc588d410ada3e32c4263\n",
1013 | "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n",
1014 | "INFO:transformers.modeling_utils:Weights from pretrained model not used in BartForConditionalGeneration: ['encoder.version', 'decoder.version']\n",
1015 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n",
1016 | "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n",
1017 | "INFO:__main__:saving model to bart-large-4096\n",
1018 | "INFO:transformers.configuration_utils:Configuration saved in bart-large-4096/config.json\n",
1019 | "INFO:transformers.modeling_utils:Model weights saved in bart-large-4096/pytorch_model.bin\n"
1020 | ],
1021 | "name": "stderr"
1022 | }
1023 | ]
1024 | },
1025 | {
1026 | "cell_type": "code",
1027 | "metadata": {
1028 | "id": "9-lgWprfN8QW",
1029 | "colab_type": "code",
1030 | "outputId": "c1a817ea-3a05-4676-acea-81b27b3f6591",
1031 | "colab": {
1032 | "base_uri": "https://localhost:8080/",
1033 | "height": 1000
1034 | }
1035 | },
1036 | "source": [
1037 | "long_model = LongBartForConditionalGeneration.from_pretrained('bart-large-4096')\n",
1038 | "tokenizer = BartTokenizer.from_pretrained('bart-large-4096')"
1039 | ],
1040 | "execution_count": 7,
1041 | "outputs": [
1042 | {
1043 | "output_type": "stream",
1044 | "text": [
1045 | "INFO:transformers.configuration_utils:loading configuration file bart-large-4096/config.json\n",
1046 | "INFO:transformers.configuration_utils:Model config BartConfig {\n",
1047 | " \"_num_labels\": 3,\n",
1048 | " \"activation_dropout\": 0.0,\n",
1049 | " \"activation_function\": \"gelu\",\n",
1050 | " \"add_bias_logits\": false,\n",
1051 | " \"add_final_layer_norm\": false,\n",
1052 | " \"architectures\": [\n",
1053 | " \"BartForConditionalGeneration\"\n",
1054 | " ],\n",
1055 | " \"attention_dropout\": 0.0,\n",
1056 | " \"attention_probs_dropout_prob\": 0.0,\n",
1057 | " \"attention_window\": [\n",
1058 | " 512,\n",
1059 | " 512,\n",
1060 | " 512,\n",
1061 | " 512,\n",
1062 | " 512,\n",
1063 | " 512,\n",
1064 | " 512,\n",
1065 | " 512,\n",
1066 | " 512,\n",
1067 | " 512,\n",
1068 | " 512,\n",
1069 | " 512\n",
1070 | " ],\n",
1071 | " \"bos_token_id\": 0,\n",
1072 | " \"classif_dropout\": 0.0,\n",
1073 | " \"d_model\": 1024,\n",
1074 | " \"decoder_attention_heads\": 16,\n",
1075 | " \"decoder_ffn_dim\": 4096,\n",
1076 | " \"decoder_layerdrop\": 0.0,\n",
1077 | " \"decoder_layers\": 12,\n",
1078 | " \"decoder_max_position_embeddings\": 1024,\n",
1079 | " \"decoder_start_token_id\": 2,\n",
1080 | " \"dropout\": 0.1,\n",
1081 | " \"encoder_attention_heads\": 16,\n",
1082 | " \"encoder_ffn_dim\": 4096,\n",
1083 | " \"encoder_layerdrop\": 0.0,\n",
1084 | " \"encoder_layers\": 12,\n",
1085 | " \"encoder_max_position_embeddings\": 4096,\n",
1086 | " \"eos_token_id\": 2,\n",
1087 | " \"id2label\": {\n",
1088 | " \"0\": \"LABEL_0\",\n",
1089 | " \"1\": \"LABEL_1\",\n",
1090 | " \"2\": \"LABEL_2\"\n",
1091 | " },\n",
1092 | " \"init_std\": 0.02,\n",
1093 | " \"is_encoder_decoder\": true,\n",
1094 | " \"label2id\": {\n",
1095 | " \"LABEL_0\": 0,\n",
1096 | " \"LABEL_1\": 1,\n",
1097 | " \"LABEL_2\": 2\n",
1098 | " },\n",
1099 | " \"max_position_embeddings\": 1024,\n",
1100 | " \"model_type\": \"bart\",\n",
1101 | " \"normalize_before\": false,\n",
1102 | " \"normalize_embedding\": true,\n",
1103 | " \"num_hidden_layers\": 12,\n",
1104 | " \"output_past\": false,\n",
1105 | " \"pad_token_id\": 1,\n",
1106 | " \"prefix\": \" \",\n",
1107 | " \"scale_embedding\": false,\n",
1108 | " \"static_position_embeddings\": false,\n",
1109 | " \"task_specific_params\": {\n",
1110 | " \"summarization\": {\n",
1111 | " \"early_stopping\": true,\n",
1112 | " \"length_penalty\": 2.0,\n",
1113 | " \"max_length\": 142,\n",
1114 | " \"min_length\": 56,\n",
1115 | " \"no_repeat_ngram_size\": 3,\n",
1116 | " \"num_beams\": 4\n",
1117 | " }\n",
1118 | " },\n",
1119 | " \"vocab_size\": 50265\n",
1120 | "}\n",
1121 | "\n",
1122 | "INFO:transformers.modeling_utils:loading weights file bart-large-4096/pytorch_model.bin\n",
1123 | "INFO:transformers.tokenization_utils:Model name 'bart-large-4096' not found in model shortcut name list (bart-large, bart-large-mnli, bart-large-cnn, bart-large-xsum). Assuming 'bart-large-4096' is a path, a model identifier, or url to a directory containing tokenizer files.\n",
1124 | "INFO:transformers.tokenization_utils:Didn't find file bart-large-4096/added_tokens.json. We won't load it.\n",
1125 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/vocab.json\n",
1126 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/merges.txt\n",
1127 | "INFO:transformers.tokenization_utils:loading file None\n",
1128 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/special_tokens_map.json\n",
1129 | "INFO:transformers.tokenization_utils:loading file bart-large-4096/tokenizer_config.json\n"
1130 | ],
1131 | "name": "stderr"
1132 | }
1133 | ]
1134 | },
1135 | {
1136 | "cell_type": "code",
1137 | "metadata": {
1138 | "id": "cLhZFQMYONPb",
1139 | "colab_type": "code",
1140 | "colab": {
1141 | "base_uri": "https://localhost:8080/",
1142 | "height": 34
1143 | },
1144 | "outputId": "f4eda1b3-5333-4144-bdd4-9804046dd30a"
1145 | },
1146 | "source": [
1147 | "TXT = \"My friends are but they eat too many carbs.\"\n",
1148 | "\n",
1149 | "# 4096 seq len crashes even with 35 GB memory\n",
1150 | "# so we also probably need sliding-window attention in decoder as well\n",
1151 | "input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt', max_length=2560, pad_to_max_length=True)['input_ids']\n",
1152 | "\n",
1153 | "logits = long_model(input_ids)[0]\n",
1154 | "\n",
1155 | "masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()\n",
1156 | "probs = logits[0, masked_index].softmax(dim=0)\n",
1157 | "values, predictions = probs.topk(5)\n",
1158 | "tokenizer.decode(predictions).split()"
1159 | ],
1160 | "execution_count": 8,
1161 | "outputs": [
1162 | {
1163 | "output_type": "execute_result",
1164 | "data": {
1165 | "text/plain": [
1166 | "['having', 'still', 'going', 'getting', 'not']"
1167 | ]
1168 | },
1169 | "metadata": {
1170 | "tags": []
1171 | },
1172 | "execution_count": 8
1173 | }
1174 | ]
1175 | }
1176 | ]
1177 | }
--------------------------------------------------------------------------------
/longbart/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_longbart import LongformerSelfAttentionForBart, LongBartForConditionalGeneration
2 | from .modeling_bart import BartForConditionalGeneration
--------------------------------------------------------------------------------
/longbart/configuration_bart.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ BART configuration """
16 |
17 |
18 | import logging
19 |
20 | from transformers.configuration_utils import PretrainedConfig
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26 | "facebook/bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
27 | "facebook/bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
28 | "facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
29 | "facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
30 | "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
31 | }
32 |
33 |
34 | class BartConfig(PretrainedConfig):
35 | r"""
36 | Configuration class for Bart. Parameters are renamed from the fairseq implementation
37 | """
38 | model_type = "bart"
39 |
40 | def __init__(
41 | self,
42 | activation_dropout=0.0,
43 | activation_function="gelu",
44 | vocab_size=50265,
45 | d_model=1024,
46 | encoder_ffn_dim=4096,
47 | encoder_layers=12,
48 | encoder_attention_heads=16,
49 | decoder_ffn_dim=4096,
50 | decoder_layers=12,
51 | decoder_attention_heads=16,
52 | encoder_layerdrop=0.0,
53 | decoder_layerdrop=0.0,
54 | attention_dropout=0.0,
55 | dropout=0.1,
56 | max_position_embeddings=1024,
57 | encoder_max_position_embeddings=None,
58 | decoder_max_position_embeddings=None,
59 | init_std=0.02,
60 | classifier_dropout=0.0,
61 | num_labels=3,
62 | is_encoder_decoder=True,
63 | pad_token_id=1,
64 | bos_token_id=0,
65 | eos_token_id=2,
66 | normalize_before=False,
67 | add_final_layer_norm=False,
68 | scale_embedding=False,
69 | normalize_embedding=True,
70 | static_position_embeddings=False,
71 | add_bias_logits=False,
72 | gradient_checkpointing=False,
73 | **common_kwargs
74 | ):
75 | r"""
76 | :class:`~transformers.BartConfig` is the configuration class for `BartModel`.
77 | Examples:
78 | config = BartConfig.from_pretrained('bart-large')
79 | model = BartModel(config)
80 | """
81 | if "hidden_size" in common_kwargs:
82 | raise ValueError("hidden size is called d_model")
83 | super().__init__(
84 | num_labels=num_labels,
85 | pad_token_id=pad_token_id,
86 | bos_token_id=bos_token_id,
87 | eos_token_id=eos_token_id,
88 | is_encoder_decoder=is_encoder_decoder,
89 | **common_kwargs,
90 | )
91 | self.vocab_size = vocab_size
92 | self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
93 | self.encoder_ffn_dim = encoder_ffn_dim
94 | self.encoder_layers = self.num_hidden_layers = encoder_layers
95 | self.encoder_attention_heads = encoder_attention_heads
96 | self.encoder_layerdrop = encoder_layerdrop
97 | self.decoder_layerdrop = decoder_layerdrop
98 | self.decoder_ffn_dim = decoder_ffn_dim
99 | self.decoder_layers = decoder_layers
100 | self.decoder_attention_heads = decoder_attention_heads
101 | self.init_std = init_std # Normal(0, this parameter)
102 | self.activation_function = activation_function
103 |
104 | self.max_position_embeddings = max_position_embeddings
105 | self.encoder_max_position_embeddings = encoder_max_position_embeddings if encoder_max_position_embeddings else max_position_embeddings
106 | self.decoder_max_position_embeddings = decoder_max_position_embeddings if decoder_max_position_embeddings else max_position_embeddings
107 |
108 | # Params introduced for Mbart
109 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
110 | self.normalize_embedding = normalize_embedding # True for mbart, False otherwise
111 | self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
112 | self.add_final_layer_norm = add_final_layer_norm
113 |
114 | # Params introduced for Marian
115 | self.add_bias_logits = add_bias_logits
116 | self.static_position_embeddings = static_position_embeddings
117 |
118 | # 3 Types of Dropout
119 | self.attention_dropout = attention_dropout
120 | self.activation_dropout = activation_dropout
121 | self.dropout = dropout
122 |
123 | # Classifier stuff
124 | self.classif_dropout = classifier_dropout
125 |
126 | # gradient_checkpointing
127 | self.gradient_checkpointing = gradient_checkpointing
128 | self.output_attentions = True
129 |
130 | @property
131 | def num_attention_heads(self) -> int:
132 | return self.encoder_attention_heads
133 |
134 | @property
135 | def hidden_size(self) -> int:
136 | return self.d_model
137 |
138 | def is_valid_mbart(self) -> bool:
139 | """Is the configuration aligned with the MBART paper."""
140 | if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
141 | return True
142 | if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
143 | logger.info("This configuration is a mixture of MBART and BART settings")
144 | return False
--------------------------------------------------------------------------------
/longbart/convert_bart_to_longbart.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | from transformers import BartTokenizer
6 |
7 | from .modeling_bart import BartForConditionalGeneration
8 | from .modeling_longbart import LongformerSelfAttentionForBart
9 |
10 | logger = logging.getLogger(__name__)
11 | logging.basicConfig(level=logging.INFO)
12 |
13 | def create_long_model(
14 | save_model_to,
15 | base_model='facebook/bart-large',
16 | tokenizer_name_or_path='facebook/bart-large',
17 | attention_window=1024,
18 | max_pos=4096
19 | ):
20 | model = BartForConditionalGeneration.from_pretrained(base_model)
21 | tokenizer = BartTokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos)
22 | config = model.config
23 |
24 | # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention
25 | # expects attention_probs_dropout_prob, so set it here
26 | config.attention_probs_dropout_prob = config.attention_dropout
27 |
28 | # extend position embeddings
29 | tokenizer.model_max_length = max_pos
30 | tokenizer.init_kwargs['model_max_length'] = max_pos
31 | current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape
32 | # config.max_position_embeddings = max_pos
33 | config.encoder_max_position_embeddings = max_pos
34 | max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
35 | assert max_pos > current_max_pos
36 | # allocate a larger position embedding matrix
37 | new_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
38 | # copy position embeddings over and over to initialize the new position embeddings
39 | k = 2
40 | step = current_max_pos - 2
41 | while k < max_pos - 1:
42 | new_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]
43 | k += step
44 | model.model.encoder.embed_positions.weight.data = new_pos_embed
45 |
46 | # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`
47 | config.attention_window = [attention_window] * config.num_hidden_layers
48 | for i, layer in enumerate(model.model.encoder.layers):
49 | longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i)
50 |
51 | longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj
52 | longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj
53 | longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj
54 |
55 | longformer_self_attn_for_bart.longformer_self_attn.query_global = layer.self_attn.q_proj
56 | longformer_self_attn_for_bart.longformer_self_attn.key_global = layer.self_attn.k_proj
57 | longformer_self_attn_for_bart.longformer_self_attn.value_global = layer.self_attn.v_proj
58 |
59 | longformer_self_attn_for_bart.output = layer.self_attn.out_proj
60 |
61 | layer.self_attn = longformer_self_attn_for_bart
62 |
63 | logger.info(f'saving model to {save_model_to}')
64 | model.save_pretrained(save_model_to)
65 | tokenizer.save_pretrained(save_model_to)
66 | return model, tokenizer
67 |
68 |
69 | def main():
70 | parser = argparse.ArgumentParser(description="Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention")
71 | parser.add_argument(
72 | 'base_model',
73 | type=str,
74 | default='facebook/bart-large',
75 | help='The name or path of the base model you want to convert'
76 | )
77 | parser.add_argument(
78 | 'tokenizer_name_or_path',
79 | type=str,
80 | default='facebook/bart-large',
81 | help='The name or path of the tokenizer'
82 | )
83 | parser.add_argument(
84 | 'save_model_to',
85 | type=str,
86 | required=True,
87 | help='The path to save the converted model'
88 | )
89 | parser.add_argument(
90 | 'attention_window',
91 | type=int,
92 | default=1024,
93 | help='attention window size for longformer self attention'
94 | )
95 | parser.add_argument(
96 | 'max_pos',
97 | type=int,
98 | default=4096,
99 | help='maximum encoder positions'
100 | )
101 |
102 | args = parser.parse_args()
103 |
104 | if not os.path.exists(args.save_model_to):
105 | os.mkdir(args.save_model_to)
106 |
107 | create_long_model(
108 | save_model_to=args.save_model_to,
109 | base_model=args.base_model,
110 | tokenizer_name_or_path=args.tokenizer_name_or_path,
111 | attention_window=args.attention_window,
112 | max_pos=args.max_pos
113 | )
114 |
115 |
116 | if __name__ == "__main__":
117 | main()
--------------------------------------------------------------------------------
/longbart/modeling_bart.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch BART model, ported from the fairseq repo."""
16 | import logging
17 | import math
18 | import random
19 | from typing import Dict, List, Optional, Tuple
20 |
21 | import numpy as np
22 | import torch
23 | import torch.utils.checkpoint
24 | import torch.nn.functional as F
25 | from torch import Tensor, nn
26 |
27 | from transformers.activations import ACT2FN
28 | from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
29 | from transformers.modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
30 |
31 | from .configuration_bart import BartConfig
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
37 | "facebook/bart-large",
38 | "facebook/bart-large-mnli",
39 | "facebook/bart-large-cnn",
40 | "facebook/bart-large-xsum",
41 | "facebook/mbart-large-en-ro",
42 | # See all BART models at https://huggingface.co/models?filter=bart
43 | ]
44 |
45 |
46 | BART_START_DOCSTRING = r"""
47 |
48 | This model is a PyTorch `torch.nn.Module `_ sub-class. Use it as a regular PyTorch Module and
49 | refer to the PyTorch documentation for all matters related to general usage and behavior.
50 |
51 | Parameters:
52 | config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
53 | Initializing with a config file does not load the weights associated with the model, only the configuration.
54 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
55 |
56 | """
57 | BART_GENERATION_EXAMPLE = r"""
58 | Examples::
59 |
60 | from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
61 | # see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
62 | model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
63 | tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
64 | ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
65 | inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
66 | # Generate Summary
67 | summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
68 | print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
69 |
70 | """
71 |
72 | BART_INPUTS_DOCSTRING = r"""
73 | Args:
74 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
75 | Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
76 | Padding will be ignored by default should you provide it.
77 | Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
78 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
79 | Mask to avoid performing attention on padding token indices in input_ids.
80 | Mask values selected in ``[0, 1]``:
81 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
82 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
83 | Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
84 | `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
85 | Used in the cross-attention of the decoder.
86 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
87 | Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
88 | decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
89 | Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
90 | If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
91 | See diagram 1 in the paper for more info on the default strategy
92 | """
93 |
94 |
95 | def invert_mask(attention_mask):
96 | assert attention_mask.dim() == 2
97 | return attention_mask.eq(0)
98 |
99 |
100 | def _prepare_bart_decoder_inputs(
101 | config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
102 | ):
103 | """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
104 | none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
105 | Note: this is not called during generation
106 | """
107 | pad_token_id = config.pad_token_id
108 | if decoder_input_ids is None:
109 | decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
110 | bsz, tgt_len = decoder_input_ids.size()
111 | if decoder_padding_mask is None:
112 | decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
113 | else:
114 | decoder_padding_mask = invert_mask(decoder_padding_mask)
115 | causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
116 | dtype=causal_mask_dtype, device=decoder_input_ids.device
117 | )
118 | return decoder_input_ids, decoder_padding_mask, causal_mask
119 |
120 |
121 | class PretrainedBartModel(PreTrainedModel):
122 | config_class = BartConfig
123 | base_model_prefix = "model"
124 |
125 | def _init_weights(self, module):
126 | std = self.config.init_std
127 | if isinstance(module, nn.Linear):
128 | module.weight.data.normal_(mean=0.0, std=std)
129 | if module.bias is not None:
130 | module.bias.data.zero_()
131 | elif isinstance(module, SinusoidalPositionalEmbedding):
132 | pass
133 | elif isinstance(module, nn.Embedding):
134 | module.weight.data.normal_(mean=0.0, std=std)
135 | if module.padding_idx is not None:
136 | module.weight.data[module.padding_idx].zero_()
137 |
138 | @property
139 | def dummy_inputs(self):
140 | pad_token = self.config.pad_token_id
141 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
142 | dummy_inputs = {
143 | "attention_mask": input_ids.ne(pad_token),
144 | "input_ids": input_ids,
145 | }
146 | return dummy_inputs
147 |
148 |
149 | def _make_linear_from_emb(emb):
150 | vocab_size, emb_size = emb.weight.shape
151 | lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
152 | lin_layer.weight.data = emb.weight.data
153 | return lin_layer
154 |
155 |
156 | # Helper Functions, mostly for making masks
157 | def _check_shapes(shape_1, shape2):
158 | if shape_1 != shape2:
159 | raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
160 |
161 |
162 | def shift_tokens_right(input_ids, pad_token_id):
163 | """Shift input ids one token to the right, and wrap the last non pad token (usually )."""
164 | prev_output_tokens = input_ids.clone()
165 | index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
166 | prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
167 | prev_output_tokens[:, 1:] = input_ids[:, :-1]
168 | return prev_output_tokens
169 |
170 |
171 | def make_padding_mask(input_ids, padding_idx=1):
172 | """True for pad tokens"""
173 | padding_mask = input_ids.eq(padding_idx)
174 | if not padding_mask.any():
175 | padding_mask = None
176 | return padding_mask
177 |
178 |
179 | # Helper Modules
180 |
181 |
182 | class EncoderLayer(nn.Module):
183 | def __init__(self, config: BartConfig):
184 | super().__init__()
185 | self.embed_dim = config.d_model
186 | self.output_attentions = config.output_attentions
187 | self.self_attn = SelfAttention(
188 | self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
189 | )
190 | self.normalize_before = config.normalize_before
191 | self.self_attn_layer_norm = LayerNorm(self.embed_dim)
192 | self.dropout = config.dropout
193 | self.activation_fn = ACT2FN[config.activation_function]
194 | self.activation_dropout = config.activation_dropout
195 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
196 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
197 | self.final_layer_norm = LayerNorm(self.embed_dim)
198 |
199 | def forward(self, x, encoder_padding_mask):
200 | """
201 | Args:
202 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
203 | encoder_padding_mask (ByteTensor): binary ByteTensor of shape
204 | `(batch, src_len)` where padding elements are indicated by ``1``.
205 | for t_tgt, t_src is excluded (or masked out), =0 means it is
206 | included in attention
207 |
208 | Returns:
209 | encoded output of shape `(seq_len, batch, embed_dim)`
210 | """
211 | residual = x
212 | if self.normalize_before:
213 | x = self.self_attn_layer_norm(x)
214 | x, attn_weights = self.self_attn(
215 | query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
216 | )
217 | x = F.dropout(x, p=self.dropout, training=self.training)
218 | x = residual + x
219 | if not self.normalize_before:
220 | x = self.self_attn_layer_norm(x)
221 |
222 | residual = x
223 | if self.normalize_before:
224 | x = self.final_layer_norm(x)
225 | x = self.activation_fn(self.fc1(x))
226 | x = F.dropout(x, p=self.activation_dropout, training=self.training)
227 | x = self.fc2(x)
228 | x = F.dropout(x, p=self.dropout, training=self.training)
229 | x = residual + x
230 | if not self.normalize_before:
231 | x = self.final_layer_norm(x)
232 | return (x, attn_weights)
233 |
234 |
235 | class BartEncoder(nn.Module):
236 | """
237 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
238 | is a :class:`EncoderLayer`.
239 |
240 | Args:
241 | config: BartConfig
242 | """
243 |
244 | def __init__(self, config: BartConfig, embed_tokens):
245 | super().__init__()
246 | self.config = config
247 | self.dropout = config.dropout
248 | self.layerdrop = config.encoder_layerdrop
249 | self.output_attentions = config.output_attentions
250 | self.output_hidden_states = config.output_hidden_states
251 |
252 | embed_dim = embed_tokens.embedding_dim
253 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
254 | self.padding_idx = embed_tokens.padding_idx
255 | self.max_source_positions = config.encoder_max_position_embeddings
256 |
257 | self.embed_tokens = embed_tokens
258 | if config.static_position_embeddings:
259 | self.embed_positions = SinusoidalPositionalEmbedding(
260 | config.encoder_max_position_embeddings, embed_dim, self.padding_idx
261 | )
262 | else:
263 | self.embed_positions = LearnedPositionalEmbedding(
264 | config.encoder_max_position_embeddings, embed_dim, self.padding_idx,
265 | )
266 | self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
267 | self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
268 | # mbart has one extra layer_norm
269 | self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
270 |
271 | def forward(
272 | self, input_ids, attention_mask=None,
273 | ):
274 | """
275 | Args:
276 | input_ids (LongTensor): tokens in the source language of shape
277 | `(batch, src_len)`
278 | attention_mask (torch.LongTensor): indicating which indices are padding tokens.
279 | Returns:
280 | Tuple comprised of:
281 | - **x** (Tensor): the last encoder layer's output of
282 | shape `(src_len, batch, embed_dim)`
283 | - **encoder_states** (List[Tensor]): all intermediate
284 | hidden states of shape `(src_len, batch, embed_dim)`.
285 | Only populated if *self.output_hidden_states:* is True.
286 | - **all_attentions** (List[Tensor]): Attention weights for each layer.
287 | During training might not be of length n_layers because of layer dropout.
288 | """
289 | # check attention mask and invert
290 | if attention_mask is not None:
291 | attention_mask = invert_mask(attention_mask)
292 |
293 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
294 | embed_pos = self.embed_positions(input_ids)
295 | x = inputs_embeds + embed_pos
296 | x = self.layernorm_embedding(x)
297 | x = F.dropout(x, p=self.dropout, training=self.training)
298 |
299 | # B x T x C -> T x B x C
300 | x = x.transpose(0, 1)
301 |
302 | encoder_states, all_attentions = [], []
303 | for encoder_layer in self.layers:
304 | if self.output_hidden_states:
305 | encoder_states.append(x)
306 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
307 | dropout_probability = random.uniform(0, 1)
308 | if self.training and (dropout_probability < self.layerdrop): # skip the layer
309 | attn = None
310 | else:
311 | if getattr(self.config, "gradient_checkpointing", False):
312 | x, attn = torch.utils.checkpoint.checkpoint(
313 | encoder_layer,
314 | x,
315 | attention_mask
316 | )
317 | else:
318 | x, attn = encoder_layer(x, attention_mask)
319 |
320 |
321 | if self.output_attentions:
322 | all_attentions.append(attn)
323 |
324 | if self.layer_norm:
325 | x = self.layer_norm(x)
326 | if self.output_hidden_states:
327 | encoder_states.append(x)
328 |
329 | # T x B x C -> B x T x C
330 | encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
331 | x = x.transpose(0, 1)
332 |
333 | return x, encoder_states, all_attentions
334 |
335 |
336 | class DecoderLayer(nn.Module):
337 | def __init__(self, config: BartConfig):
338 | super().__init__()
339 | self.embed_dim = config.d_model
340 | self.self_attn = SelfAttention(
341 | embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
342 | )
343 | self.dropout = config.dropout
344 | self.activation_fn = ACT2FN[config.activation_function]
345 | self.activation_dropout = config.activation_dropout
346 | self.normalize_before = config.normalize_before
347 |
348 | self.self_attn_layer_norm = LayerNorm(self.embed_dim)
349 | self.encoder_attn = SelfAttention(
350 | self.embed_dim,
351 | config.decoder_attention_heads,
352 | dropout=config.attention_dropout,
353 | encoder_decoder_attention=True,
354 | )
355 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
356 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
357 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
358 | self.final_layer_norm = LayerNorm(self.embed_dim)
359 |
360 | def forward(
361 | self,
362 | x,
363 | encoder_hidden_states,
364 | encoder_attn_mask=None,
365 | layer_state=None,
366 | causal_mask=None,
367 | decoder_padding_mask=None,
368 | output_attentions=False,
369 | ):
370 | residual = x
371 |
372 | if layer_state is None:
373 | layer_state = {}
374 | if self.normalize_before:
375 | x = self.self_attn_layer_norm(x)
376 | # Self Attention
377 |
378 | x, self_attn_weights = self.self_attn(
379 | query=x,
380 | key=x,
381 | layer_state=layer_state, # adds keys to layer state
382 | key_padding_mask=decoder_padding_mask,
383 | attn_mask=causal_mask,
384 | )
385 | x = F.dropout(x, p=self.dropout, training=self.training)
386 | x = residual + x
387 | if not self.normalize_before:
388 | x = self.self_attn_layer_norm(x)
389 |
390 | # Cross attention
391 | residual = x
392 | assert self.encoder_attn.cache_key != self.self_attn.cache_key
393 | if self.normalize_before:
394 | x = self.encoder_attn_layer_norm(x)
395 | x, _ = self.encoder_attn(
396 | query=x,
397 | key=encoder_hidden_states,
398 | key_padding_mask=encoder_attn_mask,
399 | layer_state=layer_state, # mutates layer state
400 | )
401 | x = F.dropout(x, p=self.dropout, training=self.training)
402 | x = residual + x
403 | if not self.normalize_before:
404 | x = self.encoder_attn_layer_norm(x)
405 |
406 | # Fully Connected
407 | residual = x
408 | if self.normalize_before:
409 | x = self.final_layer_norm(x)
410 | x = self.activation_fn(self.fc1(x))
411 | x = F.dropout(x, p=self.activation_dropout, training=self.training)
412 | x = self.fc2(x)
413 | x = F.dropout(x, p=self.dropout, training=self.training)
414 | x = residual + x
415 | if not self.normalize_before:
416 | x = self.final_layer_norm(x)
417 | return (
418 | x,
419 | self_attn_weights,
420 | layer_state,
421 | ) # just self_attn weights for now, following t5, layer_state = cache for decoding
422 |
423 |
424 | class BartDecoder(nn.Module):
425 | """
426 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer
427 | is a :class:`DecoderLayer`.
428 | Args:
429 | config: BartConfig
430 | embed_tokens (torch.nn.Embedding): output embedding
431 | """
432 |
433 | def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
434 | super().__init__()
435 | self.output_hidden_states = config.output_hidden_states
436 | self.dropout = config.dropout
437 | self.layerdrop = config.decoder_layerdrop
438 | self.padding_idx = embed_tokens.padding_idx
439 | self.max_target_positions = config.max_position_embeddings
440 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
441 | self.embed_tokens = embed_tokens
442 | if config.static_position_embeddings:
443 | self.embed_positions = SinusoidalPositionalEmbedding(
444 | config.max_position_embeddings, config.d_model, config.pad_token_id
445 | )
446 | else:
447 | self.embed_positions = LearnedPositionalEmbedding(
448 | config.max_position_embeddings, config.d_model, self.padding_idx,
449 | )
450 | self.layers = nn.ModuleList(
451 | [DecoderLayer(config) for _ in range(config.decoder_layers)]
452 | ) # type: List[DecoderLayer]
453 | self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
454 | self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
455 |
456 | def forward(
457 | self,
458 | input_ids,
459 | encoder_hidden_states,
460 | encoder_padding_mask,
461 | decoder_padding_mask,
462 | decoder_causal_mask,
463 | decoder_cached_states=None,
464 | use_cache=False,
465 | output_attentions=False,
466 | **unused,
467 | ):
468 | """
469 | Includes several features from "Jointly Learning to Align and
470 | Translate with Transformer Models" (Garg et al., EMNLP 2019).
471 |
472 | Args:
473 | input_ids (LongTensor): previous decoder outputs of shape
474 | `(batch, tgt_len)`, for teacher forcing
475 | encoder_hidden_states: output from the encoder, used for
476 | encoder-side attention
477 | encoder_padding_mask: for ignoring pad tokens
478 | decoder_cached_states (dict or None): dictionary used for storing state during generation
479 |
480 | Returns:
481 | tuple:
482 | - the decoder's features of shape `(batch, tgt_len, embed_dim)`
483 | - hidden states
484 | - attentions
485 | """
486 | # check attention mask and invert
487 | if encoder_padding_mask is not None:
488 | encoder_padding_mask = invert_mask(encoder_padding_mask)
489 |
490 | # embed positions
491 | positions = self.embed_positions(input_ids, use_cache=use_cache)
492 |
493 | if use_cache:
494 | input_ids = input_ids[:, -1:]
495 | positions = positions[:, -1:] # happens after we embed them
496 | # assert input_ids.ne(self.padding_idx).any()
497 |
498 | x = self.embed_tokens(input_ids) * self.embed_scale
499 | x += positions
500 | x = self.layernorm_embedding(x)
501 | x = F.dropout(x, p=self.dropout, training=self.training)
502 |
503 | # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
504 | x = x.transpose(0, 1)
505 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
506 |
507 | # decoder layers
508 | all_hidden_states = ()
509 | all_self_attns = ()
510 | next_decoder_cache = []
511 | for idx, decoder_layer in enumerate(self.layers):
512 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
513 | if self.output_hidden_states:
514 | all_hidden_states += (x,)
515 | dropout_probability = random.uniform(0, 1)
516 | if self.training and (dropout_probability < self.layerdrop):
517 | continue
518 |
519 | layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
520 |
521 | x, layer_self_attn, layer_past = decoder_layer(
522 | x,
523 | encoder_hidden_states,
524 | encoder_attn_mask=encoder_padding_mask,
525 | decoder_padding_mask=decoder_padding_mask,
526 | layer_state=layer_state,
527 | causal_mask=decoder_causal_mask,
528 | output_attentions=output_attentions,
529 | )
530 |
531 | if use_cache:
532 | next_decoder_cache.append(layer_past.copy())
533 |
534 | if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
535 | x = self.layer_norm(x)
536 | if output_attentions:
537 | all_self_attns += (layer_self_attn,)
538 |
539 | # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
540 | all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
541 | x = x.transpose(0, 1)
542 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
543 |
544 | if use_cache:
545 | next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
546 | else:
547 | next_cache = None
548 | return x, next_cache, all_hidden_states, list(all_self_attns)
549 |
550 |
551 | def _reorder_buffer(attn_cache, new_order):
552 | for k, input_buffer_k in attn_cache.items():
553 | if input_buffer_k is not None:
554 | attn_cache[k] = input_buffer_k.index_select(0, new_order)
555 | return attn_cache
556 |
557 |
558 | class SelfAttention(nn.Module):
559 | """Multi-headed attention from 'Attention Is All You Need' paper"""
560 |
561 | def __init__(
562 | self,
563 | embed_dim,
564 | num_heads,
565 | dropout=0.0,
566 | bias=True,
567 | encoder_decoder_attention=False, # otherwise self_attention
568 | ):
569 | super().__init__()
570 | self.embed_dim = embed_dim
571 | self.num_heads = num_heads
572 | self.dropout = dropout
573 | self.head_dim = embed_dim // num_heads
574 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
575 | self.scaling = self.head_dim ** -0.5
576 |
577 | self.encoder_decoder_attention = encoder_decoder_attention
578 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
579 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
580 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
581 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
582 | self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
583 |
584 | def _shape(self, tensor, dim_0, bsz):
585 | return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1)
586 |
587 | def forward(
588 | self,
589 | query,
590 | key: Optional[Tensor],
591 | key_padding_mask: Optional[Tensor] = None,
592 | layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
593 | attn_mask: Optional[Tensor] = None,
594 | need_weights=False,
595 | ) -> Tuple[Tensor, Optional[Tensor]]:
596 | """Input shape: Time(SeqLen) x Batch x Channel"""
597 | static_kv: bool = self.encoder_decoder_attention
598 | tgt_len, bsz, embed_dim = query.size()
599 | assert embed_dim == self.embed_dim
600 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
601 | # get here for encoder decoder cause of static_kv
602 | if layer_state is not None: # reuse k,v and encoder_padding_mask
603 | saved_state = layer_state.get(self.cache_key, {})
604 | if "prev_key" in saved_state:
605 | # previous time steps are cached - no need to recompute key and value if they are static
606 | if static_kv:
607 | key = None
608 | else:
609 | saved_state = None
610 | layer_state = {}
611 |
612 | q = self.q_proj(query) * self.scaling
613 | if static_kv:
614 | if key is None:
615 | k = v = None
616 | else:
617 | k = self.k_proj(key)
618 | v = self.v_proj(key)
619 | else:
620 | k = self.k_proj(query)
621 | v = self.v_proj(query)
622 |
623 | q = self._shape(q, tgt_len, bsz)
624 | if k is not None:
625 | k = self._shape(k, -1, bsz)
626 | if v is not None:
627 | v = self._shape(v, -1, bsz)
628 |
629 | if saved_state is not None:
630 | k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
631 |
632 | # Update cache
633 | layer_state[self.cache_key] = {
634 | "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
635 | "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
636 | "prev_key_padding_mask": key_padding_mask if not static_kv else None,
637 | }
638 |
639 | assert k is not None
640 | src_len = k.size(1)
641 | attn_weights = torch.bmm(q, k.transpose(1, 2))
642 | assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
643 |
644 | if attn_mask is not None:
645 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
646 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
647 |
648 | # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
649 | if key_padding_mask is not None and key_padding_mask.dim() == 0:
650 | key_padding_mask = None
651 | assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
652 |
653 | if key_padding_mask is not None: # don't attend to padding symbols
654 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
655 | reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
656 | attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
657 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
658 | attn_weights = F.softmax(attn_weights, dim=-1)
659 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
660 |
661 | assert v is not None
662 | attn_output = torch.bmm(attn_probs, v)
663 | assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
664 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
665 | attn_output = self.out_proj(attn_output)
666 | if need_weights:
667 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
668 | else:
669 | attn_weights = None
670 | return attn_output, attn_weights
671 |
672 | def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
673 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
674 | if "prev_key" in saved_state:
675 | _prev_key = saved_state["prev_key"]
676 | assert _prev_key is not None
677 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
678 | if static_kv:
679 | k = prev_key
680 | else:
681 | assert k is not None
682 | k = torch.cat([prev_key, k], dim=1)
683 | if "prev_value" in saved_state:
684 | _prev_value = saved_state["prev_value"]
685 | assert _prev_value is not None
686 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
687 | if static_kv:
688 | v = prev_value
689 | else:
690 | assert v is not None
691 | v = torch.cat([prev_value, v], dim=1)
692 | assert k is not None and v is not None
693 | prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
694 | key_padding_mask = self._cat_prev_key_padding_mask(
695 | key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
696 | )
697 | return k, v, key_padding_mask
698 |
699 | @staticmethod
700 | def _cat_prev_key_padding_mask(
701 | key_padding_mask: Optional[Tensor],
702 | prev_key_padding_mask: Optional[Tensor],
703 | batch_size: int,
704 | src_len: int,
705 | static_kv: bool,
706 | ) -> Optional[Tensor]:
707 | # saved key padding masks have shape (bsz, seq_len)
708 | if prev_key_padding_mask is not None:
709 | if static_kv:
710 | new_key_padding_mask = prev_key_padding_mask
711 | else:
712 | new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
713 |
714 | elif key_padding_mask is not None:
715 | filler = torch.zeros(
716 | batch_size,
717 | src_len - key_padding_mask.size(1),
718 | dtype=key_padding_mask.dtype,
719 | device=key_padding_mask.device,
720 | )
721 | new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1)
722 | else:
723 | new_key_padding_mask = prev_key_padding_mask
724 | return new_key_padding_mask
725 |
726 |
727 | class BartClassificationHead(nn.Module):
728 | """Head for sentence-level classification tasks."""
729 |
730 | # This can trivially be shared with RobertaClassificationHead
731 |
732 | def __init__(
733 | self, input_dim, inner_dim, num_classes, pooler_dropout,
734 | ):
735 | super().__init__()
736 | self.dense = nn.Linear(input_dim, inner_dim)
737 | self.dropout = nn.Dropout(p=pooler_dropout)
738 | self.out_proj = nn.Linear(inner_dim, num_classes)
739 |
740 | def forward(self, x):
741 | x = self.dropout(x)
742 | x = self.dense(x)
743 | x = torch.tanh(x)
744 | x = self.dropout(x)
745 | x = self.out_proj(x)
746 | return x
747 |
748 |
749 | class LearnedPositionalEmbedding(nn.Embedding):
750 | """
751 | This module learns positional embeddings up to a fixed maximum size.
752 | Padding ids are ignored by either offsetting based on padding_idx
753 | or by setting padding_idx to None and ensuring that the appropriate
754 | position ids are passed to the forward function.
755 | """
756 |
757 | def __init__(
758 | self, num_embeddings: int, embedding_dim: int, padding_idx: int,
759 | ):
760 | # if padding_idx is specified then offset the embedding ids by
761 | # this index and adjust num_embeddings appropriately
762 | assert padding_idx is not None
763 | num_embeddings += padding_idx + 1 # WHY?
764 | super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
765 |
766 | def forward(self, input, use_cache=False):
767 | """Input is expected to be of size [bsz x seqlen]."""
768 | if use_cache: # the position is our current step in the decoded sequence
769 | pos = int(self.padding_idx + input.size(1))
770 | positions = input.data.new(1, 1).fill_(pos)
771 | else:
772 | positions = create_position_ids_from_input_ids(input, self.padding_idx)
773 | return super().forward(positions)
774 |
775 |
776 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
777 | if torch.cuda.is_available():
778 | try:
779 | from apex.normalization import FusedLayerNorm
780 |
781 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
782 | except ImportError:
783 | pass
784 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
785 |
786 |
787 | def fill_with_neg_inf(t):
788 | """FP16-compatible function that fills a input_ids with -inf."""
789 | return t.float().fill_(float("-inf")).type_as(t)
790 |
791 |
792 | def _filter_out_falsey_values(tup) -> Tuple:
793 | """Remove entries that are None or [] from an iterable."""
794 | return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)
795 |
796 |
797 | # Public API
798 | def _get_shape(t):
799 | return getattr(t, "shape", None)
800 |
801 |
802 | @add_start_docstrings(
803 | "The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING,
804 | )
805 | class BartModel(PretrainedBartModel):
806 | def __init__(self, config: BartConfig):
807 | super().__init__(config)
808 | self.output_attentions = config.output_attentions
809 | self.output_hidden_states = config.output_hidden_states
810 |
811 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size
812 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
813 |
814 | self.encoder = BartEncoder(config, self.shared)
815 | self.decoder = BartDecoder(config, self.shared)
816 |
817 | self.init_weights()
818 |
819 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
820 | def forward(
821 | self,
822 | input_ids,
823 | attention_mask=None,
824 | decoder_input_ids=None,
825 | encoder_outputs: Optional[Tuple] = None,
826 | decoder_attention_mask=None,
827 | decoder_cached_states=None,
828 | use_cache=False,
829 | ):
830 |
831 | # make masks if user doesn't supply
832 | if not use_cache:
833 | decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
834 | self.config,
835 | input_ids,
836 | decoder_input_ids=decoder_input_ids,
837 | decoder_padding_mask=decoder_attention_mask,
838 | causal_mask_dtype=self.shared.weight.dtype,
839 | )
840 | else:
841 | decoder_padding_mask, causal_mask = None, None
842 |
843 | assert decoder_input_ids is not None
844 | if encoder_outputs is None:
845 | encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
846 | assert isinstance(encoder_outputs, tuple)
847 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
848 | decoder_outputs = self.decoder(
849 | decoder_input_ids,
850 | encoder_outputs[0],
851 | attention_mask,
852 | decoder_padding_mask,
853 | decoder_causal_mask=causal_mask,
854 | decoder_cached_states=decoder_cached_states,
855 | use_cache=use_cache,
856 | )
857 | # Attention and hidden_states will be [] or None if they aren't needed
858 | decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
859 | assert isinstance(decoder_outputs[0], torch.Tensor)
860 | encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs)
861 | return decoder_outputs + encoder_outputs
862 |
863 | def get_input_embeddings(self):
864 | return self.shared
865 |
866 | def set_input_embeddings(self, value):
867 | self.shared = value
868 | self.encoder.embed_tokens = self.shared
869 | self.decoder.embed_tokens = self.shared
870 |
871 | def get_output_embeddings(self):
872 | return _make_linear_from_emb(self.shared) # make it on the fly
873 |
874 |
875 | @add_start_docstrings(
876 | "The BART Model with a language modeling head. Can be used for summarization.",
877 | BART_START_DOCSTRING + BART_GENERATION_EXAMPLE,
878 | )
879 | class BartForConditionalGeneration(PretrainedBartModel):
880 | base_model_prefix = "model"
881 |
882 | def __init__(self, config: BartConfig):
883 | super().__init__(config)
884 | base_model = BartModel(config)
885 | self.model = base_model
886 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
887 |
888 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
889 | old_num_tokens = self.model.shared.num_embeddings
890 | new_embeddings = super().resize_token_embeddings(new_num_tokens)
891 | self.model.shared = new_embeddings
892 | self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
893 | return new_embeddings
894 |
895 | def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
896 | if new_num_tokens <= old_num_tokens:
897 | new_bias = self.final_logits_bias[:, :new_num_tokens]
898 | else:
899 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
900 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
901 | self.register_buffer("final_logits_bias", new_bias)
902 |
903 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
904 | def forward(
905 | self,
906 | input_ids,
907 | attention_mask=None,
908 | encoder_outputs=None,
909 | decoder_input_ids=None,
910 | decoder_attention_mask=None,
911 | decoder_cached_states=None,
912 | lm_labels=None,
913 | use_cache=False,
914 | **unused
915 | ):
916 | r"""
917 | lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
918 | Labels for computing the masked language modeling loss.
919 | Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
920 | Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
921 | with labels
922 | in ``[0, ..., config.vocab_size]``.
923 |
924 | Returns:
925 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
926 | masked_lm_loss (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
927 | Masked language modeling loss.
928 | prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
929 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
930 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
931 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
932 | of shape :obj:`(batch_size, sequence_length, hidden_size)`.
933 |
934 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
935 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
936 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
937 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
938 |
939 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
940 | heads.
941 |
942 | Examples::
943 |
944 | # Mask filling only works for bart-large
945 | from transformers import BartTokenizer, BartForConditionalGeneration
946 | tokenizer = BartTokenizer.from_pretrained('bart-large')
947 | TXT = "My friends are but they eat too many carbs."
948 | model = BartForConditionalGeneration.from_pretrained('bart-large')
949 | input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
950 | logits = model(input_ids)[0]
951 | masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
952 | probs = logits[0, masked_index].softmax(dim=0)
953 | values, predictions = probs.topk(5)
954 | tokenizer.decode(predictions).split()
955 | # ['good', 'great', 'all', 'really', 'very']
956 | """
957 | outputs = self.model(
958 | input_ids,
959 | attention_mask=attention_mask,
960 | decoder_input_ids=decoder_input_ids,
961 | encoder_outputs=encoder_outputs,
962 | decoder_attention_mask=decoder_attention_mask,
963 | decoder_cached_states=decoder_cached_states,
964 | use_cache=use_cache,
965 | )
966 | lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
967 | outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
968 | if lm_labels is not None:
969 | loss_fct = nn.CrossEntropyLoss()
970 | # TODO(SS): do we need to ignore pad tokens in lm_labels?
971 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), lm_labels.view(-1))
972 | outputs = (masked_lm_loss,) + outputs
973 |
974 | return outputs
975 |
976 | def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
977 | assert past is not None, "past has to be defined for encoder_outputs"
978 |
979 | # first step, decoder_cached_states are empty
980 | if not past[1]:
981 | encoder_outputs, decoder_cached_states = past, None
982 | else:
983 | encoder_outputs, decoder_cached_states = past
984 | return {
985 | "input_ids": None, # encoder_outputs is defined. input_ids not needed
986 | "encoder_outputs": encoder_outputs,
987 | "decoder_cached_states": decoder_cached_states,
988 | "decoder_input_ids": decoder_input_ids,
989 | "attention_mask": attention_mask,
990 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
991 | }
992 |
993 | def prepare_logits_for_generation(self, logits, cur_len, max_length):
994 | if cur_len == 1:
995 | self._force_token_ids_generation(logits, self.config.bos_token_id)
996 | if cur_len == max_length - 1 and self.config.eos_token_id is not None:
997 | self._force_token_ids_generation(logits, self.config.eos_token_id)
998 | return logits
999 |
1000 | def _force_token_ids_generation(self, scores, token_ids) -> None:
1001 | """force one of token_ids to be generated by setting prob of all other tokens to 0"""
1002 | if isinstance(token_ids, int):
1003 | token_ids = [token_ids]
1004 | all_but_token_ids_mask = torch.tensor(
1005 | [x for x in range(self.config.vocab_size) if x not in token_ids],
1006 | dtype=torch.long,
1007 | device=next(self.parameters()).device,
1008 | )
1009 | assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
1010 | scores[:, all_but_token_ids_mask] = -float("inf")
1011 |
1012 | @staticmethod
1013 | def _reorder_cache(past, beam_idx):
1014 | ((enc_out, enc_mask), decoder_cached_states) = past
1015 | reordered_past = []
1016 | for layer_past in decoder_cached_states:
1017 | # get the correct batch idx from decoder layer's batch dim for cross and self-attn
1018 | layer_past_new = {
1019 | attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
1020 | }
1021 | reordered_past.append(layer_past_new)
1022 |
1023 | new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
1024 | new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
1025 |
1026 | past = ((new_enc_out, new_enc_mask), reordered_past)
1027 | return past
1028 |
1029 | def get_encoder(self):
1030 | return self.model.encoder
1031 |
1032 | def get_output_embeddings(self):
1033 | return _make_linear_from_emb(self.model.shared) # make it on the fly
1034 |
1035 |
1036 | @add_start_docstrings(
1037 | """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
1038 | BART_START_DOCSTRING,
1039 | )
1040 | class BartForSequenceClassification(PretrainedBartModel):
1041 | def __init__(self, config: BartConfig, **kwargs):
1042 | super().__init__(config, **kwargs)
1043 | self.model = BartModel(config)
1044 | self.classification_head = BartClassificationHead(
1045 | config.d_model, config.d_model, config.num_labels, config.classif_dropout,
1046 | )
1047 | self.model._init_weights(self.classification_head.dense)
1048 | self.model._init_weights(self.classification_head.out_proj)
1049 |
1050 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1051 | def forward(
1052 | self,
1053 | input_ids,
1054 | attention_mask=None,
1055 | encoder_outputs=None,
1056 | decoder_input_ids=None,
1057 | decoder_attention_mask=None,
1058 | labels=None,
1059 | ):
1060 | r"""
1061 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1062 | Labels for computing the sequence classification/regression loss.
1063 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1064 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1065 |
1066 | Returns:
1067 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
1068 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
1069 | Classification loss (cross entropy)
1070 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
1071 | Classification (or regression if config.num_labels==1) scores (before SoftMax).
1072 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
1073 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1074 | of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1075 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1076 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
1077 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1078 | Attentions weights after the attention softmax, used to compute the weighted average in the
1079 | self-attention
1080 | heads.
1081 |
1082 | Examples::
1083 |
1084 | from transformers import BartTokenizer, BartForSequenceClassification
1085 | import torch
1086 |
1087 | tokenizer = BartTokenizer.from_pretrained('bart-large')
1088 | model = BartForSequenceClassification.from_pretrained('bart-large')
1089 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute",
1090 | add_special_tokens=True)).unsqueeze(0) # Batch size 1
1091 | labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
1092 | outputs = model(input_ids, labels=labels)
1093 | loss, logits = outputs[:2]
1094 |
1095 | """
1096 | outputs = self.model(
1097 | input_ids,
1098 | attention_mask=attention_mask,
1099 | decoder_input_ids=decoder_input_ids,
1100 | decoder_attention_mask=decoder_attention_mask,
1101 | encoder_outputs=encoder_outputs,
1102 | )
1103 | x = outputs[0] # last hidden state
1104 | eos_mask = input_ids.eq(self.config.eos_token_id)
1105 | if len(torch.unique(eos_mask.sum(1))) > 1:
1106 | raise ValueError("All examples must have the same number of tokens.")
1107 | sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
1108 | logits = self.classification_head(sentence_representation)
1109 | # Prepend logits
1110 | outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
1111 | if labels is not None: # prepend loss to output,
1112 | loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
1113 | outputs = (loss,) + outputs
1114 |
1115 | return outputs
1116 |
1117 |
1118 | class SinusoidalPositionalEmbedding(nn.Embedding):
1119 | """This module produces sinusoidal positional embeddings of any length."""
1120 |
1121 | def __init__(self, num_positions, embedding_dim, padding_idx=None):
1122 | super().__init__(num_positions, embedding_dim)
1123 | if embedding_dim % 2 != 0:
1124 | raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
1125 | self.weight = self._init_weight(self.weight)
1126 |
1127 | @staticmethod
1128 | def _init_weight(out: nn.Parameter):
1129 | """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
1130 | The cos features are in the 2nd half of the vector. [dim // 2:]
1131 | """
1132 | n_pos, dim = out.shape
1133 | position_enc = np.array(
1134 | [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
1135 | )
1136 | out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
1137 | out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
1138 | out.detach_()
1139 | out.requires_grad = False
1140 | return out
1141 |
1142 | @torch.no_grad()
1143 | def forward(self, input_ids, use_cache=False):
1144 | """Input is expected to be of size [bsz x seqlen]."""
1145 | bsz, seq_len = input_ids.shape[:2]
1146 | if use_cache:
1147 | positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
1148 | else:
1149 | # starts at 0, ends at 1-seq_len
1150 | positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
1151 | return super().forward(positions)
--------------------------------------------------------------------------------
/longbart/modeling_longbart.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple
2 |
3 | from torch import Tensor, nn
4 |
5 | from transformers.modeling_longformer import LongformerSelfAttention
6 |
7 | from .modeling_bart import BartForConditionalGeneration
8 |
9 | class LongBartForConditionalGeneration(BartForConditionalGeneration):
10 | def __init__(self, config):
11 | super().__init__(config)
12 | for i, layer in enumerate(self.model.encoder.layers):
13 | # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`
14 | layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
15 |
16 |
17 | class LongformerSelfAttentionForBart(nn.Module):
18 | def __init__(self, config, layer_id):
19 | super().__init__()
20 | self.embed_dim = config.d_model
21 | self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
22 | self.output = nn.Linear(self.embed_dim, self.embed_dim)
23 |
24 | def forward(
25 | self,
26 | query,
27 | key: Optional[Tensor],
28 | key_padding_mask: Optional[Tensor] = None,
29 | layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
30 | attn_mask: Optional[Tensor] = None,
31 | need_weights=False,
32 | ) -> Tuple[Tensor, Optional[Tensor]]:
33 |
34 | tgt_len, bsz, embed_dim = query.size()
35 | assert embed_dim == self.embed_dim
36 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
37 |
38 | # LongformerSelfAttention expects this shape
39 | query = query.view(bsz, tgt_len, embed_dim)
40 |
41 | outputs = self.longformer_self_attn(
42 | query,
43 | attention_mask=attn_mask,
44 | head_mask=None,
45 | encoder_hidden_states=None,
46 | encoder_attention_mask=None,
47 | )
48 |
49 | attn_output = outputs[0]
50 | attn_output = attn_output.contiguous().view(tgt_len, bsz, embed_dim)
51 | attn_output = self.output(attn_output)
52 |
53 | return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)
54 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='longbart',
5 | version='0.1',
6 | description='Long version of the BART model',
7 | url='https://github.com/patil-suraj/longbart',
8 | author='Suraj Patil',
9 | author_email='surajp815@gmail.com',
10 | packages=['longbart'],
11 | keywords="NLP deep learning transformer pytorch bart",
12 | install_requires=[
13 | 'transformers == 2.11.0'
14 | ],
15 | python_requires=">=3.6.0",
16 | classifiers=[
17 | "Intended Audience :: Developers",
18 | "Intended Audience :: Education",
19 | "Intended Audience :: Science/Research",
20 | "License :: OSI Approved :: Apache Software License",
21 | "Operating System :: OS Independent",
22 | "Programming Language :: Python :: 3",
23 | "Programming Language :: Python :: 3.6",
24 | "Programming Language :: Python :: 3.7",
25 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
26 | ],
27 | zip_safe=False
28 | )
--------------------------------------------------------------------------------