├── .gitattributes ├── images └── tilt_arch.png ├── requirements.txt ├── mkdocs.yml ├── setup.cfg ├── setup.py ├── LICENSE ├── how_did_i_prepare_the_stuffs ├── README.md ├── tilt_part_3_1_aligning_all_the_parts_to_make_tilt.ipynb └── tilt_part_2_3_sample_preparing_funsd_for_t5_dataset.ipynb ├── .gitignore ├── README.md └── src ├── visual_backbone.py ├── dataset.py └── t5.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /images/tilt_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/TiLT-Implementation/HEAD/images/tilt_arch.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | datasets 3 | sentencepiece 4 | pytorch_lightning 5 | seqeval 6 | evaluate 7 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: TiLT 2 | site_url: https://uakarsh.github.io/tilt 3 | copyright: MIT 4 | theme: 5 | name: "material" 6 | palette: 7 | primary: "red" 8 | accent: "red" 9 | 10 | repo_name: uakarsh/TiLT-Implementation 11 | repo_url: https://github.com/uakarsh/TiLT-Implementation 12 | 13 | nav: 14 | - Home: pad_tokens_start_idx.md -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503, E203, B305 3 | max-line-length = 88 4 | 5 | [mypy] 6 | disallow_untyped_defs = True 7 | ignore_missing_imports = True 8 | 9 | [tool:isort] 10 | profile = black 11 | known_first_party = tilt,tests 12 | 13 | [tool:pytest] 14 | testpaths = tests 15 | addopts = 16 | -rxXs 17 | --cov=tilt 18 | --cov=tests 19 | --cov-report=term-missing 20 | --cov-fail-under=80 21 | --cov-config=.coveragerc -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'tilt_transformers', 5 | packages = find_packages(where="src"), 6 | package_dir = {"": "src", "docformer": "src/"}, 7 | version = '0.1.0', 8 | license='MIT', 9 | description = 'Going Full-TILT Boogie on Document Understanding with Text-Image-Layout Transformer:', 10 | author = 'Akarsh Upadhay', 11 | author_email = 'akarshupadhyayabc@gmail.com', 12 | url = 'https://github.com/uakarsh/TiLT-Implementation', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'attention mechanism', 16 | 'document understanding', 17 | ], 18 | install_requires=[ 19 | 'torch>=1.6', 20 | 'torchvision', 21 | 'transformers', 22 | 'sentencepiece', 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.7', 30 | ], 31 | 32 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /how_did_i_prepare_the_stuffs/README.md: -------------------------------------------------------------------------------- 1 | ### A big note here about the preparation of the stuffs. 2 | 3 | * The notebooks from 1 to 3.2 used generative approach to solve the problem. This was part from the idea that T5 transformers are based on the generative approach. 4 | 5 | * However, in the conclusion section of the TiLT transformer, the authors have mentioned that they have used the extractive approach, which means predicting the logits, and I have missed that part right now. I have added the code for preparing the FUNSD abstractive dataset, as well as the same would be followed for the CORD dataset. 6 | 7 | * Also, I have the code for DocVQA (for extractive tasks, which includes predicting the start and the end logits of the answer from the context) ready, and I would also add it soon 8 | 9 | * It would take me a while, to prepare the modeling approach for abstractive approach (as when I was going to finish the generative approach, I visited the paper and saw that the authors have used the extractive approach). 10 | 11 | * The idea was, actually confusing, as I was also ready for using the abstractive approach, but when I saw the T5's approach, I guess it hit me, and made me do the generative approach. Although, all the code are ready, I guess I would take a stop, and visit the abstractive approach for now. Let's see how this goes. 12 | 13 | * By the way, if time permits, I would soon add the code for FUNSD, CORD as well as DocVQA, since I have worked on them, and have the idea to finetune the model on the same. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Going Full-TILT Boogie on Document Understanding with Text-Image-Layout Transformer: PyTorch Implementation 2 | 3 | ![TiLT architecture](images/tilt_arch.png) 4 | 5 | This repository contains the implementation of the paper: [Going Full-TILT Boogie on Document Understanding with Text-Image-Layout Transformer](https://arxiv.org/pdf/2102.09550v3.pdf). Note that, the authors have not released the original implementation of the paper. 6 | 7 | Abstract: We address the challenging problem of Natural Language Comprehension beyond plain-text documents by introducing the TILT neural network architecture which simultaneously learns layout information, visual features, and textual semantics. Contrary to previous approaches, we rely on a decoder capable of unifying a variety of problems involving natural language. The layout is represented as an attention bias and complemented with contextualized visual information, while the core of our model is a pretrained encoder-decoder Transformer. Our novel approach achieves state-of-the-art results in extracting information from documents and answering questions which demand layout understanding (DocVQA, CORD, SROIE). At the same time, we simplify the process by employing an end-to-end model. 8 | 9 | 10 | ## Requirements 11 | * See in the requirements.txt file 12 | 13 | 14 | ## Dataset 15 | * I would be including the [FUNSD Dataset](https://guillaumejaume.github.io/FUNSD/), as well as the [CORD Dataset](https://github.com/clovaai/cord) soon. Currently, the entire approach is being implemented, and due to my silly mistakes, it would take me a while to prepare the entire pipeline. 16 | 17 | 18 | ## Pretrained Models 19 | * I am not sure, if I would be able to include the pretrained models, due to resource constraints, but would add the finetuning code for FUNSD, CORD and DocVQA soon. 20 | 21 | 22 | ## Modeling: 23 | * The modeling part of the pipeline, basically is inspired from [HuggingFace's T5 implementation](https://huggingface.co/docs/transformers/model_doc/t5), and the initialization of the weights are being done from the same. The code for the same is available in the `src/t5.py` file. 24 | 25 | 26 | ## Examples: 27 | * For finetuning TiLT on CORD, the example along with the results are present [here](https://github.com/uakarsh/TiLT-Implementation/blob/main/experiments/cord-tilt-part-4-1-abstractive-approach-for-t.ipynb) 28 | 29 | * Similarily, for finetuning TiLT on FUNSD, the example along with the results are present [here](https://github.com/uakarsh/TiLT-Implementation/blob/main/experiments/tilt-part-4-1-abstractive-approach-for-training.ipynb) 30 | 31 | 32 | ## My Results: 33 | | Model Name | Dataset Name | Number of Parameters | Overall Precision | Overall Recall | Overall F1 Score | Overall Accuracy | 34 | |-----------------|--------------|----------------------|-------------------|----------------|------------------|------------------| 35 | | TILT | FUNSD | 225M | 57.58 | 42.25 | 48.87 | 83.60 | 36 | | TILT | CORD | 225M | 64.81 | 62.64 | 63.71 | 80.52 | 37 | | TILT(Original) | CORD | 230M | --- | --- | 95.11 | --- | 38 | 39 | Note, that in the case of my results on CORD, the model has not been pre-trained (the weights are intialized from the hugging face's implementation), and it has been trained for 30 epochs, while in the original paper, the authors have trained on 360,000 steps which is roughly equivalent to 360,000 / 100 = 360 epochs. (100 comes from 800 / 8, since 8 is the batch size mentioned in the paper, and 800 are the training examples in the CORD dataset) 40 | 41 | ## Citation 42 | If you find this repository useful, please cite the following paper: 43 | ```bibtex 44 | @inproceedings{powalski2021going, 45 | title={Going full-tilt boogie on document understanding with text-image-layout transformer}, 46 | author={Powalski, Rafa{\l} and Borchmann, {\L}ukasz and Jurkiewicz, Dawid and Dwojak, Tomasz and Pietruszka, Micha{\l} and Pa{\l}ka, Gabriela}, 47 | booktitle={Document Analysis and Recognition--ICDAR 2021: 16th International Conference, Lausanne, Switzerland, September 5--10, 2021, Proceedings, Part II 16}, 48 | pages={732--747}, 49 | year={2021}, 50 | organization={Springer} 51 | } 52 | ``` 53 | 54 | ## License 55 | This project is licensed under the MIT License - see the LICENSE file for details 56 | -------------------------------------------------------------------------------- /src/visual_backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from torchvision.ops import roi_pool 5 | 6 | # Convolution block for UNet Encoder 7 | class ConvBlock(nn.Module): 8 | """ 9 | A Convolutional Block that consists of two convolution layers each followed by 10 | instance normalization, LeakyReLU activation and dropout. 11 | """ 12 | 13 | def __init__(self, in_chans: int, out_chans: int, drop_prob: float): 14 | """ 15 | Args: 16 | in_chans: Number of channels in the input. 17 | out_chans: Number of channels in the output. 18 | drop_prob: Dropout probability. 19 | """ 20 | super().__init__() 21 | 22 | self.in_chans = in_chans 23 | self.out_chans = out_chans 24 | self.drop_prob = drop_prob 25 | 26 | self.layers = nn.Sequential( 27 | nn.Conv2d(in_chans, out_chans, kernel_size=3, 28 | padding=1, bias=False), 29 | nn.InstanceNorm2d(out_chans), 30 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 31 | nn.Dropout2d(drop_prob), 32 | nn.Conv2d(out_chans, out_chans, kernel_size=3, 33 | padding=1, bias=False), 34 | nn.InstanceNorm2d(out_chans), 35 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 36 | nn.Dropout2d(drop_prob), 37 | nn.Conv2d(out_chans, out_chans, kernel_size=3, 38 | padding=1, bias=False), 39 | nn.InstanceNorm2d(out_chans), 40 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 41 | nn.Dropout2d(drop_prob), 42 | ) 43 | 44 | def forward(self, image: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Args: 47 | image: Input 4D tensor of shape `(N, in_chans, H, W)`. 48 | Returns: 49 | Output tensor of shape `(N, out_chans, H, W)`. 50 | """ 51 | return self.layers(image) 52 | 53 | 54 | # UNet Encoder 55 | class Unet_encoder(nn.Module): 56 | 57 | def __init__(self, 58 | in_channels: int = 3, 59 | channels: int = 32, 60 | num_pool_layers: int = 4, 61 | drop_prob: float = 0.0 62 | ): 63 | """ 64 | Args: 65 | in_chans: Number of channels in the input to the U-Net model. 66 | out_chans: Number of channels in the output to the U-Net model. 67 | chans: Number of output channels of the first convolution layer. 68 | num_pool_layers: Number of down-sampling and up-sampling layers. 69 | drop_prob: Dropout probability. 70 | """ 71 | super().__init__() 72 | 73 | self.in_channels = in_channels 74 | self.channels = channels 75 | 76 | self.num_pool_layers = num_pool_layers 77 | self.drop_prob = drop_prob 78 | 79 | self.down_sample_layers = nn.ModuleList([ 80 | ConvBlock(in_channels, channels, drop_prob) 81 | ]) 82 | ch = channels 83 | 84 | for _ in range(num_pool_layers - 1): 85 | self.down_sample_layers.append(ConvBlock(ch, ch*2, drop_prob)) 86 | ch *= 2 87 | 88 | self.conv = ConvBlock(ch, ch*2, drop_prob) 89 | 90 | def forward(self, image: torch.Tensor) -> torch.Tensor: 91 | """ 92 | Args: 93 | Image: Input 4D tensor of shape (Batch Size, in channels, H, W) 94 | Returns: 95 | Output tensor of shape (Batch Size, out_channels, H, W) 96 | """ 97 | output = image 98 | 99 | # Appplying down sample layers 100 | for num, layer in enumerate(self.down_sample_layers): 101 | output = layer(output) 102 | output = F.max_pool2d(output, kernel_size=2, stride=2, padding=0) 103 | 104 | output = self.conv(output) 105 | return output 106 | 107 | 108 | # RoI Align, it was a mistake, I assumed RoIPool for RoIALign, but it was not the case 109 | 110 | class RoIAlign(nn.Module): 111 | def __init__(self, output_size=(3, 3), spatial_scale=0.125, sampling_ratio=2): 112 | super().__init__() 113 | 114 | """ 115 | Args 116 | output_size: (h, w) of the output feature map 117 | spatial_scale: ratio of the input feature map height (or w) to the raw image height (or w). 118 | Equals the reciprocal of total stride in convolutional layers 119 | sampling_ratio: number of inputs samples to take for each output sample 120 | """ 121 | 122 | # self.output_size = output_size 123 | # self.spatial_scale = spatial_scale 124 | # self.sampling_ratio = sampling_ratio 125 | self.roi_align = RoIAlign( 126 | output_size, spatial_scale=spatial_scale, sampling_ratio=sampling_ratio) 127 | 128 | def forward(self, image_embedding, bboxes): 129 | """ 130 | Args: 131 | image_embedding: Input 4D tensor of shape (Batch size, in channels, H, W) 132 | bboxes: Input 3D Tensor of shape (Batch Size, max sequence length, 4) (4 corresponding to xmin, ymin, xmax, ymax) 133 | Returns: 134 | feature_maps_bboxes: tensor of shape (batch, max sequence length, in channels, *output_size) 135 | """ 136 | 137 | feature_maps_bboxes = [] 138 | for single_batch_img, single_batch_bbox in zip(image_embedding, bboxes): 139 | feature_map_single_batch = self.roi_align(input=single_batch_img.unsqueeze(0), 140 | rois=torch.cat([torch.zeros(single_batch_bbox.shape[0], 1).to( 141 | single_batch_bbox.device), single_batch_bbox], axis=-1).float() 142 | ) 143 | feature_maps_bboxes.append(feature_map_single_batch) 144 | 145 | return torch.stack(feature_maps_bboxes, axis=0) 146 | 147 | 148 | # RoIPool 149 | 150 | class RoIPool(nn.Module): 151 | 152 | def __init__(self, output_size=(3, 3), spatial_scale=0.125): 153 | super().__init__() 154 | """Args 155 | output_size: (h, w) of the output feature map 156 | spatial_scale: ratio of the input feature map height (or w) to the raw image height (or w). 157 | Equals the reciprocal of total stride in convolutional layers 158 | """ 159 | 160 | self.output_size = output_size 161 | self.spatial_scale = spatial_scale 162 | self.roi_pool = roi_pool 163 | 164 | def forward(self, image_embedding, bboxes): 165 | """ 166 | Args: 167 | image_embedding: Input 4D tensor of shape (Batch size, in channels, H, W) 168 | bboxes: Input 3D Tensor of shape (Batch Size, max sequence length, 4) (4 corresponding to xmin, ymin, xmax, ymax) 169 | Returns: 170 | feature_maps_bboxes: tensor of shape (batch, max sequence length, in channels, *output_size) 171 | """ 172 | 173 | feature_maps_bboxes = [] 174 | for single_batch_img, single_batch_bbox in zip(image_embedding, bboxes): 175 | feature_map_single_batch = self.roi_pool(input=single_batch_img.unsqueeze(0), 176 | boxes=torch.cat([torch.zeros(single_batch_bbox.shape[0], 1).to( 177 | single_batch_bbox.device), single_batch_bbox], axis=-1).float(), 178 | output_size=self.output_size, 179 | spatial_scale=self.spatial_scale 180 | ) 181 | feature_maps_bboxes.append(feature_map_single_batch) 182 | 183 | return torch.stack(feature_maps_bboxes, axis=0) 184 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision.transforms import ToTensor 3 | import torch 4 | 5 | 6 | ## This is a dataset referring to the generative problem 7 | ## The Dataset class for FUNSD Dataset (and I believe the same would be use for CORD Dataset) 8 | class FUNSDDs(Dataset): 9 | 10 | def __init__(self, ds, tokenizer, max_seq_length:int = 512, pad_token_box = [0, 0, 0, 0], resize_scale = (512, 384), transform = None): 11 | 12 | """ 13 | Args: 14 | ds (list): list of dict, each dict contains the following keys: 15 | - image (np.ndarray): the image 16 | - tokens (list): list of tokens 17 | - bboxes (list): list of bboxes 18 | - ner_tags (list): list of ner_tags 19 | tokenizer (Tokenizer): the tokenizer 20 | max_seq_length (int, optional): the maximum length of the sequence. Defaults to 512. 21 | pad_token_box (list, optional): the padding token box. Defaults to [0, 0, 0, 0]. 22 | resize_scale (tuple, optional): the resize scale. Defaults to (512, 384). 23 | transform (callable, optional): the transform. Defaults to None. 24 | """ 25 | 26 | self.ds = ds 27 | self.tokenizer = tokenizer 28 | self.max_seq_length = max_seq_length 29 | self.pad_token_box = pad_token_box 30 | self.resize_scale = resize_scale 31 | self.transform = transform if transform is not None else ToTensor() 32 | 33 | def __len__(self): 34 | """ 35 | Returns: 36 | int: the length of the dataset 37 | """ 38 | return len(self.ds) 39 | 40 | def __getitem__(self, idx): 41 | 42 | """ 43 | Args: 44 | idx (int): the index of the data to be returned. 45 | """ 46 | 47 | encoding = self.ds[idx] 48 | 49 | resized_image = encoding['image'].copy().resize(self.resize_scale) 50 | words = encoding['tokens'] 51 | bboxes = encoding['bboxes'] 52 | labels = encoding['ner_tags'] 53 | 54 | ## 1. Performing the image pre-processing 55 | img_tensor = self.transform(resized_image) ## (3, 384, 512) 56 | 57 | ## 2. Performing the semantic pre-processing 58 | encoding = self.tokenizer(words, is_split_into_words = True, add_special_tokens = False) 59 | 60 | # pad_token_box = [0, 0, 0, 0] 61 | max_seq_length = 512 62 | 63 | input_ids = encoding['input_ids'] 64 | attention_mask = encoding['attention_mask'] 65 | 66 | ## Note that, there is no need for bboxes, since the model does not use bbox as feature, so no pre-processing of that 67 | bbox_according_to_tokenizer = [bboxes[i] for i in encoding.word_ids()] 68 | # labels_according_to_tokenizer = [self.tokenizer(str(labels[i] + 1))['input_ids'][0] for i in encoding.word_ids()] 69 | #labels_according_to_tokenizer = [self.tokenizer(str(labels[i] + 1))['input_ids'][0] for i, _ in enumerate(labels)] 70 | 71 | # Truncation of token_boxes + token_labels 72 | special_tokens_count = 1 73 | if len(input_ids) > max_seq_length - special_tokens_count: 74 | bbox_according_to_tokenizer = bbox_according_to_tokenizer[: (max_seq_length - special_tokens_count)] 75 | input_ids = input_ids[: (max_seq_length - special_tokens_count)] 76 | #labels_according_to_tokenizer = labels_according_to_tokenizer[: (max_seq_length - special_tokens_count)] 77 | attention_mask = attention_mask[: (max_seq_length - special_tokens_count)] 78 | 79 | 80 | ## Padding 81 | input_ids = input_ids + [self.tokenizer.eos_token_id] 82 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + [[1000, 1000, 1000, 1000]] 83 | #labels_according_to_tokenizer = labels_according_to_tokenizer + [self.tokenizer.eos_token_id] ## For QA, the model requires an end of sentence i.e eos token 84 | attention_mask = attention_mask + [1] 85 | 86 | pad_length = max_seq_length - len(input_ids) 87 | 88 | input_ids = input_ids + [self.tokenizer.pad_token_id] * (pad_length) 89 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + [self.pad_token_box] * (pad_length) 90 | #labels_according_to_tokenizer = labels_according_to_tokenizer + [self.tokenizer.pad_token_id] * (pad_length) 91 | attention_mask = attention_mask + [0] * (pad_length) 92 | 93 | ## Converting stuffs to tensor 94 | input_ids = torch.tensor(input_ids) 95 | bbox_according_to_tokenizer = torch.tensor(bbox_according_to_tokenizer) 96 | #labels_according_to_tokenizer = torch.tensor(labels_according_to_tokenizer) 97 | attention_mask = torch.tensor(attention_mask) 98 | 99 | return {"input_ids" : input_ids, "labels" : labels, "attention_mask" : attention_mask, "bboxes" : bbox_according_to_tokenizer, # labels_according_to_tokenizer 100 | "pixel_values" : img_tensor} 101 | 102 | 103 | ## This is a dataset referring to the extractive problem, I believe the same would be used for CORD Dataset, and this was also a mistake, since I didn't read the last part of the paper properly 104 | class ExtFUNSDDs(Dataset): 105 | def __init__(self, ds, tokenizer, max_seq_length:int = 512, pad_token_box = [0, 0, 0, 0], resize_scale = (512, 384), transform = None): 106 | 107 | """ 108 | Args: 109 | ds (list): list of dict, each dict contains the following keys: 110 | - image (np.ndarray): the image 111 | - tokens (list): list of tokens 112 | - bboxes (list): list of bboxes 113 | - ner_tags (list): list of ner_tags 114 | tokenizer (Tokenizer): the tokenizer 115 | max_seq_length (int, optional): the maximum length of the sequence. Defaults to 512. 116 | pad_token_box (list, optional): the padding token box. Defaults to [0, 0, 0, 0]. 117 | resize_scale (tuple, optional): the resize scale. Defaults to (512, 384). 118 | transform (callable, optional): the transform. Defaults to None. 119 | """ 120 | 121 | self.ds = ds 122 | self.tokenizer = tokenizer 123 | self.max_seq_length = max_seq_length 124 | self.pad_token_box = pad_token_box 125 | self.resize_scale = resize_scale 126 | self.transform = transform if transform is not None else ToTensor() 127 | 128 | def __len__(self): 129 | """ 130 | Returns: 131 | int: the length of the dataset 132 | """ 133 | return len(self.ds) 134 | 135 | def __getitem__(self, idx): 136 | 137 | """ 138 | Args: 139 | idx (int): the index of the data to be returned. 140 | """ 141 | 142 | encoding = self.ds[idx] 143 | 144 | resized_image = encoding['image'].copy().resize(self.resize_scale) 145 | words = encoding['tokens'] 146 | bboxes = encoding['bboxes'] 147 | labels = encoding['ner_tags'] 148 | 149 | ## 1. Performing the image pre-processing 150 | img_tensor = self.transform(resized_image) ## (3, 384, 512) 151 | 152 | ## 2. Performing the semantic pre-processing 153 | encoding = self.tokenizer(words, is_split_into_words = True, add_special_tokens = False) 154 | 155 | # pad_token_box = [0, 0, 0, 0] 156 | max_seq_length = 512 157 | 158 | input_ids = encoding['input_ids'] 159 | attention_mask = encoding['attention_mask'] 160 | 161 | ## Note that, there is no need for bboxes, since the model does not use bbox as feature, so no pre-processing of that 162 | bbox_according_to_tokenizer = [bboxes[i] for i in encoding.word_ids()] 163 | labels_according_to_tokenizer = [labels[i] for i in encoding.word_ids()] ## Labels have to be in the numerical format 164 | #labels_according_to_tokenizer = [self.tokenizer(str(labels[i] + 1))['input_ids'][0] for i, _ in enumerate(labels)] 165 | 166 | # Truncation of token_boxes + token_labels 167 | special_tokens_count = 1 168 | if len(input_ids) > max_seq_length - special_tokens_count: 169 | bbox_according_to_tokenizer = bbox_according_to_tokenizer[: (max_seq_length - special_tokens_count)] 170 | input_ids = input_ids[: (max_seq_length - special_tokens_count)] 171 | labels_according_to_tokenizer = labels_according_to_tokenizer[: (max_seq_length - special_tokens_count)] 172 | attention_mask = attention_mask[: (max_seq_length - special_tokens_count)] 173 | 174 | 175 | ## Padding 176 | input_ids = input_ids + [self.tokenizer.eos_token_id] 177 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + [[1000, 1000, 1000, 1000]] 178 | labels_according_to_tokenizer = labels_according_to_tokenizer + [-100] ## For QA, the model requires an end of sentence i.e eos token 179 | attention_mask = attention_mask + [1] 180 | 181 | pad_length = max_seq_length - len(input_ids) 182 | 183 | input_ids = input_ids + [self.tokenizer.pad_token_id] * (pad_length) 184 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + [self.pad_token_box] * (pad_length) 185 | labels_according_to_tokenizer = labels_according_to_tokenizer + [-100] * (pad_length) 186 | attention_mask = attention_mask + [0] * (pad_length) 187 | 188 | ## Converting stuffs to tensor 189 | input_ids = torch.tensor(input_ids) 190 | bbox_according_to_tokenizer = torch.tensor(bbox_according_to_tokenizer) 191 | labels_according_to_tokenizer = torch.tensor(labels_according_to_tokenizer) 192 | attention_mask = torch.tensor(attention_mask) 193 | 194 | return {"input_ids" : input_ids, "labels" : labels_according_to_tokenizer, "attention_mask" : attention_mask, "bboxes" : bbox_according_to_tokenizer, # labels_according_to_tokenizer 195 | "pixel_values" : img_tensor} -------------------------------------------------------------------------------- /src/t5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import copy 5 | from transformers.models import t5 6 | from transformers import AutoModel 7 | 8 | 9 | class T5LayerNorm(nn.Module): 10 | def __init__(self, hidden_size, eps=1e-6): 11 | """ 12 | Construct a layernorm module in the T5 Style. No bias and no subtraction of mean. 13 | """ 14 | super().__init__() 15 | self.weight = nn.Parameter(torch.ones(hidden_size)) 16 | self.variance_epsilon = eps 17 | 18 | def forward(self, hidden_states): 19 | 20 | # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean 21 | # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated 22 | # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for 23 | # half-precision inputs is done in fp32 24 | 25 | variance = hidden_states.to(torch.float32).pow( 26 | 2).mean(-1, keepdim=True) 27 | hidden_states = hidden_states * \ 28 | torch.rsqrt(variance + self.variance_epsilon) 29 | 30 | # convert into half-precision if necessary 31 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 32 | hidden_states = hidden_states.to(self.weight.dtype) 33 | 34 | return self.weight * hidden_states 35 | 36 | 37 | class T5DenseActDense(nn.Module): 38 | def __init__(self, config): 39 | super().__init__() 40 | self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) 41 | self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) 42 | self.dropout = nn.Dropout(config.dropout_rate) 43 | self.act = nn.ReLU() 44 | 45 | def forward(self, hidden_states): 46 | hidden_states = self.wi(hidden_states) 47 | hidden_states = self.act(hidden_states) 48 | hidden_states = self.dropout(hidden_states) 49 | if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: 50 | hidden_states = hidden_states.to(self.wo.weight.dtype) 51 | hidden_states = self.wo(hidden_states) 52 | return hidden_states 53 | 54 | 55 | class T5DenseGatedActDense(nn.Module): 56 | def __init__(self, config): 57 | super().__init__() 58 | self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) 59 | self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) 60 | self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) 61 | self.dropout = nn.Dropout(config.dropout_rate) 62 | self.act = nn.ReLU() 63 | 64 | def forward(self, hidden_states): 65 | hidden_gelu = self.act(self.wi_0(hidden_states)) 66 | hidden_linear = self.wi_1(hidden_states) 67 | hidden_states = hidden_gelu * hidden_linear 68 | hidden_states = self.dropout(hidden_states) 69 | if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8: 70 | hidden_states = hidden_states.to(self.wo.weight.dtype) 71 | hidden_states = self.wo(hidden_states) 72 | return hidden_states 73 | 74 | 75 | class T5LayerFF(nn.Module): 76 | def __init__(self, config): 77 | super().__init__() 78 | if config.is_gated_act: 79 | self.DenseReluDense = T5DenseGatedActDense(config) 80 | else: 81 | self.DenseReluDense = T5DenseActDense(config) 82 | 83 | self.layer_norm = T5LayerNorm( 84 | config.d_model, eps=config.layer_norm_epsilon) 85 | self.dropout = nn.Dropout(config.dropout_rate) 86 | 87 | def forward(self, hidden_states): 88 | forwarded_states = self.layer_norm(hidden_states) 89 | forwarded_states = self.DenseReluDense(forwarded_states) 90 | hidden_states = hidden_states + self.dropout(forwarded_states) 91 | return hidden_states 92 | 93 | 94 | class T5Attention(nn.Module): 95 | def __init__(self, config, has_relative_attention_bias=False): 96 | super().__init__() 97 | self.is_decoder = config.is_decoder 98 | self.has_relative_attention_bias = has_relative_attention_bias 99 | self.relative_attention_num_buckets = config.relative_attention_num_buckets 100 | self.relative_attention_max_distance = config.relative_attention_max_distance 101 | self.d_model = config.d_model 102 | self.key_value_proj_dim = config.d_kv 103 | self.n_heads = config.num_heads 104 | self.dropout = config.dropout_rate 105 | self.inner_dim = self.n_heads * self.key_value_proj_dim 106 | 107 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 108 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) 109 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) 110 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) 111 | 112 | ''' 113 | Here is where the change lies, i.e adding the relative_horizontal_bias as well as the relative_vertical_bias 114 | ''' 115 | if self.has_relative_attention_bias: 116 | self.relative_attention_bias = nn.Embedding( 117 | self.relative_attention_num_buckets, self.n_heads) 118 | self.relative_horizontal_bias = nn.Embedding( 119 | self.relative_attention_num_buckets, self.n_heads) 120 | self.relative_vertical_bias = nn.Embedding( 121 | self.relative_attention_num_buckets, self.n_heads) 122 | 123 | self.gradient_checkpointing = False 124 | 125 | @staticmethod 126 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 127 | """ 128 | Adapted from Mesh Tensorflow: 129 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 130 | Translate relative position to a bucket number for relative attention. The relative position is defined as 131 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to 132 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for 133 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative 134 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. 135 | This should allow for more graceful generalization to longer sequences than the model has been trained on 136 | Args: 137 | relative_position: an int32 Tensor 138 | bidirectional: a boolean - whether the attention is bidirectional 139 | num_buckets: an integer 140 | max_distance: an integer 141 | Returns: 142 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) 143 | """ 144 | relative_buckets = 0 145 | if bidirectional: 146 | num_buckets //= 2 147 | relative_buckets += (relative_position 148 | > 0).to(torch.long) * num_buckets 149 | relative_position = torch.abs(relative_position) 150 | else: 151 | relative_position = - \ 152 | torch.min(relative_position, 153 | torch.zeros_like(relative_position)) 154 | # now relative_position is in the range [0, inf) 155 | 156 | # half of the buckets are for exact increments in positions 157 | max_exact = num_buckets // 2 158 | is_small = relative_position < max_exact 159 | 160 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 161 | relative_position_if_large = max_exact + ( 162 | torch.log(relative_position.float() / max_exact) 163 | / math.log(max_distance / max_exact) 164 | * (num_buckets - max_exact) 165 | ).to(torch.long) 166 | relative_position_if_large = torch.min( 167 | relative_position_if_large, torch.full_like( 168 | relative_position_if_large, num_buckets - 1) 169 | ) 170 | 171 | relative_buckets += torch.where(is_small, 172 | relative_position, relative_position_if_large) 173 | return relative_buckets 174 | 175 | def compute_bias_1d(self, query_length, key_length, device=None): 176 | """Compute binned relative position bias""" 177 | if device is None: 178 | device = self.relative_attention_bias.weight.device 179 | context_position = torch.arange( 180 | query_length, dtype=torch.long, device=device)[:, None] 181 | memory_position = torch.arange( 182 | key_length, dtype=torch.long, device=device)[None, :] 183 | relative_position = memory_position - \ 184 | context_position # shape (query_length, key_length) 185 | relative_position_bucket = self._relative_position_bucket( 186 | relative_position, # shape (query_length, key_length) 187 | bidirectional=(not self.is_decoder), 188 | num_buckets=self.relative_attention_num_buckets, 189 | max_distance=self.relative_attention_max_distance, 190 | ) 191 | # shape (query_length, key_length, num_heads) 192 | values = self.relative_attention_bias(relative_position_bucket) 193 | # shape (1, num_heads, query_length, key_length) 194 | values = values.permute([2, 0, 1]).unsqueeze(0) 195 | return values 196 | 197 | def compute_vertical_horizontal_bias(self, total_boxes: int = 512, device=None): 198 | 199 | denominator_to_divide = total_boxes // self.relative_attention_num_buckets 200 | 201 | """Compute the vertical and horizontal bias""" 202 | if device is None: 203 | device = self.relative_attention_bias.weight.device 204 | indices = torch.arange(total_boxes, dtype=torch.long, device=device) 205 | h_distances = (indices % self.relative_attention_num_buckets)[ 206 | :, None] - (indices % self.relative_attention_num_buckets)[None, :] 207 | v_distances = ( 208 | indices // denominator_to_divide)[:, None] - (indices // denominator_to_divide)[None, :] 209 | 210 | h_distances_bucket = self._relative_position_bucket( 211 | h_distances, # shape (query_length, key_length) 212 | bidirectional=(not self.is_decoder), 213 | num_buckets=self.relative_attention_num_buckets, 214 | max_distance=self.relative_attention_max_distance, 215 | ) 216 | 217 | ## It has to be like this : https://github.com/microsoft/i-Code/blob/d933ae53eb9dec057e605fa4c89ea701629c5b9d/i-Code-Doc/core/models/embedding/relative/relative.py#L175 218 | ## so change is needed here 219 | v_distances_bucket = self._relative_position_bucket( 220 | v_distances, # shape (query_length, key_length) 221 | bidirectional=(not self.is_decoder), 222 | num_buckets=self.relative_attention_num_buckets, 223 | max_distance=self.relative_attention_max_distance, 224 | ) 225 | 226 | h_distances_values = self.relative_horizontal_bias( 227 | h_distances_bucket) # shape (query_length, key_length, num_heads) 228 | h_distances_values = h_distances_values.permute([2, 0, 1]).unsqueeze( 229 | 0) # shape (1, num_heads, query_length, key_length) 230 | 231 | v_distances_values = self.relative_vertical_bias( 232 | v_distances_bucket) # shape (query_length, key_length, num_heads) 233 | v_distances_values = v_distances_values.permute([2, 0, 1]).unsqueeze( 234 | 0) # shape (1, num_heads, query_length, key_length) 235 | 236 | return h_distances_values, v_distances_values 237 | 238 | def forward(self, hidden_states, mask=None, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, 239 | use_cache=False, output_attentions=False): 240 | """ 241 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 242 | """ 243 | # Input is (batch_size, seq_length, dim) 244 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) 245 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) 246 | batch_size, seq_length = hidden_states.shape[:2] 247 | 248 | real_seq_length = seq_length 249 | 250 | if past_key_value is not None: 251 | assert(len(past_key_value) 252 | == 2), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 253 | real_seq_length += past_key_value[0].shape[2] if query_length is None else key_value_states.shape[1] 254 | 255 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[ 256 | 1] 257 | 258 | def shape(states): 259 | "projection" 260 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 261 | 262 | def unshape(states): 263 | """reshape""" 264 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) 265 | 266 | def project(hidden_states, proj_layer, key_value_states, past_key_value): 267 | """project hidden states correctly to key/query states""" 268 | if key_value_states is None: 269 | # self-attn 270 | # (batch_size, n_heads, seq_length, dim_per_head) 271 | hidden_states = shape(proj_layer(hidden_states)) 272 | elif past_key_value is None: 273 | # cross-attn 274 | # (batch_size, n_heads, seq_length, dim_per_head) 275 | hidden_states = shape(proj_layer(key_value_states)) 276 | 277 | if past_key_value is not None: 278 | if key_value_states is None: 279 | # self-attn 280 | # (batch_size, n_heads, key_length, dim_per_head) 281 | hidden_states = torch.cat( 282 | [past_key_value, hidden_states], dim=2) 283 | elif past_key_value.shape[2] != key_value_states.shape[1]: 284 | # checking that the `sequence_length` of the `past_key_value` is the same as 285 | # the provided `key_value_states` to support prefix tuning 286 | # cross-attn 287 | # (batch_size, n_heads, seq_length, dim_per_head) 288 | hidden_states = shape(proj_layer(key_value_states)) 289 | else: 290 | # cross-attn 291 | hidden_states = past_key_value 292 | return hidden_states 293 | 294 | # get query states 295 | query_states = shape(self.q(hidden_states)) 296 | 297 | # get key/value states 298 | key_states = project(hidden_states, self.k, key_value_states, 299 | past_key_value[0] if past_key_value is not None else None) 300 | value_states = project(hidden_states, self.v, key_value_states, 301 | past_key_value[0] if past_key_value is not None else None) 302 | 303 | # compute score 304 | # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 305 | scores = torch.matmul(query_states, key_states.transpose(3, 2)) 306 | 307 | # Sequential Part 308 | if position_bias is None: 309 | if not self.has_relative_attention_bias: 310 | position_bias = torch.zeros( 311 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype 312 | ) 313 | if self.gradient_checkpointing and self.training: 314 | position_bias.requires_grad = True 315 | else: 316 | position_bias = self.compute_bias_1d( 317 | real_seq_length, key_length, device=scores.device) 318 | h_distances_values, v_distances_values = self.compute_vertical_horizontal_bias( 319 | total_boxes=real_seq_length, device=scores.device) 320 | position_bias = position_bias + h_distances_values + v_distances_values 321 | 322 | # if key and values are already calculated 323 | # we want only the last query position bias 324 | if past_key_value is not None: 325 | position_bias = position_bias[:, :, -hidden_states.size(1):, :] 326 | 327 | if mask is not None: 328 | # (batch_size, n_heads, seq_length, key_length) 329 | position_bias = position_bias + mask 330 | 331 | position_bias_masked = position_bias # No pruning right now 332 | 333 | scores += position_bias_masked 334 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( 335 | scores 336 | ) # (batch_size, n_heads, seq_length, key_length) 337 | attn_weights = nn.functional.dropout( 338 | attn_weights, p=self.dropout, training=self.training 339 | ) # (batch_size, n_heads, seq_length, key_length) 340 | 341 | # Mask heads if we want to 342 | if layer_head_mask is not None: 343 | attn_weights = attn_weights * layer_head_mask 344 | 345 | # (batch_size, seq_length, dim) 346 | attn_output = unshape(torch.matmul(attn_weights, value_states)) 347 | attn_output = self.o(attn_output) 348 | 349 | present_key_value_state = (key_states, value_states) if ( 350 | self.is_decoder and use_cache) else None 351 | outputs = (attn_output,) + \ 352 | (present_key_value_state,) + (position_bias,) 353 | 354 | if output_attentions: 355 | outputs = outputs + (attn_weights,) 356 | return outputs 357 | 358 | 359 | class T5LayerSelfAttention(nn.Module): 360 | def __init__(self, config, has_relative_attention_bias=False): 361 | super().__init__() 362 | self.SelfAttention = T5Attention( 363 | config, has_relative_attention_bias=has_relative_attention_bias) 364 | self.layer_norm = T5LayerNorm( 365 | config.d_model, eps=config.layer_norm_epsilon) 366 | self.dropout = nn.Dropout(config.dropout_rate) 367 | 368 | def forward(self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False): 369 | normed_hidden_states = self.layer_norm(hidden_states) 370 | attention_output = self.SelfAttention(normed_hidden_states, mask=attention_mask, position_bias=position_bias, 371 | layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions,) 372 | hidden_states = hidden_states + self.dropout(attention_output[0]) 373 | # add attentions if we output them 374 | outputs = (hidden_states,) + attention_output[1:] 375 | return outputs 376 | 377 | 378 | class T5LayerCrossAttention(nn.Module): 379 | def __init__(self, config): 380 | super().__init__() 381 | self.EncDecAttention = T5Attention( 382 | config, has_relative_attention_bias=False) 383 | self.layer_norm = T5LayerNorm( 384 | config.d_model, eps=config.layer_norm_epsilon) 385 | self.dropout = nn.Dropout(config.dropout_rate) 386 | 387 | def forward(self, hidden_states, key_value_states, attention_mask=None, position_bias=None, layer_head_mask=None, past_key_value=None, use_cache=False, query_length=None, output_attentions=False, ): 388 | normed_hidden_states = self.layer_norm(hidden_states) 389 | attention_output = self.EncDecAttention(normed_hidden_states, mask=attention_mask, 390 | key_value_states=key_value_states, position_bias=position_bias, 391 | layer_head_mask=layer_head_mask, 392 | past_key_value=past_key_value, 393 | use_cache=use_cache, 394 | query_length=query_length, 395 | output_attentions=output_attentions,) 396 | layer_output = hidden_states + self.dropout(attention_output[0]) 397 | # add attention if we output them 398 | outputs = (layer_output, ) + attention_output[1:] 399 | return outputs 400 | 401 | 402 | class T5Block(nn.Module): 403 | def __init__(self, config, has_relative_attention_bias=False): 404 | super().__init__() 405 | self.is_decoder = config.is_decoder 406 | self.layer = nn.ModuleList() 407 | self.layer.append(T5LayerSelfAttention( 408 | config, has_relative_attention_bias=has_relative_attention_bias)) 409 | if self.is_decoder: 410 | self.layer.append(T5LayerCrossAttention(config)) 411 | 412 | self.layer.append(T5LayerFF(config)) 413 | 414 | def forward(self, hidden_states, attention_mask=None, position_bias=None, encoder_hidden_states=None, 415 | encoder_attention_mask=None, encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, 416 | past_key_value=None, use_cache=False, output_attentions=False, return_dict=True): 417 | 418 | if past_key_value is not None: 419 | expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 420 | 421 | if len(past_key_value) != expected_num_past_key_values: 422 | raise ValueError( 423 | f"There should be {expected_num_past_key_values} past states. " 424 | f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" 425 | f"Got {len(past_key_value)} past key / value states" 426 | ) 427 | 428 | self_attn_past_key_value = past_key_value[:2] 429 | cross_attn_past_key_value = past_key_value[2:] 430 | else: 431 | self_attn_past_key_value, cross_attn_past_key_value = None, None 432 | 433 | self_attention_outputs = self.layer[0]( 434 | hidden_states, 435 | attention_mask=attention_mask, 436 | position_bias=position_bias, 437 | layer_head_mask=layer_head_mask, 438 | past_key_value=self_attn_past_key_value, 439 | use_cache=use_cache, 440 | output_attentions=output_attentions, 441 | ) 442 | hidden_states, present_key_value_state = self_attention_outputs[:2] 443 | # Keep self-attention outputs and relative position weights 444 | attention_outputs = self_attention_outputs[2:] 445 | 446 | # clamp inf values to enable fp16 training 447 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 448 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 449 | hidden_states = torch.clamp( 450 | hidden_states, min=-clamp_value, max=clamp_value) 451 | 452 | do_cross_attention = self.is_decoder and encoder_hidden_states is not None 453 | if do_cross_attention: 454 | # the actual query length is unknown for cross attention 455 | # if using past key value states. Need to inject it here 456 | if present_key_value_state is not None: 457 | query_length = present_key_value_state[0].shape[2] 458 | else: 459 | query_length = None 460 | 461 | cross_attention_outputs = self.layer[1]( 462 | hidden_states, 463 | key_value_states=encoder_hidden_states, 464 | attention_mask=encoder_attention_mask, 465 | position_bias=encoder_decoder_position_bias, 466 | layer_head_mask=cross_attn_layer_head_mask, 467 | past_key_value=cross_attn_past_key_value, 468 | query_length=query_length, 469 | use_cache=use_cache, 470 | output_attentions=output_attentions, 471 | ) 472 | hidden_states = cross_attention_outputs[0] 473 | 474 | # clamp inf values to enable fp16 training 475 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 476 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 477 | hidden_states = torch.clamp( 478 | hidden_states, min=-clamp_value, max=clamp_value) 479 | 480 | # Combine self attn and cross attn key value states 481 | if present_key_value_state is not None: 482 | present_key_value_state = present_key_value_state + \ 483 | cross_attention_outputs[1] 484 | 485 | # Keep cross-attention outputs and relative position weights 486 | attention_outputs = attention_outputs + cross_attention_outputs[2:] 487 | 488 | # Apply Feed Forward layer 489 | hidden_states = self.layer[-1](hidden_states) 490 | 491 | # clamp inf values to enable fp16 training 492 | if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 493 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 494 | hidden_states = torch.clamp( 495 | hidden_states, min=-clamp_value, max=clamp_value) 496 | 497 | outputs = (hidden_states,) 498 | 499 | if use_cache: 500 | outputs = outputs + (present_key_value_state,) + attention_outputs 501 | else: 502 | outputs = outputs + attention_outputs 503 | 504 | # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 505 | return outputs 506 | 507 | 508 | class T5Stack(t5.modeling_t5.T5Stack): 509 | def __init__(self, config, embed_tokens=None): 510 | '''Just changes in the `T5Block`, so have to update it as per our implementation''' 511 | super().__init__(config=config, embed_tokens=embed_tokens) 512 | self.block = nn.ModuleList( 513 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) 514 | for i in range(config.num_layers)] 515 | ) 516 | 517 | def forward( 518 | self, 519 | input_ids=None, 520 | attention_mask=None, 521 | encoder_hidden_states=None, 522 | encoder_attention_mask=None, 523 | inputs_embeds=None, 524 | head_mask=None, 525 | cross_attn_head_mask=None, 526 | past_key_values=None, 527 | use_cache=None, 528 | output_attentions=None, 529 | output_hidden_states=None, 530 | return_dict=None, 531 | ): 532 | 533 | return super().forward(input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, 534 | inputs_embeds=inputs_embeds, head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, 535 | use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) 536 | 537 | 538 | class T5Model(t5.modeling_t5.T5Model): 539 | def __init__(self, config): 540 | super().__init__(config=config) 541 | 542 | self.config = config 543 | encoder_config = copy.deepcopy(config) 544 | decoder_config = copy.deepcopy(config) 545 | decoder_config.update(dict(is_decoder=True)) 546 | 547 | self.encoder = T5Stack(encoder_config, self.shared) 548 | self.decoder = T5Stack(decoder_config, self.shared) 549 | 550 | self.post_init() 551 | 552 | def forward(self, **kwargs): 553 | return super().forward(**kwargs) 554 | 555 | def load_weights(self): 556 | dummy_model = AutoModel.from_pretrained(self.config._name_or_path) 557 | self.load_state_dict(dummy_model.state_dict(), strict=False) 558 | print("Weights loaded successfully!") 559 | 560 | 561 | class T5ForConditionalGeneration(t5.modeling_t5.T5ForConditionalGeneration): 562 | def __init__(self, config): 563 | ''' 564 | It is similar to the T5ForConditionalGeneration described in the `hugging_face` repository, however I had to tweak it a bit, 565 | since there is an addition of the `relative_horizontal_bias` as well as `relative_vertical_bias` in the `T5Attention` class, and also 566 | the entire approach is generative in nature, so maybe it can be used in some other dataset, such as Question Answering 567 | ''' 568 | 569 | super().__init__(config=config) 570 | 571 | self.config = config 572 | encoder_config = copy.deepcopy(config) 573 | decoder_config = copy.deepcopy(config) 574 | # In the pretrained version, the decoder config, the `is_decoder` option is True 575 | decoder_config.update(dict(is_decoder=True)) 576 | 577 | self.encoder = T5Stack(encoder_config, self.shared) 578 | self.decoder = T5Stack(decoder_config, self.shared) 579 | 580 | if config.load_weights: 581 | self.load_weights() 582 | else: 583 | self.post_init() 584 | print("Initialization done without loading the weights") 585 | 586 | def forward(self, **kwargs): 587 | '''Same as mentioned in the hugging face's implementation''' 588 | return super().forward(**kwargs) 589 | 590 | def load_weights(self): 591 | ''' 592 | Loads the weights from the pretrained model 593 | ''' 594 | dummy_model = AutoModel.from_pretrained(self.config._name_or_path) 595 | self.load_state_dict(dummy_model.state_dict(), strict=False) 596 | print("Weights loaded successfully!") 597 | 598 | 599 | class T5EncoderModel(t5.modeling_t5.T5ForConditionalGeneration): 600 | def __init__(self, config): 601 | ''' 602 | It is similar to the T5EncoderModel described in the `hugging_face` repository, however I had to tweak it a bit, 603 | since there is an addition of the `relative_horizontal_bias` as well as `relative_vertical_bias` in the `T5Attention` class 604 | ''' 605 | super().__init__(config=config) 606 | self.encoder = T5Stack(config, self.shared) 607 | self.post_init() 608 | 609 | def forward(self, **kwargs): 610 | '''Similar to the `T5EncoderModel` mentioned in the hugging face's t5 implementation''' 611 | return super().forward(**kwargs) 612 | 613 | 614 | class T5ForConditionalGenerationAbstractive(t5.modeling_t5.T5ForConditionalGeneration): 615 | def __init__(self, config): 616 | ''' 617 | T5ForConditionalGenerationAbstractive is a T5ForConditionalGeneration model with a linear layer on top of the decoder output, 618 | where the decoder output is the output of the last layer of the decoder, followed by a linear layer projection. 619 | 620 | It is similar to T5ForConditionalGeneration, however, it is based on a concept of generative answer, and this was what I did earlier, 621 | however, the authors have used an abstractive approach, and so I had to tweak somethinig, and essentially, it is the `self.lm_head` 622 | ''' 623 | 624 | super().__init__(config=config) 625 | 626 | self.config = config 627 | encoder_config = copy.deepcopy(config) 628 | decoder_config = copy.deepcopy(config) 629 | # In the pretrained version, the decoder config, the `is_decoder` option is True 630 | decoder_config.update(dict(is_decoder=True)) 631 | 632 | self.encoder = T5Stack(encoder_config, self.shared) 633 | self.decoder = T5Stack(decoder_config, self.shared) 634 | self.lm_head = nn.Linear(in_features=config.d_model, 635 | out_features=config.num_classes, bias=False) 636 | 637 | if config.load_weights: 638 | self.load_weights() 639 | else: 640 | self.post_init() 641 | print("Initialization done without loading the weights") 642 | 643 | def forward(self, **kwargs): 644 | ''' 645 | Forward pass of T5ForConditionalGenerationAbstractive. It is similar to T5ForConditionalGeneration, however, it is based on a concept of generative answer, 646 | and this was what I did earlier, 647 | ''' 648 | return super().forward(**kwargs) 649 | 650 | def load_weights(self): 651 | ''' 652 | Load the weights of the T5ForConditionalGenerationAbstractive model 653 | It is adaptable to both the `t5-base` and `t5-large` configuration settings 654 | ''' 655 | dummy_model = AutoModel.from_pretrained(self.config._name_or_path) 656 | self.load_state_dict(dummy_model.state_dict(), strict=False) 657 | print("Weights loaded successfully!") 658 | -------------------------------------------------------------------------------- /how_did_i_prepare_the_stuffs/tilt_part_3_1_aligning_all_the_parts_to_make_tilt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyNdtYJJE/gqRdJD2aJ7nI9l", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "widgets": { 18 | "application/vnd.jupyter.widget-state+json": { 19 | "250938dfb71d45e0858cd757df64cd8b": { 20 | "model_module": "@jupyter-widgets/controls", 21 | "model_name": "HBoxModel", 22 | "model_module_version": "1.5.0", 23 | "state": { 24 | "_dom_classes": [], 25 | "_model_module": "@jupyter-widgets/controls", 26 | "_model_module_version": "1.5.0", 27 | "_model_name": "HBoxModel", 28 | "_view_count": null, 29 | "_view_module": "@jupyter-widgets/controls", 30 | "_view_module_version": "1.5.0", 31 | "_view_name": "HBoxView", 32 | "box_style": "", 33 | "children": [ 34 | "IPY_MODEL_20ea30cbafb84c618979bd41e8d921d6", 35 | "IPY_MODEL_ff6f0b419450464a9daf73a41e357930", 36 | "IPY_MODEL_681f7ed3e39442cd8f987db7a131f4e1" 37 | ], 38 | "layout": "IPY_MODEL_33af6b7b9d7e4c8d90a9e4d66159f951" 39 | } 40 | }, 41 | "20ea30cbafb84c618979bd41e8d921d6": { 42 | "model_module": "@jupyter-widgets/controls", 43 | "model_name": "HTMLModel", 44 | "model_module_version": "1.5.0", 45 | "state": { 46 | "_dom_classes": [], 47 | "_model_module": "@jupyter-widgets/controls", 48 | "_model_module_version": "1.5.0", 49 | "_model_name": "HTMLModel", 50 | "_view_count": null, 51 | "_view_module": "@jupyter-widgets/controls", 52 | "_view_module_version": "1.5.0", 53 | "_view_name": "HTMLView", 54 | "description": "", 55 | "description_tooltip": null, 56 | "layout": "IPY_MODEL_3959923118c84010b4b0024b28ca2734", 57 | "placeholder": "​", 58 | "style": "IPY_MODEL_b3fd33e82ba7455981c0499df81cb0d5", 59 | "value": "100%" 60 | } 61 | }, 62 | "ff6f0b419450464a9daf73a41e357930": { 63 | "model_module": "@jupyter-widgets/controls", 64 | "model_name": "FloatProgressModel", 65 | "model_module_version": "1.5.0", 66 | "state": { 67 | "_dom_classes": [], 68 | "_model_module": "@jupyter-widgets/controls", 69 | "_model_module_version": "1.5.0", 70 | "_model_name": "FloatProgressModel", 71 | "_view_count": null, 72 | "_view_module": "@jupyter-widgets/controls", 73 | "_view_module_version": "1.5.0", 74 | "_view_name": "ProgressView", 75 | "bar_style": "success", 76 | "description": "", 77 | "description_tooltip": null, 78 | "layout": "IPY_MODEL_59955fdf275549ca8873c9b053419fd7", 79 | "max": 2, 80 | "min": 0, 81 | "orientation": "horizontal", 82 | "style": "IPY_MODEL_5e288264d36f49c58f09b95ae3587e26", 83 | "value": 2 84 | } 85 | }, 86 | "681f7ed3e39442cd8f987db7a131f4e1": { 87 | "model_module": "@jupyter-widgets/controls", 88 | "model_name": "HTMLModel", 89 | "model_module_version": "1.5.0", 90 | "state": { 91 | "_dom_classes": [], 92 | "_model_module": "@jupyter-widgets/controls", 93 | "_model_module_version": "1.5.0", 94 | "_model_name": "HTMLModel", 95 | "_view_count": null, 96 | "_view_module": "@jupyter-widgets/controls", 97 | "_view_module_version": "1.5.0", 98 | "_view_name": "HTMLView", 99 | "description": "", 100 | "description_tooltip": null, 101 | "layout": "IPY_MODEL_66bcfafd26e14fb8b333d0da2efc222f", 102 | "placeholder": "​", 103 | "style": "IPY_MODEL_c3deff7939634161bca16addf85406a4", 104 | "value": " 2/2 [00:00<00:00, 21.71it/s]" 105 | } 106 | }, 107 | "33af6b7b9d7e4c8d90a9e4d66159f951": { 108 | "model_module": "@jupyter-widgets/base", 109 | "model_name": "LayoutModel", 110 | "model_module_version": "1.2.0", 111 | "state": { 112 | "_model_module": "@jupyter-widgets/base", 113 | "_model_module_version": "1.2.0", 114 | "_model_name": "LayoutModel", 115 | "_view_count": null, 116 | "_view_module": "@jupyter-widgets/base", 117 | "_view_module_version": "1.2.0", 118 | "_view_name": "LayoutView", 119 | "align_content": null, 120 | "align_items": null, 121 | "align_self": null, 122 | "border": null, 123 | "bottom": null, 124 | "display": null, 125 | "flex": null, 126 | "flex_flow": null, 127 | "grid_area": null, 128 | "grid_auto_columns": null, 129 | "grid_auto_flow": null, 130 | "grid_auto_rows": null, 131 | "grid_column": null, 132 | "grid_gap": null, 133 | "grid_row": null, 134 | "grid_template_areas": null, 135 | "grid_template_columns": null, 136 | "grid_template_rows": null, 137 | "height": null, 138 | "justify_content": null, 139 | "justify_items": null, 140 | "left": null, 141 | "margin": null, 142 | "max_height": null, 143 | "max_width": null, 144 | "min_height": null, 145 | "min_width": null, 146 | "object_fit": null, 147 | "object_position": null, 148 | "order": null, 149 | "overflow": null, 150 | "overflow_x": null, 151 | "overflow_y": null, 152 | "padding": null, 153 | "right": null, 154 | "top": null, 155 | "visibility": null, 156 | "width": null 157 | } 158 | }, 159 | "3959923118c84010b4b0024b28ca2734": { 160 | "model_module": "@jupyter-widgets/base", 161 | "model_name": "LayoutModel", 162 | "model_module_version": "1.2.0", 163 | "state": { 164 | "_model_module": "@jupyter-widgets/base", 165 | "_model_module_version": "1.2.0", 166 | "_model_name": "LayoutModel", 167 | "_view_count": null, 168 | "_view_module": "@jupyter-widgets/base", 169 | "_view_module_version": "1.2.0", 170 | "_view_name": "LayoutView", 171 | "align_content": null, 172 | "align_items": null, 173 | "align_self": null, 174 | "border": null, 175 | "bottom": null, 176 | "display": null, 177 | "flex": null, 178 | "flex_flow": null, 179 | "grid_area": null, 180 | "grid_auto_columns": null, 181 | "grid_auto_flow": null, 182 | "grid_auto_rows": null, 183 | "grid_column": null, 184 | "grid_gap": null, 185 | "grid_row": null, 186 | "grid_template_areas": null, 187 | "grid_template_columns": null, 188 | "grid_template_rows": null, 189 | "height": null, 190 | "justify_content": null, 191 | "justify_items": null, 192 | "left": null, 193 | "margin": null, 194 | "max_height": null, 195 | "max_width": null, 196 | "min_height": null, 197 | "min_width": null, 198 | "object_fit": null, 199 | "object_position": null, 200 | "order": null, 201 | "overflow": null, 202 | "overflow_x": null, 203 | "overflow_y": null, 204 | "padding": null, 205 | "right": null, 206 | "top": null, 207 | "visibility": null, 208 | "width": null 209 | } 210 | }, 211 | "b3fd33e82ba7455981c0499df81cb0d5": { 212 | "model_module": "@jupyter-widgets/controls", 213 | "model_name": "DescriptionStyleModel", 214 | "model_module_version": "1.5.0", 215 | "state": { 216 | "_model_module": "@jupyter-widgets/controls", 217 | "_model_module_version": "1.5.0", 218 | "_model_name": "DescriptionStyleModel", 219 | "_view_count": null, 220 | "_view_module": "@jupyter-widgets/base", 221 | "_view_module_version": "1.2.0", 222 | "_view_name": "StyleView", 223 | "description_width": "" 224 | } 225 | }, 226 | "59955fdf275549ca8873c9b053419fd7": { 227 | "model_module": "@jupyter-widgets/base", 228 | "model_name": "LayoutModel", 229 | "model_module_version": "1.2.0", 230 | "state": { 231 | "_model_module": "@jupyter-widgets/base", 232 | "_model_module_version": "1.2.0", 233 | "_model_name": "LayoutModel", 234 | "_view_count": null, 235 | "_view_module": "@jupyter-widgets/base", 236 | "_view_module_version": "1.2.0", 237 | "_view_name": "LayoutView", 238 | "align_content": null, 239 | "align_items": null, 240 | "align_self": null, 241 | "border": null, 242 | "bottom": null, 243 | "display": null, 244 | "flex": null, 245 | "flex_flow": null, 246 | "grid_area": null, 247 | "grid_auto_columns": null, 248 | "grid_auto_flow": null, 249 | "grid_auto_rows": null, 250 | "grid_column": null, 251 | "grid_gap": null, 252 | "grid_row": null, 253 | "grid_template_areas": null, 254 | "grid_template_columns": null, 255 | "grid_template_rows": null, 256 | "height": null, 257 | "justify_content": null, 258 | "justify_items": null, 259 | "left": null, 260 | "margin": null, 261 | "max_height": null, 262 | "max_width": null, 263 | "min_height": null, 264 | "min_width": null, 265 | "object_fit": null, 266 | "object_position": null, 267 | "order": null, 268 | "overflow": null, 269 | "overflow_x": null, 270 | "overflow_y": null, 271 | "padding": null, 272 | "right": null, 273 | "top": null, 274 | "visibility": null, 275 | "width": null 276 | } 277 | }, 278 | "5e288264d36f49c58f09b95ae3587e26": { 279 | "model_module": "@jupyter-widgets/controls", 280 | "model_name": "ProgressStyleModel", 281 | "model_module_version": "1.5.0", 282 | "state": { 283 | "_model_module": "@jupyter-widgets/controls", 284 | "_model_module_version": "1.5.0", 285 | "_model_name": "ProgressStyleModel", 286 | "_view_count": null, 287 | "_view_module": "@jupyter-widgets/base", 288 | "_view_module_version": "1.2.0", 289 | "_view_name": "StyleView", 290 | "bar_color": null, 291 | "description_width": "" 292 | } 293 | }, 294 | "66bcfafd26e14fb8b333d0da2efc222f": { 295 | "model_module": "@jupyter-widgets/base", 296 | "model_name": "LayoutModel", 297 | "model_module_version": "1.2.0", 298 | "state": { 299 | "_model_module": "@jupyter-widgets/base", 300 | "_model_module_version": "1.2.0", 301 | "_model_name": "LayoutModel", 302 | "_view_count": null, 303 | "_view_module": "@jupyter-widgets/base", 304 | "_view_module_version": "1.2.0", 305 | "_view_name": "LayoutView", 306 | "align_content": null, 307 | "align_items": null, 308 | "align_self": null, 309 | "border": null, 310 | "bottom": null, 311 | "display": null, 312 | "flex": null, 313 | "flex_flow": null, 314 | "grid_area": null, 315 | "grid_auto_columns": null, 316 | "grid_auto_flow": null, 317 | "grid_auto_rows": null, 318 | "grid_column": null, 319 | "grid_gap": null, 320 | "grid_row": null, 321 | "grid_template_areas": null, 322 | "grid_template_columns": null, 323 | "grid_template_rows": null, 324 | "height": null, 325 | "justify_content": null, 326 | "justify_items": null, 327 | "left": null, 328 | "margin": null, 329 | "max_height": null, 330 | "max_width": null, 331 | "min_height": null, 332 | "min_width": null, 333 | "object_fit": null, 334 | "object_position": null, 335 | "order": null, 336 | "overflow": null, 337 | "overflow_x": null, 338 | "overflow_y": null, 339 | "padding": null, 340 | "right": null, 341 | "top": null, 342 | "visibility": null, 343 | "width": null 344 | } 345 | }, 346 | "c3deff7939634161bca16addf85406a4": { 347 | "model_module": "@jupyter-widgets/controls", 348 | "model_name": "DescriptionStyleModel", 349 | "model_module_version": "1.5.0", 350 | "state": { 351 | "_model_module": "@jupyter-widgets/controls", 352 | "_model_module_version": "1.5.0", 353 | "_model_name": "DescriptionStyleModel", 354 | "_view_count": null, 355 | "_view_module": "@jupyter-widgets/base", 356 | "_view_module_version": "1.2.0", 357 | "_view_name": "StyleView", 358 | "description_width": "" 359 | } 360 | } 361 | } 362 | }, 363 | "accelerator": "GPU", 364 | "gpuClass": "standard" 365 | }, 366 | "cells": [ 367 | { 368 | "cell_type": "markdown", 369 | "metadata": { 370 | "id": "view-in-github", 371 | "colab_type": "text" 372 | }, 373 | "source": [ 374 | "\"Open" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": { 381 | "colab": { 382 | "base_uri": "https://localhost:8080/" 383 | }, 384 | "id": "15ohrRAbDZ4v", 385 | "outputId": "5d5eec1e-0227-45c0-c9f8-337e02486c1b" 386 | }, 387 | "outputs": [ 388 | { 389 | "output_type": "stream", 390 | "name": "stdout", 391 | "text": [ 392 | "fatal: destination path 'TiLT-Implementation' already exists and is not an empty directory.\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "!git clone https://github.com/uakarsh/TiLT-Implementation.git" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "source": [ 403 | "!pip install -r /content/TiLT-Implementation/requirements.txt" 404 | ], 405 | "metadata": { 406 | "id": "IlsCNhv3D0hx" 407 | }, 408 | "execution_count": null, 409 | "outputs": [] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "source": [ 414 | "import sys\n", 415 | "sys.path.append(\"/content/TiLT-Implementation/src/\")" 416 | ], 417 | "metadata": { 418 | "id": "1W4eZnAcD3Pg" 419 | }, 420 | "execution_count": null, 421 | "outputs": [] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "source": [ 426 | "from transformers import AutoTokenizer, AutoConfig\n", 427 | "from datasets import load_dataset\n", 428 | "import torch\n", 429 | "import torch.nn as nn\n", 430 | "\n", 431 | "from dataset import FUNSDDs\n", 432 | "from torchvision import transforms\n", 433 | "from tqdm.auto import tqdm\n", 434 | "\n", 435 | "## Custom imports\n", 436 | "from visual_backbone import Unet_encoder, RoIPool\n", 437 | "from t5 import T5ForConditionalGeneration, T5Stack\n", 438 | "from transformers import AutoModel" 439 | ], 440 | "metadata": { 441 | "id": "hshzmmrID39p" 442 | }, 443 | "execution_count": null, 444 | "outputs": [] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "source": [ 449 | "## 1.1. Preparing the dataset" 450 | ], 451 | "metadata": { 452 | "id": "uPeVZmMeEbyf" 453 | } 454 | }, 455 | { 456 | "cell_type": "code", 457 | "source": [ 458 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 459 | "\n", 460 | "hf_ds = load_dataset(\"nielsr/funsd-layoutlmv3\")\n", 461 | "model_name = \"t5-base\"\n", 462 | "## Visual Embedding extractor's parameters\n", 463 | "in_channels = 3\n", 464 | "num_pool_layers = 3\n", 465 | "channels = 16\n", 466 | "sampling_ratio = 2\n", 467 | "spatial_scale = 48 / 384\n", 468 | "output_size = (3,3)\n", 469 | "load_weights = True\n", 470 | "\n", 471 | "## Tokenizer's parameter\n", 472 | "model_max_length = 512\n", 473 | "\n", 474 | "t5_config = AutoConfig.from_pretrained(model_name)\n", 475 | "## Adding new parameters\n", 476 | "t5_config.update(dict(in_channels = in_channels, num_pool_layers = num_pool_layers, channels = channels, model_max_length = model_max_length,\n", 477 | " output_size = output_size, spatial_scale = spatial_scale, sampling_ratio = sampling_ratio, use_cache = False, load_weights = load_weights))\n", 478 | "\n", 479 | "## Tokenizer\n", 480 | "tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True, model_max_length = model_max_length)" 481 | ], 482 | "metadata": { 483 | "colab": { 484 | "base_uri": "https://localhost:8080/", 485 | "height": 86, 486 | "referenced_widgets": [ 487 | "250938dfb71d45e0858cd757df64cd8b", 488 | "20ea30cbafb84c618979bd41e8d921d6", 489 | "ff6f0b419450464a9daf73a41e357930", 490 | "681f7ed3e39442cd8f987db7a131f4e1", 491 | "33af6b7b9d7e4c8d90a9e4d66159f951", 492 | "3959923118c84010b4b0024b28ca2734", 493 | "b3fd33e82ba7455981c0499df81cb0d5", 494 | "59955fdf275549ca8873c9b053419fd7", 495 | "5e288264d36f49c58f09b95ae3587e26", 496 | "66bcfafd26e14fb8b333d0da2efc222f", 497 | "c3deff7939634161bca16addf85406a4" 498 | ] 499 | }, 500 | "id": "qzcvnEyYD-KC", 501 | "outputId": "a7f9bab8-de97-400c-edbf-36ba169ab33d" 502 | }, 503 | "execution_count": null, 504 | "outputs": [ 505 | { 506 | "output_type": "stream", 507 | "name": "stderr", 508 | "text": [ 509 | "WARNING:datasets.builder:Found cached dataset funsd-layoutlmv3 (/root/.cache/huggingface/datasets/nielsr___funsd-layoutlmv3/funsd/1.0.0/0e3f4efdfd59aa1c3b4952c517894f7b1fc4d75c12ef01bcc8626a69e41c1bb9)\n" 510 | ] 511 | }, 512 | { 513 | "output_type": "display_data", 514 | "data": { 515 | "text/plain": [ 516 | " 0%| | 0/2 [00:00\"Open" 2767 | ] 2768 | }, 2769 | { 2770 | "cell_type": "code", 2771 | "execution_count": null, 2772 | "metadata": { 2773 | "colab": { 2774 | "base_uri": "https://localhost:8080/" 2775 | }, 2776 | "id": "IR5WxNj-dSN_", 2777 | "outputId": "ba18d456-9945-46cb-8f02-16cf0a46bfa1" 2778 | }, 2779 | "outputs": [ 2780 | { 2781 | "output_type": "stream", 2782 | "name": "stdout", 2783 | "text": [ 2784 | "Cloning into 'TiLT-Implementation'...\n", 2785 | "remote: Enumerating objects: 77, done.\u001b[K\n", 2786 | "remote: Counting objects: 100% (77/77), done.\u001b[K\n", 2787 | "remote: Compressing objects: 100% (58/58), done.\u001b[K\n", 2788 | "remote: Total 77 (delta 31), reused 45 (delta 11), pack-reused 0\u001b[K\n", 2789 | "Unpacking objects: 100% (77/77), 2.77 MiB | 7.51 MiB/s, done.\n" 2790 | ] 2791 | } 2792 | ], 2793 | "source": [ 2794 | "!git clone https://github.com/uakarsh/TiLT-Implementation.git" 2795 | ] 2796 | }, 2797 | { 2798 | "cell_type": "code", 2799 | "source": [ 2800 | "!pip install -r /content/TiLT-Implementation/requirements.txt" 2801 | ], 2802 | "metadata": { 2803 | "id": "oBrxq2BLiuXw" 2804 | }, 2805 | "execution_count": null, 2806 | "outputs": [] 2807 | }, 2808 | { 2809 | "cell_type": "code", 2810 | "source": [ 2811 | "import sys\n", 2812 | "sys.path.append(\"/content/TiLT-Implementation/src/\")" 2813 | ], 2814 | "metadata": { 2815 | "id": "fiGVO1aFiv-h" 2816 | }, 2817 | "execution_count": null, 2818 | "outputs": [] 2819 | }, 2820 | { 2821 | "cell_type": "code", 2822 | "source": [ 2823 | "from transformers import AutoTokenizer\n", 2824 | "from datasets import load_dataset\n", 2825 | "import torch\n", 2826 | "import torch.nn\n", 2827 | "\n", 2828 | "model_name = \"t5-base\"\n", 2829 | "tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True)" 2830 | ], 2831 | "metadata": { 2832 | "colab": { 2833 | "base_uri": "https://localhost:8080/", 2834 | "height": 217, 2835 | "referenced_widgets": [ 2836 | "2c08e449950a4d3ea9049a4aaf9f4de7", 2837 | "759b967f50d84ffd903fab90c0969a6d", 2838 | "f97e169933e648378b59c146da88bc3b", 2839 | "3432461a6f7446dda1e28ac7870591c6", 2840 | "e41f7aa014de44dca338647c329d2e02", 2841 | "35b3ce19787a47c8b6c2a9e685550108", 2842 | "ba5cff7a36bf4ae080f76ec53dfce07c", 2843 | "c203dd7e0dcc42d889b05755a018c998", 2844 | "a40c5f7a73ab422ca89258fa2c98bc52", 2845 | "1dc415b05d0d4b9e999fddf1dea84562", 2846 | "cecac60443284810bac22a9a10819572", 2847 | "33ec7ec081d24433a72b65c2c6085aaf", 2848 | "954d298b8a8d43628fc2add59266541e", 2849 | "75208c1d1d3247c19baca7c526eb7382", 2850 | "6e1cb705ef9d4341b6d61f4cccd6c009", 2851 | "9f9fa275896b446b8e6bd511190fe922", 2852 | "ffa2df3f95e241fe90f6514049a32669", 2853 | "993ea1ee236b4b77bce5def3cb4eba68", 2854 | "01b087a18a2346cc97aeaa712e54cef4", 2855 | "c0903c5551224beea5edf4609b07514c", 2856 | "3ecf6715c92042ccbe672b5e1f758de5", 2857 | "3e826539414a4e798e4a6232174e1066", 2858 | "242efb0ea706421c994793f87ece24eb", 2859 | "56e42bcb2df74e2a8e990da26760da6f", 2860 | "9aef638dd3db437ab712c99cbb57433d", 2861 | "c5371258333d471abc1a9e73b23573a0", 2862 | "69cb8cdc043144148dd599569da87d25", 2863 | "3a970557b24940d29eea9ab77b620a0b", 2864 | "aece0cbdf6d04e29a3484a71ae74a971", 2865 | "9e1768f313ef4f8f9503a6e9c64bb707", 2866 | "09d9547e80e44ce5981f66190bd3bb5c", 2867 | "41c267d7ae7f4de183d8ffbd60d47b36", 2868 | "656c59f5ea564ff994112595236213b9" 2869 | ] 2870 | }, 2871 | "id": "lR-QX-JbiwRW", 2872 | "outputId": "ccac92d0-61d7-4749-a076-31e6bf9cc3bb" 2873 | }, 2874 | "execution_count": null, 2875 | "outputs": [ 2876 | { 2877 | "output_type": "display_data", 2878 | "data": { 2879 | "text/plain": [ 2880 | "Downloading (…)lve/main/config.json: 0%| | 0.00/1.21k [00:00