├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── src ├── docllm │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── finetuning │ │ │ ├── __init__.py │ │ │ └── reading_order.py │ │ ├── precompute_epoch.py │ │ ├── precomputed_epoch_loader.py │ │ ├── preprocessing │ │ │ ├── __init__.py │ │ │ ├── data_structure.py │ │ │ ├── doc_data_from_hocr.py │ │ │ └── doc_data_to_pickle.py │ │ └── pretraining │ │ │ ├── __init__.py │ │ │ ├── build_hf_dataset.py │ │ │ ├── collate.py │ │ │ ├── config.py │ │ │ ├── data_packing_pipe.py │ │ │ ├── dataset.py │ │ │ ├── masker.py │ │ │ ├── pipeline.py │ │ │ ├── precomputation_pipeline.py │ │ │ ├── tensor_data_loader_pipe.py │ │ │ └── traindata_pipe.py │ ├── modules │ │ ├── __init__.py │ │ ├── llama │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── causal_docllm.py │ │ │ ├── config.py │ │ │ ├── decoder_layer.py │ │ │ ├── llama_docllm.py │ │ │ └── pretrained_model.py │ │ ├── part_freezable_embedding.py │ │ ├── part_freezable_linear.py │ │ └── spatial_embedder.py │ ├── pretraining_loss.py │ ├── py.typed │ ├── scripts │ │ ├── convert_fsdp_to_hf.py │ │ └── json_to_pickle.py │ └── training │ │ └── pretraining │ │ ├── accelerate │ │ └── train_llama.py │ │ └── huggingface │ │ └── train_llama.py └── external_scripts │ └── document_tokenization │ ├── LICENSE │ └── document_tokenization_pymupdf.py └── test ├── __init__.py ├── conftest.py ├── data ├── __init__.py ├── finetuning │ ├── __init__.py │ └── test_reading_order.py ├── preprocessing │ ├── __init__.py │ ├── example_json.py │ ├── test_doc_data_from_hocr.py │ └── test_doc_data_to_pickle.py └── pretraining │ ├── __init__.py │ ├── conftest.py │ ├── test_config.py │ ├── test_data_packing_pipe.py │ ├── test_dataset.py │ ├── test_masker.py │ ├── test_pipeline.py │ ├── test_precomputation_pipeline.py │ ├── test_tensor_data_loader_pipe.py │ └── test_traindata_pipe.py ├── modules ├── llama │ ├── __init__.py │ ├── conftest.py │ ├── helpers.py │ ├── test_attention.py │ ├── test_causal_docllm.py │ ├── test_decoder_layer.py │ └── test_llama_docllm.py ├── part_freezable_embedding_test_helpers.py ├── part_freezable_linear_test_helpers.py ├── test_part_freezable_embedding.py └── test_part_freezable_linear.py ├── test_pretraining_loss.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BlueCrescent 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DocLLM 2 | 3 | This is an implementation of the DocLLM paper for Llama models. Based on the paper "[DocLLM: A layout-aware generative language model for multimodal document understanding](https://arxiv.org/abs/2401.00908)". 4 | 5 | ## License 6 | 7 | Most of the code in this repository is published under MIT license. 8 | However, the script "src/external_scripts/document_tokenization/document_tokenization_pymupdf.py" is published [GNU Affero General Public License](https://www.gnu.org/licenses/agpl-3.0.html) due to it using [PyMuPDF](https://github.com/pymupdf/PyMuPDF). If another license for PyMuPDF is acquired, the script may also be used under that license. 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "DocLLM" 3 | version = "0.1.0" 4 | requires-python = ">=3.11,<3.13" 5 | description = "DocLLM implementation" 6 | dependencies = [ 7 | "torch", 8 | "transformers==4.46.3", 9 | "pydantic", 10 | "sentencepiece", 11 | "tensorboard", 12 | "numpy==1.*", 13 | "accelerate==0.34.2", 14 | "datasets", 15 | ] 16 | 17 | [project.optional-dependencies] 18 | linting = ["pre-commit"] 19 | tests = ["pytest"] 20 | 21 | [project.scripts] 22 | docllm = "docllm.__main__:main" 23 | 24 | [build-system] 25 | requires = ["setuptools >= 61.0.0"] 26 | build-backend = "setuptools.build_meta" 27 | 28 | [tool.black] 29 | target-version = ["py311"] 30 | line-length = 120 31 | 32 | [tool.isort] 33 | profile = "black" 34 | line_length = 120 35 | -------------------------------------------------------------------------------- /src/docllm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/src/docllm/__init__.py -------------------------------------------------------------------------------- /src/docllm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessing.data_structure import Block, Coord, Document, Line, Page, Rect, Token, Word, WritingMode 2 | from .preprocessing.doc_data_to_pickle import document_tokenization_from_data, document_tokenization_from_file 3 | from .pretraining.config import DocLLMPreTrainDataConfig, NumMaskedBlocksType 4 | from .pretraining.data_packing_pipe import DataPackingPipe 5 | from .pretraining.pipeline import build_docllm_datapipeline 6 | from .pretraining.tensor_data_loader_pipe import TensorDataLoaderPipe 7 | from .pretraining.traindata_pipe import DocLLMTrainDataPipe 8 | -------------------------------------------------------------------------------- /src/docllm/data/finetuning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/src/docllm/data/finetuning/__init__.py -------------------------------------------------------------------------------- /src/docllm/data/finetuning/reading_order.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, List, Optional, Tuple 2 | 3 | import torch 4 | from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt 5 | 6 | ZeroOneFloat = Annotated[float, Field(strict=True, ge=0, le=1)] 7 | CoordToken = ZeroOneFloat 8 | TokenRect = Tuple[CoordToken, CoordToken, CoordToken, CoordToken] 9 | 10 | 11 | class ReadingOrderConfig(BaseModel): 12 | max_seq_length: PositiveInt 13 | min_num_blocks_available: NonNegativeInt = 4 14 | min_num_tokens_available: NonNegativeInt = 1 15 | bos_text_token: Optional[int] 16 | bos_bbox_token: TokenRect = (0.0, 0.0, 0.0, 0.0) 17 | eos_text_token: Optional[int] 18 | eos_bbox_token: TokenRect = (1.0, 1.0, 1.0, 1.0) 19 | output_bbox_token: TokenRect = (0.0, 0.0, 0.0, 0.0) 20 | 21 | prompt_token_ids: List[NonNegativeInt] = Field(default_factory=list) 22 | 23 | 24 | class ReadingOrderSampleBuilder: 25 | def __init__(self, config: ReadingOrderConfig) -> None: 26 | self._max_seq_length = config.max_seq_length 27 | self._min_num_blocks_available = config.min_num_blocks_available 28 | self._min_num_tokens_available = config.min_num_tokens_available 29 | 30 | self._bos_text_token = [torch.tensor([config.bos_text_token])] if config.bos_text_token is not None else [] 31 | self._bos_bbox_token = [torch.tensor([config.bos_bbox_token])] if config.bos_text_token is not None else [] 32 | assert get_num_tokens(self._bos_text_token) == get_num_tokens(self._bos_bbox_token) 33 | self._eos_text_token = [torch.tensor([config.eos_text_token])] if config.eos_text_token is not None else [] 34 | self._eos_bbox_token = [torch.tensor([config.eos_bbox_token])] if config.eos_text_token is not None else [] 35 | assert get_num_tokens(self._eos_text_token) == get_num_tokens(self._eos_bbox_token) 36 | self._prompt_token_ids = [torch.tensor([tid]) for tid in config.prompt_token_ids] 37 | self._prompt_bboxes = [torch.tensor([[0.0, 0.0, 1.0, 1.0]]) for _ in config.prompt_token_ids] 38 | assert get_num_tokens(self._prompt_token_ids) == get_num_tokens(self._prompt_bboxes) 39 | self._output_bbox_token = torch.tensor([config.output_bbox_token]) 40 | 41 | def __call__( 42 | self, input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor] 43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor, torch.LongTensor]: 44 | input_tensors = list(input_tensors) if isinstance(input_tensors, tuple) else input_tensors 45 | bbox_tensors = list(bbox_tensors) if isinstance(bbox_tensors, tuple) else bbox_tensors 46 | 47 | num_blocks = len(input_tensors) 48 | num_tokens = get_num_tokens(input_tensors) 49 | self._check_inputs(bbox_tensors, num_blocks, num_tokens) 50 | 51 | shuffled_ids, shuffled_bboxes = self._build_shuffled_input_sequence(input_tensors, bbox_tensors, num_blocks) 52 | input_ids = self._build_input_ids(input_tensors, shuffled_ids) 53 | bboxes = self._build_bboxes(shuffled_bboxes, num_tokens) 54 | labels = self._build_labels(input_tensors, shuffled_ids) 55 | loss_mask, num_target_tokens = self._build_loss_mask(num_tokens) 56 | 57 | return self._truncate_to_max_seq_len(input_ids, bboxes, loss_mask, labels, num_target_tokens) 58 | 59 | def _build_shuffled_input_sequence( 60 | self, input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor], num_blocks: int 61 | ) -> Tuple[List[torch.LongTensor], List[torch.FloatTensor]]: 62 | permutation = torch.randperm(num_blocks) 63 | shuffled_inputs = [input_tensors[i] for i in permutation] 64 | shuffled_bboxes = [bbox_tensors[i] for i in permutation] 65 | return shuffled_inputs, shuffled_bboxes 66 | 67 | def _build_input_ids( 68 | self, input_tensors: List[torch.LongTensor], shuffled_ids: List[torch.LongTensor] 69 | ) -> torch.LongTensor: 70 | input_ids = torch.cat( 71 | self._bos_text_token 72 | + shuffled_ids 73 | + self._eos_text_token 74 | + self._prompt_token_ids 75 | + self._bos_text_token 76 | + input_tensors 77 | ) 78 | return input_ids 79 | 80 | def _build_bboxes( 81 | self, shuffled_bboxes: List[torch.FloatTensor], num_tokens: int 82 | ) -> torch.FloatTensor: 83 | target_bbox_tokens = [self._output_bbox_token] * num_tokens 84 | bboxes = torch.cat( 85 | self._bos_bbox_token 86 | + shuffled_bboxes 87 | + self._eos_bbox_token 88 | + self._prompt_bboxes 89 | + self._bos_bbox_token 90 | + target_bbox_tokens 91 | ) 92 | return bboxes 93 | 94 | def _build_labels( 95 | self, input_tensors: List[torch.LongTensor], shuffled_ids: List[torch.LongTensor] 96 | ) -> torch.LongTensor: 97 | labels = torch.cat( 98 | shuffled_ids 99 | + self._eos_text_token 100 | + self._prompt_token_ids 101 | + self._bos_text_token 102 | + input_tensors 103 | + self._eos_text_token 104 | ) 105 | return labels 106 | 107 | def _build_loss_mask(self, num_tokens: int) -> Tuple[torch.BoolTensor, int]: 108 | zero_mask = torch.zeros( 109 | num_tokens + len(self._eos_text_token) + len(self._prompt_token_ids) + len(self._bos_text_token), 110 | dtype=torch.bool, 111 | ) 112 | one_mask = torch.ones(num_tokens + len(self._eos_text_token), dtype=torch.bool) 113 | loss_mask = torch.cat([zero_mask, one_mask]) 114 | num_target_tokens = one_mask.size(0) 115 | return loss_mask, num_target_tokens - len(self._bos_text_token) 116 | 117 | def _check_inputs(self, bbox_tensors: List[torch.FloatTensor], num_blocks: int, num_tokens: int) -> None: 118 | if num_blocks < self._min_num_blocks_available: 119 | raise ValueError( 120 | f"Number of blocks ({num_blocks}) is less than the minimum " 121 | f"number of blocks required ({self._min_num_blocks_available})." 122 | ) 123 | if num_tokens < self._min_num_tokens_available: 124 | raise ValueError( 125 | f"Number of tokens ({num_tokens}) is less than the minimum " 126 | f"number of tokens required ({self._min_num_tokens_available})." 127 | ) 128 | if num_blocks != len(bbox_tensors): 129 | raise ValueError( 130 | f"Number of blocks ({num_blocks}) does not match the number of bounding boxes " 131 | f"({len(bbox_tensors)})." 132 | ) 133 | if num_tokens != get_num_tokens(bbox_tensors): 134 | raise ValueError( 135 | f"Number of tokens ({num_tokens}) does not match the number of bounding boxes " 136 | f"({get_num_tokens(bbox_tensors)})." 137 | ) 138 | 139 | def _truncate_to_max_seq_len( 140 | self, 141 | text_inputs: torch.LongTensor, 142 | spatial_inputs: torch.FloatTensor, 143 | loss_mask: torch.BoolTensor, 144 | label_tokens: torch.LongTensor, 145 | num_target_tokens: int, 146 | ) -> Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]: 147 | if text_inputs.size(0) <= self._max_seq_length: 148 | return text_inputs, spatial_inputs, loss_mask, label_tokens 149 | if text_inputs.size(0) - num_target_tokens + 1 >= self._max_seq_length: 150 | raise ValueError( 151 | f"Number of tokens ({text_inputs.size(0)}) exceeds maximal sequence " 152 | f"length ({self._max_seq_length}) with {num_target_tokens=}. " 153 | "Nothing would be learned. Skipping..." 154 | ) 155 | return ( 156 | text_inputs[: self._max_seq_length], 157 | spatial_inputs[: self._max_seq_length], 158 | loss_mask[: self._max_seq_length], 159 | label_tokens[: self._max_seq_length], 160 | ) 161 | 162 | 163 | def get_num_tokens(list_of_tensors: List[torch.Tensor]) -> int: 164 | """ 165 | Calculate the number of tokens/entries in a list of tensors. 166 | 167 | Args: 168 | list_of_tensors: List of tensors where each entry in dimension 0 corresponds to one token/entry. 169 | 170 | Returns: 171 | Number of tokens. 172 | """ 173 | return sum(tensor.shape[0] for tensor in list_of_tensors) 174 | -------------------------------------------------------------------------------- /src/docllm/data/precompute_epoch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | import os 4 | import random 5 | from typing import Iterable, List, Protocol, Tuple 6 | 7 | import torch 8 | 9 | 10 | class SampleBuilder(Protocol): 11 | def __call__( 12 | self, input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor] 13 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor, torch.LongTensor]: 14 | pass 15 | 16 | 17 | class SampleBuilderFactory(Protocol): 18 | def __call__(self) -> SampleBuilder: 19 | pass 20 | 21 | 22 | def precompute_epoch_files( 23 | directory: str, 24 | output_directory: str, 25 | max_seq_length: int, 26 | num_processes: int, 27 | sample_builder_factory: SampleBuilderFactory, 28 | start_idx: int = 0, 29 | end_idx: int = -1, 30 | ) -> None: 31 | files = sorted(list_pt_files(directory))[start_idx:end_idx] 32 | random.shuffle(files) 33 | queue = multiprocessing.Queue() 34 | for file in files: 35 | queue.put(file) 36 | 37 | workers = [ 38 | multiprocessing.Process( 39 | target=precompute_worker, 40 | args=(queue, output_directory, max_seq_length, f"subset_{i}_file", sample_builder_factory), 41 | ) 42 | for i in range(num_processes) 43 | ] 44 | try: 45 | for worker in workers: 46 | worker.start() 47 | for _ in range(num_processes): 48 | queue.put(None) 49 | except Exception as e: 50 | print(f"Error while precomputing files: {e}") 51 | for worker in workers: 52 | worker.terminate() 53 | raise e 54 | finally: 55 | for worker in workers: 56 | worker.join() 57 | 58 | 59 | def precompute_worker( 60 | queue: multiprocessing.Queue, 61 | output_directory: str, 62 | max_seq_length: int, 63 | prefix: str, 64 | sample_builder_factory: SampleBuilderFactory, 65 | ) -> None: 66 | precomputer = SubsetPrecomputer(output_directory, max_seq_length, prefix, sample_builder_factory) 67 | precomputer.process(IterableQueue(queue)) 68 | 69 | 70 | class SubsetPrecomputer: 71 | def __init__( 72 | self, output_directory: str, max_seq_length: int, prefix: str, sample_builder_factory: SampleBuilderFactory 73 | ): 74 | self._output_directory = output_directory 75 | self._max_sequence_length = max_seq_length 76 | self._prefix = prefix 77 | self._sample_builder: SampleBuilder = sample_builder_factory() 78 | self._elements = [] 79 | self._total_size = 0 80 | 81 | def process(self, files: Iterable[str]) -> None: 82 | self._save_packed_data(self._build_packed_data(self._build_training_inputs(self._load_pickle_file(files)))) 83 | 84 | def _load_pickle_file(self, file_paths: Iterable[str]) -> Iterable[Tuple[List[torch.Tensor], List[torch.Tensor]]]: 85 | for path in file_paths: 86 | file_data = torch.load(path, map_location="cpu", weights_only=True) 87 | if len(file_data) == 0: 88 | print(f"File '{path}' is empty.") 89 | continue 90 | text_tokens, bbox_tokens = tuple(map(list, zip(*file_data))) 91 | if self._check_data(text_tokens, bbox_tokens, path): 92 | yield text_tokens, bbox_tokens 93 | 94 | def _check_data(self, text_tokens: List[torch.Tensor], bbox_tokens: List[torch.Tensor], path: str) -> bool: 95 | if len(text_tokens) != len(bbox_tokens): 96 | logging.warning( 97 | f"Length of text tokens ({len(text_tokens)}) and bbox tokens " 98 | f"({len(bbox_tokens)}) must match. (In file {path})" 99 | ) 100 | return False 101 | if not all(bbox_tokens[i].size() == (text_tokens[i].size()[0], 4) for i in range(len(text_tokens))): 102 | logging.warning("Unexpected bounding box shape. (In file {path})") 103 | return False 104 | return True 105 | 106 | def _build_training_inputs( 107 | self, input_and_bbox_tensors: Iterable[Tuple[List[torch.Tensor], List[torch.Tensor]]] 108 | ) -> Iterable[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]]: 109 | for input_tensors, bbox_tensors in input_and_bbox_tensors: 110 | try: 111 | yield self._sample_builder(input_tensors, bbox_tensors) 112 | except ValueError as e: 113 | logging.warning(e) 114 | 115 | def _build_packed_data( 116 | self, file_data: Iterable[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]] 117 | ) -> Iterable[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]]: 118 | for item in file_data: 119 | item_size = item[0].size(0) 120 | if self._total_size + item_size > self._max_sequence_length: 121 | yield self._build_packed_entry(self._elements) 122 | self._elements = [] 123 | self._total_size = 0 124 | self._elements.append(item) 125 | self._total_size += item_size 126 | 127 | if self._elements: 128 | yield self._build_packed_entry(self._elements) 129 | 130 | def _build_packed_entry( 131 | self, elemets: List[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]] 132 | ) -> Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]: 133 | inputs, bboxes, mask, labels = zip(*elemets) 134 | sizes = torch.LongTensor([i.size(0) for i in inputs]) 135 | return torch.cat(inputs), torch.cat(bboxes), torch.cat(mask), torch.cat(labels), sizes 136 | 137 | def _save_packed_data( 138 | self, 139 | packed_data: Iterable[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]], 140 | ) -> None: 141 | for i, data in enumerate(packed_data): 142 | new_fn = os.path.join(self._output_directory, self._prefix + f"_{i}.pt") 143 | torch.save(data, new_fn) 144 | 145 | 146 | class IterableQueue: 147 | def __init__(self, queue: multiprocessing.Queue): 148 | self.queue = queue 149 | 150 | def __iter__(self): 151 | while True: 152 | task = self.queue.get() 153 | if task is None: 154 | break 155 | yield task 156 | 157 | 158 | def list_pt_files(directory: str) -> List[str]: 159 | return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".pt")] 160 | -------------------------------------------------------------------------------- /src/docllm/data/precomputed_epoch_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from typing import Dict, List, Tuple 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from docllm.data.pretraining.collate import collate 9 | 10 | PackedData = Tuple[ 11 | torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.BoolTensor, torch.LongTensor 12 | ] 13 | 14 | 15 | def build_epoch_dataloader(directory: str, batch_size: int, padding_value: float, offset_idx: int = 0) -> DataLoader: 16 | dataset = PrecomputedEpoch(directory, offset_idx=offset_idx) 17 | return torch.utils.data.DataLoader( 18 | dataset, batch_size=batch_size, collate_fn=partial(collate_with_packing, padding_value=padding_value) 19 | ) 20 | 21 | 22 | class PrecomputedEpoch(Dataset): 23 | def __init__(self, directory: str, offset_idx: int = 0): 24 | self._directory = directory 25 | self._file_names = [f for f in os.listdir(directory) if f.endswith(".pt")] 26 | self._offset_idx = offset_idx 27 | 28 | def __len__(self): 29 | return len(self._file_names) + self._offset_idx 30 | 31 | def __getitem__(self, idx: int) -> PackedData: 32 | idx -= self._offset_idx 33 | if idx < 0: 34 | raise IndexError(f"Index {idx + self._offset_idx} is bellow given offset.") 35 | file_path = os.path.join(self._directory, self._file_names[idx]) 36 | input_ids, bbox, loss_mask, labels, sizes = torch.load(file_path) 37 | 38 | attention_mask, position_ids = compute_packing_attention_mask_and_pos_ids(sizes) 39 | 40 | return input_ids, bbox, loss_mask, labels, position_ids, attention_mask 41 | 42 | 43 | def compute_packing_attention_mask_and_pos_ids(sizes: torch.LongTensor) -> Tuple[torch.BoolTensor, torch.LongTensor]: 44 | total_len = sizes.sum() 45 | mask = torch.zeros((total_len, total_len), dtype=torch.bool) 46 | pos_ids = torch.zeros(total_len, dtype=torch.long) 47 | start = 0 48 | for size in sizes: 49 | end = start + size 50 | mask[start:end, start:end] = torch.tril(torch.ones((size, size), dtype=torch.bool)) 51 | pos_ids[start:end] = torch.arange(size) 52 | start = end 53 | return mask, pos_ids 54 | 55 | 56 | def collate_as_dict( 57 | inputs: List[PackedData], padding_value: float, additional_col: bool = False 58 | ) -> Dict[str, torch.Tensor]: 59 | input_ids, bbox, loss_mask, labels, position_ids, attention_mask = collate_with_packing( 60 | inputs, padding_value, additional_col 61 | ) 62 | return { 63 | "input_ids": input_ids, 64 | "input_coordinates": bbox.clamp(min=0.0, max=1.0), # FIXME 65 | "loss_mask": loss_mask, 66 | "labels": labels, 67 | "position_ids": position_ids, 68 | "attention_mask": attention_mask, 69 | } 70 | 71 | 72 | def collate_with_packing(inputs: List[PackedData], padding_value: float, additional_col: bool = False) -> PackedData: 73 | att_mask_idx = 5 74 | res = collate(inputs, padding_value, ignore_indices={att_mask_idx}) 75 | input_ids, bbox, loss_mask, labels, position_ids = res 76 | max_length = input_ids.size(1) 77 | padded_attention_masks = torch.stack( 78 | [_pad_attention_mask(t[att_mask_idx], max_length, additional_col) for t in inputs] 79 | ) 80 | assert padded_attention_masks.size(1) == max_length 81 | assert padded_attention_masks.size(2) == max_length + 1 82 | return input_ids, bbox, loss_mask, labels, position_ids, padded_attention_masks 83 | 84 | 85 | def _pad_attention_mask(mask: torch.BoolTensor, to_length: int, additional_col: bool) -> torch.BoolTensor: 86 | # if mask.size(0) == to_length: 87 | # return mask 88 | return torch.cat( 89 | [ 90 | torch.cat([mask, torch.zeros(to_length - mask.size(0), mask.size(1), device=mask.device)], dim=0), 91 | torch.zeros(to_length, to_length - mask.size(1) + int(additional_col), device=mask.device), 92 | ], 93 | dim=1, 94 | ) 95 | 96 | 97 | # attention_mask = torch.cat([attention_mask, torch.zeros(2, attention_mask.size(1), 1, dtype=torch.bool)], dim=2) 98 | -------------------------------------------------------------------------------- /src/docllm/data/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/docllm/data/preprocessing/data_structure.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | from typing import List, Tuple 4 | 5 | from pydantic import BaseModel 6 | 7 | Coord = float 8 | Rect = Tuple[Coord, Coord, Coord, Coord] 9 | 10 | 11 | class WritingMode(Enum): 12 | horizontal = 0 13 | vertical = 1 14 | 15 | 16 | class Token(BaseModel): 17 | text: str 18 | token_id: int 19 | bbox: Rect 20 | 21 | 22 | class Word(BaseModel): 23 | text: str 24 | tokens: List[Token] 25 | bbox: Rect 26 | 27 | 28 | class Line(BaseModel): 29 | text: str 30 | words: List[Word] 31 | direction: Tuple[float, float] 32 | writing_mode: WritingMode 33 | bbox: Rect 34 | 35 | 36 | class Block(BaseModel): 37 | lines: List[Line] 38 | bbox: Rect 39 | 40 | def get_text(self) -> str: 41 | return "\n".join(line.text for line in self.lines) 42 | 43 | 44 | class Page(BaseModel): 45 | blocks: List[Block] 46 | width: Coord 47 | height: Coord 48 | 49 | def get_text(self, block_separator: str = "\n") -> str: 50 | return block_separator.join(block.get_text() for block in self.blocks) 51 | 52 | 53 | class Document(BaseModel): 54 | pages: List[Page] 55 | filename: str 56 | 57 | def get_text(self, page: int, block_separator: str = "\n") -> str: 58 | return self.pages[page].get_text(block_separator=block_separator) 59 | 60 | def save_as_json(self, directory: str): 61 | basename = os.path.splitext(os.path.basename(self.filename))[0] 62 | filename = os.path.join(directory, basename + ".json") 63 | if os.path.exists(filename): 64 | raise IOError(f"Output file '{filename}' already exists. Cannot save document as JSON.") 65 | as_json = self.model_dump_json() 66 | with open(filename, "w") as f: 67 | f.write(as_json) 68 | -------------------------------------------------------------------------------- /src/docllm/data/preprocessing/doc_data_from_hocr.py: -------------------------------------------------------------------------------- 1 | import re 2 | import xml.etree.ElementTree as ET 3 | from collections import Counter 4 | from itertools import groupby 5 | from typing import Iterable, List, Tuple 6 | 7 | from pydantic import BaseModel, ConfigDict 8 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 9 | 10 | from docllm.data.preprocessing.data_structure import Block, Document, Line, Page, Token, Word, WritingMode 11 | 12 | 13 | class Setup(BaseModel): 14 | model_config = ConfigDict(arbitrary_types_allowed=True) 15 | 16 | tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast 17 | build_token_bboxes: bool = False 18 | begin_of_word: str = "▁" 19 | 20 | 21 | def parse_hocr_document(hocr_path: str, setup: Setup) -> Document: 22 | root = parse_without_namespaces(hocr_path) 23 | body = root.findall(".//body") 24 | pages = body[0].findall(".//div[@class='ocr_page']") 25 | return Document(filename=hocr_path, pages=[pp for page in pages if (pp := parse_hocr_page(page, setup)).blocks]) 26 | 27 | 28 | def parse_without_namespaces(hocr_path: str) -> ET.Element: 29 | it = ET.iterparse(hocr_path) 30 | for _, el in it: 31 | _, _, el.tag = el.tag.rpartition("}") 32 | return it.root 33 | 34 | 35 | def parse_hocr_page(page: ET.Element, setup: Setup) -> Page: 36 | blocks = page.findall(".//content_area") 37 | _, _, width, height = page.attrib["title"].split("bbox")[1].split(";")[0].strip().split() 38 | return Page( 39 | blocks=[pb for block in blocks if (pb := parse_hocr_block(block, setup)).lines], 40 | width=float(width), 41 | height=float(height), 42 | ) 43 | 44 | 45 | def parse_hocr_block(block: ET.Element, setup: Setup) -> Block: 46 | paragraphs = block.findall(".//paragraph") 47 | lines = [l for p in paragraphs for l in p.findall(".//span[@class='ocr_line']")] 48 | parsed_lines = [pl for line in lines if (pl := parse_hocr_line(line, setup)).words] 49 | if not parsed_lines: 50 | return Block(lines=[], bbox=(0.0, 0.0, 0.0, 0.0)) 51 | return Block(lines=parsed_lines, bbox=join_rects(*(l.bbox for l in parsed_lines))) 52 | 53 | 54 | def parse_hocr_line(line: ET.Element, setup: Setup) -> Line: 55 | words = [w for w in line.findall(".//span[@class='ocrx_word']") if w.text is not None] 56 | bbox = parse_bbox(line) 57 | writing_mode = determine_writing_mode(words) 58 | return Line( 59 | text=" ".join(w.text for w in words), 60 | words=[parse_hocr_word(word, writing_mode, setup) for word in words], 61 | direction=(float(bbox[2]) - float(bbox[0]), float(bbox[3]) - float(bbox[1])), 62 | writing_mode=writing_mode, 63 | bbox=bbox, 64 | ) 65 | 66 | 67 | def determine_writing_mode(words: List[ET.Element]) -> WritingMode: 68 | if len(words) == 0: 69 | return WritingMode.horizontal 70 | word_writing_modes = Counter(parse_writing_mode_from_baseline(word) for word in words) 71 | writing_mode = word_writing_modes.most_common(1)[0][0] 72 | return writing_mode 73 | 74 | 75 | def parse_writing_mode_from_baseline(word: ET.Element) -> WritingMode: 76 | baseline_str = [el for el in word.attrib["title"].split(";") if el.strip().startswith("baseline")][0].strip() 77 | minx, miny, maxx, maxy = map(float, baseline_str.split(" ")[1:]) 78 | if maxx - minx > maxy - miny: 79 | return WritingMode.horizontal 80 | return WritingMode.vertical 81 | 82 | 83 | def parse_hocr_word(word: ET.Element, writing_mode: WritingMode, setup: Setup) -> Word: 84 | bbox = parse_bbox(word) 85 | if setup.build_token_bboxes: 86 | tokens = list(build_word_tokens(word.text, setup.tokenizer, setup.begin_of_word, bbox, writing_mode)) 87 | else: 88 | tokens = list(build_word_tokens_with_same_bounding_box(word.text, setup.tokenizer, bbox)) 89 | return Word(text=word.text, tokens=tokens, bbox=bbox) 90 | 91 | 92 | def parse_bbox(element: ET.Element) -> Tuple[float, float, float, float]: 93 | bbox_str = [el for el in element.attrib["title"].split(";") if el.strip().startswith("bbox")][0].strip() 94 | return tuple(map(float, bbox_str.split(" ")[1:])) 95 | 96 | 97 | def join_rects(*rects: Tuple[float, float, float, float]) -> Tuple[float, float, float, float]: 98 | lefts, tops, rights, bottoms = zip(*rects) 99 | return min(lefts), min(tops), max(rights), max(bottoms) 100 | 101 | 102 | def build_word_tokens_with_same_bounding_box( 103 | text: str, tokenizer: PreTrainedTokenizer, word_bbox: Tuple[float, float, float, float] 104 | ) -> Iterable[Token]: 105 | token_ids = tokenize_with_leading_space(text, tokenizer) 106 | tokens = tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=True) 107 | for t, t_id in zip(tokens, token_ids): 108 | yield Token(text=t, token_id=t_id, bbox=word_bbox) 109 | 110 | 111 | def tokenize_with_leading_space( 112 | text: str, 113 | tokenizer: PreTrainedTokenizer, 114 | ) -> List[int]: 115 | # Always treat words as part of an ongoing text sequence by adding a leading space. 116 | return tokenizer(" " + text, return_attention_mask=False)["input_ids"] 117 | 118 | 119 | def build_word_tokens( 120 | text: str, 121 | tokenizer: PreTrainedTokenizer, 122 | begin_of_word: str, 123 | word_bbox: Tuple[float, float, float, float], 124 | writing_mode: WritingMode, 125 | ) -> Iterable[Token]: 126 | grouped_tokens_and_ids = tokenize_word(text, tokenizer, begin_of_word) 127 | if writing_mode == WritingMode.horizontal: 128 | yield from build_horizontal_tokens(text, word_bbox, grouped_tokens_and_ids) 129 | else: 130 | yield from build_vertical_tokens(text, word_bbox, grouped_tokens_and_ids) 131 | 132 | 133 | def tokenize_word(text: str, tokenizer: PreTrainedTokenizer, begin_of_word: str) -> Iterable[List[Tuple[str, int]]]: 134 | token_ids = tokenizer(text, return_attention_mask=False)["input_ids"] 135 | tokens = tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=True) 136 | tokens = map(lambda t: re.sub(f"^{begin_of_word}", "", t), tokens) 137 | tokens_and_ids = zip(tokens, token_ids) 138 | tokens_and_ids = filter(lambda t: t[0], tokens_and_ids) 139 | grouped_tokens_and_ids = group_consecutive_escaped_tokens(tokens_and_ids) 140 | return grouped_tokens_and_ids 141 | 142 | 143 | def build_horizontal_tokens( 144 | text: str, word_bbox: Tuple[float, float, float, float], grouped_tokens_and_ids: Iterable[List[Tuple[str, int]]] 145 | ) -> Iterable[Token]: 146 | left, top, right, bottom = word_bbox 147 | char_width = (right - left) / len(text) 148 | for group in grouped_tokens_and_ids: 149 | if len(group) == 1: 150 | t, tid = group[0] 151 | token, left = build_horizontal_token(left, top, bottom, char_width, t, tid, 1.0) 152 | yield token 153 | else: 154 | scale = compute_token_scale(group) 155 | for t, tid in group: 156 | token, left = build_horizontal_token(left, top, bottom, char_width, t, tid, scale) 157 | yield token 158 | 159 | 160 | def build_horizontal_token( 161 | left: float, top: float, bottom: float, char_width: float, t: str, tid: int, scale: float 162 | ) -> Tuple[Token, float]: 163 | token_width = token_length(t) * char_width * scale 164 | bbox = (left, top, left + token_width, bottom) 165 | left += token_width 166 | token = Token(text=t, token_id=tid, bbox=bbox) 167 | return token, left 168 | 169 | 170 | def build_vertical_tokens( 171 | text: str, word_bbox: Tuple[float, float, float, float], grouped_tokens_and_ids: Iterable[List[Tuple[str, int]]] 172 | ) -> Iterable[Token]: 173 | left, top, right, bottom = word_bbox 174 | char_width = (bottom - top) / len(text) 175 | for group in grouped_tokens_and_ids: 176 | if len(group) == 1: 177 | t, tid = group[0] 178 | token, top = build_vertical_token(left, top, right, char_width, t, tid, 1.0) 179 | yield token 180 | else: 181 | scale = compute_token_scale(group) 182 | for t, tid in group: 183 | token, top = build_vertical_token(left, top, right, char_width, t, tid, scale) 184 | yield token 185 | 186 | 187 | def build_vertical_token( 188 | left: float, top: float, right: float, char_width: float, t: str, tid: int, scale: float 189 | ) -> Tuple[Token, float]: 190 | token_width = token_length(t) * char_width * scale 191 | bbox = (left, top, right, top + token_width) 192 | top += token_width 193 | token = Token(text=t, token_id=tid, bbox=bbox) 194 | return token, top 195 | 196 | 197 | def compute_token_scale(group: List[Tuple[str, int]]) -> float: 198 | token_bytes = "".join(map(extract_bytes, (t for t, _ in group))) 199 | detok_text = bytes.fromhex(token_bytes).decode("utf-8") 200 | scale = len(detok_text) / len(group) 201 | return scale 202 | 203 | 204 | def token_length(token: str) -> int: 205 | if is_escaped(token): 206 | return 1 207 | return len(token) 208 | 209 | 210 | def extract_bytes(token: str) -> bytes: 211 | assert len(token) == 6 212 | return token[3:5] 213 | 214 | 215 | def group_consecutive_escaped_tokens(tokens_and_ids: Iterable[Tuple[str, int]]) -> Iterable[List[Tuple[str, int]]]: 216 | get_idx = lambda idx_tok_id: idx_tok_id[0] 217 | get_token = lambda idx_tok_id: idx_tok_id[1][0] 218 | get_group_index = lambda idx_tok_id: -1 if is_escaped(get_token(idx_tok_id)) else get_idx(idx_tok_id) 219 | drop_index = lambda idx_tok_id: idx_tok_id[1] 220 | return (list(map(drop_index, g)) for _, g in groupby(enumerate(tokens_and_ids), key=get_group_index)) 221 | 222 | 223 | def is_escaped(token: str) -> bool: 224 | return token.startswith("<") and token.endswith(">") 225 | -------------------------------------------------------------------------------- /src/docllm/data/preprocessing/doc_data_to_pickle.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | from typing import Any, Dict, Iterable, List, Tuple 6 | 7 | import torch 8 | 9 | from docllm.data.preprocessing.data_structure import Document 10 | 11 | Coord = float 12 | Rect = Tuple[Coord, Coord, Coord, Coord] 13 | 14 | 15 | def document_tokenization_from_file(filename: str, output_dir: str, use_page_dimensions: bool) -> None: 16 | basename = os.path.splitext(os.path.basename(filename))[0] 17 | with open(filename, "r") as file: 18 | document_tokenization = json.load(file) 19 | for i, page_block_tokens in enumerate( 20 | document_tokenization_to_pagewise_block_tokens(document_tokenization, use_page_dimensions) 21 | ): 22 | torch.save(page_block_tokens, os.path.join(output_dir, f"{basename}_{i}.pt")) 23 | 24 | 25 | def document_tokenization_from_data(doc_data: Document, output_dir: str, use_page_dimensions: bool) -> None: 26 | basename = os.path.splitext(os.path.basename(doc_data.filename))[0] 27 | document_tokenization = doc_data.model_dump() 28 | for i, page_block_tokens in enumerate( 29 | document_tokenization_to_pagewise_block_tokens(document_tokenization, use_page_dimensions) 30 | ): 31 | torch.save(page_block_tokens, os.path.join(output_dir, f"{basename}_{i}.pt")) 32 | 33 | 34 | def document_tokenization_to_pagewise_block_tokens( 35 | document_tokenization: str, use_page_dimensions: bool 36 | ) -> Iterable[List[torch.Tensor]]: 37 | for page in document_tokenization["pages"]: 38 | tokenizer = BoundingBoxTokenizer(page, use_page_dimensions) 39 | yield list(page_tokenization_to_block_tokens(page, tokenizer)) 40 | 41 | 42 | def page_tokenization_to_block_tokens( 43 | page_tokenization: Dict[str, Any], tokenizer: BoundingBoxTokenizer 44 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: 45 | for block in page_tokenization["blocks"]: 46 | tokens = [ 47 | (token["token_id"], tokenizer(token["bbox"])) 48 | for line in block["lines"] 49 | for word in line["words"] 50 | for token in word["tokens"] 51 | ] 52 | if len(tokens) == 0: 53 | continue 54 | text_tokens, bb_tokens = zip(*tokens) 55 | text_tokens = torch.tensor(text_tokens, dtype=torch.long) 56 | bb_tokens = torch.stack(bb_tokens) 57 | yield text_tokens, bb_tokens 58 | 59 | 60 | class BoundingBoxTokenizer: 61 | def __init__(self, page: Dict[str, Any], use_page_dimensions: bool) -> None: 62 | if use_page_dimensions: 63 | self._width = page["width"] 64 | self._height = page["height"] 65 | self._minx = 0.0 66 | self._miny = 0.0 67 | else: 68 | self._minx, self._miny, self._width, self._height = self._compute_min_max_dimensions(page) 69 | 70 | def __call__(self, bounding_box: Rect) -> torch.FloatTensor: 71 | return torch.tensor( 72 | [ 73 | (bounding_box[0] - self._minx) / self._width, 74 | (bounding_box[1] - self._miny) / self._height, 75 | (bounding_box[2] - self._minx) / self._width, 76 | (bounding_box[3] - self._miny) / self._height, 77 | ] 78 | ) 79 | 80 | def _compute_min_max_dimensions(self, page: Dict[str, Any]) -> Tuple[Coord, Coord, Coord, Coord]: 81 | minx, miny, maxx, maxy = float("inf"), float("inf"), float("-inf"), float("-inf") 82 | for block in page["blocks"]: 83 | for line in block["lines"]: 84 | for word in line["words"]: 85 | for token in word["tokens"]: 86 | minx = min(minx, token["bbox"][0]) 87 | miny = min(miny, token["bbox"][1]) 88 | maxx = max(maxx, token["bbox"][2]) 89 | maxy = max(maxy, token["bbox"][3]) 90 | return minx, miny, maxx - minx, maxy - miny 91 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/build_hf_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from functools import partial 4 | from typing import Dict, List, Tuple 5 | 6 | import torch 7 | from datasets import Array2D, Dataset, Features, Sequence, Value 8 | 9 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 10 | from docllm.data.pretraining.precomputation_pipeline import generate_precomputed_data 11 | 12 | 13 | def precompute_dataset(config: DocLLMPreTrainDataConfig) -> Dataset: 14 | features = Features( 15 | { 16 | "input_ids": Sequence(feature=Value(dtype="int64")), 17 | "input_coordinates": Array2D(dtype="int64", shape=(config.max_seq_length, 4)), 18 | "loss_mask": Sequence(feature=Value(dtype="int64")), 19 | "labels": Sequence(feature=Value(dtype="int64")), 20 | } 21 | ) 22 | return Dataset.from_generator( 23 | lambda: map(to_model_input, generate_precomputed_data(config=config)), 24 | # features=features # TODO: Can this be used? 25 | ) 26 | 27 | 28 | def to_model_input( 29 | data: Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor] | List[torch.Tensor], 30 | ) -> Dict[str, torch.Tensor]: 31 | return {"input_ids": data[0], "input_coordinates": data[1], "loss_mask": data[2], "labels": data[3]} 32 | 33 | 34 | def prepare_docllm_data(dirs: List[str], max_sequence_length: int): 35 | for dir in dirs: 36 | for file_name in os.listdir(dir): 37 | file_data = torch.load(file_name, map_location="cpu", weights_only=True) 38 | if len(file_data) > 0 and len(file_data[0]) > 0 and len(file_data[0][-1]) < max_sequence_length: 39 | text_tokens, bbox_tokens, loss_mask, labels = file_data 40 | yield {"input_ids": text_tokens, "attention_mask": loss_mask, "bbox": bbox_tokens, "labels": labels} 41 | else: 42 | logging.warning( 43 | f"File {file_name} is not valid. ({len(file_data)=}, {len(file_data[0])=}, {len(file_data[0][-1])=})" 44 | ) 45 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/collate.py: -------------------------------------------------------------------------------- 1 | from typing import List, Set, Tuple 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | 7 | def collate( 8 | items: List[Tuple[torch.Tensor, ...]], padding_value: float, ignore_indices: Set[int] = {} 9 | ) -> Tuple[torch.Tensor, ...]: 10 | """Turns a list of tuples of tensors into a tuple of collated tensors. 11 | The nth entry of each tuple in the list is collated together as the nth entry of the output tuple. 12 | 13 | Args: 14 | items (List[Tuple[torch.Tensor, ...]]): A list of tuples containing tensors to be collated. 15 | padding_value (float): The value to be used for all padding. Probably zero. 16 | ignore_indices (Set[int], optional): A set of indices to ignore. Defaults to {}. 17 | 18 | Returns: 19 | Tuple[torch.Tensor, ...]: A tuple of collated tensors. 20 | """ 21 | return tuple( 22 | collate_tensors(tensors, padding_value=padding_value) 23 | for i, tensors in enumerate(zip(*items)) 24 | if i not in ignore_indices 25 | ) 26 | 27 | 28 | def collate_tensors(inputs: List[torch.Tensor], padding_value: float) -> torch.Tensor: 29 | return pad_sequence(inputs, batch_first=True, padding_value=padding_value) 30 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from collections.abc import Sequence 5 | from typing import Annotated, Optional, Protocol, Tuple, runtime_checkable 6 | 7 | import torch 8 | from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt, ValidationInfo, field_validator 9 | 10 | ZeroOneFloat = Annotated[float, Field(strict=True, ge=0, le=1)] 11 | CoordToken = ZeroOneFloat 12 | TokenRect = Tuple[CoordToken, CoordToken, CoordToken, CoordToken] 13 | Interval = Tuple[PositiveInt, PositiveInt] | Tuple[ZeroOneFloat, ZeroOneFloat] 14 | NumMaskedBlocksType = PositiveInt | ZeroOneFloat | Interval 15 | 16 | 17 | class DocLLMPreTrainDataConfig(BaseModel): 18 | batch_size: PositiveInt 19 | drop_last_batch_if_not_full: bool = False 20 | shuffle: bool = True 21 | shuffle_buffer_size: PositiveInt = 10000 22 | use_sharding_filter: bool = True 23 | use_packing: bool = False 24 | max_seq_length: PositiveInt 25 | 26 | # This can encode either an absolute or relative count. 27 | # If an interval is given, the number of masked blocks 28 | # will be sampled uniformly from the interval each time. 29 | num_masked_blocks: NumMaskedBlocksType 30 | # Num blocks times the value will be ceiled to get the 31 | # maximum number of blocks that can be masked. 32 | max_percentage_masked_blocks: ZeroOneFloat = 0.15 33 | # Minimum of number of blocks that should be available 34 | # in the document. This is used to prevent masking 35 | # too many blocks in a document that is too small. 36 | min_num_blocks_available: NonNegativeInt = 6 37 | 38 | mask_text_token: int = 0 39 | mask_bbox_token: TokenRect = (0.0, 0.0, 0.0, 0.0) 40 | block_start_text_token: int 41 | block_end_text_token: int 42 | bos_text_token: Optional[int] 43 | bos_bbox_token: Optional[TokenRect] = None 44 | padding_value: float = 0.0 45 | 46 | directory: str | Sequence[str] 47 | 48 | @property 49 | def num_masked_blocks_callable(self) -> NumMaskedBlocks: 50 | return _build_num_masked_blocks_callable(self.num_masked_blocks) 51 | 52 | @field_validator("num_masked_blocks", mode="after") 53 | def check_interval_is_not_empty(cls, v: NumMaskedBlocksType) -> NumMaskedBlocksType: 54 | if isinstance(v, tuple) and v[0] >= v[1]: 55 | raise ValueError("num_masked_blocks intervals cannot be empty or negative") 56 | return v 57 | 58 | @field_validator("bos_bbox_token", mode="after") 59 | def set_bos_bbox_token_if_not_set( 60 | cls, v: Optional[TokenRect], info: ValidationInfo # , values: Dict[str, Any] 61 | ) -> TokenRect: 62 | if v is None: 63 | return info.data["mask_bbox_token"] 64 | return v 65 | 66 | 67 | @runtime_checkable 68 | class NumMaskedBlocks(Protocol): 69 | def __call__(self, num_blocks: int) -> int: ... 70 | 71 | 72 | def _build_num_masked_blocks_callable(num_masked_blocks: NumMaskedBlocksType) -> NumMaskedBlocks: 73 | if isinstance(num_masked_blocks, int): 74 | return lambda _: num_masked_blocks 75 | elif isinstance(num_masked_blocks, float): 76 | return lambda num_blocks: math.ceil(num_masked_blocks * num_blocks) 77 | elif isinstance(num_masked_blocks, Tuple) and isinstance(num_masked_blocks[0], int): 78 | 79 | def num_masked_blocks_callable(num_blocks: int) -> int: 80 | return torch.randint(num_masked_blocks[0], num_masked_blocks[1], (1,)).item() 81 | 82 | return num_masked_blocks_callable 83 | elif isinstance(num_masked_blocks, Tuple) and isinstance(num_masked_blocks[0], float): 84 | ival_length = num_masked_blocks[1] - num_masked_blocks[0] 85 | 86 | def num_masked_blocks_callable(num_blocks: int) -> int: 87 | return math.ceil((num_masked_blocks[0] + torch.rand(1).item() * ival_length) * num_blocks) 88 | 89 | return num_masked_blocks_callable 90 | else: 91 | raise ValueError(f"Invalid num_masked_blocks type: {num_masked_blocks}") 92 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/data_packing_pipe.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Tuple 2 | 3 | import torch 4 | from torch.utils.data import IterDataPipe, functional_datapipe 5 | 6 | 7 | @functional_datapipe("pack_data") 8 | class DataPackingPipe(IterDataPipe): 9 | def __init__(self, source_datapipe: IterDataPipe, max_sequence_length: int) -> None: 10 | self._source_datapipe = source_datapipe 11 | self._max_sequence_length = max_sequence_length 12 | 13 | def __iter__(self) -> Iterable[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]]: 14 | elements = [] 15 | total_size = 0 16 | 17 | for item in self._source_datapipe: 18 | item_size = item[0].size(0) 19 | if total_size + item_size > self._max_sequence_length: 20 | inputs, bboxes, mask, labels = zip(*elements) 21 | yield torch.cat(inputs), torch.cat(bboxes), torch.cat(mask), torch.cat(labels) 22 | elements = [] 23 | total_size = 0 24 | elements.append(item) 25 | total_size += item_size 26 | 27 | if elements: 28 | inputs, bboxes, mask, labels = zip(*elements) 29 | yield torch.cat(inputs), torch.cat(bboxes), torch.cat(mask), torch.cat(labels) 30 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Iterable, List, Tuple 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 9 | from docllm.data.pretraining.masker import DocLLMPretrainingMasker 10 | 11 | 12 | class DocLLMPretrainDataset(Dataset): 13 | def __init__(self, config: DocLLMPreTrainDataConfig, skip_files: Iterable[str] = []): 14 | self._masker = DocLLMPretrainingMasker(config) 15 | self._files = sorted( 16 | set(glob(os.path.join(config.directory, "**/*.pt"), recursive=True)).difference(skip_files) 17 | ) 18 | 19 | def __len__(self): 20 | return len(self._files) 21 | 22 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor, torch.LongTensor]: 23 | file_data = torch.load(self._files[idx], map_location="cpu") 24 | if len(file_data) == 0: 25 | raise ValueError(f"File is empty.") 26 | text_tokens, bbox_tokens = tuple(map(list, zip(*file_data))) 27 | self._check_data(text_tokens, bbox_tokens) 28 | return self._masker(text_tokens, bbox_tokens) 29 | 30 | def _check_data(self, text_tokens: List[torch.Tensor], bbox_tokens: List[torch.Tensor]) -> None: 31 | if len(text_tokens) != len(bbox_tokens): 32 | raise ValueError( 33 | f"Length of text tokens ({len(text_tokens)}) and bbox tokens ({len(bbox_tokens)}) must match." 34 | ) 35 | if not all(bbox_tokens[i].size() == (text_tokens[i].size()[0], 4) for i in range(len(text_tokens))): 36 | raise ValueError("Unexpected bounding box shape.") 37 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/masker.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import List, Tuple 4 | 5 | import torch 6 | 7 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 8 | 9 | 10 | class DocLLMPretrainingMasker: 11 | def __init__(self, config: DocLLMPreTrainDataConfig) -> None: 12 | self._config = config.model_copy() 13 | self._get_num_masks = self._config.num_masked_blocks_callable 14 | self._min_num_blocks_available = self._config.min_num_blocks_available 15 | self._mask_text_token = torch.tensor([self._config.mask_text_token]) 16 | self._mask_bbox_token = torch.tensor([self._config.mask_bbox_token]) 17 | self._block_start_text_token = torch.tensor([self._config.block_start_text_token]) 18 | self._block_end_text_token = torch.tensor([self._config.block_end_text_token]) 19 | self._bos_text_token = [torch.tensor([self._config.bos_text_token])] if self._config.bos_text_token else [] 20 | self._bos_bbox_token = [torch.tensor([self._config.bos_bbox_token])] if self._config.bos_text_token else [] 21 | 22 | def __call__( 23 | self, input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor] 24 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor, torch.LongTensor]: 25 | num_blocks = len(input_tensors) 26 | if num_blocks < self._min_num_blocks_available: 27 | raise ValueError( 28 | f"Number of blocks ({num_blocks}) is less than the minimum " 29 | f"number of blocks required ({self._min_num_blocks_available})." 30 | ) 31 | num_masks = self._get_valid_nun_masks(num_blocks) 32 | mask_indices = self._get_mask_indices(num_blocks, num_masks) 33 | return self._create_masked_input(input_tensors, bbox_tensors, mask_indices) 34 | 35 | def _get_valid_nun_masks(self, num_blocks): 36 | num_masks = self._get_num_masks(num_blocks) 37 | if num_masks > (max_num_masks := self._compute_max_num_masks(num_blocks)): 38 | logging.warning( 39 | f"Number of masks ({num_masks}) cannot exceed maximal allowed number of " 40 | f"blocks (ceil({num_blocks} * {self._config.max_percentage_masked_blocks}))." 41 | ) 42 | num_masks = max_num_masks 43 | if num_masks <= 0: 44 | raise ValueError("Number of masks cannot be zero.") 45 | return num_masks 46 | 47 | def _compute_max_num_masks(self, num_blocks): 48 | return min(math.ceil(num_blocks * self._config.max_percentage_masked_blocks), num_blocks) 49 | 50 | def _get_mask_indices(self, num_blocks: int, num_masks: int) -> List[int]: 51 | # FIXME: Should we prevent adjacent blocks from being masked? 52 | weights = torch.ones(num_blocks) 53 | mask_indices = torch.multinomial(weights, num_masks, replacement=False) 54 | return sorted(map(torch.Tensor.item, mask_indices)) 55 | 56 | def _create_masked_input( 57 | self, input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor], mask_indices: List[int] 58 | ) -> Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]: 59 | input_tensors, bbox_tensors, target_tokens, label_tensors = self._extract_masked_tokens( 60 | input_tensors, bbox_tensors, mask_indices 61 | ) 62 | num_target_tokens = sum(tensor.shape[0] for tensor in target_tokens) 63 | text_inputs = torch.cat(self._bos_text_token + input_tensors + target_tokens) 64 | target_bbox_tokens = [self._mask_bbox_token] * num_target_tokens 65 | spatial_inputs = torch.cat(self._bos_bbox_token + bbox_tensors + target_bbox_tokens) 66 | loss_mask = torch.zeros_like(text_inputs, dtype=torch.bool) 67 | loss_mask[-num_target_tokens:] = 1 68 | label_tokens = torch.cat(input_tensors + label_tensors) 69 | return self._truncate_to_max_seq_len(text_inputs, spatial_inputs, loss_mask, label_tokens, num_target_tokens) 70 | 71 | def _extract_masked_tokens( 72 | self, input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor], mask_indices: List[int] 73 | ) -> Tuple[List[torch.LongTensor], List[torch.FloatTensor], List[torch.BoolTensor], List[torch.LongTensor]]: 74 | target_tokens = [] 75 | label_tensors = [self._block_end_text_token] 76 | for i in mask_indices: 77 | target_tokens.append(self._block_start_text_token) 78 | target_tokens.append(input_tensors[i]) 79 | label_tensors.append(input_tensors[i]) 80 | label_tensors.append(self._block_end_text_token) 81 | input_tensors[i] = self._mask_text_token 82 | bbox_tensors[i] = self._mask_bbox_token 83 | return input_tensors, bbox_tensors, target_tokens, label_tensors 84 | 85 | def _truncate_to_max_seq_len( 86 | self, 87 | text_inputs: torch.LongTensor, 88 | spatial_inputs: torch.FloatTensor, 89 | loss_mask: torch.BoolTensor, 90 | label_tokens: torch.LongTensor, 91 | num_target_tokens: int, 92 | ) -> Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]: 93 | if text_inputs.size(0) <= self._config.max_seq_length: 94 | return text_inputs, spatial_inputs, loss_mask, label_tokens 95 | if text_inputs.size(0) - num_target_tokens + 1 >= self._config.max_seq_length: 96 | raise ValueError( 97 | f"Number of tokens ({text_inputs.size(0)}) exceeds maximal sequence " 98 | f"length ({self._config.max_seq_length}) with {num_target_tokens=}. " 99 | "Nothing would be learned. Skipping..." 100 | ) 101 | return ( 102 | text_inputs[: self._config.max_seq_length], 103 | spatial_inputs[: self._config.max_seq_length], 104 | loss_mask[: self._config.max_seq_length], 105 | label_tokens[: self._config.max_seq_length], 106 | ) 107 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/pipeline.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import List, Tuple 3 | 4 | import torch 5 | from torch.utils.data import DataChunk, IterDataPipe 6 | from torch.utils.data.datapipes.iter import FileLister 7 | 8 | from docllm.data.pretraining.collate import collate 9 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 10 | 11 | 12 | def build_docllm_datapipeline(config: DocLLMPreTrainDataConfig) -> IterDataPipe: 13 | datapipe = FileLister(config.directory, masks="*.pt", recursive=True) 14 | if config.shuffle: 15 | datapipe = datapipe.shuffle(buffer_size=config.shuffle_buffer_size) 16 | if config.use_sharding_filter: 17 | datapipe = datapipe.sharding_filter() 18 | datapipe = datapipe.load_docllm_tensor_data() 19 | datapipe = datapipe.build_docllm_train_data(config=config) 20 | if config.use_packing: 21 | datapipe = datapipe.pack_data(max_sequence_length=config.max_seq_length) 22 | datapipe = datapipe.batch( 23 | config.batch_size, 24 | drop_last=config.drop_last_batch_if_not_full, 25 | wrapper_class=partial(BatchWithPaddingWrapperClass, padding_value=config.padding_value), 26 | ) 27 | return datapipe 28 | 29 | 30 | class BatchWithPaddingWrapperClass(DataChunk): 31 | def __init__( 32 | self, 33 | items: List[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]], 34 | padding_value: float, 35 | ): 36 | super().__init__(collate(items, padding_value)) 37 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/precomputation_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Iterable, List, Tuple 4 | 5 | import torch 6 | from torch.utils.data import IterDataPipe, functional_datapipe 7 | from torch.utils.data.datapipes.iter import FileLister 8 | 9 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 10 | from docllm.data.pretraining.pipeline import build_docllm_datapipeline 11 | 12 | 13 | def precompute_data(config: DocLLMPreTrainDataConfig, output_dir: str): 14 | for i, data in enumerate(generate_precomputed_data(config)): 15 | torch.save(data, os.path.join(output_dir, f"eval_{i}.pt")) 16 | 17 | 18 | def generate_precomputed_data(config: DocLLMPreTrainDataConfig) -> Iterable[List[torch.Tensor]]: 19 | config = config.model_copy() 20 | config.batch_size = 1 21 | data_gen_pipeline = build_docllm_datapipeline(config) 22 | for i, dat in enumerate(data_gen_pipeline): 23 | if len(dat[0]) == 0: 24 | logging.warning(f"Skipping empty data {i}...") 25 | continue 26 | yield list(debatch(dat)) 27 | 28 | 29 | def debatch( 30 | one_element_batch_sample: ( 31 | Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor] | List[torch.Tensor] 32 | ), 33 | ) -> Iterable[torch.Tensor]: 34 | for s in one_element_batch_sample: 35 | assert s.size(0) == 1 36 | yield s[0] 37 | 38 | 39 | def build_precomputed_data_pipeline(input_dir: str) -> IterDataPipe: 40 | datapipe = FileLister(input_dir, masks="*.pt", recursive=True, non_deterministic=False) 41 | datapipe = datapipe.load_docllm_precomputed_data() 42 | return datapipe 43 | 44 | 45 | @functional_datapipe("load_docllm_precomputed_data") 46 | class PrecomputedDataLoaderPipe(IterDataPipe): 47 | def __init__(self, source_datapipe: IterDataPipe) -> None: 48 | self._source_datapipe = source_datapipe 49 | 50 | def __iter__(self) -> Iterable[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]]: 51 | for file_name in self._source_datapipe: 52 | file_data = torch.load(file_name, map_location="cpu", weights_only=True) 53 | if len(file_data) > 0 and len(file_data[0]) > 0: 54 | text_tokens, bbox_tokens, loss_mask, labels = file_data 55 | yield text_tokens, bbox_tokens, loss_mask, labels 56 | else: 57 | logging.warning(f"File {file_name} is empty.") 58 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/tensor_data_loader_pipe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Iterable, List, Tuple 3 | 4 | import torch 5 | from torch.utils.data import IterDataPipe, functional_datapipe 6 | 7 | 8 | @functional_datapipe("load_docllm_tensor_data") 9 | class TensorDataLoaderPipe(IterDataPipe): 10 | def __init__(self, source_datapipe: IterDataPipe) -> None: 11 | self._source_datapipe = source_datapipe 12 | 13 | def __iter__(self) -> Iterable[Tuple[List[torch.Tensor], List[torch.Tensor]]]: 14 | for file_name in self._source_datapipe: 15 | try: 16 | yield self._try_parse_file(file_name) 17 | except Exception as e: 18 | logging.warning(f"Error while parsing file {file_name}: {e}") 19 | 20 | def _try_parse_file(self, file_name: str) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: 21 | file_data = torch.load(file_name, map_location="cpu", weights_only=True) 22 | if len(file_data) == 0: 23 | raise ValueError(f"File is empty.") 24 | text_tokens, bbox_tokens = tuple(map(list, zip(*file_data))) 25 | self._check_data(text_tokens, bbox_tokens) 26 | return text_tokens, bbox_tokens 27 | 28 | def _check_data(self, text_tokens: List[torch.Tensor], bbox_tokens: List[torch.Tensor]) -> None: 29 | if len(text_tokens) != len(bbox_tokens): 30 | raise ValueError( 31 | f"Length of text tokens ({len(text_tokens)}) and bbox tokens ({len(bbox_tokens)}) must match." 32 | ) 33 | if not all(bbox_tokens[i].size() == (text_tokens[i].size()[0], 4) for i in range(len(text_tokens))): 34 | raise ValueError("Unexpected bounding box shape.") 35 | -------------------------------------------------------------------------------- /src/docllm/data/pretraining/traindata_pipe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Iterable, List, Tuple 4 | 5 | import torch 6 | from torch.utils.data import IterDataPipe, functional_datapipe 7 | 8 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 9 | from docllm.data.pretraining.masker import DocLLMPretrainingMasker 10 | 11 | 12 | @functional_datapipe("build_docllm_train_data") 13 | class DocLLMTrainDataPipe(IterDataPipe): 14 | def __init__(self, source_datapipe: IterDataPipe, config: DocLLMPreTrainDataConfig) -> None: 15 | self._masker = DocLLMPretrainingMasker(config) 16 | self._source_datapipe = source_datapipe 17 | 18 | def __iter__(self) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor, torch.LongTensor]]: 19 | for input_tensors, bbox_tensors in self._source_datapipe: 20 | yield from self._try_build_inputs(input_tensors, bbox_tensors) 21 | 22 | def _try_build_inputs( 23 | self, input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor] 24 | ) -> Iterable[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]]: 25 | try: 26 | yield self._masker(input_tensors, bbox_tensors) 27 | except ValueError as e: 28 | logging.warning(e) 29 | -------------------------------------------------------------------------------- /src/docllm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/src/docllm/modules/__init__.py -------------------------------------------------------------------------------- /src/docllm/modules/llama/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import DocLLMAttention 2 | from .config import DocLLMLlamaConfig 3 | from .decoder_layer import DocLLMLlamaDecoderLayer 4 | from .llama_docllm import DocLLMModelOutputWithPast, LlamaDocLLM 5 | from .pretrained_model import DocLLMLlamaPreTrainedModel 6 | -------------------------------------------------------------------------------- /src/docllm/modules/llama/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | from transformers.cache_utils import Cache 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 9 | from transformers.utils import logging 10 | 11 | from docllm.modules.llama.config import DocLLMLlamaConfig, PositionalEmbeddingMode 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | class DocLLMAttention(nn.Module): 17 | def __init__(self, config: DocLLMLlamaConfig, layer_idx: Optional[int] = None): 18 | if config.pretraining_tp > 1: 19 | raise NotImplementedError() 20 | super().__init__() 21 | self.config = config 22 | self.attention_dropout = config.attention_dropout 23 | self.hidden_size = config.hidden_size 24 | self.num_heads = config.num_attention_heads 25 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) 26 | self.num_key_value_heads = config.num_key_value_heads 27 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 28 | self.layer_idx = layer_idx 29 | self.scaling = self.head_dim**-0.5 30 | 31 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) 32 | self.k_proj = nn.Linear( 33 | self.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias 34 | ) 35 | self.v_proj = nn.Linear( 36 | self.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias 37 | ) 38 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) 39 | 40 | self.spatial_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 41 | self.spatial_k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 42 | self._lambda_ts = config.lambda_ts 43 | self._lambda_st = config.lambda_st 44 | self._lambda_ss = config.lambda_ss 45 | self._positional_embedding_mode = config.positional_embedding_mode 46 | if config._attn_implementation == "sdpa": 47 | self._attn_fct = self._sdpa_attention 48 | elif config._attn_implementation == "eager": 49 | self._attn_fct = self._eager_attention 50 | else: 51 | raise ValueError(f"Unknown attention implementation: {config._attn_implementation}") 52 | 53 | def set_freeze_llama_layers(self, freeze: bool): 54 | self.q_proj.weight.requires_grad_(not freeze) 55 | self.k_proj.weight.requires_grad_(not freeze) 56 | self.v_proj.weight.requires_grad_(not freeze) 57 | self.o_proj.weight.requires_grad_(not freeze) 58 | 59 | @torch.no_grad() 60 | def init_additional_weights(self, init_func: Callable[[torch.Tensor], None] = torch.nn.init.xavier_normal_): 61 | init_func(self.spatial_q_proj.weight) 62 | init_func(self.spatial_k_proj.weight) 63 | 64 | def forward( 65 | self, 66 | hidden_states: Tensor, 67 | bounding_box_embeddings: Tensor, 68 | attention_mask: Optional[torch.Tensor] = None, 69 | past_key_value: Cache | None = None, 70 | spatial_past_key_value: Cache | None = None, 71 | output_attentions: bool = False, 72 | cache_position: Optional[torch.LongTensor] = None, 73 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 74 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[Cache]]: 75 | query_states, key_states, value_states = self._compute_q_k_v( 76 | hidden_states, 77 | bounding_box_embeddings, 78 | past_key_value, 79 | spatial_past_key_value, 80 | cache_position, 81 | position_embeddings, 82 | ) 83 | 84 | attn_weights, attn_output = self._attn_fct( 85 | query_states, 86 | key_states, 87 | value_states, 88 | attention_mask, 89 | output_attentions, 90 | ) 91 | attn_output = attn_output.transpose(1, 2) # .contiguous() 92 | bsz, q_len, _ = hidden_states.size() 93 | if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): 94 | raise ValueError( 95 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 96 | f" {attn_output.size()}" 97 | ) 98 | attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1).contiguous() 99 | attn_output = self.o_proj(attn_output) 100 | 101 | return attn_output, attn_weights, past_key_value, spatial_past_key_value 102 | 103 | def _compute_q_k_v( 104 | self, 105 | hidden_states: torch.Tensor, 106 | bounding_box_embeddings: torch.Tensor, 107 | past_key_value: Optional[Cache], 108 | spatial_past_key_value: Optional[Cache], 109 | cache_position: Optional[torch.LongTensor], 110 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], 111 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 112 | query_states, key_states, value_states = self._compute_q_k_v_for_projections( 113 | hidden_states=hidden_states, 114 | q_proj=self.q_proj, 115 | k_proj=self.k_proj, 116 | v_proj=self.v_proj, 117 | past_key_value=past_key_value, 118 | position_embeddings=position_embeddings, 119 | cache_position=cache_position, 120 | ) 121 | 122 | spatial_query_states, spatial_key_states, _ = self._compute_q_k_v_for_projections( 123 | hidden_states=bounding_box_embeddings, 124 | q_proj=self.spatial_q_proj, 125 | k_proj=self.spatial_k_proj, 126 | v_proj=None, 127 | past_key_value=spatial_past_key_value, 128 | position_embeddings=position_embeddings, 129 | cache_position=cache_position, 130 | ) 131 | 132 | query_states = query_states + spatial_query_states 133 | key_states = key_states + spatial_key_states 134 | 135 | key_states = repeat_kv(key_states, self.num_key_value_groups) 136 | value_states = repeat_kv(value_states, self.num_key_value_groups) 137 | return query_states, key_states, value_states 138 | 139 | def _compute_q_k_v_for_projections( 140 | self, 141 | hidden_states: torch.Tensor, 142 | q_proj: nn.Linear, 143 | k_proj: nn.Linear, 144 | v_proj: Optional[nn.Linear], 145 | past_key_value: Optional[Cache], 146 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], 147 | cache_position: Optional[torch.LongTensor] = None, 148 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 149 | bsz, q_len, _ = hidden_states.size() 150 | 151 | query_states = q_proj(hidden_states) 152 | key_states = k_proj(hidden_states) 153 | if v_proj is not None: 154 | value_states = v_proj(hidden_states) 155 | else: 156 | value_states = None 157 | 158 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 159 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 160 | if v_proj is not None: 161 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 162 | 163 | if v_proj is not None or self._positional_embedding_mode != PositionalEmbeddingMode.NONE: 164 | cos, sin = position_embeddings 165 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 166 | else: 167 | cos = sin = None 168 | 169 | if past_key_value is not None: 170 | cache_kwargs = ( 171 | {"cache_position": cache_position} 172 | if cos is not None and sin is not None 173 | else {"sin": sin, "cos": cos, "cache_position": cache_position} 174 | ) # Specific to RoPE models 175 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 176 | 177 | return query_states, key_states, value_states 178 | 179 | def _eager_attention( 180 | self, 181 | query_states: torch.Tensor, 182 | key_states: torch.Tensor, 183 | value_states: torch.Tensor, 184 | attention_mask: Optional[torch.Tensor], 185 | output_attentions: bool, 186 | ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: 187 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling 188 | 189 | if attention_mask is not None: 190 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] 191 | attn_weights = attn_weights + causal_mask 192 | 193 | # upcast attention to fp32 194 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 195 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 196 | attn_output = torch.matmul(attn_weights, value_states) 197 | 198 | if not output_attentions: 199 | attn_weights = None 200 | 201 | return attn_weights, attn_output 202 | 203 | def _sdpa_attention( 204 | self, 205 | query_states: torch.Tensor, 206 | key_states: torch.Tensor, 207 | value_states: torch.Tensor, 208 | attention_mask: Optional[torch.Tensor], 209 | output_attentions: bool, 210 | ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: 211 | if output_attentions: 212 | logger.warning_once( 213 | "Setting `output_attentions=True` is not supported for `sdpa` attention. Ignoring the argument." 214 | ) 215 | if attention_mask is not None: 216 | attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] 217 | attn_output = torch.nn.functional.scaled_dot_product_attention( 218 | query=query_states.contiguous(), 219 | key=key_states.contiguous(), 220 | value=value_states.contiguous(), 221 | attn_mask=attention_mask, 222 | dropout_p=self.attention_dropout, 223 | scale=self.scaling, 224 | is_causal=True, 225 | ) 226 | return None, attn_output 227 | -------------------------------------------------------------------------------- /src/docllm/modules/llama/config.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | 3 | from transformers import LlamaConfig 4 | 5 | from docllm.modules.spatial_embedder import SpatialEmbeddingType 6 | 7 | 8 | class PositionalEmbeddingMode(StrEnum): 9 | NONE = "none" 10 | TEXT_ONLY = "text_only" 11 | TEXT_AND_SPATIAL = "text_and_spatial" 12 | 13 | 14 | class DocLLMLlamaConfig(LlamaConfig): 15 | 16 | def __init__( 17 | self, 18 | vocab_size=32000, 19 | additional_training_vocab_size=0, 20 | hidden_size=4096, 21 | intermediate_size=11008, 22 | num_hidden_layers=32, 23 | num_attention_heads=32, 24 | num_key_value_heads=None, 25 | hidden_act="silu", 26 | max_position_embeddings=2048, 27 | initializer_range=0.02, 28 | rms_norm_eps=1e-6, 29 | use_cache=False, 30 | pad_token_id=None, 31 | bos_token_id=1, 32 | eos_token_id=2, 33 | pretraining_tp=1, 34 | tie_word_embeddings=False, 35 | rope_theta=10000.0, 36 | rope_scaling=None, 37 | attention_bias=False, 38 | attention_dropout=0.0, 39 | lambda_ts: float = 1.0, 40 | lambda_st: float = 1.0, 41 | lambda_ss: float = 1.0, 42 | positional_embedding_mode: PositionalEmbeddingMode = PositionalEmbeddingMode.TEXT_AND_SPATIAL, 43 | embedding_type: SpatialEmbeddingType = SpatialEmbeddingType.PROJECTION, 44 | embed_include_width_height: bool = False, 45 | embed_max_coord: int = 1000, 46 | _attn_implementation: str = "eager", 47 | **kwargs, 48 | ): 49 | super().__init__( 50 | vocab_size, 51 | hidden_size, 52 | intermediate_size, 53 | num_hidden_layers, 54 | num_attention_heads, 55 | num_key_value_heads, 56 | hidden_act, 57 | max_position_embeddings, 58 | initializer_range, 59 | rms_norm_eps, 60 | use_cache, 61 | pad_token_id, 62 | bos_token_id, 63 | eos_token_id, 64 | pretraining_tp, 65 | tie_word_embeddings, 66 | rope_theta, 67 | rope_scaling, 68 | attention_bias, 69 | attention_dropout, 70 | _attn_implementation=_attn_implementation, 71 | **kwargs, 72 | ) 73 | self.additional_training_vocab_size = additional_training_vocab_size 74 | self.lambda_ts = lambda_ts 75 | self.lambda_st = lambda_st 76 | self.lambda_ss = lambda_ss 77 | self.positional_embedding_mode = positional_embedding_mode 78 | self.embedding_type = embedding_type 79 | self.embed_include_width_height = embed_include_width_height 80 | self.embed_max_coord = embed_max_coord 81 | 82 | if self.pretraining_tp > 1: 83 | raise NotImplementedError() 84 | -------------------------------------------------------------------------------- /src/docllm/modules/llama/decoder_layer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Callable, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | from transformers.cache_utils import Cache 8 | from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm 9 | 10 | from docllm.modules.llama.attention import DocLLMAttention 11 | from docllm.modules.llama.config import DocLLMLlamaConfig 12 | 13 | 14 | class DocLLMLlamaDecoderLayer(nn.Module): 15 | def __init__(self, config: DocLLMLlamaConfig, layer_idx: int): 16 | super().__init__() 17 | self.hidden_size = config.hidden_size 18 | 19 | self.self_attn = DocLLMAttention(config=config, layer_idx=layer_idx) 20 | 21 | self.mlp = LlamaMLP(config) 22 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 23 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 24 | 25 | def set_freeze_llama_layers(self, freeze: bool): 26 | self.self_attn.set_freeze_llama_layers(freeze) 27 | self.mlp.requires_grad_(not freeze) 28 | self.input_layernorm.requires_grad_(not freeze) 29 | self.post_attention_layernorm.requires_grad_(not freeze) 30 | 31 | @torch.no_grad() 32 | def init_additional_weights(self, init_func: Callable[[torch.Tensor], None] = torch.nn.init.xavier_normal_): 33 | self.self_attn.init_additional_weights(init_func) 34 | 35 | def forward( 36 | self, 37 | hidden_states: torch.Tensor, 38 | bounding_box_embeddings: Tensor, 39 | attention_mask: Optional[torch.Tensor] = None, 40 | past_key_value: Optional[Cache] = None, 41 | spatial_past_key_value: Optional[Cache] = None, 42 | output_attentions: Optional[bool] = False, 43 | cache_position: Optional[torch.LongTensor] = None, 44 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 45 | **kwargs, 46 | ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Cache], Optional[Cache]]: 47 | """ 48 | Args: 49 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 50 | bounding_box_embeddings (`torch.FloatTensor`): bounding box embeddings of shape `(batch, seq_len, embed_dim)` 51 | position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)` 52 | attention_mask (`torch.FloatTensor`, *optional*): 53 | attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, 54 | query_sequence_length, key_sequence_length)` if default attention is used. 55 | output_attentions (`bool`, *optional*): 56 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 57 | returned tensors for more detail. 58 | past_key_value (`Cache`, *optional*): cached past key and value projection states 59 | spatial_past_key_value (`Cache`, *optional*): cached past key and value projection states 60 | """ 61 | if "padding_mask" in kwargs: 62 | warnings.warn( 63 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 64 | ) 65 | 66 | residual = hidden_states 67 | 68 | hidden_states = self.input_layernorm(hidden_states) 69 | 70 | # Self Attention 71 | ( 72 | hidden_states, 73 | self_attn_weights, 74 | present_key_value, 75 | spatial_present_key_value, 76 | ) = self.self_attn( 77 | hidden_states=hidden_states, 78 | bounding_box_embeddings=bounding_box_embeddings, 79 | attention_mask=attention_mask, 80 | past_key_value=past_key_value, 81 | spatial_past_key_value=spatial_past_key_value, 82 | output_attentions=output_attentions, 83 | cache_position=cache_position, 84 | position_embeddings=position_embeddings, 85 | **kwargs, 86 | ) 87 | hidden_states = residual + hidden_states 88 | 89 | # Fully Connected 90 | residual = hidden_states 91 | hidden_states = self.post_attention_layernorm(hidden_states) 92 | hidden_states = self.mlp(hidden_states) 93 | hidden_states = residual + hidden_states 94 | 95 | if not output_attentions: 96 | self_attn_weights = None 97 | 98 | return hidden_states, self_attn_weights, present_key_value, spatial_present_key_value 99 | -------------------------------------------------------------------------------- /src/docllm/modules/llama/pretrained_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import PreTrainedModel 3 | from transformers.utils import add_start_docstrings 4 | 5 | from docllm.modules.llama.config import DocLLMLlamaConfig 6 | 7 | DOCLLM_LLAMA_START_DOCSTRING = r""" 8 | ... 9 | 10 | Parameters: 11 | config ([`DocLLMLlamaConfig`]): 12 | Model configuration class with all the parameters of the model. Initializing with a config file does not 13 | load the weights associated with the model, only the configuration. Check out the 14 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 15 | """ 16 | 17 | 18 | @add_start_docstrings( 19 | "The bare LLaMA DocLLM Model outputting raw hidden-states without any specific head on top.", 20 | DOCLLM_LLAMA_START_DOCSTRING, 21 | ) 22 | class DocLLMLlamaPreTrainedModel(PreTrainedModel): 23 | config_class = DocLLMLlamaConfig 24 | base_model_prefix = "model" 25 | supports_gradient_checkpointing = True 26 | _no_split_modules = ["DocLLMLlamaDecoderLayer"] 27 | _skip_keys_device_placement = ["past_key_values", "spatial_past_key_value"] 28 | _supports_flash_attn_2 = False 29 | _supports_sdpa = False 30 | _supports_cache_class = True 31 | 32 | def _init_weights(self, module): 33 | std = self.config.initializer_range 34 | if isinstance(module, nn.Linear): 35 | module.weight.data.normal_(mean=0.0, std=std) 36 | if module.bias is not None: 37 | module.bias.data.zero_() 38 | elif isinstance(module, nn.Embedding): 39 | module.weight.data.normal_(mean=0.0, std=std) 40 | if module.padding_idx is not None: 41 | module.weight.data[module.padding_idx].zero_() 42 | -------------------------------------------------------------------------------- /src/docllm/modules/part_freezable_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PartFreezableEmbedding(nn.Embedding): 8 | def __init__( 9 | self, 10 | num_embeddings: int, 11 | embedding_dim: int, 12 | padding_idx: int | None = None, 13 | max_norm: float | None = None, 14 | norm_type: float = 2, 15 | scale_grad_by_freq: bool = False, 16 | sparse: bool = False, 17 | _weight: torch.Tensor | None = None, 18 | _freeze: bool = False, 19 | device=None, 20 | dtype=None, 21 | num_additional_tokens: int = 0, 22 | ) -> None: 23 | super().__init__( 24 | num_embeddings=num_embeddings, 25 | embedding_dim=embedding_dim, 26 | padding_idx=padding_idx, 27 | max_norm=max_norm, 28 | norm_type=norm_type, 29 | scale_grad_by_freq=scale_grad_by_freq, 30 | sparse=sparse, 31 | _weight=_weight, 32 | _freeze=_freeze, 33 | device=device, 34 | dtype=dtype, 35 | ) 36 | self._num_additional_tokens = num_additional_tokens 37 | if self._num_additional_tokens > 0: 38 | self.additional_embeddings = nn.Embedding( 39 | num_embeddings=self._num_additional_tokens, 40 | embedding_dim=embedding_dim, 41 | padding_idx=padding_idx, 42 | max_norm=max_norm, 43 | norm_type=norm_type, 44 | scale_grad_by_freq=scale_grad_by_freq, 45 | sparse=sparse, 46 | _weight=_weight, 47 | _freeze=_freeze, 48 | device=device, 49 | dtype=dtype, 50 | ) 51 | 52 | def set_freeze_original_embeddings(self, freeze: bool): 53 | self.requires_grad_(not freeze) 54 | if self._num_additional_tokens > 0: 55 | self.additional_embeddings.requires_grad_(True) 56 | 57 | @torch.no_grad() 58 | def init_additional_embeddings(self, init_func: Callable[[torch.Tensor], None] = torch.nn.init.xavier_normal_): 59 | if self._num_additional_tokens > 0: 60 | init_func(self.additional_embeddings.weight) 61 | 62 | @torch.no_grad() 63 | def fuse_additional_embeddings(self): 64 | if self._num_additional_tokens > 0: 65 | self.weight = nn.Parameter(torch.cat([self.weight, self.additional_embeddings.weight], dim=0)) 66 | self.num_embeddings += self._num_additional_tokens 67 | self._num_additional_tokens = 0 68 | del self.additional_embeddings 69 | 70 | def forward(self, input: torch.Tensor) -> torch.Tensor: 71 | input = input.clone() 72 | add_token_mask = input >= self.num_embeddings 73 | input[add_token_mask] -= self.num_embeddings 74 | embeddings = super().forward(input) 75 | if self._num_additional_tokens > 0: 76 | add_embeddings = self.additional_embeddings(input[add_token_mask]) 77 | embeddings.view(-1, self.embedding_dim)[add_token_mask.view(-1)] = add_embeddings 78 | return embeddings 79 | -------------------------------------------------------------------------------- /src/docllm/modules/part_freezable_linear.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PartFreezableLinear(nn.Linear): 8 | """ 9 | A linear layer with the ability to freeze the original outputs and add additional outputs. 10 | 11 | Args: 12 | in_features (int): Size of each input sample. 13 | out_features (int): Size of each output sample. 14 | bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True. 15 | num_additional_outputs (int, optional): Number of additional outputs to be added. Default is 0. 16 | 17 | Methods: 18 | set_freeze_original_outputs(freeze: bool): Sets whether to freeze the original outputs. 19 | forward(input: torch.Tensor) -> torch.Tensor: Performs forward pass of the linear layer. 20 | 21 | """ 22 | 23 | def __init__( 24 | self, 25 | in_features: int, 26 | out_features: int, 27 | bias: bool = True, 28 | num_additional_outputs: int = 0, 29 | ) -> None: 30 | super().__init__( 31 | in_features=in_features, 32 | out_features=out_features, 33 | bias=bias, 34 | ) 35 | self._num_additional_outputs = num_additional_outputs 36 | if self._num_additional_outputs > 0: 37 | self._additional_outputs = nn.Linear( 38 | in_features=in_features, 39 | out_features=self._num_additional_outputs, 40 | bias=bias, 41 | ) 42 | 43 | def set_freeze_original_outputs(self, freeze: bool): 44 | """ 45 | Sets whether to freeze the original outputs. 46 | 47 | Args: 48 | freeze (bool): If True, the original outputs will be frozen. If False, the original outputs will be trainable. 49 | 50 | """ 51 | self.requires_grad_(not freeze) 52 | if self._num_additional_outputs > 0: 53 | self._additional_outputs.requires_grad_(True) 54 | 55 | @torch.no_grad() 56 | def init_additional_outputs(self, init_func: Callable[[torch.Tensor], None] = torch.nn.init.xavier_normal_): 57 | """ 58 | Initializes the additional outputs. 59 | """ 60 | if self._num_additional_outputs > 0: 61 | init_func(self._additional_outputs.weight) 62 | if self.bias is not None: 63 | torch.nn.init.zeros_(self._additional_outputs.bias) 64 | 65 | @torch.no_grad() 66 | def fuse_additional_outputs(self): 67 | """ 68 | Fuses the additional outputs with the original outputs. 69 | Afterwards, the additional outputs will be removed. 70 | """ 71 | if self._num_additional_outputs > 0: 72 | self.weight = nn.Parameter(torch.cat([self.weight, self._additional_outputs.weight], dim=0)) 73 | if self.bias is not None: 74 | self.bias = nn.Parameter(torch.cat([self.bias, self._additional_outputs.bias], dim=0)) 75 | self.out_features += self._num_additional_outputs 76 | self._num_additional_outputs = 0 77 | del self._additional_outputs 78 | 79 | def forward(self, input: torch.Tensor) -> torch.Tensor: 80 | """ 81 | Performs forward pass of the linear layer. 82 | 83 | Args: 84 | input (torch.Tensor): Input tensor. 85 | 86 | Returns: 87 | torch.Tensor: Output tensor. 88 | 89 | """ 90 | output = super().forward(input) 91 | if self._num_additional_outputs > 0: 92 | additional_output = self._additional_outputs(input) 93 | output = torch.cat([output, additional_output], dim=-1) 94 | return output 95 | -------------------------------------------------------------------------------- /src/docllm/modules/spatial_embedder.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SpatialEmbeddingType(StrEnum): 8 | PROJECTION = "projection" 9 | EMBED = "embed" 10 | 11 | 12 | class SpatialEmbedder(nn.Module): 13 | def __init__( 14 | self, 15 | hidden_size: int, 16 | embedding_type: SpatialEmbeddingType, 17 | include_width_height: bool, 18 | max_coord: int = 1000, 19 | ): 20 | super().__init__() 21 | self._embedding_type = embedding_type 22 | self._include_width_height = include_width_height 23 | self._num_coords = max_coord + 1 24 | 25 | self._setup(hidden_size) 26 | 27 | def forward(self, input_coordinates: torch.FloatTensor) -> torch.FloatTensor: 28 | if self._include_width_height: 29 | input_coordinates = self._add_width_and_height(input_coordinates) 30 | 31 | if self._embedding_type == SpatialEmbeddingType.PROJECTION: 32 | return self.proj(input_coordinates) 33 | 34 | input_coordinates = self._scale(input_coordinates) 35 | 36 | embedded = [ 37 | self.embed_x(input_coordinates[:, :, 0]), 38 | self.embed_y(input_coordinates[:, :, 1]), 39 | self.embed_x(input_coordinates[:, :, 2]), 40 | self.embed_y(input_coordinates[:, :, 3]), 41 | ] 42 | if self._include_width_height: 43 | embedded += [ 44 | self.embed_h(input_coordinates[:, :, 4]), 45 | self.embed_w(input_coordinates[:, :, 5]), 46 | ] 47 | return torch.cat(embedded, dim=-1) 48 | 49 | def _setup(self, hidden_size: int): 50 | num_feat = 6 if self._include_width_height else 4 51 | 52 | match self._embedding_type: 53 | case SpatialEmbeddingType.PROJECTION: 54 | self.proj = nn.Linear(num_feat, hidden_size, bias=False) 55 | case SpatialEmbeddingType.EMBED: 56 | self._setup_dimension_embeddings(hidden_size, num_feat) 57 | 58 | def _setup_dimension_embeddings(self, hidden_size: int, num_feat: int): 59 | if hidden_size % num_feat != 0: 60 | raise Exception(f"Hidden size must be divisible by {num_feat} for spatial embeddings.") 61 | self.embed_x = nn.Embedding(self._num_coords, hidden_size // num_feat) 62 | self.embed_y = nn.Embedding(self._num_coords, hidden_size // num_feat) 63 | if self._include_width_height: 64 | self.embed_h = nn.Embedding(self._num_coords, hidden_size // num_feat) 65 | self.embed_w = nn.Embedding(self._num_coords, hidden_size // num_feat) 66 | 67 | def _add_width_and_height(self, input_coordinates: torch.Tensor) -> torch.Tensor: 68 | heights = (input_coordinates[:, :, 3] - input_coordinates[:, :, 1]).unsqueeze(2) 69 | widths = (input_coordinates[:, :, 2] - input_coordinates[:, :, 0]).unsqueeze(2) 70 | return torch.cat((input_coordinates, heights, widths), dim=-1) 71 | 72 | def _scale(self, input_coordinates: torch.FloatTensor) -> torch.LongTensor: 73 | max_coord = self._num_coords - 1 74 | return torch.round(input_coordinates * max_coord).int() 75 | -------------------------------------------------------------------------------- /src/docllm/pretraining_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class DocLLMCrossEntropyLoss(nn.CrossEntropyLoss): 8 | def __init__( 9 | self, 10 | weight: torch.Tensor | None = None, 11 | size_average=None, 12 | ignore_index: int = -100, 13 | reduce=None, 14 | label_smoothing: float = 0, 15 | ) -> None: 16 | super().__init__( 17 | weight=weight, 18 | size_average=size_average, 19 | ignore_index=ignore_index, 20 | reduce=reduce, 21 | reduction="none", 22 | label_smoothing=label_smoothing, 23 | ) 24 | 25 | def forward( 26 | self, logits: torch.FloatTensor, labels: torch.LongTensor, loss_mask: Optional[torch.BoolTensor] 27 | ) -> torch.FloatTensor: 28 | losses = super().forward(logits.view(-1, logits.size(-1)), labels.view(-1)) 29 | if loss_mask is not None: 30 | losses = losses * loss_mask.view(-1).float() 31 | return losses.sum() / loss_mask.sum().float() 32 | -------------------------------------------------------------------------------- /src/docllm/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/src/docllm/py.typed -------------------------------------------------------------------------------- /src/docllm/scripts/convert_fsdp_to_hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch.distributed._shard.checkpoint as dist_cp 5 | 6 | from docllm.modules.llama.causal_docllm import CausalLlamaDocLLM 7 | from docllm.modules.llama.config import DocLLMLlamaConfig 8 | 9 | 10 | def convert_fsdp_to_hf(fsdp_checkpoint_dir: str, hf_checkpoint_dir: str, base_model_name: str) -> None: 11 | model_config: DocLLMLlamaConfig = DocLLMLlamaConfig.from_pretrained( 12 | base_model_name, additional_training_vocab_size=3, use_cache=False, _attn_implementation="sdpa" 13 | ) 14 | model: CausalLlamaDocLLM = CausalLlamaDocLLM.from_pretrained(base_model_name, config=model_config) 15 | model_state_dict = {"model": model.state_dict()} 16 | dist_cp.load(state_dict=model_state_dict, storage_reader=dist_cp.FileSystemReader(fsdp_checkpoint_dir)) 17 | model_state_dict = model_state_dict["model"] 18 | model.load_state_dict(model_state_dict) 19 | model.save_pretrained(os.path.join(hf_checkpoint_dir)) 20 | 21 | 22 | def main(): 23 | args = argparse.ArgumentParser(description="Convert FSDP checkpoint to HF format") 24 | args.add_argument("--fsdp_checkpoint_dir", type=str, required=True, help="Directory containing the FSDP checkpoint") 25 | args.add_argument("--hf_checkpoint_dir", type=str, required=True, help="Directory to save the HF checkpoint") 26 | args.add_argument("--base_model_name", type=str, required=True, help="Base model name for configuration") 27 | args = args.parse_args() 28 | fsdp_checkpoint_dir = args.fsdp_checkpoint_dir 29 | hf_checkpoint_dir = args.hf_checkpoint_dir 30 | base_model_name = args.base_model_name 31 | convert_fsdp_to_hf(fsdp_checkpoint_dir, hf_checkpoint_dir, base_model_name) 32 | print(f"Converted FSDP checkpoint from {fsdp_checkpoint_dir} to HF format in {hf_checkpoint_dir}") 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /src/docllm/scripts/json_to_pickle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from docllm.data.preprocessing.doc_data_to_pickle import document_tokenization_from_file 5 | 6 | 7 | def main(input_file_or_dir: str, output_dir: str, use_page_dimensions: bool) -> None: 8 | if os.path.isfile(input_file_or_dir): 9 | input_files = [input_file_or_dir] 10 | else: 11 | input_files = [ 12 | os.path.join(input_file_or_dir, fn) for fn in os.listdir(input_file_or_dir) if fn.endswith(".json") 13 | ] 14 | for filename in input_files: 15 | document_tokenization_from_file(filename, output_dir, use_page_dimensions) 16 | 17 | 18 | if __name__ == "__main__": 19 | print(sys.argv) 20 | if len(sys.argv) < 3: 21 | print("Usage: python document_tokenization.py []") 22 | print("Default value for use_page_dimensions is True.") 23 | sys.exit(1) 24 | input_file_or_dir = sys.argv[1] 25 | output_dir = sys.argv[2] 26 | use_page_dimensions = bool(sys.argv[3]) if len(sys.argv) > 3 else True 27 | main(input_file_or_dir, output_dir, use_page_dimensions) 28 | -------------------------------------------------------------------------------- /src/docllm/training/pretraining/accelerate/train_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from functools import partial 4 | from typing import Iterable 5 | 6 | import torch 7 | from accelerate import Accelerator 8 | from accelerate.utils import LoggerType, ProjectConfiguration 9 | from pydantic import BaseModel 10 | from torch.utils.data import BatchSampler, DataLoader, RandomSampler 11 | 12 | from docllm.data.pretraining.collate import collate 13 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 14 | from docllm.data.pretraining.dataset import DocLLMPretrainDataset 15 | from docllm.data.pretraining.precomputation_pipeline import build_precomputed_data_pipeline, precompute_data 16 | from docllm.modules.llama.causal_docllm import CausalDocLLMOutputWithPast, CausalLlamaDocLLM 17 | from docllm.modules.llama.config import DocLLMLlamaConfig 18 | from docllm.pretraining_loss import DocLLMCrossEntropyLoss 19 | 20 | 21 | class LlamaDocLLMTrainerConfig(BaseModel): 22 | initial_model_checkpoint: str 23 | num_epochs: int = 1 24 | batch_size: int 25 | freeze_base_weights: bool = False 26 | train_data_dir: str 27 | eval_data_dir_in: str 28 | eval_data_dir_out: str 29 | output_dir: str 30 | resume_from_checkpoint: bool 31 | eval_steps: int 32 | 33 | max_lr: float = 6e-4 34 | 35 | 36 | class Trainer: 37 | def __init__( 38 | self, 39 | config: LlamaDocLLMTrainerConfig, 40 | skip_files: Iterable[str] = [], 41 | ) -> None: 42 | self._config = config.model_copy() 43 | self._model_config = DocLLMLlamaConfig( 44 | vocab_size=32000, 45 | additional_training_vocab_size=3, 46 | hidden_size=512, 47 | num_hidden_layers=2, 48 | num_attention_heads=2, 49 | max_position_embeddings=4096, 50 | use_cache=False, 51 | ) 52 | self._model = CausalLlamaDocLLM(self._model_config) 53 | 54 | # TODO Init new weight 55 | 56 | if self._config.freeze_base_weights: 57 | self._model.set_freeze_llama_layers(True) 58 | 59 | self._build_train_dataloader(skip_files) 60 | self._build_eval_dataloader() 61 | 62 | self._loss = DocLLMCrossEntropyLoss() 63 | self._optimizer = torch.optim.AdamW(self._model.parameters(), lr=self._config.max_lr) # TODO 64 | self._scheduler = torch.optim.lr_scheduler.OneCycleLR( 65 | self._optimizer, 66 | max_lr=self._config.max_lr, 67 | div_factor=10.0, 68 | final_div_factor=1.0, 69 | total_steps=math.ceil(len(self._train_dataloader) / self._train_dataloader.batch_size), 70 | pct_start=0.01, 71 | anneal_strategy="cos", 72 | ) 73 | 74 | project_config = ProjectConfiguration( 75 | project_dir=self._config.output_dir, 76 | automatic_checkpoint_naming=True, 77 | total_limit=5, 78 | iteration=0, # TODO should this be adapted when loading a checkpoint? 79 | ) 80 | 81 | # https://huggingface.co/docs/accelerate/usage_guides/checkpoint 82 | self._accelerator = Accelerator(project_config=project_config, log_with=[LoggerType.TENSORBOARD]) 83 | 84 | config_to_log = { 85 | "num_samples": len(self._train_dataloader.dataset), 86 | "max_lr": self._config.max_lr, 87 | "loss_function": str(self._loss), 88 | } 89 | 90 | self._accelerator.init_trackers("DocLLM_Llama-2_7B", config=config_to_log) 91 | 92 | self._model, self._optimizer, self._train_dataloader = self._accelerator.prepare( 93 | self._model, self._optimizer, self._train_dataloader 94 | ) 95 | 96 | # Register the LR scheduler 97 | self._accelerator.register_for_checkpointing(self._scheduler) 98 | 99 | if self._config.resume_from_checkpoint: 100 | self._accelerator.load_state(get_last_checkpoint(self._config.output_dir)) 101 | else: 102 | self._accelerator.save_state() 103 | 104 | self._device = self._accelerator.device 105 | self._model.to(self._device) 106 | 107 | def train(self): 108 | num_tokens_seen = 0 109 | 110 | # Perform training 111 | for epoch in range(self._config.num_epochs): 112 | for step, batch in enumerate(self._train_dataloader): 113 | self._accelerator.log({"lr": self._scheduler.get_lr()}, step=step) 114 | self._optimizer.zero_grad() 115 | inputs, bboxes, mask, labels = batch 116 | num_tokens_seen += inputs.size(1) 117 | self._accelerator.log({"num_tokens_seen": num_tokens_seen}, step=step) 118 | inputs = inputs.to(self._device) 119 | bboxes = bboxes.to(self._device) 120 | mask = mask.to(self._device) 121 | labels = labels.to(self._device) 122 | outputs: CausalDocLLMOutputWithPast = self._model(input_ids=inputs, input_coordinates=bboxes) 123 | loss: torch.Tensor = self._loss(outputs.logits, labels, mask) 124 | self._accelerator.log({"train_loss": loss.item()}, step=step) 125 | self._accelerator.backward(loss) 126 | grad_norm = torch.nn.utils.clip_grad_norm_(self._model.parameters(), 1.0) 127 | self._accelerator.log({"gradient_norm": grad_norm.item()}, step=step) 128 | self._optimizer.step() 129 | self._scheduler.step() 130 | self._accelerator.log({"learning_rate": self._scheduler.get_last_lr()}, step=step) 131 | self._checkpoint_and_evaluate(step) 132 | self._checkpoint_and_evaluate() 133 | 134 | self._accelerator.end_training() 135 | 136 | def _build_train_dataloader(self, skip_files: Iterable[str]): 137 | data_config = self._build_config(self._config.train_data_dir) 138 | dataset = DocLLMPretrainDataset(data_config, skip_files) 139 | self._train_dataloader = DataLoader( 140 | dataset, 141 | batch_size=data_config.batch_size, 142 | shuffle=data_config.shuffle, 143 | num_workers=0, 144 | collate_fn=partial(collate, padding_value=data_config.padding_value), 145 | drop_last=data_config.drop_last_batch_if_not_full, 146 | ) 147 | 148 | def _build_eval_dataloader(self): 149 | data_config = self._precompute_eval_data_for_determinism() 150 | data_pipeline_eval = build_precomputed_data_pipeline(self._config.eval_data_dir_out) 151 | self._eval_dataloader = DataLoader( 152 | data_pipeline_eval, 153 | batch_size=data_config.batch_size, 154 | shuffle=False, 155 | collate_fn=partial(collate, padding_value=data_config.padding_value), 156 | num_workers=0, 157 | drop_last=False, 158 | ) 159 | # examples_iter_eval = ExamplesIterable(generate_examples_fn=iter_fct, kwargs={"pipeline": data_pipeline_eval}) 160 | # gen = np.random.Generator(np.random.PCG64(42)) 161 | # shuffle_config = ShufflingConfig(generator=gen, _original_seed=42) 162 | # eval_dataset = IterableDataset(examples_iter_eval, shuffling=shuffle_config) 163 | # self._eval_dataset = eval_dataset.with_format(type="torch") 164 | 165 | def _precompute_eval_data_for_determinism(self) -> DocLLMPreTrainDataConfig: 166 | data_config = self._build_config(self._config.eval_data_dir_in) 167 | os.makedirs(self._config.eval_data_dir_out, exist_ok=True) 168 | if len(os.listdir(self._config.eval_data_dir_out)) == 0: 169 | precompute_data(data_config, self._config.eval_data_dir_out) 170 | return data_config 171 | 172 | def _build_config(self, data_dir: str) -> DocLLMPreTrainDataConfig: 173 | return DocLLMPreTrainDataConfig( 174 | batch_size=self._config.batch_size, 175 | drop_last_batch_if_not_full=False, 176 | shuffle=True, 177 | use_sharding_filter=False, 178 | use_packing=False, 179 | max_seq_length=self._model_config.max_position_embeddings, 180 | num_masked_blocks=(0.05, 0.20), 181 | max_percentage_masked_blocks=0.20, 182 | mask_text_token=self._model_config.vocab_size + 2, 183 | mask_bbox_token=(0.0, 0.0, 0.0, 0.0), 184 | block_start_text_token=self._model_config.vocab_size + 1, 185 | block_end_text_token=self._model_config.vocab_size, 186 | bos_text_token=self._model_config.bos_token_id, 187 | bos_bbox_token=None, 188 | padding_value=0.0, 189 | directory=data_dir, 190 | ) 191 | 192 | def _checkpoint_and_evaluate(self, step: int | None = None): 193 | if step is None or step % self._config.eval_steps == 0: 194 | self._accelerator.save_state() 195 | self._run_evaluation() 196 | 197 | def _run_evaluation(self, step: int | None = None): 198 | with torch.no_grad(): 199 | self._model.eval() 200 | total_loss = 0.0 201 | for batch in self._eval_dataloader: 202 | inputs, bboxes, mask, labels = batch 203 | inputs = inputs.to(self._accelerator.device) 204 | bboxes = bboxes.to(self._accelerator.device) 205 | mask = mask.to(self._accelerator.device) 206 | labels = labels.to(self._accelerator.device) 207 | outputs: CausalDocLLMOutputWithPast = self._model(inputs, bboxes) 208 | loss: torch.Tensor = self._loss(outputs.logits, labels, mask) 209 | total_loss += loss.item() 210 | self._accelerator.log({"eval_loss": total_loss}, step=step) 211 | self._model.train() 212 | 213 | 214 | def get_last_checkpoint(project_dir: str) -> str: 215 | return max(os.listdir(os.path.join(project_dir, "checkpoints")), key=lambda x: int(x.split("_")[-11])) 216 | -------------------------------------------------------------------------------- /src/external_scripts/document_tokenization/LICENSE: -------------------------------------------------------------------------------- 1 | GNU Affero General Public License 2 | 3 | https://www.gnu.org/licenses/agpl-3.0.html -------------------------------------------------------------------------------- /src/external_scripts/document_tokenization/document_tokenization_pymupdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | from functools import partial 5 | from itertools import groupby 6 | from typing import Dict, Iterable, List, Optional, Tuple, Union 7 | 8 | import fitz 9 | from transformers import AutoTokenizer, PreTrainedTokenizer 10 | 11 | from docllm.data.preprocessing.data_structure import Block, Document, Line, Page, Rect, Token, Word, WritingMode 12 | from docllm.data.preprocessing.doc_data_from_hocr import ( 13 | Setup, 14 | build_word_tokens_with_same_bounding_box, 15 | tokenize_with_leading_space, 16 | ) 17 | 18 | 19 | def main(args: List[str]): 20 | print(args) 21 | if len(args) < 3: 22 | print("Usage: python document_tokenization.py []") 23 | sys.exit(1) 24 | pdf_path = args[1] 25 | pdf_file_base = os.path.splitext(os.path.basename(pdf_path))[0] 26 | output_dir = args[2] 27 | tokenizer_name_or_path = "LeoLM/leo-hessianai-7b" if len(args) < 4 else args[3] 28 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, legacy=True, add_bos_token=False) 29 | setup = Setup(tokenizer=tokenizer) 30 | 31 | doc = parse_doc_from_pdf_path(pdf_path, setup) 32 | 33 | with open(os.path.join(output_dir, pdf_file_base + ".json"), "w") as f: 34 | f.write(doc.model_dump_json()) 35 | 36 | 37 | def parse_doc_from_pdf_path(pdf_path: str, setup: Setup) -> Document: 38 | doc = fitz.open(pdf_path) 39 | pages = map(partial(parse_page_from_page_dict, setup=setup), (p.get_text("rawdict") for p in doc)) 40 | return Document(pages=list(pages), filename=pdf_path) 41 | 42 | 43 | def parse_page_from_page_dict( 44 | page_dict: Dict[str, Union[str, List[Dict[str, Union[str, List[Dict[str, Union[int, Rect]]]]]]]], setup: Setup 45 | ) -> Page: 46 | blocks = [ 47 | block 48 | for block in map(partial(parse_block_from_block_dict, setup=setup), page_dict["blocks"]) 49 | if block is not None 50 | ] 51 | return Page(blocks=blocks, width=page_dict["width"], height=page_dict["height"]) 52 | 53 | 54 | def parse_block_from_block_dict( 55 | block_dict: Dict[str, Union[str, List[Dict[str, Union[str, List[Dict[str, Union[int, Rect]]]]]]]], setup: Setup 56 | ) -> Optional[Block]: 57 | if block_dict["type"] != 0: 58 | return None 59 | lines = map(partial(parse_line_from_line_dict, setup=setup), block_dict["lines"]) 60 | return Block(lines=list(lines), bbox=block_dict["bbox"]) 61 | 62 | 63 | def parse_line_from_line_dict( 64 | line_dict: Dict[str, Union[str, List[Dict[str, Union[int, Rect]]]]], setup: Setup 65 | ) -> Line: 66 | chars = [c for span in line_dict["spans"] for c in span["chars"]] 67 | text = "".join(c["c"] for c in chars) 68 | writing_mode = WritingMode(line_dict["wmode"]) 69 | split_chars = (list(group) for k, group in groupby(chars, key=lambda c: re.match(r"\s", c["c"])) if not k) 70 | words = map(partial(parse_word_from_char_dicts, writing_mode=writing_mode, setup=setup), split_chars) 71 | return Line( 72 | text=text, 73 | words=list(words), 74 | direction=(line_dict["dir"][0], line_dict["dir"][1]), 75 | writing_mode=writing_mode, 76 | bbox=line_dict["bbox"], 77 | ) 78 | 79 | 80 | def parse_word_from_char_dicts( 81 | char_dicts: List[Dict[str, Union[int, Rect]]], writing_mode: WritingMode, setup: Setup 82 | ) -> Word: 83 | text = "".join(char_dict["c"] for char_dict in char_dicts) 84 | if setup.build_token_bboxes: 85 | tokens = list(_build_tokens(text, char_dicts, writing_mode, setup)) 86 | bbox = _join_rects(*(t.bbox for t in tokens)) 87 | else: 88 | bbox = _join_rects(*(c["bbox"] for c in char_dicts)) 89 | tokens = list(build_word_tokens_with_same_bounding_box(text, setup.tokenizer, bbox)) 90 | return Word(text=text, tokens=tokens, bbox=bbox) 91 | 92 | 93 | def _build_tokens( 94 | text: str, char_dicts: List[Dict[str, Union[int, Rect]]], writing_mode: WritingMode, setup: Setup 95 | ) -> Iterable[Token]: 96 | token_ids = tokenize_with_leading_space(text, setup.tokenizer) 97 | token_groups = _build_token_strings_with_groups_of_escaped_tokens(token_ids, setup) 98 | for i, (t, b) in zip(token_ids, _build_token_boxes(token_groups, text, char_dicts, writing_mode)): 99 | yield Token(text=t, token_id=i, bbox=b) 100 | 101 | 102 | def _build_token_strings_with_groups_of_escaped_tokens(token_ids: List[int], setup: Setup): 103 | tokens = setup.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=True) 104 | tokens = _remove_begin_of_word_markers(tokens, setup.begin_of_word) 105 | token_groups = _group_consecutive_escaped_tokens(tokens) 106 | return token_groups 107 | 108 | 109 | def _remove_begin_of_word_markers(tokens: Iterable[str], begin_of_word: str) -> Iterable[str]: 110 | return filter(None, map(lambda t: re.sub(f"^{begin_of_word}", "", t), tokens)) 111 | 112 | 113 | def _group_consecutive_escaped_tokens(tokens: Iterable[str]) -> Iterable[List[str]]: 114 | return ( 115 | list(map(lambda t: t[1], g)) 116 | for _, g in groupby(enumerate(tokens), key=lambda idx_tok: -1 if _is_escaped(idx_tok[1]) else idx_tok[0]) 117 | ) 118 | 119 | 120 | def _build_token_boxes( 121 | token_groups: Iterable[List[str]], 122 | original_text: str, 123 | char_dicts: List[Dict[str, Union[int, Rect]]], 124 | writing_mode: WritingMode, 125 | ) -> Iterable[Tuple[str, Rect]]: 126 | start_idx = 0 127 | for tg in token_groups: 128 | result_iter, start_idx = _create_token_boxes_for_token_group( 129 | tg, original_text, start_idx, char_dicts, writing_mode 130 | ) 131 | yield from result_iter 132 | 133 | 134 | def _create_token_boxes_for_token_group( 135 | token_group: List[str], 136 | original_text: str, 137 | start_idx: int, 138 | char_dicts: List[Dict[str, Union[int, Rect]]], 139 | writing_mode: WritingMode, 140 | ) -> Tuple[Iterable[Tuple[str, Rect]], int]: 141 | if len(token_group) == 1 and not _is_escaped(token_group[0]): 142 | detok_text = token_group[0] 143 | token_box, start_idx = _extract_token_box(detok_text, original_text, char_dicts, start_idx) 144 | return [(detok_text, token_box)], start_idx 145 | else: 146 | token_boxes, start_idx = _create_boxes_for_escaped_tokens( 147 | token_group, original_text, start_idx, char_dicts, writing_mode 148 | ) 149 | return zip(token_group, token_boxes), start_idx 150 | 151 | 152 | def _is_escaped(token: str) -> bool: 153 | return token.startswith("<") and token.endswith(">") 154 | 155 | 156 | def _create_boxes_for_escaped_tokens( 157 | tokens: List[str], 158 | original_text: str, 159 | start_idx: int, 160 | char_dicts: List[Dict[str, Union[int, Rect]]], 161 | writing_mode: WritingMode, 162 | ) -> Tuple[List[Rect], int]: 163 | token_bytes = "".join(map(_extract_bytes, tokens)) 164 | detok_text = bytes.fromhex(token_bytes).decode("utf-8") 165 | token_box, start_idx = _extract_token_box(detok_text, original_text, char_dicts, start_idx) 166 | token_boxes = _split_token_box(token_box, len(tokens), writing_mode) 167 | return token_boxes, start_idx 168 | 169 | 170 | def _extract_token_box( 171 | detokenized_text: str, original_text: str, char_dicts: List[Dict[str, Union[int, Rect]]], start_idx: int 172 | ) -> Tuple[Rect, int]: 173 | end_idx = start_idx + len(detokenized_text) 174 | assert detokenized_text == original_text[start_idx:end_idx] 175 | token_box = _join_rects(*(c["bbox"] for c in char_dicts[start_idx:end_idx])) 176 | return token_box, end_idx 177 | 178 | 179 | def _join_rects(*rects: Tuple[float, float, float, float]) -> Tuple[float, float, float, float]: 180 | lefts, tops, rights, bottoms = zip(*rects) 181 | return min(lefts), min(tops), max(rights), max(bottoms) 182 | 183 | 184 | def _extract_bytes(token: str) -> bytes: 185 | assert len(token) == 6 186 | return token[3:5] 187 | 188 | 189 | def _split_token_box(token_box: Rect, num_tokens: int, direction: WritingMode) -> Iterable[Rect]: 190 | left, top, right, bottom = token_box 191 | if direction == WritingMode.horizontal: 192 | width = (right - left) / num_tokens 193 | return ((left + i * width, top, left + (i + 1) * width, bottom) for i in range(num_tokens)) 194 | else: 195 | height = (bottom - top) / num_tokens 196 | return ((left, top + i * height, right, top + (i + 1) * height) for i in range(num_tokens)) 197 | 198 | 199 | if __name__ == "__main__": 200 | main(sys.argv) 201 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/test/__init__.py -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(autouse=True) 7 | def set_log_level(): 8 | logging.getLogger().setLevel(logging.DEBUG) 9 | -------------------------------------------------------------------------------- /test/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/test/data/__init__.py -------------------------------------------------------------------------------- /test/data/finetuning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/test/data/finetuning/__init__.py -------------------------------------------------------------------------------- /test/data/finetuning/test_reading_order.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple 3 | 4 | import pytest 5 | import torch 6 | 7 | from docllm.data.finetuning.reading_order import ReadingOrderConfig, ReadingOrderSampleBuilder, get_num_tokens 8 | 9 | 10 | def test_instantiation(reading_order_sample_builder: ReadingOrderSampleBuilder): 11 | assert isinstance(reading_order_sample_builder, ReadingOrderSampleBuilder) 12 | 13 | 14 | def test_produces_four_tuple( 15 | input_tokens: List[torch.LongTensor], 16 | bbox_tokens: List[torch.FloatTensor], 17 | reading_order_sample_builder: ReadingOrderSampleBuilder, 18 | ): 19 | res = reading_order_sample_builder(input_tokens, bbox_tokens) 20 | assert isinstance(res, tuple) and len(res) == 4 21 | 22 | 23 | def test_all_tuple_entries_are_tensors( 24 | input_tokens: List[torch.LongTensor], 25 | bbox_tokens: List[torch.FloatTensor], 26 | reading_order_sample_builder: ReadingOrderSampleBuilder, 27 | ): 28 | res = reading_order_sample_builder(input_tokens, bbox_tokens) 29 | assert all(isinstance(tensor, torch.Tensor) for tensor in res) 30 | 31 | 32 | def test_input_tokens_are_long_tensor( 33 | input_tokens: List[torch.LongTensor], 34 | bbox_tokens: List[torch.FloatTensor], 35 | reading_order_sample_builder: ReadingOrderSampleBuilder, 36 | ): 37 | input_tokens, _, _, _ = reading_order_sample_builder(input_tokens, bbox_tokens) 38 | assert isinstance(input_tokens, torch.LongTensor) 39 | 40 | 41 | def test_bounding_boxes_are_float_tensor( 42 | input_tokens: List[torch.LongTensor], 43 | bbox_tokens: List[torch.FloatTensor], 44 | reading_order_sample_builder: ReadingOrderSampleBuilder, 45 | ): 46 | _, bounding_boxes, _, _ = reading_order_sample_builder(input_tokens, bbox_tokens) 47 | assert isinstance(bounding_boxes, torch.FloatTensor) 48 | 49 | 50 | def test_loss_mask_is_bool_tensor( 51 | input_tokens: List[torch.LongTensor], 52 | bbox_tokens: List[torch.FloatTensor], 53 | reading_order_sample_builder: ReadingOrderSampleBuilder, 54 | ): 55 | _, _, loss_mask, _ = reading_order_sample_builder(input_tokens, bbox_tokens) 56 | assert isinstance(loss_mask, torch.BoolTensor) 57 | 58 | 59 | def test_label_is_long_tensor( 60 | input_tokens: List[torch.LongTensor], 61 | bbox_tokens: List[torch.FloatTensor], 62 | reading_order_sample_builder: ReadingOrderSampleBuilder, 63 | ): 64 | _, _, _, label = reading_order_sample_builder(input_tokens, bbox_tokens) 65 | assert isinstance(label, torch.LongTensor) 66 | 67 | 68 | def test_result_shapes_match( 69 | input_tokens: List[torch.LongTensor], 70 | bbox_tokens: List[torch.FloatTensor], 71 | reading_order_sample_builder: ReadingOrderSampleBuilder, 72 | ): 73 | input_ids, bounding_boxes, loss_mask, label = reading_order_sample_builder(input_tokens, bbox_tokens) 74 | seq_len = input_ids.size(0) 75 | assert bounding_boxes.size() == (seq_len, 4) 76 | assert label.size() == (seq_len,) 77 | assert loss_mask.size() == (seq_len,) 78 | 79 | 80 | def test_input_tokens_shape( 81 | input_tokens: List[torch.LongTensor], 82 | bbox_tokens: List[torch.FloatTensor], 83 | reading_order_sample_builder: ReadingOrderSampleBuilder, 84 | expected_sequence_length: int, 85 | ): 86 | input_ids, _, _, _ = reading_order_sample_builder(input_tokens, bbox_tokens) 87 | assert input_ids.size() == (expected_sequence_length,) 88 | 89 | 90 | def test_expected_output_in_input_ids_has_original_sequence( 91 | input_tokens: List[torch.LongTensor], 92 | bbox_tokens: List[torch.FloatTensor], 93 | reading_order_sample_builder: ReadingOrderSampleBuilder, 94 | ): 95 | input_ids, _, _, _ = reading_order_sample_builder(input_tokens, bbox_tokens) 96 | original_token_sequence = torch.cat(input_tokens) 97 | num_tokens = len(original_token_sequence) 98 | assert torch.equal(input_ids[-(num_tokens):], original_token_sequence) 99 | 100 | 101 | def test_expected_output_in_bounding_boxes_has_original_sequence( 102 | input_tokens: List[torch.LongTensor], 103 | bbox_tokens: List[torch.FloatTensor], 104 | reading_order_sample_builder: ReadingOrderSampleBuilder, 105 | ): 106 | _, bounding_boxes, _, _ = reading_order_sample_builder(input_tokens, bbox_tokens) 107 | original_bounding_boxes = torch.cat(bbox_tokens) 108 | num_tokens = len(original_bounding_boxes) 109 | assert torch.equal(bounding_boxes[-(num_tokens):], original_bounding_boxes) 110 | 111 | 112 | def test_expected_output_in_labels_has_original_sequence( 113 | input_tokens: List[torch.LongTensor], 114 | bbox_tokens: List[torch.FloatTensor], 115 | reading_order_sample_builder: ReadingOrderSampleBuilder, 116 | ): 117 | _, _, _, label = reading_order_sample_builder(input_tokens, bbox_tokens) 118 | original_token_sequence = torch.cat(input_tokens) 119 | num_tokens = len(original_token_sequence) 120 | assert torch.equal(label[-(num_tokens + 1) : -1], original_token_sequence) 121 | 122 | 123 | def test_loss_mask_masks_expected_output( 124 | input_tokens: List[torch.LongTensor], 125 | bbox_tokens: List[torch.FloatTensor], 126 | reading_order_sample_builder: ReadingOrderSampleBuilder, 127 | ): 128 | _, _, loss_mask, _ = reading_order_sample_builder(input_tokens, bbox_tokens) 129 | original_token_sequence = torch.cat(input_tokens) 130 | num_tokens = len(original_token_sequence) 131 | assert torch.equal(loss_mask[-num_tokens - 1 :], torch.ones(num_tokens + 1, dtype=torch.bool)) 132 | assert torch.equal(loss_mask[: -num_tokens - 1], torch.zeros_like(loss_mask[: -num_tokens - 1], dtype=torch.bool)) 133 | 134 | 135 | @pytest.fixture 136 | def expected_sequence_length(input_tokens: List[torch.LongTensor], readingorder_config: ReadingOrderConfig) -> int: 137 | bos_add = 1 if readingorder_config.bos_text_token is not None else 0 138 | eos_add = 1 if readingorder_config.eos_text_token is not None else 0 139 | num_tokens = get_num_tokens(input_tokens) 140 | prompt_length = len(readingorder_config.prompt_token_ids) 141 | 142 | return bos_add + num_tokens + eos_add + prompt_length + bos_add + num_tokens 143 | 144 | 145 | @pytest.fixture 146 | def reading_order_sample_builder(readingorder_config: ReadingOrderConfig) -> ReadingOrderSampleBuilder: 147 | return ReadingOrderSampleBuilder(readingorder_config) 148 | 149 | 150 | @pytest.fixture 151 | def readingorder_config(max_sequence_length: int) -> ReadingOrderConfig: 152 | return ReadingOrderConfig( 153 | max_seq_length=max_sequence_length, 154 | min_num_blocks_available=3, 155 | min_num_tokens_available=2, 156 | bos_text_token=1, 157 | bos_bbox_token=(0.0, 0.0, 0.0, 0.0), 158 | eos_text_token=2, 159 | eos_bbox_token=(1.0, 1.0, 1.0, 1.0), 160 | prompt_token_ids=[4, 5, 6], 161 | ) 162 | 163 | 164 | @pytest.fixture 165 | def input_tokens(test_data: Tuple[List[torch.LongTensor], List[torch.FloatTensor]]) -> List[torch.LongTensor]: 166 | return test_data[0] 167 | 168 | 169 | @pytest.fixture 170 | def bbox_tokens(test_data: Tuple[List[torch.LongTensor], List[torch.FloatTensor]]) -> List[torch.FloatTensor]: 171 | return test_data[1] 172 | 173 | 174 | @pytest.fixture 175 | def test_data( 176 | num_blocks: int, 177 | range_block_size: Tuple[int, int], 178 | ) -> Tuple[List[torch.LongTensor], List[torch.FloatTensor]]: 179 | max_token = 1111 180 | return tuple( 181 | map( 182 | list, 183 | zip( 184 | *( 185 | ( 186 | torch.randint(max_token, (b := random.randint(range_block_size[0], range_block_size[1]),)), 187 | torch.rand(b, 4), 188 | ) 189 | for _ in range(num_blocks) 190 | ) 191 | ), 192 | ) 193 | ) 194 | 195 | 196 | @pytest.fixture 197 | def max_sequence_length() -> int: 198 | return 128 199 | 200 | 201 | @pytest.fixture 202 | def num_blocks() -> int: 203 | return 10 204 | 205 | 206 | @pytest.fixture 207 | def range_block_size() -> Tuple[int, int]: 208 | return (2, 5) 209 | -------------------------------------------------------------------------------- /test/data/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/test/data/preprocessing/__init__.py -------------------------------------------------------------------------------- /test/data/preprocessing/example_json.py: -------------------------------------------------------------------------------- 1 | document_dict = { 2 | "pages": [ 3 | { 4 | "blocks": [ 5 | { 6 | "lines": [ 7 | { 8 | "text": "Some example sentence.", 9 | "words": [ 10 | { 11 | "text": "Some", 12 | "tokens": [ 13 | { 14 | "text": "Some", 15 | "token_id": 2416, 16 | "bbox": [48.0, 788.0, 85.0, 804.0], 17 | }, 18 | ], 19 | "bbox": [48.0, 788.0, 85.0, 804.0], 20 | }, 21 | { 22 | "text": "example", 23 | "tokens": [ 24 | { 25 | "text": "ex", 26 | "token_id": 192, 27 | "bbox": [85.0, 788.0, 107.0, 804.0], 28 | }, 29 | { 30 | "text": "ample", 31 | "token_id": 2989, 32 | "bbox": [107.0, 788.0, 111.0, 804.0], 33 | }, 34 | ], 35 | "bbox": [85.0, 788.0, 111.0, 804.0], 36 | }, 37 | { 38 | "text": "sentence.", 39 | "tokens": [ 40 | { 41 | "text": "sen", 42 | "token_id": 1984, 43 | "bbox": [111.0, 788.0, 132.0, 804.0], 44 | }, 45 | { 46 | "text": "tence", 47 | "token_id": 545, 48 | "bbox": [132.0, 788.0, 160.0, 804.0], 49 | }, 50 | { 51 | "text": ".", 52 | "token_id": 2989, 53 | "bbox": [160.0, 788.0, 163.0, 804.0], 54 | }, 55 | ], 56 | "bbox": [111.0, 788.0, 163.0, 804.0], 57 | }, 58 | ], 59 | } 60 | ], 61 | "bbox": [48.0, 788.0, 163.0, 804.0], 62 | } 63 | ], 64 | "width": 595.0, 65 | "height": 841.0, 66 | }, 67 | { 68 | "blocks": [ 69 | { 70 | "lines": [ 71 | { 72 | "text": "Other text example!", 73 | "words": [ 74 | { 75 | "text": "Other", 76 | "tokens": [ 77 | { 78 | "text": "Some", 79 | "token_id": 1, 80 | "bbox": [10.0, 15.0, 20.0, 25.0], 81 | }, 82 | ], 83 | "bbox": [10.0, 15.0, 20.0, 25.0], 84 | }, 85 | { 86 | "text": "te", 87 | "tokens": [ 88 | { 89 | "text": "te", 90 | "token_id": 2, 91 | "bbox": [30.0, 15.0, 40.0, 25.0], 92 | }, 93 | { 94 | "text": "xt", 95 | "token_id": 3, 96 | "bbox": [40.0, 15, 50.0, 25.0], 97 | }, 98 | ], 99 | "bbox": [30.0, 15.0, 50.0, 25.0], 100 | }, 101 | { 102 | "text": "example!", 103 | "tokens": [ 104 | { 105 | "text": "ex", 106 | "token_id": 4, 107 | "bbox": [10.0, 55.0, 30.0, 65.0], 108 | }, 109 | { 110 | "text": "ample", 111 | "token_id": 5, 112 | "bbox": [30.0, 55.0, 60.0, 65.0], 113 | }, 114 | { 115 | "text": "!", 116 | "token_id": 6, 117 | "bbox": [60.0, 53.0, 62.0, 65.0], 118 | }, 119 | ], 120 | "bbox": [10.0, 53.0, 62.0, 65.0], 121 | }, 122 | ], 123 | } 124 | ], 125 | "bbox": [10.0, 15.0, 62.0, 65.0], 126 | } 127 | ], 128 | "width": 100.0, 129 | "height": 95.0, 130 | }, 131 | ] 132 | } 133 | -------------------------------------------------------------------------------- /test/data/preprocessing/test_doc_data_to_pickle.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tempfile import NamedTemporaryFile, TemporaryDirectory 4 | from test.data.preprocessing.example_json import document_dict 5 | from typing import Callable, Iterable, List, Tuple 6 | 7 | import pytest 8 | import torch 9 | 10 | from docllm.data.preprocessing.doc_data_to_pickle import ( 11 | BoundingBoxTokenizer, 12 | document_tokenization_from_file, 13 | page_tokenization_to_block_tokens, 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def input_file() -> Iterable[str]: 19 | with NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".json") as file: 20 | json.dump(document_dict, file) 21 | file.flush() 22 | yield file.name 23 | 24 | 25 | @pytest.fixture 26 | def output_dir() -> Iterable[str]: 27 | with TemporaryDirectory() as output_dir: 28 | yield output_dir 29 | 30 | 31 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 32 | def test_document_tokenization_produces_correct_number_of_files( 33 | input_file: str, output_dir: str, use_page_dimensions: bool 34 | ): 35 | document_tokenization_from_file(input_file, output_dir, use_page_dimensions) 36 | assert len(os.listdir(output_dir)) == len(document_dict["pages"]) 37 | 38 | 39 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 40 | def test_document_tokenization_produces_correct_number_of_tensors( 41 | input_file: str, output_dir: str, use_page_dimensions: bool 42 | ): 43 | def test(tensors: List[Tuple[torch.Tensor, torch.Tensor]], page_num: int): 44 | assert len(tensors) == len(document_dict["pages"][page_num]["blocks"]) 45 | 46 | document_tokenization_from_file(input_file, output_dir, use_page_dimensions) 47 | _run_test_for_each_page_result_file(output_dir, test) 48 | 49 | 50 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 51 | def test_document_tokenization_produces_two_tensors_for_each_block( 52 | input_file: str, output_dir: str, use_page_dimensions: bool 53 | ): 54 | def test(tensors: List[Tuple[torch.Tensor, torch.Tensor]], page_num: int): 55 | assert all(len(tensor) == 2 for tensor in tensors) 56 | 57 | document_tokenization_from_file(input_file, output_dir, use_page_dimensions) 58 | _run_test_for_each_page_result_file(output_dir, test) 59 | 60 | 61 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 62 | def test_document_tokenization_tensor_lengths_are_number_of_tokens_in_block( 63 | input_file: str, output_dir: str, use_page_dimensions: bool 64 | ): 65 | def test(tensors: List[Tuple[torch.Tensor, torch.Tensor]], page_num: int): 66 | for (text_tensor, bbox_tensor), block in zip(tensors, document_dict["pages"][page_num]["blocks"]): 67 | num_tokens = len([t for l in block["lines"] for w in l["words"] for t in w["tokens"]]) 68 | assert text_tensor.shape == (num_tokens,) and bbox_tensor.shape[0] == num_tokens 69 | 70 | document_tokenization_from_file(input_file, output_dir, use_page_dimensions) 71 | _run_test_for_each_page_result_file(output_dir, test) 72 | 73 | 74 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 75 | def test_document_tokenization_produces_bbox_tensors_with_size_four( 76 | input_file: str, output_dir: str, use_page_dimensions: bool 77 | ): 78 | def test(tensors: List[Tuple[torch.Tensor, torch.Tensor]], page_num: int): 79 | for _, bbox_tensor in tensors: 80 | assert bbox_tensor.shape[-1] == 4 81 | 82 | document_tokenization_from_file(input_file, output_dir, use_page_dimensions) 83 | _run_test_for_each_page_result_file(output_dir, test) 84 | 85 | 86 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 87 | def test_document_tokenization_produces_bbox_with_expected_shape( 88 | input_file: str, output_dir: str, use_page_dimensions: bool 89 | ): 90 | def test(tensors: List[Tuple[torch.Tensor, torch.Tensor]], page_num: int): 91 | for (_, bbox_tensor), block in zip(tensors, document_dict["pages"][page_num]["blocks"]): 92 | num_tokens = len([t for l in block["lines"] for w in l["words"] for t in w["tokens"]]) 93 | assert bbox_tensor.shape == (num_tokens, 4) 94 | 95 | document_tokenization_from_file(input_file, output_dir, use_page_dimensions) 96 | _run_test_for_each_page_result_file(output_dir, test) 97 | 98 | 99 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 100 | def test_page_tokenization_to_block_tokens_produces_correct_number_of_tensors(use_page_dimensions): 101 | page = document_dict["pages"][1] 102 | tokenizer = BoundingBoxTokenizer(page, use_page_dimensions=use_page_dimensions) 103 | tensors = list(page_tokenization_to_block_tokens(page, tokenizer)) 104 | for (text_tensor, bbox_tensor), b in zip(tensors, page["blocks"]): 105 | num_tokens = len([t for l in b["lines"] for w in l["words"] for t in w["tokens"]]) 106 | assert text_tensor.shape == (num_tokens,) and len(bbox_tensor) == num_tokens 107 | 108 | 109 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 110 | def test_page_tokenization_to_block_tokens_produces_bbox_tensors_with_size_four(use_page_dimensions): 111 | page = document_dict["pages"][1] 112 | tokenizer = BoundingBoxTokenizer(page, use_page_dimensions=use_page_dimensions) 113 | tensors = list(page_tokenization_to_block_tokens(page, tokenizer)) 114 | assert all(bbox_tensor.shape[-1] == 4 for _, bbox_tensor in tensors) 115 | 116 | 117 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 118 | def test_bounding_box_tokenizer_produces_correct_output_shape(use_page_dimensions: bool): 119 | tokenizer = BoundingBoxTokenizer(document_dict["pages"][1], use_page_dimensions=use_page_dimensions) 120 | bounding_box = [10.0, 15.0, 20.0, 25.0] 121 | assert tokenizer(bounding_box).shape == (4,) 122 | 123 | 124 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 125 | def test_tokenized_bounding_boxes_in_zero_one(use_page_dimensions: bool): 126 | tokenizer = BoundingBoxTokenizer(document_dict["pages"][1], use_page_dimensions=use_page_dimensions) 127 | token_bbs = [ 128 | tokenizer(t["bbox"]) 129 | for block in document_dict["pages"][1]["blocks"] 130 | for line in block["lines"] 131 | for word in line["words"] 132 | for t in word["tokens"] 133 | ] 134 | assert all((t >= 0).all() and (t <= 1).all() for t in token_bbs) 135 | 136 | 137 | @pytest.mark.parametrize("use_page_dimensions", [True, False]) 138 | def test_tokenized_bounding_boxes_have_non_negative_lengths(use_page_dimensions: bool): 139 | tokenizer = BoundingBoxTokenizer(document_dict["pages"][1], use_page_dimensions=use_page_dimensions) 140 | token_bbs = [ 141 | tokenizer(t["bbox"]) 142 | for block in document_dict["pages"][1]["blocks"] 143 | for line in block["lines"] 144 | for word in line["words"] 145 | for t in word["tokens"] 146 | ] 147 | assert all(t[0] <= t[2] and t[1] <= t[3] for t in token_bbs) 148 | 149 | 150 | def test_bounding_box_tokenizer_produces_scaled_coordinates_using_page_dimensions(): 151 | page = document_dict["pages"][1] 152 | 153 | use_page_dimensions = True 154 | tokenizer = BoundingBoxTokenizer(page, use_page_dimensions=use_page_dimensions) 155 | bounding_box = [10.0, 15.0, 20.0, 25.0] 156 | token_bb = [ 157 | bounding_box[0] / page["width"], 158 | bounding_box[1] / page["height"], 159 | bounding_box[2] / page["width"], 160 | bounding_box[3] / page["height"], 161 | ] 162 | assert tokenizer(bounding_box).tolist() == pytest.approx(token_bb, 0.0001) 163 | 164 | 165 | def test_bounding_box_tokenizer_produces_scaled_coordinates_using_min_max(): 166 | page = document_dict["pages"][1] 167 | coords = [t["bbox"] for b in page["blocks"] for l in b["lines"] for w in l["words"] for t in w["tokens"]] 168 | minx_coords, miny_coords, maxx_coords, maxy_coords = zip(*coords) 169 | minx, miny, maxx, maxy = min(minx_coords), min(miny_coords), max(maxx_coords), max(maxy_coords) 170 | width, height = maxx - minx, maxy - miny 171 | 172 | use_page_dimensions = False 173 | tokenizer = BoundingBoxTokenizer(page, use_page_dimensions=use_page_dimensions) 174 | bounding_box = [10.0, 15.0, 20.0, 25.0] 175 | token_bb = [ 176 | (bounding_box[0] - minx) / width, 177 | (bounding_box[1] - miny) / height, 178 | (bounding_box[2] - minx) / width, 179 | (bounding_box[3] - miny) / height, 180 | ] 181 | assert tokenizer(bounding_box).tolist() == pytest.approx(token_bb, 0.0001) 182 | 183 | 184 | def _run_test_for_each_page_result_file( 185 | output_dir: str, test: Callable[[List[Tuple[torch.Tensor, torch.Tensor]], int], None] 186 | ): 187 | for filename in os.listdir(output_dir): 188 | with open(os.path.join(output_dir, filename), "rb") as file: 189 | tensors = torch.load(file) 190 | page_num = int(filename.split("_")[-1].split(".")[0]) 191 | test(tensors, page_num) 192 | -------------------------------------------------------------------------------- /test/data/pretraining/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/test/data/pretraining/__init__.py -------------------------------------------------------------------------------- /test/data/pretraining/conftest.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from tempfile import TemporaryDirectory 4 | from typing import Iterable, List, Tuple 5 | 6 | import pytest 7 | import torch 8 | 9 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig, NumMaskedBlocksType 10 | 11 | 12 | @pytest.fixture 13 | def num_docs() -> int: 14 | return 10 15 | 16 | 17 | @pytest.fixture 18 | def multiple_docs_test_data(num_docs: int) -> List[List[Tuple[torch.LongTensor, torch.FloatTensor]]]: 19 | docs = [] 20 | for _ in range(num_docs): 21 | num_blocks = random.randint(4, 11) 22 | data = [ 23 | ( 24 | torch.randint(1, 10, (block_len := random.randint(1, 7),), dtype=torch.long), 25 | torch.rand(block_len, 4), 26 | ) 27 | for _ in range(num_blocks) 28 | ] 29 | docs.append(data) 30 | return docs 31 | 32 | 33 | @pytest.fixture 34 | def input_dir_with_data( 35 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 36 | ) -> Iterable[str]: 37 | with TemporaryDirectory() as input_dir: 38 | for i, data in enumerate(multiple_docs_test_data): 39 | torch.save(data, f"{input_dir}/doc_{i}.pt") 40 | yield input_dir 41 | 42 | 43 | @pytest.fixture 44 | def num_blocks() -> int: 45 | return 10 46 | 47 | 48 | @pytest.fixture 49 | def batch_size(request) -> int: 50 | if not hasattr(request, "param") or request.param is None: 51 | return 1 52 | return request.param 53 | 54 | 55 | @pytest.fixture 56 | def range_block_size() -> Tuple[int, int]: 57 | return (2, 5) 58 | 59 | 60 | @pytest.fixture 61 | def max_sequence_length( 62 | request, num_masked_blocks: NumMaskedBlocksType, range_block_size: Tuple[int, int], num_blocks: int 63 | ) -> int: 64 | if hasattr(request, "param") and request.param and isinstance(request.param, int): 65 | return request.param 66 | if isinstance(num_masked_blocks, tuple): 67 | num_masked_blocks = num_masked_blocks[1] 68 | if isinstance(num_masked_blocks, int): 69 | max_masked_blocks = num_masked_blocks 70 | elif isinstance(num_masked_blocks, float): 71 | max_masked_blocks = math.ceil(num_masked_blocks * range_block_size[1]) 72 | return range_block_size[1] * num_blocks + (range_block_size[1] + 1) * max_masked_blocks 73 | 74 | 75 | @pytest.fixture(params=[1, 2, (3, 5), 0.2, (0.2, 0.4)]) 76 | def num_masked_blocks(request) -> NumMaskedBlocksType: 77 | return request.param 78 | 79 | 80 | @pytest.fixture 81 | def max_percentage_masked_blocks(request) -> float: 82 | if not hasattr(request, "param") or request.param is None: 83 | return 0.8 84 | return request.param 85 | 86 | 87 | @pytest.fixture 88 | def use_sharding_filter(request) -> float: 89 | if not hasattr(request, "param") or request.param is None: 90 | return False 91 | return request.param 92 | 93 | 94 | @pytest.fixture 95 | def pretraining_config( 96 | batch_size: int, 97 | use_sharding_filter: bool, 98 | max_sequence_length: int, 99 | num_masked_blocks: NumMaskedBlocksType, 100 | max_percentage_masked_blocks: float, 101 | ) -> DocLLMPreTrainDataConfig: 102 | return DocLLMPreTrainDataConfig( 103 | batch_size=batch_size, 104 | use_sharding_filter=use_sharding_filter, 105 | max_seq_length=max_sequence_length, 106 | num_masked_blocks=num_masked_blocks, 107 | max_percentage_masked_blocks=max_percentage_masked_blocks, 108 | mask_text_token=0, 109 | mask_bbox_token=(0.0, 0.0, 0.0, 0.0), 110 | block_start_text_token=1337, 111 | block_start_bbox_token=(0.0, 0.0, 0.0, 0.0), 112 | block_end_text_token=1338, 113 | bos_text_token=1, 114 | bos_bbox_token=(0.0, 0.0, 0.0, 0.0), 115 | directory="", 116 | ) 117 | 118 | 119 | @pytest.fixture 120 | def pretraining_config_with_input_dir( 121 | pretraining_config: DocLLMPreTrainDataConfig, input_dir_with_data: str 122 | ) -> DocLLMPreTrainDataConfig: 123 | pretraining_config = pretraining_config.model_copy() 124 | pretraining_config.directory = input_dir_with_data 125 | return pretraining_config 126 | -------------------------------------------------------------------------------- /test/data/pretraining/test_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from docllm.data import DocLLMPreTrainDataConfig 4 | from docllm.data.pretraining.config import NumMaskedBlocksType 5 | 6 | 7 | def test_instantiation(pretraining_config: DocLLMPreTrainDataConfig): 8 | assert isinstance(pretraining_config, DocLLMPreTrainDataConfig) 9 | 10 | 11 | def test_num_masked_blocks_callable_callable_produces_values_defined_in_config( 12 | pretraining_config: DocLLMPreTrainDataConfig, num_blocks: int 13 | ): 14 | num_masked_blocks_callable = pretraining_config.num_masked_blocks_callable 15 | if isinstance(pretraining_config.num_masked_blocks, int): 16 | assert num_masked_blocks_callable(num_blocks) == pretraining_config.num_masked_blocks 17 | elif isinstance(pretraining_config.num_masked_blocks, float): 18 | assert num_masked_blocks_callable(num_blocks) == math.ceil(num_blocks * pretraining_config.num_masked_blocks) 19 | elif isinstance(pretraining_config.num_masked_blocks, tuple): 20 | lower = pretraining_config.num_masked_blocks[0] 21 | upper = pretraining_config.num_masked_blocks[1] - 1 22 | if isinstance(pretraining_config.num_masked_blocks[0], float): 23 | lower = math.ceil(num_blocks * lower) 24 | upper = math.ceil(num_blocks * (upper + 1)) 25 | for _ in range(10): 26 | assert lower <= num_masked_blocks_callable(num_blocks) <= upper 27 | else: 28 | raise ValueError("Invalid num_masked_blocks type") 29 | -------------------------------------------------------------------------------- /test/data/pretraining/test_data_packing_pipe.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple 3 | from unittest.mock import MagicMock 4 | 5 | import pytest 6 | import torch 7 | from torch.utils.data import IterDataPipe 8 | 9 | from docllm.data.pretraining.data_packing_pipe import DataPackingPipe 10 | 11 | 12 | @pytest.fixture 13 | def test_data( 14 | num_docs: int, 15 | range_block_size: Tuple[int, int], 16 | ) -> List[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]]: 17 | max_token = 1111 18 | return [ 19 | ( 20 | torch.randint(max_token, (b := random.randint(range_block_size[0], range_block_size[1]),)), 21 | torch.rand(b, 4), 22 | torch.ones(b, dtype=torch.bool), 23 | torch.randint(max_token, (b,)), 24 | ) 25 | for _ in range(num_docs) 26 | ] 27 | 28 | 29 | @pytest.fixture 30 | def min_test_data_sequence_length( 31 | test_data: List[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]], 32 | ) -> int: 33 | return min(t.size(0) for t, _, _, _ in test_data) 34 | 35 | 36 | @pytest.fixture 37 | def source_datapipe( 38 | test_data: List[Tuple[torch.LongTensor, torch.FloatTensor, torch.BoolTensor, torch.LongTensor]], 39 | ) -> IterDataPipe: 40 | pipe = MagicMock(spec=IterDataPipe) 41 | pipe.__iter__.return_value = test_data 42 | return pipe 43 | 44 | 45 | @pytest.fixture 46 | def data_packing_pipe( 47 | source_datapipe: IterDataPipe, 48 | max_sequence_length: int, 49 | ) -> DataPackingPipe: 50 | return DataPackingPipe(source_datapipe, max_sequence_length) 51 | 52 | 53 | @pytest.mark.parametrize("max_sequence_length", (80, 150, 1000), indirect=True) 54 | def test_instantiation(data_packing_pipe: DataPackingPipe): 55 | assert isinstance(data_packing_pipe, DataPackingPipe) 56 | 57 | 58 | @pytest.mark.parametrize("max_sequence_length", (10000,), indirect=True) 59 | def test_datapipe_produces_single_output_for_long_sequence_length(data_packing_pipe: DataPackingPipe): 60 | assert len(list(data_packing_pipe)) == 1 61 | 62 | 63 | @pytest.mark.parametrize("max_sequence_length", (80, 150), indirect=True) 64 | def test_output_tensors_are_at_most_max_sequence_length_long( 65 | data_packing_pipe: DataPackingPipe, max_sequence_length: int 66 | ): 67 | assert all(t.size(0) <= max_sequence_length for results in data_packing_pipe for t in results) 68 | 69 | 70 | @pytest.mark.parametrize("max_sequence_length", (80, 150), indirect=True) 71 | def test_output_tensors_are_longer_than_min_test_sequence_length_long( 72 | data_packing_pipe: DataPackingPipe, min_test_data_sequence_length: int 73 | ): 74 | assert all(min_test_data_sequence_length < t.size(0) for results in data_packing_pipe for t in results) 75 | -------------------------------------------------------------------------------- /test/data/pretraining/test_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | import torch 6 | 7 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 8 | from docllm.data.pretraining.dataset import DocLLMPretrainDataset 9 | from docllm.data.pretraining.masker import DocLLMPretrainingMasker 10 | 11 | 12 | def test_initialization(dataset: DocLLMPretrainDataset): 13 | assert isinstance(dataset, DocLLMPretrainDataset) 14 | 15 | 16 | def test_has_expected_length( 17 | dataset: DocLLMPretrainDataset, multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]] 18 | ): 19 | assert len(dataset) == len(multiple_docs_test_data) 20 | 21 | 22 | def test_produces_expected_number_of_tensors(dataset: DocLLMPretrainDataset): 23 | for i in range(len(dataset)): 24 | assert len(dataset[i]) == 4 25 | 26 | 27 | def test_produces_data_with_expected_length( 28 | dataset: DocLLMPretrainDataset, 29 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 30 | dummy_masker: DocLLMPretrainingMasker, 31 | ): 32 | dataset._masker = dummy_masker 33 | for i in range(len(dataset)): 34 | inputs, bboxes, _, _ = dataset[i] 35 | expected_inputs, expected_bboxes = zip(*multiple_docs_test_data[i]) 36 | assert len(inputs) == sum(len(block) for block in expected_inputs) 37 | assert len(bboxes) == sum(len(block) for block in expected_bboxes) 38 | 39 | 40 | def test_produces_expected_data( 41 | dataset: DocLLMPretrainDataset, 42 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 43 | dummy_masker: DocLLMPretrainingMasker, 44 | ): 45 | dataset._masker = dummy_masker 46 | for i in range(len(dataset)): 47 | inputs, bboxes, _, _ = dataset[i] 48 | expected_inputs, expected_bboxes = zip(*multiple_docs_test_data[i]) 49 | assert (inputs == torch.cat(expected_inputs)).all() 50 | assert (bboxes == torch.cat(expected_bboxes)).all() 51 | 52 | 53 | def test_input_tokens_have_expected_type(dataset: DocLLMPretrainDataset): 54 | for i in range(len(dataset)): 55 | inputs, _, _, _ = dataset[i] 56 | assert isinstance(inputs, torch.LongTensor) 57 | 58 | 59 | def test_input_bboxes_have_expected_type(dataset: DocLLMPretrainDataset): 60 | for i in range(len(dataset)): 61 | _, bboxes, _, _ = dataset[i] 62 | assert isinstance(bboxes, torch.FloatTensor) 63 | 64 | 65 | def test_loss_mask_has_expected_type(dataset: DocLLMPretrainDataset): 66 | for i in range(len(dataset)): 67 | _, _, mask, _ = dataset[i] 68 | assert isinstance(mask, torch.BoolTensor) 69 | 70 | 71 | def test_labels_have_expected_type(dataset: DocLLMPretrainDataset): 72 | for i in range(len(dataset)): 73 | _, _, _, labels = dataset[i] 74 | assert isinstance(labels, torch.LongTensor) 75 | 76 | 77 | def test_input_tokens_have_expected_shape(dataset: DocLLMPretrainDataset): 78 | for i in range(len(dataset)): 79 | inputs, _, _, _ = dataset[i] 80 | assert len(inputs.shape) == 1 81 | 82 | 83 | def test_input_bboxes_have_expected_shape(dataset: DocLLMPretrainDataset): 84 | for i in range(len(dataset)): 85 | _, bboxes, _, _ = dataset[i] 86 | assert len(bboxes.shape) == 2 87 | assert bboxes.size(1) == 4 88 | 89 | 90 | def test_loss_mask_has_expected_shape(dataset: DocLLMPretrainDataset): 91 | for i in range(len(dataset)): 92 | _, _, mask, _ = dataset[i] 93 | assert len(mask.shape) == 1 94 | 95 | 96 | def test_labels_have_expected_shape(dataset: DocLLMPretrainDataset): 97 | for i in range(len(dataset)): 98 | _, _, _, labels = dataset[i] 99 | assert len(labels.shape) == 1 100 | 101 | 102 | def test_all_sequence_lengths_match(dataset: DocLLMPretrainDataset): 103 | for i in range(len(dataset)): 104 | inputs, bboxes, mask, labels = dataset[i] 105 | assert len(inputs) == len(bboxes) == len(mask) == len(labels) 106 | 107 | 108 | @pytest.fixture 109 | def dataset(pretraining_config_with_input_dir: DocLLMPreTrainDataConfig) -> DocLLMPretrainDataset: 110 | return DocLLMPretrainDataset(pretraining_config_with_input_dir) 111 | 112 | 113 | @pytest.fixture 114 | def dummy_masker() -> DocLLMPretrainingMasker: 115 | def no_mask( 116 | input_tensors: List[torch.LongTensor], bbox_tensors: List[torch.FloatTensor] 117 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor, torch.LongTensor]: 118 | inputs = torch.cat(input_tensors) 119 | bboxes = torch.cat(bbox_tensors) 120 | loss_mask = torch.ones_like(inputs, dtype=torch.bool) 121 | labels = torch.randint(0, 10, (inputs.size(0),), dtype=torch.long) 122 | return inputs, bboxes, loss_mask, labels 123 | 124 | masker = MagicMock(spec=DocLLMPretrainingMasker) 125 | masker.side_effect = no_mask 126 | return masker 127 | -------------------------------------------------------------------------------- /test/data/pretraining/test_masker.py: -------------------------------------------------------------------------------- 1 | import random 2 | from test.util import extract_return_value 3 | from typing import List, Tuple 4 | from unittest.mock import MagicMock 5 | 6 | import pytest 7 | import torch 8 | 9 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 10 | from docllm.data.pretraining.masker import DocLLMPretrainingMasker 11 | 12 | 13 | def test_instantiation(docllm_pretrain_masker: DocLLMPretrainingMasker): 14 | assert isinstance(docllm_pretrain_masker, DocLLMPretrainingMasker) 15 | 16 | 17 | def test_produces_four_tuple( 18 | input_tokens: List[torch.LongTensor], 19 | bbox_tokens: List[torch.FloatTensor], 20 | docllm_pretrain_masker: DocLLMPretrainingMasker, 21 | ): 22 | res = docllm_pretrain_masker(input_tokens, bbox_tokens) 23 | assert isinstance(res, tuple) and len(res) == 4 24 | 25 | 26 | def test_all_tuple_entries_are_tensors( 27 | input_tokens: List[torch.LongTensor], 28 | bbox_tokens: List[torch.FloatTensor], 29 | docllm_pretrain_masker: DocLLMPretrainingMasker, 30 | ): 31 | res = docllm_pretrain_masker(input_tokens, bbox_tokens) 32 | assert all(isinstance(tensor, torch.Tensor) for tensor in res) 33 | 34 | 35 | def test_input_tokens_are_long_tensor( 36 | input_tokens: List[torch.LongTensor], 37 | bbox_tokens: List[torch.FloatTensor], 38 | docllm_pretrain_masker: DocLLMPretrainingMasker, 39 | ): 40 | input_tokens, _, _, _ = docllm_pretrain_masker(input_tokens, bbox_tokens) 41 | assert isinstance(input_tokens, torch.LongTensor) 42 | 43 | 44 | def test_bounding_boxes_are_float_tensor( 45 | input_tokens: List[torch.LongTensor], 46 | bbox_tokens: List[torch.FloatTensor], 47 | docllm_pretrain_masker: DocLLMPretrainingMasker, 48 | ): 49 | _, bounding_boxes, _, _ = docllm_pretrain_masker(input_tokens, bbox_tokens) 50 | assert isinstance(bounding_boxes, torch.FloatTensor) 51 | 52 | 53 | def test_loss_mask_is_bool_tensor( 54 | input_tokens: List[torch.LongTensor], 55 | bbox_tokens: List[torch.FloatTensor], 56 | docllm_pretrain_masker: DocLLMPretrainingMasker, 57 | ): 58 | _, _, loss_mask, _ = docllm_pretrain_masker(input_tokens, bbox_tokens) 59 | assert isinstance(loss_mask, torch.BoolTensor) 60 | 61 | 62 | def test_label_is_long_tensor( 63 | input_tokens: List[torch.LongTensor], 64 | bbox_tokens: List[torch.FloatTensor], 65 | docllm_pretrain_masker: DocLLMPretrainingMasker, 66 | ): 67 | _, _, _, label = docllm_pretrain_masker(input_tokens, bbox_tokens) 68 | assert isinstance(label, torch.LongTensor) 69 | 70 | 71 | def test_get_mask_indices(): 72 | train_data_pipe = DocLLMPretrainingMasker(MagicMock(spec=DocLLMPreTrainDataConfig)) 73 | num_blocks = 10 74 | num_masks = 4 75 | for _ in range(20): 76 | result = train_data_pipe._get_mask_indices(num_blocks, num_masks) 77 | assert len(result) == num_masks 78 | assert all(0 <= index < num_blocks for index in result) 79 | 80 | 81 | @pytest.fixture 82 | def docllm_pretrain_masker(pretraining_config: DocLLMPreTrainDataConfig) -> DocLLMPretrainingMasker: 83 | return DocLLMPretrainingMasker(pretraining_config) 84 | 85 | 86 | @pytest.fixture 87 | def input_tokens(test_data: Tuple[List[torch.LongTensor], List[torch.FloatTensor]]) -> List[torch.LongTensor]: 88 | return test_data[0] 89 | 90 | 91 | @pytest.fixture 92 | def bbox_tokens(test_data: Tuple[List[torch.LongTensor], List[torch.FloatTensor]]) -> List[torch.FloatTensor]: 93 | return test_data[1] 94 | 95 | 96 | @pytest.fixture 97 | def test_data( 98 | num_blocks: int, 99 | range_block_size: Tuple[int, int], 100 | ) -> Tuple[List[torch.LongTensor], List[torch.FloatTensor]]: 101 | max_token = 1111 102 | return tuple( 103 | map( 104 | list, 105 | zip( 106 | *( 107 | ( 108 | torch.randint(max_token, (b := random.randint(range_block_size[0], range_block_size[1]),)), 109 | torch.rand(b, 4), 110 | ) 111 | for _ in range(num_blocks) 112 | ) 113 | ), 114 | ) 115 | ) 116 | -------------------------------------------------------------------------------- /test/data/pretraining/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Tuple 3 | 4 | import pytest 5 | import torch 6 | from torch.utils.data import DataLoader, IterDataPipe 7 | 8 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 9 | from docllm.data.pretraining.pipeline import build_docllm_datapipeline 10 | 11 | 12 | @pytest.fixture 13 | def num_batches(num_docs: int, batch_size: int) -> int: 14 | return num_docs // batch_size + (1 if num_docs % batch_size else 0) 15 | 16 | 17 | @pytest.fixture 18 | def pipeline(pretraining_config_with_input_dir: DocLLMPreTrainDataConfig) -> IterDataPipe: 19 | return build_docllm_datapipeline(pretraining_config_with_input_dir) 20 | 21 | 22 | def test_initialization(pipeline: IterDataPipe): 23 | assert isinstance(pipeline, IterDataPipe) 24 | 25 | 26 | def test_datapipe_produces_expected_number_of_document_results(pipeline: IterDataPipe, num_docs: int): 27 | # FIXME fails sometimes because sequence length is to long 28 | assert len(list(pipeline)) == num_docs 29 | 30 | 31 | @pytest.mark.parametrize("batch_size", (2, 4, 7), indirect=True) 32 | def test_datapipe_produces_expected_number_of_document_results_with_varying_batch_sizes( 33 | pipeline: IterDataPipe, num_batches: int 34 | ): 35 | assert len(list(pipeline)) == num_batches 36 | 37 | 38 | def test_datapipe_produces_result_list_with_batch_size_entries(pipeline: IterDataPipe): 39 | assert all(isinstance(res, list) and len(res) == 4 for res in list(pipeline)[:-1]) 40 | 41 | 42 | def test_datapipe_first_list_entry_is_long_tensor(pipeline: IterDataPipe): 43 | assert all(isinstance(res[0], torch.LongTensor) for res in pipeline) 44 | 45 | 46 | def test_datapipe_second_list_entry_is_float_tensor(pipeline: IterDataPipe): 47 | assert all(isinstance(res[1], torch.FloatTensor) for res in pipeline) 48 | 49 | 50 | def test_datapipe_third_list_entry_is_bool_tensor(pipeline: IterDataPipe): 51 | assert all(isinstance(res[2], torch.BoolTensor) for res in pipeline) 52 | 53 | 54 | @pytest.mark.parametrize("batch_size", (2, 4, 7), indirect=True) 55 | def test_datapipe_tensors_have_length_batch_size( 56 | pipeline: IterDataPipe, 57 | batch_size: int, 58 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 59 | ): 60 | for batch_start_idx, res in zip(range(0, len(multiple_docs_test_data), batch_size), pipeline): 61 | effective_batch_size = min(batch_size, len(multiple_docs_test_data) - batch_start_idx) 62 | assert all(r.size(0) == effective_batch_size for r in res) 63 | 64 | 65 | def test_datapipe_in_dataloader_loads_expected_number_of_files( 66 | pipeline: IterDataPipe, 67 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 68 | ): 69 | dataloader = DataLoader(pipeline, batch_size=None, num_workers=0) 70 | assert len(list(dataloader)) == len(multiple_docs_test_data) 71 | 72 | 73 | @pytest.mark.parametrize("use_sharding_filter", (True,), indirect=True) 74 | def test_datapipe_in_dataloader_with_multiple_workers_loads_expected_number_of_files( 75 | pipeline: IterDataPipe, 76 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 77 | ): 78 | dataloader = DataLoader(pipeline, batch_size=None, num_workers=4) 79 | assert len(list(dataloader)) == len(multiple_docs_test_data) 80 | 81 | 82 | @pytest.mark.parametrize("use_sharding_filter,batch_size", ((True, 2), (True, 4)), indirect=True) 83 | def test_datapipe_with_batch_size_in_dataloader_with_single_workers_loads_expected_number_of_files( 84 | pipeline: IterDataPipe, 85 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 86 | batch_size: int, 87 | ): 88 | dataloader = DataLoader(pipeline, batch_size=None, num_workers=0) 89 | data = list(dataloader) 90 | assert len(data) == math.ceil(len(multiple_docs_test_data) / batch_size) 91 | 92 | 93 | @pytest.mark.parametrize("use_sharding_filter,batch_size", ((True, 2), (True, 4)), indirect=True) 94 | def test_datapipe_with_batch_size_in_dataloader_with_multiple_workers_loads_expected_number_of_files( 95 | pipeline: IterDataPipe, 96 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 97 | batch_size: int, 98 | ): 99 | dataloader = DataLoader(pipeline, batch_size=None, num_workers=2) 100 | data = list(dataloader) 101 | # Note: More batches than expected are produced because multiple workers produce multiple partial batches 102 | assert len(data) == math.ceil(len(multiple_docs_test_data) / batch_size) + 1 103 | -------------------------------------------------------------------------------- /test/data/pretraining/test_precomputation_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tempfile import TemporaryDirectory 3 | from typing import Iterable, List, Tuple 4 | 5 | import pytest 6 | import torch 7 | 8 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 9 | from docllm.data.pretraining.precomputation_pipeline import build_precomputed_data_pipeline, precompute_data 10 | 11 | 12 | @pytest.fixture 13 | def output_dir() -> Iterable[str]: 14 | with TemporaryDirectory() as output_dir: 15 | yield output_dir 16 | 17 | 18 | @pytest.fixture 19 | def precomputation_config( 20 | input_dir_with_data: str, pretraining_config: DocLLMPreTrainDataConfig 21 | ) -> DocLLMPreTrainDataConfig: 22 | config = pretraining_config.model_copy( 23 | update={ 24 | "batch_size": 1, 25 | "drop_last_batch_if_not_full": False, 26 | "shuffle": False, 27 | "use_sharding_filter": False, 28 | "directory": input_dir_with_data, 29 | } 30 | ) 31 | return config 32 | 33 | 34 | def test_precompute_data_produces_expected_number_of_files( 35 | precomputation_config: DocLLMPreTrainDataConfig, output_dir: str, num_docs: int 36 | ): 37 | precompute_data(precomputation_config, output_dir) 38 | assert len(os.listdir(output_dir)) == num_docs 39 | 40 | 41 | def test_precompute_data_produces_expected_files( 42 | precomputation_config: DocLLMPreTrainDataConfig, output_dir: str, num_docs: int 43 | ): 44 | precompute_data(precomputation_config, output_dir) 45 | for i in range(num_docs): 46 | assert os.path.exists(os.path.join(output_dir, f"eval_{i}.pt")) 47 | 48 | 49 | def test_loading_precomputed_files_yields_expected_number_of_files( 50 | precomputation_config: DocLLMPreTrainDataConfig, output_dir: str, num_docs: int 51 | ): 52 | precompute_data(precomputation_config, output_dir) 53 | datapipe = build_precomputed_data_pipeline(output_dir) 54 | assert len(list(datapipe)) == num_docs 55 | 56 | 57 | def test_loaded_files_have_consistent_shape( 58 | precomputation_config: DocLLMPreTrainDataConfig, output_dir: str, num_docs: int 59 | ): 60 | precompute_data(precomputation_config, output_dir) 61 | datapipe = build_precomputed_data_pipeline(output_dir) 62 | for i, data in enumerate(datapipe): 63 | assert len(data) == 4 64 | assert len(data[0]) == len(data[1]) == len(data[2]) == len(data[3]) 65 | for j in range(len(data[0])): 66 | assert data[0][j].shape[0] == data[1][j].shape[0] 67 | assert data[0][j].shape[0] == data[2][j].shape[0] 68 | assert data[0][j].shape[0] == data[3][j].shape[0] 69 | 70 | 71 | def test_loaded_files_have_original_text_tokens_where_not_masked( 72 | precomputation_config: DocLLMPreTrainDataConfig, 73 | output_dir: str, 74 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 75 | ): 76 | precomputation_config.use_packing = False 77 | precompute_data(precomputation_config, output_dir) 78 | datapipe = build_precomputed_data_pipeline(output_dir) 79 | for i, data in enumerate(datapipe): 80 | pipe_data_idx = 1 81 | for block in multiple_docs_test_data[i]: 82 | if data[0][0][pipe_data_idx] == precomputation_config.mask_text_token: 83 | pipe_data_idx += 1 84 | else: 85 | block_size = len(block[0]) 86 | new_idx = pipe_data_idx + block_size 87 | assert torch.equal(data[0][0][pipe_data_idx:new_idx], block[0]) 88 | pipe_data_idx = new_idx 89 | 90 | 91 | def test_loaded_files_have_original_bbox_tokens_where_not_masked( 92 | precomputation_config: DocLLMPreTrainDataConfig, 93 | output_dir: str, 94 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 95 | ): 96 | precomputation_config.use_packing = False 97 | precompute_data(precomputation_config, output_dir) 98 | datapipe = build_precomputed_data_pipeline(output_dir) 99 | for i, data in enumerate(datapipe): 100 | pipe_data_idx = 1 101 | for block in multiple_docs_test_data[i]: 102 | if data[0][0][pipe_data_idx] == precomputation_config.mask_text_token: 103 | pipe_data_idx += 1 104 | else: 105 | block_size = len(block[0]) 106 | new_idx = pipe_data_idx + block_size 107 | assert torch.equal(data[1][0][pipe_data_idx:new_idx], block[1]) 108 | pipe_data_idx = new_idx 109 | 110 | 111 | def test_loaded_files_test_and_bbox_tokens_are_masked_at_same_positions( 112 | precomputation_config: DocLLMPreTrainDataConfig, 113 | output_dir: str, 114 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 115 | ): 116 | precomputation_config.use_packing = False 117 | precompute_data(precomputation_config, output_dir) 118 | datapipe = build_precomputed_data_pipeline(output_dir) 119 | for i, data in enumerate(datapipe): 120 | pipe_data_idx = 1 121 | for block in multiple_docs_test_data[i]: 122 | if data[0][0][pipe_data_idx] == precomputation_config.mask_text_token: 123 | assert torch.equal(data[1][0][pipe_data_idx], torch.tensor(precomputation_config.mask_bbox_token)) 124 | pipe_data_idx += 1 125 | else: 126 | assert not torch.equal(data[1][0][pipe_data_idx], torch.tensor(precomputation_config.mask_bbox_token)) 127 | pipe_data_idx += len(block[0]) 128 | -------------------------------------------------------------------------------- /test/data/pretraining/test_tensor_data_loader_pipe.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from typing import List, Tuple 4 | from unittest.mock import MagicMock 5 | 6 | import pytest 7 | import torch 8 | from torch.utils.data import IterDataPipe 9 | 10 | from docllm.data.pretraining.tensor_data_loader_pipe import TensorDataLoaderPipe 11 | 12 | 13 | @pytest.fixture 14 | def files_datapipe(input_dir_with_data: str) -> IterDataPipe: 15 | pipe = MagicMock(spec=IterDataPipe) 16 | pipe.__iter__.return_value = sorted( 17 | glob.glob(os.path.join(input_dir_with_data, "*.pt")), 18 | key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]), 19 | ) 20 | return pipe 21 | 22 | 23 | @pytest.fixture 24 | def tensor_data_loader_pipe(files_datapipe: IterDataPipe) -> TensorDataLoaderPipe: 25 | return TensorDataLoaderPipe(files_datapipe) 26 | 27 | 28 | @pytest.fixture 29 | def add_empty_file(input_dir_with_data: str): 30 | torch.save([], f"{input_dir_with_data}/empty_doc.py") 31 | 32 | 33 | def test_initialization(tensor_data_loader_pipe: TensorDataLoaderPipe): 34 | assert isinstance(tensor_data_loader_pipe, TensorDataLoaderPipe) 35 | 36 | 37 | def test_iter_produces_two_tuples(tensor_data_loader_pipe: TensorDataLoaderPipe): 38 | assert all(isinstance(res, tuple) and len(res) == 2 for res in tensor_data_loader_pipe) 39 | 40 | 41 | def test_iter_tuples_consist_of_lists(tensor_data_loader_pipe: TensorDataLoaderPipe): 42 | for res in tensor_data_loader_pipe: 43 | assert all(isinstance(lst, list) for lst in res) 44 | 45 | 46 | def test_iter_first_list_consist_of_long_tensors(tensor_data_loader_pipe: TensorDataLoaderPipe): 47 | for lst1, _ in tensor_data_loader_pipe: 48 | assert all(isinstance(t, torch.LongTensor) for t in lst1) 49 | 50 | 51 | def test_iter_second_list_consist_of_long_tensors(tensor_data_loader_pipe: TensorDataLoaderPipe): 52 | for _, lst2 in tensor_data_loader_pipe: 53 | assert all(isinstance(t, torch.FloatTensor) for t in lst2) 54 | 55 | 56 | def test_iter_matching_tensors_have_same_length(tensor_data_loader_pipe: TensorDataLoaderPipe): 57 | for lst1, lst2 in tensor_data_loader_pipe: 58 | for t1, t2 in zip(lst1, lst2): 59 | assert t1.size(0) == t2.size(0) 60 | 61 | 62 | def test_first_tensors_are_same_as_initial_data( 63 | tensor_data_loader_pipe: TensorDataLoaderPipe, 64 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 65 | ): 66 | for (lst1, _), data in zip(tensor_data_loader_pipe, multiple_docs_test_data): 67 | dat1, _ = zip(*data) 68 | assert all((t1 == d1).all() for t1, d1 in zip(lst1, dat1)) 69 | 70 | 71 | def test_second_tensors_are_same_as_initial_data( 72 | tensor_data_loader_pipe: TensorDataLoaderPipe, 73 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 74 | ): 75 | for (_, lst2), data in zip(tensor_data_loader_pipe, multiple_docs_test_data): 76 | _, dat2 = zip(*data) 77 | assert all((t2 == d2).all() for t2, d2 in zip(lst2, dat2)) 78 | 79 | 80 | @pytest.mark.usefixtures("add_empty_file") 81 | def test_empty_file_is_skipped( 82 | tensor_data_loader_pipe: TensorDataLoaderPipe, 83 | multiple_docs_test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 84 | ): 85 | assert len(list(tensor_data_loader_pipe)) == len(multiple_docs_test_data) 86 | -------------------------------------------------------------------------------- /test/data/pretraining/test_traindata_pipe.py: -------------------------------------------------------------------------------- 1 | import random 2 | from test.util import extract_return_value 3 | from typing import List, Tuple 4 | from unittest.mock import MagicMock 5 | 6 | import pytest 7 | import torch 8 | from torch.utils.data import IterDataPipe 9 | 10 | from docllm.data import DocLLMPreTrainDataConfig, DocLLMTrainDataPipe 11 | from docllm.data.pretraining.config import DocLLMPreTrainDataConfig 12 | from docllm.data.pretraining.traindata_pipe import DocLLMTrainDataPipe 13 | 14 | 15 | @pytest.fixture 16 | def test_data( 17 | num_docs: int, 18 | num_blocks: int, 19 | range_block_size: Tuple[int, int], 20 | ) -> List[List[Tuple[torch.LongTensor, torch.FloatTensor]]]: 21 | max_token = 1111 22 | return [ 23 | [ 24 | ( 25 | torch.randint(max_token, (b := random.randint(range_block_size[0], range_block_size[1]),)), 26 | torch.rand(b, 4), 27 | ) 28 | for _ in range(num_blocks) 29 | ] 30 | for _ in range(num_docs) 31 | ] 32 | 33 | 34 | @pytest.fixture 35 | def source_datapipe(test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]]) -> IterDataPipe: 36 | def split_list( 37 | lst: List[Tuple[torch.LongTensor, torch.FloatTensor]], 38 | ) -> Tuple[List[torch.LongTensor], List[torch.FloatTensor]]: 39 | l1, l2 = zip(*lst) 40 | return list(l1), list(l2) 41 | 42 | pipe = MagicMock(spec=IterDataPipe) 43 | pipe.__iter__.return_value = map(split_list, test_data) 44 | return pipe 45 | 46 | 47 | @pytest.fixture 48 | def docllm_train_data_pipe( 49 | source_datapipe: IterDataPipe, pretraining_config: DocLLMPreTrainDataConfig 50 | ) -> DocLLMTrainDataPipe: 51 | return DocLLMTrainDataPipe(source_datapipe, pretraining_config) 52 | 53 | 54 | @pytest.fixture 55 | def num_tokens_in_docs(test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]]) -> List[int]: 56 | return [sum(t.size(0) for t, _ in doc_data) for doc_data in test_data] 57 | 58 | 59 | def test_instantiation(docllm_train_data_pipe: DocLLMTrainDataPipe): 60 | assert isinstance(docllm_train_data_pipe, DocLLMTrainDataPipe) 61 | 62 | 63 | def test_iter_produces_three_tuples(docllm_train_data_pipe: DocLLMTrainDataPipe): 64 | assert all(isinstance(data, tuple) and len(data) == 4 for data in docllm_train_data_pipe) 65 | 66 | 67 | def test_all_tuple_entries_are_tensors(docllm_train_data_pipe: DocLLMTrainDataPipe): 68 | assert all(isinstance(tensor, torch.Tensor) for data in docllm_train_data_pipe for tensor in data) 69 | 70 | 71 | def test_first_tuple_entry_is_long_tensor(docllm_train_data_pipe: DocLLMTrainDataPipe): 72 | assert all(isinstance(t1, torch.LongTensor) for t1, _, _, _ in docllm_train_data_pipe) 73 | 74 | 75 | def test_second_tuple_entry_is_float_tensor(docllm_train_data_pipe: DocLLMTrainDataPipe): 76 | assert all(isinstance(t2, torch.FloatTensor) for _, t2, _, _ in docllm_train_data_pipe) 77 | 78 | 79 | def test_third_tuple_entry_is_bool_tensor(docllm_train_data_pipe: DocLLMTrainDataPipe): 80 | assert all(isinstance(t3, torch.BoolTensor) for _, _, t3, _ in docllm_train_data_pipe) 81 | 82 | 83 | def test_fourth_tuple_entry_is_long_tensor(docllm_train_data_pipe: DocLLMTrainDataPipe): 84 | assert all(isinstance(t4, torch.LongTensor) for _, _, _, t4 in docllm_train_data_pipe) 85 | 86 | 87 | def test_produces_correct_input_tensor_shapes( 88 | docllm_train_data_pipe: DocLLMTrainDataPipe, 89 | num_tokens_in_docs: List[int], 90 | pretraining_config: DocLLMPreTrainDataConfig, 91 | ): 92 | num_masks_values, docllm_train_data_pipe._masker._get_num_masks = extract_return_value( 93 | docllm_train_data_pipe._masker._get_num_masks 94 | ) 95 | count_bos_tokens = 1 if pretraining_config.bos_text_token else 0 96 | for (input_tensor, _, _, _), num_tokens, num_masks in zip( 97 | docllm_train_data_pipe, num_tokens_in_docs, num_masks_values 98 | ): 99 | assert input_tensor.shape == (count_bos_tokens + num_tokens + 2 * num_masks,) 100 | 101 | 102 | def test_produces_correct_spatial_input_tensor_shapes( 103 | docllm_train_data_pipe: DocLLMTrainDataPipe, 104 | num_tokens_in_docs: List[int], 105 | pretraining_config: DocLLMPreTrainDataConfig, 106 | ): 107 | num_masks_values, docllm_train_data_pipe._masker._get_num_masks = extract_return_value( 108 | docllm_train_data_pipe._masker._get_num_masks 109 | ) 110 | count_bos_tokens = 1 if pretraining_config.bos_text_token else 0 111 | for (_, spatial_input, _, _), num_tokens, num_masks in zip( 112 | docllm_train_data_pipe, num_tokens_in_docs, num_masks_values 113 | ): 114 | assert spatial_input.shape == (count_bos_tokens + num_tokens + 2 * num_masks, 4) 115 | 116 | 117 | def test_produces_correct_loss_mask_shapes( 118 | docllm_train_data_pipe: DocLLMTrainDataPipe, 119 | num_tokens_in_docs: List[int], 120 | pretraining_config: DocLLMPreTrainDataConfig, 121 | ): 122 | num_masks_values, docllm_train_data_pipe._masker._get_num_masks = extract_return_value( 123 | docllm_train_data_pipe._masker._get_num_masks 124 | ) 125 | count_bos_tokens = 1 if pretraining_config.bos_text_token else 0 126 | for (_, _, loss_mask, _), num_tokens, num_masks in zip( 127 | docllm_train_data_pipe, num_tokens_in_docs, num_masks_values 128 | ): 129 | assert loss_mask.shape == (count_bos_tokens + num_tokens + 2 * num_masks,) 130 | 131 | 132 | def test_produces_correct_label_tensor_shapes( 133 | docllm_train_data_pipe: DocLLMTrainDataPipe, 134 | num_tokens_in_docs: List[int], 135 | pretraining_config: DocLLMPreTrainDataConfig, 136 | ): 137 | num_masks_values, docllm_train_data_pipe._masker._get_num_masks = extract_return_value( 138 | docllm_train_data_pipe._masker._get_num_masks 139 | ) 140 | count_bos_tokens = 1 if pretraining_config.bos_text_token else 0 141 | for (_, _, _, label_tensor), num_tokens, num_masks in zip( 142 | docllm_train_data_pipe, num_tokens_in_docs, num_masks_values 143 | ): 144 | assert label_tensor.shape == (count_bos_tokens + num_tokens + 2 * num_masks,) 145 | 146 | 147 | def test_produces_correct_loss_mask_values(docllm_train_data_pipe: DocLLMTrainDataPipe): 148 | extraction_results, docllm_train_data_pipe._masker._extract_masked_tokens = extract_return_value( 149 | docllm_train_data_pipe._masker._extract_masked_tokens 150 | ) 151 | for (_, _, loss_mask, _), (_, _, target_tokens, _) in zip(docllm_train_data_pipe, extraction_results): 152 | len_masked_to_one = sum(t.size(0) for t in target_tokens) 153 | assert not loss_mask[:-len_masked_to_one].any() and loss_mask[-len_masked_to_one:].all() 154 | 155 | 156 | @pytest.mark.parametrize("num_masked_blocks,max_percentage_masked_blocks", [(1.0, 1.0)], indirect=True) 157 | def test_mask_whole_sequence_works_for_text_tokens( 158 | docllm_train_data_pipe: DocLLMTrainDataPipe, 159 | pretraining_config: DocLLMPreTrainDataConfig, 160 | test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 161 | ): 162 | for dat, (t1, _, _, _) in zip(test_data, docllm_train_data_pipe): 163 | assert (t1[1 : len(dat)] == pretraining_config.mask_text_token).all() 164 | 165 | 166 | @pytest.mark.parametrize("num_masked_blocks,max_percentage_masked_blocks", [(1.0, 1.0)], indirect=True) 167 | def test_mask_whole_sequence_works_for_box_tokens( 168 | docllm_train_data_pipe: DocLLMTrainDataPipe, 169 | pretraining_config: DocLLMPreTrainDataConfig, 170 | test_data: List[List[Tuple[torch.LongTensor, torch.FloatTensor]]], 171 | ): 172 | mask_bbox_token = torch.tensor(pretraining_config.mask_bbox_token) 173 | for dat, (_, t2, _, _) in zip(test_data, docllm_train_data_pipe): 174 | assert all((box == mask_bbox_token).all() for box in t2[1 : len(dat)]) 175 | -------------------------------------------------------------------------------- /test/modules/llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueCrescent/DocLLM/0f9b3b3fefbe983c3c9d1a5310d6adc18db366a7/test/modules/llama/__init__.py -------------------------------------------------------------------------------- /test/modules/llama/conftest.py: -------------------------------------------------------------------------------- 1 | from test.modules.llama.helpers import InputSizes, ModelInputs 2 | from typing import Tuple 3 | 4 | import pytest 5 | import torch 6 | from transformers.cache_utils import Cache, DynamicCache 7 | 8 | from docllm.modules.llama.config import DocLLMLlamaConfig 9 | 10 | 11 | @pytest.fixture 12 | def config() -> DocLLMLlamaConfig: 13 | return DocLLMLlamaConfig() 14 | 15 | 16 | @pytest.fixture 17 | def small_config() -> DocLLMLlamaConfig: 18 | return DocLLMLlamaConfig( 19 | hidden_size=256, 20 | num_attention_heads=4, 21 | num_hidden_layers=2, 22 | intermediate_size=1024, 23 | vocab_size=320, 24 | additional_training_vocab_size=3, 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def input_sizes() -> InputSizes: 30 | return InputSizes(batch_size=2, sequence_length=10) 31 | 32 | 33 | @pytest.fixture 34 | def model_inputs(input_sizes, config) -> ModelInputs: 35 | return ModelInputs( 36 | input_embeddings=torch.randn(input_sizes.batch_size, input_sizes.sequence_length, config.hidden_size), 37 | spatial_embeddings=torch.randn(input_sizes.batch_size, input_sizes.sequence_length, config.hidden_size), 38 | ) 39 | 40 | 41 | @pytest.fixture 42 | def attention_mask(input_sizes: InputSizes) -> torch.Tensor: 43 | return torch.ones(input_sizes.batch_size, 1, input_sizes.sequence_length, input_sizes.sequence_length).bool() 44 | 45 | 46 | @pytest.fixture 47 | def position_ids(input_sizes: InputSizes) -> torch.LongTensor: 48 | return torch.arange(input_sizes.sequence_length).unsqueeze(0).expand(input_sizes.batch_size, -1) 49 | 50 | 51 | @pytest.fixture 52 | def position_embeddings(model_inputs: ModelInputs, config: DocLLMLlamaConfig) -> Tuple[torch.Tensor, torch.Tensor]: 53 | return model_inputs.compute_positional_embeddings(config) 54 | 55 | 56 | @pytest.fixture 57 | def past_key_value() -> Cache: 58 | return DynamicCache() 59 | 60 | 61 | @pytest.fixture 62 | def spatial_past_key_value() -> Cache: 63 | return DynamicCache() 64 | -------------------------------------------------------------------------------- /test/modules/llama/helpers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from pydantic import BaseModel, ConfigDict 5 | from transformers import LlamaConfig 6 | from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding 7 | 8 | 9 | class InputSizes(BaseModel): 10 | batch_size: int 11 | sequence_length: int 12 | 13 | 14 | class ModelInputs(BaseModel): 15 | model_config = ConfigDict(arbitrary_types_allowed=True) 16 | 17 | input_embeddings: torch.Tensor 18 | spatial_embeddings: torch.Tensor 19 | 20 | def compute_positional_embeddings(self, config: LlamaConfig) -> Tuple[torch.Tensor, torch.Tensor]: 21 | cache_position = torch.arange(0, self.input_embeddings.shape[1], device=self.input_embeddings.device) 22 | position_ids = cache_position.unsqueeze(0) 23 | rotary_emb = LlamaRotaryEmbedding(config=config) 24 | position_embeddings = rotary_emb(self.input_embeddings, position_ids) 25 | return position_embeddings 26 | -------------------------------------------------------------------------------- /test/modules/llama/test_attention.py: -------------------------------------------------------------------------------- 1 | from test.modules.llama.helpers import InputSizes, ModelInputs 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers.cache_utils import Cache 7 | 8 | from docllm.modules.llama import DocLLMAttention, DocLLMLlamaConfig 9 | from docllm.modules.llama.config import PositionalEmbeddingMode 10 | 11 | 12 | def test_output_shape_is_input_shape( 13 | config: DocLLMLlamaConfig, model_inputs: ModelInputs, position_embeddings: Tuple[torch.Tensor, torch.Tensor] 14 | ): 15 | attention = DocLLMAttention(config) 16 | output, _, _, _ = attention( 17 | model_inputs.input_embeddings, model_inputs.spatial_embeddings, position_embeddings=position_embeddings 18 | ) 19 | assert output.shape == model_inputs.input_embeddings.shape 20 | 21 | 22 | def test_with_attention_mask( 23 | config: DocLLMLlamaConfig, 24 | model_inputs: ModelInputs, 25 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 26 | attention_mask: torch.Tensor, 27 | ): 28 | attention = DocLLMAttention(config) 29 | output, _, _, _ = attention( 30 | model_inputs.input_embeddings, 31 | model_inputs.spatial_embeddings, 32 | attention_mask=attention_mask, 33 | position_embeddings=position_embeddings, 34 | ) 35 | assert output.shape == model_inputs.input_embeddings.shape 36 | 37 | 38 | def test_with_position_ids( 39 | config: DocLLMLlamaConfig, model_inputs: ModelInputs, position_embeddings: Tuple[torch.Tensor, torch.Tensor] 40 | ): 41 | attention = DocLLMAttention(config) 42 | output, _, _, _ = attention( 43 | model_inputs.input_embeddings, model_inputs.spatial_embeddings, position_embeddings=position_embeddings 44 | ) 45 | assert output.shape == model_inputs.input_embeddings.shape 46 | 47 | 48 | def test_with_cache( 49 | config: DocLLMLlamaConfig, 50 | model_inputs: ModelInputs, 51 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 52 | past_key_value: Cache, 53 | spatial_past_key_value: Cache, 54 | ): 55 | attention = DocLLMAttention(config, 0) 56 | output, _, _, _ = attention( 57 | model_inputs.input_embeddings, 58 | model_inputs.spatial_embeddings, 59 | past_key_value=past_key_value, 60 | spatial_past_key_value=spatial_past_key_value, 61 | position_embeddings=position_embeddings, 62 | ) 63 | assert output.shape == model_inputs.input_embeddings.shape 64 | 65 | 66 | def test_output_shape_is_input_shape_with_all_parameters( 67 | config: DocLLMLlamaConfig, 68 | model_inputs: ModelInputs, 69 | attention_mask: torch.Tensor, 70 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 71 | past_key_value: Cache, 72 | spatial_past_key_value: Cache, 73 | ): 74 | attention = DocLLMAttention(config, layer_idx=0) 75 | output, _, _, _ = attention( 76 | model_inputs.input_embeddings, 77 | model_inputs.spatial_embeddings, 78 | attention_mask=attention_mask, 79 | past_key_value=past_key_value, 80 | spatial_past_key_value=spatial_past_key_value, 81 | position_embeddings=position_embeddings, 82 | ) 83 | assert output.shape == model_inputs.input_embeddings.shape 84 | 85 | 86 | def test_computed_q_k_v_have_expected_shape( 87 | config: DocLLMLlamaConfig, 88 | input_sizes: InputSizes, 89 | model_inputs: ModelInputs, 90 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 91 | past_key_value: Cache, 92 | ): 93 | head_dim = config.hidden_size // config.num_attention_heads 94 | attention = DocLLMAttention(config, layer_idx=0) 95 | q_proj = k_proj = v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * head_dim) 96 | q, k, v = attention._compute_q_k_v_for_projections( 97 | model_inputs.input_embeddings, q_proj, k_proj, v_proj, past_key_value, position_embeddings 98 | ) 99 | expected_shape = (input_sizes.batch_size, config.num_attention_heads, input_sizes.sequence_length, head_dim) 100 | assert q.shape == expected_shape 101 | assert k.shape == expected_shape 102 | assert v.shape == expected_shape 103 | 104 | 105 | def test_after_freezing_llama_weights_spatial_layers_are_not_frozen(config: DocLLMLlamaConfig): 106 | attention = DocLLMAttention(config) 107 | attention.set_freeze_llama_layers(True) 108 | for name, param in attention.named_parameters(recurse=True): 109 | assert not param.requires_grad or name.startswith("spatial_") 110 | 111 | 112 | def test_after_unfreezing_llama_weights_everything_is_not_frozen(config: DocLLMLlamaConfig): 113 | attention = DocLLMAttention(config) 114 | attention.set_freeze_llama_layers(True) 115 | attention.set_freeze_llama_layers(False) 116 | for param in attention.parameters(recurse=True): 117 | assert param.requires_grad 118 | -------------------------------------------------------------------------------- /test/modules/llama/test_causal_docllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tempfile import TemporaryDirectory 3 | 4 | import pytest 5 | import torch 6 | from transformers import LlamaForCausalLM 7 | 8 | from docllm.modules.llama.causal_docllm import CausalDocLLMOutputWithPast, CausalLlamaDocLLM 9 | from docllm.modules.llama.config import DocLLMLlamaConfig 10 | 11 | 12 | @pytest.fixture 13 | def small_model(small_config: DocLLMLlamaConfig) -> CausalLlamaDocLLM: 14 | return CausalLlamaDocLLM(small_config) 15 | 16 | 17 | def test_initialization(small_model: CausalLlamaDocLLM): 18 | assert isinstance(small_model, CausalLlamaDocLLM) 19 | 20 | 21 | def test_forward_result_has_expected_type(small_model: CausalLlamaDocLLM): 22 | input_ids = torch.tensor([[1, 2, 3]]) 23 | coordinates = torch.tensor([[[0.1, 0.0, 0.4, 1.0], [0.2, 0.1, 0.3, 0.2], [0.3, 0.2, 0.2, 0.3]]]) 24 | output = small_model.forward(input_ids, coordinates) 25 | assert isinstance(output, CausalDocLLMOutputWithPast) 26 | 27 | 28 | def test_forward_result_has_expected_logits_shape(small_model: CausalLlamaDocLLM, small_config: DocLLMLlamaConfig): 29 | input_ids = torch.tensor([[1, 2, 3]]) 30 | coordinates = torch.tensor([[[0.1, 0.0, 0.4, 1.0], [0.2, 0.1, 0.3, 0.2], [0.3, 0.2, 0.2, 0.3]]]) 31 | output = small_model.forward(input_ids, coordinates) 32 | assert output.logits.shape == ( 33 | input_ids.shape[0], 34 | input_ids.shape[1], 35 | small_config.vocab_size + small_config.additional_training_vocab_size, 36 | ) 37 | 38 | 39 | def test_loss_result_has_expected_shape(small_model: CausalLlamaDocLLM): 40 | input_ids = torch.tensor([[1, 2, 3]]) 41 | coordinates = torch.tensor([[[0.1, 0.0, 0.4, 1.0], [0.2, 0.1, 0.3, 0.2], [0.3, 0.2, 0.2, 0.3]]]) 42 | labels = torch.tensor([[2, 3, 4]]) 43 | loss_mask = torch.tensor([[0, 1, 1]]) 44 | output = small_model.forward(input_ids, coordinates, labels=labels, loss_mask=loss_mask) 45 | assert output.loss.shape == () 46 | 47 | 48 | def test_loss_is_zero_with_loss_mask_zero(small_model: CausalLlamaDocLLM): 49 | input_ids = torch.tensor([[1, 2, 3]]) 50 | coordinates = torch.tensor([[[0.1, 0.0, 0.4, 1.0], [0.2, 0.1, 0.3, 0.2], [0.3, 0.2, 0.2, 0.3]]]) 51 | labels = torch.tensor([[2, 3, 4]]) 52 | loss_mask = torch.tensor([[0, 0, 0]]) 53 | output = small_model.forward(input_ids, coordinates, labels=labels, loss_mask=loss_mask) 54 | assert output.loss == 0.0 55 | 56 | 57 | def test_loss_is_not_zero_with_loss_mask_one(small_model: CausalLlamaDocLLM): 58 | input_ids = torch.tensor([[1, 2, 3]]) 59 | coordinates = torch.tensor([[[0.1, 0.0, 0.4, 1.0], [0.2, 0.1, 0.3, 0.2], [0.3, 0.2, 0.2, 0.3]]]) 60 | labels = torch.tensor([[2, 3, 4]]) 61 | loss_mask = torch.tensor([[1, 1, 1]]) 62 | output = small_model.forward(input_ids, coordinates, labels=labels, loss_mask=loss_mask) 63 | assert output.loss != 0.0 64 | 65 | 66 | def test_after_freezing_llama_weights_spatial_layers_are_not_frozen(small_model: CausalLlamaDocLLM): 67 | small_model.set_freeze_llama_layers(True) 68 | for name, param in small_model.named_parameters(recurse=True): 69 | assert not param.requires_grad or "spatial" in name or "additional" in name 70 | 71 | 72 | def test_after_unfreezing_llama_weights_everything_is_not_frozen(small_model: CausalLlamaDocLLM): 73 | small_model.set_freeze_llama_layers(True) 74 | small_model.set_freeze_llama_layers(False) 75 | for param in small_model.parameters(recurse=True): 76 | assert param.requires_grad 77 | 78 | 79 | def test_loading_llama_weights_initiates_non_spatial_weights(small_config: DocLLMLlamaConfig): 80 | llama = LlamaForCausalLM(small_config) 81 | for param in llama.parameters(recurse=True): 82 | torch.nn.init.constant_(param, 1.0) 83 | with TemporaryDirectory() as dir: 84 | model_path = os.path.join(dir, "llama") 85 | llama.save_pretrained(model_path) 86 | model = CausalLlamaDocLLM.from_pretrained(model_path) 87 | for name, param in model.named_parameters(recurse=True): 88 | assert (param == 1.0).all().item() ^ ("spatial" in name or "additional" in name) 89 | 90 | 91 | def test_loading_llama_weights_does_not_touch_non_spatial_weights(small_config: DocLLMLlamaConfig): 92 | llama = LlamaForCausalLM(small_config) 93 | for param in llama.parameters(recurse=True): 94 | torch.nn.init.constant_(param, 1.0) 95 | with TemporaryDirectory() as dir: 96 | model_path = os.path.join(dir, "llama") 97 | llama.save_pretrained(model_path) 98 | model = CausalLlamaDocLLM.from_pretrained(model_path) 99 | for name, param in model.named_parameters(recurse=True): 100 | assert (param != 1.0).all().item() ^ ("spatial" not in name and "additional" not in name) 101 | 102 | 103 | def test_fuse_additional_embeddings_removes_additional_layers(small_model: CausalLlamaDocLLM): 104 | small_model.fuse_additional_embeddings() 105 | for name, _ in small_model.named_parameters(recurse=True): 106 | assert "additional" not in name 107 | 108 | 109 | def test_fuse_additional_embeddings_changes_config(small_model: CausalLlamaDocLLM, small_config: DocLLMLlamaConfig): 110 | expected_vocab_size_after_fuse = small_config.vocab_size + small_config.additional_training_vocab_size 111 | small_model.fuse_additional_embeddings() 112 | assert small_model.config.vocab_size == expected_vocab_size_after_fuse 113 | assert small_model.config.additional_training_vocab_size == 0 114 | 115 | 116 | def test_after_fusing_additional_embeddings_results_for_original_tokens_are_same( 117 | small_model: CausalLlamaDocLLM, small_config: DocLLMLlamaConfig 118 | ): 119 | input = torch.randint(0, small_config.vocab_size, (2, 3)) 120 | spatial_input = torch.rand((2, 3, 4)) 121 | output = small_model(input, spatial_input) 122 | small_model.fuse_additional_embeddings() 123 | fused_output = small_model(input, spatial_input) 124 | assert torch.allclose(output.logits, fused_output.logits) 125 | 126 | 127 | def test_after_fusing_additional_embeddings_results_for_new_tokens_are_same( 128 | small_model: CausalLlamaDocLLM, small_config: DocLLMLlamaConfig 129 | ): 130 | input = torch.randint( 131 | small_config.vocab_size, small_config.vocab_size + small_config.additional_training_vocab_size, (2, 3) 132 | ) 133 | spatial_input = torch.rand((2, 3, 4)) 134 | output = small_model(input, spatial_input) 135 | small_model.fuse_additional_embeddings() 136 | fused_output = small_model(input, spatial_input) 137 | assert torch.allclose(output.logits, fused_output.logits) 138 | -------------------------------------------------------------------------------- /test/modules/llama/test_decoder_layer.py: -------------------------------------------------------------------------------- 1 | from test.modules.llama.helpers import InputSizes, ModelInputs 2 | from typing import Tuple 3 | 4 | import torch 5 | from transformers.cache_utils import Cache 6 | 7 | from docllm.modules.llama import DocLLMLlamaConfig, DocLLMLlamaDecoderLayer 8 | 9 | 10 | def test_simple_forward_has_same_output_shape_as_input_shape( 11 | config: DocLLMLlamaConfig, model_inputs: ModelInputs, position_embeddings: Tuple[torch.Tensor, torch.Tensor] 12 | ): 13 | decoder_layer = DocLLMLlamaDecoderLayer(config, layer_idx=0) 14 | hidden_states, _, _, _ = decoder_layer.forward( 15 | model_inputs.input_embeddings, model_inputs.spatial_embeddings, position_embeddings=position_embeddings 16 | ) 17 | assert hidden_states.shape == model_inputs.input_embeddings.shape 18 | 19 | 20 | def test_output_attentions_have_correct_shape( 21 | config: DocLLMLlamaConfig, 22 | input_sizes: InputSizes, 23 | model_inputs: ModelInputs, 24 | attention_mask: torch.Tensor, 25 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 26 | ): 27 | decoder_layer = DocLLMLlamaDecoderLayer(config, layer_idx=0) 28 | _, self_attn_weights, _, _ = decoder_layer.forward( 29 | model_inputs.input_embeddings, 30 | model_inputs.spatial_embeddings, 31 | attention_mask=attention_mask, 32 | output_attentions=True, 33 | position_embeddings=position_embeddings, 34 | ) 35 | expected_attention_weights_shape = ( 36 | input_sizes.batch_size, 37 | config.num_attention_heads, 38 | input_sizes.sequence_length, 39 | input_sizes.sequence_length, 40 | ) 41 | assert self_attn_weights.shape == expected_attention_weights_shape 42 | 43 | 44 | def test_key_value_caches_are_none_by_default( 45 | config: DocLLMLlamaConfig, model_inputs: ModelInputs, position_embeddings: Tuple[torch.Tensor, torch.Tensor] 46 | ): 47 | decoder_layer = DocLLMLlamaDecoderLayer(config, layer_idx=0) 48 | _, _, present_key_value, spatial_present_key_value = decoder_layer.forward( 49 | model_inputs.input_embeddings, model_inputs.spatial_embeddings, position_embeddings=position_embeddings 50 | ) 51 | assert present_key_value is None 52 | assert spatial_present_key_value is None 53 | 54 | 55 | def test_key_value_caches_are_returned( 56 | config: DocLLMLlamaConfig, 57 | model_inputs: ModelInputs, 58 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 59 | past_key_value: Cache, 60 | spatial_past_key_value: Cache, 61 | ): 62 | decoder_layer = DocLLMLlamaDecoderLayer(config, layer_idx=0) 63 | _, _, present_key_value, spatial_present_key_value = decoder_layer.forward( 64 | model_inputs.input_embeddings, 65 | model_inputs.spatial_embeddings, 66 | past_key_value=past_key_value, 67 | spatial_past_key_value=spatial_past_key_value, 68 | position_embeddings=position_embeddings, 69 | ) 70 | assert present_key_value is not None 71 | assert spatial_present_key_value is not None 72 | 73 | 74 | def test_after_freezing_llama_weights_spatial_layers_are_not_frozen(config: DocLLMLlamaConfig): 75 | decoder_layer = DocLLMLlamaDecoderLayer(config, layer_idx=0) 76 | decoder_layer.set_freeze_llama_layers(True) 77 | for name, param in decoder_layer.named_parameters(recurse=True): 78 | assert not param.requires_grad or "spatial_" in name 79 | 80 | 81 | def test_after_unfreezing_llama_weights_everything_is_not_frozen(config: DocLLMLlamaConfig): 82 | decoder_layer = DocLLMLlamaDecoderLayer(config, layer_idx=0) 83 | decoder_layer.set_freeze_llama_layers(True) 84 | decoder_layer.set_freeze_llama_layers(False) 85 | for param in decoder_layer.parameters(recurse=True): 86 | assert param.requires_grad 87 | -------------------------------------------------------------------------------- /test/modules/part_freezable_embedding_test_helpers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import PretrainedConfig, PreTrainedModel 3 | 4 | from docllm.modules.part_freezable_embedding import PartFreezableEmbedding 5 | 6 | 7 | class EmbeddingConfig(PretrainedConfig): 8 | def __init__( 9 | self, 10 | num_embeddings: int = 5, 11 | embedding_dim: int = 11, 12 | **kwargs, 13 | ) -> None: 14 | super().__init__(**kwargs) 15 | self.num_embeddings = num_embeddings 16 | self.embedding_dim = embedding_dim 17 | 18 | 19 | class EmbeddingModel(PreTrainedModel): 20 | config_class = EmbeddingConfig 21 | 22 | def __init__(self, config: EmbeddingConfig, **kwargs) -> None: 23 | super().__init__(config) 24 | self.emb = nn.Embedding(num_embeddings=config.num_embeddings, embedding_dim=config.embedding_dim, **kwargs) 25 | 26 | def _init_weights(self, module): 27 | if isinstance(module, nn.Embedding): 28 | module.weight.data.fill_(1.0) 29 | if module.padding_idx is not None: 30 | module.weight.data[module.padding_idx].fill_(1.0) 31 | 32 | 33 | class PFEmbeddingConfig(EmbeddingConfig): 34 | def __init__( 35 | self, 36 | num_embeddings: int = 5, 37 | embedding_dim: int = 11, 38 | num_additional_tokens: int = 0, 39 | **kwargs, 40 | ) -> None: 41 | super().__init__(num_embeddings, embedding_dim, **kwargs) 42 | self.num_additional_tokens = num_additional_tokens 43 | 44 | 45 | class PFEmbeddingModel(PreTrainedModel): 46 | config_class = PFEmbeddingConfig 47 | 48 | def __init__( 49 | self, 50 | config: PFEmbeddingConfig, 51 | **kwargs, 52 | ) -> None: 53 | super().__init__(config) 54 | self.emb = PartFreezableEmbedding( 55 | num_embeddings=config.num_embeddings, 56 | embedding_dim=config.embedding_dim, 57 | num_additional_tokens=config.num_additional_tokens, 58 | **kwargs, 59 | ) 60 | 61 | def _init_weights(self, module): 62 | if isinstance(module, nn.Embedding): 63 | module.weight.data.zero_() 64 | if module.padding_idx is not None: 65 | module.weight.data[module.padding_idx].zero_() 66 | -------------------------------------------------------------------------------- /test/modules/part_freezable_linear_test_helpers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import PretrainedConfig, PreTrainedModel 3 | 4 | from docllm.modules.part_freezable_linear import PartFreezableLinear 5 | 6 | 7 | class LinearConfig(PretrainedConfig): 8 | def __init__( 9 | self, 10 | in_features: int = 5, 11 | out_features: int = 11, 12 | **kwargs, 13 | ) -> None: 14 | super().__init__(**kwargs) 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | 18 | 19 | class LinearModel(PreTrainedModel): 20 | config_class = LinearConfig 21 | 22 | def __init__(self, config: LinearConfig, **kwargs) -> None: 23 | super().__init__(config) 24 | self.linear = nn.Linear(in_features=config.in_features, out_features=config.out_features, **kwargs) 25 | 26 | def _init_weights(self, module): 27 | if isinstance(module, nn.Linear): 28 | module.weight.data.fill_(1.0) 29 | if module.bias is not None: 30 | module.bias.data.fill_(1.0) 31 | 32 | 33 | class PFLinearConfig(LinearConfig): 34 | def __init__( 35 | self, 36 | in_features: int = 5, 37 | out_features: int = 11, 38 | num_additional_outputs: int = 0, 39 | **kwargs, 40 | ) -> None: 41 | super().__init__(in_features, out_features, **kwargs) 42 | self.num_additional_outputs = num_additional_outputs 43 | 44 | 45 | class PFLinearModel(PreTrainedModel): 46 | config_class = PFLinearConfig 47 | 48 | def __init__( 49 | self, 50 | config: PFLinearConfig, 51 | **kwargs, 52 | ) -> None: 53 | super().__init__(config) 54 | self.linear = PartFreezableLinear( 55 | in_features=config.in_features, 56 | out_features=config.out_features, 57 | num_additional_outputs=config.num_additional_outputs, 58 | **kwargs, 59 | ) 60 | 61 | def _init_weights(self, module): 62 | if isinstance(module, nn.Linear): 63 | module.weight.data.zero_() 64 | if module.bias is not None: 65 | module.bias.data.zero_() 66 | -------------------------------------------------------------------------------- /test/modules/test_part_freezable_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tempfile import TemporaryDirectory 3 | from test.modules.part_freezable_embedding_test_helpers import ( 4 | EmbeddingConfig, 5 | EmbeddingModel, 6 | PFEmbeddingConfig, 7 | PFEmbeddingModel, 8 | ) 9 | 10 | import pytest 11 | import torch 12 | 13 | from docllm.modules.part_freezable_embedding import PartFreezableEmbedding 14 | 15 | 16 | @pytest.fixture 17 | def num_embeddings() -> int: 18 | return 5 19 | 20 | 21 | @pytest.fixture 22 | def embedding_dim() -> int: 23 | return 11 24 | 25 | 26 | @pytest.fixture 27 | def num_additional_tokens(request) -> int: 28 | if not hasattr(request, "param") or request.param is None: 29 | return 3 30 | return request.param 31 | 32 | 33 | @pytest.fixture 34 | def part_freezable_emb( 35 | num_embeddings, 36 | embedding_dim, 37 | num_additional_tokens, 38 | ) -> PartFreezableEmbedding: 39 | return PartFreezableEmbedding( 40 | num_embeddings=num_embeddings, 41 | embedding_dim=embedding_dim, 42 | num_additional_tokens=num_additional_tokens, 43 | ) 44 | 45 | 46 | def test_part_freezable_embeddings_init(part_freezable_emb: PartFreezableEmbedding): 47 | assert isinstance(part_freezable_emb, PartFreezableEmbedding) 48 | 49 | 50 | @pytest.mark.parametrize("num_additional_tokens", (0,), indirect=True) 51 | def test_part_freezable_embeddings_returns_expected_shape_for_original_token_without_add_tokens( 52 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int, embedding_dim: int 53 | ): 54 | input = torch.randint(0, num_embeddings, (2, 3)) 55 | token_emb = part_freezable_emb(input) 56 | expected_size = tuple(input.shape) + (embedding_dim,) 57 | assert token_emb.shape == expected_size 58 | 59 | 60 | def test_part_freezable_embeddings_returns_expected_shape_for_original_token_with_add_tokens( 61 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int, embedding_dim: int 62 | ): 63 | input = torch.randint(0, num_embeddings, (2, 3)) 64 | token_emb = part_freezable_emb(input) 65 | expected_size = tuple(input.shape) + (embedding_dim,) 66 | assert token_emb.shape == expected_size 67 | 68 | 69 | def test_part_freezable_embeddings_returns_expected_shape_for_new_token( 70 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int, embedding_dim: int, num_additional_tokens: int 71 | ): 72 | input = torch.randint(num_embeddings, num_embeddings + num_additional_tokens, (2, 3)) 73 | token_emb = part_freezable_emb(input) 74 | expected_size = tuple(input.shape) + (embedding_dim,) 75 | assert token_emb.shape == expected_size 76 | 77 | 78 | @pytest.mark.parametrize("num_additional_tokens", (0,), indirect=True) 79 | def test_part_freezable_embeddings_backward_for_original_token_without_add_tokens( 80 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int 81 | ): 82 | input = torch.randint(0, num_embeddings, (2, 3)) 83 | token_emb = part_freezable_emb(input) 84 | loss = token_emb.sum() 85 | loss.backward() 86 | 87 | 88 | def test_part_freezable_embeddings_backward_for_original_token_with_add_tokens( 89 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int 90 | ): 91 | input = torch.randint(0, num_embeddings, (2, 3)) 92 | token_emb = part_freezable_emb(input) 93 | loss = token_emb.sum() 94 | loss.backward() 95 | 96 | 97 | def test_part_freezable_embeddings_backward_for_new_token( 98 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int, num_additional_tokens: int 99 | ): 100 | input = torch.randint(num_embeddings, num_embeddings + num_additional_tokens, (2, 3)) 101 | token_emb = part_freezable_emb(input) 102 | loss = token_emb.sum() 103 | loss.backward() 104 | 105 | 106 | def test_after_freezing_original_embs_additional_ones_are_not_frozen(part_freezable_emb: PartFreezableEmbedding): 107 | part_freezable_emb.set_freeze_original_embeddings(True) 108 | for name, param in part_freezable_emb.named_parameters(recurse=True): 109 | assert (not param.requires_grad) ^ ("additional" in name) 110 | 111 | 112 | def test_after_unfreezing_original_embs_everything_is_not_frozen(part_freezable_emb: PartFreezableEmbedding): 113 | part_freezable_emb.set_freeze_original_embeddings(True) 114 | part_freezable_emb.set_freeze_original_embeddings(False) 115 | for param in part_freezable_emb.parameters(recurse=True): 116 | assert param.requires_grad 117 | 118 | 119 | def test_loading_part_freezable_embeddings_loads_normal_embeddings(): 120 | original_model = EmbeddingModel(EmbeddingConfig()) 121 | original_model.init_weights() 122 | with TemporaryDirectory() as dir: 123 | model_path = os.path.join(dir, "embeddings") 124 | original_model.save_pretrained(model_path) 125 | model = PFEmbeddingModel.from_pretrained(model_path, config=PFEmbeddingConfig(num_additional_tokens=3)) 126 | for name, param in model.named_parameters(recurse=True): 127 | assert (param == 1.0).all().item() ^ ("additional" in name) 128 | 129 | 130 | def test_loading_part_freezable_embeddings_does_set_additional_weights(): 131 | original_model = EmbeddingModel(EmbeddingConfig()) 132 | original_model.init_weights() 133 | with TemporaryDirectory() as dir: 134 | model_path = os.path.join(dir, "embeddings") 135 | original_model.save_pretrained(model_path) 136 | model = PFEmbeddingModel.from_pretrained(model_path, config=PFEmbeddingConfig(num_additional_tokens=3)) 137 | for name, param in model.named_parameters(recurse=True): 138 | assert (param == 0.0).all().item() ^ ("additional" not in name) 139 | 140 | 141 | def test_additional_embeddings_are_in_parameters(): 142 | model = PFEmbeddingModel(PFEmbeddingConfig(num_additional_tokens=3)) 143 | assert any("additional" in name for name, _ in model.named_parameters(recurse=True)) 144 | 145 | 146 | def test_additional_tokens_cause_expected_additional_num_trainable_parameter(): 147 | config = PFEmbeddingConfig(num_additional_tokens=3) 148 | original_model_num_params = EmbeddingModel(config).num_parameters() 149 | new_num_params = PFEmbeddingModel(config).num_parameters() 150 | assert new_num_params - original_model_num_params == config.num_additional_tokens * config.embedding_dim 151 | 152 | 153 | def test_freezing_original_embedding_weights_leaves_remaining_num_additional_trainable_parameter(): 154 | config = PFEmbeddingConfig(num_additional_tokens=3) 155 | model = PFEmbeddingModel(config) 156 | model.emb.set_freeze_original_embeddings(True) 157 | new_num_params = model.num_parameters(only_trainable=True) 158 | assert new_num_params == config.num_additional_tokens * config.embedding_dim 159 | 160 | 161 | def test_fusing_additional_embeddings_removes_additional_embeddings(part_freezable_emb: PartFreezableEmbedding): 162 | part_freezable_emb.fuse_additional_embeddings() 163 | for name, _ in part_freezable_emb.named_parameters(recurse=True): 164 | assert "additional" not in name 165 | 166 | 167 | def test_fusing_additional_embeddings_sets_num_additional_embeddings_to_zero( 168 | part_freezable_emb: PartFreezableEmbedding, 169 | ): 170 | part_freezable_emb.fuse_additional_embeddings() 171 | assert part_freezable_emb._num_additional_tokens == 0 172 | 173 | 174 | def test_after_fusing_additional_embeddings_results_for_original_tokens_are_same( 175 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int 176 | ): 177 | input = torch.randint(0, num_embeddings, (2, 3)) 178 | output = part_freezable_emb(input) 179 | part_freezable_emb.fuse_additional_embeddings() 180 | fused_output = part_freezable_emb(input) 181 | assert torch.allclose(output, fused_output) 182 | 183 | 184 | def test_after_fusing_additional_embeddings_results_for_new_tokens_are_same( 185 | part_freezable_emb: PartFreezableEmbedding, num_embeddings: int, num_additional_tokens: int 186 | ): 187 | input = torch.randint(num_embeddings, num_embeddings + num_additional_tokens, (2, 3)) 188 | output = part_freezable_emb(input) 189 | part_freezable_emb.fuse_additional_embeddings() 190 | fused_output = part_freezable_emb(input) 191 | assert torch.allclose(output, fused_output) 192 | -------------------------------------------------------------------------------- /test/modules/test_part_freezable_linear.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tempfile import TemporaryDirectory 3 | from test.modules.part_freezable_linear_test_helpers import LinearConfig, LinearModel, PFLinearConfig, PFLinearModel 4 | 5 | import pytest 6 | import torch 7 | 8 | from docllm.modules.part_freezable_linear import PartFreezableLinear 9 | 10 | 11 | @pytest.fixture 12 | def in_features() -> int: 13 | return 11 14 | 15 | 16 | @pytest.fixture 17 | def out_features() -> int: 18 | return 5 19 | 20 | 21 | @pytest.fixture 22 | def num_additional_outputs(request) -> int: 23 | if not hasattr(request, "param") or request.param is None: 24 | return 3 25 | return request.param 26 | 27 | 28 | @pytest.fixture 29 | def part_freezable_linear( 30 | in_features, 31 | out_features, 32 | num_additional_outputs, 33 | ) -> PartFreezableLinear: 34 | return PartFreezableLinear( 35 | in_features=in_features, 36 | out_features=out_features, 37 | num_additional_outputs=num_additional_outputs, 38 | ) 39 | 40 | 41 | def test_init_part_freezable_linear(part_freezable_linear: PartFreezableLinear): 42 | assert isinstance(part_freezable_linear, PartFreezableLinear) 43 | 44 | 45 | @pytest.mark.parametrize("num_additional_outputs", (0,), indirect=True) 46 | def test_part_freezable_linear_returns_expected_shape_without_additional_outputs( 47 | part_freezable_linear: PartFreezableLinear, in_features: int, out_features: int 48 | ): 49 | input = torch.rand((2, 3, in_features)) 50 | token_emb = part_freezable_linear(input) 51 | expected_size = tuple(input.shape[:-1]) + (out_features,) 52 | assert token_emb.shape == expected_size 53 | 54 | 55 | def test_part_freezable_linear_returns_expected_shape_with_additional_outputs( 56 | part_freezable_linear: PartFreezableLinear, in_features: int, out_features: int, num_additional_outputs: int 57 | ): 58 | input = torch.rand((2, 3, in_features)) 59 | token_emb = part_freezable_linear(input) 60 | expected_size = tuple(input.shape[:-1]) + (out_features + num_additional_outputs,) 61 | assert token_emb.shape == expected_size 62 | 63 | 64 | @pytest.mark.parametrize("num_additional_outputs", (0,), indirect=True) 65 | def test_part_freezable_linear_backward_without_additional_outputs( 66 | part_freezable_linear: PartFreezableLinear, in_features: int 67 | ): 68 | input = torch.rand((2, 3, in_features)) 69 | output = part_freezable_linear(input) 70 | loss = output.sum() 71 | loss.backward() 72 | 73 | 74 | def test_part_freezable_linear_backward_with_additional_outputs( 75 | part_freezable_linear: PartFreezableLinear, in_features: int 76 | ): 77 | input = torch.rand((2, 3, in_features)) 78 | output = part_freezable_linear(input) 79 | loss = output.sum() 80 | loss.backward() 81 | 82 | 83 | def test_after_freezing_original_linear_layer_additional_outputs_are_not_frozen( 84 | part_freezable_linear: PartFreezableLinear, 85 | ): 86 | part_freezable_linear.set_freeze_original_outputs(True) 87 | for name, param in part_freezable_linear.named_parameters(recurse=True): 88 | assert (not param.requires_grad) ^ ("additional" in name) 89 | 90 | 91 | def test_after_unfreezing_original_linear_layer_everything_is_not_frozen(part_freezable_linear: PartFreezableLinear): 92 | part_freezable_linear.set_freeze_original_outputs(True) 93 | part_freezable_linear.set_freeze_original_outputs(False) 94 | for param in part_freezable_linear.parameters(recurse=True): 95 | assert param.requires_grad 96 | 97 | 98 | def test_loading_part_freezable_linear_loads_normal_embeddings(): 99 | original_model = LinearModel(LinearConfig()) 100 | original_model.init_weights() 101 | with TemporaryDirectory() as dir: 102 | model_path = os.path.join(dir, "linear") 103 | original_model.save_pretrained(model_path) 104 | model = PFLinearModel.from_pretrained(model_path, config=PFLinearConfig(num_additional_outputs=3)) 105 | for name, param in model.named_parameters(recurse=True): 106 | assert (param == 1.0).all().item() ^ ("additional" in name) 107 | 108 | 109 | def test_loading_part_freezable_linear_does_set_additional_weights(): 110 | original_model = LinearModel(LinearConfig()) 111 | original_model.init_weights() 112 | with TemporaryDirectory() as dir: 113 | model_path = os.path.join(dir, "linear") 114 | original_model.save_pretrained(model_path) 115 | model = PFLinearModel.from_pretrained(model_path, config=PFLinearConfig(num_additional_outputs=3)) 116 | for name, param in model.named_parameters(recurse=True): 117 | assert (param == 0.0).all().item() ^ ("additional" not in name) 118 | 119 | 120 | def test_additional_outputs_are_in_parameters(): 121 | model = PFLinearModel(PFLinearConfig(num_additional_outputs=3)) 122 | assert any("additional" in name for name, _ in model.named_parameters(recurse=True)) 123 | 124 | 125 | def test_additional_outputs_cause_expected_additional_num_trainable_parameter(): 126 | config = PFLinearConfig(num_additional_outputs=3) 127 | original_model_num_params = LinearModel(config).num_parameters() 128 | new_num_params = PFLinearModel(config).num_parameters() 129 | assert new_num_params - original_model_num_params == (config.in_features + 1) * config.num_additional_outputs 130 | 131 | 132 | def test_freezing_original_linear_weights_leaves_remaining_num_additional_trainable_parameter(): 133 | config = PFLinearConfig(num_additional_outputs=3) 134 | model = PFLinearModel(config) 135 | model.linear.set_freeze_original_outputs(True) 136 | new_num_params = model.num_parameters(only_trainable=True) 137 | assert new_num_params == (config.in_features + 1) * config.num_additional_outputs 138 | 139 | 140 | def test_fusing_additional_outputs_removes_additional_outputs(part_freezable_linear: PartFreezableLinear): 141 | part_freezable_linear.fuse_additional_outputs() 142 | for name, _ in part_freezable_linear.named_parameters(recurse=True): 143 | assert "additional" not in name 144 | 145 | 146 | def test_fusing_additional_outputs_sets_num_additional_outputs_to_zero(part_freezable_linear: PartFreezableLinear): 147 | part_freezable_linear.fuse_additional_outputs() 148 | assert part_freezable_linear._num_additional_outputs == 0 149 | 150 | 151 | def test_after_fusing_additional_outputs_results_are_same(part_freezable_linear: PartFreezableLinear, in_features: int): 152 | input = torch.rand((2, 3, in_features)) 153 | output = part_freezable_linear(input) 154 | part_freezable_linear.fuse_additional_outputs() 155 | fused_output = part_freezable_linear(input) 156 | assert torch.allclose(output, fused_output) 157 | -------------------------------------------------------------------------------- /test/test_pretraining_loss.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | 4 | import pytest 5 | import torch 6 | 7 | from docllm.pretraining_loss import DocLLMCrossEntropyLoss 8 | 9 | 10 | @pytest.fixture 11 | def batch_size() -> int: 12 | return 10 13 | 14 | 15 | @pytest.fixture 16 | def seq_len() -> int: 17 | return 17 18 | 19 | 20 | @pytest.fixture 21 | def vocab_size() -> int: 22 | return 7 23 | 24 | 25 | @pytest.fixture 26 | def loss_inputs(batch_size: int, seq_len: int, vocab_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 27 | logits = torch.rand(batch_size, seq_len, vocab_size) 28 | labels = torch.randint(vocab_size, (batch_size, seq_len)) 29 | mask = torch.ones(batch_size, seq_len, dtype=torch.bool) 30 | return logits, labels, mask 31 | 32 | 33 | def test_initialization(): 34 | loss = DocLLMCrossEntropyLoss() 35 | assert isinstance(loss, DocLLMCrossEntropyLoss) 36 | 37 | 38 | def test_forward_creates_tensor(loss_inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 39 | loss = DocLLMCrossEntropyLoss() 40 | logits, labels, mask = loss_inputs 41 | result = loss(logits, labels, mask) 42 | assert isinstance(result, torch.Tensor) 43 | 44 | 45 | def test_forward_output_is_single_value(loss_inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 46 | loss = DocLLMCrossEntropyLoss() 47 | logits, labels, mask = loss_inputs 48 | result = loss(logits, labels, mask) 49 | assert result.size() == torch.Size([]) 50 | 51 | 52 | def test_with_mask_all_zero_loss_is_zero(loss_inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 53 | loss = DocLLMCrossEntropyLoss() 54 | logits, labels, mask = loss_inputs 55 | mask = torch.zeros_like(mask, dtype=torch.bool) 56 | result = loss(logits, labels, mask) 57 | assert result.item() == 0.0 58 | 59 | 60 | def test_with_mask_all_one_loss_is_not_zero(loss_inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 61 | loss = DocLLMCrossEntropyLoss() 62 | logits, labels, mask = loss_inputs 63 | # Ensure loss is not actually zero. 64 | logits[0][0].fill_(value=-100000.0) 65 | logits[0][0][0] = 0.0 66 | labels[0][0] = 1 67 | result = loss(logits, labels, mask) 68 | assert result.item() != pytest.approx(0.0, 0.1) 69 | 70 | 71 | @pytest.mark.parametrize("correct_index", ((0, 0), (1, 3), (9, 15))) 72 | def test_loss_is_zero_with_mask_only_one_at_correct_prediction( 73 | loss_inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], correct_index: Tuple[int, int] 74 | ): 75 | batch_idx, seq_idx = correct_index 76 | loss = DocLLMCrossEntropyLoss() 77 | logits, labels, mask = loss_inputs 78 | # Values that cause a loss of zero. 79 | logits[batch_idx][seq_idx].fill_(value=-10000.0) 80 | logits[batch_idx][seq_idx][labels[batch_idx][seq_idx]] = 0.0 81 | # Use mask to filter out all other values. 82 | mask = torch.zeros_like(mask, dtype=torch.bool) 83 | mask[batch_idx][seq_idx] = 1 84 | result = loss(logits, labels, mask) 85 | assert result.item() == pytest.approx(0.0, 0.00001) 86 | 87 | 88 | def test_masked_result_is_same_as_result_with_ignore_index_labels( 89 | loss_inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 90 | ): 91 | loss = DocLLMCrossEntropyLoss(ignore_index=-100) 92 | logits, labels, mask = loss_inputs 93 | 94 | positions = random.sample(range(mask.numel()), 17) 95 | mask.view(-1)[positions] = 0 96 | labels_with_ignore_index = labels.clone() 97 | labels_with_ignore_index.view(-1)[positions] = -100 98 | 99 | masked_result = loss(logits, labels, mask) 100 | ignore_idx_result = loss(logits, labels_with_ignore_index, None) 101 | assert masked_result == ignore_idx_result 102 | -------------------------------------------------------------------------------- /test/util.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, ParamSpec, Tuple, TypeVar 2 | 3 | Params = ParamSpec("Params") 4 | Ret = TypeVar("Ret") 5 | 6 | 7 | def extract_return_value(fct: Callable[Params, Ret]) -> Tuple[List[Ret], Callable[Params, Ret]]: 8 | result_list = [] 9 | 10 | def new_func(*args: Params.args, **kwargs: Params.kwargs) -> Ret: 11 | res = fct(*args, **kwargs) 12 | result_list.append(res) 13 | return res 14 | 15 | return result_list, new_func 16 | --------------------------------------------------------------------------------