├── MANIFEST.in ├── assets ├── tableQA.gif └── ezgif.com-gif-maker.gif:Zone.Identifier ├── requirements.txt ├── src └── ratransformers │ ├── roberta.py │ ├── gpt2.py │ ├── bert.py │ ├── bart.py │ ├── __init__.py │ ├── t5.py │ └── longt5.py ├── setup.py ├── .github └── workflows │ └── build_and_test.yml ├── LICENSE ├── tests └── test_ratransformer_forward.py ├── .gitignore ├── README.md └── notebooks ├── NER_conll_example.ipynb ├── TableQA_tabfact_example.ipynb └── text2sql_spider_example.ipynb /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include LICENSE 3 | include README.md 4 | -------------------------------------------------------------------------------- /assets/tableQA.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoaoLages/RATransformers/HEAD/assets/tableQA.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.1 2 | transformers>=4.6.1 3 | setuptools>=49.6.0 4 | torch>=1.9.1 5 | -------------------------------------------------------------------------------- /assets/ezgif.com-gif-maker.gif:Zone.Identifier: -------------------------------------------------------------------------------- 1 | [ZoneTransfer] 2 | ZoneId=3 3 | ReferrerUrl=https://ezgif.com/ 4 | HostUrl=https://s7.ezgif.com/save/ezgif-7-1c4c0d0f22.gif 5 | -------------------------------------------------------------------------------- /src/ratransformers/roberta.py: -------------------------------------------------------------------------------- 1 | from transformers.models.bert.modeling_bert import BertSelfAttention 2 | from transformers.models.roberta.modeling_roberta import RobertaSelfAttention 3 | 4 | 5 | class RobertaRelationalSelfAttention(BertSelfAttention): 6 | pass -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | with open('README.md', encoding='utf-8') as f: 5 | long_description = f.read() 6 | 7 | with open('requirements.txt', encoding='utf-8') as f: 8 | required = f.read().splitlines() 9 | 10 | 11 | setup( 12 | name='ratransformers', 13 | version='1.3.2', 14 | description='RATransformer - make a transformer model learn implicit relations passed in the input', 15 | long_description=long_description, 16 | long_description_content_type='text/markdown', 17 | url='https://github.com/JoaoLages/RATransformers', 18 | author='Joao Lages', 19 | author_email='joaop.glages@gmail.com', 20 | license='MIT', 21 | packages=find_packages('src'), 22 | package_dir={'': 'src'}, 23 | install_requires=required 24 | ) 25 | -------------------------------------------------------------------------------- /.github/workflows/build_and_test.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.7", "3.8", "3.9", "3.10"] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install flake8 pytest 23 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 24 | - name: Install package 25 | run: | 26 | pip install -e . 27 | - name: Test with pytest 28 | run: | 29 | python -m pytest 30 | - name: Save artifacts 31 | uses: actions/upload-artifact@v2 32 | with: 33 | name: plot-screenshots 34 | path: tmp/ 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 João Lages 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_ratransformer_forward.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from packaging import version 3 | from ratransformers import RATransformer 4 | from transformers import AutoModelForSeq2SeqLM 5 | import transformers 6 | 7 | 8 | class TestModelsForward: 9 | """ 10 | Test small supported models to see if nothing is breaking in forward step 11 | """ 12 | def ratransformer_forward(self, ratransformer: RATransformer): 13 | model = ratransformer.model 14 | tokenizer = ratransformer.tokenizer 15 | encoding = tokenizer( 16 | "this is just a dummy text", 17 | return_tensors="pt", 18 | input_relations=None 19 | ) 20 | if 'decoder_input_ids' in inspect.signature(model.forward).parameters: 21 | if version.parse(transformers.__version__) >= version.parse('4.13'): 22 | decoder_input_ids = model._prepare_decoder_input_ids_for_generation(encoding['input_ids'].shape[0], None, None) 23 | else: 24 | decoder_input_ids = model._prepare_decoder_input_ids_for_generation(encoding['input_ids'], None, None) 25 | encoding['decoder_input_ids'] = decoder_input_ids 26 | _ = model(**encoding) 27 | 28 | def test_t5(self): 29 | ratransformer = RATransformer( 30 | "t5-small", 31 | relation_kinds=['dummy1', 'dummy2'], 32 | model_cls=AutoModelForSeq2SeqLM 33 | ) 34 | self.ratransformer_forward(ratransformer) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # RATransformers 🐭 4 | 5 | ![PyPI - Latest Package Version](https://img.shields.io/pypi/v/ratransformers?logo=pypi&style=flat&color=orange) ![GitHub - License](https://img.shields.io/github/license/JoaoLages/ratransformers?logo=github&style=flat&color=green) 6 | 7 | **RATransformers**, short for Relation-Aware Transformers, is a package built on top of [transformers 🤗](https://github.com/huggingface/transformers) 8 | that enables the training/fine-tuning of models with extra relation-aware input features. 9 |
10 | 11 | ### Example - Encoding a table in TableQA (Question Answering on Tabular Data) 12 | ![](assets/tableQA.gif) 13 | 14 | [[Notebook Link](https://github.com/JoaoLages/RATransformers/blob/main/notebooks/TableQA_tabfact_example.ipynb)] 15 | 16 | In this example we can see that passing the table as text with no additional information to the model is a poor representation. 17 | 18 | With RATransformers 🐭 you are able to encode the table in a more structured way by passing specific relations within the input. 19 | RATransformers 🐭 also allows you to pass further features related with each input word/token. 20 | 21 | Check more examples in [[here](https://github.com/JoaoLages/RATransformers/blob/main/notebooks/)]. 22 | 23 | ## Installation 24 | 25 | Install directly from PyPI: 26 | 27 | pip install ratransformers 28 | 29 | ## Usage 30 | 31 | ```python 32 | from ratransformers import RATransformer 33 | from transformers import AutoModelForSequenceClassification 34 | 35 | 36 | ratransformer = RATransformer( 37 | "nielsr/tapex-large-finetuned-tabfact", # define the 🤗 model you want to load 38 | relation_kinds=['is_value_of_column', 'is_from_same_row'], # define the relations that you want to model in the input 39 | model_cls=AutoModelForSequenceClassification, # define the model class 40 | pretrained_tokenizer_name_or_path='facebook/bart-large' # define the tokenizer you want to load (in case it is not the same as the model) 41 | ) 42 | model = ratransformer.model 43 | tokenizer = ratransformer.tokenizer 44 | ``` 45 | 46 | With only these steps your RATransformer 🐭 is ready to be trained. 47 | 48 | More implementation details in [the examples here](https://github.com/JoaoLages/RATransformers/blob/main/notebooks/). 49 | 50 | ## How does it work? 51 | We modify the self-attention layers of the transformer model as explained in the section 3 of [the RAT-SQL paper](https://arxiv.org/pdf/1911.04942.pdf). 52 | 53 | ## Supported Models 54 | Currently we support a limited number of transformer models: 55 | - [BART](https://huggingface.co/docs/transformers/model_doc/bart) 56 | - [BERT](https://huggingface.co/docs/transformers/model_doc/bert) 57 | - [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) 58 | - [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta) 59 | - [T5](https://huggingface.co/docs/transformers/model_doc/t5) 60 | - [LongT5](https://huggingface.co/docs/transformers/model_doc/longt5) 61 | 62 | Want another model? Feel free to open an [Issue](https://github.com/JoaoLages/RATransformers/issues) or create a [Pull Request](https://github.com/JoaoLages/RATransformers/pulls) and let's get started 🚀 63 | -------------------------------------------------------------------------------- /src/ratransformers/gpt2.py: -------------------------------------------------------------------------------- 1 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class GPT2RelationalAttention(GPT2Attention): 7 | def __init__(self, *args, num_relation_kinds: int, use_same_relation_kv_emb: bool = True, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self.num_relation_kinds = num_relation_kinds 10 | self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0) 11 | if use_same_relation_kv_emb: 12 | self.relation_v_emb = self.relation_k_emb 13 | else: 14 | self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0) 15 | self.input_relation_kinds = [] # will hold (batch, seq_length, seq_length, num_relation_kinds) 16 | 17 | def forward( 18 | self, 19 | hidden_states, 20 | layer_past=None, 21 | attention_mask=None, 22 | head_mask=None, 23 | encoder_hidden_states=None, 24 | encoder_attention_mask=None, 25 | use_cache=False, 26 | output_attentions=False, 27 | ): 28 | if encoder_hidden_states is not None: 29 | if not hasattr(self, "q_attn"): 30 | raise ValueError( 31 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 32 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 33 | ) 34 | 35 | query = self.q_attn(hidden_states) 36 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 37 | attention_mask = encoder_attention_mask 38 | else: 39 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 40 | 41 | batch_size, seq_length = hidden_states.shape[:2] 42 | 43 | assert len(self.input_relation_kinds) == 1 44 | input_relation_kinds = self.input_relation_kinds[0] 45 | assert input_relation_kinds.shape == (batch_size, seq_length, seq_length) 46 | 47 | query = self._split_heads(query, self.num_heads, self.head_dim) 48 | key = self._split_heads(key, self.num_heads, self.head_dim) 49 | value = self._split_heads(value, self.num_heads, self.head_dim) 50 | 51 | if layer_past is not None: 52 | past_key, past_value = layer_past 53 | key = torch.cat((past_key, key), dim=-2) 54 | value = torch.cat((past_value, value), dim=-2) 55 | 56 | if use_cache is True: 57 | present = (key, value) 58 | else: 59 | present = None 60 | 61 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 62 | 63 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 64 | attn_output = self.c_proj(attn_output) 65 | attn_output = self.resid_dropout(attn_output) 66 | 67 | outputs = (attn_output, present) 68 | if output_attentions: 69 | outputs += (attn_weights,) 70 | 71 | return outputs # a, present, (attentions) 72 | 73 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 74 | 75 | input_relation_kinds = self.input_relation_kinds[0] 76 | relation_k_embeds = self.relation_k_emb(input_relation_kinds) 77 | 78 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 79 | 80 | # q_t is [batch, seq_length, n_heads, dim_per_head] 81 | q_t = query.permute(0, 2, 1, 3) 82 | 83 | # r_t is [batch, seq_length, dim_per_head, seq_length] 84 | r_t = relation_k_embeds.transpose(-2, -1) 85 | 86 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length] 87 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length] 88 | 89 | # Add to scores 90 | attn_weights += q_tr_tmatmul_t 91 | 92 | if self.scale_attn_weights: 93 | attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) 94 | 95 | if not self.is_cross_attention: 96 | # if only "normal" attention layer implements causal mask 97 | query_length, key_length = query.size(-2), key.size(-2) 98 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() 99 | attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) 100 | 101 | if attention_mask is not None: 102 | # Apply the attention mask 103 | attn_weights = attn_weights + attention_mask 104 | 105 | attn_weights = nn.Softmax(dim=-1)(attn_weights) 106 | attn_weights = self.attn_dropout(attn_weights) 107 | 108 | # Mask heads if we want to 109 | if head_mask is not None: 110 | attn_weights = attn_weights * head_mask 111 | 112 | attn_output = torch.matmul(attn_weights, value) 113 | 114 | relation_v_embeds = self.relation_v_emb(input_relation_kinds) 115 | 116 | # w_t is [batch, seq_length, n_heads, seq_length] 117 | w_t = attn_weights.permute(0, 2, 1, 3) 118 | 119 | # [batch, seq_length, n_heads, seq_length] 120 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds) 121 | 122 | attn_output += w_tr_matmul.permute(0, 2, 1, 3) 123 | 124 | return attn_output, attn_weights -------------------------------------------------------------------------------- /src/ratransformers/bert.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from transformers.models.bert.modeling_bert import BertSelfAttention 4 | import torch.nn as nn 5 | import torch 6 | 7 | 8 | class BertRelationalSelfAttention(BertSelfAttention): 9 | def __init__(self, *args, num_relation_kinds: int, use_same_relation_kv_emb: bool = True, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.num_relation_kinds = num_relation_kinds 12 | self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.attention_head_size, padding_idx=0) 13 | if use_same_relation_kv_emb: 14 | self.relation_v_emb = self.relation_k_emb 15 | else: 16 | self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.attention_head_size, padding_idx=0) 17 | self.input_relation_kinds = [] # will hold (batch, seq_length, seq_length, num_relation_kinds) 18 | 19 | def forward( 20 | self, 21 | hidden_states, 22 | attention_mask=None, 23 | head_mask=None, 24 | encoder_hidden_states=None, 25 | encoder_attention_mask=None, 26 | past_key_value=None, 27 | output_attentions=False, 28 | ): 29 | 30 | batch_size, seq_length = hidden_states.shape[:2] 31 | 32 | assert len(self.input_relation_kinds) == 1 33 | input_relation_kinds = self.input_relation_kinds[0] 34 | assert input_relation_kinds.shape == (batch_size, seq_length, seq_length) 35 | 36 | # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds) 37 | relation_k_embeds = self.relation_k_emb(input_relation_kinds) 38 | relation_v_embeds = self.relation_v_emb(input_relation_kinds) 39 | 40 | mixed_query_layer = self.query(hidden_states) 41 | 42 | # If this is instantiated as a cross-attention module, the keys 43 | # and values come from an encoder; the attention mask needs to be 44 | # such that the encoder's padding tokens are not attended to. 45 | is_cross_attention = encoder_hidden_states is not None 46 | 47 | if is_cross_attention and past_key_value is not None: 48 | # reuse k,v, cross_attentions 49 | key_layer = past_key_value[0] 50 | value_layer = past_key_value[1] 51 | attention_mask = encoder_attention_mask 52 | elif is_cross_attention: 53 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 54 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 55 | attention_mask = encoder_attention_mask 56 | elif past_key_value is not None: 57 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 58 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 59 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 60 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 61 | else: 62 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 63 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 64 | 65 | query_layer = self.transpose_for_scores(mixed_query_layer) 66 | 67 | if self.is_decoder: 68 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 69 | # Further calls to cross_attention layer can then reuse all cross-attention 70 | # key/value_states (first "if" case) 71 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 72 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 73 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 74 | # if encoder bi-directional self-attention `past_key_value` is always `None` 75 | past_key_value = (key_layer, value_layer) 76 | 77 | # Take the dot product between "query" and "key" to get the raw attention scores. 78 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 79 | 80 | # q_t is [batch, seq_length, n_heads, dim_per_head] 81 | q_t = query_layer.permute(0, 2, 1, 3) 82 | 83 | # r_t is [batch, seq_length, dim_per_head, seq_length] 84 | r_t = relation_k_embeds.transpose(-2, -1) 85 | 86 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length] 87 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length] 88 | 89 | # Add to scores 90 | attention_scores += q_tr_tmatmul_t 91 | 92 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 93 | seq_length = hidden_states.size()[1] 94 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 95 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 96 | distance = position_ids_l - position_ids_r 97 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 98 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 99 | 100 | if self.position_embedding_type == "relative_key": 101 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 102 | attention_scores = attention_scores + relative_position_scores 103 | elif self.position_embedding_type == "relative_key_query": 104 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 105 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 106 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 107 | 108 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 109 | if attention_mask is not None: 110 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 111 | attention_scores = attention_scores + attention_mask 112 | 113 | # Normalize the attention scores to probabilities. 114 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 115 | 116 | # This is actually dropping out entire tokens to attend to, which might 117 | # seem a bit unusual, but is taken from the original Transformer paper. 118 | attention_probs = self.dropout(attention_probs) 119 | 120 | # Mask heads if we want to 121 | if head_mask is not None: 122 | attention_probs = attention_probs * head_mask 123 | 124 | context_layer = torch.matmul(attention_probs, value_layer) 125 | 126 | # w_t is [batch, seq_length, n_heads, seq_length] 127 | w_t = attention_probs.permute(0, 2, 1, 3) 128 | 129 | # [batch, seq_length, n_heads, seq_length] 130 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds) 131 | 132 | context_layer += w_tr_matmul.permute(0, 2, 1, 3) 133 | 134 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 135 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 136 | context_layer = context_layer.view(*new_context_layer_shape) 137 | 138 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 139 | 140 | if self.is_decoder: 141 | outputs = outputs + (past_key_value,) 142 | return outputs 143 | -------------------------------------------------------------------------------- /notebooks/NER_conll_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "0b4a5857", 7 | "metadata": { 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "!pip install datasets" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "636ff184", 21 | "metadata": { 22 | "pycharm": { 23 | "name": "#%%\n" 24 | } 25 | }, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "Reusing dataset conll2003 (/home/ola/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" 32 | ] 33 | }, 34 | { 35 | "data": { 36 | "application/vnd.jupyter.widget-view+json": { 37 | "model_id": "da46312ea7114898bc51966b2fe78766", 38 | "version_major": 2, 39 | "version_minor": 0 40 | }, 41 | "text/plain": [ 42 | " 0%| | 0/3 [00:00 Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 27 | """Input shape: Batch x Time x Channel""" 28 | 29 | # if key_value_states are provided this layer is used as a cross-attention layer 30 | # for the decoder 31 | is_cross_attention = key_value_states is not None 32 | bsz, tgt_len, embed_dim = hidden_states.size() 33 | 34 | assert len(self.input_relation_kinds) == 1 35 | input_relation_kinds = self.input_relation_kinds[0] 36 | assert input_relation_kinds.shape == (bsz, tgt_len, tgt_len) 37 | 38 | # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds) 39 | relation_k_embeds = self.relation_k_emb(input_relation_kinds) 40 | relation_v_embeds = self.relation_v_emb(input_relation_kinds) 41 | 42 | # get query proj 43 | query_states = self.q_proj(hidden_states) * self.scaling 44 | # get key, value proj 45 | if is_cross_attention and past_key_value is not None: 46 | # reuse k,v, cross_attentions 47 | key_states = past_key_value[0] 48 | value_states = past_key_value[1] 49 | elif is_cross_attention: 50 | # cross_attentions 51 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 52 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 53 | elif past_key_value is not None: 54 | # reuse k, v, self_attention 55 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 56 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 57 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 58 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 59 | else: 60 | # self_attention 61 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 62 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 63 | 64 | if self.is_decoder: 65 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 66 | # Further calls to cross_attention layer can then reuse all cross-attention 67 | # key/value_states (first "if" case) 68 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 69 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 70 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 71 | # if encoder bi-directional self-attention `past_key_value` is always `None` 72 | past_key_value = (key_states, value_states) 73 | 74 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 75 | query_states = self._shape(query_states, tgt_len, bsz) 76 | src_len = key_states.size(2) 77 | 78 | # compute scores 79 | attn_weights = torch.matmul( 80 | query_states, key_states.transpose(3, 2) 81 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 82 | 83 | # q_t is [batch, seq_length, n_heads, dim_per_head] 84 | q_t = query_states.permute(0, 2, 1, 3) 85 | 86 | # r_t is [batch, seq_length, dim_per_head, seq_length] 87 | r_t = relation_k_embeds.transpose(-2, -1) 88 | 89 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length] 90 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length] 91 | 92 | # Add to scores 93 | attn_weights += q_tr_tmatmul_t 94 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 95 | 96 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 97 | raise ValueError( 98 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 99 | ) 100 | 101 | if attention_mask is not None: 102 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 103 | raise ValueError( 104 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 105 | ) 106 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 107 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 108 | 109 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 110 | 111 | if layer_head_mask is not None: 112 | if layer_head_mask.size() != (self.num_heads,): 113 | raise ValueError( 114 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 115 | ) 116 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 117 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 118 | 119 | if output_attentions: 120 | # this operation is a bit awkward, but it's required to 121 | # make sure that attn_weights keeps its gradient. 122 | # In order to do so, attn_weights have to be reshaped 123 | # twice and have to be reused in the following 124 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 125 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 126 | else: 127 | attn_weights_reshaped = None 128 | 129 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 130 | 131 | attn_output = torch.bmm(attn_probs, value_states.view(*proj_shape)) 132 | 133 | # w_t is [batch, seq_length, n_heads, seq_length] 134 | w_t = attn_probs.view(bsz, self.num_heads, tgt_len, src_len).permute(0, 2, 1, 3) 135 | 136 | # [batch, seq_length, n_heads, seq_length] 137 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds) 138 | 139 | attn_output += w_tr_matmul.permute(0, 2, 1, 3).view(bsz * self.num_heads, tgt_len, self.head_dim) 140 | 141 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 142 | raise ValueError( 143 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 144 | ) 145 | 146 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 147 | attn_output = attn_output.transpose(1, 2) 148 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) 149 | 150 | attn_output = self.out_proj(attn_output) 151 | 152 | return attn_output, attn_weights_reshaped, past_key_value -------------------------------------------------------------------------------- /notebooks/TableQA_tabfact_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6ac23e34", 6 | "metadata": {}, 7 | "source": [ 8 | "In this notebook we will show how to use RATransformers 🐭 to encode your data as the image shows.\n", 9 | "\n", 10 | "![](../assets/tableQA.gif)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "98ee93fc", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "!pip install pandas" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "7841fcd4", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import ratransformers\n", 31 | "import pandas as pd\n", 32 | "from transformers import BartTokenizerFast, BartForSequenceClassification\n", 33 | "\n", 34 | "\n", 35 | "ratransformer = ratransformers.RATransformer(\n", 36 | " \"nielsr/tapex-large-finetuned-tabfact\", # define the 🤗 model you want to load\n", 37 | " relation_kinds=['is_value_of_column', 'is_from_same_row'], # define the relations that you want to model in the input\n", 38 | " tokenizer_cls=BartTokenizerFast, # define the tokenizer class \n", 39 | " model_cls=BartForSequenceClassification, # define the model class\n", 40 | " pretrained_tokenizer_name_or_path='facebook/bart-large' # define the tokenizer you want to load (in case it is not the same as the model)\n", 41 | ")\n", 42 | "model = ratransformer.model\n", 43 | "tokenizer = ratransformer.tokenizer\n", 44 | "\n", 45 | "# create table\n", 46 | "data = {'Actors': [\"Brad Pitt\", \"Leonardo Di Caprio\", \"George Clooney\"], 'Number of movies': [\"87\", \"53\", \"69\"]}\n", 47 | "table = pd.DataFrame.from_dict(data)\n", 48 | "\n", 49 | "# turn into dict\n", 50 | "table_dict = {\"header\": list(table.columns), \"rows\": [list(row.values) for i,row in table.iterrows()]}" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "id": "df24f04d", 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/html": [ 62 | "
\n", 63 | "\n", 76 | "\n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | "
ActorsNumber of movies
0Brad Pitt87
1Leonardo Di Caprio53
2George Clooney69
\n", 102 | "
" 103 | ], 104 | "text/plain": [ 105 | " Actors Number of movies\n", 106 | "0 Brad Pitt 87\n", 107 | "1 Leonardo Di Caprio 53\n", 108 | "2 George Clooney 69" 109 | ] 110 | }, 111 | "execution_count": 2, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "table" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 3, 123 | "id": "3526ff9d", 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "{'header': ['Actors', 'Number of movies'],\n", 130 | " 'rows': [['Brad Pitt', '87'],\n", 131 | " ['Leonardo Di Caprio', '53'],\n", 132 | " ['George Clooney', '69']]}" 133 | ] 134 | }, 135 | "execution_count": 3, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "table_dict" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 4, 147 | "id": "9cb67881", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "from collections import defaultdict \n", 152 | "import itertools\n", 153 | "\n", 154 | "\n", 155 | "class IndexedRowTableLinearize:\n", 156 | " # adapted from https://github.com/microsoft/Table-Pretraining/blob/main/tapex/processor/table_linearize.py\n", 157 | " \"\"\"\n", 158 | " FORMAT: col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...\n", 159 | " \"\"\"\n", 160 | "\n", 161 | " def process_input(self, sentence, table_content):\n", 162 | " \"\"\"\n", 163 | " Given a sentence+ table, converts it into a flatten sequence with special symbols.\n", 164 | " Also returns the word relations\n", 165 | " \"\"\"\n", 166 | " assert \"header\" in table_content and \"rows\" in table_content\n", 167 | " \n", 168 | " input_text = sentence\n", 169 | " word_relations = defaultdict(dict)\n", 170 | " \n", 171 | " # process header\n", 172 | " input_text += \"col : \"\n", 173 | " col_id_to_span = {}\n", 174 | " for i, col in enumerate(table_content[\"header\"]):\n", 175 | " col_id_to_span[i] = (len(input_text), len(input_text) + len(col))\n", 176 | " input_text += f\"{col} | \" \n", 177 | " \n", 178 | " # process rows\n", 179 | " for row_index, row in enumerate(table_content[\"rows\"]):\n", 180 | " input_text += f\"row {row_index + 1} : \"\n", 181 | " \n", 182 | " all_cell_spans = []\n", 183 | " for i, cell_value in enumerate(row):\n", 184 | " cell_value = str(cell_value)\n", 185 | " cell_span = (len(input_text), len(input_text) + len(cell_value))\n", 186 | " all_cell_spans.append(cell_span)\n", 187 | " \n", 188 | " # save word relation - row value belong to specific column\n", 189 | " word_relations[cell_span][col_id_to_span[i]] = \"is_value_of_column\"\n", 190 | "\n", 191 | " input_text += f\"{cell_value} | \"\n", 192 | " \n", 193 | " # save word relation - all values belong to same row\n", 194 | " for (span_i, span_j) in itertools.permutations(all_cell_spans, 2):\n", 195 | " word_relations[span_i][span_j] = \"is_from_same_row\"\n", 196 | " \n", 197 | " if input_text.endswith(' | '): # remove trailing characters\n", 198 | " input_text = input_text[:-len(' | ')]\n", 199 | " \n", 200 | " return input_text, word_relations" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 5, 206 | "id": "136ca212", 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "tensor([0])\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "linearizer = IndexedRowTableLinearize()\n", 219 | "\n", 220 | "sentence = \"George Clooney has 69 movies\"\n", 221 | "joint_input, word_relations = linearizer.process_input(sentence, table_dict)\n", 222 | "\n", 223 | "# encode \n", 224 | "encoding = tokenizer(joint_input, return_tensors=\"pt\", input_relations=word_relations)\n", 225 | "\n", 226 | "# forward pass\n", 227 | "outputs = model(**encoding)\n", 228 | "\n", 229 | "# print prediction\n", 230 | "logits = outputs.logits\n", 231 | "print(logits.argmax(-1))" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 6, 237 | "id": "5aff40ff", 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "'George Clooney has 69 moviescol : Actors | Number of movies | row 1 : Brad Pitt | 87 | row 2 : Leonardo Di Caprio | 53 | row 3 : George Clooney | 69'" 244 | ] 245 | }, 246 | "execution_count": 6, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | } 250 | ], 251 | "source": [ 252 | "joint_input" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 7, 258 | "id": "c5131427", 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "data": { 263 | "text/plain": [ 264 | "defaultdict(dict,\n", 265 | " {(70, 79): {(34, 40): 'is_value_of_column',\n", 266 | " (82, 84): 'is_from_same_row'},\n", 267 | " (82, 84): {(43, 59): 'is_value_of_column',\n", 268 | " (70, 79): 'is_from_same_row'},\n", 269 | " (95, 113): {(34, 40): 'is_value_of_column',\n", 270 | " (116, 118): 'is_from_same_row'},\n", 271 | " (116, 118): {(43, 59): 'is_value_of_column',\n", 272 | " (95, 113): 'is_from_same_row'},\n", 273 | " (129, 143): {(34, 40): 'is_value_of_column',\n", 274 | " (146, 148): 'is_from_same_row'},\n", 275 | " (146, 148): {(43, 59): 'is_value_of_column',\n", 276 | " (129, 143): 'is_from_same_row'}})" 277 | ] 278 | }, 279 | "execution_count": 7, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "word_relations" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "id": "53aa1366", 291 | "metadata": {}, 292 | "source": [ 293 | "**Your model is now ready to be trained with relational information in the input!**\n", 294 | "\n", 295 | "Check the standard procedure to train HuggingFace 🤗 models in [here](https://huggingface.co/docs/transformers/training)." 296 | ] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.7.12" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 5 320 | } 321 | -------------------------------------------------------------------------------- /src/ratransformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.2' 2 | 3 | import inspect 4 | 5 | from transformers import AutoTokenizer, AutoModel, BertPreTrainedModel, BartPretrainedModel, T5PreTrainedModel, \ 6 | PreTrainedTokenizer, BatchEncoding, GPT2PreTrainedModel, PreTrainedModel, LongT5PreTrainedModel, PretrainedConfig, \ 7 | AutoConfig 8 | from transformers.dynamic_module_utils import get_class_from_dynamic_module 9 | from transformers.models.auto.auto_factory import _get_model_class, _BaseAutoModelClass 10 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel 11 | from typing import Dict, Optional, List, Tuple, Type 12 | import torch.nn as nn 13 | from transformers.utils import logging 14 | 15 | from ratransformers.longt5 import change_longt5_module_and_get_relational_emb_dim 16 | from ratransformers.t5 import change_t5_module_and_get_relational_emb_dim 17 | import torch 18 | import functools 19 | import numpy as np 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | def _change_this_module(model: PreTrainedModel, module_name: str, module: nn.Module, num_relation_kinds: int, 26 | use_same_relation_kv_emb: bool = True) -> None: 27 | relational_embedding_dim = None 28 | if isinstance(model, T5PreTrainedModel): 29 | relational_embedding_dim = change_t5_module_and_get_relational_emb_dim(module_name=module_name, module=module) 30 | 31 | elif isinstance(model, LongT5PreTrainedModel): 32 | relational_embedding_dim = change_longt5_module_and_get_relational_emb_dim(module_name=module_name, 33 | module=module) 34 | 35 | elif isinstance(model, BertPreTrainedModel): 36 | raise NotImplementedError(f"Not implemented for this version, downgrade to ratransformers==0.1.1") 37 | 38 | elif isinstance(model, BartPretrainedModel): 39 | raise NotImplementedError(f"Not implemented for this version, downgrade to ratransformers==0.1.1") 40 | 41 | elif isinstance(model, RobertaPreTrainedModel): 42 | raise NotImplementedError(f"Not implemented for this version, downgrade to ratransformers==0.1.1") 43 | 44 | elif isinstance(model, GPT2PreTrainedModel): 45 | raise NotImplementedError(f"Not implemented for this version, downgrade to ratransformers==0.1.1") 46 | 47 | else: 48 | raise NotImplementedError(f"Could not find implementation for the model type: '{type(model)}'. " 49 | f"Feel free to open an issue in GitHub to ask for its addition!") 50 | 51 | if relational_embedding_dim is None: 52 | return 53 | 54 | module.num_relation_kinds = num_relation_kinds 55 | module.relation_k_emb = nn.Embedding(num_relation_kinds + 1, relational_embedding_dim, padding_idx=0) 56 | if use_same_relation_kv_emb: 57 | module.relation_v_emb = module.relation_k_emb 58 | else: 59 | module.relation_v_emb = nn.Embedding(num_relation_kinds + 1, relational_embedding_dim, padding_idx=0) 60 | 61 | 62 | def _get_model_class_from_auto_class(cls, pretrained_model_name_or_path, **kwargs): 63 | if not cls.__name__.startswith('AutoModel'): 64 | return cls 65 | 66 | config = kwargs.pop("config", None) 67 | trust_remote_code = kwargs.pop("trust_remote_code", False) 68 | kwargs["_from_auto"] = True 69 | hub_kwargs_names = [ 70 | "cache_dir", 71 | "force_download", 72 | "local_files_only", 73 | "proxies", 74 | "resume_download", 75 | "revision", 76 | "subfolder", 77 | "use_auth_token", 78 | ] 79 | hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} 80 | if not isinstance(config, PretrainedConfig): 81 | config, kwargs = AutoConfig.from_pretrained( 82 | pretrained_model_name_or_path, 83 | return_unused_kwargs=True, 84 | trust_remote_code=trust_remote_code, 85 | **hub_kwargs, 86 | **kwargs, 87 | ) 88 | if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: 89 | if not trust_remote_code: 90 | raise ValueError( 91 | f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo " 92 | "on your local machine. Make sure you have read the code there to avoid malicious use, then set " 93 | "the option `trust_remote_code=True` to remove this error." 94 | ) 95 | if hub_kwargs.get("revision", None) is None: 96 | logger.warning( 97 | "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " 98 | "no malicious code has been contributed in a newer revision." 99 | ) 100 | class_ref = config.auto_map[cls.__name__] 101 | module_file, class_name = class_ref.split(".") 102 | model_class = get_class_from_dynamic_module( 103 | pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs 104 | ) 105 | return model_class 106 | 107 | elif type(config) in cls._model_mapping.keys(): 108 | model_class = _get_model_class(config, cls._model_mapping) 109 | return model_class 110 | 111 | raise ValueError( 112 | f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" 113 | f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." 114 | ) 115 | 116 | 117 | class RATransformer: 118 | 119 | def __init__(self, 120 | pretrained_model_name_or_path: Optional[str], 121 | relation_kinds: List[str], 122 | tokenizer_cls: Type[PreTrainedTokenizer] = AutoTokenizer, 123 | model_cls: Type[PreTrainedModel] = AutoModel, 124 | pretrained_tokenizer_name_or_path: Optional[str] = None, **kwargs): 125 | """ 126 | Returns an initialized and ready to test/train RATransformer 127 | Args: 128 | pretrained_model_name_or_path: model name or path to pass directly to Huggingface's `model_cls` class 129 | relation_kinds: list with all the possible relation kinds that can exist within the input 130 | tokenizer_cls: pass your own AutoTokenizer class to initialize the tokenizer 131 | model_cls: pass your own AutoModel class to initialize the model 132 | pretrained_tokenizer_name_or_path: Optional. Tokenizer name or path to pass directly 133 | to Huggingface's `tokenizer_cls` class. By default, will be equal to pretrained_model_name_or_path 134 | kwargs: other arguments to be passed to model_cls.from_pretrained and tokenizer_cls.from_pretrained methods 135 | """ 136 | 137 | model_cls_args = set(inspect.getfullargspec(model_cls.from_pretrained)[0]) 138 | model_kwargs = {k:v for k, v in kwargs.items() if k in model_cls_args} 139 | tokenizer_args = set(inspect.getfullargspec(tokenizer_cls.from_pretrained)[0]) 140 | tokenizer_kwargs = {k:v for k, v in kwargs.items() if k in tokenizer_args} 141 | 142 | self.relational_kind_to_index = {t: i + 1 for i, t in enumerate(relation_kinds)} 143 | 144 | pretrained_tokenizer_name_or_path = pretrained_tokenizer_name_or_path or pretrained_model_name_or_path 145 | if pretrained_tokenizer_name_or_path is not None: 146 | self.tokenizer = tokenizer_cls.from_pretrained( 147 | pretrained_model_name_or_path=pretrained_tokenizer_name_or_path, **tokenizer_kwargs 148 | ) 149 | else: 150 | logger.error( 151 | "`pretrained_model_name_or_path=None` and `pretrained_tokenizer_name_or_path=None` is not supported. " 152 | "Please pass at least `pretrained_tokenizer_name_or_path` to initialize the tokenizer" 153 | ) 154 | 155 | self.model = None 156 | if pretrained_model_name_or_path is not None: 157 | 158 | # change auto class to model class 159 | model_cls = _get_model_class_from_auto_class( 160 | cls=model_cls, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs 161 | ) 162 | 163 | def model_cls_load_pretrained_model_prefix_function(function): 164 | @functools.wraps(function) 165 | def run(model, *args, **kwargs): 166 | # change attention layers with relational ones, if not done before 167 | for module_name, module in model.named_modules(): 168 | _change_this_module( 169 | model=model, module_name=module_name, module=module, 170 | num_relation_kinds=len(relation_kinds) 171 | ) 172 | return function(model, *args, **kwargs) 173 | return run 174 | model_cls._load_pretrained_model = model_cls_load_pretrained_model_prefix_function( 175 | model_cls._load_pretrained_model 176 | ) 177 | 178 | self.model = model_cls.from_pretrained( 179 | pretrained_model_name_or_path=pretrained_model_name_or_path, **model_kwargs 180 | ) 181 | 182 | else: 183 | logger.warning( 184 | "`pretrained_model_name_or_path=None` which means that your RATransformer model won't be initialized, " 185 | "only its tokenizer." 186 | ) 187 | 188 | def model_prefix_function(function): 189 | @functools.wraps(function) 190 | def run(*args, **kwargs): 191 | if 'offset_mapping' in kwargs: 192 | del kwargs['offset_mapping'] 193 | return function(*args, **kwargs) 194 | return run 195 | 196 | def tokenizer_suffix_function(function): 197 | @functools.wraps(function) 198 | def run(*args, **kwargs): 199 | if 'return_tensors' not in kwargs or kwargs['return_tensors'] != 'pt': 200 | raise Exception("RATransformer's tokenizer expects `return_tensors='pt'`") 201 | if 'input_relations' not in kwargs: 202 | raise Exception("tokenizer expects 'input_relations' argument") 203 | input_relations = kwargs.pop('input_relations') 204 | kwargs['return_offsets_mapping'] = True 205 | out = function(*args, **kwargs) 206 | out['input_relations'] = self.get_new_input_relation_kinds( 207 | tokenizer_outputs=out, input_relations=input_relations 208 | ) 209 | return out 210 | return run 211 | 212 | if self.model is not None: 213 | # change model's call, forward and generate method 214 | self.model.__call__ = model_prefix_function(self.model.__call__) 215 | self.model.forward = model_prefix_function(self.model.forward) 216 | self.model.generate = model_prefix_function(self.model.generate) 217 | 218 | # change tokenizer's call and encode plus methods 219 | self.tokenizer.__call__ = tokenizer_suffix_function(self.tokenizer.__call__) 220 | self.tokenizer.batch_encode_plus = tokenizer_suffix_function(self.tokenizer.batch_encode_plus) 221 | self.tokenizer.encode_plus = tokenizer_suffix_function(self.tokenizer.encode_plus) 222 | 223 | def get_new_input_relation_kinds( 224 | self, 225 | tokenizer_outputs: BatchEncoding, 226 | input_relations: Optional[List[Dict[Tuple[int, int], Dict[Tuple[int, int], str]]]] = None 227 | ) -> torch.Tensor: 228 | 229 | assert 'offset_mapping' in tokenizer_outputs, "Run tokenizer with return_offsets_mapping=True" 230 | 231 | aux_input_relation_kinds = np.zeros( 232 | (len(tokenizer_outputs['input_ids']), len(tokenizer_outputs['input_ids'][0]), len(tokenizer_outputs['input_ids'][0])), 233 | dtype=np.int64 234 | ) 235 | if input_relations is not None: 236 | if isinstance(input_relations, dict): 237 | input_relations = [input_relations] 238 | assert len(tokenizer_outputs['offset_mapping']) == len(input_relations) 239 | for batch_idx, (token_mappings, relations) in enumerate(zip(tokenizer_outputs['offset_mapping'], input_relations)): 240 | 241 | for word_i_span, word_relations in relations.items(): 242 | word_i_token_ids = [ 243 | token_idx for token_idx, token_span in enumerate(token_mappings) 244 | if max(0, min(token_span[1], word_i_span[1]) - max(token_span[0], word_i_span[0])) > 0 # check for word/token overlaps 245 | ] 246 | for word_j_span, relation_kind in word_relations.items(): 247 | if relation_kind not in self.relational_kind_to_index: 248 | raise AttributeError( 249 | f"relation of type '{relation_kind}' not found, " 250 | f"RATransformer was initialized with these relation types: {list(self.relational_kind_to_index)}" 251 | ) 252 | for token_j_idx, token_span in enumerate(token_mappings): 253 | if max(0, min(token_span[1], word_j_span[1]) - max(token_span[0], word_j_span[0])) > 0: # check for word/token overlaps 254 | for token_i_idx in word_i_token_ids: 255 | try: 256 | aux_input_relation_kinds[batch_idx, token_i_idx, token_j_idx] = \ 257 | self.relational_kind_to_index[relation_kind] 258 | 259 | except IndexError: 260 | raise IndexError(f"Could not find relation kind '{relation_kind}'") 261 | 262 | return torch.from_numpy(aux_input_relation_kinds).to(tokenizer_outputs['input_ids'].device) 263 | -------------------------------------------------------------------------------- /notebooks/text2sql_spider_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ac2b9942", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install gdown" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "8223d01b", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "Downloading...\n", 24 | "From: https://drive.google.com/u/1/uc?export=download&confirm=k3T5&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0\n", 25 | "To: /home/ubuntu/RATransformers/notebooks/spider.zip\n", 26 | "100%|██████████| 99.7M/99.7M [00:01<00:00, 94.9MB/s]\n" 27 | ] 28 | }, 29 | { 30 | "data": { 31 | "text/plain": [ 32 | "'spider.zip'" 33 | ] 34 | }, 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "output_type": "execute_result" 38 | } 39 | ], 40 | "source": [ 41 | "import gdown\n", 42 | "\n", 43 | "spider_url = 'https://drive.google.com/u/1/uc?export=download&confirm=k3T5&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0'\n", 44 | "output = 'spider.zip'\n", 45 | "gdown.download(spider_url, output, quiet=False)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "fa4faf7b", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "!unzip -o spider.zip" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "id": "c2ae14d0", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import json\n", 66 | "\n", 67 | "with open('spider/tables.json') as fp:\n", 68 | " tables = {t['db_id']: t for t in json.load(fp)}\n", 69 | "\n", 70 | "with open('spider/train_spider.json') as fp:\n", 71 | " train_data = json.load(fp)\n", 72 | "\n", 73 | "with open('spider/train_others.json') as fp:\n", 74 | " train_data += json.load(fp)\n", 75 | "\n", 76 | "with open('spider/dev.json') as fp:\n", 77 | " test_data = json.load(fp)\n" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "id": "8242d534", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "Train: 8577 Skipped 82 samples with too long input.\n", 91 | "Test: 1034 Skipped 0 samples with too long input.\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "from collections import defaultdict\n", 97 | "\n", 98 | "def get_processed_data(raw_data):\n", 99 | " X, y, X_word_relations = [], [], []\n", 100 | " n_skip = 0\n", 101 | " for d in raw_data:\n", 102 | " input_text = d['question'] + f\" | {d['db_id']}\"\n", 103 | " \n", 104 | " word_relations = defaultdict(dict)\n", 105 | "\n", 106 | " table_span, table_i = None, None\n", 107 | " for i, c_name in tables[d['db_id']]['column_names_original']:\n", 108 | " if i < 0: continue\n", 109 | " if table_i != i:\n", 110 | " table_i = i\n", 111 | " table_span = (len(input_text + ' | '), len(input_text + ' | ') + len(tables[d['db_id']]['table_names_original'][i]))\n", 112 | " input_text += f\" | {tables[d['db_id']]['table_names_original'][i]} : \"\n", 113 | "\n", 114 | " c_span = (len(input_text), len(input_text) + len(c_name))\n", 115 | " input_text += c_name\n", 116 | "\n", 117 | " else:\n", 118 | " c_span = (len(input_text + ', '), len(input_text + ', ') + len(c_name))\n", 119 | " input_text += f', {c_name}'\n", 120 | "\n", 121 | " word_relations[table_span][c_span] = 'table_column_link'\n", 122 | " word_relations[c_span][table_span] = 'column_table_link'\n", 123 | "\n", 124 | " if len(input_text.split()) > 200:\n", 125 | " # Skipped sample with too long input\n", 126 | " n_skip += 1\n", 127 | " continue\n", 128 | " \n", 129 | " X.append(input_text.lower())\n", 130 | " y.append((d['db_id'] + ' | ' + d['query']).lower())\n", 131 | " X_word_relations.append(word_relations)\n", 132 | " \n", 133 | " return X, y, X_word_relations, n_skip\n", 134 | "\n", 135 | "train_X, train_y, train_X_word_relations, n_skip = get_processed_data(train_data)\n", 136 | "print(\"Train:\", len(train_X), f\" Skipped {n_skip} samples with too long input.\")\n", 137 | "test_X, test_y, test_X_word_relations, n_skip = get_processed_data(test_data)\n", 138 | "print(\"Test:\", len(test_X), f\" Skipped {n_skip} samples with too long input.\")\n", 139 | "\n" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "id": "68bc75d5", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "from transformers import AutoModelForSeq2SeqLM\n", 150 | "import ratransformers\n", 151 | "\n", 152 | "ratransformer = ratransformers.RATransformer(\n", 153 | " 'tscholak/1zha5ono', \n", 154 | " relation_kinds=['table_column_link', 'column_table_link'],\n", 155 | " model_cls=AutoModelForSeq2SeqLM\n", 156 | ")\n", 157 | "model = ratransformer.model\n", 158 | "tokenizer = ratransformer.tokenizer" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 7, 164 | "id": "acf8577a", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "import torch\n", 169 | "\n", 170 | "class Text2SQLDataset(torch.utils.data.Dataset):\n", 171 | " def __init__(self, X, y, tokenizer, X_word_relations=None):\n", 172 | " self.X = X\n", 173 | " self.y = y\n", 174 | " self.X_word_relations = X_word_relations or [None] * len(X)\n", 175 | " self.tokenizer = tokenizer\n", 176 | " \n", 177 | " def __getitem__(self, index: int) -> dict:\n", 178 | " \n", 179 | " source = self.tokenizer(self.X[index], padding='max_length', input_relations=self.X_word_relations[index], return_tensors=\"pt\")\n", 180 | " target = self.tokenizer(self.y[index], padding='max_length', input_relations=None, return_tensors=\"pt\")\n", 181 | " \n", 182 | " source_ids = source[\"input_ids\"].squeeze()\n", 183 | " source_input_relations = source[\"input_relations\"].squeeze()\n", 184 | " target_ids = target[\"input_ids\"].squeeze()\n", 185 | " target_ids[target_ids == 0] = -100\n", 186 | "\n", 187 | " src_mask = source[\"attention_mask\"].squeeze()\n", 188 | " target_mask = target[\"attention_mask\"].squeeze()\n", 189 | "\n", 190 | " return {\n", 191 | " \"input_ids\": source_ids,\n", 192 | " \"attention_mask\": src_mask,\n", 193 | " \"label\": target_ids,\n", 194 | " \"decoder_attention_mask\": target_mask,\n", 195 | " 'input_relations': source_input_relations\n", 196 | " }\n", 197 | "\n", 198 | " def __len__(self):\n", 199 | " return len(self.X)\n", 200 | "\n", 201 | "# Get datasets with word relations\n", 202 | "train_d = Text2SQLDataset(train_X, train_y, tokenizer, train_X_word_relations)\n", 203 | "val_d = Text2SQLDataset(test_X, test_y, tokenizer, test_X_word_relations)\n", 204 | "\n", 205 | "# Get datasets without word relations\n", 206 | "train_d_without_relations = Text2SQLDataset(train_X, train_y, tokenizer)\n", 207 | "val_d_without_relations = Text2SQLDataset(test_X, test_y, tokenizer)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 8, 213 | "id": "1fe8de0f", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback\n", 218 | "\n", 219 | "# Set training arguments\n", 220 | "training_args = Seq2SeqTrainingArguments(\n", 221 | " output_dir='checkpoints',\n", 222 | " per_device_train_batch_size=1,\n", 223 | " gradient_accumulation_steps=2,\n", 224 | " per_device_eval_batch_size=4,\n", 225 | " evaluation_strategy='steps',\n", 226 | " max_steps=100000,\n", 227 | " eval_steps=1000,\n", 228 | " seed=42,\n", 229 | " save_total_limit=1,\n", 230 | " predict_with_generate=True,\n", 231 | " load_best_model_at_end=True\n", 232 | ")" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 10, 238 | "id": "36c0684f", 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "data": { 243 | "text/html": [ 244 | "\n", 245 | "
\n", 246 | " \n", 247 | " \n", 248 | " [259/259 1:50:31]\n", 249 | "
\n", 250 | " " 251 | ], 252 | "text/plain": [ 253 | "" 254 | ] 255 | }, 256 | "metadata": {}, 257 | "output_type": "display_data" 258 | }, 259 | { 260 | "data": { 261 | "text/plain": [ 262 | "{'eval_loss': 1.0085068941116333,\n", 263 | " 'eval_runtime': 1324.2969,\n", 264 | " 'eval_samples_per_second': 0.781}" 265 | ] 266 | }, 267 | "execution_count": 10, 268 | "metadata": {}, 269 | "output_type": "execute_result" 270 | } 271 | ], 272 | "source": [ 273 | "# Set trainer\n", 274 | "trainer = Seq2SeqTrainer(\n", 275 | " model=model, \n", 276 | " args=training_args,\n", 277 | " train_dataset=train_d, \n", 278 | " eval_dataset=val_d, \n", 279 | " tokenizer=tokenizer,\n", 280 | " callbacks=[EarlyStoppingCallback()]\n", 281 | ")\n", 282 | "\n", 283 | "# get performance before training\n", 284 | "trainer.evaluate()" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 11, 290 | "id": "23b5a658", 291 | "metadata": { 292 | "scrolled": false 293 | }, 294 | "outputs": [ 295 | { 296 | "data": { 297 | "text/html": [ 298 | "\n", 299 | "
\n", 300 | " \n", 301 | " \n", 302 | " [ 2000/100000 2:56:39 < 144:24:58, 0.19 it/s, Epoch 0/24]\n", 303 | "
\n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | "
StepTraining LossValidation Loss
10000.0997000.325624
20000.0605000.348871

" 325 | ], 326 | "text/plain": [ 327 | "" 328 | ] 329 | }, 330 | "metadata": {}, 331 | "output_type": "display_data" 332 | }, 333 | { 334 | "data": { 335 | "text/plain": [ 336 | "TrainOutput(global_step=2000, training_loss=0.10654444885253907, metrics={'train_runtime': 10601.4604, 'train_samples_per_second': 9.433, 'total_flos': 0, 'epoch': 0.47})" 337 | ] 338 | }, 339 | "execution_count": 11, 340 | "metadata": {}, 341 | "output_type": "execute_result" 342 | } 343 | ], 344 | "source": [ 345 | "# train until early stopping\n", 346 | "trainer.train()" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 12, 352 | "id": "06da266f", 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "data": { 357 | "text/html": [ 358 | "\n", 359 | "

\n", 360 | " \n", 361 | " \n", 362 | " [259/259 20:01]\n", 363 | "
\n", 364 | " " 365 | ], 366 | "text/plain": [ 367 | "" 368 | ] 369 | }, 370 | "metadata": {}, 371 | "output_type": "display_data" 372 | }, 373 | { 374 | "data": { 375 | "text/plain": [ 376 | "{'eval_loss': 0.32562369108200073,\n", 377 | " 'eval_runtime': 1206.6218,\n", 378 | " 'eval_samples_per_second': 0.857,\n", 379 | " 'epoch': 0.47}" 380 | ] 381 | }, 382 | "execution_count": 12, 383 | "metadata": {}, 384 | "output_type": "execute_result" 385 | } 386 | ], 387 | "source": [ 388 | "# get performance after training\n", 389 | "trainer.evaluate()" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 13, 395 | "id": "9854bb38", 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "# Save model\n", 400 | "trainer.save_model('ra-tscholak/1zha5ono')" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "id": "48673297", 406 | "metadata": {}, 407 | "source": [ 408 | "Training done! After saving, you can then reload the model with the ratransformers package again!" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 9, 414 | "id": "5c6251b6", 415 | "metadata": { 416 | "scrolled": false 417 | }, 418 | "outputs": [ 419 | { 420 | "data": { 421 | "text/html": [ 422 | "\n", 423 | "
\n", 424 | " \n", 425 | " \n", 426 | " [259/259 15:35]\n", 427 | "
\n", 428 | " " 429 | ], 430 | "text/plain": [ 431 | "" 432 | ] 433 | }, 434 | "metadata": {}, 435 | "output_type": "display_data" 436 | }, 437 | { 438 | "data": { 439 | "text/plain": [ 440 | "{'eval_loss': 0.32562369108200073,\n", 441 | " 'eval_runtime': 938.7328,\n", 442 | " 'eval_samples_per_second': 1.101}" 443 | ] 444 | }, 445 | "execution_count": 9, 446 | "metadata": {}, 447 | "output_type": "execute_result" 448 | } 449 | ], 450 | "source": [ 451 | "# Reload model again\n", 452 | "ratransformer = ratransformers.RATransformer(\n", 453 | " 'ra-tscholak/1zha5ono', \n", 454 | " relation_kinds=['table_column_link', 'column_table_link'],\n", 455 | " alias_model_name='t5'\n", 456 | ")\n", 457 | "model = ratransformer.model\n", 458 | "tokenizer = ratransformer.tokenizer\n", 459 | "\n", 460 | "trainer = Seq2SeqTrainer(\n", 461 | " model=model, \n", 462 | " args=training_args,\n", 463 | " train_dataset=train_d, \n", 464 | " eval_dataset=val_d, \n", 465 | " tokenizer=tokenizer,\n", 466 | " callbacks=[EarlyStoppingCallback()]\n", 467 | ")\n", 468 | "trainer.evaluate()" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "id": "ee0202ca", 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [] 478 | } 479 | ], 480 | "metadata": { 481 | "kernelspec": { 482 | "display_name": "Python 3", 483 | "language": "python", 484 | "name": "python3" 485 | }, 486 | "language_info": { 487 | "codemirror_mode": { 488 | "name": "ipython", 489 | "version": 3 490 | }, 491 | "file_extension": ".py", 492 | "mimetype": "text/x-python", 493 | "name": "python", 494 | "nbconvert_exporter": "python", 495 | "pygments_lexer": "ipython3", 496 | "version": "3.7.12" 497 | } 498 | }, 499 | "nbformat": 4, 500 | "nbformat_minor": 5 501 | } 502 | -------------------------------------------------------------------------------- /src/ratransformers/t5.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from types import MethodType 4 | from typing import Optional, Tuple, Union 5 | 6 | from torch.nn import CrossEntropyLoss 7 | from torch.utils.checkpoint import checkpoint 8 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqModelOutput, \ 9 | BaseModelOutput, Seq2SeqLMOutput 10 | from transformers.models.t5.modeling_t5 import T5Attention, T5LayerSelfAttention, T5LayerCrossAttention, T5Block, \ 11 | logger, T5Stack, T5Model, T5_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, T5ForConditionalGeneration, T5EncoderModel, \ 12 | T5_ENCODER_INPUTS_DOCSTRING 13 | import torch.nn as nn 14 | import torch 15 | from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings 16 | 17 | 18 | def change_t5_module_and_get_relational_emb_dim(module_name: str, module: nn.Module) -> Optional[int]: 19 | relational_embedding_dim = None 20 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, T5Attention): 21 | module.forward = MethodType(RelationalT5Attention.forward, module) 22 | relational_embedding_dim = module.key_value_proj_dim 23 | elif isinstance(module, T5Stack): 24 | module.forward = MethodType(RelationalT5Stack.forward, module) 25 | elif isinstance(module, T5ForConditionalGeneration): 26 | module.forward = MethodType(RelationalT5ForConditionalGeneration.forward, module) 27 | elif isinstance(module, T5EncoderModel): 28 | module.forward = MethodType(RelationalT5EncoderModel.forward, module) 29 | elif isinstance(module, T5Model): 30 | module.forward = MethodType(RelationalT5Model.forward, module) 31 | elif isinstance(module, T5Block): 32 | module.forward = MethodType(RelationalT5Block.forward, module) 33 | elif isinstance(module, T5LayerCrossAttention): 34 | module.forward = MethodType(RelationalT5LayerCrossAttention.forward, module) 35 | elif isinstance(module, T5LayerSelfAttention): 36 | module.forward = MethodType(RelationalT5LayerSelfAttention.forward, module) 37 | return relational_embedding_dim 38 | 39 | 40 | class RelationalT5ForConditionalGeneration(T5ForConditionalGeneration): 41 | @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) 42 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 43 | def forward( 44 | self, 45 | input_ids: Optional[torch.LongTensor] = None, 46 | attention_mask: Optional[torch.FloatTensor] = None, 47 | decoder_input_ids: Optional[torch.LongTensor] = None, 48 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 49 | head_mask: Optional[torch.FloatTensor] = None, 50 | decoder_head_mask: Optional[torch.FloatTensor] = None, 51 | cross_attn_head_mask: Optional[torch.Tensor] = None, 52 | encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, 53 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 54 | inputs_embeds: Optional[torch.FloatTensor] = None, 55 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 56 | labels: Optional[torch.LongTensor] = None, 57 | use_cache: Optional[bool] = None, 58 | output_attentions: Optional[bool] = None, 59 | output_hidden_states: Optional[bool] = None, 60 | return_dict: Optional[bool] = None, 61 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 62 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: 63 | r""" 64 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 65 | Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., 66 | config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for 67 | labels in `[0, ..., config.vocab_size]` 68 | 69 | Returns: 70 | 71 | Examples: 72 | 73 | ```python 74 | >>> from transformers import T5Tokenizer, T5ForConditionalGeneration 75 | 76 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") 77 | >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") 78 | 79 | >>> # training 80 | >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids 81 | >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids 82 | >>> outputs = model(input_ids=input_ids, labels=labels) 83 | >>> loss = outputs.loss 84 | >>> logits = outputs.logits 85 | 86 | >>> # inference 87 | >>> input_ids = tokenizer( 88 | ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" 89 | ... ).input_ids # Batch size 1 90 | >>> outputs = model.generate(input_ids) 91 | >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 92 | >>> # studies have shown that owning a dog is good for you. 93 | ```""" 94 | use_cache = use_cache if use_cache is not None else self.config.use_cache 95 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 96 | 97 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 98 | if head_mask is not None and decoder_head_mask is None: 99 | if self.config.num_layers == self.config.num_decoder_layers: 100 | warnings.warn( 101 | """ 102 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, 103 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. 104 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, 105 | num_heads)`. 106 | """ 107 | , FutureWarning 108 | ) 109 | decoder_head_mask = head_mask 110 | 111 | # Encode if needed (training, first prediction pass) 112 | if encoder_outputs is None: 113 | # Convert encoder inputs in embeddings if needed 114 | encoder_outputs = self.encoder( 115 | input_ids=input_ids, 116 | attention_mask=attention_mask, 117 | inputs_embeds=inputs_embeds, 118 | head_mask=head_mask, 119 | output_attentions=output_attentions, 120 | output_hidden_states=output_hidden_states, 121 | return_dict=return_dict, 122 | input_relations=input_relations 123 | ) 124 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 125 | encoder_outputs = BaseModelOutput( 126 | last_hidden_state=encoder_outputs[0], 127 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 128 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 129 | ) 130 | 131 | hidden_states = encoder_outputs[0] 132 | 133 | if self.model_parallel: 134 | torch.cuda.set_device(self.decoder.first_device) 135 | 136 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 137 | # get decoder inputs from shifting lm labels to the right 138 | decoder_input_ids = self._shift_right(labels) 139 | 140 | # Set device for model parallelism 141 | if self.model_parallel: 142 | torch.cuda.set_device(self.decoder.first_device) 143 | hidden_states = hidden_states.to(self.decoder.first_device) 144 | if decoder_input_ids is not None: 145 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 146 | if attention_mask is not None: 147 | attention_mask = attention_mask.to(self.decoder.first_device) 148 | if decoder_attention_mask is not None: 149 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) 150 | 151 | # Decode 152 | decoder_outputs = self.decoder( 153 | input_ids=decoder_input_ids, 154 | attention_mask=decoder_attention_mask, 155 | inputs_embeds=decoder_inputs_embeds, 156 | past_key_values=past_key_values, 157 | encoder_hidden_states=hidden_states, 158 | encoder_attention_mask=attention_mask, 159 | head_mask=decoder_head_mask, 160 | cross_attn_head_mask=cross_attn_head_mask, 161 | use_cache=use_cache, 162 | output_attentions=output_attentions, 163 | output_hidden_states=output_hidden_states, 164 | return_dict=return_dict, 165 | input_relations=input_relations 166 | ) 167 | 168 | sequence_output = decoder_outputs[0] 169 | 170 | # Set device for model parallelism 171 | if self.model_parallel: 172 | torch.cuda.set_device(self.encoder.first_device) 173 | self.lm_head = self.lm_head.to(self.encoder.first_device) 174 | sequence_output = sequence_output.to(self.lm_head.weight.device) 175 | 176 | if self.config.tie_word_embeddings: 177 | # Rescale output before projecting on vocab 178 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 179 | sequence_output = sequence_output * (self.model_dim ** -0.5) 180 | 181 | lm_logits = self.lm_head(sequence_output) 182 | 183 | loss = None 184 | if labels is not None: 185 | loss_fct = CrossEntropyLoss(ignore_index=-100) 186 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 187 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 188 | 189 | if not return_dict: 190 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs 191 | return ((loss,) + output) if loss is not None else output 192 | 193 | return Seq2SeqLMOutput( 194 | loss=loss, 195 | logits=lm_logits, 196 | past_key_values=decoder_outputs.past_key_values, 197 | decoder_hidden_states=decoder_outputs.hidden_states, 198 | decoder_attentions=decoder_outputs.attentions, 199 | cross_attentions=decoder_outputs.cross_attentions, 200 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 201 | encoder_hidden_states=encoder_outputs.hidden_states, 202 | encoder_attentions=encoder_outputs.attentions, 203 | ) 204 | 205 | 206 | class RelationalT5EncoderModel(T5EncoderModel): 207 | @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) 208 | @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) 209 | def forward( 210 | self, 211 | input_ids: Optional[torch.LongTensor] = None, 212 | attention_mask: Optional[torch.FloatTensor] = None, 213 | head_mask: Optional[torch.FloatTensor] = None, 214 | inputs_embeds: Optional[torch.FloatTensor] = None, 215 | output_attentions: Optional[bool] = None, 216 | output_hidden_states: Optional[bool] = None, 217 | return_dict: Optional[bool] = None, 218 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 219 | ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: 220 | r""" 221 | Returns: 222 | 223 | Example: 224 | 225 | ```python 226 | >>> from transformers import T5Tokenizer, T5EncoderModel 227 | 228 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") 229 | >>> model = T5EncoderModel.from_pretrained("t5-small") 230 | >>> input_ids = tokenizer( 231 | ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" 232 | ... ).input_ids # Batch size 1 233 | >>> outputs = model(input_ids=input_ids) 234 | >>> last_hidden_states = outputs.last_hidden_state 235 | ```""" 236 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 237 | 238 | encoder_outputs = self.encoder( 239 | input_ids=input_ids, 240 | attention_mask=attention_mask, 241 | inputs_embeds=inputs_embeds, 242 | head_mask=head_mask, 243 | output_attentions=output_attentions, 244 | output_hidden_states=output_hidden_states, 245 | return_dict=return_dict, 246 | input_relations=input_relations 247 | ) 248 | 249 | return encoder_outputs 250 | 251 | 252 | class RelationalT5Model(T5Model): 253 | @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) 254 | @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) 255 | def forward( 256 | self, 257 | input_ids: Optional[torch.LongTensor] = None, 258 | attention_mask: Optional[torch.FloatTensor] = None, 259 | decoder_input_ids: Optional[torch.LongTensor] = None, 260 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 261 | head_mask: Optional[torch.FloatTensor] = None, 262 | decoder_head_mask: Optional[torch.FloatTensor] = None, 263 | cross_attn_head_mask: Optional[torch.Tensor] = None, 264 | encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 265 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 266 | inputs_embeds: Optional[torch.Tensor] = None, 267 | decoder_inputs_embeds: Optional[torch.Tensor] = None, 268 | use_cache: Optional[bool] = None, 269 | output_attentions: Optional[bool] = None, 270 | output_hidden_states: Optional[bool] = None, 271 | return_dict: Optional[bool] = None, 272 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 273 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: 274 | r""" 275 | Returns: 276 | 277 | Example: 278 | 279 | ```python 280 | >>> from transformers import T5Tokenizer, T5Model 281 | 282 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") 283 | >>> model = T5Model.from_pretrained("t5-small") 284 | 285 | >>> input_ids = tokenizer( 286 | ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" 287 | ... ).input_ids # Batch size 1 288 | >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 289 | 290 | >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. 291 | >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. 292 | >>> decoder_input_ids = model._shift_right(decoder_input_ids) 293 | 294 | >>> # forward pass 295 | >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) 296 | >>> last_hidden_states = outputs.last_hidden_state 297 | ```""" 298 | use_cache = use_cache if use_cache is not None else self.config.use_cache 299 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 300 | 301 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 302 | if head_mask is not None and decoder_head_mask is None: 303 | if self.config.num_layers == self.config.num_decoder_layers: 304 | warnings.warn( 305 | """ 306 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, 307 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. 308 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, 309 | num_heads)`. 310 | """, FutureWarning 311 | ) 312 | decoder_head_mask = head_mask 313 | 314 | # Encode if needed (training, first prediction pass) 315 | if encoder_outputs is None: 316 | encoder_outputs = self.encoder( 317 | input_ids=input_ids, 318 | attention_mask=attention_mask, 319 | inputs_embeds=inputs_embeds, 320 | head_mask=head_mask, 321 | output_attentions=output_attentions, 322 | output_hidden_states=output_hidden_states, 323 | return_dict=return_dict, 324 | input_relations=input_relations 325 | ) 326 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 327 | encoder_outputs = BaseModelOutput( 328 | last_hidden_state=encoder_outputs[0], 329 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 330 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 331 | ) 332 | 333 | hidden_states = encoder_outputs[0] 334 | 335 | # Set device for model parallelism 336 | if self.model_parallel: 337 | torch.cuda.set_device(self.decoder.first_device) 338 | hidden_states = hidden_states.to(self.decoder.first_device) 339 | if decoder_input_ids is not None: 340 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 341 | if attention_mask is not None: 342 | attention_mask = attention_mask.to(self.decoder.first_device) 343 | if decoder_attention_mask is not None: 344 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) 345 | 346 | # Decode 347 | decoder_outputs = self.decoder( 348 | input_ids=decoder_input_ids, 349 | attention_mask=decoder_attention_mask, 350 | inputs_embeds=decoder_inputs_embeds, 351 | past_key_values=past_key_values, 352 | encoder_hidden_states=hidden_states, 353 | encoder_attention_mask=attention_mask, 354 | head_mask=decoder_head_mask, 355 | cross_attn_head_mask=cross_attn_head_mask, 356 | use_cache=use_cache, 357 | output_attentions=output_attentions, 358 | output_hidden_states=output_hidden_states, 359 | return_dict=return_dict, 360 | input_relations=input_relations 361 | ) 362 | 363 | if not return_dict: 364 | return decoder_outputs + encoder_outputs 365 | 366 | return Seq2SeqModelOutput( 367 | last_hidden_state=decoder_outputs.last_hidden_state, 368 | past_key_values=decoder_outputs.past_key_values, 369 | decoder_hidden_states=decoder_outputs.hidden_states, 370 | decoder_attentions=decoder_outputs.attentions, 371 | cross_attentions=decoder_outputs.cross_attentions, 372 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 373 | encoder_hidden_states=encoder_outputs.hidden_states, 374 | encoder_attentions=encoder_outputs.attentions, 375 | ) 376 | 377 | 378 | class RelationalT5Stack(T5Stack): 379 | def forward( 380 | self, 381 | input_ids=None, 382 | attention_mask=None, 383 | encoder_hidden_states=None, 384 | encoder_attention_mask=None, 385 | inputs_embeds=None, 386 | head_mask=None, 387 | cross_attn_head_mask=None, 388 | past_key_values=None, 389 | use_cache=None, 390 | output_attentions=None, 391 | output_hidden_states=None, 392 | return_dict=None, 393 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 394 | ): 395 | # Model parallel 396 | if self.model_parallel: 397 | torch.cuda.set_device(self.first_device) 398 | self.embed_tokens = self.embed_tokens.to(self.first_device) 399 | use_cache = use_cache if use_cache is not None else self.config.use_cache 400 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 401 | output_hidden_states = ( 402 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 403 | ) 404 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 405 | 406 | if input_ids is not None and inputs_embeds is not None: 407 | err_msg_prefix = "decoder_" if self.is_decoder else "" 408 | raise ValueError( 409 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" 410 | ) 411 | elif input_ids is not None: 412 | input_shape = input_ids.size() 413 | input_ids = input_ids.view(-1, input_shape[-1]) 414 | elif inputs_embeds is not None: 415 | input_shape = inputs_embeds.size()[:-1] 416 | else: 417 | err_msg_prefix = "decoder_" if self.is_decoder else "" 418 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") 419 | 420 | if inputs_embeds is None: 421 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 422 | inputs_embeds = self.embed_tokens(input_ids) 423 | 424 | batch_size, seq_length = input_shape 425 | 426 | # required mask seq length can be calculated via length of past 427 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length 428 | 429 | if use_cache is True: 430 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" 431 | 432 | if attention_mask is None: 433 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) 434 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 435 | encoder_seq_length = encoder_hidden_states.shape[1] 436 | encoder_attention_mask = torch.ones( 437 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 438 | ) 439 | 440 | # initialize past_key_values with `None` if past does not exist 441 | if past_key_values is None: 442 | past_key_values = [None] * len(self.block) 443 | 444 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 445 | # ourselves in which case we just need to make it broadcastable to all heads. 446 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) 447 | 448 | # If a 2D or 3D attention mask is provided for the cross-attention 449 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 450 | if self.is_decoder and encoder_hidden_states is not None: 451 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 452 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 453 | if encoder_attention_mask is None: 454 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 455 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 456 | else: 457 | encoder_extended_attention_mask = None 458 | 459 | # Prepare head mask if needed 460 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 461 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) 462 | present_key_value_states = () if use_cache else None 463 | all_hidden_states = () if output_hidden_states else None 464 | all_attentions = () if output_attentions else None 465 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 466 | position_bias = None 467 | encoder_decoder_position_bias = None 468 | 469 | hidden_states = self.dropout(inputs_embeds) 470 | 471 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 472 | layer_head_mask = head_mask[i] 473 | cross_attn_layer_head_mask = cross_attn_head_mask[i] 474 | # Model parallel 475 | if self.model_parallel: 476 | torch.cuda.set_device(hidden_states.device) 477 | # Ensure that attention_mask is always on the same device as hidden_states 478 | if attention_mask is not None: 479 | attention_mask = attention_mask.to(hidden_states.device) 480 | if position_bias is not None: 481 | position_bias = position_bias.to(hidden_states.device) 482 | if encoder_hidden_states is not None: 483 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) 484 | if encoder_extended_attention_mask is not None: 485 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) 486 | if encoder_decoder_position_bias is not None: 487 | encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) 488 | if layer_head_mask is not None: 489 | layer_head_mask = layer_head_mask.to(hidden_states.device) 490 | if cross_attn_layer_head_mask is not None: 491 | cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) 492 | if output_hidden_states: 493 | all_hidden_states = all_hidden_states + (hidden_states,) 494 | 495 | if self.gradient_checkpointing and self.training: 496 | if use_cache: 497 | logger.warning( 498 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 499 | ) 500 | use_cache = False 501 | 502 | def create_custom_forward(module): 503 | def custom_forward(*inputs, **kwd): 504 | return tuple(module(*inputs, **kwd, use_cache=use_cache, output_attentions=output_attentions)) 505 | 506 | return custom_forward 507 | 508 | layer_outputs = checkpoint( 509 | create_custom_forward(layer_module), 510 | hidden_states, 511 | extended_attention_mask, 512 | position_bias, 513 | encoder_hidden_states, 514 | encoder_extended_attention_mask, 515 | encoder_decoder_position_bias, 516 | layer_head_mask, 517 | cross_attn_layer_head_mask, 518 | None, # past_key_value is always None with gradient checkpointing 519 | input_relations=input_relations 520 | ) 521 | else: 522 | layer_outputs = layer_module( 523 | hidden_states, 524 | attention_mask=extended_attention_mask, 525 | position_bias=position_bias, 526 | encoder_hidden_states=encoder_hidden_states, 527 | encoder_attention_mask=encoder_extended_attention_mask, 528 | encoder_decoder_position_bias=encoder_decoder_position_bias, 529 | layer_head_mask=layer_head_mask, 530 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 531 | past_key_value=past_key_value, 532 | use_cache=use_cache, 533 | output_attentions=output_attentions, 534 | input_relations=input_relations 535 | ) 536 | 537 | # layer_outputs is a tuple with: 538 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 539 | if use_cache is False: 540 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] 541 | 542 | hidden_states, present_key_value_state = layer_outputs[:2] 543 | 544 | # We share the position biases between the layers - the first layer store them 545 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), 546 | # (cross-attention position bias), (cross-attention weights) 547 | position_bias = layer_outputs[2] 548 | if self.is_decoder and encoder_hidden_states is not None: 549 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] 550 | # append next layer key value states 551 | if use_cache: 552 | present_key_value_states = present_key_value_states + (present_key_value_state,) 553 | 554 | if output_attentions: 555 | all_attentions = all_attentions + (layer_outputs[3],) 556 | if self.is_decoder: 557 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],) 558 | 559 | # Model Parallel: If it's the last layer for that device, put things on the next device 560 | if self.model_parallel: 561 | for k, v in self.device_map.items(): 562 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 563 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 564 | 565 | hidden_states = self.final_layer_norm(hidden_states) 566 | hidden_states = self.dropout(hidden_states) 567 | 568 | # Add last layer 569 | if output_hidden_states: 570 | all_hidden_states = all_hidden_states + (hidden_states,) 571 | 572 | if not return_dict: 573 | return tuple( 574 | v 575 | for v in [ 576 | hidden_states, 577 | present_key_value_states, 578 | all_hidden_states, 579 | all_attentions, 580 | all_cross_attentions, 581 | ] 582 | if v is not None 583 | ) 584 | return BaseModelOutputWithPastAndCrossAttentions( 585 | last_hidden_state=hidden_states, 586 | past_key_values=present_key_value_states, 587 | hidden_states=all_hidden_states, 588 | attentions=all_attentions, 589 | cross_attentions=all_cross_attentions, 590 | ) 591 | 592 | 593 | class RelationalT5Block(T5Block): 594 | def forward( 595 | self, 596 | hidden_states, 597 | attention_mask=None, 598 | position_bias=None, 599 | encoder_hidden_states=None, 600 | encoder_attention_mask=None, 601 | encoder_decoder_position_bias=None, 602 | layer_head_mask=None, 603 | cross_attn_layer_head_mask=None, 604 | past_key_value=None, 605 | use_cache=False, 606 | output_attentions=False, 607 | return_dict=True, 608 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 609 | ): 610 | 611 | if past_key_value is not None: 612 | if not self.is_decoder: 613 | logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") 614 | expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 615 | 616 | if len(past_key_value) != expected_num_past_key_values: 617 | raise ValueError( 618 | f"There should be {expected_num_past_key_values} past states. " 619 | f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" 620 | f"Got {len(past_key_value)} past key / value states" 621 | ) 622 | 623 | self_attn_past_key_value = past_key_value[:2] 624 | cross_attn_past_key_value = past_key_value[2:] 625 | else: 626 | self_attn_past_key_value, cross_attn_past_key_value = None, None 627 | 628 | self_attention_outputs = self.layer[0]( 629 | hidden_states, 630 | attention_mask=attention_mask, 631 | position_bias=position_bias, 632 | layer_head_mask=layer_head_mask, 633 | past_key_value=self_attn_past_key_value, 634 | use_cache=use_cache, 635 | output_attentions=output_attentions, 636 | input_relations=input_relations 637 | ) 638 | hidden_states, present_key_value_state = self_attention_outputs[:2] 639 | attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights 640 | 641 | # clamp inf values to enable fp16 training 642 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 643 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 644 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 645 | 646 | do_cross_attention = self.is_decoder and encoder_hidden_states is not None 647 | if do_cross_attention: 648 | # the actual query length is unknown for cross attention 649 | # if using past key value states. Need to inject it here 650 | if present_key_value_state is not None: 651 | query_length = present_key_value_state[0].shape[2] 652 | else: 653 | query_length = None 654 | 655 | cross_attention_outputs = self.layer[1]( 656 | hidden_states, 657 | key_value_states=encoder_hidden_states, 658 | attention_mask=encoder_attention_mask, 659 | position_bias=encoder_decoder_position_bias, 660 | layer_head_mask=cross_attn_layer_head_mask, 661 | past_key_value=cross_attn_past_key_value, 662 | query_length=query_length, 663 | use_cache=use_cache, 664 | output_attentions=output_attentions, 665 | input_relations=input_relations 666 | ) 667 | hidden_states = cross_attention_outputs[0] 668 | 669 | # clamp inf values to enable fp16 training 670 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 671 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 672 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 673 | 674 | # Combine self attn and cross attn key value states 675 | if present_key_value_state is not None: 676 | present_key_value_state = present_key_value_state + cross_attention_outputs[1] 677 | 678 | # Keep cross-attention outputs and relative position weights 679 | attention_outputs = attention_outputs + cross_attention_outputs[2:] 680 | 681 | # Apply Feed Forward layer 682 | hidden_states = self.layer[-1](hidden_states) 683 | 684 | # clamp inf values to enable fp16 training 685 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 686 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 687 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 688 | 689 | outputs = (hidden_states,) 690 | 691 | if use_cache: 692 | outputs = outputs + (present_key_value_state,) + attention_outputs 693 | else: 694 | outputs = outputs + attention_outputs 695 | 696 | return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 697 | 698 | 699 | class RelationalT5LayerSelfAttention(T5LayerSelfAttention): 700 | def forward( 701 | self, 702 | hidden_states, 703 | attention_mask=None, 704 | position_bias=None, 705 | layer_head_mask=None, 706 | past_key_value=None, 707 | use_cache=False, 708 | output_attentions=False, 709 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 710 | ): 711 | normed_hidden_states = self.layer_norm(hidden_states) 712 | kwargs = {'input_relations': input_relations} \ 713 | if 'input_relations' in inspect.getfullargspec(self.SelfAttention.forward)[0] else {} 714 | attention_output = self.SelfAttention( 715 | normed_hidden_states, 716 | mask=attention_mask, 717 | position_bias=position_bias, 718 | layer_head_mask=layer_head_mask, 719 | past_key_value=past_key_value, 720 | use_cache=use_cache, 721 | output_attentions=output_attentions, 722 | **kwargs 723 | ) 724 | hidden_states = hidden_states + self.dropout(attention_output[0]) 725 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them 726 | return outputs 727 | 728 | 729 | class RelationalT5LayerCrossAttention(T5LayerCrossAttention): 730 | def forward( 731 | self, 732 | hidden_states, 733 | key_value_states, 734 | attention_mask=None, 735 | position_bias=None, 736 | layer_head_mask=None, 737 | past_key_value=None, 738 | use_cache=False, 739 | query_length=None, 740 | output_attentions=False, 741 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 742 | ): 743 | normed_hidden_states = self.layer_norm(hidden_states) 744 | kwargs = {'input_relations': input_relations} \ 745 | if 'input_relations' in inspect.getfullargspec(self.EncDecAttention.forward)[0] else {} 746 | attention_output = self.EncDecAttention( 747 | normed_hidden_states, 748 | mask=attention_mask, 749 | key_value_states=key_value_states, 750 | position_bias=position_bias, 751 | layer_head_mask=layer_head_mask, 752 | past_key_value=past_key_value, 753 | use_cache=use_cache, 754 | query_length=query_length, 755 | output_attentions=output_attentions, 756 | **kwargs 757 | ) 758 | layer_output = hidden_states + self.dropout(attention_output[0]) 759 | outputs = (layer_output,) + attention_output[1:] # add attentions if we output them 760 | return outputs 761 | 762 | 763 | class RelationalT5Attention(T5Attention): 764 | def forward( 765 | self, 766 | hidden_states, 767 | mask=None, 768 | key_value_states=None, 769 | position_bias=None, 770 | past_key_value=None, 771 | layer_head_mask=None, 772 | query_length=None, 773 | use_cache=False, 774 | output_attentions=False, 775 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 776 | ): 777 | """ 778 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 779 | """ 780 | 781 | # Input is (batch_size, seq_length, dim) 782 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) 783 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) 784 | batch_size, seq_length = hidden_states.shape[:2] 785 | 786 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model" 787 | assert input_relations.shape == (batch_size, seq_length, seq_length) 788 | 789 | # (batch_size, seq_length, seq_length, self.inner_dim // num_relation_kinds) 790 | relation_k_embeds = self.relation_k_emb(input_relations) 791 | relation_v_embeds = self.relation_v_emb(input_relations) 792 | 793 | real_seq_length = seq_length 794 | 795 | if past_key_value is not None: 796 | assert ( 797 | len(past_key_value) == 2 798 | ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 799 | real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length 800 | 801 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] 802 | 803 | def shape(states): 804 | """projection""" 805 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 806 | 807 | def unshape(states): 808 | """reshape""" 809 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) 810 | 811 | def project(hidden_states, proj_layer, key_value_states, past_key_value): 812 | """projects hidden states correctly to key/query states""" 813 | if key_value_states is None: 814 | # self-attn 815 | # (batch_size, n_heads, seq_length, dim_per_head) 816 | hidden_states = shape(proj_layer(hidden_states)) 817 | elif past_key_value is None: 818 | # cross-attn 819 | # (batch_size, n_heads, seq_length, dim_per_head) 820 | hidden_states = shape(proj_layer(key_value_states)) 821 | 822 | if past_key_value is not None: 823 | if key_value_states is None: 824 | # self-attn 825 | # (batch_size, n_heads, key_length, dim_per_head) 826 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2) 827 | else: 828 | # cross-attn 829 | hidden_states = past_key_value 830 | return hidden_states 831 | 832 | # get query states 833 | query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, key_value_proj_dim) 834 | 835 | # get key/value states 836 | key_states = project( 837 | hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None 838 | ) 839 | value_states = project( 840 | hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None 841 | ) 842 | 843 | # compute scores 844 | scores = torch.matmul( 845 | query_states, key_states.transpose(3, 2) 846 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 847 | 848 | # q_t is [batch, seq_length, n_heads, key_value_proj_dim] 849 | q_t = query_states.permute(0, 2, 1, 3) 850 | 851 | # r_t is [batch, seq_length, key_value_proj_dim, seq_length] 852 | r_t = relation_k_embeds.transpose(-2, -1) 853 | 854 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length] 855 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length] 856 | 857 | # Add to scores 858 | scores += q_tr_tmatmul_t 859 | 860 | if position_bias is None: 861 | if not self.has_relative_attention_bias: 862 | position_bias = torch.zeros( 863 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype 864 | ) 865 | if self.gradient_checkpointing and self.training: 866 | position_bias.requires_grad = True 867 | else: 868 | position_bias = self.compute_bias(real_seq_length, key_length) 869 | 870 | # if key and values are already calculated 871 | # we want only the last query position bias 872 | if past_key_value is not None: 873 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :] 874 | 875 | if mask is not None: 876 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) 877 | 878 | if self.pruned_heads: 879 | mask = torch.ones(position_bias.shape[1]) 880 | mask[list(self.pruned_heads)] = 0 881 | position_bias_masked = position_bias[:, mask.bool()] 882 | else: 883 | position_bias_masked = position_bias 884 | 885 | scores += position_bias_masked 886 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( 887 | scores 888 | ) # (batch_size, n_heads, seq_length, key_length) 889 | attn_weights = nn.functional.dropout( 890 | attn_weights, p=self.dropout, training=self.training 891 | ) # (batch_size, n_heads, seq_length, key_length) 892 | 893 | # Mask heads if we want to 894 | if layer_head_mask is not None: 895 | attn_weights = attn_weights * layer_head_mask 896 | 897 | # [batch, n_heads, seq_length, seq_length] 898 | wv_matmul = torch.matmul(attn_weights, value_states) 899 | 900 | # w_t is [batch, seq_length, n_heads, seq_length] 901 | w_t = attn_weights.permute(0, 2, 1, 3) 902 | 903 | # [batch, seq_length, n_heads, seq_length] 904 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds) 905 | 906 | attn_output = unshape(wv_matmul + w_tr_matmul.permute(0, 2, 1, 3)) # (batch_size, seq_length, dim) 907 | attn_output = self.o(attn_output) 908 | 909 | present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None 910 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) 911 | 912 | if output_attentions: 913 | outputs = outputs + (attn_weights,) 914 | return outputs 915 | -------------------------------------------------------------------------------- /src/ratransformers/longt5.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from types import MethodType 4 | from typing import Optional, Tuple, Union, Any 5 | 6 | from torch.nn import CrossEntropyLoss 7 | from torch.utils.checkpoint import checkpoint 8 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqModelOutput, \ 9 | BaseModelOutput, Seq2SeqLMOutput 10 | from transformers.models.longt5.modeling_longt5 import LongT5Attention, LongT5LayerSelfAttention, \ 11 | LongT5LayerCrossAttention, LongT5Block, \ 12 | logger, LongT5Stack, LongT5Model, _CONFIG_FOR_DOC, LONGT5_INPUTS_DOCSTRING, _get_local_attention_mask, \ 13 | LongT5LocalAttention, _split_into_blocks, _concatenate_3_blocks, _pad_to_multiple, LongT5TransientGlobalAttention, \ 14 | _make_global_fixed_block_ids, _create_global_aggregates, LongT5LayerTransientGlobalSelfAttention, \ 15 | LongT5LayerLocalSelfAttention, LongT5PreTrainedModel, LongT5ForConditionalGeneration, LongT5EncoderModel, \ 16 | LONGT5_ENCODER_INPUTS_DOCSTRING 17 | import torch.nn as nn 18 | import torch 19 | from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings 20 | 21 | 22 | def change_longt5_module_and_get_relational_emb_dim(module_name: str, module: nn.Module) -> Optional[int]: 23 | relational_embedding_dim = None 24 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, LongT5Attention): 25 | module.forward = MethodType(RelationalLongT5Attention.forward, module) 26 | relational_embedding_dim = module.inner_dim // module.n_heads 27 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, LongT5LocalAttention): 28 | module.forward = MethodType(RelationalLongT5LocalAttention.forward, module) 29 | relational_embedding_dim = module.inner_dim // module.n_heads 30 | if 'encoder' in module_name and 'decoder' not in module_name and isinstance(module, LongT5TransientGlobalAttention): 31 | module.forward = MethodType(RelationalLongT5TransientGlobalAttention.forward, module) 32 | relational_embedding_dim = module.inner_dim // module.n_heads 33 | elif isinstance(module, LongT5Stack): 34 | module.forward = MethodType(RelationalLongT5Stack.forward, module) 35 | elif isinstance(module, LongT5ForConditionalGeneration): 36 | module.forward = MethodType(RelationalLongT5ForConditionalGeneration.forward, module) 37 | elif isinstance(module, LongT5Model): 38 | module.forward = MethodType(RelationalLongT5Model.forward, module) 39 | elif isinstance(module, LongT5EncoderModel): 40 | module.forward = MethodType(RelationalLongT5EncoderModel.forward, module) 41 | elif isinstance(module, LongT5Block): 42 | module.forward = MethodType(RelationalLongT5Block.forward, module) 43 | elif isinstance(module, LongT5LayerCrossAttention): 44 | module.forward = MethodType(RelationalLongT5LayerCrossAttention.forward, module) 45 | elif isinstance(module, LongT5LayerSelfAttention): 46 | module.forward = MethodType(RelationalLongT5LayerSelfAttention.forward, module) 47 | elif isinstance(module, LongT5LayerLocalSelfAttention): 48 | module.forward = MethodType(RelationalLongT5LayerLocalSelfAttention.forward, module) 49 | elif isinstance(module, LongT5LayerTransientGlobalSelfAttention): 50 | module.forward = MethodType(RelationalLongT5LayerTransientGlobalSelfAttention.forward, module) 51 | return relational_embedding_dim 52 | 53 | 54 | class RelationalLongT5ForConditionalGeneration(LongT5ForConditionalGeneration): 55 | @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING) 56 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 57 | def forward( 58 | self, 59 | input_ids: Optional[torch.LongTensor] = None, 60 | attention_mask: Optional[torch.FloatTensor] = None, 61 | decoder_input_ids: Optional[torch.LongTensor] = None, 62 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 63 | head_mask: Optional[torch.FloatTensor] = None, 64 | decoder_head_mask: Optional[torch.FloatTensor] = None, 65 | cross_attn_head_mask: Optional[torch.Tensor] = None, 66 | encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, 67 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 70 | labels: Optional[torch.LongTensor] = None, 71 | use_cache: Optional[bool] = None, 72 | output_attentions: Optional[bool] = None, 73 | output_hidden_states: Optional[bool] = None, 74 | return_dict: Optional[bool] = None, 75 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 76 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: 77 | r""" 78 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 79 | Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., 80 | config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for 81 | labels in `[0, ..., config.vocab_size]` 82 | 83 | Returns: 84 | 85 | Examples: 86 | 87 | ```python 88 | >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration 89 | 90 | >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps") 91 | >>> model = LongT5ForConditionalGeneration.from_pretrained( 92 | ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps" 93 | ... ) 94 | 95 | >>> # Let's try a very long input. 96 | >>> inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt") 97 | >>> input_ids = inputs.input_ids 98 | 99 | >>> outputs = model.generate(input_ids) 100 | >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 101 | abstractthe aim of this article is to provide an overview of the literature on the role of dog 102 | ```""" 103 | use_cache = use_cache if use_cache is not None else self.config.use_cache 104 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 105 | 106 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 107 | if head_mask is not None and decoder_head_mask is None: 108 | if self.config.num_layers == self.config.num_decoder_layers: 109 | warnings.warn( 110 | """ 111 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, 112 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. 113 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, 114 | num_heads)`. 115 | """ 116 | , FutureWarning 117 | ) 118 | decoder_head_mask = head_mask 119 | 120 | # Encode if needed (training, first prediction pass) 121 | if encoder_outputs is None: 122 | # Convert encoder inputs in embeddings if needed 123 | encoder_outputs = self.encoder( 124 | input_ids=input_ids, 125 | attention_mask=attention_mask, 126 | inputs_embeds=inputs_embeds, 127 | head_mask=head_mask, 128 | output_attentions=output_attentions, 129 | output_hidden_states=output_hidden_states, 130 | return_dict=return_dict, 131 | input_relations=input_relations 132 | ) 133 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 134 | encoder_outputs = BaseModelOutput( 135 | last_hidden_state=encoder_outputs[0], 136 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 137 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 138 | ) 139 | 140 | hidden_states = encoder_outputs[0] 141 | 142 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 143 | # get decoder inputs from shifting lm labels to the right 144 | decoder_input_ids = self._shift_right(labels) 145 | 146 | # Decode 147 | decoder_outputs = self.decoder( 148 | input_ids=decoder_input_ids, 149 | attention_mask=decoder_attention_mask, 150 | inputs_embeds=decoder_inputs_embeds, 151 | past_key_values=past_key_values, 152 | encoder_hidden_states=hidden_states, 153 | encoder_attention_mask=attention_mask, 154 | head_mask=decoder_head_mask, 155 | cross_attn_head_mask=cross_attn_head_mask, 156 | use_cache=use_cache, 157 | output_attentions=output_attentions, 158 | output_hidden_states=output_hidden_states, 159 | return_dict=return_dict, 160 | input_relations=input_relations 161 | ) 162 | 163 | sequence_output = decoder_outputs[0] 164 | 165 | if self.config.tie_word_embeddings: 166 | # Rescale output before projecting on vocab 167 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 168 | sequence_output = sequence_output * (self.model_dim ** -0.5) 169 | 170 | lm_logits = self.lm_head(sequence_output) 171 | 172 | loss = None 173 | if labels is not None: 174 | loss_fct = CrossEntropyLoss(ignore_index=-100) 175 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 176 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 177 | 178 | if not return_dict: 179 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs 180 | return ((loss,) + output) if loss is not None else output 181 | 182 | return Seq2SeqLMOutput( 183 | loss=loss, 184 | logits=lm_logits, 185 | past_key_values=decoder_outputs.past_key_values, 186 | decoder_hidden_states=decoder_outputs.hidden_states, 187 | decoder_attentions=decoder_outputs.attentions, 188 | cross_attentions=decoder_outputs.cross_attentions, 189 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 190 | encoder_hidden_states=encoder_outputs.hidden_states, 191 | encoder_attentions=encoder_outputs.attentions, 192 | ) 193 | 194 | class RelationalLongT5Model(LongT5Model): 195 | @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING) 196 | @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) 197 | def forward( 198 | self, 199 | input_ids: Optional[torch.LongTensor] = None, 200 | attention_mask: Optional[torch.FloatTensor] = None, 201 | decoder_input_ids: Optional[torch.LongTensor] = None, 202 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 203 | head_mask: Optional[torch.FloatTensor] = None, 204 | decoder_head_mask: Optional[torch.FloatTensor] = None, 205 | cross_attn_head_mask: Optional[torch.Tensor] = None, 206 | encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 207 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 208 | inputs_embeds: Optional[torch.Tensor] = None, 209 | decoder_inputs_embeds: Optional[torch.Tensor] = None, 210 | use_cache: Optional[bool] = None, 211 | output_attentions: Optional[bool] = None, 212 | output_hidden_states: Optional[bool] = None, 213 | return_dict: Optional[bool] = None, 214 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 215 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: 216 | r""" 217 | Returns: 218 | 219 | Example: 220 | 221 | ```python 222 | >>> from transformers import T5Tokenizer, LongT5Model 223 | 224 | >>> tokenizer = T5Tokenizer.from_pretrained("google/long-t5-local-base") 225 | >>> model = LongT5Model.from_pretrained("google/long-t5-local-base") 226 | 227 | >>> # Let's try a very long encoder input. 228 | >>> input_ids = tokenizer( 229 | ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt" 230 | ... ).input_ids # Batch size 1 231 | 232 | >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 233 | 234 | >>> # forward pass 235 | >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) 236 | >>> last_hidden_states = outputs.last_hidden_state 237 | ```""" 238 | use_cache = use_cache if use_cache is not None else self.config.use_cache 239 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 240 | 241 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 242 | if head_mask is not None and decoder_head_mask is None: 243 | if self.config.num_layers == self.config.num_decoder_layers: 244 | warnings.warn( 245 | """ 246 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, 247 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. 248 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, 249 | num_heads)`. 250 | """ 251 | , FutureWarning 252 | ) 253 | decoder_head_mask = head_mask 254 | 255 | # Encode if needed (training, first prediction pass) 256 | if encoder_outputs is None: 257 | encoder_outputs = self.encoder( 258 | input_ids=input_ids, 259 | attention_mask=attention_mask, 260 | inputs_embeds=inputs_embeds, 261 | head_mask=head_mask, 262 | output_attentions=output_attentions, 263 | output_hidden_states=output_hidden_states, 264 | return_dict=return_dict, 265 | input_relations=input_relations 266 | ) 267 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 268 | encoder_outputs = BaseModelOutput( 269 | last_hidden_state=encoder_outputs[0], 270 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 271 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 272 | ) 273 | 274 | hidden_states = encoder_outputs[0] 275 | 276 | # Decode 277 | decoder_outputs = self.decoder( 278 | input_ids=decoder_input_ids, 279 | attention_mask=decoder_attention_mask, 280 | inputs_embeds=decoder_inputs_embeds, 281 | past_key_values=past_key_values, 282 | encoder_hidden_states=hidden_states, 283 | encoder_attention_mask=attention_mask, 284 | head_mask=decoder_head_mask, 285 | cross_attn_head_mask=cross_attn_head_mask, 286 | use_cache=use_cache, 287 | output_attentions=output_attentions, 288 | output_hidden_states=output_hidden_states, 289 | return_dict=return_dict, 290 | input_relations=input_relations 291 | ) 292 | 293 | if not return_dict: 294 | return decoder_outputs + encoder_outputs 295 | 296 | return Seq2SeqModelOutput( 297 | last_hidden_state=decoder_outputs.last_hidden_state, 298 | past_key_values=decoder_outputs.past_key_values, 299 | decoder_hidden_states=decoder_outputs.hidden_states, 300 | decoder_attentions=decoder_outputs.attentions, 301 | cross_attentions=decoder_outputs.cross_attentions, 302 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 303 | encoder_hidden_states=encoder_outputs.hidden_states, 304 | encoder_attentions=encoder_outputs.attentions, 305 | ) 306 | 307 | 308 | class RelationalLongT5EncoderModel(LongT5EncoderModel): 309 | @add_start_docstrings_to_model_forward(LONGT5_ENCODER_INPUTS_DOCSTRING) 310 | @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) 311 | def forward( 312 | self, 313 | input_ids: Optional[torch.LongTensor] = None, 314 | attention_mask: Optional[torch.FloatTensor] = None, 315 | head_mask: Optional[torch.FloatTensor] = None, 316 | inputs_embeds: Optional[torch.FloatTensor] = None, 317 | output_attentions: Optional[bool] = None, 318 | output_hidden_states: Optional[bool] = None, 319 | return_dict: Optional[bool] = None, 320 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 321 | ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: 322 | r""" 323 | Returns: 324 | 325 | Example: 326 | 327 | ```python 328 | >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration 329 | 330 | >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base") 331 | >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base") 332 | >>> input_ids = tokenizer( 333 | ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt" 334 | ... ).input_ids # Batch size 1 335 | >>> outputs = model(input_ids=input_ids) 336 | >>> last_hidden_states = outputs.last_hidden_state 337 | ```""" 338 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 339 | 340 | encoder_outputs = self.encoder( 341 | input_ids=input_ids, 342 | attention_mask=attention_mask, 343 | inputs_embeds=inputs_embeds, 344 | head_mask=head_mask, 345 | output_attentions=output_attentions, 346 | output_hidden_states=output_hidden_states, 347 | return_dict=return_dict, 348 | input_relations=input_relations 349 | ) 350 | 351 | return encoder_outputs 352 | 353 | class RelationalLongT5Stack(LongT5Stack): 354 | def forward( 355 | self, 356 | input_ids=None, 357 | attention_mask=None, 358 | encoder_hidden_states=None, 359 | encoder_attention_mask=None, 360 | inputs_embeds=None, 361 | head_mask=None, 362 | cross_attn_head_mask=None, 363 | past_key_values=None, 364 | use_cache=None, 365 | output_attentions=None, 366 | output_hidden_states=None, 367 | return_dict=None, 368 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 369 | ): 370 | use_cache = use_cache if use_cache is not None else self.config.use_cache 371 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 372 | output_hidden_states = ( 373 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 374 | ) 375 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 376 | 377 | if input_ids is not None and inputs_embeds is not None: 378 | err_msg_prefix = "decoder_" if self.is_decoder else "" 379 | raise ValueError( 380 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" 381 | ) 382 | elif input_ids is not None: 383 | input_shape = input_ids.size() 384 | input_ids = input_ids.view(-1, input_shape[-1]) 385 | elif inputs_embeds is not None: 386 | input_shape = inputs_embeds.size()[:-1] 387 | else: 388 | err_msg_prefix = "decoder_" if self.is_decoder else "" 389 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") 390 | 391 | if inputs_embeds is None: 392 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 393 | inputs_embeds = self.embed_tokens(input_ids) 394 | 395 | batch_size, seq_length = input_shape 396 | 397 | # required mask seq length can be calculated via length of past 398 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length 399 | 400 | if use_cache is True: 401 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" 402 | 403 | if attention_mask is None: 404 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) 405 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 406 | encoder_seq_length = encoder_hidden_states.shape[1] 407 | encoder_attention_mask = torch.ones( 408 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 409 | ) 410 | 411 | # initialize past_key_values with `None` if past does not exist 412 | if past_key_values is None: 413 | past_key_values = [None] * len(self.block) 414 | 415 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 416 | # ourselves in which case we just need to make it broadcastable to all heads. 417 | # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used 418 | if self.is_decoder: 419 | extended_attention_mask = self.get_extended_attention_mask( 420 | attention_mask, input_shape, inputs_embeds.device 421 | ) 422 | elif self.config.encoder_attention_type == "local": 423 | extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) 424 | else: # we need to use both local attention mask and standard extended mask for transient-global attention 425 | extended_attention_mask = attention_mask 426 | 427 | # If a 2D or 3D attention mask is provided for the cross-attention 428 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 429 | if self.is_decoder and encoder_hidden_states is not None: 430 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 431 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 432 | if encoder_attention_mask is None: 433 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 434 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 435 | else: 436 | encoder_extended_attention_mask = None 437 | 438 | # Prepare head mask if needed 439 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 440 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) 441 | present_key_value_states = () if use_cache else None 442 | all_hidden_states = () if output_hidden_states else None 443 | all_attentions = () if output_attentions else None 444 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 445 | position_bias = None 446 | encoder_decoder_position_bias = None 447 | 448 | hidden_states = self.dropout(inputs_embeds) 449 | 450 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 451 | layer_head_mask = head_mask[i] 452 | cross_attn_layer_head_mask = cross_attn_head_mask[i] 453 | 454 | if output_hidden_states: 455 | all_hidden_states = all_hidden_states + (hidden_states,) 456 | 457 | if self.gradient_checkpointing and self.training: 458 | if use_cache: 459 | use_cache = False 460 | 461 | def create_custom_forward(module): 462 | def custom_forward(*inputs, **kwd): 463 | return tuple(module(*inputs, **kwd, use_cache=use_cache, output_attentions=output_attentions)) 464 | 465 | return custom_forward 466 | 467 | layer_outputs = checkpoint( 468 | create_custom_forward(layer_module), 469 | hidden_states, 470 | extended_attention_mask, 471 | position_bias, 472 | encoder_hidden_states, 473 | encoder_extended_attention_mask, 474 | encoder_decoder_position_bias, 475 | layer_head_mask, 476 | cross_attn_layer_head_mask, 477 | None, # past_key_value is always None with gradient checkpointing 478 | input_relations=input_relations 479 | ) 480 | else: 481 | layer_outputs = layer_module( 482 | hidden_states, 483 | attention_mask=extended_attention_mask, 484 | position_bias=position_bias, 485 | encoder_hidden_states=encoder_hidden_states, 486 | encoder_attention_mask=encoder_extended_attention_mask, 487 | encoder_decoder_position_bias=encoder_decoder_position_bias, 488 | layer_head_mask=layer_head_mask, 489 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 490 | past_key_value=past_key_value, 491 | use_cache=use_cache, 492 | output_attentions=output_attentions, 493 | input_relations=input_relations 494 | ) 495 | 496 | # layer_outputs is a tuple with: 497 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 498 | if use_cache is False: 499 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] 500 | 501 | hidden_states, present_key_value_state = layer_outputs[:2] 502 | 503 | # We share the position biases between the layers - the first layer store them 504 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), 505 | # (cross-attention position bias), (cross-attention weights) 506 | position_bias = layer_outputs[2] 507 | if self.is_decoder and encoder_hidden_states is not None: 508 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] 509 | # append next layer key value states 510 | if use_cache: 511 | present_key_value_states = present_key_value_states + (present_key_value_state,) 512 | 513 | if output_attentions: 514 | all_attentions = all_attentions + (layer_outputs[3],) 515 | if self.is_decoder: 516 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],) 517 | 518 | hidden_states = self.final_layer_norm(hidden_states) 519 | hidden_states = self.dropout(hidden_states) 520 | 521 | # Add last layer 522 | if output_hidden_states: 523 | all_hidden_states = all_hidden_states + (hidden_states,) 524 | 525 | if not return_dict: 526 | return tuple( 527 | v 528 | for v in [ 529 | hidden_states, 530 | present_key_value_states, 531 | all_hidden_states, 532 | all_attentions, 533 | all_cross_attentions, 534 | ] 535 | if v is not None 536 | ) 537 | return BaseModelOutputWithPastAndCrossAttentions( 538 | last_hidden_state=hidden_states, 539 | past_key_values=present_key_value_states, 540 | hidden_states=all_hidden_states, 541 | attentions=all_attentions, 542 | cross_attentions=all_cross_attentions, 543 | ) 544 | 545 | 546 | class RelationalLongT5Block(LongT5Block): 547 | def forward( 548 | self, 549 | hidden_states, 550 | attention_mask=None, 551 | position_bias=None, 552 | encoder_hidden_states=None, 553 | encoder_attention_mask=None, 554 | encoder_decoder_position_bias=None, 555 | layer_head_mask=None, 556 | cross_attn_layer_head_mask=None, 557 | past_key_value=None, 558 | use_cache=False, 559 | output_attentions=False, 560 | return_dict=True, 561 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 562 | ): 563 | 564 | if past_key_value is not None: 565 | if not self.is_decoder: 566 | logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") 567 | expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 568 | 569 | if len(past_key_value) != expected_num_past_key_values: 570 | raise ValueError( 571 | f"There should be {expected_num_past_key_values} past states. " 572 | f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" 573 | f"Got {len(past_key_value)} past key / value states" 574 | ) 575 | 576 | self_attn_past_key_value = past_key_value[:2] 577 | cross_attn_past_key_value = past_key_value[2:] 578 | else: 579 | self_attn_past_key_value, cross_attn_past_key_value = None, None 580 | 581 | self_attention_outputs = self.layer[0]( 582 | hidden_states, 583 | attention_mask=attention_mask, 584 | position_bias=position_bias, 585 | layer_head_mask=layer_head_mask, 586 | past_key_value=self_attn_past_key_value, 587 | use_cache=use_cache, 588 | output_attentions=output_attentions, 589 | input_relations=input_relations 590 | ) 591 | hidden_states, present_key_value_state = self_attention_outputs[:2] 592 | attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights 593 | 594 | # clamp inf values to enable fp16 training 595 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 596 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 597 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 598 | 599 | do_cross_attention = self.is_decoder and encoder_hidden_states is not None 600 | if do_cross_attention: 601 | # the actual query length is unknown for cross attention 602 | # if using past key value states. Need to inject it here 603 | if present_key_value_state is not None: 604 | query_length = present_key_value_state[0].shape[2] 605 | else: 606 | query_length = None 607 | 608 | cross_attention_outputs = self.layer[1]( 609 | hidden_states, 610 | key_value_states=encoder_hidden_states, 611 | attention_mask=encoder_attention_mask, 612 | position_bias=encoder_decoder_position_bias, 613 | layer_head_mask=cross_attn_layer_head_mask, 614 | past_key_value=cross_attn_past_key_value, 615 | query_length=query_length, 616 | use_cache=use_cache, 617 | output_attentions=output_attentions, 618 | input_relations=input_relations 619 | ) 620 | hidden_states = cross_attention_outputs[0] 621 | 622 | # clamp inf values to enable fp16 training 623 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 624 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 625 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 626 | 627 | # Combine self attn and cross attn key value states 628 | if present_key_value_state is not None: 629 | present_key_value_state = present_key_value_state + cross_attention_outputs[1] 630 | 631 | # Keep cross-attention outputs and relative position weights 632 | attention_outputs = attention_outputs + cross_attention_outputs[2:] 633 | 634 | # Apply Feed Forward layer 635 | hidden_states = self.layer[-1](hidden_states) 636 | 637 | # clamp inf values to enable fp16 training 638 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 639 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 640 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 641 | 642 | outputs = (hidden_states,) 643 | 644 | if use_cache: 645 | outputs = outputs + (present_key_value_state,) + attention_outputs 646 | else: 647 | outputs = outputs + attention_outputs 648 | 649 | return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 650 | 651 | 652 | class RelationalLongT5LayerSelfAttention(LongT5LayerSelfAttention): 653 | def forward( 654 | self, 655 | hidden_states, 656 | attention_mask=None, 657 | position_bias=None, 658 | layer_head_mask=None, 659 | past_key_value=None, 660 | use_cache=False, 661 | output_attentions=False, 662 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 663 | ): 664 | normed_hidden_states = self.layer_norm(hidden_states) 665 | kwargs = {'input_relations': input_relations} \ 666 | if 'input_relations' in inspect.getfullargspec(self.SelfAttention.forward)[0] else {} 667 | attention_output = self.SelfAttention( 668 | normed_hidden_states, 669 | mask=attention_mask, 670 | position_bias=position_bias, 671 | layer_head_mask=layer_head_mask, 672 | past_key_value=past_key_value, 673 | use_cache=use_cache, 674 | output_attentions=output_attentions, 675 | **kwargs 676 | ) 677 | hidden_states = hidden_states + self.dropout(attention_output[0]) 678 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them 679 | return outputs 680 | 681 | 682 | class RelationalLongT5LayerCrossAttention(LongT5LayerCrossAttention): 683 | def forward( 684 | self, 685 | hidden_states, 686 | key_value_states, 687 | attention_mask=None, 688 | position_bias=None, 689 | layer_head_mask=None, 690 | past_key_value=None, 691 | use_cache=False, 692 | query_length=None, 693 | output_attentions=False, 694 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 695 | ): 696 | normed_hidden_states = self.layer_norm(hidden_states) 697 | kwargs = {'input_relations': input_relations} \ 698 | if 'input_relations' in inspect.getfullargspec(self.EncDecAttention.forward)[0] else {} 699 | attention_output = self.EncDecAttention( 700 | normed_hidden_states, 701 | mask=attention_mask, 702 | key_value_states=key_value_states, 703 | position_bias=position_bias, 704 | layer_head_mask=layer_head_mask, 705 | past_key_value=past_key_value, 706 | use_cache=use_cache, 707 | query_length=query_length, 708 | output_attentions=output_attentions, 709 | **kwargs 710 | ) 711 | layer_output = hidden_states + self.dropout(attention_output[0]) 712 | outputs = (layer_output,) + attention_output[1:] # add attentions if we output them 713 | return outputs 714 | 715 | 716 | class RelationalLongT5LayerLocalSelfAttention(LongT5LayerLocalSelfAttention): 717 | def forward( 718 | self, 719 | hidden_states, 720 | attention_mask=None, 721 | position_bias=None, 722 | layer_head_mask=None, 723 | output_attentions=False, 724 | input_relations=None, # will hold (batch, seq_length, seq_length, num_relation_kinds) 725 | **kwargs: Any, # to accept past_key_value and use_cache kwargs 726 | ): 727 | normed_hidden_states = self.layer_norm(hidden_states) 728 | kwargs = {'input_relations': input_relations} \ 729 | if 'input_relations' in inspect.getfullargspec(self.LocalSelfAttention.forward)[0] else {} 730 | attention_output = self.LocalSelfAttention( 731 | normed_hidden_states, 732 | mask=attention_mask, 733 | position_bias=position_bias, 734 | layer_head_mask=layer_head_mask, 735 | output_attentions=output_attentions, 736 | **kwargs 737 | ) 738 | hidden_states = hidden_states + self.dropout(attention_output[0]) 739 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them 740 | return outputs 741 | 742 | 743 | class RelationalLongT5LayerTransientGlobalSelfAttention(LongT5LayerTransientGlobalSelfAttention): 744 | def forward( 745 | self, 746 | hidden_states, 747 | attention_mask=None, 748 | position_bias=None, 749 | layer_head_mask=None, 750 | output_attentions=False, 751 | input_relations=None, # will hold (batch, seq_length, seq_length, num_relation_kinds) 752 | **kwargs: Any, # to accept past_key_value and use_cache kwargs 753 | ): 754 | normed_hidden_states = self.layer_norm(hidden_states) 755 | kwargs = {'input_relations': input_relations} \ 756 | if 'input_relations' in inspect.getfullargspec(self.TransientGlobalSelfAttention.forward)[0] else {} 757 | attention_output = self.TransientGlobalSelfAttention( 758 | normed_hidden_states, 759 | mask=attention_mask, 760 | position_bias=position_bias, 761 | layer_head_mask=layer_head_mask, 762 | output_attentions=output_attentions, 763 | **kwargs 764 | ) 765 | hidden_states = hidden_states + self.dropout(attention_output[0]) 766 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them 767 | return outputs 768 | 769 | 770 | class RelationalLongT5Attention(LongT5Attention): 771 | def forward( 772 | self, 773 | hidden_states, 774 | mask=None, 775 | key_value_states=None, 776 | position_bias=None, 777 | past_key_value=None, 778 | layer_head_mask=None, 779 | query_length=None, 780 | use_cache=False, 781 | output_attentions=False, 782 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 783 | ): 784 | """ 785 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 786 | """ 787 | 788 | # Input is (batch_size, seq_length, dim) 789 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) 790 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) 791 | batch_size, seq_length = hidden_states.shape[:2] 792 | 793 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model" 794 | assert input_relations.shape == (batch_size, seq_length, seq_length) 795 | 796 | # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds) 797 | relation_k_embeds = self.relation_k_emb(input_relations) 798 | relation_v_embeds = self.relation_v_emb(input_relations) 799 | 800 | real_seq_length = seq_length 801 | 802 | if past_key_value is not None: 803 | assert ( 804 | len(past_key_value) == 2 805 | ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 806 | real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length 807 | 808 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] 809 | 810 | def shape(states): 811 | """projection""" 812 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 813 | 814 | def unshape(states): 815 | """reshape""" 816 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) 817 | 818 | def project(hidden_states, proj_layer, key_value_states, past_key_value): 819 | """projects hidden states correctly to key/query states""" 820 | if key_value_states is None: 821 | # self-attn 822 | # (batch_size, n_heads, seq_length, dim_per_head) 823 | hidden_states = shape(proj_layer(hidden_states)) 824 | elif past_key_value is None: 825 | # cross-attn 826 | # (batch_size, n_heads, seq_length, dim_per_head) 827 | hidden_states = shape(proj_layer(key_value_states)) 828 | 829 | if past_key_value is not None: 830 | if key_value_states is None: 831 | # self-attn 832 | # (batch_size, n_heads, key_length, dim_per_head) 833 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2) 834 | else: 835 | # cross-attn 836 | hidden_states = past_key_value 837 | return hidden_states 838 | 839 | # get query states 840 | query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) 841 | 842 | # get key/value states 843 | key_states = project( 844 | hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None 845 | ) 846 | value_states = project( 847 | hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None 848 | ) 849 | 850 | # compute scores 851 | scores = torch.matmul( 852 | query_states, key_states.transpose(3, 2) 853 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 854 | 855 | # q_t is [batch, seq_length, n_heads, dim_per_head] 856 | q_t = query_states.permute(0, 2, 1, 3) 857 | 858 | # r_t is [batch, seq_length, dim_per_head, seq_length] 859 | r_t = relation_k_embeds.transpose(-2, -1) 860 | 861 | q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length] 862 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length] 863 | 864 | # Add to scores 865 | scores += q_tr_tmatmul_t 866 | 867 | if position_bias is None: 868 | if not self.has_relative_attention_bias: 869 | position_bias = torch.zeros( 870 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype 871 | ) 872 | if self.gradient_checkpointing and self.training: 873 | position_bias.requires_grad = True 874 | else: 875 | position_bias = self.compute_bias(real_seq_length, key_length) 876 | 877 | # if key and values are already calculated 878 | # we want only the last query position bias 879 | if past_key_value is not None: 880 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :] 881 | 882 | if mask is not None: 883 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) 884 | 885 | if self.pruned_heads: 886 | mask = torch.ones(position_bias.shape[1]) 887 | mask[list(self.pruned_heads)] = 0 888 | position_bias_masked = position_bias[:, mask.bool()] 889 | else: 890 | position_bias_masked = position_bias 891 | 892 | scores += position_bias_masked 893 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( 894 | scores 895 | ) # (batch_size, n_heads, seq_length, key_length) 896 | attn_weights = nn.functional.dropout( 897 | attn_weights, p=self.dropout, training=self.training 898 | ) # (batch_size, n_heads, seq_length, key_length) 899 | 900 | # Mask heads if we want to 901 | if layer_head_mask is not None: 902 | attn_weights = attn_weights * layer_head_mask 903 | 904 | # [batch, n_heads, seq_length, seq_length] 905 | wv_matmul = torch.matmul(attn_weights, value_states) 906 | 907 | # w_t is [batch, seq_length, n_heads, seq_length] 908 | w_t = attn_weights.permute(0, 2, 1, 3) 909 | 910 | # [batch, seq_length, n_heads, seq_length] 911 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds) 912 | 913 | attn_output = unshape(wv_matmul + w_tr_matmul.permute(0, 2, 1, 3)) # (batch_size, seq_length, dim) 914 | attn_output = self.o(attn_output) 915 | 916 | present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None 917 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) 918 | 919 | if output_attentions: 920 | outputs = outputs + (attn_weights,) 921 | return outputs 922 | 923 | 924 | class RelationalLongT5LocalAttention(LongT5LocalAttention): 925 | def forward( 926 | self, 927 | hidden_states, 928 | mask=None, 929 | position_bias=None, 930 | layer_head_mask=None, 931 | output_attentions=False, 932 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 933 | ): 934 | batch_size, seq_length = hidden_states.shape[:2] 935 | 936 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model" 937 | assert input_relations.shape == (batch_size, seq_length, seq_length) 938 | 939 | # (batch_size, seq_length, seq_length, dim_per_head) 940 | relation_k_embeds = self.relation_k_emb(input_relations) 941 | relation_v_embeds = self.relation_v_emb(input_relations) 942 | 943 | def shape(states): 944 | """projection""" 945 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) 946 | 947 | def unshape(states): 948 | """reshape""" 949 | return states.contiguous().view(batch_size, -1, self.inner_dim) 950 | 951 | # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head) 952 | query_states = shape(self.q(hidden_states)) 953 | key_states = shape(self.k(hidden_states)) 954 | value_states = shape(self.v(hidden_states)) 955 | 956 | # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head) 957 | query_states = _split_into_blocks(query_states, self.block_len, dim=1) 958 | key_states = _split_into_blocks(key_states, self.block_len, dim=1) 959 | value_states = _split_into_blocks(value_states, self.block_len, dim=1) 960 | 961 | # Split relation_k_embeds, relation_v_embeds into blocks -> (batch_size, num_blocks, block_len, seq_length, dim_per_head) 962 | relation_k_embeds = _split_into_blocks(relation_k_embeds, self.block_len, dim=1) 963 | relation_v_embeds = _split_into_blocks(relation_v_embeds, self.block_len, dim=1) 964 | 965 | # Resize relation_k_embeds and relation_v_embeds -> (batch_size, num_blocks, block_len, block_len, dim_per_head) 966 | # each block can only have relations with another block, not with the full length 967 | num_blocks = query_states.shape[1] 968 | def get_new_relation_embeds(relation_embeds): 969 | new_relation_embeds = [] 970 | for block_i in range(num_blocks): 971 | new_relation_embeds.append( 972 | relation_embeds[:, block_i, :, block_i * self.block_len: (block_i + 1) * self.block_len, 973 | :].unsqueeze(1) 974 | ) 975 | if new_relation_embeds[-1].shape[-2] % self.block_len != 0: 976 | # pad tensor to multiple of block_len 977 | new_relation_embeds[-1] = _pad_to_multiple( 978 | new_relation_embeds[-1], self.block_len, dim=3, pad_value=0 979 | ) 980 | return torch.cat(new_relation_embeds, dim=1) 981 | relation_k_embeds = get_new_relation_embeds(relation_k_embeds) 982 | relation_v_embeds = get_new_relation_embeds(relation_v_embeds) 983 | 984 | # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) 985 | key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2) 986 | value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2) 987 | 988 | # Concatenate 3 blocks -> (batch_size, num_blocks, 3 * block_len, block_len, dim_per_head) 989 | relation_k_embeds = _concatenate_3_blocks(relation_k_embeds, block_dim=1, sequence_dim=2) 990 | relation_v_embeds = _concatenate_3_blocks(relation_v_embeds, block_dim=1, sequence_dim=2) 991 | 992 | # Compute scores 993 | scores = torch.einsum( 994 | "...qhd,...khd->...hqk", query_states, key_states 995 | ) # (batch_size, num_blocks, n_heads, block_len, 3 * block_len) 996 | 997 | # q_t is (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) 998 | q_t = _concatenate_3_blocks(query_states, block_dim=1, sequence_dim=2) 999 | 1000 | # r_t is (batch_size, num_blocks, 3 * block_len, dim_per_head, block_len) 1001 | r_t = relation_k_embeds.transpose(-2, -1) 1002 | 1003 | # (batch_size, num_blocks, 3 * block_len, n_heads, block_len) 1004 | q_tr_t_matmul = torch.matmul(q_t, r_t) 1005 | 1006 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len) 1007 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 1, 3, 4, 2) 1008 | 1009 | # Add to scores 1010 | scores += q_tr_tmatmul_t 1011 | 1012 | if position_bias is None: 1013 | # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) 1014 | if not self.has_relative_attention_bias: 1015 | position_bias = torch.zeros( 1016 | (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype 1017 | ) 1018 | if self.gradient_checkpointing and self.training: 1019 | position_bias.requires_grad = True 1020 | else: 1021 | position_bias = self.compute_bias(self.block_len) 1022 | 1023 | if mask is not None: 1024 | # Replace masked positions with -1e10 (according to the original implementation) 1025 | mask = torch.where(mask > 0, 0.0, -1e10) 1026 | # We need to adjust position bias shape to be sum with mask 1027 | position_bias = position_bias + mask.transpose(1, 2) 1028 | 1029 | scores += position_bias 1030 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len) 1031 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) 1032 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len) 1033 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 1034 | 1035 | # Mask heads if we want to 1036 | if layer_head_mask is not None: 1037 | attn_weights = attn_weights * layer_head_mask 1038 | 1039 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head) 1040 | attn_weights = attn_weights.type(value_states.dtype) 1041 | wv_matmul = torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states) 1042 | 1043 | # (batch_size, num_blocks, block_len, n_heads, 3 * block_len) 1044 | w_t = attn_weights.permute(0, 1, 3, 2, 4) 1045 | 1046 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head) 1047 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds.transpose(-3, -2)) 1048 | 1049 | # (batch_size, seq_length, dim_per_head) 1050 | attn_output = unshape(wv_matmul + w_tr_matmul)[:, :seq_length, :] 1051 | 1052 | # (batch_size, seq_length, d_model) 1053 | attn_output = self.o(attn_output) 1054 | 1055 | present_key_value_state = None 1056 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) 1057 | 1058 | if output_attentions: 1059 | outputs = outputs + (attn_weights,) 1060 | return outputs 1061 | 1062 | 1063 | class RelationalLongT5TransientGlobalAttention(LongT5TransientGlobalAttention): 1064 | def forward( 1065 | self, 1066 | hidden_states, 1067 | mask=None, 1068 | position_bias=None, 1069 | layer_head_mask=None, 1070 | output_attentions=False, 1071 | input_relations=None # will hold (batch, seq_length, seq_length, num_relation_kinds) 1072 | ): 1073 | batch_size, seq_length = hidden_states.shape[:2] 1074 | 1075 | assert input_relations is not None, "Using RATransformer but no 'input_relations' were passed to the model" 1076 | assert input_relations.shape == (batch_size, seq_length, seq_length) 1077 | 1078 | # (batch_size, seq_length, seq_length, dim_per_head) 1079 | relation_k_embeds = self.relation_k_emb(input_relations) 1080 | relation_v_embeds = self.relation_v_emb(input_relations) 1081 | 1082 | def shape(states): 1083 | """projection""" 1084 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) 1085 | 1086 | def unshape(states): 1087 | """reshape""" 1088 | return states.contiguous().view(batch_size, -1, self.inner_dim) 1089 | 1090 | # Prepare components for transient-global attention 1091 | # Obtain block_ids and global_segment_ids 1092 | # global_seq_len := seq_len // self.global_block_size 1093 | # shapes: (batch_size, seq_len) & (batch_size, global_seq_len) 1094 | block_ids, global_segment_ids = _make_global_fixed_block_ids( 1095 | mask if mask is not None else torch.ones(hidden_states.shape[:-1]), 1096 | self.global_block_size, 1097 | ) 1098 | # Create global inputs 1099 | _global_seq_len = global_segment_ids.shape[-1] 1100 | global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len) 1101 | global_inputs = self.global_input_layer_norm(global_inputs) 1102 | 1103 | # get query states -> (batch_size, seq_length, n_heads, dim_per_head) 1104 | query_states = shape(self.q(hidden_states)) 1105 | key_states = shape(self.k(hidden_states)) 1106 | value_states = shape(self.v(hidden_states)) 1107 | # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head) 1108 | side_key_states = shape(self.k(global_inputs)) 1109 | side_value_states = shape(self.v(global_inputs)) 1110 | 1111 | # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head) 1112 | query_states = _split_into_blocks(query_states, self.block_len, dim=1) 1113 | key_states = _split_into_blocks(key_states, self.block_len, dim=1) 1114 | value_states = _split_into_blocks(value_states, self.block_len, dim=1) 1115 | 1116 | # Split relation_k_embeds, relation_v_embeds into blocks -> (batch_size, num_blocks, block_len, seq_length, dim_per_head) 1117 | relation_k_embeds = _split_into_blocks(relation_k_embeds, self.block_len, dim=1) 1118 | relation_v_embeds = _split_into_blocks(relation_v_embeds, self.block_len, dim=1) 1119 | 1120 | # Resize relation_k_embeds and relation_v_embeds -> (batch_size, num_blocks, block_len, block_len, dim_per_head) 1121 | # each block can only have relations with another block, not with the full length 1122 | num_blocks = query_states.shape[1] 1123 | def get_new_relation_embeds(relation_embeds): 1124 | new_relation_embeds = [] 1125 | for block_i in range(num_blocks): 1126 | new_relation_embeds.append( 1127 | relation_embeds[:, block_i, :, block_i * self.block_len: (block_i + 1) * self.block_len, :].unsqueeze(1) 1128 | ) 1129 | if new_relation_embeds[-1].shape[-2] % self.block_len != 0: 1130 | # pad tensor to multiple of block_len 1131 | new_relation_embeds[-1] = _pad_to_multiple( 1132 | new_relation_embeds[-1], self.block_len, dim=3, pad_value=0 1133 | ) 1134 | return torch.cat(new_relation_embeds, dim=1) 1135 | relation_k_embeds = get_new_relation_embeds(relation_k_embeds) 1136 | relation_v_embeds = get_new_relation_embeds(relation_v_embeds) 1137 | 1138 | # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) 1139 | key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2) 1140 | value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2) 1141 | 1142 | # Concatenate 3 blocks -> (batch_size, num_blocks, 3 * block_len, block_len, dim_per_head) 1143 | relation_k_embeds = _concatenate_3_blocks(relation_k_embeds, block_dim=1, sequence_dim=2) 1144 | relation_v_embeds = _concatenate_3_blocks(relation_v_embeds, block_dim=1, sequence_dim=2) 1145 | 1146 | # Tile side inputs across local key/value blocks 1147 | # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head) 1148 | reps = [1] * (side_key_states.ndim + 1) 1149 | reps[1] = key_states.shape[1] 1150 | side_key_states = side_key_states.unsqueeze(1).repeat(reps) 1151 | side_value_states = side_value_states.unsqueeze(1).repeat(reps) 1152 | 1153 | # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones 1154 | # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head) 1155 | key_states = torch.cat([key_states, side_key_states], dim=2) 1156 | value_states = torch.cat([value_states, side_value_states], dim=2) 1157 | 1158 | # Add zeros to relation_k_embeds and relation_v_embeds 1159 | # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, block_len, dim_per_head) 1160 | zeros_shape_to_add = list(relation_k_embeds.shape[:2]) + [side_value_states.shape[2]] + list(relation_k_embeds.shape[3:]) 1161 | relation_k_embeds = torch.cat( 1162 | [relation_k_embeds, torch.zeros(zeros_shape_to_add, device=relation_k_embeds.device)], dim=2 1163 | ) 1164 | relation_v_embeds = torch.cat( 1165 | [relation_v_embeds, torch.zeros(zeros_shape_to_add, device=relation_v_embeds.device)], dim=2 1166 | ) 1167 | 1168 | # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len) 1169 | scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states) 1170 | 1171 | # q_t is (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head) 1172 | q_t = _concatenate_3_blocks(query_states, block_dim=1, sequence_dim=2) 1173 | q_t = torch.cat([q_t, side_value_states], dim=2) 1174 | 1175 | # r_t is (batch_size, num_blocks, 3 * block_len + global_seq_len, dim_per_head, block_len) 1176 | r_t = relation_k_embeds.transpose(-2, -1) 1177 | 1178 | # (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, block_len) 1179 | q_tr_t_matmul = torch.matmul(q_t, r_t) 1180 | 1181 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len) 1182 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 1, 3, 4, 2) 1183 | 1184 | # Add to scores 1185 | scores += q_tr_tmatmul_t 1186 | 1187 | if mask is not None: 1188 | # We need to adjust position bias shape to be sum with mask 1189 | local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device) 1190 | # Replace masked positions with -10_000 (according to the original implementation) 1191 | local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10) 1192 | else: 1193 | local_attention_mask = None 1194 | 1195 | if position_bias is None: 1196 | # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) 1197 | if not self.has_relative_attention_bias: 1198 | position_bias = torch.zeros( 1199 | (1, 1, self.n_heads, self.block_len, 3 * self.block_len), 1200 | device=scores.device, 1201 | dtype=scores.dtype, 1202 | ) 1203 | if self.gradient_checkpointing and self.training: 1204 | position_bias.requires_grad = True 1205 | else: 1206 | position_bias = self.compute_bias(self.block_len) 1207 | 1208 | if local_attention_mask is not None: 1209 | # (batch_size, 1, n_heads, block_len, 3 * block_len) 1210 | position_bias = position_bias + local_attention_mask.transpose(1, 2) 1211 | position_bias = position_bias.type(scores.dtype) 1212 | 1213 | # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len) 1214 | if mask is None: 1215 | mask = torch.ones(batch_size, seq_length) 1216 | # (batch_size, num_heads, seq_len, global_seq_len) 1217 | side_position_bias = self.compute_side_bias(mask, global_segment_ids) 1218 | # (batch_size, num_blocks, num_heads, block_len, global_seq_len) 1219 | side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2) 1220 | side_position_bias = side_position_bias.type(scores.dtype).to(scores.device) 1221 | # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len) 1222 | position_bias = torch.cat([position_bias, side_position_bias], dim=-1) 1223 | 1224 | scores += position_bias 1225 | # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len) 1226 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) 1227 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 1228 | 1229 | # Mask heads if we want to 1230 | if layer_head_mask is not None: 1231 | attn_weights = attn_weights * layer_head_mask 1232 | 1233 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head) 1234 | attn_weights = attn_weights.type(value_states.dtype) 1235 | wv_matmul = torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states) 1236 | 1237 | # (batch_size, num_blocks, block_len, n_heads, 3 * block_len + global_seq_len) 1238 | w_t = attn_weights.permute(0, 1, 3, 2, 4) 1239 | 1240 | # (batch_size, num_blocks, block_len, n_heads, dim_per_head) 1241 | w_tr_matmul = torch.matmul(w_t, relation_v_embeds.transpose(-3, -2)) 1242 | 1243 | # (batch_size, seq_length, dim_per_head) 1244 | attn_output = unshape(wv_matmul + w_tr_matmul)[:, :seq_length, :] 1245 | 1246 | # (batch_size, seq_length, d_model) 1247 | attn_output = self.o(attn_output) 1248 | 1249 | present_key_value_state = None 1250 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) 1251 | 1252 | if output_attentions: 1253 | outputs = outputs + (attn_weights,) 1254 | return outputs 1255 | 1256 | --------------------------------------------------------------------------------