├── MANIFEST.in
├── assets
├── tableQA.gif
└── ezgif.com-gif-maker.gif:Zone.Identifier
├── requirements.txt
├── src
└── ratransformers
│ ├── roberta.py
│ ├── gpt2.py
│ ├── bert.py
│ ├── bart.py
│ ├── __init__.py
│ ├── t5.py
│ └── longt5.py
├── setup.py
├── .github
└── workflows
│ └── build_and_test.yml
├── LICENSE
├── tests
└── test_ratransformer_forward.py
├── .gitignore
├── README.md
└── notebooks
├── NER_conll_example.ipynb
├── TableQA_tabfact_example.ipynb
└── text2sql_spider_example.ipynb
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include LICENSE
3 | include README.md
4 |
--------------------------------------------------------------------------------
/assets/tableQA.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoaoLages/RATransformers/HEAD/assets/tableQA.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.19.1
2 | transformers>=4.6.1
3 | setuptools>=49.6.0
4 | torch>=1.9.1
5 |
--------------------------------------------------------------------------------
/assets/ezgif.com-gif-maker.gif:Zone.Identifier:
--------------------------------------------------------------------------------
1 | [ZoneTransfer]
2 | ZoneId=3
3 | ReferrerUrl=https://ezgif.com/
4 | HostUrl=https://s7.ezgif.com/save/ezgif-7-1c4c0d0f22.gif
5 |
--------------------------------------------------------------------------------
/src/ratransformers/roberta.py:
--------------------------------------------------------------------------------
1 | from transformers.models.bert.modeling_bert import BertSelfAttention
2 | from transformers.models.roberta.modeling_roberta import RobertaSelfAttention
3 |
4 |
5 | class RobertaRelationalSelfAttention(BertSelfAttention):
6 | pass
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | with open('README.md', encoding='utf-8') as f:
5 | long_description = f.read()
6 |
7 | with open('requirements.txt', encoding='utf-8') as f:
8 | required = f.read().splitlines()
9 |
10 |
11 | setup(
12 | name='ratransformers',
13 | version='1.3.2',
14 | description='RATransformer - make a transformer model learn implicit relations passed in the input',
15 | long_description=long_description,
16 | long_description_content_type='text/markdown',
17 | url='https://github.com/JoaoLages/RATransformers',
18 | author='Joao Lages',
19 | author_email='joaop.glages@gmail.com',
20 | license='MIT',
21 | packages=find_packages('src'),
22 | package_dir={'': 'src'},
23 | install_requires=required
24 | )
25 |
--------------------------------------------------------------------------------
/.github/workflows/build_and_test.yml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: ["3.7", "3.8", "3.9", "3.10"]
12 |
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install flake8 pytest
23 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
24 | - name: Install package
25 | run: |
26 | pip install -e .
27 | - name: Test with pytest
28 | run: |
29 | python -m pytest
30 | - name: Save artifacts
31 | uses: actions/upload-artifact@v2
32 | with:
33 | name: plot-screenshots
34 | path: tmp/
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 João Lages
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/test_ratransformer_forward.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from packaging import version
3 | from ratransformers import RATransformer
4 | from transformers import AutoModelForSeq2SeqLM
5 | import transformers
6 |
7 |
8 | class TestModelsForward:
9 | """
10 | Test small supported models to see if nothing is breaking in forward step
11 | """
12 | def ratransformer_forward(self, ratransformer: RATransformer):
13 | model = ratransformer.model
14 | tokenizer = ratransformer.tokenizer
15 | encoding = tokenizer(
16 | "this is just a dummy text",
17 | return_tensors="pt",
18 | input_relations=None
19 | )
20 | if 'decoder_input_ids' in inspect.signature(model.forward).parameters:
21 | if version.parse(transformers.__version__) >= version.parse('4.13'):
22 | decoder_input_ids = model._prepare_decoder_input_ids_for_generation(encoding['input_ids'].shape[0], None, None)
23 | else:
24 | decoder_input_ids = model._prepare_decoder_input_ids_for_generation(encoding['input_ids'], None, None)
25 | encoding['decoder_input_ids'] = decoder_input_ids
26 | _ = model(**encoding)
27 |
28 | def test_t5(self):
29 | ratransformer = RATransformer(
30 | "t5-small",
31 | relation_kinds=['dummy1', 'dummy2'],
32 | model_cls=AutoModelForSeq2SeqLM
33 | )
34 | self.ratransformer_forward(ratransformer)
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # RATransformers 🐭
4 |
5 |  
6 |
7 | **RATransformers**, short for Relation-Aware Transformers, is a package built on top of [transformers 🤗](https://github.com/huggingface/transformers)
8 | that enables the training/fine-tuning of models with extra relation-aware input features.
9 |
10 |
11 | ### Example - Encoding a table in TableQA (Question Answering on Tabular Data)
12 | 
13 |
14 | [[Notebook Link](https://github.com/JoaoLages/RATransformers/blob/main/notebooks/TableQA_tabfact_example.ipynb)]
15 |
16 | In this example we can see that passing the table as text with no additional information to the model is a poor representation.
17 |
18 | With RATransformers 🐭 you are able to encode the table in a more structured way by passing specific relations within the input.
19 | RATransformers 🐭 also allows you to pass further features related with each input word/token.
20 |
21 | Check more examples in [[here](https://github.com/JoaoLages/RATransformers/blob/main/notebooks/)].
22 |
23 | ## Installation
24 |
25 | Install directly from PyPI:
26 |
27 | pip install ratransformers
28 |
29 | ## Usage
30 |
31 | ```python
32 | from ratransformers import RATransformer
33 | from transformers import AutoModelForSequenceClassification
34 |
35 |
36 | ratransformer = RATransformer(
37 | "nielsr/tapex-large-finetuned-tabfact", # define the 🤗 model you want to load
38 | relation_kinds=['is_value_of_column', 'is_from_same_row'], # define the relations that you want to model in the input
39 | model_cls=AutoModelForSequenceClassification, # define the model class
40 | pretrained_tokenizer_name_or_path='facebook/bart-large' # define the tokenizer you want to load (in case it is not the same as the model)
41 | )
42 | model = ratransformer.model
43 | tokenizer = ratransformer.tokenizer
44 | ```
45 |
46 | With only these steps your RATransformer 🐭 is ready to be trained.
47 |
48 | More implementation details in [the examples here](https://github.com/JoaoLages/RATransformers/blob/main/notebooks/).
49 |
50 | ## How does it work?
51 | We modify the self-attention layers of the transformer model as explained in the section 3 of [the RAT-SQL paper](https://arxiv.org/pdf/1911.04942.pdf).
52 |
53 | ## Supported Models
54 | Currently we support a limited number of transformer models:
55 | - [BART](https://huggingface.co/docs/transformers/model_doc/bart)
56 | - [BERT](https://huggingface.co/docs/transformers/model_doc/bert)
57 | - [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)
58 | - [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)
59 | - [T5](https://huggingface.co/docs/transformers/model_doc/t5)
60 | - [LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)
61 |
62 | Want another model? Feel free to open an [Issue](https://github.com/JoaoLages/RATransformers/issues) or create a [Pull Request](https://github.com/JoaoLages/RATransformers/pulls) and let's get started 🚀
63 |
--------------------------------------------------------------------------------
/src/ratransformers/gpt2.py:
--------------------------------------------------------------------------------
1 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | class GPT2RelationalAttention(GPT2Attention):
7 | def __init__(self, *args, num_relation_kinds: int, use_same_relation_kv_emb: bool = True, **kwargs):
8 | super().__init__(*args, **kwargs)
9 | self.num_relation_kinds = num_relation_kinds
10 | self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
11 | if use_same_relation_kv_emb:
12 | self.relation_v_emb = self.relation_k_emb
13 | else:
14 | self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
15 | self.input_relation_kinds = [] # will hold (batch, seq_length, seq_length, num_relation_kinds)
16 |
17 | def forward(
18 | self,
19 | hidden_states,
20 | layer_past=None,
21 | attention_mask=None,
22 | head_mask=None,
23 | encoder_hidden_states=None,
24 | encoder_attention_mask=None,
25 | use_cache=False,
26 | output_attentions=False,
27 | ):
28 | if encoder_hidden_states is not None:
29 | if not hasattr(self, "q_attn"):
30 | raise ValueError(
31 | "If class is used as cross attention, the weights `q_attn` have to be defined. "
32 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
33 | )
34 |
35 | query = self.q_attn(hidden_states)
36 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
37 | attention_mask = encoder_attention_mask
38 | else:
39 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
40 |
41 | batch_size, seq_length = hidden_states.shape[:2]
42 |
43 | assert len(self.input_relation_kinds) == 1
44 | input_relation_kinds = self.input_relation_kinds[0]
45 | assert input_relation_kinds.shape == (batch_size, seq_length, seq_length)
46 |
47 | query = self._split_heads(query, self.num_heads, self.head_dim)
48 | key = self._split_heads(key, self.num_heads, self.head_dim)
49 | value = self._split_heads(value, self.num_heads, self.head_dim)
50 |
51 | if layer_past is not None:
52 | past_key, past_value = layer_past
53 | key = torch.cat((past_key, key), dim=-2)
54 | value = torch.cat((past_value, value), dim=-2)
55 |
56 | if use_cache is True:
57 | present = (key, value)
58 | else:
59 | present = None
60 |
61 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
62 |
63 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
64 | attn_output = self.c_proj(attn_output)
65 | attn_output = self.resid_dropout(attn_output)
66 |
67 | outputs = (attn_output, present)
68 | if output_attentions:
69 | outputs += (attn_weights,)
70 |
71 | return outputs # a, present, (attentions)
72 |
73 | def _attn(self, query, key, value, attention_mask=None, head_mask=None):
74 |
75 | input_relation_kinds = self.input_relation_kinds[0]
76 | relation_k_embeds = self.relation_k_emb(input_relation_kinds)
77 |
78 | attn_weights = torch.matmul(query, key.transpose(-1, -2))
79 |
80 | # q_t is [batch, seq_length, n_heads, dim_per_head]
81 | q_t = query.permute(0, 2, 1, 3)
82 |
83 | # r_t is [batch, seq_length, dim_per_head, seq_length]
84 | r_t = relation_k_embeds.transpose(-2, -1)
85 |
86 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length]
87 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length]
88 |
89 | # Add to scores
90 | attn_weights += q_tr_tmatmul_t
91 |
92 | if self.scale_attn_weights:
93 | attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
94 |
95 | if not self.is_cross_attention:
96 | # if only "normal" attention layer implements causal mask
97 | query_length, key_length = query.size(-2), key.size(-2)
98 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
99 | attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
100 |
101 | if attention_mask is not None:
102 | # Apply the attention mask
103 | attn_weights = attn_weights + attention_mask
104 |
105 | attn_weights = nn.Softmax(dim=-1)(attn_weights)
106 | attn_weights = self.attn_dropout(attn_weights)
107 |
108 | # Mask heads if we want to
109 | if head_mask is not None:
110 | attn_weights = attn_weights * head_mask
111 |
112 | attn_output = torch.matmul(attn_weights, value)
113 |
114 | relation_v_embeds = self.relation_v_emb(input_relation_kinds)
115 |
116 | # w_t is [batch, seq_length, n_heads, seq_length]
117 | w_t = attn_weights.permute(0, 2, 1, 3)
118 |
119 | # [batch, seq_length, n_heads, seq_length]
120 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds)
121 |
122 | attn_output += w_tr_matmul.permute(0, 2, 1, 3)
123 |
124 | return attn_output, attn_weights
--------------------------------------------------------------------------------
/src/ratransformers/bert.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from transformers.models.bert.modeling_bert import BertSelfAttention
4 | import torch.nn as nn
5 | import torch
6 |
7 |
8 | class BertRelationalSelfAttention(BertSelfAttention):
9 | def __init__(self, *args, num_relation_kinds: int, use_same_relation_kv_emb: bool = True, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self.num_relation_kinds = num_relation_kinds
12 | self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.attention_head_size, padding_idx=0)
13 | if use_same_relation_kv_emb:
14 | self.relation_v_emb = self.relation_k_emb
15 | else:
16 | self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.attention_head_size, padding_idx=0)
17 | self.input_relation_kinds = [] # will hold (batch, seq_length, seq_length, num_relation_kinds)
18 |
19 | def forward(
20 | self,
21 | hidden_states,
22 | attention_mask=None,
23 | head_mask=None,
24 | encoder_hidden_states=None,
25 | encoder_attention_mask=None,
26 | past_key_value=None,
27 | output_attentions=False,
28 | ):
29 |
30 | batch_size, seq_length = hidden_states.shape[:2]
31 |
32 | assert len(self.input_relation_kinds) == 1
33 | input_relation_kinds = self.input_relation_kinds[0]
34 | assert input_relation_kinds.shape == (batch_size, seq_length, seq_length)
35 |
36 | # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
37 | relation_k_embeds = self.relation_k_emb(input_relation_kinds)
38 | relation_v_embeds = self.relation_v_emb(input_relation_kinds)
39 |
40 | mixed_query_layer = self.query(hidden_states)
41 |
42 | # If this is instantiated as a cross-attention module, the keys
43 | # and values come from an encoder; the attention mask needs to be
44 | # such that the encoder's padding tokens are not attended to.
45 | is_cross_attention = encoder_hidden_states is not None
46 |
47 | if is_cross_attention and past_key_value is not None:
48 | # reuse k,v, cross_attentions
49 | key_layer = past_key_value[0]
50 | value_layer = past_key_value[1]
51 | attention_mask = encoder_attention_mask
52 | elif is_cross_attention:
53 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
54 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
55 | attention_mask = encoder_attention_mask
56 | elif past_key_value is not None:
57 | key_layer = self.transpose_for_scores(self.key(hidden_states))
58 | value_layer = self.transpose_for_scores(self.value(hidden_states))
59 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
60 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
61 | else:
62 | key_layer = self.transpose_for_scores(self.key(hidden_states))
63 | value_layer = self.transpose_for_scores(self.value(hidden_states))
64 |
65 | query_layer = self.transpose_for_scores(mixed_query_layer)
66 |
67 | if self.is_decoder:
68 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
69 | # Further calls to cross_attention layer can then reuse all cross-attention
70 | # key/value_states (first "if" case)
71 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
72 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
73 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
74 | # if encoder bi-directional self-attention `past_key_value` is always `None`
75 | past_key_value = (key_layer, value_layer)
76 |
77 | # Take the dot product between "query" and "key" to get the raw attention scores.
78 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
79 |
80 | # q_t is [batch, seq_length, n_heads, dim_per_head]
81 | q_t = query_layer.permute(0, 2, 1, 3)
82 |
83 | # r_t is [batch, seq_length, dim_per_head, seq_length]
84 | r_t = relation_k_embeds.transpose(-2, -1)
85 |
86 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length]
87 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length]
88 |
89 | # Add to scores
90 | attention_scores += q_tr_tmatmul_t
91 |
92 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
93 | seq_length = hidden_states.size()[1]
94 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
95 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
96 | distance = position_ids_l - position_ids_r
97 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
98 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
99 |
100 | if self.position_embedding_type == "relative_key":
101 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
102 | attention_scores = attention_scores + relative_position_scores
103 | elif self.position_embedding_type == "relative_key_query":
104 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
105 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
106 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
107 |
108 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
109 | if attention_mask is not None:
110 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
111 | attention_scores = attention_scores + attention_mask
112 |
113 | # Normalize the attention scores to probabilities.
114 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
115 |
116 | # This is actually dropping out entire tokens to attend to, which might
117 | # seem a bit unusual, but is taken from the original Transformer paper.
118 | attention_probs = self.dropout(attention_probs)
119 |
120 | # Mask heads if we want to
121 | if head_mask is not None:
122 | attention_probs = attention_probs * head_mask
123 |
124 | context_layer = torch.matmul(attention_probs, value_layer)
125 |
126 | # w_t is [batch, seq_length, n_heads, seq_length]
127 | w_t = attention_probs.permute(0, 2, 1, 3)
128 |
129 | # [batch, seq_length, n_heads, seq_length]
130 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds)
131 |
132 | context_layer += w_tr_matmul.permute(0, 2, 1, 3)
133 |
134 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
135 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
136 | context_layer = context_layer.view(*new_context_layer_shape)
137 |
138 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
139 |
140 | if self.is_decoder:
141 | outputs = outputs + (past_key_value,)
142 | return outputs
143 |
--------------------------------------------------------------------------------
/notebooks/NER_conll_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "0b4a5857",
7 | "metadata": {
8 | "pycharm": {
9 | "name": "#%%\n"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "!pip install datasets"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "id": "636ff184",
21 | "metadata": {
22 | "pycharm": {
23 | "name": "#%%\n"
24 | }
25 | },
26 | "outputs": [
27 | {
28 | "name": "stderr",
29 | "output_type": "stream",
30 | "text": [
31 | "Reusing dataset conll2003 (/home/ola/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
32 | ]
33 | },
34 | {
35 | "data": {
36 | "application/vnd.jupyter.widget-view+json": {
37 | "model_id": "da46312ea7114898bc51966b2fe78766",
38 | "version_major": 2,
39 | "version_minor": 0
40 | },
41 | "text/plain": [
42 | " 0%| | 0/3 [00:00, ?it/s]"
43 | ]
44 | },
45 | "metadata": {},
46 | "output_type": "display_data"
47 | }
48 | ],
49 | "source": [
50 | "from datasets import load_dataset\n",
51 | "import ratransformers\n",
52 | "\n",
53 | "# Load dataset\n",
54 | "dataset = load_dataset('conll2003')\n",
55 | "\n",
56 | "# copied from https://huggingface.co/datasets/conll2003\n",
57 | "pos_tag_to_id = {'\"': 0, \"''\": 1, '#': 2, '$': 3, '(': 4, ')': 5, ',': 6, '.': 7, ':': 8, '``': 9, 'CC': 10, 'CD': 11, 'DT': 12,\n",
58 | " 'EX': 13, 'FW': 14, 'IN': 15, 'JJ': 16, 'JJR': 17, 'JJS': 18, 'LS': 19, 'MD': 20, 'NN': 21, 'NNP': 22, 'NNPS': 23,\n",
59 | " 'NNS': 24, 'NN|SYM': 25, 'PDT': 26, 'POS': 27, 'PRP': 28, 'PRP$': 29, 'RB': 30, 'RBR': 31, 'RBS': 32, 'RP': 33,\n",
60 | " 'SYM': 34, 'TO': 35, 'UH': 36, 'VB': 37, 'VBD': 38, 'VBG': 39, 'VBN': 40, 'VBP': 41, 'VBZ': 42, 'WDT': 43,\n",
61 | " 'WP': 44, 'WP$': 45, 'WRB': 46}\n",
62 | "\n",
63 | "id_to_pos_tag = {v: k for k, v in pos_tag_to_id.items()}"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "id": "a27dfb70",
70 | "metadata": {
71 | "pycharm": {
72 | "name": "#%%\n"
73 | }
74 | },
75 | "outputs": [],
76 | "source": [
77 | "from transformers import AutoModelForTokenClassification\n",
78 | "\n",
79 | "# Load ratransformer model and tokenizer\n",
80 | "ratransformer = ratransformers.RATransformer(\n",
81 | " \"dslim/bert-base-NER\", \n",
82 | " relation_kinds=list(pos_tag_to_id),\n",
83 | " model_cls=AutoModelForTokenClassification\n",
84 | ")\n",
85 | "model = ratransformer.model\n",
86 | "tokenizer = ratransformer.tokenizer"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 4,
92 | "id": "c1c2ab13",
93 | "metadata": {
94 | "pycharm": {
95 | "name": "#%%\n"
96 | }
97 | },
98 | "outputs": [
99 | {
100 | "data": {
101 | "text/plain": [
102 | "{'id': '0',\n",
103 | " 'tokens': ['EU',\n",
104 | " 'rejects',\n",
105 | " 'German',\n",
106 | " 'call',\n",
107 | " 'to',\n",
108 | " 'boycott',\n",
109 | " 'British',\n",
110 | " 'lamb',\n",
111 | " '.'],\n",
112 | " 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],\n",
113 | " 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],\n",
114 | " 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}"
115 | ]
116 | },
117 | "execution_count": 4,
118 | "metadata": {},
119 | "output_type": "execute_result"
120 | }
121 | ],
122 | "source": [
123 | "dataset['train'][0]"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 5,
129 | "id": "32253dd7",
130 | "metadata": {
131 | "pycharm": {
132 | "name": "#%%\n"
133 | }
134 | },
135 | "outputs": [
136 | {
137 | "data": {
138 | "text/plain": [
139 | "defaultdict(dict,\n",
140 | " {(0, 2): {(0, 2): 'NNP'},\n",
141 | " (3, 10): {(3, 10): 'VBZ'},\n",
142 | " (11, 17): {(11, 17): 'JJ'},\n",
143 | " (18, 22): {(18, 22): 'NN'},\n",
144 | " (23, 25): {(23, 25): 'TO'},\n",
145 | " (26, 33): {(26, 33): 'VB'},\n",
146 | " (34, 41): {(34, 41): 'JJ'},\n",
147 | " (42, 46): {(42, 46): 'NN'},\n",
148 | " (47, 48): {(47, 48): '.'}})"
149 | ]
150 | },
151 | "execution_count": 5,
152 | "metadata": {},
153 | "output_type": "execute_result"
154 | }
155 | ],
156 | "source": [
157 | "from collections import defaultdict\n",
158 | "\n",
159 | "# Construct a map from span in text to POS_TAG\n",
160 | "word_relations = defaultdict(dict)\n",
161 | "span_init = 0\n",
162 | "for tok, pos_tag_id in zip(dataset['train'][0]['tokens'], dataset['train'][0]['pos_tags']):\n",
163 | " span = (span_init, span_init + len(tok))\n",
164 | " word_relations[span][span] = id_to_pos_tag[pos_tag_id]\n",
165 | " span_init = span_init + len(tok + ' ')\n",
166 | "word_relations"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 6,
172 | "id": "2ec8aca9",
173 | "metadata": {
174 | "pycharm": {
175 | "name": "#%%\n"
176 | }
177 | },
178 | "outputs": [
179 | {
180 | "name": "stdout",
181 | "output_type": "stream",
182 | "text": [
183 | "EU B-ORG\n",
184 | "rejects O\n",
185 | "German B-MISC\n",
186 | "call O\n",
187 | "to O\n",
188 | "boycott O\n",
189 | "British B-MISC\n",
190 | "la O\n",
191 | "mb O\n",
192 | ". O\n"
193 | ]
194 | }
195 | ],
196 | "source": [
197 | "# encode \n",
198 | "text = \" \".join(dataset['train'][0]['tokens'])\n",
199 | "encoding = tokenizer(\n",
200 | " text, \n",
201 | " return_tensors=\"pt\", \n",
202 | " input_relations=word_relations\n",
203 | ")\n",
204 | "\n",
205 | "# forward pass\n",
206 | "outputs = model(**encoding)\n",
207 | "\n",
208 | "# get labels ids and convert to label tags\n",
209 | "labels = outputs.logits.argmax(-1)\n",
210 | "tokens_to_labels = [model.config.id2label[label_id.item()] for label_id in labels[0]]\n",
211 | "\n",
212 | "# print tokens with their predicted NER tags\n",
213 | "for i, token_i_map in enumerate(encoding['offset_mapping'][0]):\n",
214 | " span = token_i_map.tolist()\n",
215 | " token = text[span[0]:span[1]]\n",
216 | " if token: # skip special tokens\n",
217 | " print(token, tokens_to_labels[i])"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "id": "9246c501",
223 | "metadata": {
224 | "pycharm": {
225 | "name": "#%% md\n"
226 | }
227 | },
228 | "source": [
229 | "**Your model is now ready to be trained with relational information in the input!**\n",
230 | "\n",
231 | "Check the standard procedure to train HuggingFace 🤗 models in [here](https://huggingface.co/docs/transformers/training)."
232 | ]
233 | }
234 | ],
235 | "metadata": {
236 | "kernelspec": {
237 | "display_name": "Python 3",
238 | "language": "python",
239 | "name": "python3"
240 | },
241 | "language_info": {
242 | "codemirror_mode": {
243 | "name": "ipython",
244 | "version": 3
245 | },
246 | "file_extension": ".py",
247 | "mimetype": "text/x-python",
248 | "name": "python",
249 | "nbconvert_exporter": "python",
250 | "pygments_lexer": "ipython3",
251 | "version": "3.7.12"
252 | }
253 | },
254 | "nbformat": 4,
255 | "nbformat_minor": 5
256 | }
--------------------------------------------------------------------------------
/src/ratransformers/bart.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | from transformers.models.bart.modeling_bart import BartAttention
3 | import torch.nn as nn
4 | import torch
5 |
6 |
7 | class BartRelationalAttention(BartAttention):
8 | def __init__(self, *args, num_relation_kinds: int, use_same_relation_kv_emb: bool = True, **kwargs):
9 | super().__init__(*args, **kwargs)
10 | self.num_relation_kinds = num_relation_kinds
11 | self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
12 | if use_same_relation_kv_emb:
13 | self.relation_v_emb = self.relation_k_emb
14 | else:
15 | self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
16 | self.input_relation_kinds = [] # will hold (batch, seq_length, seq_length, num_relation_kinds)
17 |
18 | def forward(
19 | self,
20 | hidden_states: torch.Tensor,
21 | key_value_states: Optional[torch.Tensor] = None,
22 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
23 | attention_mask: Optional[torch.Tensor] = None,
24 | layer_head_mask: Optional[torch.Tensor] = None,
25 | output_attentions: bool = False,
26 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
27 | """Input shape: Batch x Time x Channel"""
28 |
29 | # if key_value_states are provided this layer is used as a cross-attention layer
30 | # for the decoder
31 | is_cross_attention = key_value_states is not None
32 | bsz, tgt_len, embed_dim = hidden_states.size()
33 |
34 | assert len(self.input_relation_kinds) == 1
35 | input_relation_kinds = self.input_relation_kinds[0]
36 | assert input_relation_kinds.shape == (bsz, tgt_len, tgt_len)
37 |
38 | # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
39 | relation_k_embeds = self.relation_k_emb(input_relation_kinds)
40 | relation_v_embeds = self.relation_v_emb(input_relation_kinds)
41 |
42 | # get query proj
43 | query_states = self.q_proj(hidden_states) * self.scaling
44 | # get key, value proj
45 | if is_cross_attention and past_key_value is not None:
46 | # reuse k,v, cross_attentions
47 | key_states = past_key_value[0]
48 | value_states = past_key_value[1]
49 | elif is_cross_attention:
50 | # cross_attentions
51 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
52 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
53 | elif past_key_value is not None:
54 | # reuse k, v, self_attention
55 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
56 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
57 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
58 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
59 | else:
60 | # self_attention
61 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
62 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
63 |
64 | if self.is_decoder:
65 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
66 | # Further calls to cross_attention layer can then reuse all cross-attention
67 | # key/value_states (first "if" case)
68 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
69 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
70 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
71 | # if encoder bi-directional self-attention `past_key_value` is always `None`
72 | past_key_value = (key_states, value_states)
73 |
74 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
75 | query_states = self._shape(query_states, tgt_len, bsz)
76 | src_len = key_states.size(2)
77 |
78 | # compute scores
79 | attn_weights = torch.matmul(
80 | query_states, key_states.transpose(3, 2)
81 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
82 |
83 | # q_t is [batch, seq_length, n_heads, dim_per_head]
84 | q_t = query_states.permute(0, 2, 1, 3)
85 |
86 | # r_t is [batch, seq_length, dim_per_head, seq_length]
87 | r_t = relation_k_embeds.transpose(-2, -1)
88 |
89 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length]
90 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length]
91 |
92 | # Add to scores
93 | attn_weights += q_tr_tmatmul_t
94 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
95 |
96 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
97 | raise ValueError(
98 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
99 | )
100 |
101 | if attention_mask is not None:
102 | if attention_mask.size() != (bsz, 1, tgt_len, src_len):
103 | raise ValueError(
104 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
105 | )
106 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
107 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
108 |
109 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
110 |
111 | if layer_head_mask is not None:
112 | if layer_head_mask.size() != (self.num_heads,):
113 | raise ValueError(
114 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
115 | )
116 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
117 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
118 |
119 | if output_attentions:
120 | # this operation is a bit awkward, but it's required to
121 | # make sure that attn_weights keeps its gradient.
122 | # In order to do so, attn_weights have to be reshaped
123 | # twice and have to be reused in the following
124 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
125 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
126 | else:
127 | attn_weights_reshaped = None
128 |
129 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
130 |
131 | attn_output = torch.bmm(attn_probs, value_states.view(*proj_shape))
132 |
133 | # w_t is [batch, seq_length, n_heads, seq_length]
134 | w_t = attn_probs.view(bsz, self.num_heads, tgt_len, src_len).permute(0, 2, 1, 3)
135 |
136 | # [batch, seq_length, n_heads, seq_length]
137 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds)
138 |
139 | attn_output += w_tr_matmul.permute(0, 2, 1, 3).view(bsz * self.num_heads, tgt_len, self.head_dim)
140 |
141 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
142 | raise ValueError(
143 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
144 | )
145 |
146 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
147 | attn_output = attn_output.transpose(1, 2)
148 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
149 |
150 | attn_output = self.out_proj(attn_output)
151 |
152 | return attn_output, attn_weights_reshaped, past_key_value
--------------------------------------------------------------------------------
/notebooks/TableQA_tabfact_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "6ac23e34",
6 | "metadata": {},
7 | "source": [
8 | "In this notebook we will show how to use RATransformers 🐭 to encode your data as the image shows.\n",
9 | "\n",
10 | ""
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "98ee93fc",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "!pip install pandas"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 1,
26 | "id": "7841fcd4",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import ratransformers\n",
31 | "import pandas as pd\n",
32 | "from transformers import BartTokenizerFast, BartForSequenceClassification\n",
33 | "\n",
34 | "\n",
35 | "ratransformer = ratransformers.RATransformer(\n",
36 | " \"nielsr/tapex-large-finetuned-tabfact\", # define the 🤗 model you want to load\n",
37 | " relation_kinds=['is_value_of_column', 'is_from_same_row'], # define the relations that you want to model in the input\n",
38 | " tokenizer_cls=BartTokenizerFast, # define the tokenizer class \n",
39 | " model_cls=BartForSequenceClassification, # define the model class\n",
40 | " pretrained_tokenizer_name_or_path='facebook/bart-large' # define the tokenizer you want to load (in case it is not the same as the model)\n",
41 | ")\n",
42 | "model = ratransformer.model\n",
43 | "tokenizer = ratransformer.tokenizer\n",
44 | "\n",
45 | "# create table\n",
46 | "data = {'Actors': [\"Brad Pitt\", \"Leonardo Di Caprio\", \"George Clooney\"], 'Number of movies': [\"87\", \"53\", \"69\"]}\n",
47 | "table = pd.DataFrame.from_dict(data)\n",
48 | "\n",
49 | "# turn into dict\n",
50 | "table_dict = {\"header\": list(table.columns), \"rows\": [list(row.values) for i,row in table.iterrows()]}"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "id": "df24f04d",
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "data": {
61 | "text/html": [
62 | ""
254 | ]
255 | },
256 | "metadata": {},
257 | "output_type": "display_data"
258 | },
259 | {
260 | "data": {
261 | "text/plain": [
262 | "{'eval_loss': 1.0085068941116333,\n",
263 | " 'eval_runtime': 1324.2969,\n",
264 | " 'eval_samples_per_second': 0.781}"
265 | ]
266 | },
267 | "execution_count": 10,
268 | "metadata": {},
269 | "output_type": "execute_result"
270 | }
271 | ],
272 | "source": [
273 | "# Set trainer\n",
274 | "trainer = Seq2SeqTrainer(\n",
275 | " model=model, \n",
276 | " args=training_args,\n",
277 | " train_dataset=train_d, \n",
278 | " eval_dataset=val_d, \n",
279 | " tokenizer=tokenizer,\n",
280 | " callbacks=[EarlyStoppingCallback()]\n",
281 | ")\n",
282 | "\n",
283 | "# get performance before training\n",
284 | "trainer.evaluate()"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 11,
290 | "id": "23b5a658",
291 | "metadata": {
292 | "scrolled": false
293 | },
294 | "outputs": [
295 | {
296 | "data": {
297 | "text/html": [
298 | "\n",
299 | " \n",
300 | " \n",
301 | "
\n",
302 | " [ 2000/100000 2:56:39 < 144:24:58, 0.19 it/s, Epoch 0/24]\n",
303 | "
\n",
304 | " \n",
305 | " \n",
306 | " \n",
307 | " | Step | \n",
308 | " Training Loss | \n",
309 | " Validation Loss | \n",
310 | "
\n",
311 | " \n",
312 | " \n",
313 | " \n",
314 | " | 1000 | \n",
315 | " 0.099700 | \n",
316 | " 0.325624 | \n",
317 | "
\n",
318 | " \n",
319 | " | 2000 | \n",
320 | " 0.060500 | \n",
321 | " 0.348871 | \n",
322 | "
\n",
323 | " \n",
324 | "
"
325 | ],
326 | "text/plain": [
327 | ""
328 | ]
329 | },
330 | "metadata": {},
331 | "output_type": "display_data"
332 | },
333 | {
334 | "data": {
335 | "text/plain": [
336 | "TrainOutput(global_step=2000, training_loss=0.10654444885253907, metrics={'train_runtime': 10601.4604, 'train_samples_per_second': 9.433, 'total_flos': 0, 'epoch': 0.47})"
337 | ]
338 | },
339 | "execution_count": 11,
340 | "metadata": {},
341 | "output_type": "execute_result"
342 | }
343 | ],
344 | "source": [
345 | "# train until early stopping\n",
346 | "trainer.train()"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 12,
352 | "id": "06da266f",
353 | "metadata": {},
354 | "outputs": [
355 | {
356 | "data": {
357 | "text/html": [
358 | "\n",
359 | " \n",
360 | " \n",
361 | "
\n",
362 | " [259/259 20:01]\n",
363 | "
\n",
364 | " "
365 | ],
366 | "text/plain": [
367 | ""
368 | ]
369 | },
370 | "metadata": {},
371 | "output_type": "display_data"
372 | },
373 | {
374 | "data": {
375 | "text/plain": [
376 | "{'eval_loss': 0.32562369108200073,\n",
377 | " 'eval_runtime': 1206.6218,\n",
378 | " 'eval_samples_per_second': 0.857,\n",
379 | " 'epoch': 0.47}"
380 | ]
381 | },
382 | "execution_count": 12,
383 | "metadata": {},
384 | "output_type": "execute_result"
385 | }
386 | ],
387 | "source": [
388 | "# get performance after training\n",
389 | "trainer.evaluate()"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 13,
395 | "id": "9854bb38",
396 | "metadata": {},
397 | "outputs": [],
398 | "source": [
399 | "# Save model\n",
400 | "trainer.save_model('ra-tscholak/1zha5ono')"
401 | ]
402 | },
403 | {
404 | "cell_type": "markdown",
405 | "id": "48673297",
406 | "metadata": {},
407 | "source": [
408 | "Training done! After saving, you can then reload the model with the ratransformers package again!"
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": 9,
414 | "id": "5c6251b6",
415 | "metadata": {
416 | "scrolled": false
417 | },
418 | "outputs": [
419 | {
420 | "data": {
421 | "text/html": [
422 | "\n",
423 | " \n",
424 | " \n",
425 | "
\n",
426 | " [259/259 15:35]\n",
427 | "
\n",
428 | " "
429 | ],
430 | "text/plain": [
431 | ""
432 | ]
433 | },
434 | "metadata": {},
435 | "output_type": "display_data"
436 | },
437 | {
438 | "data": {
439 | "text/plain": [
440 | "{'eval_loss': 0.32562369108200073,\n",
441 | " 'eval_runtime': 938.7328,\n",
442 | " 'eval_samples_per_second': 1.101}"
443 | ]
444 | },
445 | "execution_count": 9,
446 | "metadata": {},
447 | "output_type": "execute_result"
448 | }
449 | ],
450 | "source": [
451 | "# Reload model again\n",
452 | "ratransformer = ratransformers.RATransformer(\n",
453 | " 'ra-tscholak/1zha5ono', \n",
454 | " relation_kinds=['table_column_link', 'column_table_link'],\n",
455 | " alias_model_name='t5'\n",
456 | ")\n",
457 | "model = ratransformer.model\n",
458 | "tokenizer = ratransformer.tokenizer\n",
459 | "\n",
460 | "trainer = Seq2SeqTrainer(\n",
461 | " model=model, \n",
462 | " args=training_args,\n",
463 | " train_dataset=train_d, \n",
464 | " eval_dataset=val_d, \n",
465 | " tokenizer=tokenizer,\n",
466 | " callbacks=[EarlyStoppingCallback()]\n",
467 | ")\n",
468 | "trainer.evaluate()"
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": null,
474 | "id": "ee0202ca",
475 | "metadata": {},
476 | "outputs": [],
477 | "source": []
478 | }
479 | ],
480 | "metadata": {
481 | "kernelspec": {
482 | "display_name": "Python 3",
483 | "language": "python",
484 | "name": "python3"
485 | },
486 | "language_info": {
487 | "codemirror_mode": {
488 | "name": "ipython",
489 | "version": 3
490 | },
491 | "file_extension": ".py",
492 | "mimetype": "text/x-python",
493 | "name": "python",
494 | "nbconvert_exporter": "python",
495 | "pygments_lexer": "ipython3",
496 | "version": "3.7.12"
497 | }
498 | },
499 | "nbformat": 4,
500 | "nbformat_minor": 5
501 | }
502 |
--------------------------------------------------------------------------------
/src/ratransformers/t5.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import warnings
3 | from types import MethodType
4 | from typing import Optional, Tuple, Union
5 |
6 | from torch.nn import CrossEntropyLoss
7 | from torch.utils.checkpoint import checkpoint
8 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqModelOutput, \
9 | BaseModelOutput, Seq2SeqLMOutput
10 | from transformers.models.t5.modeling_t5 import T5Attention, T5LayerSelfAttention, T5LayerCrossAttention, T5Block, \
11 | logger, T5Stack, T5Model, T5_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, T5ForConditionalGeneration, T5EncoderModel, \
12 | T5_ENCODER_INPUTS_DOCSTRING
13 | import torch.nn as nn
14 | import torch
15 | from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
16 |
17 |
18 | def change_t5_module_and_get_relational_emb_dim(module_name: str, module: nn.Module) -> Optional[int]:
19 | relational_embedding_dim = None
20 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, T5Attention):
21 | module.forward = MethodType(RelationalT5Attention.forward, module)
22 | relational_embedding_dim = module.key_value_proj_dim
23 | elif isinstance(module, T5Stack):
24 | module.forward = MethodType(RelationalT5Stack.forward, module)
25 | elif isinstance(module, T5ForConditionalGeneration):
26 | module.forward = MethodType(RelationalT5ForConditionalGeneration.forward, module)
27 | elif isinstance(module, T5EncoderModel):
28 | module.forward = MethodType(RelationalT5EncoderModel.forward, module)
29 | elif isinstance(module, T5Model):
30 | module.forward = MethodType(RelationalT5Model.forward, module)
31 | elif isinstance(module, T5Block):
32 | module.forward = MethodType(RelationalT5Block.forward, module)
33 | elif isinstance(module, T5LayerCrossAttention):
34 | module.forward = MethodType(RelationalT5LayerCrossAttention.forward, module)
35 | elif isinstance(module, T5LayerSelfAttention):
36 | module.forward = MethodType(RelationalT5LayerSelfAttention.forward, module)
37 | return relational_embedding_dim
38 |
39 |
40 | class RelationalT5ForConditionalGeneration(T5ForConditionalGeneration):
41 | @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
42 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
43 | def forward(
44 | self,
45 | input_ids: Optional[torch.LongTensor] = None,
46 | attention_mask: Optional[torch.FloatTensor] = None,
47 | decoder_input_ids: Optional[torch.LongTensor] = None,
48 | decoder_attention_mask: Optional[torch.BoolTensor] = None,
49 | head_mask: Optional[torch.FloatTensor] = None,
50 | decoder_head_mask: Optional[torch.FloatTensor] = None,
51 | cross_attn_head_mask: Optional[torch.Tensor] = None,
52 | encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
53 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
54 | inputs_embeds: Optional[torch.FloatTensor] = None,
55 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
56 | labels: Optional[torch.LongTensor] = None,
57 | use_cache: Optional[bool] = None,
58 | output_attentions: Optional[bool] = None,
59 | output_hidden_states: Optional[bool] = None,
60 | return_dict: Optional[bool] = None,
61 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
62 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
63 | r"""
64 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
65 | Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
66 | config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
67 | labels in `[0, ..., config.vocab_size]`
68 |
69 | Returns:
70 |
71 | Examples:
72 |
73 | ```python
74 | >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
75 |
76 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
77 | >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
78 |
79 | >>> # training
80 | >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids
81 | >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids
82 | >>> outputs = model(input_ids=input_ids, labels=labels)
83 | >>> loss = outputs.loss
84 | >>> logits = outputs.logits
85 |
86 | >>> # inference
87 | >>> input_ids = tokenizer(
88 | ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
89 | ... ).input_ids # Batch size 1
90 | >>> outputs = model.generate(input_ids)
91 | >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
92 | >>> # studies have shown that owning a dog is good for you.
93 | ```"""
94 | use_cache = use_cache if use_cache is not None else self.config.use_cache
95 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
96 |
97 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
98 | if head_mask is not None and decoder_head_mask is None:
99 | if self.config.num_layers == self.config.num_decoder_layers:
100 | warnings.warn(
101 | """
102 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
103 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
104 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
105 | num_heads)`.
106 | """
107 | , FutureWarning
108 | )
109 | decoder_head_mask = head_mask
110 |
111 | # Encode if needed (training, first prediction pass)
112 | if encoder_outputs is None:
113 | # Convert encoder inputs in embeddings if needed
114 | encoder_outputs = self.encoder(
115 | input_ids=input_ids,
116 | attention_mask=attention_mask,
117 | inputs_embeds=inputs_embeds,
118 | head_mask=head_mask,
119 | output_attentions=output_attentions,
120 | output_hidden_states=output_hidden_states,
121 | return_dict=return_dict,
122 | input_relations=input_relations
123 | )
124 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
125 | encoder_outputs = BaseModelOutput(
126 | last_hidden_state=encoder_outputs[0],
127 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
128 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
129 | )
130 |
131 | hidden_states = encoder_outputs[0]
132 |
133 | if self.model_parallel:
134 | torch.cuda.set_device(self.decoder.first_device)
135 |
136 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
137 | # get decoder inputs from shifting lm labels to the right
138 | decoder_input_ids = self._shift_right(labels)
139 |
140 | # Set device for model parallelism
141 | if self.model_parallel:
142 | torch.cuda.set_device(self.decoder.first_device)
143 | hidden_states = hidden_states.to(self.decoder.first_device)
144 | if decoder_input_ids is not None:
145 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
146 | if attention_mask is not None:
147 | attention_mask = attention_mask.to(self.decoder.first_device)
148 | if decoder_attention_mask is not None:
149 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
150 |
151 | # Decode
152 | decoder_outputs = self.decoder(
153 | input_ids=decoder_input_ids,
154 | attention_mask=decoder_attention_mask,
155 | inputs_embeds=decoder_inputs_embeds,
156 | past_key_values=past_key_values,
157 | encoder_hidden_states=hidden_states,
158 | encoder_attention_mask=attention_mask,
159 | head_mask=decoder_head_mask,
160 | cross_attn_head_mask=cross_attn_head_mask,
161 | use_cache=use_cache,
162 | output_attentions=output_attentions,
163 | output_hidden_states=output_hidden_states,
164 | return_dict=return_dict,
165 | input_relations=input_relations
166 | )
167 |
168 | sequence_output = decoder_outputs[0]
169 |
170 | # Set device for model parallelism
171 | if self.model_parallel:
172 | torch.cuda.set_device(self.encoder.first_device)
173 | self.lm_head = self.lm_head.to(self.encoder.first_device)
174 | sequence_output = sequence_output.to(self.lm_head.weight.device)
175 |
176 | if self.config.tie_word_embeddings:
177 | # Rescale output before projecting on vocab
178 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
179 | sequence_output = sequence_output * (self.model_dim ** -0.5)
180 |
181 | lm_logits = self.lm_head(sequence_output)
182 |
183 | loss = None
184 | if labels is not None:
185 | loss_fct = CrossEntropyLoss(ignore_index=-100)
186 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
187 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
188 |
189 | if not return_dict:
190 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
191 | return ((loss,) + output) if loss is not None else output
192 |
193 | return Seq2SeqLMOutput(
194 | loss=loss,
195 | logits=lm_logits,
196 | past_key_values=decoder_outputs.past_key_values,
197 | decoder_hidden_states=decoder_outputs.hidden_states,
198 | decoder_attentions=decoder_outputs.attentions,
199 | cross_attentions=decoder_outputs.cross_attentions,
200 | encoder_last_hidden_state=encoder_outputs.last_hidden_state,
201 | encoder_hidden_states=encoder_outputs.hidden_states,
202 | encoder_attentions=encoder_outputs.attentions,
203 | )
204 |
205 |
206 | class RelationalT5EncoderModel(T5EncoderModel):
207 | @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
208 | @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
209 | def forward(
210 | self,
211 | input_ids: Optional[torch.LongTensor] = None,
212 | attention_mask: Optional[torch.FloatTensor] = None,
213 | head_mask: Optional[torch.FloatTensor] = None,
214 | inputs_embeds: Optional[torch.FloatTensor] = None,
215 | output_attentions: Optional[bool] = None,
216 | output_hidden_states: Optional[bool] = None,
217 | return_dict: Optional[bool] = None,
218 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
219 | ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
220 | r"""
221 | Returns:
222 |
223 | Example:
224 |
225 | ```python
226 | >>> from transformers import T5Tokenizer, T5EncoderModel
227 |
228 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
229 | >>> model = T5EncoderModel.from_pretrained("t5-small")
230 | >>> input_ids = tokenizer(
231 | ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
232 | ... ).input_ids # Batch size 1
233 | >>> outputs = model(input_ids=input_ids)
234 | >>> last_hidden_states = outputs.last_hidden_state
235 | ```"""
236 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
237 |
238 | encoder_outputs = self.encoder(
239 | input_ids=input_ids,
240 | attention_mask=attention_mask,
241 | inputs_embeds=inputs_embeds,
242 | head_mask=head_mask,
243 | output_attentions=output_attentions,
244 | output_hidden_states=output_hidden_states,
245 | return_dict=return_dict,
246 | input_relations=input_relations
247 | )
248 |
249 | return encoder_outputs
250 |
251 |
252 | class RelationalT5Model(T5Model):
253 | @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
254 | @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
255 | def forward(
256 | self,
257 | input_ids: Optional[torch.LongTensor] = None,
258 | attention_mask: Optional[torch.FloatTensor] = None,
259 | decoder_input_ids: Optional[torch.LongTensor] = None,
260 | decoder_attention_mask: Optional[torch.BoolTensor] = None,
261 | head_mask: Optional[torch.FloatTensor] = None,
262 | decoder_head_mask: Optional[torch.FloatTensor] = None,
263 | cross_attn_head_mask: Optional[torch.Tensor] = None,
264 | encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
265 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
266 | inputs_embeds: Optional[torch.Tensor] = None,
267 | decoder_inputs_embeds: Optional[torch.Tensor] = None,
268 | use_cache: Optional[bool] = None,
269 | output_attentions: Optional[bool] = None,
270 | output_hidden_states: Optional[bool] = None,
271 | return_dict: Optional[bool] = None,
272 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
273 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
274 | r"""
275 | Returns:
276 |
277 | Example:
278 |
279 | ```python
280 | >>> from transformers import T5Tokenizer, T5Model
281 |
282 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
283 | >>> model = T5Model.from_pretrained("t5-small")
284 |
285 | >>> input_ids = tokenizer(
286 | ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
287 | ... ).input_ids # Batch size 1
288 | >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
289 |
290 | >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
291 | >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
292 | >>> decoder_input_ids = model._shift_right(decoder_input_ids)
293 |
294 | >>> # forward pass
295 | >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
296 | >>> last_hidden_states = outputs.last_hidden_state
297 | ```"""
298 | use_cache = use_cache if use_cache is not None else self.config.use_cache
299 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
300 |
301 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
302 | if head_mask is not None and decoder_head_mask is None:
303 | if self.config.num_layers == self.config.num_decoder_layers:
304 | warnings.warn(
305 | """
306 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
307 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
308 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
309 | num_heads)`.
310 | """, FutureWarning
311 | )
312 | decoder_head_mask = head_mask
313 |
314 | # Encode if needed (training, first prediction pass)
315 | if encoder_outputs is None:
316 | encoder_outputs = self.encoder(
317 | input_ids=input_ids,
318 | attention_mask=attention_mask,
319 | inputs_embeds=inputs_embeds,
320 | head_mask=head_mask,
321 | output_attentions=output_attentions,
322 | output_hidden_states=output_hidden_states,
323 | return_dict=return_dict,
324 | input_relations=input_relations
325 | )
326 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
327 | encoder_outputs = BaseModelOutput(
328 | last_hidden_state=encoder_outputs[0],
329 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
330 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
331 | )
332 |
333 | hidden_states = encoder_outputs[0]
334 |
335 | # Set device for model parallelism
336 | if self.model_parallel:
337 | torch.cuda.set_device(self.decoder.first_device)
338 | hidden_states = hidden_states.to(self.decoder.first_device)
339 | if decoder_input_ids is not None:
340 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
341 | if attention_mask is not None:
342 | attention_mask = attention_mask.to(self.decoder.first_device)
343 | if decoder_attention_mask is not None:
344 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
345 |
346 | # Decode
347 | decoder_outputs = self.decoder(
348 | input_ids=decoder_input_ids,
349 | attention_mask=decoder_attention_mask,
350 | inputs_embeds=decoder_inputs_embeds,
351 | past_key_values=past_key_values,
352 | encoder_hidden_states=hidden_states,
353 | encoder_attention_mask=attention_mask,
354 | head_mask=decoder_head_mask,
355 | cross_attn_head_mask=cross_attn_head_mask,
356 | use_cache=use_cache,
357 | output_attentions=output_attentions,
358 | output_hidden_states=output_hidden_states,
359 | return_dict=return_dict,
360 | input_relations=input_relations
361 | )
362 |
363 | if not return_dict:
364 | return decoder_outputs + encoder_outputs
365 |
366 | return Seq2SeqModelOutput(
367 | last_hidden_state=decoder_outputs.last_hidden_state,
368 | past_key_values=decoder_outputs.past_key_values,
369 | decoder_hidden_states=decoder_outputs.hidden_states,
370 | decoder_attentions=decoder_outputs.attentions,
371 | cross_attentions=decoder_outputs.cross_attentions,
372 | encoder_last_hidden_state=encoder_outputs.last_hidden_state,
373 | encoder_hidden_states=encoder_outputs.hidden_states,
374 | encoder_attentions=encoder_outputs.attentions,
375 | )
376 |
377 |
378 | class RelationalT5Stack(T5Stack):
379 | def forward(
380 | self,
381 | input_ids=None,
382 | attention_mask=None,
383 | encoder_hidden_states=None,
384 | encoder_attention_mask=None,
385 | inputs_embeds=None,
386 | head_mask=None,
387 | cross_attn_head_mask=None,
388 | past_key_values=None,
389 | use_cache=None,
390 | output_attentions=None,
391 | output_hidden_states=None,
392 | return_dict=None,
393 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
394 | ):
395 | # Model parallel
396 | if self.model_parallel:
397 | torch.cuda.set_device(self.first_device)
398 | self.embed_tokens = self.embed_tokens.to(self.first_device)
399 | use_cache = use_cache if use_cache is not None else self.config.use_cache
400 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
401 | output_hidden_states = (
402 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
403 | )
404 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
405 |
406 | if input_ids is not None and inputs_embeds is not None:
407 | err_msg_prefix = "decoder_" if self.is_decoder else ""
408 | raise ValueError(
409 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
410 | )
411 | elif input_ids is not None:
412 | input_shape = input_ids.size()
413 | input_ids = input_ids.view(-1, input_shape[-1])
414 | elif inputs_embeds is not None:
415 | input_shape = inputs_embeds.size()[:-1]
416 | else:
417 | err_msg_prefix = "decoder_" if self.is_decoder else ""
418 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
419 |
420 | if inputs_embeds is None:
421 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
422 | inputs_embeds = self.embed_tokens(input_ids)
423 |
424 | batch_size, seq_length = input_shape
425 |
426 | # required mask seq length can be calculated via length of past
427 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
428 |
429 | if use_cache is True:
430 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
431 |
432 | if attention_mask is None:
433 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
434 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
435 | encoder_seq_length = encoder_hidden_states.shape[1]
436 | encoder_attention_mask = torch.ones(
437 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
438 | )
439 |
440 | # initialize past_key_values with `None` if past does not exist
441 | if past_key_values is None:
442 | past_key_values = [None] * len(self.block)
443 |
444 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
445 | # ourselves in which case we just need to make it broadcastable to all heads.
446 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
447 |
448 | # If a 2D or 3D attention mask is provided for the cross-attention
449 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
450 | if self.is_decoder and encoder_hidden_states is not None:
451 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
452 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
453 | if encoder_attention_mask is None:
454 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
455 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
456 | else:
457 | encoder_extended_attention_mask = None
458 |
459 | # Prepare head mask if needed
460 | head_mask = self.get_head_mask(head_mask, self.config.num_layers)
461 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
462 | present_key_value_states = () if use_cache else None
463 | all_hidden_states = () if output_hidden_states else None
464 | all_attentions = () if output_attentions else None
465 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None
466 | position_bias = None
467 | encoder_decoder_position_bias = None
468 |
469 | hidden_states = self.dropout(inputs_embeds)
470 |
471 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
472 | layer_head_mask = head_mask[i]
473 | cross_attn_layer_head_mask = cross_attn_head_mask[i]
474 | # Model parallel
475 | if self.model_parallel:
476 | torch.cuda.set_device(hidden_states.device)
477 | # Ensure that attention_mask is always on the same device as hidden_states
478 | if attention_mask is not None:
479 | attention_mask = attention_mask.to(hidden_states.device)
480 | if position_bias is not None:
481 | position_bias = position_bias.to(hidden_states.device)
482 | if encoder_hidden_states is not None:
483 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
484 | if encoder_extended_attention_mask is not None:
485 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
486 | if encoder_decoder_position_bias is not None:
487 | encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
488 | if layer_head_mask is not None:
489 | layer_head_mask = layer_head_mask.to(hidden_states.device)
490 | if cross_attn_layer_head_mask is not None:
491 | cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
492 | if output_hidden_states:
493 | all_hidden_states = all_hidden_states + (hidden_states,)
494 |
495 | if self.gradient_checkpointing and self.training:
496 | if use_cache:
497 | logger.warning(
498 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
499 | )
500 | use_cache = False
501 |
502 | def create_custom_forward(module):
503 | def custom_forward(*inputs, **kwd):
504 | return tuple(module(*inputs, **kwd, use_cache=use_cache, output_attentions=output_attentions))
505 |
506 | return custom_forward
507 |
508 | layer_outputs = checkpoint(
509 | create_custom_forward(layer_module),
510 | hidden_states,
511 | extended_attention_mask,
512 | position_bias,
513 | encoder_hidden_states,
514 | encoder_extended_attention_mask,
515 | encoder_decoder_position_bias,
516 | layer_head_mask,
517 | cross_attn_layer_head_mask,
518 | None, # past_key_value is always None with gradient checkpointing
519 | input_relations=input_relations
520 | )
521 | else:
522 | layer_outputs = layer_module(
523 | hidden_states,
524 | attention_mask=extended_attention_mask,
525 | position_bias=position_bias,
526 | encoder_hidden_states=encoder_hidden_states,
527 | encoder_attention_mask=encoder_extended_attention_mask,
528 | encoder_decoder_position_bias=encoder_decoder_position_bias,
529 | layer_head_mask=layer_head_mask,
530 | cross_attn_layer_head_mask=cross_attn_layer_head_mask,
531 | past_key_value=past_key_value,
532 | use_cache=use_cache,
533 | output_attentions=output_attentions,
534 | input_relations=input_relations
535 | )
536 |
537 | # layer_outputs is a tuple with:
538 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
539 | if use_cache is False:
540 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
541 |
542 | hidden_states, present_key_value_state = layer_outputs[:2]
543 |
544 | # We share the position biases between the layers - the first layer store them
545 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
546 | # (cross-attention position bias), (cross-attention weights)
547 | position_bias = layer_outputs[2]
548 | if self.is_decoder and encoder_hidden_states is not None:
549 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
550 | # append next layer key value states
551 | if use_cache:
552 | present_key_value_states = present_key_value_states + (present_key_value_state,)
553 |
554 | if output_attentions:
555 | all_attentions = all_attentions + (layer_outputs[3],)
556 | if self.is_decoder:
557 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
558 |
559 | # Model Parallel: If it's the last layer for that device, put things on the next device
560 | if self.model_parallel:
561 | for k, v in self.device_map.items():
562 | if i == v[-1] and "cuda:" + str(k) != self.last_device:
563 | hidden_states = hidden_states.to("cuda:" + str(k + 1))
564 |
565 | hidden_states = self.final_layer_norm(hidden_states)
566 | hidden_states = self.dropout(hidden_states)
567 |
568 | # Add last layer
569 | if output_hidden_states:
570 | all_hidden_states = all_hidden_states + (hidden_states,)
571 |
572 | if not return_dict:
573 | return tuple(
574 | v
575 | for v in [
576 | hidden_states,
577 | present_key_value_states,
578 | all_hidden_states,
579 | all_attentions,
580 | all_cross_attentions,
581 | ]
582 | if v is not None
583 | )
584 | return BaseModelOutputWithPastAndCrossAttentions(
585 | last_hidden_state=hidden_states,
586 | past_key_values=present_key_value_states,
587 | hidden_states=all_hidden_states,
588 | attentions=all_attentions,
589 | cross_attentions=all_cross_attentions,
590 | )
591 |
592 |
593 | class RelationalT5Block(T5Block):
594 | def forward(
595 | self,
596 | hidden_states,
597 | attention_mask=None,
598 | position_bias=None,
599 | encoder_hidden_states=None,
600 | encoder_attention_mask=None,
601 | encoder_decoder_position_bias=None,
602 | layer_head_mask=None,
603 | cross_attn_layer_head_mask=None,
604 | past_key_value=None,
605 | use_cache=False,
606 | output_attentions=False,
607 | return_dict=True,
608 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
609 | ):
610 |
611 | if past_key_value is not None:
612 | if not self.is_decoder:
613 | logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
614 | expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
615 |
616 | if len(past_key_value) != expected_num_past_key_values:
617 | raise ValueError(
618 | f"There should be {expected_num_past_key_values} past states. "
619 | f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
620 | f"Got {len(past_key_value)} past key / value states"
621 | )
622 |
623 | self_attn_past_key_value = past_key_value[:2]
624 | cross_attn_past_key_value = past_key_value[2:]
625 | else:
626 | self_attn_past_key_value, cross_attn_past_key_value = None, None
627 |
628 | self_attention_outputs = self.layer[0](
629 | hidden_states,
630 | attention_mask=attention_mask,
631 | position_bias=position_bias,
632 | layer_head_mask=layer_head_mask,
633 | past_key_value=self_attn_past_key_value,
634 | use_cache=use_cache,
635 | output_attentions=output_attentions,
636 | input_relations=input_relations
637 | )
638 | hidden_states, present_key_value_state = self_attention_outputs[:2]
639 | attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
640 |
641 | # clamp inf values to enable fp16 training
642 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
643 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000
644 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
645 |
646 | do_cross_attention = self.is_decoder and encoder_hidden_states is not None
647 | if do_cross_attention:
648 | # the actual query length is unknown for cross attention
649 | # if using past key value states. Need to inject it here
650 | if present_key_value_state is not None:
651 | query_length = present_key_value_state[0].shape[2]
652 | else:
653 | query_length = None
654 |
655 | cross_attention_outputs = self.layer[1](
656 | hidden_states,
657 | key_value_states=encoder_hidden_states,
658 | attention_mask=encoder_attention_mask,
659 | position_bias=encoder_decoder_position_bias,
660 | layer_head_mask=cross_attn_layer_head_mask,
661 | past_key_value=cross_attn_past_key_value,
662 | query_length=query_length,
663 | use_cache=use_cache,
664 | output_attentions=output_attentions,
665 | input_relations=input_relations
666 | )
667 | hidden_states = cross_attention_outputs[0]
668 |
669 | # clamp inf values to enable fp16 training
670 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
671 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000
672 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
673 |
674 | # Combine self attn and cross attn key value states
675 | if present_key_value_state is not None:
676 | present_key_value_state = present_key_value_state + cross_attention_outputs[1]
677 |
678 | # Keep cross-attention outputs and relative position weights
679 | attention_outputs = attention_outputs + cross_attention_outputs[2:]
680 |
681 | # Apply Feed Forward layer
682 | hidden_states = self.layer[-1](hidden_states)
683 |
684 | # clamp inf values to enable fp16 training
685 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
686 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000
687 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
688 |
689 | outputs = (hidden_states,)
690 |
691 | if use_cache:
692 | outputs = outputs + (present_key_value_state,) + attention_outputs
693 | else:
694 | outputs = outputs + attention_outputs
695 |
696 | return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
697 |
698 |
699 | class RelationalT5LayerSelfAttention(T5LayerSelfAttention):
700 | def forward(
701 | self,
702 | hidden_states,
703 | attention_mask=None,
704 | position_bias=None,
705 | layer_head_mask=None,
706 | past_key_value=None,
707 | use_cache=False,
708 | output_attentions=False,
709 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
710 | ):
711 | normed_hidden_states = self.layer_norm(hidden_states)
712 | kwargs = {'input_relations': input_relations} \
713 | if 'input_relations' in inspect.getfullargspec(self.SelfAttention.forward)[0] else {}
714 | attention_output = self.SelfAttention(
715 | normed_hidden_states,
716 | mask=attention_mask,
717 | position_bias=position_bias,
718 | layer_head_mask=layer_head_mask,
719 | past_key_value=past_key_value,
720 | use_cache=use_cache,
721 | output_attentions=output_attentions,
722 | **kwargs
723 | )
724 | hidden_states = hidden_states + self.dropout(attention_output[0])
725 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
726 | return outputs
727 |
728 |
729 | class RelationalT5LayerCrossAttention(T5LayerCrossAttention):
730 | def forward(
731 | self,
732 | hidden_states,
733 | key_value_states,
734 | attention_mask=None,
735 | position_bias=None,
736 | layer_head_mask=None,
737 | past_key_value=None,
738 | use_cache=False,
739 | query_length=None,
740 | output_attentions=False,
741 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
742 | ):
743 | normed_hidden_states = self.layer_norm(hidden_states)
744 | kwargs = {'input_relations': input_relations} \
745 | if 'input_relations' in inspect.getfullargspec(self.EncDecAttention.forward)[0] else {}
746 | attention_output = self.EncDecAttention(
747 | normed_hidden_states,
748 | mask=attention_mask,
749 | key_value_states=key_value_states,
750 | position_bias=position_bias,
751 | layer_head_mask=layer_head_mask,
752 | past_key_value=past_key_value,
753 | use_cache=use_cache,
754 | query_length=query_length,
755 | output_attentions=output_attentions,
756 | **kwargs
757 | )
758 | layer_output = hidden_states + self.dropout(attention_output[0])
759 | outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
760 | return outputs
761 |
762 |
763 | class RelationalT5Attention(T5Attention):
764 | def forward(
765 | self,
766 | hidden_states,
767 | mask=None,
768 | key_value_states=None,
769 | position_bias=None,
770 | past_key_value=None,
771 | layer_head_mask=None,
772 | query_length=None,
773 | use_cache=False,
774 | output_attentions=False,
775 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
776 | ):
777 | """
778 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
779 | """
780 |
781 | # Input is (batch_size, seq_length, dim)
782 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
783 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
784 | batch_size, seq_length = hidden_states.shape[:2]
785 |
786 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model"
787 | assert input_relations.shape == (batch_size, seq_length, seq_length)
788 |
789 | # (batch_size, seq_length, seq_length, self.inner_dim // num_relation_kinds)
790 | relation_k_embeds = self.relation_k_emb(input_relations)
791 | relation_v_embeds = self.relation_v_emb(input_relations)
792 |
793 | real_seq_length = seq_length
794 |
795 | if past_key_value is not None:
796 | assert (
797 | len(past_key_value) == 2
798 | ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
799 | real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
800 |
801 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
802 |
803 | def shape(states):
804 | """projection"""
805 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
806 |
807 | def unshape(states):
808 | """reshape"""
809 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
810 |
811 | def project(hidden_states, proj_layer, key_value_states, past_key_value):
812 | """projects hidden states correctly to key/query states"""
813 | if key_value_states is None:
814 | # self-attn
815 | # (batch_size, n_heads, seq_length, dim_per_head)
816 | hidden_states = shape(proj_layer(hidden_states))
817 | elif past_key_value is None:
818 | # cross-attn
819 | # (batch_size, n_heads, seq_length, dim_per_head)
820 | hidden_states = shape(proj_layer(key_value_states))
821 |
822 | if past_key_value is not None:
823 | if key_value_states is None:
824 | # self-attn
825 | # (batch_size, n_heads, key_length, dim_per_head)
826 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
827 | else:
828 | # cross-attn
829 | hidden_states = past_key_value
830 | return hidden_states
831 |
832 | # get query states
833 | query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, key_value_proj_dim)
834 |
835 | # get key/value states
836 | key_states = project(
837 | hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
838 | )
839 | value_states = project(
840 | hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
841 | )
842 |
843 | # compute scores
844 | scores = torch.matmul(
845 | query_states, key_states.transpose(3, 2)
846 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
847 |
848 | # q_t is [batch, seq_length, n_heads, key_value_proj_dim]
849 | q_t = query_states.permute(0, 2, 1, 3)
850 |
851 | # r_t is [batch, seq_length, key_value_proj_dim, seq_length]
852 | r_t = relation_k_embeds.transpose(-2, -1)
853 |
854 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length]
855 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length]
856 |
857 | # Add to scores
858 | scores += q_tr_tmatmul_t
859 |
860 | if position_bias is None:
861 | if not self.has_relative_attention_bias:
862 | position_bias = torch.zeros(
863 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
864 | )
865 | if self.gradient_checkpointing and self.training:
866 | position_bias.requires_grad = True
867 | else:
868 | position_bias = self.compute_bias(real_seq_length, key_length)
869 |
870 | # if key and values are already calculated
871 | # we want only the last query position bias
872 | if past_key_value is not None:
873 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
874 |
875 | if mask is not None:
876 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
877 |
878 | if self.pruned_heads:
879 | mask = torch.ones(position_bias.shape[1])
880 | mask[list(self.pruned_heads)] = 0
881 | position_bias_masked = position_bias[:, mask.bool()]
882 | else:
883 | position_bias_masked = position_bias
884 |
885 | scores += position_bias_masked
886 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
887 | scores
888 | ) # (batch_size, n_heads, seq_length, key_length)
889 | attn_weights = nn.functional.dropout(
890 | attn_weights, p=self.dropout, training=self.training
891 | ) # (batch_size, n_heads, seq_length, key_length)
892 |
893 | # Mask heads if we want to
894 | if layer_head_mask is not None:
895 | attn_weights = attn_weights * layer_head_mask
896 |
897 | # [batch, n_heads, seq_length, seq_length]
898 | wv_matmul = torch.matmul(attn_weights, value_states)
899 |
900 | # w_t is [batch, seq_length, n_heads, seq_length]
901 | w_t = attn_weights.permute(0, 2, 1, 3)
902 |
903 | # [batch, seq_length, n_heads, seq_length]
904 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds)
905 |
906 | attn_output = unshape(wv_matmul + w_tr_matmul.permute(0, 2, 1, 3)) # (batch_size, seq_length, dim)
907 | attn_output = self.o(attn_output)
908 |
909 | present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
910 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
911 |
912 | if output_attentions:
913 | outputs = outputs + (attn_weights,)
914 | return outputs
915 |
--------------------------------------------------------------------------------
/src/ratransformers/longt5.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import warnings
3 | from types import MethodType
4 | from typing import Optional, Tuple, Union, Any
5 |
6 | from torch.nn import CrossEntropyLoss
7 | from torch.utils.checkpoint import checkpoint
8 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqModelOutput, \
9 | BaseModelOutput, Seq2SeqLMOutput
10 | from transformers.models.longt5.modeling_longt5 import LongT5Attention, LongT5LayerSelfAttention, \
11 | LongT5LayerCrossAttention, LongT5Block, \
12 | logger, LongT5Stack, LongT5Model, _CONFIG_FOR_DOC, LONGT5_INPUTS_DOCSTRING, _get_local_attention_mask, \
13 | LongT5LocalAttention, _split_into_blocks, _concatenate_3_blocks, _pad_to_multiple, LongT5TransientGlobalAttention, \
14 | _make_global_fixed_block_ids, _create_global_aggregates, LongT5LayerTransientGlobalSelfAttention, \
15 | LongT5LayerLocalSelfAttention, LongT5PreTrainedModel, LongT5ForConditionalGeneration, LongT5EncoderModel, \
16 | LONGT5_ENCODER_INPUTS_DOCSTRING
17 | import torch.nn as nn
18 | import torch
19 | from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
20 |
21 |
22 | def change_longt5_module_and_get_relational_emb_dim(module_name: str, module: nn.Module) -> Optional[int]:
23 | relational_embedding_dim = None
24 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, LongT5Attention):
25 | module.forward = MethodType(RelationalLongT5Attention.forward, module)
26 | relational_embedding_dim = module.inner_dim // module.n_heads
27 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, LongT5LocalAttention):
28 | module.forward = MethodType(RelationalLongT5LocalAttention.forward, module)
29 | relational_embedding_dim = module.inner_dim // module.n_heads
30 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, LongT5TransientGlobalAttention):
31 | module.forward = MethodType(RelationalLongT5TransientGlobalAttention.forward, module)
32 | relational_embedding_dim = module.inner_dim // module.n_heads
33 | elif isinstance(module, LongT5Stack):
34 | module.forward = MethodType(RelationalLongT5Stack.forward, module)
35 | elif isinstance(module, LongT5ForConditionalGeneration):
36 | module.forward = MethodType(RelationalLongT5ForConditionalGeneration.forward, module)
37 | elif isinstance(module, LongT5Model):
38 | module.forward = MethodType(RelationalLongT5Model.forward, module)
39 | elif isinstance(module, LongT5EncoderModel):
40 | module.forward = MethodType(RelationalLongT5EncoderModel.forward, module)
41 | elif isinstance(module, LongT5Block):
42 | module.forward = MethodType(RelationalLongT5Block.forward, module)
43 | elif isinstance(module, LongT5LayerCrossAttention):
44 | module.forward = MethodType(RelationalLongT5LayerCrossAttention.forward, module)
45 | elif isinstance(module, LongT5LayerSelfAttention):
46 | module.forward = MethodType(RelationalLongT5LayerSelfAttention.forward, module)
47 | elif isinstance(module, LongT5LayerLocalSelfAttention):
48 | module.forward = MethodType(RelationalLongT5LayerLocalSelfAttention.forward, module)
49 | elif isinstance(module, LongT5LayerTransientGlobalSelfAttention):
50 | module.forward = MethodType(RelationalLongT5LayerTransientGlobalSelfAttention.forward, module)
51 | return relational_embedding_dim
52 |
53 |
54 | class RelationalLongT5ForConditionalGeneration(LongT5ForConditionalGeneration):
55 | @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
56 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
57 | def forward(
58 | self,
59 | input_ids: Optional[torch.LongTensor] = None,
60 | attention_mask: Optional[torch.FloatTensor] = None,
61 | decoder_input_ids: Optional[torch.LongTensor] = None,
62 | decoder_attention_mask: Optional[torch.BoolTensor] = None,
63 | head_mask: Optional[torch.FloatTensor] = None,
64 | decoder_head_mask: Optional[torch.FloatTensor] = None,
65 | cross_attn_head_mask: Optional[torch.Tensor] = None,
66 | encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
67 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
68 | inputs_embeds: Optional[torch.FloatTensor] = None,
69 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
70 | labels: Optional[torch.LongTensor] = None,
71 | use_cache: Optional[bool] = None,
72 | output_attentions: Optional[bool] = None,
73 | output_hidden_states: Optional[bool] = None,
74 | return_dict: Optional[bool] = None,
75 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
76 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
77 | r"""
78 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
79 | Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
80 | config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
81 | labels in `[0, ..., config.vocab_size]`
82 |
83 | Returns:
84 |
85 | Examples:
86 |
87 | ```python
88 | >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
89 |
90 | >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
91 | >>> model = LongT5ForConditionalGeneration.from_pretrained(
92 | ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
93 | ... )
94 |
95 | >>> # Let's try a very long input.
96 | >>> inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
97 | >>> input_ids = inputs.input_ids
98 |
99 | >>> outputs = model.generate(input_ids)
100 | >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
101 | abstractthe aim of this article is to provide an overview of the literature on the role of dog
102 | ```"""
103 | use_cache = use_cache if use_cache is not None else self.config.use_cache
104 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
105 |
106 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
107 | if head_mask is not None and decoder_head_mask is None:
108 | if self.config.num_layers == self.config.num_decoder_layers:
109 | warnings.warn(
110 | """
111 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
112 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
113 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
114 | num_heads)`.
115 | """
116 | , FutureWarning
117 | )
118 | decoder_head_mask = head_mask
119 |
120 | # Encode if needed (training, first prediction pass)
121 | if encoder_outputs is None:
122 | # Convert encoder inputs in embeddings if needed
123 | encoder_outputs = self.encoder(
124 | input_ids=input_ids,
125 | attention_mask=attention_mask,
126 | inputs_embeds=inputs_embeds,
127 | head_mask=head_mask,
128 | output_attentions=output_attentions,
129 | output_hidden_states=output_hidden_states,
130 | return_dict=return_dict,
131 | input_relations=input_relations
132 | )
133 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
134 | encoder_outputs = BaseModelOutput(
135 | last_hidden_state=encoder_outputs[0],
136 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
137 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
138 | )
139 |
140 | hidden_states = encoder_outputs[0]
141 |
142 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
143 | # get decoder inputs from shifting lm labels to the right
144 | decoder_input_ids = self._shift_right(labels)
145 |
146 | # Decode
147 | decoder_outputs = self.decoder(
148 | input_ids=decoder_input_ids,
149 | attention_mask=decoder_attention_mask,
150 | inputs_embeds=decoder_inputs_embeds,
151 | past_key_values=past_key_values,
152 | encoder_hidden_states=hidden_states,
153 | encoder_attention_mask=attention_mask,
154 | head_mask=decoder_head_mask,
155 | cross_attn_head_mask=cross_attn_head_mask,
156 | use_cache=use_cache,
157 | output_attentions=output_attentions,
158 | output_hidden_states=output_hidden_states,
159 | return_dict=return_dict,
160 | input_relations=input_relations
161 | )
162 |
163 | sequence_output = decoder_outputs[0]
164 |
165 | if self.config.tie_word_embeddings:
166 | # Rescale output before projecting on vocab
167 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
168 | sequence_output = sequence_output * (self.model_dim ** -0.5)
169 |
170 | lm_logits = self.lm_head(sequence_output)
171 |
172 | loss = None
173 | if labels is not None:
174 | loss_fct = CrossEntropyLoss(ignore_index=-100)
175 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
176 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
177 |
178 | if not return_dict:
179 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
180 | return ((loss,) + output) if loss is not None else output
181 |
182 | return Seq2SeqLMOutput(
183 | loss=loss,
184 | logits=lm_logits,
185 | past_key_values=decoder_outputs.past_key_values,
186 | decoder_hidden_states=decoder_outputs.hidden_states,
187 | decoder_attentions=decoder_outputs.attentions,
188 | cross_attentions=decoder_outputs.cross_attentions,
189 | encoder_last_hidden_state=encoder_outputs.last_hidden_state,
190 | encoder_hidden_states=encoder_outputs.hidden_states,
191 | encoder_attentions=encoder_outputs.attentions,
192 | )
193 |
194 | class RelationalLongT5Model(LongT5Model):
195 | @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
196 | @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
197 | def forward(
198 | self,
199 | input_ids: Optional[torch.LongTensor] = None,
200 | attention_mask: Optional[torch.FloatTensor] = None,
201 | decoder_input_ids: Optional[torch.LongTensor] = None,
202 | decoder_attention_mask: Optional[torch.BoolTensor] = None,
203 | head_mask: Optional[torch.FloatTensor] = None,
204 | decoder_head_mask: Optional[torch.FloatTensor] = None,
205 | cross_attn_head_mask: Optional[torch.Tensor] = None,
206 | encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
207 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
208 | inputs_embeds: Optional[torch.Tensor] = None,
209 | decoder_inputs_embeds: Optional[torch.Tensor] = None,
210 | use_cache: Optional[bool] = None,
211 | output_attentions: Optional[bool] = None,
212 | output_hidden_states: Optional[bool] = None,
213 | return_dict: Optional[bool] = None,
214 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
215 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
216 | r"""
217 | Returns:
218 |
219 | Example:
220 |
221 | ```python
222 | >>> from transformers import T5Tokenizer, LongT5Model
223 |
224 | >>> tokenizer = T5Tokenizer.from_pretrained("google/long-t5-local-base")
225 | >>> model = LongT5Model.from_pretrained("google/long-t5-local-base")
226 |
227 | >>> # Let's try a very long encoder input.
228 | >>> input_ids = tokenizer(
229 | ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt"
230 | ... ).input_ids # Batch size 1
231 |
232 | >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
233 |
234 | >>> # forward pass
235 | >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
236 | >>> last_hidden_states = outputs.last_hidden_state
237 | ```"""
238 | use_cache = use_cache if use_cache is not None else self.config.use_cache
239 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
240 |
241 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
242 | if head_mask is not None and decoder_head_mask is None:
243 | if self.config.num_layers == self.config.num_decoder_layers:
244 | warnings.warn(
245 | """
246 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
247 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
248 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
249 | num_heads)`.
250 | """
251 | , FutureWarning
252 | )
253 | decoder_head_mask = head_mask
254 |
255 | # Encode if needed (training, first prediction pass)
256 | if encoder_outputs is None:
257 | encoder_outputs = self.encoder(
258 | input_ids=input_ids,
259 | attention_mask=attention_mask,
260 | inputs_embeds=inputs_embeds,
261 | head_mask=head_mask,
262 | output_attentions=output_attentions,
263 | output_hidden_states=output_hidden_states,
264 | return_dict=return_dict,
265 | input_relations=input_relations
266 | )
267 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
268 | encoder_outputs = BaseModelOutput(
269 | last_hidden_state=encoder_outputs[0],
270 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
271 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
272 | )
273 |
274 | hidden_states = encoder_outputs[0]
275 |
276 | # Decode
277 | decoder_outputs = self.decoder(
278 | input_ids=decoder_input_ids,
279 | attention_mask=decoder_attention_mask,
280 | inputs_embeds=decoder_inputs_embeds,
281 | past_key_values=past_key_values,
282 | encoder_hidden_states=hidden_states,
283 | encoder_attention_mask=attention_mask,
284 | head_mask=decoder_head_mask,
285 | cross_attn_head_mask=cross_attn_head_mask,
286 | use_cache=use_cache,
287 | output_attentions=output_attentions,
288 | output_hidden_states=output_hidden_states,
289 | return_dict=return_dict,
290 | input_relations=input_relations
291 | )
292 |
293 | if not return_dict:
294 | return decoder_outputs + encoder_outputs
295 |
296 | return Seq2SeqModelOutput(
297 | last_hidden_state=decoder_outputs.last_hidden_state,
298 | past_key_values=decoder_outputs.past_key_values,
299 | decoder_hidden_states=decoder_outputs.hidden_states,
300 | decoder_attentions=decoder_outputs.attentions,
301 | cross_attentions=decoder_outputs.cross_attentions,
302 | encoder_last_hidden_state=encoder_outputs.last_hidden_state,
303 | encoder_hidden_states=encoder_outputs.hidden_states,
304 | encoder_attentions=encoder_outputs.attentions,
305 | )
306 |
307 |
308 | class RelationalLongT5EncoderModel(LongT5EncoderModel):
309 | @add_start_docstrings_to_model_forward(LONGT5_ENCODER_INPUTS_DOCSTRING)
310 | @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
311 | def forward(
312 | self,
313 | input_ids: Optional[torch.LongTensor] = None,
314 | attention_mask: Optional[torch.FloatTensor] = None,
315 | head_mask: Optional[torch.FloatTensor] = None,
316 | inputs_embeds: Optional[torch.FloatTensor] = None,
317 | output_attentions: Optional[bool] = None,
318 | output_hidden_states: Optional[bool] = None,
319 | return_dict: Optional[bool] = None,
320 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
321 | ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
322 | r"""
323 | Returns:
324 |
325 | Example:
326 |
327 | ```python
328 | >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
329 |
330 | >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
331 | >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base")
332 | >>> input_ids = tokenizer(
333 | ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt"
334 | ... ).input_ids # Batch size 1
335 | >>> outputs = model(input_ids=input_ids)
336 | >>> last_hidden_states = outputs.last_hidden_state
337 | ```"""
338 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
339 |
340 | encoder_outputs = self.encoder(
341 | input_ids=input_ids,
342 | attention_mask=attention_mask,
343 | inputs_embeds=inputs_embeds,
344 | head_mask=head_mask,
345 | output_attentions=output_attentions,
346 | output_hidden_states=output_hidden_states,
347 | return_dict=return_dict,
348 | input_relations=input_relations
349 | )
350 |
351 | return encoder_outputs
352 |
353 | class RelationalLongT5Stack(LongT5Stack):
354 | def forward(
355 | self,
356 | input_ids=None,
357 | attention_mask=None,
358 | encoder_hidden_states=None,
359 | encoder_attention_mask=None,
360 | inputs_embeds=None,
361 | head_mask=None,
362 | cross_attn_head_mask=None,
363 | past_key_values=None,
364 | use_cache=None,
365 | output_attentions=None,
366 | output_hidden_states=None,
367 | return_dict=None,
368 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
369 | ):
370 | use_cache = use_cache if use_cache is not None else self.config.use_cache
371 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
372 | output_hidden_states = (
373 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
374 | )
375 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
376 |
377 | if input_ids is not None and inputs_embeds is not None:
378 | err_msg_prefix = "decoder_" if self.is_decoder else ""
379 | raise ValueError(
380 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
381 | )
382 | elif input_ids is not None:
383 | input_shape = input_ids.size()
384 | input_ids = input_ids.view(-1, input_shape[-1])
385 | elif inputs_embeds is not None:
386 | input_shape = inputs_embeds.size()[:-1]
387 | else:
388 | err_msg_prefix = "decoder_" if self.is_decoder else ""
389 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
390 |
391 | if inputs_embeds is None:
392 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
393 | inputs_embeds = self.embed_tokens(input_ids)
394 |
395 | batch_size, seq_length = input_shape
396 |
397 | # required mask seq length can be calculated via length of past
398 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
399 |
400 | if use_cache is True:
401 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
402 |
403 | if attention_mask is None:
404 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
405 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
406 | encoder_seq_length = encoder_hidden_states.shape[1]
407 | encoder_attention_mask = torch.ones(
408 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
409 | )
410 |
411 | # initialize past_key_values with `None` if past does not exist
412 | if past_key_values is None:
413 | past_key_values = [None] * len(self.block)
414 |
415 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
416 | # ourselves in which case we just need to make it broadcastable to all heads.
417 | # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
418 | if self.is_decoder:
419 | extended_attention_mask = self.get_extended_attention_mask(
420 | attention_mask, input_shape, inputs_embeds.device
421 | )
422 | elif self.config.encoder_attention_type == "local":
423 | extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
424 | else: # we need to use both local attention mask and standard extended mask for transient-global attention
425 | extended_attention_mask = attention_mask
426 |
427 | # If a 2D or 3D attention mask is provided for the cross-attention
428 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
429 | if self.is_decoder and encoder_hidden_states is not None:
430 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
431 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
432 | if encoder_attention_mask is None:
433 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
434 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
435 | else:
436 | encoder_extended_attention_mask = None
437 |
438 | # Prepare head mask if needed
439 | head_mask = self.get_head_mask(head_mask, self.config.num_layers)
440 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
441 | present_key_value_states = () if use_cache else None
442 | all_hidden_states = () if output_hidden_states else None
443 | all_attentions = () if output_attentions else None
444 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None
445 | position_bias = None
446 | encoder_decoder_position_bias = None
447 |
448 | hidden_states = self.dropout(inputs_embeds)
449 |
450 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
451 | layer_head_mask = head_mask[i]
452 | cross_attn_layer_head_mask = cross_attn_head_mask[i]
453 |
454 | if output_hidden_states:
455 | all_hidden_states = all_hidden_states + (hidden_states,)
456 |
457 | if self.gradient_checkpointing and self.training:
458 | if use_cache:
459 | use_cache = False
460 |
461 | def create_custom_forward(module):
462 | def custom_forward(*inputs, **kwd):
463 | return tuple(module(*inputs, **kwd, use_cache=use_cache, output_attentions=output_attentions))
464 |
465 | return custom_forward
466 |
467 | layer_outputs = checkpoint(
468 | create_custom_forward(layer_module),
469 | hidden_states,
470 | extended_attention_mask,
471 | position_bias,
472 | encoder_hidden_states,
473 | encoder_extended_attention_mask,
474 | encoder_decoder_position_bias,
475 | layer_head_mask,
476 | cross_attn_layer_head_mask,
477 | None, # past_key_value is always None with gradient checkpointing
478 | input_relations=input_relations
479 | )
480 | else:
481 | layer_outputs = layer_module(
482 | hidden_states,
483 | attention_mask=extended_attention_mask,
484 | position_bias=position_bias,
485 | encoder_hidden_states=encoder_hidden_states,
486 | encoder_attention_mask=encoder_extended_attention_mask,
487 | encoder_decoder_position_bias=encoder_decoder_position_bias,
488 | layer_head_mask=layer_head_mask,
489 | cross_attn_layer_head_mask=cross_attn_layer_head_mask,
490 | past_key_value=past_key_value,
491 | use_cache=use_cache,
492 | output_attentions=output_attentions,
493 | input_relations=input_relations
494 | )
495 |
496 | # layer_outputs is a tuple with:
497 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
498 | if use_cache is False:
499 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
500 |
501 | hidden_states, present_key_value_state = layer_outputs[:2]
502 |
503 | # We share the position biases between the layers - the first layer store them
504 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
505 | # (cross-attention position bias), (cross-attention weights)
506 | position_bias = layer_outputs[2]
507 | if self.is_decoder and encoder_hidden_states is not None:
508 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
509 | # append next layer key value states
510 | if use_cache:
511 | present_key_value_states = present_key_value_states + (present_key_value_state,)
512 |
513 | if output_attentions:
514 | all_attentions = all_attentions + (layer_outputs[3],)
515 | if self.is_decoder:
516 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
517 |
518 | hidden_states = self.final_layer_norm(hidden_states)
519 | hidden_states = self.dropout(hidden_states)
520 |
521 | # Add last layer
522 | if output_hidden_states:
523 | all_hidden_states = all_hidden_states + (hidden_states,)
524 |
525 | if not return_dict:
526 | return tuple(
527 | v
528 | for v in [
529 | hidden_states,
530 | present_key_value_states,
531 | all_hidden_states,
532 | all_attentions,
533 | all_cross_attentions,
534 | ]
535 | if v is not None
536 | )
537 | return BaseModelOutputWithPastAndCrossAttentions(
538 | last_hidden_state=hidden_states,
539 | past_key_values=present_key_value_states,
540 | hidden_states=all_hidden_states,
541 | attentions=all_attentions,
542 | cross_attentions=all_cross_attentions,
543 | )
544 |
545 |
546 | class RelationalLongT5Block(LongT5Block):
547 | def forward(
548 | self,
549 | hidden_states,
550 | attention_mask=None,
551 | position_bias=None,
552 | encoder_hidden_states=None,
553 | encoder_attention_mask=None,
554 | encoder_decoder_position_bias=None,
555 | layer_head_mask=None,
556 | cross_attn_layer_head_mask=None,
557 | past_key_value=None,
558 | use_cache=False,
559 | output_attentions=False,
560 | return_dict=True,
561 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
562 | ):
563 |
564 | if past_key_value is not None:
565 | if not self.is_decoder:
566 | logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
567 | expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
568 |
569 | if len(past_key_value) != expected_num_past_key_values:
570 | raise ValueError(
571 | f"There should be {expected_num_past_key_values} past states. "
572 | f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
573 | f"Got {len(past_key_value)} past key / value states"
574 | )
575 |
576 | self_attn_past_key_value = past_key_value[:2]
577 | cross_attn_past_key_value = past_key_value[2:]
578 | else:
579 | self_attn_past_key_value, cross_attn_past_key_value = None, None
580 |
581 | self_attention_outputs = self.layer[0](
582 | hidden_states,
583 | attention_mask=attention_mask,
584 | position_bias=position_bias,
585 | layer_head_mask=layer_head_mask,
586 | past_key_value=self_attn_past_key_value,
587 | use_cache=use_cache,
588 | output_attentions=output_attentions,
589 | input_relations=input_relations
590 | )
591 | hidden_states, present_key_value_state = self_attention_outputs[:2]
592 | attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
593 |
594 | # clamp inf values to enable fp16 training
595 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
596 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000
597 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
598 |
599 | do_cross_attention = self.is_decoder and encoder_hidden_states is not None
600 | if do_cross_attention:
601 | # the actual query length is unknown for cross attention
602 | # if using past key value states. Need to inject it here
603 | if present_key_value_state is not None:
604 | query_length = present_key_value_state[0].shape[2]
605 | else:
606 | query_length = None
607 |
608 | cross_attention_outputs = self.layer[1](
609 | hidden_states,
610 | key_value_states=encoder_hidden_states,
611 | attention_mask=encoder_attention_mask,
612 | position_bias=encoder_decoder_position_bias,
613 | layer_head_mask=cross_attn_layer_head_mask,
614 | past_key_value=cross_attn_past_key_value,
615 | query_length=query_length,
616 | use_cache=use_cache,
617 | output_attentions=output_attentions,
618 | input_relations=input_relations
619 | )
620 | hidden_states = cross_attention_outputs[0]
621 |
622 | # clamp inf values to enable fp16 training
623 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
624 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000
625 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
626 |
627 | # Combine self attn and cross attn key value states
628 | if present_key_value_state is not None:
629 | present_key_value_state = present_key_value_state + cross_attention_outputs[1]
630 |
631 | # Keep cross-attention outputs and relative position weights
632 | attention_outputs = attention_outputs + cross_attention_outputs[2:]
633 |
634 | # Apply Feed Forward layer
635 | hidden_states = self.layer[-1](hidden_states)
636 |
637 | # clamp inf values to enable fp16 training
638 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
639 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000
640 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
641 |
642 | outputs = (hidden_states,)
643 |
644 | if use_cache:
645 | outputs = outputs + (present_key_value_state,) + attention_outputs
646 | else:
647 | outputs = outputs + attention_outputs
648 |
649 | return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
650 |
651 |
652 | class RelationalLongT5LayerSelfAttention(LongT5LayerSelfAttention):
653 | def forward(
654 | self,
655 | hidden_states,
656 | attention_mask=None,
657 | position_bias=None,
658 | layer_head_mask=None,
659 | past_key_value=None,
660 | use_cache=False,
661 | output_attentions=False,
662 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
663 | ):
664 | normed_hidden_states = self.layer_norm(hidden_states)
665 | kwargs = {'input_relations': input_relations} \
666 | if 'input_relations' in inspect.getfullargspec(self.SelfAttention.forward)[0] else {}
667 | attention_output = self.SelfAttention(
668 | normed_hidden_states,
669 | mask=attention_mask,
670 | position_bias=position_bias,
671 | layer_head_mask=layer_head_mask,
672 | past_key_value=past_key_value,
673 | use_cache=use_cache,
674 | output_attentions=output_attentions,
675 | **kwargs
676 | )
677 | hidden_states = hidden_states + self.dropout(attention_output[0])
678 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
679 | return outputs
680 |
681 |
682 | class RelationalLongT5LayerCrossAttention(LongT5LayerCrossAttention):
683 | def forward(
684 | self,
685 | hidden_states,
686 | key_value_states,
687 | attention_mask=None,
688 | position_bias=None,
689 | layer_head_mask=None,
690 | past_key_value=None,
691 | use_cache=False,
692 | query_length=None,
693 | output_attentions=False,
694 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
695 | ):
696 | normed_hidden_states = self.layer_norm(hidden_states)
697 | kwargs = {'input_relations': input_relations} \
698 | if 'input_relations' in inspect.getfullargspec(self.EncDecAttention.forward)[0] else {}
699 | attention_output = self.EncDecAttention(
700 | normed_hidden_states,
701 | mask=attention_mask,
702 | key_value_states=key_value_states,
703 | position_bias=position_bias,
704 | layer_head_mask=layer_head_mask,
705 | past_key_value=past_key_value,
706 | use_cache=use_cache,
707 | query_length=query_length,
708 | output_attentions=output_attentions,
709 | **kwargs
710 | )
711 | layer_output = hidden_states + self.dropout(attention_output[0])
712 | outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
713 | return outputs
714 |
715 |
716 | class RelationalLongT5LayerLocalSelfAttention(LongT5LayerLocalSelfAttention):
717 | def forward(
718 | self,
719 | hidden_states,
720 | attention_mask=None,
721 | position_bias=None,
722 | layer_head_mask=None,
723 | output_attentions=False,
724 | input_relations=None, # will hold (batch, seq_length, seq_length, num_relation_kinds)
725 | **kwargs: Any, # to accept past_key_value and use_cache kwargs
726 | ):
727 | normed_hidden_states = self.layer_norm(hidden_states)
728 | kwargs = {'input_relations': input_relations} \
729 | if 'input_relations' in inspect.getfullargspec(self.LocalSelfAttention.forward)[0] else {}
730 | attention_output = self.LocalSelfAttention(
731 | normed_hidden_states,
732 | mask=attention_mask,
733 | position_bias=position_bias,
734 | layer_head_mask=layer_head_mask,
735 | output_attentions=output_attentions,
736 | **kwargs
737 | )
738 | hidden_states = hidden_states + self.dropout(attention_output[0])
739 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
740 | return outputs
741 |
742 |
743 | class RelationalLongT5LayerTransientGlobalSelfAttention(LongT5LayerTransientGlobalSelfAttention):
744 | def forward(
745 | self,
746 | hidden_states,
747 | attention_mask=None,
748 | position_bias=None,
749 | layer_head_mask=None,
750 | output_attentions=False,
751 | input_relations=None, # will hold (batch, seq_length, seq_length, num_relation_kinds)
752 | **kwargs: Any, # to accept past_key_value and use_cache kwargs
753 | ):
754 | normed_hidden_states = self.layer_norm(hidden_states)
755 | kwargs = {'input_relations': input_relations} \
756 | if 'input_relations' in inspect.getfullargspec(self.TransientGlobalSelfAttention.forward)[0] else {}
757 | attention_output = self.TransientGlobalSelfAttention(
758 | normed_hidden_states,
759 | mask=attention_mask,
760 | position_bias=position_bias,
761 | layer_head_mask=layer_head_mask,
762 | output_attentions=output_attentions,
763 | **kwargs
764 | )
765 | hidden_states = hidden_states + self.dropout(attention_output[0])
766 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
767 | return outputs
768 |
769 |
770 | class RelationalLongT5Attention(LongT5Attention):
771 | def forward(
772 | self,
773 | hidden_states,
774 | mask=None,
775 | key_value_states=None,
776 | position_bias=None,
777 | past_key_value=None,
778 | layer_head_mask=None,
779 | query_length=None,
780 | use_cache=False,
781 | output_attentions=False,
782 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
783 | ):
784 | """
785 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
786 | """
787 |
788 | # Input is (batch_size, seq_length, dim)
789 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
790 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
791 | batch_size, seq_length = hidden_states.shape[:2]
792 |
793 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model"
794 | assert input_relations.shape == (batch_size, seq_length, seq_length)
795 |
796 | # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
797 | relation_k_embeds = self.relation_k_emb(input_relations)
798 | relation_v_embeds = self.relation_v_emb(input_relations)
799 |
800 | real_seq_length = seq_length
801 |
802 | if past_key_value is not None:
803 | assert (
804 | len(past_key_value) == 2
805 | ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
806 | real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
807 |
808 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
809 |
810 | def shape(states):
811 | """projection"""
812 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
813 |
814 | def unshape(states):
815 | """reshape"""
816 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
817 |
818 | def project(hidden_states, proj_layer, key_value_states, past_key_value):
819 | """projects hidden states correctly to key/query states"""
820 | if key_value_states is None:
821 | # self-attn
822 | # (batch_size, n_heads, seq_length, dim_per_head)
823 | hidden_states = shape(proj_layer(hidden_states))
824 | elif past_key_value is None:
825 | # cross-attn
826 | # (batch_size, n_heads, seq_length, dim_per_head)
827 | hidden_states = shape(proj_layer(key_value_states))
828 |
829 | if past_key_value is not None:
830 | if key_value_states is None:
831 | # self-attn
832 | # (batch_size, n_heads, key_length, dim_per_head)
833 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
834 | else:
835 | # cross-attn
836 | hidden_states = past_key_value
837 | return hidden_states
838 |
839 | # get query states
840 | query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
841 |
842 | # get key/value states
843 | key_states = project(
844 | hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
845 | )
846 | value_states = project(
847 | hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
848 | )
849 |
850 | # compute scores
851 | scores = torch.matmul(
852 | query_states, key_states.transpose(3, 2)
853 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
854 |
855 | # q_t is [batch, seq_length, n_heads, dim_per_head]
856 | q_t = query_states.permute(0, 2, 1, 3)
857 |
858 | # r_t is [batch, seq_length, dim_per_head, seq_length]
859 | r_t = relation_k_embeds.transpose(-2, -1)
860 |
861 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length]
862 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length]
863 |
864 | # Add to scores
865 | scores += q_tr_tmatmul_t
866 |
867 | if position_bias is None:
868 | if not self.has_relative_attention_bias:
869 | position_bias = torch.zeros(
870 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
871 | )
872 | if self.gradient_checkpointing and self.training:
873 | position_bias.requires_grad = True
874 | else:
875 | position_bias = self.compute_bias(real_seq_length, key_length)
876 |
877 | # if key and values are already calculated
878 | # we want only the last query position bias
879 | if past_key_value is not None:
880 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
881 |
882 | if mask is not None:
883 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
884 |
885 | if self.pruned_heads:
886 | mask = torch.ones(position_bias.shape[1])
887 | mask[list(self.pruned_heads)] = 0
888 | position_bias_masked = position_bias[:, mask.bool()]
889 | else:
890 | position_bias_masked = position_bias
891 |
892 | scores += position_bias_masked
893 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
894 | scores
895 | ) # (batch_size, n_heads, seq_length, key_length)
896 | attn_weights = nn.functional.dropout(
897 | attn_weights, p=self.dropout, training=self.training
898 | ) # (batch_size, n_heads, seq_length, key_length)
899 |
900 | # Mask heads if we want to
901 | if layer_head_mask is not None:
902 | attn_weights = attn_weights * layer_head_mask
903 |
904 | # [batch, n_heads, seq_length, seq_length]
905 | wv_matmul = torch.matmul(attn_weights, value_states)
906 |
907 | # w_t is [batch, seq_length, n_heads, seq_length]
908 | w_t = attn_weights.permute(0, 2, 1, 3)
909 |
910 | # [batch, seq_length, n_heads, seq_length]
911 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds)
912 |
913 | attn_output = unshape(wv_matmul + w_tr_matmul.permute(0, 2, 1, 3)) # (batch_size, seq_length, dim)
914 | attn_output = self.o(attn_output)
915 |
916 | present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
917 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
918 |
919 | if output_attentions:
920 | outputs = outputs + (attn_weights,)
921 | return outputs
922 |
923 |
924 | class RelationalLongT5LocalAttention(LongT5LocalAttention):
925 | def forward(
926 | self,
927 | hidden_states,
928 | mask=None,
929 | position_bias=None,
930 | layer_head_mask=None,
931 | output_attentions=False,
932 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
933 | ):
934 | batch_size, seq_length = hidden_states.shape[:2]
935 |
936 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model"
937 | assert input_relations.shape == (batch_size, seq_length, seq_length)
938 |
939 | # (batch_size, seq_length, seq_length, dim_per_head)
940 | relation_k_embeds = self.relation_k_emb(input_relations)
941 | relation_v_embeds = self.relation_v_emb(input_relations)
942 |
943 | def shape(states):
944 | """projection"""
945 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
946 |
947 | def unshape(states):
948 | """reshape"""
949 | return states.contiguous().view(batch_size, -1, self.inner_dim)
950 |
951 | # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
952 | query_states = shape(self.q(hidden_states))
953 | key_states = shape(self.k(hidden_states))
954 | value_states = shape(self.v(hidden_states))
955 |
956 | # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
957 | query_states = _split_into_blocks(query_states, self.block_len, dim=1)
958 | key_states = _split_into_blocks(key_states, self.block_len, dim=1)
959 | value_states = _split_into_blocks(value_states, self.block_len, dim=1)
960 |
961 | # Split relation_k_embeds, relation_v_embeds into blocks -> (batch_size, num_blocks, block_len, seq_length, dim_per_head)
962 | relation_k_embeds = _split_into_blocks(relation_k_embeds, self.block_len, dim=1)
963 | relation_v_embeds = _split_into_blocks(relation_v_embeds, self.block_len, dim=1)
964 |
965 | # Resize relation_k_embeds and relation_v_embeds -> (batch_size, num_blocks, block_len, block_len, dim_per_head)
966 | # each block can only have relations with another block, not with the full length
967 | num_blocks = query_states.shape[1]
968 | def get_new_relation_embeds(relation_embeds):
969 | new_relation_embeds = []
970 | for block_i in range(num_blocks):
971 | new_relation_embeds.append(
972 | relation_embeds[:, block_i, :, block_i * self.block_len: (block_i + 1) * self.block_len,
973 | :].unsqueeze(1)
974 | )
975 | if new_relation_embeds[-1].shape[-2] % self.block_len != 0:
976 | # pad tensor to multiple of block_len
977 | new_relation_embeds[-1] = _pad_to_multiple(
978 | new_relation_embeds[-1], self.block_len, dim=3, pad_value=0
979 | )
980 | return torch.cat(new_relation_embeds, dim=1)
981 | relation_k_embeds = get_new_relation_embeds(relation_k_embeds)
982 | relation_v_embeds = get_new_relation_embeds(relation_v_embeds)
983 |
984 | # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
985 | key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
986 | value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
987 |
988 | # Concatenate 3 blocks -> (batch_size, num_blocks, 3 * block_len, block_len, dim_per_head)
989 | relation_k_embeds = _concatenate_3_blocks(relation_k_embeds, block_dim=1, sequence_dim=2)
990 | relation_v_embeds = _concatenate_3_blocks(relation_v_embeds, block_dim=1, sequence_dim=2)
991 |
992 | # Compute scores
993 | scores = torch.einsum(
994 | "...qhd,...khd->...hqk", query_states, key_states
995 | ) # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
996 |
997 | # q_t is (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
998 | q_t = _concatenate_3_blocks(query_states, block_dim=1, sequence_dim=2)
999 |
1000 | # r_t is (batch_size, num_blocks, 3 * block_len, dim_per_head, block_len)
1001 | r_t = relation_k_embeds.transpose(-2, -1)
1002 |
1003 | # (batch_size, num_blocks, 3 * block_len, n_heads, block_len)
1004 | q_tr_t_matmul = torch.matmul(q_t, r_t)
1005 |
1006 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
1007 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 1, 3, 4, 2)
1008 |
1009 | # Add to scores
1010 | scores += q_tr_tmatmul_t
1011 |
1012 | if position_bias is None:
1013 | # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
1014 | if not self.has_relative_attention_bias:
1015 | position_bias = torch.zeros(
1016 | (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype
1017 | )
1018 | if self.gradient_checkpointing and self.training:
1019 | position_bias.requires_grad = True
1020 | else:
1021 | position_bias = self.compute_bias(self.block_len)
1022 |
1023 | if mask is not None:
1024 | # Replace masked positions with -1e10 (according to the original implementation)
1025 | mask = torch.where(mask > 0, 0.0, -1e10)
1026 | # We need to adjust position bias shape to be sum with mask
1027 | position_bias = position_bias + mask.transpose(1, 2)
1028 |
1029 | scores += position_bias
1030 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
1031 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
1032 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
1033 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
1034 |
1035 | # Mask heads if we want to
1036 | if layer_head_mask is not None:
1037 | attn_weights = attn_weights * layer_head_mask
1038 |
1039 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head)
1040 | attn_weights = attn_weights.type(value_states.dtype)
1041 | wv_matmul = torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
1042 |
1043 | # (batch_size, num_blocks, block_len, n_heads, 3 * block_len)
1044 | w_t = attn_weights.permute(0, 1, 3, 2, 4)
1045 |
1046 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head)
1047 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds.transpose(-3, -2))
1048 |
1049 | # (batch_size, seq_length, dim_per_head)
1050 | attn_output = unshape(wv_matmul + w_tr_matmul)[:, :seq_length, :]
1051 |
1052 | # (batch_size, seq_length, d_model)
1053 | attn_output = self.o(attn_output)
1054 |
1055 | present_key_value_state = None
1056 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
1057 |
1058 | if output_attentions:
1059 | outputs = outputs + (attn_weights,)
1060 | return outputs
1061 |
1062 |
1063 | class RelationalLongT5TransientGlobalAttention(LongT5TransientGlobalAttention):
1064 | def forward(
1065 | self,
1066 | hidden_states,
1067 | mask=None,
1068 | position_bias=None,
1069 | layer_head_mask=None,
1070 | output_attentions=False,
1071 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds)
1072 | ):
1073 | batch_size, seq_length = hidden_states.shape[:2]
1074 |
1075 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model"
1076 | assert input_relations.shape == (batch_size, seq_length, seq_length)
1077 |
1078 | # (batch_size, seq_length, seq_length, dim_per_head)
1079 | relation_k_embeds = self.relation_k_emb(input_relations)
1080 | relation_v_embeds = self.relation_v_emb(input_relations)
1081 |
1082 | def shape(states):
1083 | """projection"""
1084 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
1085 |
1086 | def unshape(states):
1087 | """reshape"""
1088 | return states.contiguous().view(batch_size, -1, self.inner_dim)
1089 |
1090 | # Prepare components for transient-global attention
1091 | # Obtain block_ids and global_segment_ids
1092 | # global_seq_len := seq_len // self.global_block_size
1093 | # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
1094 | block_ids, global_segment_ids = _make_global_fixed_block_ids(
1095 | mask if mask is not None else torch.ones(hidden_states.shape[:-1]),
1096 | self.global_block_size,
1097 | )
1098 | # Create global inputs
1099 | _global_seq_len = global_segment_ids.shape[-1]
1100 | global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
1101 | global_inputs = self.global_input_layer_norm(global_inputs)
1102 |
1103 | # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
1104 | query_states = shape(self.q(hidden_states))
1105 | key_states = shape(self.k(hidden_states))
1106 | value_states = shape(self.v(hidden_states))
1107 | # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
1108 | side_key_states = shape(self.k(global_inputs))
1109 | side_value_states = shape(self.v(global_inputs))
1110 |
1111 | # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
1112 | query_states = _split_into_blocks(query_states, self.block_len, dim=1)
1113 | key_states = _split_into_blocks(key_states, self.block_len, dim=1)
1114 | value_states = _split_into_blocks(value_states, self.block_len, dim=1)
1115 |
1116 | # Split relation_k_embeds, relation_v_embeds into blocks -> (batch_size, num_blocks, block_len, seq_length, dim_per_head)
1117 | relation_k_embeds = _split_into_blocks(relation_k_embeds, self.block_len, dim=1)
1118 | relation_v_embeds = _split_into_blocks(relation_v_embeds, self.block_len, dim=1)
1119 |
1120 | # Resize relation_k_embeds and relation_v_embeds -> (batch_size, num_blocks, block_len, block_len, dim_per_head)
1121 | # each block can only have relations with another block, not with the full length
1122 | num_blocks = query_states.shape[1]
1123 | def get_new_relation_embeds(relation_embeds):
1124 | new_relation_embeds = []
1125 | for block_i in range(num_blocks):
1126 | new_relation_embeds.append(
1127 | relation_embeds[:, block_i, :, block_i * self.block_len: (block_i + 1) * self.block_len, :].unsqueeze(1)
1128 | )
1129 | if new_relation_embeds[-1].shape[-2] % self.block_len != 0:
1130 | # pad tensor to multiple of block_len
1131 | new_relation_embeds[-1] = _pad_to_multiple(
1132 | new_relation_embeds[-1], self.block_len, dim=3, pad_value=0
1133 | )
1134 | return torch.cat(new_relation_embeds, dim=1)
1135 | relation_k_embeds = get_new_relation_embeds(relation_k_embeds)
1136 | relation_v_embeds = get_new_relation_embeds(relation_v_embeds)
1137 |
1138 | # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
1139 | key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
1140 | value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
1141 |
1142 | # Concatenate 3 blocks -> (batch_size, num_blocks, 3 * block_len, block_len, dim_per_head)
1143 | relation_k_embeds = _concatenate_3_blocks(relation_k_embeds, block_dim=1, sequence_dim=2)
1144 | relation_v_embeds = _concatenate_3_blocks(relation_v_embeds, block_dim=1, sequence_dim=2)
1145 |
1146 | # Tile side inputs across local key/value blocks
1147 | # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
1148 | reps = [1] * (side_key_states.ndim + 1)
1149 | reps[1] = key_states.shape[1]
1150 | side_key_states = side_key_states.unsqueeze(1).repeat(reps)
1151 | side_value_states = side_value_states.unsqueeze(1).repeat(reps)
1152 |
1153 | # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
1154 | # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
1155 | key_states = torch.cat([key_states, side_key_states], dim=2)
1156 | value_states = torch.cat([value_states, side_value_states], dim=2)
1157 |
1158 | # Add zeros to relation_k_embeds and relation_v_embeds
1159 | # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, block_len, dim_per_head)
1160 | zeros_shape_to_add = list(relation_k_embeds.shape[:2]) + [side_value_states.shape[2]] + list(relation_k_embeds.shape[3:])
1161 | relation_k_embeds = torch.cat(
1162 | [relation_k_embeds, torch.zeros(zeros_shape_to_add, device=relation_k_embeds.device)], dim=2
1163 | )
1164 | relation_v_embeds = torch.cat(
1165 | [relation_v_embeds, torch.zeros(zeros_shape_to_add, device=relation_v_embeds.device)], dim=2
1166 | )
1167 |
1168 | # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len)
1169 | scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
1170 |
1171 | # q_t is (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
1172 | q_t = _concatenate_3_blocks(query_states, block_dim=1, sequence_dim=2)
1173 | q_t = torch.cat([q_t, side_value_states], dim=2)
1174 |
1175 | # r_t is (batch_size, num_blocks, 3 * block_len + global_seq_len, dim_per_head, block_len)
1176 | r_t = relation_k_embeds.transpose(-2, -1)
1177 |
1178 | # (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, block_len)
1179 | q_tr_t_matmul = torch.matmul(q_t, r_t)
1180 |
1181 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
1182 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 1, 3, 4, 2)
1183 |
1184 | # Add to scores
1185 | scores += q_tr_tmatmul_t
1186 |
1187 | if mask is not None:
1188 | # We need to adjust position bias shape to be sum with mask
1189 | local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device)
1190 | # Replace masked positions with -10_000 (according to the original implementation)
1191 | local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10)
1192 | else:
1193 | local_attention_mask = None
1194 |
1195 | if position_bias is None:
1196 | # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
1197 | if not self.has_relative_attention_bias:
1198 | position_bias = torch.zeros(
1199 | (1, 1, self.n_heads, self.block_len, 3 * self.block_len),
1200 | device=scores.device,
1201 | dtype=scores.dtype,
1202 | )
1203 | if self.gradient_checkpointing and self.training:
1204 | position_bias.requires_grad = True
1205 | else:
1206 | position_bias = self.compute_bias(self.block_len)
1207 |
1208 | if local_attention_mask is not None:
1209 | # (batch_size, 1, n_heads, block_len, 3 * block_len)
1210 | position_bias = position_bias + local_attention_mask.transpose(1, 2)
1211 | position_bias = position_bias.type(scores.dtype)
1212 |
1213 | # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
1214 | if mask is None:
1215 | mask = torch.ones(batch_size, seq_length)
1216 | # (batch_size, num_heads, seq_len, global_seq_len)
1217 | side_position_bias = self.compute_side_bias(mask, global_segment_ids)
1218 | # (batch_size, num_blocks, num_heads, block_len, global_seq_len)
1219 | side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2)
1220 | side_position_bias = side_position_bias.type(scores.dtype).to(scores.device)
1221 | # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len)
1222 | position_bias = torch.cat([position_bias, side_position_bias], dim=-1)
1223 |
1224 | scores += position_bias
1225 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
1226 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
1227 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
1228 |
1229 | # Mask heads if we want to
1230 | if layer_head_mask is not None:
1231 | attn_weights = attn_weights * layer_head_mask
1232 |
1233 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head)
1234 | attn_weights = attn_weights.type(value_states.dtype)
1235 | wv_matmul = torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
1236 |
1237 | # (batch_size, num_blocks, block_len, n_heads, 3 * block_len + global_seq_len)
1238 | w_t = attn_weights.permute(0, 1, 3, 2, 4)
1239 |
1240 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head)
1241 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds.transpose(-3, -2))
1242 |
1243 | # (batch_size, seq_length, dim_per_head)
1244 | attn_output = unshape(wv_matmul + w_tr_matmul)[:, :seq_length, :]
1245 |
1246 | # (batch_size, seq_length, d_model)
1247 | attn_output = self.o(attn_output)
1248 |
1249 | present_key_value_state = None
1250 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
1251 |
1252 | if output_attentions:
1253 | outputs = outputs + (attn_weights,)
1254 | return outputs
1255 |
1256 |
--------------------------------------------------------------------------------