├── .gitignore ├── GoNotoCurrent.ttf ├── LICENSE ├── README.md ├── data.py ├── image_utils.py ├── imgs └── ptp.png ├── model_configs ├── ptp │ ├── config.json │ ├── merges.txt │ ├── preprocessor_config.json │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json └── screenshot-llama-380m │ ├── config.json │ ├── preprocessor_config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer.model │ └── tokenizer_config.json ├── modeling ├── configuration_pixel.py ├── configuration_ptp.py ├── configuration_screenshot_llama.py ├── modeling_pixel.py ├── modeling_ptp.py ├── modeling_screenshot_llama.py ├── processing_ptp.py ├── processing_screenshot_llama.py ├── sincos_pos.py └── span_masking.py ├── rendering └── src │ ├── Makefile │ ├── README.md │ ├── main.cpp │ ├── render.cpp │ ├── render.h │ ├── renderer.cpp │ ├── renderer.cpython-310-x86_64-linux-gnu.so │ └── renderer.cpython-39-x86_64-linux-gnu.so ├── requirements.txt ├── run.py ├── run_configs ├── ptp.yaml ├── screenshot-llama-1.3b-from-sheared-llama.yaml └── screenshot-llama-380m.yaml ├── run_multiple_gpus.sh ├── run_single_gpu.sh ├── streaming_data.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .venv* 163 | data 164 | wandb 165 | result 166 | slurm 167 | slurm_test.sh 168 | -------------------------------------------------------------------------------- /GoNotoCurrent.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/PTP/e03435f1ec1d9902bd6663798dc84f7382ca7286/GoNotoCurrent.ttf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Princeton Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Language Understanding from Screenshots 2 | 3 | This repository contains the code, data, and models for paper [Improving Language Understanding from Screenshots](https://arxiv.org/abs/2402.14073). In this paper, we focus on improving the language understanding ability of "screenshot LM" (models that process everything -- including text -- within visual inputs) and propose patch-and-text prediction (PTP), a novel pre-training objective for screenshot LMs. 4 | 5 | ![Illustration for PTP](imgs/ptp.png) 6 | 7 | 8 | ## Quick Links 9 | 10 | - [Environment](#environment) 11 | - [Preparing the data](#preparing-the-data) 12 | - [Reproducing our pre-trained models](#reproducing-our-pre-trained-models) 13 | - [Downloading our models](#downloading-our-models) 14 | - [Fine-tuning PTP models](#fine-tuning-ptp-models) 15 | - [Bugs or Questions?](#bugs-or-questions) 16 | - [Citation](#citation) 17 | 18 | 19 | ## Environment 20 | 21 | Firstly, please install the latest compatible [PyTorch](https://pytorch.org). 22 | 23 | Then, install all the required packages by running: 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | We strongly recommend using the exact same `transformers` and `accelerate` versions for best reproducibility. Please checkout the [renderer readme](./rendering/src) to make sure that the renderer is correctly configured. 29 | 30 | 31 | ## Preparing the data 32 | 33 | For our encoder-decoder experiments and the train-from-scratch autoregressive screenshot LM experiments, we use Wikipedia+BookCorpus as the pre-training data. You can find the already-tokenized dataset from [this Huggingface website](https://huggingface.co/datasets/princeton-nlp/ptp_data). You can download the data by 34 | ```bash 35 | git clone https://huggingface.co/datasets/princeton-nlp/ptp_data data 36 | ``` 37 | This folder contains four files 38 | * `wikibook_256_opt_tk_train.npy` and `wikibook_256_opt_tk_val.npy`: Wiki+Book using OPT tokenizer, 256 tokens per example (for encoder-decoder). 39 | * `wikibook_512_llama_tk_train.npy` and `wikibook_512_llama_tk_val.npy`: Wiki+Book using LLAMA tokenizer, 512 tokens per example (for train-from scratch autoregressive). 40 | 41 | For continuing training [Sheared-llama](https://github.com/princeton-nlp/LLM-Shearing) to use screenshots, we use Sheared-llama's pipeline for processing [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T) data. Please follow [this guideline](https://github.com/princeton-nlp/LLM-Shearing/tree/main/llmshearing/data) for processing the data. Our example config will use `./data/sheared-llama-rp/for_ft` for continuing pre-training and `./data/sheared-llama-rp/eval` for evaluation. 42 | 43 | 44 | ## Reproducing our pre-trained models 45 | 46 | 47 | To reproduce our models, run the following command (requires 8 GPUs): 48 | ```bash 49 | NUM_GPU=8 bash run_multiple_gpus.sh {CONFIG PATH} 50 | ``` 51 | There are three example configs: 52 | * `run_configs/ptp.yaml`: our main PTP model (encoder-decoder). 53 | * `run_configs/screenshot-llama-380m.yaml`: train-from-scratch autoregressive. 54 | * `run_configs/screenshot-llama-1.3b-from-sheared-llama.yaml`: continuing pre-training sheared-llama. 55 | 56 | You can also run the single-GPU command `run_single_gpu.sh` for testing. To ensure the same hyperparameters, you should adjust the per-GPU batch size (`per_device_train_batch_size`) or the gradient accumulation steps (`gradient_accumulation_steps`) accordingly if you are not using 8 GPUs or your GPUs cannot fit our preset batch sizes. 57 | 58 | ## Downloading our models 59 | 60 | We provide the following pre-trained models on Huggingface: 61 | 62 | * [princeton-nlp/ptp](https://huggingface.co/princeton-nlp/ptp) 63 | * [princeton-nlp/screenshot-llama-380m](https://huggingface.co/princeton-nlp/screenshot-llama-380m) 64 | * [princeton-nlp/screenshot-llama-1.3b-from-sheared-llama](https://huggingface.co/princeton-nlp/screenshot-llama-1.3b-from-sheared-llama) 65 | 66 | ## Fine-tuning PTP models 67 | 68 | Coming soon! 69 | 70 | ## Bugs or questions? 71 | 72 | If you have any questions related to the paper, feel free to email Tianyu (`tianyug@cs.princeton.edu`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! 73 | 74 | ## Citation 75 | 76 | Please cite our paper if you use PTP in your work: 77 | 78 | ```bibtex 79 | @article{gao2024improving, 80 | title={Improving Language Understanding from Screenshots}, 81 | author={Gao, Tianyu and Wang, Zirui and Bhaskar, Adithya and Chen, Danqi}, 82 | journal={arXiv preprint arXiv:2402.14073}, 83 | year={2024} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | # import gi 4 | # gi.require_version('Pango', '1.0') 5 | # gi.require_version('PangoCairo', '1.0') 6 | # from gi.repository import Pango, PangoCairo 7 | # import cairo 8 | from PIL import Image 9 | from dataclasses import dataclass, field 10 | import torch 11 | from streaming import LocalDataset 12 | from image_utils import * 13 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 14 | from transformers.image_utils import to_numpy_array 15 | from modeling.span_masking import SpanMaskingGenerator 16 | from random import sample 17 | from image_utils import render_text 18 | 19 | class NumpyDataset(Dataset): 20 | 21 | def __init__(self, path, block_size=None): 22 | self.tokens = np.load(path) 23 | self.block_size = self.tokens.shape[1] if block_size is None else block_size 24 | self.font_size = None 25 | 26 | def __len__(self): 27 | return len(self.tokens) 28 | 29 | def __getitem__(self, idx): 30 | return {"tokens": self.tokens[idx][:self.block_size], "font_size": self.font_size} 31 | 32 | 33 | class RenderTextCollator: 34 | def __init__(self, 35 | processor: object, 36 | font_size: int, 37 | line_space: int, 38 | replace_new_line: bool, 39 | new_line_token: str, 40 | width: int, 41 | height: int, 42 | block_size: int = 1024, 43 | rendered_as_target: bool = False, 44 | patch_width: int = 16, 45 | patch_height: int = 16, 46 | text_mask_rate: float = 0, 47 | merge_text_masks: bool = False, 48 | ignore_white_patches: bool = False, 49 | add_black_patch: bool = False, 50 | add_prefix: bool = False, 51 | autoregressive: bool = False, 52 | ar_image_block_size: int = None, 53 | total_block_size: int = None, 54 | context_mask: int = None, 55 | image_mode: str = "RGB", 56 | sample_mask_at_collator: bool = False, 57 | mask_ratio: float = 0, 58 | span_masking: bool = False, 59 | max_span_length: int = 6, 60 | ): 61 | self.processor = processor 62 | self.font_size = font_size 63 | self.line_space = line_space 64 | self.replace_new_line = replace_new_line 65 | self.new_line_token = new_line_token 66 | self.width = width 67 | self.height = height 68 | self.block_size = block_size 69 | self.rendered_as_target = rendered_as_target 70 | self.patch_width = patch_width 71 | self.patch_height = patch_height 72 | self.text_mask_rate = text_mask_rate 73 | self.merge_text_masks = merge_text_masks 74 | self.ignore_white_patches = ignore_white_patches 75 | self.add_black_patch = add_black_patch 76 | self.add_prefix = add_prefix 77 | self.autoregressive = autoregressive 78 | self.ar_image_block_size = ar_image_block_size 79 | self.total_block_size = total_block_size 80 | self.context_mask = context_mask 81 | self.image_mode = image_mode 82 | self.sample_mask_at_collator = sample_mask_at_collator 83 | self.mask_ratio = mask_ratio 84 | self.span_masking = span_masking 85 | self.max_span_length = max_span_length 86 | 87 | 88 | def mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]: 89 | """ 90 | Text masking 91 | """ 92 | labels = inputs.clone() 93 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 94 | probability_matrix = torch.full(labels.shape, self.text_mask_rate) 95 | if special_tokens_mask is None: 96 | special_tokens_mask = [ 97 | self.processor.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 98 | ] 99 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 100 | else: 101 | special_tokens_mask = special_tokens_mask.bool() 102 | 103 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 104 | masked_indices = torch.bernoulli(probability_matrix).bool() 105 | 106 | inputs[masked_indices] = self.processor.tokenizer.mask_token_id 107 | 108 | return inputs, labels 109 | 110 | 111 | def __call__(self, batch): 112 | new_batch = {"flattened_patches": [], "attention_mask": [], "labels": []} 113 | if self.autoregressive: 114 | # Data for autoregressive mode 115 | new_batch["input_ids"] = [] 116 | if self.ar_image_block_size == 0: 117 | # Text only 118 | new_batch = {"input_ids": [], "attention_mask": [], "labels": []} 119 | if self.sample_mask_at_collator: 120 | # Sample patch mask in data collator 121 | new_batch["patch_mask"] = [] 122 | 123 | for item in batch: 124 | if self.autoregressive and self.ar_image_block_size == 0: 125 | # Autoregressive: text only 126 | text_tokens = torch.tensor(item["tokens"].astype(np.int64)).long() 127 | 128 | input_ids = torch.cat([torch.tensor([self.processor.tokenizer.bos_token_id]).long(), text_tokens], 0) 129 | attention_mask = torch.ones(input_ids.shape).long() 130 | if self.total_block_size is not None: 131 | # Truncate 132 | input_ids = input_ids[:self.total_block_size] 133 | attention_mask = attention_mask[:self.total_block_size] 134 | new_batch["input_ids"].append(input_ids) 135 | new_batch["attention_mask"].append(attention_mask) 136 | labels = input_ids + 0 137 | if self.context_mask is not None: 138 | # Only predict on the non-masked part (mostly for evaluation) 139 | labels[:self.context_mask] = -100 140 | new_batch["labels"].append(labels) 141 | elif self.autoregressive: 142 | # Autoregressive with screenshot 143 | image_tokens = item["tokens"][:self.ar_image_block_size] # render these as screenshots 144 | 145 | text = self.processor.decode(image_tokens, skip_special_tokens=True) 146 | if self.replace_new_line: 147 | text = text.replace("\n", self.new_line_token) 148 | 149 | if self.add_prefix: 150 | text = "Beginning of the sequence: " + text 151 | 152 | image, rendered_text = render_text(text=text, font_size=self.font_size, line_space=self.line_space, width=self.width, height=self.height) 153 | 154 | # In the case where not all text is rendered into the screenshot, we truncate the text 155 | if self.replace_new_line: 156 | _ = rendered_text.replace(self.new_line_token, "\n").rstrip(" ") 157 | else: 158 | _ = rendered_text.rstrip(" ") 159 | encoded_num_img_tokens = len(self.processor(text=_, add_special_tokens=False)['input_ids']) 160 | text_tokens = torch.tensor(item["tokens"][min(encoded_num_img_tokens,self.ar_image_block_size):].astype(np.int64)).long() 161 | encoding = self.processor(images=image, return_tensors="pt", add_special_tokens=True) 162 | 163 | new_batch["flattened_patches"].append(encoding["flattened_patches"][0]) 164 | patch_attention_mask = encoding["attention_mask"][0] 165 | 166 | assert not self.add_black_patch # not supported (and not needed with ) 167 | 168 | # Mask out the attention to ending white patches 169 | if self.ignore_white_patches: 170 | fpatches = new_batch["flattened_patches"][-1][:, 2:] 171 | non_white_patches = ((fpatches - fpatches.mean(dim=-1, keepdim=True)) ** 2 < 1e-6).long().sum(-1) != fpatches.shape[-1] 172 | reverse_non_white_patches = non_white_patches.flip(-1) 173 | non_white_patches = reverse_non_white_patches.nonzero() 174 | if len(non_white_patches) == 0: 175 | first_white_patch = 0 176 | else: 177 | first_white_patch = len(reverse_non_white_patches) - non_white_patches[0][0] 178 | 179 | patch_attention_mask[first_white_patch:] = 0 180 | 181 | # BOS + image + text 182 | input_ids = torch.cat([torch.tensor([self.processor.tokenizer.bos_token_id]).long(), encoding["image_input_ids"][0], text_tokens], 0) 183 | attention_mask = torch.ones(input_ids.shape).long() 184 | patch_mask = input_ids == self.processor.patch_token_id 185 | attention_mask[patch_mask] = patch_attention_mask.long() 186 | if self.total_block_size is not None: 187 | input_ids = input_ids[:self.total_block_size] 188 | attention_mask = attention_mask[:self.total_block_size] 189 | new_batch["input_ids"].append(input_ids) 190 | new_batch["attention_mask"].append(attention_mask) 191 | new_batch["labels"].append(input_ids) 192 | 193 | else: 194 | if self.text_mask_rate > 0: 195 | input_ids = torch.tensor(item["tokens"].astype(np.int32)).long().unsqueeze(0) 196 | input_ids, labels = self.mask_tokens(input_ids) 197 | input_ids = input_ids.squeeze(0) 198 | labels = labels.squeeze(0) 199 | text = self.processor.decode(input_ids, skip_special_tokens=False) 200 | else: 201 | text = self.processor.decode(item["tokens"], skip_special_tokens=True) 202 | 203 | if self.replace_new_line: 204 | text = text.replace("\n", self.new_line_token) 205 | 206 | if self.merge_text_masks and self.text_mask_rate > 0: 207 | while True: 208 | if "" not in text: 209 | break 210 | text = text.replace("", "") 211 | 212 | if self.add_prefix: 213 | text = "Beginning of the sequence: " + text 214 | 215 | image, rendered_text = render_text(text=text, font_size=self.font_size, line_space=self.line_space, width=self.width, height=self.height) 216 | image = image.convert(self.image_mode) 217 | image = to_numpy_array(image) 218 | if self.image_mode != "RGB": 219 | image = np.expand_dims(image, -1) # h, w, 1 220 | if self.image_mode == "1": 221 | image = image.astype(np.float32) # bool -> float for clf 222 | 223 | if self.rendered_as_target: 224 | if self.text_mask_rate > 0: 225 | # this is not very accurate as with the merge masks we can only estimate how much is rendered in the labels 226 | valid_num_tokens = len(self.processor.tokenizer.tokenize(rendered_text)) 227 | # consider the merged masks 228 | valid_num_tokens = int(valid_num_tokens / (len(self.processor.tokenizer.tokenize(text)) / len(labels))) 229 | labels[valid_num_tokens:] = self.processor.tokenizer.pad_token_id 230 | else: 231 | labels = self.processor.tokenizer(rendered_text, return_tensors="pt", add_special_tokens=False, max_length=self.block_size, padding="max_length", truncation=True)["input_ids"].squeeze() 232 | 233 | encoding = self.processor(images=image, return_tensors="pt", add_special_tokens=True) 234 | new_batch["flattened_patches"].append(encoding["flattened_patches"][0]) 235 | new_batch["attention_mask"].append(encoding["attention_mask"][0]) 236 | new_batch["labels"].append(labels) 237 | 238 | if self.add_black_patch: 239 | self.ignore_white_patches 240 | 241 | if self.ignore_white_patches: 242 | fpatches = new_batch["flattened_patches"][-1][:, 2:] 243 | # White patches should have all pixels = 1 (normalized) 244 | non_white_patches = (fpatches > 1 - 1e-6).long().sum(-1) != fpatches.shape[-1] 245 | reverse_non_white_patches = non_white_patches.flip(-1) 246 | non_white_patches = reverse_non_white_patches.nonzero() 247 | if len(non_white_patches) == 0: 248 | first_white_patch = 0 249 | else: 250 | first_white_patch = len(reverse_non_white_patches) - non_white_patches[0][0] 251 | 252 | new_batch["attention_mask"][-1][first_white_patch:] = 0 253 | 254 | if self.add_black_patch: 255 | if first_white_patch == len(reverse_non_white_patches): 256 | first_white_patch -= 1 # if there is no white patch, force changing the last one to black 257 | 258 | black = 0 259 | new_batch["flattened_patches"][-1][first_white_patch, 2:] = black 260 | new_batch["attention_mask"][-1][first_white_patch] = 1 261 | 262 | if self.sample_mask_at_collator: 263 | assert self.span_masking is True # we are only doing this for span masking 264 | seq_length = new_batch["flattened_patches"][-1].shape[0] 265 | len_keep = int(seq_length * (1 - self.mask_ratio)) 266 | span_masking_generator = SpanMaskingGenerator( 267 | num_patches=seq_length, 268 | num_masking_patches=seq_length-len_keep, 269 | max_span_length=self.max_span_length, 270 | spacing="span", 271 | cumulative_span_weights=[0.2,0.4,0.6,0.8,0.9,1] 272 | ) 273 | patch_mask = torch.tensor(span_masking_generator()) 274 | new_batch["patch_mask"].append(patch_mask) 275 | 276 | for key in new_batch: 277 | new_batch[key] = torch.stack(new_batch[key]) 278 | 279 | return new_batch 280 | -------------------------------------------------------------------------------- /image_utils.py: -------------------------------------------------------------------------------- 1 | # Most part of the code was adopted from https://gist.github.com/pojda/8bf989a0556845aaf4662cd34f21d269 2 | 3 | from PIL import Image, ImageDraw, ImageFont 4 | import PIL 5 | import numpy as np 6 | 7 | from io import BytesIO 8 | import base64 9 | 10 | import sys 11 | sys.path.append("rendering/src") 12 | 13 | try: 14 | import renderer 15 | except ImportError as e: 16 | print("Fail to import simple renderer") 17 | print(e) 18 | 19 | 20 | def render_text( 21 | text, 22 | width=512, 23 | height=256, 24 | font_size=10, 25 | line_space=6, # the final height of one line is font_size + line_space 26 | white_bg=True, 27 | no_full_rendering_warning=False, # print a message if the text is not fully rendered 28 | ): 29 | array = np.zeros(width*height, dtype=np.int8) 30 | # extra_line_space = int(font_size * line_space - font_size) 31 | 32 | rendered, rendered_text = renderer.render_unicode( 33 | array, text, height, width, font_size, line_space, True, True, True, True, True 34 | ) 35 | if no_full_rendering_warning and len(rendered_text) != len(text): 36 | print("Warning: text got cut off and was not fully rendered!!") 37 | rendered = rendered.reshape(height, width) 38 | rendered = (255 - rendered) if white_bg else rendered 39 | return Image.fromarray(rendered, "L").convert("RGB"), rendered_text 40 | 41 | 42 | def renormalize_pred(image_tensor): 43 | min = image_tensor.min(-1).values.min(-1).values 44 | max = image_tensor.max(-1).values.max(-1).values 45 | std = 255 / (max - min) 46 | mean = -min * std 47 | image_tensor = image_tensor * std.unsqueeze(-1).unsqueeze(-1) + mean.unsqueeze(-1).unsqueeze(-1) 48 | return image_tensor 49 | 50 | 51 | def renormalize(image_tensor): 52 | return image_tensor * 255 53 | 54 | 55 | def flattened_patches_to_image(flattened_patches, height=256, width=512, patch_height=16, patch_width=16, mask=None, original_patches=None, image_mode="RGB"): 56 | # Convert flattened_patches back to PIL image 57 | # flattend_patches: (num_patches, 768) 58 | h = height // patch_height 59 | w = width // patch_width 60 | c = 3 if image_mode == "RGB" else 1 61 | if image_mode == "1": 62 | flattened_patches = flattened_patches * 255 63 | image_mode = "L" # convert to grayscale for further processing 64 | if original_patches is not None and mask is not None: 65 | original_patches = renormalize(original_patches) 66 | flattened_patches = renormalize_pred(flattened_patches) 67 | flattened_patches = flattened_patches * mask.unsqueeze(-1) + original_patches * (1 - mask.unsqueeze(-1)) 68 | else: 69 | flattened_patches = renormalize(flattened_patches) 70 | 71 | flattened_patches = flattened_patches.reshape(h * w, patch_height, patch_width, c) # (h * w, ph, pw, 3) 72 | if mask is not None: 73 | flattened_patches[mask.bool(), :, :, 0] = flattened_patches[mask.bool(), :, :, 0] * 0.7 + 255 * 0.3 74 | flattened_patches = flattened_patches.reshape(h, w, patch_height, patch_width, c) # (h, w, ph, pw, 3) 75 | flattened_patches = flattened_patches.permute(0, 2, 1, 3, 4) # (h, ph, w, pw, 3) 76 | flattened_patches = flattened_patches.reshape(h * patch_height, w * patch_width, c) # (h * ph, w * pw, 3) 77 | if c == 1: 78 | flattened_patches = flattened_patches.squeeze(-1) 79 | image = flattened_patches.numpy() 80 | 81 | image = Image.fromarray(image.astype(np.uint8), mode=image_mode) 82 | 83 | return image 84 | 85 | 86 | def flattened_patches_to_vit_pixel_values(flattened_patches, height=256, width=512, patch_height=16, patch_width=16, image_mode="RGB"): 87 | # Convert flattened_patches to pixel values 88 | # pixel_values: batch_size, num_channels, height, width 89 | # flattend_patches: (num_patches, 768) 90 | h = height // patch_height 91 | w = width // patch_width 92 | num_channels = 3 if image_mode == "RGB" else 1 93 | bsz, num_patch, patch_emb = flattened_patches.shape 94 | 95 | flattened_patches = flattened_patches.reshape(bsz, h * w, patch_height, patch_width, num_channels) # (bsz, h * w, ph, pw, c) 96 | flattened_patches = flattened_patches.reshape(bsz, h, w, patch_height, patch_width, num_channels) # (bsz, h, w, ph, pw, c) 97 | flattened_patches = flattened_patches.permute(0, 5, 1, 3, 2, 4) # (bsz, c, h, ph, w, pw) 98 | flattened_patches = flattened_patches.reshape(bsz, num_channels, h * patch_height, w * patch_width) # (bsz, c, h * ph, w * pw) 99 | 100 | return flattened_patches 101 | 102 | -------------------------------------------------------------------------------- /imgs/ptp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/PTP/e03435f1ec1d9902bd6663798dc84f7382ca7286/imgs/ptp.png -------------------------------------------------------------------------------- /model_configs/ptp/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "PTPForConditionalGeneration" 4 | ], 5 | "model_type": "ptp", 6 | "pad_token_id": 1, 7 | "torch_dtype": "float32", 8 | "tie_word_embeddings": true, 9 | "text_config": { 10 | "emb_layer_norm": true, 11 | "d_ff": 3072, 12 | "d_kv": 64, 13 | "decoder_start_token_id": 0, 14 | "dense_act_fn": "gelu_new", 15 | "dropout_rate": 0.1, 16 | "attention_dropout": 0.1, 17 | "encoder_hidden_size": 768, 18 | "hidden_size": 768, 19 | "initializer_factor": 1.0, 20 | "initializer_range": 0.02, 21 | "is_decoder": true, 22 | "is_encoder_decoder": false, 23 | "layer_norm_epsilon": 1e-06, 24 | "model_type": "ptp_text_decoder", 25 | "num_heads": 12, 26 | "num_layers": 12, 27 | "pad_token_id": 1, 28 | "tie_word_embeddings": true, 29 | "vocab_size": 50265 30 | }, 31 | "vision_config": { 32 | "attention_probs_dropout_prob": 0.1, 33 | "decoder_hidden_size": 512, 34 | "decoder_intermediate_size": 2048, 35 | "decoder_num_attention_heads": 16, 36 | "decoder_num_hidden_layers": 8, 37 | "hidden_act": "gelu", 38 | "hidden_dropout_prob": 0.1, 39 | "hidden_size": 768, 40 | "image_size": [ 41 | 16, 42 | 8176 43 | ], 44 | "initializer_factor": 1.0, 45 | "initializer_range": 0.02, 46 | "intermediate_size": 3072, 47 | "layer_norm_eps": 1e-12, 48 | "model_type": "ptp_vision", 49 | "norm_pix_loss": true, 50 | "num_attention_heads": 12, 51 | "num_hidden_layers": 12, 52 | "patch_size": 16, 53 | "qkv_bias": true, 54 | "torch_dtype": "float32" 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /model_configs/ptp/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_convert_rgb": true, 3 | "do_normalize": false, 4 | "image_processor_type": "PTPImageProcessor", 5 | "max_patches": 2048, 6 | "patch_size": { 7 | "height": 16, 8 | "width": 16 9 | }, 10 | "processor_class": "PTPProcessor" 11 | } 12 | -------------------------------------------------------------------------------- /model_configs/ptp/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}} 2 | -------------------------------------------------------------------------------- /model_configs/ptp/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"errors": "replace", "unk_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "mask_token": {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "add_bos_token": true, "special_tokens_map_file": null, "name_or_path": "patrickvonplaten/opt-30b", "processor_class": "PTPProcessor", "tokenizer_class": "GPT2Tokenizer"} 2 | -------------------------------------------------------------------------------- /model_configs/screenshot-llama-380m/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForScreenshot" 4 | ], 5 | "bos_token_id": 1, 6 | "eos_token_id": 2, 7 | "hidden_act": "silu", 8 | "hidden_size": 1024, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 2816, 11 | "max_position_embeddings": 4096, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "num_key_value_heads": 16, 16 | "pretraining_tp": 1, 17 | "rms_norm_eps": 1e-05, 18 | "rope_scaling": null, 19 | "tie_word_embeddings": false, 20 | "torch_dtype": "float32", 21 | "use_cache": true, 22 | "vocab_size": 32000 23 | } 24 | -------------------------------------------------------------------------------- /model_configs/screenshot-llama-380m/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_convert_rgb": true, 3 | "do_normalize": false, 4 | "concat_coord": true, 5 | "patch_size": { 6 | "height": 16, 7 | "width": 16 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /model_configs/screenshot-llama-380m/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": false, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": false, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "unk_token": { 17 | "content": "", 18 | "lstrip": false, 19 | "normalized": false, 20 | "rstrip": false, 21 | "single_word": false 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /model_configs/screenshot-llama-380m/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/PTP/e03435f1ec1d9902bd6663798dc84f7382ca7286/model_configs/screenshot-llama-380m/tokenizer.model -------------------------------------------------------------------------------- /model_configs/screenshot-llama-380m/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "clean_up_tokenization_spaces": false, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "legacy": false, 22 | "model_max_length": 1000000000000000019884624838656, 23 | "pad_token": null, 24 | "padding_side": "right", 25 | "sp_model_kwargs": {}, 26 | "tokenizer_class": "LlamaTokenizer", 27 | "unk_token": { 28 | "__type": "AddedToken", 29 | "content": "", 30 | "lstrip": false, 31 | "normalized": false, 32 | "rstrip": false, 33 | "single_word": false 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /modeling/configuration_pixel.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | PIXEL model configuration 17 | adapted from ViT MAE model configuration: 18 | https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/vit_mae/configuration_vit_mae.py 19 | """ 20 | 21 | from transformers import PretrainedConfig 22 | from transformers.utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | PIXEL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | "Team-PIXEL/pixel-base": "https://huggingface.co/Team-PIXEL/pixel-base/resolve/main/config.json", 29 | } 30 | 31 | 32 | class PIXELConfig(PretrainedConfig): 33 | 34 | model_type = "pixel" 35 | 36 | def __init__( 37 | self, 38 | hidden_size=768, 39 | num_hidden_layers=12, 40 | num_attention_heads=12, 41 | intermediate_size=3072, 42 | hidden_act="gelu", 43 | hidden_dropout_prob=0.1, 44 | attention_probs_dropout_prob=0.1, 45 | initializer_range=0.02, 46 | layer_norm_eps=1e-12, 47 | image_size=(16, 8464), 48 | patch_size=16, 49 | num_channels=3, 50 | qkv_bias=True, 51 | decoder_num_attention_heads=16, 52 | decoder_hidden_size=512, 53 | decoder_num_hidden_layers=8, 54 | decoder_intermediate_size=2048, 55 | mask_ratio=0, 56 | norm_pix_loss=True, 57 | embedding_layernorm=False, 58 | image_mode="RGB", 59 | **kwargs 60 | ): 61 | super().__init__(**kwargs) 62 | 63 | self.hidden_size = hidden_size 64 | self.num_hidden_layers = num_hidden_layers 65 | self.num_attention_heads = num_attention_heads 66 | self.intermediate_size = intermediate_size 67 | self.hidden_act = hidden_act 68 | self.hidden_dropout_prob = hidden_dropout_prob 69 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 70 | self.initializer_range = initializer_range 71 | self.layer_norm_eps = layer_norm_eps 72 | self.image_size = image_size 73 | self.patch_size = patch_size 74 | self.num_channels = num_channels 75 | self.qkv_bias = qkv_bias 76 | self.decoder_num_attention_heads = decoder_num_attention_heads 77 | self.decoder_hidden_size = decoder_hidden_size 78 | self.decoder_num_hidden_layers = decoder_num_hidden_layers 79 | self.decoder_intermediate_size = decoder_intermediate_size 80 | self.mask_ratio = mask_ratio 81 | self.norm_pix_loss = norm_pix_loss 82 | self.embedding_layernorm = embedding_layernorm 83 | self.image_mode = image_mode -------------------------------------------------------------------------------- /modeling/configuration_ptp.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/pix2struct/configuration_pix2struct.py 2 | 3 | # coding=utf-8 4 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ Pix2Struct model configuration""" 18 | 19 | import copy 20 | import os 21 | from typing import Union 22 | 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.utils import logging 25 | from .configuration_pixel import PIXELConfig 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | 31 | class PTPTextConfig(PretrainedConfig): 32 | 33 | model_type = "ptp_text_decoder" 34 | keys_to_ignore_at_inference = ["past_key_values"] 35 | attribute_map = { 36 | "hidden_size": "hidden_size", 37 | "num_attention_heads": "num_heads", 38 | "num_hidden_layers": "num_layers", 39 | } 40 | 41 | def __init__( 42 | self, 43 | vocab_size=50244, 44 | hidden_size=768, 45 | d_kv=64, 46 | d_ff=2048, 47 | num_layers=12, 48 | num_heads=12, 49 | dropout_rate=0.1, 50 | attention_dropout=0.1, 51 | layer_norm_epsilon=1e-6, 52 | initializer_factor=1.0, 53 | dense_act_fn="gelu_new", 54 | decoder_start_token_id=0, 55 | use_cache=False, 56 | pad_token_id=0, 57 | eos_token_id=1, 58 | tie_word_embeddings=True, 59 | is_decoder=True, 60 | emb_layer_norm=True, 61 | is_glu=True, 62 | **kwargs, 63 | ): 64 | self.vocab_size = vocab_size 65 | self.hidden_size = hidden_size 66 | self.d_kv = d_kv 67 | self.d_ff = d_ff 68 | self.num_layers = num_layers 69 | self.num_heads = num_heads 70 | self.dropout_rate = dropout_rate 71 | self.attention_dropout = attention_dropout 72 | self.layer_norm_epsilon = layer_norm_epsilon 73 | self.layer_norm_eps = self.layer_norm_epsilon 74 | self.initializer_factor = initializer_factor 75 | self.use_cache = use_cache 76 | 77 | self.eos_token_id = eos_token_id 78 | self.decoder_start_token_id = decoder_start_token_id 79 | self.emb_layer_norm = emb_layer_norm 80 | self.is_glu = is_glu 81 | 82 | # for backwards compatibility 83 | self.dense_act_fn = dense_act_fn 84 | 85 | super().__init__( 86 | pad_token_id=pad_token_id, 87 | eos_token_id=eos_token_id, 88 | decoder_start_token_id=decoder_start_token_id, 89 | tie_word_embeddings=tie_word_embeddings, 90 | is_decoder=is_decoder, 91 | **kwargs, 92 | ) 93 | 94 | @classmethod 95 | def from_pretrained( 96 | cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs 97 | ) -> "PretrainedConfig": 98 | config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs) 99 | 100 | # get the text config dict if we are loading from Pix2StructConfig 101 | if config_dict.get("model_type") == "ptp": 102 | config_dict = config_dict["text_config"] 103 | 104 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 105 | logger.warning( 106 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 107 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 108 | ) 109 | 110 | return cls.from_dict(config_dict, **kwargs) 111 | 112 | 113 | 114 | class PTPVisionConfig(PretrainedConfig): 115 | 116 | model_type = "ptp_vision" 117 | 118 | def __init__( 119 | self, 120 | hidden_size=768, 121 | num_hidden_layers=12, 122 | num_attention_heads=12, 123 | intermediate_size=3072, 124 | hidden_act="gelu", 125 | hidden_dropout_prob=0.1, 126 | attention_probs_dropout_prob=0.1, 127 | initializer_range=0.02, 128 | layer_norm_eps=1e-12, 129 | image_size=(16, 8192), 130 | patch_size=16, 131 | num_channels=3, 132 | qkv_bias=True, 133 | decoder_num_attention_heads=16, 134 | decoder_hidden_size=512, 135 | decoder_num_hidden_layers=8, 136 | decoder_intermediate_size=2048, 137 | norm_pix_loss=True, 138 | embedding_layernorm=False, 139 | image_mode="RGB", 140 | **kwargs 141 | ): 142 | super().__init__(**kwargs) 143 | 144 | self.hidden_size = hidden_size 145 | self.num_hidden_layers = num_hidden_layers 146 | self.num_attention_heads = num_attention_heads 147 | self.intermediate_size = intermediate_size 148 | self.hidden_act = hidden_act 149 | self.hidden_dropout_prob = hidden_dropout_prob 150 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 151 | self.initializer_range = initializer_range 152 | self.layer_norm_eps = layer_norm_eps 153 | self.image_size = image_size 154 | self.patch_size = patch_size 155 | self.num_channels = num_channels 156 | self.qkv_bias = qkv_bias 157 | self.decoder_num_attention_heads = decoder_num_attention_heads 158 | self.decoder_hidden_size = decoder_hidden_size 159 | self.decoder_num_hidden_layers = decoder_num_hidden_layers 160 | self.decoder_intermediate_size = decoder_intermediate_size 161 | self.norm_pix_loss = norm_pix_loss 162 | self.embedding_layernorm = embedding_layernorm 163 | self.image_mode = image_mode 164 | 165 | @classmethod 166 | def from_pretrained( 167 | cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs 168 | ) -> "PretrainedConfig": 169 | config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs) 170 | 171 | # get the text config dict if we are loading from Pix2StructConfig 172 | if config_dict.get("model_type") == "ptp": 173 | config_dict = config_dict["vision_config"] 174 | 175 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 176 | logger.warning( 177 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 178 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 179 | ) 180 | 181 | return cls.from_dict(config_dict, **kwargs) 182 | 183 | 184 | class PTPConfig(PretrainedConfig): 185 | 186 | model_type = "ptp" 187 | is_composition = True 188 | 189 | def __init__( 190 | self, 191 | text_config=None, 192 | vision_config=None, 193 | tie_word_embeddings=True, 194 | is_encoder_decoder=True, 195 | add_mae_decoder=True, 196 | add_text_decoder=True, 197 | initializer_factor=1.0, 198 | initializer_range=0.02, 199 | **kwargs, 200 | ): 201 | super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs) 202 | 203 | if text_config is None: 204 | text_config = {} 205 | logger.info("text_config is None. Initializing the Pix2StructTextConfig with default values.") 206 | 207 | if vision_config is None: 208 | vision_config = {} 209 | logger.info("vision_config is None. Initializing the Pix2StructVisionConfig with default values.") 210 | 211 | self.text_config = PTPTextConfig(**text_config) 212 | self.vision_config = PTPVisionConfig(**vision_config) 213 | 214 | self.decoder_start_token_id = self.text_config.decoder_start_token_id 215 | self.pad_token_id = self.text_config.pad_token_id 216 | self.eos_token_id = self.text_config.eos_token_id 217 | 218 | self.add_mae_decoder = add_mae_decoder 219 | self.add_text_decoder = add_text_decoder 220 | 221 | self.initializer_factor = initializer_factor 222 | self.initializer_range = initializer_range 223 | 224 | 225 | @classmethod 226 | def from_text_vision_configs( 227 | cls, text_config: PTPTextConfig, vision_config: PTPVisionConfig, **kwargs 228 | ): 229 | return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) 230 | 231 | def to_dict(self): 232 | output = copy.deepcopy(self.__dict__) 233 | output["text_config"] = self.text_config.to_dict() 234 | output["vision_config"] = self.vision_config.to_dict() 235 | output["model_type"] = self.__class__.model_type 236 | return output 237 | -------------------------------------------------------------------------------- /modeling/configuration_screenshot_llama.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaConfig 2 | import copy 3 | 4 | class LlamaScreenshotConfig(LlamaConfig): 5 | 6 | model_type = "screenshot-llama" 7 | 8 | def __init__( 9 | self, 10 | patch_embed_size=768, 11 | img_begin_token_id=None, 12 | img_end_token_id=None, 13 | patch_token_id=None, 14 | newline_token_id=None, 15 | norm_pix_loss=False, 16 | pixel_decoder_config=None, 17 | **kwargs, 18 | ): 19 | super().__init__(**kwargs) 20 | 21 | # This is the size of the patch embedding (ph x pw x 3) 22 | self.patch_embed_size = patch_embed_size 23 | 24 | # Newly added special token 25 | self.img_begin_token_id = img_begin_token_id 26 | self.img_end_token_id = img_end_token_id 27 | self.patch_token_id = patch_token_id 28 | self.newline_token_id = newline_token_id 29 | 30 | # Loss for pixel-level supervision 31 | self.norm_pix_loss = norm_pix_loss 32 | 33 | if isinstance(pixel_decoder_config, dict): 34 | self.pixel_decoder_config = LlamaScreenshotConfig(pixel_decoder_config) 35 | else: 36 | self.pixel_decoder_config = pixel_decoder_config 37 | 38 | def to_dict(self): 39 | 40 | output = copy.deepcopy(self.__dict__) 41 | if self.pixel_decoder_config is not None: 42 | output['pixel_decoder_config'] = self.pixel_decoder_config.to_dict() 43 | output["model_type"] = self.__class__.model_type 44 | 45 | return output 46 | -------------------------------------------------------------------------------- /modeling/modeling_screenshot_llama.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import CrossEntropyLoss 6 | import numpy as np 7 | from typing import Optional, Tuple, Union, List 8 | from dataclasses import dataclass 9 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 10 | 11 | from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel 12 | 13 | from transformers.utils import logging 14 | from transformers.activations import ACT2FN 15 | from transformers.models.llama.modeling_llama import _expand_mask 16 | import copy 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | 21 | # def prepare_bidirectional_decoder_attn_mask(attention_mask, input_shape, inputs_embeds): 22 | # # create causal mask 23 | # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 24 | # expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 25 | # inputs_embeds.device 26 | # ) 27 | # return expanded_attn_mask 28 | 29 | 30 | @dataclass 31 | class ScreenshotCausalLMOutputWithPast(CausalLMOutputWithPast): 32 | 33 | pixel_loss: Optional[torch.FloatTensor] = None 34 | text_loss: Optional[torch.FloatTensor] = None 35 | patch_logits: Optional[torch.FloatTensor] = None 36 | 37 | 38 | class LlamaScreenshotModel(LlamaModel): 39 | """ 40 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 41 | 42 | Args: 43 | config: LlamaConfig 44 | """ 45 | 46 | def __init__(self, config): 47 | super().__init__(config) 48 | 49 | self.patch_projection = nn.Linear(config.patch_embed_size, config.hidden_size, bias=True) 50 | if getattr(config, "add_input_mlp", False): 51 | self.input_patch_mlp_or_identity = LlamaEmbedderMLP(config) 52 | else: 53 | self.input_patch_mlp_or_identity = nn.Identity() 54 | 55 | self.post_init() 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | flattened_patches: Optional[torch.Tensor] = None, 62 | position_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[List[torch.FloatTensor]] = None, 64 | inputs_embeds: Optional[torch.FloatTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | return_dict: Optional[bool] = None, 69 | ) -> Union[Tuple, BaseModelOutputWithPast]: 70 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 71 | output_hidden_states = ( 72 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 73 | ) 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 77 | 78 | # retrieve input_ids and inputs_embeds 79 | if input_ids is not None and inputs_embeds is not None: 80 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 81 | elif input_ids is not None: 82 | batch_size, seq_length = input_ids.shape 83 | elif inputs_embeds is not None: 84 | batch_size, seq_length, _ = inputs_embeds.shape 85 | else: 86 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 87 | 88 | seq_length_with_past = seq_length 89 | past_key_values_length = 0 90 | 91 | if past_key_values is not None: 92 | past_key_values_length = past_key_values[0][0].shape[2] 93 | seq_length_with_past = seq_length_with_past + past_key_values_length 94 | 95 | if position_ids is None: 96 | device = input_ids.device if input_ids is not None else inputs_embeds.device 97 | position_ids = torch.arange( 98 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 99 | ) 100 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 101 | else: 102 | position_ids = position_ids.view(-1, seq_length).long() 103 | 104 | if inputs_embeds is None: 105 | inputs_embeds = self.embed_tokens(input_ids) 106 | 107 | # Patch embedding 108 | if flattened_patches is not None: 109 | patch_embeds = self.patch_projection(flattened_patches) # (B, PL, Pemb) -> (B, PL, H) 110 | patch_embeds = self.input_patch_mlp_or_identity(patch_embeds) # (B, PL, H) 111 | patch_mask = input_ids == self.config.patch_token_id # (B, L) 112 | patch_mask = patch_mask.unsqueeze(-1).expand_as(inputs_embeds) # (B, L, H) 113 | if inputs_embeds.dtype != patch_embeds.dtype: 114 | inputs_embeds = inputs_embeds.to(patch_embeds.dtype) 115 | inputs_embeds[patch_mask] = patch_embeds.reshape(-1) 116 | 117 | # embed positions 118 | if attention_mask is None: 119 | attention_mask = torch.ones( 120 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 121 | ) 122 | padding_mask = None 123 | else: 124 | if 0 in attention_mask: 125 | padding_mask = attention_mask 126 | else: 127 | padding_mask = None 128 | 129 | attention_mask = self._prepare_decoder_attention_mask( 130 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 131 | ) 132 | 133 | hidden_states = inputs_embeds 134 | 135 | if self.gradient_checkpointing and self.training: 136 | if use_cache: 137 | logger.warning_once( 138 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 139 | ) 140 | use_cache = False 141 | 142 | # decoder layers 143 | all_hidden_states = () if output_hidden_states else None 144 | all_self_attns = () if output_attentions else None 145 | next_decoder_cache = () if use_cache else None 146 | 147 | for idx, decoder_layer in enumerate(self.layers): 148 | if output_hidden_states: 149 | all_hidden_states += (hidden_states,) 150 | 151 | past_key_value = past_key_values[idx] if past_key_values is not None else None 152 | 153 | if self.gradient_checkpointing and self.training: 154 | 155 | def create_custom_forward(module): 156 | def custom_forward(*inputs): 157 | # None for past_key_value 158 | return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) 159 | 160 | return custom_forward 161 | 162 | layer_outputs = torch.utils.checkpoint.checkpoint( 163 | create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids 164 | ) 165 | else: 166 | layer_outputs = decoder_layer( 167 | hidden_states, 168 | attention_mask=attention_mask, 169 | position_ids=position_ids, 170 | past_key_value=past_key_value, 171 | output_attentions=output_attentions, 172 | use_cache=use_cache, 173 | padding_mask=padding_mask, 174 | ) 175 | 176 | hidden_states = layer_outputs[0] 177 | 178 | if use_cache: 179 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 180 | 181 | if output_attentions: 182 | all_self_attns += (layer_outputs[1],) 183 | 184 | hidden_states = self.norm(hidden_states) 185 | 186 | # add hidden states from the last decoder layer 187 | if output_hidden_states: 188 | all_hidden_states += (hidden_states,) 189 | 190 | next_cache = next_decoder_cache if use_cache else None 191 | if not return_dict: 192 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 193 | return BaseModelOutputWithPast( 194 | last_hidden_state=hidden_states, 195 | past_key_values=next_cache, 196 | hidden_states=all_hidden_states, 197 | attentions=all_self_attns, 198 | ) 199 | 200 | 201 | class LlamaEmbedderMLP(nn.Module): 202 | def __init__(self, config): 203 | super().__init__() 204 | self.config = config 205 | self.hidden_size = config.hidden_size 206 | 207 | if hasattr(config, "intermediate_emb_size") and config.intermediate_emb_size is not None: 208 | self.intermediate_size = config.intermediate_emb_size 209 | elif hasattr(config, "multiplicative_factor") and config.multiplicative_factor is not None: 210 | multiplicative_factor = config.multiplicative_factor 211 | self.intermediate_size = multiplicative_factor * config.hidden_size 212 | else: 213 | multiplicative_factor = config.vocab_size // config.hidden_size 214 | self.intermediate_size = multiplicative_factor * config.hidden_size 215 | 216 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) 217 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) 218 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) 219 | self.act_fn = ACT2FN[config.hidden_act] 220 | 221 | def forward(self, x): 222 | if self.config.pretraining_tp > 1: 223 | slice = self.intermediate_size // self.config.pretraining_tp 224 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 225 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 226 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 227 | 228 | gate_proj = torch.cat( 229 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 230 | ) 231 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 232 | 233 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 234 | down_proj = [ 235 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 236 | ] 237 | down_proj = sum(down_proj) 238 | else: 239 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 240 | 241 | return down_proj 242 | 243 | 244 | class LlamaForScreenshot(LlamaPreTrainedModel): 245 | def __init__(self, config, **kwargs): 246 | super().__init__(config) 247 | self.model = LlamaScreenshotModel(config) 248 | 249 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 250 | 251 | if getattr(config, "add_output_mlp", False): 252 | self.output_patch_mlp_or_identity = LlamaEmbedderMLP(config) 253 | else: 254 | self.output_patch_mlp_or_identity = nn.Identity() 255 | 256 | if getattr(config, "pixel_decoder", False): 257 | print("Add a transformer PIXEL decoder!") 258 | self.pixel_decoder = LlamaScreenshotModel(config.pixel_decoder_config) 259 | self.encoder_to_decoder_proj = nn.Linear(config.hidden_size, config.pixel_decoder_config.hidden_size, bias=True) 260 | self.patch_head = nn.Linear(config.pixel_decoder_config.hidden_size, config.patch_embed_size, bias=True) 261 | else: 262 | self.patch_head = nn.Linear(config.hidden_size, config.patch_embed_size, bias=True) 263 | 264 | # Initialize weights and apply final processing 265 | self.post_init() 266 | 267 | def get_input_embeddings(self): 268 | return self.model.embed_tokens 269 | 270 | def set_input_embeddings(self, value): 271 | self.model.embed_tokens = value 272 | 273 | def get_output_embeddings(self): 274 | return self.lm_head 275 | 276 | def set_output_embeddings(self, new_embeddings): 277 | self.lm_head = new_embeddings 278 | 279 | def set_decoder(self, decoder): 280 | self.model = decoder 281 | 282 | def get_decoder(self): 283 | return self.model 284 | 285 | def pixel_loss(self, target, pred, mask): 286 | # target and pred: (bsz * npatches, pH*pW*3) 287 | if self.config.norm_pix_loss: 288 | # note that this is not the same as the pix2struct normalization 289 | # this is normalization within patches (same as the original MAE/PIXEL loss) 290 | mean = target.mean(dim=-1, keepdim=True) 291 | var = target.var(dim=-1, keepdim=True) 292 | target = (target - mean) / (var + 1.0e-6) ** 0.5 293 | 294 | loss = (pred - target) ** 2 295 | loss = loss.mean(dim=-1) # [b * PL] 296 | loss = (loss * mask).sum() / mask.sum() # mean loss on patches that have attn=1 297 | 298 | return loss 299 | 300 | def forward( 301 | self, 302 | input_ids: torch.LongTensor = None, 303 | attention_mask: Optional[torch.Tensor] = None, 304 | position_ids: Optional[torch.LongTensor] = None, 305 | flattened_patches: Optional[torch.FloatTensor] = None, 306 | past_key_values: Optional[List[torch.FloatTensor]] = None, 307 | inputs_embeds: Optional[torch.FloatTensor] = None, 308 | labels: Optional[torch.LongTensor] = None, 309 | use_cache: Optional[bool] = None, 310 | output_attentions: Optional[bool] = None, 311 | output_hidden_states: Optional[bool] = None, 312 | return_dict: Optional[bool] = None, 313 | ) -> Union[Tuple, CausalLMOutputWithPast]: 314 | 315 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 316 | output_hidden_states = ( 317 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 318 | ) 319 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 320 | 321 | if flattened_patches is not None: 322 | # Remove the coordinates 323 | flattened_patches = flattened_patches[:, :, 2:] 324 | 325 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 326 | outputs = self.model( 327 | input_ids=input_ids, 328 | attention_mask=attention_mask, 329 | flattened_patches=flattened_patches, 330 | position_ids=position_ids, 331 | past_key_values=past_key_values, 332 | inputs_embeds=inputs_embeds, 333 | use_cache=use_cache, 334 | output_attentions=output_attentions, 335 | output_hidden_states=output_hidden_states, 336 | return_dict=return_dict, 337 | ) 338 | 339 | hidden_states = outputs[0] 340 | logits = self.lm_head(hidden_states) 341 | 342 | loss = None 343 | if flattened_patches is not None: 344 | shift_hidden_states = hidden_states[..., :-1, :].contiguous() 345 | shift_labels = labels[..., 1:].contiguous() 346 | shift_attention_mask = attention_mask[..., 1:].contiguous() 347 | 348 | patch_mask = shift_labels == self.config.patch_token_id 349 | patch_hidden_states = shift_hidden_states[patch_mask] # (B*PL, H) 350 | patch_attention_mask = shift_attention_mask[patch_mask] # (B*PL) 351 | patch_logits_intermediate = self.output_patch_mlp_or_identity(patch_hidden_states) 352 | 353 | if getattr(self.config, "pixel_decoder", False): 354 | batch_size = input_ids.shape[0] 355 | patch_logits_intermediate = self.encoder_to_decoder_proj(patch_logits_intermediate) 356 | pixel_decoder_outputs = self.pixel_decoder( 357 | inputs_embeds=patch_logits_intermediate.reshape(batch_size, -1, patch_logits_intermediate.shape[-1]), 358 | attention_mask=patch_attention_mask.reshape(batch_size, -1), 359 | ) 360 | pixel_decoder_hidden = pixel_decoder_outputs[0].reshape(-1, pixel_decoder_outputs[0].shape[-1]) # (B*PL, H) 361 | patch_logits = self.patch_head(pixel_decoder_hidden) 362 | else: 363 | patch_logits = self.patch_head(patch_logits_intermediate) # (B*PL, H) -> (B*PL, Pemb) 364 | 365 | pixel_loss = self.pixel_loss(flattened_patches.view(-1, flattened_patches.size(-1)), patch_logits, patch_attention_mask) 366 | loss = pixel_loss * self.ar_pixel_weight 367 | 368 | # for visualization later 369 | patch_logits[patch_attention_mask == 0] = 0 # those parts are not used for calculating loss, but they might have huge values that mess up the visualization 370 | patch_logits = patch_logits.reshape(flattened_patches.shape) 371 | else: 372 | pixel_loss = 0 373 | patch_logits = None 374 | 375 | if labels is not None: 376 | # Shift so that tokens < n predict n 377 | shift_logits = logits[..., :-1, :].contiguous() 378 | shift_labels = labels[..., 1:].contiguous() 379 | shift_attention_mask = attention_mask[..., 1:].contiguous() 380 | # Ignore the image tokens and the unmasked tokens 381 | shift_labels[shift_labels == self.config.patch_token_id] = -100 382 | shift_labels[shift_labels == self.config.img_begin_token_id] = -100 383 | shift_labels[shift_labels == self.config.img_end_token_id] = -100 384 | # We still let the model predict \n 385 | 386 | shift_labels[~shift_attention_mask.bool()] = -100 387 | 388 | # Flatten the tokens 389 | loss_fct = CrossEntropyLoss() 390 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 391 | shift_labels = shift_labels.view(-1) 392 | # Enable model parallelism 393 | shift_labels = shift_labels.to(shift_logits.device) 394 | 395 | text_loss = loss_fct(shift_logits, shift_labels) 396 | loss = text_loss * self.ar_text_weight if loss is None else loss + text_loss * self.ar_text_weight 397 | else: 398 | text_loss = 0 399 | 400 | if not return_dict: 401 | output = (logits,) + outputs[1:] 402 | return (loss,) + output if loss is not None else output 403 | 404 | return ScreenshotCausalLMOutputWithPast( 405 | loss=loss, 406 | logits=logits, 407 | past_key_values=outputs.past_key_values, 408 | hidden_states=outputs.hidden_states, 409 | attentions=outputs.attentions, 410 | pixel_loss=pixel_loss, 411 | text_loss=text_loss, 412 | patch_logits=patch_logits, 413 | ) 414 | 415 | def prepare_inputs_for_generation( 416 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 417 | ): 418 | if past_key_values: 419 | input_ids = input_ids[:, -1:] 420 | 421 | position_ids = kwargs.get("position_ids", None) 422 | if attention_mask is not None and position_ids is None: 423 | # create position_ids on the fly for batch generation 424 | position_ids = attention_mask.long().cumsum(-1) - 1 425 | position_ids.masked_fill_(attention_mask == 0, 1) 426 | if past_key_values: 427 | position_ids = position_ids[:, -1].unsqueeze(-1) 428 | 429 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 430 | if inputs_embeds is not None and past_key_values is None: 431 | model_inputs = {"inputs_embeds": inputs_embeds} 432 | else: 433 | model_inputs = {"input_ids": input_ids} 434 | 435 | model_inputs.update( 436 | { 437 | "position_ids": position_ids, 438 | "past_key_values": past_key_values, 439 | "use_cache": kwargs.get("use_cache"), 440 | "attention_mask": attention_mask, 441 | } 442 | ) 443 | return model_inputs 444 | 445 | @staticmethod 446 | def _reorder_cache(past_key_values, beam_idx): 447 | reordered_past = () 448 | for layer_past in past_key_values: 449 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 450 | return reordered_past 451 | 452 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 453 | 454 | def llama_forward_flash_attn( 455 | self, 456 | hidden_states: torch.Tensor, 457 | attention_mask: Optional[torch.Tensor] = None, 458 | position_ids: Optional[torch.LongTensor] = None, 459 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 460 | output_attentions: bool = False, 461 | use_cache: bool = False, 462 | **kwargs, 463 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: 464 | bsz, q_len, _ = hidden_states.size() 465 | 466 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 467 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 468 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 469 | 470 | kv_seq_len = key_states.shape[-2] 471 | if past_key_value is not None: 472 | kv_seq_len += past_key_value[0].shape[-2] 473 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 474 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 475 | # [bsz, nh, t, hd] 476 | 477 | if past_key_value is not None: 478 | # reuse k, v, self_attention 479 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 480 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 481 | 482 | past_key_value = (key_states, value_states) if use_cache else None 483 | 484 | attn_output = torch.nn.functional.scaled_dot_product_attention( 485 | query_states, key_states, value_states, 486 | attn_mask=attention_mask, # 0 or -inf 487 | is_causal=False, 488 | dropout_p=0 # llama has no dropout 489 | ) 490 | 491 | attn_output = attn_output.transpose(1, 2) 492 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 493 | 494 | attn_output = self.o_proj(attn_output) 495 | 496 | if not output_attentions: 497 | attn_weights = None 498 | 499 | return attn_output, attn_weights, past_key_value 500 | 501 | 502 | def find_module(root_module: nn.Module, key: str): 503 | """From OpenDelta""" 504 | sub_keys = key.split(".") 505 | parent_module = root_module 506 | for sub_key in sub_keys[:-1]: 507 | parent_module = getattr(parent_module, sub_key) 508 | module = getattr(parent_module, sub_keys[-1]) 509 | return parent_module, sub_keys[-1], module 510 | 511 | 512 | def inject_flash_attention_screenshotllama(model): 513 | for key, _ in model.named_modules(): 514 | attention_name = "self_attn" 515 | 516 | if key[-len(attention_name):] == attention_name: 517 | _, _, attn = find_module(model, key) 518 | print("Inject LLaMA flash attn:", key) 519 | attn.original_forward = attn.forward 520 | attn.forward = llama_forward_flash_attn.__get__(attn, type(attn)) -------------------------------------------------------------------------------- /modeling/processing_ptp.py: -------------------------------------------------------------------------------- 1 | # Modified based on 2 | # Pix2Struct image processor: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pix2struct/image_processing_pix2struct.py 3 | # Pix2Struct processor: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pix2struct/processing_pix2struct.py 4 | 5 | from typing import List, Optional, Union, Dict 6 | 7 | from transformers.processing_utils import ProcessorMixin 8 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 9 | from transformers.utils import TensorType 10 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 11 | import torch 12 | import numpy as np 13 | from transformers.image_utils import ( 14 | ChannelDimension, 15 | ImageInput, 16 | get_image_size, 17 | infer_channel_dimension_format, 18 | make_list_of_images, 19 | to_numpy_array, 20 | valid_images, 21 | ) 22 | from transformers.image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image 23 | from transformers.utils import TensorType, is_torch_available, is_vision_available, logging 24 | from transformers.utils.import_utils import requires_backends 25 | from transformers import AutoTokenizer 26 | import math 27 | import json 28 | import os 29 | 30 | class PTPImageProcessor(BaseImageProcessor): 31 | 32 | model_input_names = ["flattened_patches"] 33 | 34 | def __init__( 35 | self, 36 | do_convert_rgb: bool = True, 37 | do_normalize: bool = False, # if do_normalize, image = (image - mean) / std; otherwise image = image / 255 38 | patch_size: Dict[str, int] = None, 39 | concat_coord: bool = True, # prepend the coordinates of the patches to the patch features 40 | **kwargs, 41 | ) -> None: 42 | super().__init__(**kwargs) 43 | self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} 44 | self.do_normalize = do_normalize 45 | self.do_convert_rgb = do_convert_rgb 46 | self.concat_coord = concat_coord 47 | 48 | def extract_flattened_patches(self, image: np.ndarray, patch_size: dict, concat_coord = True, **kwargs) -> np.ndarray: 49 | 50 | requires_backends(self.extract_flattened_patches, "torch") 51 | 52 | # convert to torch 53 | image = to_channel_dimension_format(image, ChannelDimension.FIRST) 54 | image = torch.from_numpy(image) 55 | 56 | patch_height, patch_width = patch_size["height"], patch_size["width"] 57 | image_height, image_width = get_image_size(image) 58 | 59 | patches = torch_extract_patches(image, patch_height, patch_width) 60 | 61 | patches_shape = patches.shape 62 | rows = patches_shape[1] 63 | columns = patches_shape[2] 64 | depth = patches_shape[3] 65 | 66 | # [rows * columns, patch_height * patch_width * image_channels] 67 | patches = patches.reshape([rows * columns, depth]) 68 | 69 | # [rows * columns, 1] 70 | if concat_coord: 71 | row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1]) 72 | col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1]) 73 | 74 | # Offset by 1 so the ids do not contain zeros, which represent padding. 75 | row_ids += 1 76 | col_ids += 1 77 | 78 | # Prepare additional patch features. 79 | # [rows * columns, 1] 80 | row_ids = row_ids.to(torch.float32) 81 | col_ids = col_ids.to(torch.float32) 82 | 83 | # [rows * columns, 2 + patch_height * patch_width * image_channels] 84 | result = torch.cat([row_ids, col_ids, patches], -1) 85 | 86 | # [max_patches, 2 + patch_height * patch_width * image_channels] 87 | # result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float() 88 | else: 89 | result = patches 90 | 91 | result = to_numpy_array(result) 92 | 93 | return result 94 | 95 | def normalize( 96 | self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs 97 | ) -> np.ndarray: 98 | """ 99 | Normalize an image. image = (image - image_mean) / image_std. 100 | 101 | The image std is to mimic the tensorflow implementation of the `per_image_standardization`: 102 | https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization 103 | 104 | Args: 105 | image (`np.ndarray`): 106 | Image to normalize. 107 | """ 108 | if image.dtype == np.uint8: 109 | image = image.astype(np.float32) 110 | 111 | # take mean across the whole `image` 112 | mean = np.mean(image) 113 | std = np.std(image) 114 | adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape))) 115 | 116 | return normalize(image, mean=mean, std=adjusted_stddev, **kwargs) 117 | 118 | def preprocess( 119 | self, 120 | images: ImageInput, 121 | do_convert_rgb: bool = None, 122 | do_normalize: Optional[bool] = None, 123 | concat_coord: bool = None, 124 | patch_size: Optional[Dict[str, int]] = None, 125 | return_tensors: Optional[Union[str, TensorType]] = None, 126 | data_format: ChannelDimension = ChannelDimension.FIRST, 127 | **kwargs, 128 | ) -> ImageInput: 129 | """ 130 | Preprocess an image or batch of images. The processor first computes the maximum possible number of 131 | aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the 132 | image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the 133 | images are standardized following the tensorflow implementation of `per_image_standardization` 134 | (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization). 135 | 136 | 137 | Args: 138 | images (`ImageInput`): 139 | Image to preprocess. 140 | header_text (`Union[List[str], str]`, *optional*): 141 | Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`. 142 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 143 | Whether to convert the image to RGB. 144 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 145 | Whether to normalize the image. 146 | max_patches (`int`, *optional*, defaults to `self.max_patches`): 147 | Maximum number of patches to extract. 148 | patch_size (`dict`, *optional*, defaults to `self.patch_size`): 149 | Dictionary containing the patch height and width. 150 | return_tensors (`str` or `TensorType`, *optional*): 151 | The type of tensors to return. Can be one of: 152 | - Unset: Return a list of `np.ndarray`. 153 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 154 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 155 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 156 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 157 | """ 158 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 159 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 160 | patch_size = patch_size if patch_size is not None else self.patch_size 161 | concat_coord = concat_coord if concat_coord is not None else self.concat_coord 162 | 163 | if kwargs.get("data_format", None) is not None: 164 | raise ValueError("data_format is not an accepted input as the outputs are ") 165 | 166 | images = make_list_of_images(images) 167 | 168 | if not valid_images(images): 169 | raise ValueError( 170 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 171 | "torch.Tensor, tf.Tensor or jax.ndarray." 172 | ) 173 | 174 | # PIL RGBA images are converted to RGB 175 | if do_convert_rgb: 176 | images = [convert_to_rgb(image) for image in images] 177 | 178 | # All transformations expect numpy arrays. 179 | images = [to_numpy_array(image) for image in images] 180 | 181 | if do_normalize: 182 | images = [self.normalize(image=image) for image in images] 183 | else: 184 | images = [image / 255.0 for image in images] 185 | 186 | # convert to torch tensor and permute 187 | images = [ 188 | self.extract_flattened_patches(image=image, patch_size=patch_size, concat_coord=concat_coord).astype(np.float32) 189 | for image in images 190 | ] 191 | 192 | # create attention mask in numpy 193 | attention_masks = [(image.sum(axis=-1) != 0).astype(np.float32) for image in images] 194 | 195 | encoded_outputs = BatchFeature( 196 | data={"flattened_patches": images, "attention_mask": attention_masks}, tensor_type=return_tensors 197 | ) 198 | 199 | return encoded_outputs 200 | 201 | 202 | 203 | class PTPProcessor: 204 | 205 | def __init__(self, config_path, **image_processor_kwargs): 206 | self.tokenizer = AutoTokenizer.from_pretrained(config_path) 207 | self.tokenizer.return_token_type_ids = False 208 | 209 | self.image_processor_config = json.load(open(os.path.join(config_path, "preprocessor_config.json"))) 210 | self.image_processor_config.update(image_processor_kwargs) 211 | self.image_processor = PTPImageProcessor(**self.image_processor_config) 212 | 213 | @staticmethod 214 | def from_pretrained(config_path): 215 | return PTPProcessor(config_path) 216 | 217 | def save_pretrained(self, save_directory): 218 | self.tokenizer.save_pretrained(save_directory) 219 | json.dump(self.image_processor_config, open(os.path.join(save_directory, "preprocessor_config.json"), "w"), indent=4) 220 | 221 | def __call__( 222 | self, 223 | images=None, 224 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 225 | add_special_tokens: bool = True, 226 | padding: Union[bool, str, PaddingStrategy] = False, 227 | truncation: Union[bool, str, TruncationStrategy] = None, 228 | max_length: Optional[int] = None, 229 | stride: int = 0, 230 | pad_to_multiple_of: Optional[int] = None, 231 | return_attention_mask: Optional[bool] = None, 232 | return_overflowing_tokens: bool = False, 233 | return_special_tokens_mask: bool = False, 234 | return_offsets_mapping: bool = False, 235 | return_token_type_ids: bool = False, 236 | return_length: bool = False, 237 | verbose: bool = True, 238 | return_tensors: Optional[Union[str, TensorType]] = None, 239 | **kwargs, 240 | ) -> BatchEncoding: 241 | 242 | if images is None and text is None: 243 | raise ValueError("You have to specify either images or text.") 244 | 245 | # Get only text 246 | if images is None: 247 | self.current_processor = self.tokenizer 248 | text_encoding = self.tokenizer( 249 | text=text, 250 | add_special_tokens=add_special_tokens, 251 | padding=padding, 252 | truncation=truncation, 253 | max_length=max_length, 254 | stride=stride, 255 | pad_to_multiple_of=pad_to_multiple_of, 256 | return_attention_mask=return_attention_mask, 257 | return_overflowing_tokens=return_overflowing_tokens, 258 | return_special_tokens_mask=return_special_tokens_mask, 259 | return_offsets_mapping=return_offsets_mapping, 260 | return_token_type_ids=return_token_type_ids, 261 | return_length=return_length, 262 | verbose=verbose, 263 | return_tensors=return_tensors, 264 | **kwargs, 265 | ) 266 | return text_encoding 267 | 268 | # add pixel_values 269 | encoding_image_processor = self.image_processor( 270 | images, return_tensors=return_tensors, **kwargs 271 | ) 272 | 273 | if text is not None: 274 | text_encoding = self.tokenizer( 275 | text=text, 276 | add_special_tokens=add_special_tokens, 277 | padding=padding, 278 | truncation=truncation, 279 | max_length=max_length, 280 | stride=stride, 281 | pad_to_multiple_of=pad_to_multiple_of, 282 | return_attention_mask=return_attention_mask, 283 | return_overflowing_tokens=return_overflowing_tokens, 284 | return_special_tokens_mask=return_special_tokens_mask, 285 | return_offsets_mapping=return_offsets_mapping, 286 | return_token_type_ids=return_token_type_ids, 287 | return_length=return_length, 288 | verbose=verbose, 289 | return_tensors=return_tensors, 290 | **kwargs, 291 | ) 292 | 293 | if "attention_mask" in text_encoding: 294 | text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask") 295 | if "input_ids" in text_encoding: 296 | text_encoding["decoder_input_ids"] = text_encoding.pop("input_ids") 297 | else: 298 | text_encoding = None 299 | 300 | if text_encoding is not None: 301 | encoding_image_processor.update(text_encoding) 302 | 303 | return encoding_image_processor 304 | 305 | def batch_decode(self, *args, **kwargs): 306 | """ 307 | This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. 308 | Please refer to the docstring of this method for more information. 309 | """ 310 | return self.tokenizer.batch_decode(*args, **kwargs) 311 | 312 | def decode(self, *args, **kwargs): 313 | """ 314 | This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please 315 | refer to the docstring of this method for more information. 316 | """ 317 | return self.tokenizer.decode(*args, **kwargs) 318 | 319 | @property 320 | def model_input_names(self): 321 | tokenizer_input_names = self.tokenizer.model_input_names 322 | image_processor_input_names = self.image_processor.model_input_names 323 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 324 | 325 | 326 | 327 | 328 | # adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2 329 | def torch_extract_patches(image_tensor, patch_height, patch_width): 330 | """ 331 | Utiliy function to extract patches from a given image tensor. Returns a tensor of shape (1, `patch_height`, 332 | `patch_width`, `num_channels`x `patch_height` x `patch_width`) 333 | 334 | Args: 335 | image_tensor (torch.Tensor): 336 | The image tensor to extract patches from. 337 | patch_height (int): 338 | The height of the patches to extract. 339 | patch_width (int): 340 | The width of the patches to extract. 341 | """ 342 | requires_backends(torch_extract_patches, ["torch"]) 343 | 344 | image_tensor = image_tensor.unsqueeze(0) 345 | patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) 346 | patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1) 347 | patches = patches.permute(0, 4, 2, 3, 1).reshape( 348 | image_tensor.size(2) // patch_height, 349 | image_tensor.size(3) // patch_width, 350 | image_tensor.size(1) * patch_height * patch_width, 351 | ) 352 | return patches.unsqueeze(0) 353 | 354 | 355 | -------------------------------------------------------------------------------- /modeling/processing_screenshot_llama.py: -------------------------------------------------------------------------------- 1 | # Modified based on 2 | # Pix2Struct image processor: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pix2struct/image_processing_pix2struct.py 3 | # Pix2Struct processor: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pix2struct/processing_pix2struct.py 4 | 5 | from typing import List, Optional, Union, Dict 6 | 7 | from transformers.processing_utils import ProcessorMixin 8 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 9 | from transformers.utils import TensorType 10 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 11 | import torch 12 | import numpy as np 13 | from transformers.image_utils import ( 14 | ChannelDimension, 15 | ImageInput, 16 | get_image_size, 17 | infer_channel_dimension_format, 18 | make_list_of_images, 19 | to_numpy_array, 20 | valid_images, 21 | ) 22 | from transformers.image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image 23 | from transformers.utils import TensorType, is_torch_available, is_vision_available, logging 24 | from transformers.utils.import_utils import requires_backends 25 | from transformers import AutoTokenizer 26 | import math 27 | import json 28 | import os 29 | 30 | PATCH_ID = 0 31 | NEWLINE_ID = 1 32 | IMG_BEGIN_ID = 2 33 | IMG_END_ID = 3 34 | 35 | 36 | class ScreenshotLlamaImageProcessor(BaseImageProcessor): 37 | 38 | model_input_names = ["flattened_patches"] 39 | 40 | def __init__( 41 | self, 42 | do_convert_rgb: bool = True, 43 | do_normalize: bool = False, # if do_normalize, image = (image - mean) / std; otherwise image = image / 255 44 | patch_size: Dict[str, int] = None, 45 | concat_coord: bool = True, # prepend the coordinates of the patches to the patch features 46 | **kwargs, 47 | ) -> None: 48 | super().__init__(**kwargs) 49 | self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} 50 | self.do_normalize = do_normalize 51 | self.do_convert_rgb = do_convert_rgb 52 | self.concat_coord = concat_coord 53 | 54 | def extract_flattened_patches(self, image: np.ndarray, patch_size: dict, concat_coord = True, **kwargs) -> np.ndarray: 55 | 56 | requires_backends(self.extract_flattened_patches, "torch") 57 | 58 | # convert to torch 59 | image = to_channel_dimension_format(image, ChannelDimension.FIRST) 60 | image = torch.from_numpy(image) 61 | 62 | patch_height, patch_width = patch_size["height"], patch_size["width"] 63 | image_height, image_width = get_image_size(image) 64 | 65 | patches = torch_extract_patches(image, patch_height, patch_width) 66 | 67 | patches_shape = patches.shape 68 | rows = patches_shape[1] 69 | columns = patches_shape[2] 70 | depth = patches_shape[3] 71 | 72 | # [rows * columns, patch_height * patch_width * image_channels] 73 | patches = patches.reshape([rows * columns, depth]) 74 | 75 | # [rows * columns, 1] 76 | if concat_coord: 77 | row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1]) 78 | col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1]) 79 | 80 | # Offset by 1 so the ids do not contain zeros, which represent padding. 81 | row_ids += 1 82 | col_ids += 1 83 | 84 | # Prepare additional patch features. 85 | # [rows * columns, 1] 86 | row_ids = row_ids.to(torch.float32) 87 | col_ids = col_ids.to(torch.float32) 88 | 89 | # [rows * columns, 2 + patch_height * patch_width * image_channels] 90 | result = torch.cat([row_ids, col_ids, patches], -1) 91 | 92 | # [max_patches, 2 + patch_height * patch_width * image_channels] 93 | # result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float() 94 | else: 95 | result = patches 96 | 97 | # result: the actual patch pixel values 98 | # image_token_ids: the special tokens corresponding to the image, including the patch token, the newline token, and the image start/end tokens 99 | result = to_numpy_array(result) 100 | image_token_ids = np.zeros((rows, columns + 1), dtype=np.int32) + PATCH_ID 101 | image_token_ids[:, -1] = NEWLINE_ID 102 | image_token_ids = [IMG_BEGIN_ID] + image_token_ids.flatten().tolist() + [IMG_END_ID] 103 | 104 | return result, image_token_ids 105 | 106 | def normalize( 107 | self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs 108 | ) -> np.ndarray: 109 | """ 110 | Normalize an image. image = (image - image_mean) / image_std. 111 | 112 | The image std is to mimic the tensorflow implementation of the `per_image_standardization`: 113 | https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization 114 | 115 | Args: 116 | image (`np.ndarray`): 117 | Image to normalize. 118 | """ 119 | if image.dtype == np.uint8: 120 | image = image.astype(np.float32) 121 | 122 | # take mean across the whole `image` 123 | mean = np.mean(image) 124 | std = np.std(image) 125 | adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape))) 126 | 127 | return normalize(image, mean=mean, std=adjusted_stddev, **kwargs) 128 | 129 | def preprocess( 130 | self, 131 | images: ImageInput, 132 | do_convert_rgb: bool = None, 133 | do_normalize: Optional[bool] = None, 134 | concat_coord: bool = None, 135 | patch_size: Optional[Dict[str, int]] = None, 136 | return_tensors: Optional[Union[str, TensorType]] = None, 137 | data_format: ChannelDimension = ChannelDimension.FIRST, 138 | **kwargs, 139 | ) -> ImageInput: 140 | """ 141 | Preprocess an image or batch of images. The processor first computes the maximum possible number of 142 | aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the 143 | image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the 144 | images are standardized following the tensorflow implementation of `per_image_standardization` 145 | (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization). 146 | 147 | 148 | Args: 149 | images (`ImageInput`): 150 | Image to preprocess. 151 | header_text (`Union[List[str], str]`, *optional*): 152 | Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`. 153 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 154 | Whether to convert the image to RGB. 155 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 156 | Whether to normalize the image. 157 | max_patches (`int`, *optional*, defaults to `self.max_patches`): 158 | Maximum number of patches to extract. 159 | patch_size (`dict`, *optional*, defaults to `self.patch_size`): 160 | Dictionary containing the patch height and width. 161 | return_tensors (`str` or `TensorType`, *optional*): 162 | The type of tensors to return. Can be one of: 163 | - Unset: Return a list of `np.ndarray`. 164 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 165 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 166 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 167 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 168 | """ 169 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 170 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 171 | patch_size = patch_size if patch_size is not None else self.patch_size 172 | concat_coord = concat_coord if concat_coord is not None else self.concat_coord 173 | 174 | if kwargs.get("data_format", None) is not None: 175 | raise ValueError("data_format is not an accepted input as the outputs are ") 176 | 177 | images = make_list_of_images(images) 178 | 179 | if not valid_images(images): 180 | raise ValueError( 181 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 182 | "torch.Tensor, tf.Tensor or jax.ndarray." 183 | ) 184 | 185 | # PIL RGBA images are converted to RGB 186 | if do_convert_rgb: 187 | images = [convert_to_rgb(image) for image in images] 188 | 189 | # All transformations expect numpy arrays. 190 | images = [to_numpy_array(image) for image in images] 191 | 192 | if do_normalize: 193 | images = [self.normalize(image=image) for image in images] 194 | else: 195 | images = [image / 255.0 for image in images] 196 | 197 | # convert to torch tensor and permute 198 | images_and_token_ids = [ 199 | self.extract_flattened_patches(image=image, patch_size=patch_size, concat_coord=concat_coord) 200 | for image in images 201 | ] 202 | images = [image_and_token_ids[0].astype(np.float32) for image_and_token_ids in images_and_token_ids] 203 | image_token_ids = [image_and_token_ids[1] for image_and_token_ids in images_and_token_ids] 204 | 205 | # create attention mask in numpy 206 | attention_masks = [(image.sum(axis=-1) != 0).astype(np.float32) for image in images] 207 | 208 | encoded_outputs = BatchFeature( 209 | data={"flattened_patches": images, "attention_mask": attention_masks, "image_raw_token_ids": image_token_ids}, tensor_type=return_tensors 210 | ) 211 | 212 | return encoded_outputs 213 | 214 | 215 | 216 | class ScreenshotLlamaProcessor: 217 | 218 | def __init__(self, config_path, **image_processor_kwargs): 219 | self.tokenizer = AutoTokenizer.from_pretrained(config_path) 220 | self.tokenizer.return_token_type_ids = False 221 | 222 | self.tokenizer.add_tokens("", special_tokens=True) 223 | self.tokenizer.add_tokens("", special_tokens=True) 224 | self.tokenizer.add_tokens("", special_tokens=True) 225 | 226 | self.img_begin_token_id = self.tokenizer.encode("")[-1] 227 | self.img_end_token_id = self.tokenizer.encode("")[-1] 228 | self.patch_token_id = self.tokenizer.encode("")[-1] 229 | self.newline_token_id = self.tokenizer.encode("\n")[-1] 230 | 231 | self.image_processor_config = json.load(open(os.path.join(config_path, "preprocessor_config.json"))) 232 | self.image_processor_config.update(image_processor_kwargs) 233 | self.image_processor = ScreenshotLlamaImageProcessor(**self.image_processor_config) 234 | 235 | def save_pretrained(self, save_directory): 236 | self.tokenizer.save_pretrained(save_directory) 237 | json.dump(self.image_processor_config, open(os.path.join(save_directory, "preprocessor_config.json"), "w"), indent=4) 238 | 239 | def __call__( 240 | self, 241 | images=None, 242 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 243 | add_special_tokens: bool = True, 244 | padding: Union[bool, str, PaddingStrategy] = False, 245 | truncation: Union[bool, str, TruncationStrategy] = None, 246 | max_length: Optional[int] = None, 247 | stride: int = 0, 248 | pad_to_multiple_of: Optional[int] = None, 249 | return_attention_mask: Optional[bool] = None, 250 | return_overflowing_tokens: bool = False, 251 | return_special_tokens_mask: bool = False, 252 | return_offsets_mapping: bool = False, 253 | return_token_type_ids: bool = False, 254 | return_length: bool = False, 255 | verbose: bool = True, 256 | return_tensors: Optional[Union[str, TensorType]] = None, 257 | **kwargs, 258 | ) -> BatchEncoding: 259 | """ 260 | This method uses [`Pix2StructImageProcessor.preprocess`] method to prepare image(s) for the model, and 261 | [`T5TokenizerFast.__call__`] to prepare text for the model. 262 | 263 | Please refer to the docstring of the above two methods for more information. 264 | """ 265 | if images is None and text is None: 266 | raise ValueError("You have to specify either images or text.") 267 | 268 | # Get only text 269 | if images is None: 270 | self.current_processor = self.tokenizer 271 | text_encoding = self.tokenizer( 272 | text=text, 273 | add_special_tokens=add_special_tokens, 274 | padding=padding, 275 | truncation=truncation, 276 | max_length=max_length, 277 | stride=stride, 278 | pad_to_multiple_of=pad_to_multiple_of, 279 | return_attention_mask=return_attention_mask, 280 | return_overflowing_tokens=return_overflowing_tokens, 281 | return_special_tokens_mask=return_special_tokens_mask, 282 | return_offsets_mapping=return_offsets_mapping, 283 | return_token_type_ids=return_token_type_ids, 284 | return_length=return_length, 285 | verbose=verbose, 286 | return_tensors=return_tensors, 287 | **kwargs, 288 | ) 289 | return text_encoding 290 | 291 | # add pixel_values 292 | encoding_image_processor = self.image_processor( 293 | images, return_tensors=return_tensors, **kwargs 294 | ) 295 | 296 | if text is not None: 297 | text_encoding = self.tokenizer( 298 | text=text, 299 | add_special_tokens=add_special_tokens, 300 | padding=padding, 301 | truncation=truncation, 302 | max_length=max_length, 303 | stride=stride, 304 | pad_to_multiple_of=pad_to_multiple_of, 305 | return_attention_mask=return_attention_mask, 306 | return_overflowing_tokens=return_overflowing_tokens, 307 | return_special_tokens_mask=return_special_tokens_mask, 308 | return_offsets_mapping=return_offsets_mapping, 309 | return_token_type_ids=return_token_type_ids, 310 | return_length=return_length, 311 | verbose=verbose, 312 | return_tensors=return_tensors, 313 | **kwargs, 314 | ) 315 | 316 | if "attention_mask" in text_encoding: 317 | text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask") 318 | if "input_ids" in text_encoding: 319 | text_encoding["decoder_input_ids"] = text_encoding.pop("input_ids") 320 | else: 321 | text_encoding = None 322 | 323 | if text_encoding is not None: 324 | encoding_image_processor.update(text_encoding) 325 | 326 | image_input_ids = encoding_image_processor.pop("image_raw_token_ids") 327 | image_input_ids[image_input_ids == IMG_BEGIN_ID] = self.img_begin_token_id 328 | image_input_ids[image_input_ids == IMG_END_ID] = self.img_end_token_id 329 | image_input_ids[image_input_ids == PATCH_ID] = self.patch_token_id 330 | image_input_ids[image_input_ids == NEWLINE_ID] = self.newline_token_id 331 | 332 | encoding_image_processor["image_input_ids"] = image_input_ids 333 | 334 | return encoding_image_processor 335 | 336 | def batch_decode(self, *args, **kwargs): 337 | """ 338 | This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. 339 | Please refer to the docstring of this method for more information. 340 | """ 341 | return self.tokenizer.batch_decode(*args, **kwargs) 342 | 343 | def decode(self, *args, **kwargs): 344 | """ 345 | This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please 346 | refer to the docstring of this method for more information. 347 | """ 348 | return self.tokenizer.decode(*args, **kwargs) 349 | 350 | @property 351 | def model_input_names(self): 352 | tokenizer_input_names = self.tokenizer.model_input_names 353 | image_processor_input_names = self.image_processor.model_input_names 354 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 355 | 356 | 357 | 358 | 359 | # adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2 360 | def torch_extract_patches(image_tensor, patch_height, patch_width): 361 | """ 362 | Utiliy function to extract patches from a given image tensor. Returns a tensor of shape (1, `patch_height`, 363 | `patch_width`, `num_channels`x `patch_height` x `patch_width`) 364 | 365 | Args: 366 | image_tensor (torch.Tensor): 367 | The image tensor to extract patches from. 368 | patch_height (int): 369 | The height of the patches to extract. 370 | patch_width (int): 371 | The width of the patches to extract. 372 | """ 373 | requires_backends(torch_extract_patches, ["torch"]) 374 | 375 | image_tensor = image_tensor.unsqueeze(0) 376 | patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) 377 | patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1) 378 | patches = patches.permute(0, 4, 2, 3, 1).reshape( 379 | image_tensor.size(2) // patch_height, 380 | image_tensor.size(3) // patch_width, 381 | image_tensor.size(1) * patch_height * patch_width, 382 | ) 383 | return patches.unsqueeze(0) 384 | 385 | 386 | -------------------------------------------------------------------------------- /modeling/sincos_pos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_2d_sincos_pos_embed(embed_dim, seq_len): 5 | """ 6 | Create 2D sin/cos positional embeddings. 7 | 8 | Args: 9 | embed_dim (`int`): 10 | Embedding dimension. 11 | grid_size (`int`): 12 | The grid height and width. 13 | add_cls_token (`bool`, *optional*, defaults to `False`): 14 | Whether or not to add a classification (CLS) token. 15 | 16 | Returns: 17 | (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the 18 | position embeddings (with or without classification token) 19 | """ 20 | grid_h = np.arange(seq_len, dtype=np.float32) 21 | grid_w = np.arange(seq_len, dtype=np.float32) 22 | 23 | row_embed = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H, D/2) 24 | column_embed = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (W, D/2) 25 | 26 | return row_embed, column_embed 27 | 28 | 29 | 30 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 31 | """ 32 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 33 | """ 34 | if embed_dim % 2 != 0: 35 | raise ValueError("embed_dim must be even") 36 | 37 | omega = np.arange(embed_dim // 2, dtype=float) 38 | omega /= embed_dim / 2.0 39 | omega = 1.0 / 10000**omega # (D/2,) 40 | 41 | pos = pos.reshape(-1) # (M,) 42 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 43 | 44 | emb_sin = np.sin(out) # (M, D/2) 45 | emb_cos = np.cos(out) # (M, D/2) 46 | 47 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 48 | return emb 49 | -------------------------------------------------------------------------------- /modeling/span_masking.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm, DINO and DeiT code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 10 | # Copyright Zhun Zhong & Liang Zheng 11 | # 12 | # Hacked together by / Copyright 2020 Ross Wightman 13 | # 14 | # Modified by Hangbo Bao, for generating the masked position for visual image transformer 15 | # 16 | # Further modified by Phillip Rust, for masking spans in rendered text inputs used with PIXEL 17 | # --------------------------------------------------------' 18 | import random 19 | from typing import List, Optional, Union 20 | 21 | import numpy as np 22 | 23 | 24 | class SpanMaskingGenerator: 25 | """ 26 | Generator class that yields span masks 27 | 28 | Args: 29 | num_patches (`int`): 30 | The total number of images patches 31 | num_masking_patches (`int`, defaults to 1): 32 | The number of patches to be masked out. Typically determined by the masking ratio 33 | max_span_length (`int`, defaults to 6): 34 | The maximum number of consecutive masked patches 35 | spacing (`Union[int, str]`, default to 0): 36 | The number of non-masked patches in between consecutive masked spans. Can either be an integer value, 37 | in which case the spacing is fixed, or can be set to "span" in which case the spacing is dynamic such 38 | that on both sides of a masked span of length N patches, there will be N non-masked patches. Note that 39 | larger spacing makes it harder to greedily sample masks satisfying these constraints which can slow down 40 | masking and also cause the algorithm to terminate with a smaller mask than specified. In case of the 41 | latter, PIXEL randomly masks additional patches until the specified masking ratio is reached. 42 | 43 | These are the recommended settings: 44 | - For masking ratio <= 0.4 use "span" mode. 45 | - For ratios between 0.4 and 0.7 set spacing to 1. 46 | - For higher, set spacing to 0 47 | """ 48 | 49 | def __init__( 50 | self, 51 | num_patches: int, 52 | num_masking_patches: int = 1, 53 | max_span_length: int = 6, 54 | spacing: Union[int, str] = 0, 55 | cumulative_span_weights: Optional[List[float]] = None, 56 | ): 57 | 58 | self.num_patches = num_patches 59 | self.num_masking_patches = num_masking_patches 60 | 61 | self.max_span_length = max_span_length 62 | self.spacing = spacing 63 | assert spacing == "span" or isinstance(spacing, int) 64 | 65 | self.span_range = range(1, max_span_length + 1) 66 | self.cumulative_span_weights = cumulative_span_weights 67 | self.num_text_patches = None 68 | 69 | def _mask(self, mask, max_mask_patches): 70 | delta = 0 71 | # Lower number of attempts will speed up mask generation but might cause a lot fewer patches to be masked 72 | # than desired, particularly for high masking ratios 73 | for attempt in range(100): 74 | # Randomly select span length within specified range 75 | span = random.choices(self.span_range, cum_weights=self.cumulative_span_weights, k=1)[0] 76 | if span < self.num_patches: 77 | # This is only the case in the first iteration 78 | if self.num_text_patches is not None: 79 | # Select a span where there is text 80 | # This guarantees that we never generate a mask that only masks out padding 81 | left = random.randint(0, max(0, self.num_text_patches - span)) 82 | # self.num_text_patches = None 83 | else: 84 | # Start at random horizontal index 85 | left = random.randint(0, self.num_patches - span) 86 | 87 | space = span if self.spacing == "span" else self.spacing 88 | # Ensure no patches within patches to the left are masked 89 | if space != 0: 90 | num_masked_left = mask[max(0, left - space) : left].sum() 91 | if num_masked_left > 0: 92 | continue 93 | # Ensure no patches within patches to the right are masked 94 | num_masked_right = mask[left + span : min(left + span + space, self.num_patches)].sum() 95 | if num_masked_right > 0: 96 | continue 97 | 98 | # Account for overlap 99 | num_masked_within = mask[left : left + span].sum() 100 | if 0 < span - num_masked_within <= max_mask_patches: 101 | for j in range(left, left + span): 102 | if mask[j] == 0: 103 | mask[j] = 1 104 | delta += 1 105 | 106 | if delta > 0: 107 | break 108 | return delta 109 | 110 | def __call__(self, num_text_patches=None): # , num_text_patches: int): 111 | # Start with an empty mask 112 | mask = np.zeros(shape=self.num_patches, dtype=np.int8) 113 | self.num_text_patches = None # num_text_patches 114 | 115 | # Greedily try to add mask patches until desired number of masked patches is reached 116 | mask_count = 0 117 | while mask_count < self.num_masking_patches: 118 | max_mask_patches = self.num_masking_patches - mask_count 119 | max_mask_patches = min(max_mask_patches, self.max_span_length) 120 | 121 | # We attempt to add a span to our mask up to 100 times 122 | delta = self._mask(mask, max_mask_patches) 123 | 124 | if delta == 0: 125 | # We terminate when no new span could be added to the mask after 100 attempts 126 | # This can happen before self.num_masking_patches is reached for high masking ratios with 127 | # strong constraints 128 | break 129 | else: 130 | mask_count += delta 131 | 132 | return mask 133 | 134 | 135 | class SpanMaskingGeneratorT5: 136 | """ 137 | Generator class that yields span masks 138 | 139 | Args: 140 | num_patches (`int`): 141 | The total number of images patches 142 | num_masking_patches (`int`, defaults to 1): 143 | The number of patches to be masked out. Typically determined by the masking ratio 144 | max_span_length (`int`, defaults to 6): 145 | The maximum number of consecutive masked patches 146 | spacing (`Union[int, str]`, default to 0): 147 | The number of non-masked patches in between consecutive masked spans. Can either be an integer value, 148 | in which case the spacing is fixed, or can be set to "span" in which case the spacing is dynamic such 149 | that on both sides of a masked span of length N patches, there will be N non-masked patches. Note that 150 | larger spacing makes it harder to greedily sample masks satisfying these constraints which can slow down 151 | masking and also cause the algorithm to terminate with a smaller mask than specified. In case of the 152 | latter, PIXEL randomly masks additional patches until the specified masking ratio is reached. 153 | 154 | These are the recommended settings: 155 | - For masking ratio <= 0.4 use "span" mode. 156 | - For ratios between 0.4 and 0.7 set spacing to 1. 157 | - For higher, set spacing to 0 158 | """ 159 | 160 | def __init__( 161 | self, 162 | num_patches: int, 163 | num_masking_patches: int = 1, 164 | max_span_length: int = 6, 165 | spacing: Union[int, str] = 0, 166 | cumulative_span_weights: Optional[List[float]] = None, 167 | ): 168 | 169 | self.num_patches = num_patches 170 | self.num_masking_patches = num_masking_patches 171 | 172 | self.max_span_length = max_span_length 173 | self.mean_span_length = max_span_length // 2 174 | self.spacing = spacing 175 | assert spacing == "span" or isinstance(spacing, int) 176 | 177 | self.span_range = range(1, max_span_length + 1) 178 | self.cumulative_span_weights = cumulative_span_weights 179 | self.num_text_patches = None 180 | 181 | 182 | def __call__(self, num_text_patches=None): # , num_text_patches: int): 183 | # Start with an empty mask 184 | mask = np.zeros(shape=self.num_patches, dtype=np.int8) 185 | spans = round(self.num_masking_patches / self.mean_span_length) 186 | 187 | def random_span_lengths(num_spans, num_tokens): 188 | indices = np.arange(1,num_tokens) 189 | np.random.shuffle(indices) 190 | span_starts = np.sort(indices[:num_spans-1]) 191 | span_lengths = np.diff(span_starts, prepend=0, append=num_tokens) 192 | return span_lengths 193 | 194 | unmasked_lengths = random_span_lengths(spans, self.num_patches - self.num_masking_patches) 195 | masked_lengths = random_span_lengths(spans, self.num_masking_patches) 196 | 197 | idc = [] 198 | last = 0 199 | for unmasked_length, masked_length in zip(unmasked_lengths, masked_lengths): 200 | idc.extend(range(last + unmasked_length, last + unmasked_length + masked_length)) 201 | last = last + unmasked_length + masked_length 202 | 203 | mask[np.array(idc)] = 1 204 | 205 | return mask 206 | -------------------------------------------------------------------------------- /rendering/src/Makefile: -------------------------------------------------------------------------------- 1 | CC=g++ 2 | CFLAGS=-I. -I/usr/include/freetype2 -I/usr/include/libpng16 -Wall \ 3 | -L/usr/local/lib -lfreetype 4 | PBFLAGS=-shared -std=c++11 -fPIC $(shell python3 -m pybind11 --includes) 5 | DEPS = $(wildcard *.h) 6 | SRCS = $(wildcard *.cpp) 7 | # SRCS := $(filter-out renderer.cpp, $(SRCS)) 8 | OBJS = $(patsubst %.cpp, %.o, $(SRCS)) 9 | target=renderer$(shell python3-config --extension-suffix) 10 | 11 | %.o: %.cpp $(DEPS) 12 | $(CC) -c -o $@ $< $(CFLAGS) $(PBFLAGS) 13 | 14 | all: $(target) # renderer 15 | 16 | renderer: $(OBJS) 17 | $(CC) -o $@ $^ $(CFLAGS) 18 | 19 | $(target): $(OBJS) 20 | $(CC) -o $(target) $^ $(CFLAGS) $(PBFLAGS) 21 | rm *.o 22 | 23 | .PHONY : clean 24 | clean: 25 | rm renderer renderer.cpython* -------------------------------------------------------------------------------- /rendering/src/README.md: -------------------------------------------------------------------------------- 1 | # Renderer 2 | 3 | Our renderer is written in c++ and can be compiled to a python package by using `pybind11`. We already compiled two versions for python 3.9 (`renderer.cpython-39-x86_64-linux-gnu.so`) and python 3.10 (`renderer.cpython-310-x86_64-linux-gnu.so`), which means that if you are using python 3.9/3.10, you do not need to do anything. 4 | 5 | If you are using a different python version or find that our renderer does not work on your machine, you can compile it from scratch by simply running the `make` command in this folder. -------------------------------------------------------------------------------- /rendering/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "render.h" 3 | 4 | #define HEIGHT 256 5 | #define WIDTH 512 6 | 7 | using namespace std; 8 | 9 | int main() { 10 | FT_Library library; 11 | FT_Face face; 12 | 13 | init_ft(library); 14 | init_face(library, face); 15 | 16 | string teststring = "This is a placeholder for the main code."; 17 | unsigned char array[HEIGHT*WIDTH]; 18 | 19 | render_text(face, teststring, (unsigned char *)array, HEIGHT, WIDTH); 20 | show_image_in_terminal(array, HEIGHT, WIDTH); 21 | 22 | finish(library, face); 23 | 24 | return 0; 25 | } -------------------------------------------------------------------------------- /rendering/src/render.cpp: -------------------------------------------------------------------------------- 1 | #include "render.h" 2 | #include 3 | #include 4 | #include 5 | 6 | void init_ft(FT_Library &library) { 7 | int error = FT_Init_FreeType(&library); 8 | if (error) { 9 | std::cerr << "Failed to initialize FreeType! (" << error << ")\n"; 10 | exit(-1); 11 | } 12 | } 13 | 14 | void init_face(FT_Library library, FT_Face &face, std::string fontpath, 15 | unsigned int fontsize) { 16 | int error = 0; 17 | 18 | error = FT_New_Face(library, fontpath.c_str(), 0, &face); 19 | if (error) { 20 | std::cerr << "Failed to initialize Face! (" << error << ")\n"; 21 | exit(-1); 22 | } 23 | 24 | error = FT_Set_Char_Size(face, 0, fontsize*64, 100, 0); 25 | if (error) { 26 | std::cerr << "Failed to set font size!(" << error << ")\n";; 27 | exit(-1); 28 | } 29 | } 30 | 31 | void finish(FT_Library &library, FT_Face &face) { 32 | FT_Done_Face(face); 33 | FT_Done_FreeType(library); 34 | } 35 | 36 | int render_single_character_on_array(unsigned char *array, 37 | unsigned int height, unsigned int width, FT_Bitmap *bitmap, 38 | unsigned int x, unsigned int y, unsigned int bearingX=0) { 39 | /** 40 | * NOTE: Currently, overwrites the bounding box of the character. 41 | * To turn this behavior off, the assignment operator for the array 42 | * should be changed to an OR (|). 43 | * Returns 0 on success, -1 on exceeding width, -2 on exceeding height. 44 | */ 45 | unsigned int bwidth = bitmap->width, brows = bitmap->rows; 46 | if (bearingX > x) 47 | bearingX = 0; 48 | if (x + bwidth > width + bearingX) 49 | return -1; 50 | else if (y + brows > height) { 51 | // return -2; // This cuts off some characters in the last 52 | // line -- instead partially render the character 53 | // if possible 54 | if (y >= height) 55 | return 0; // Assume succesful - the caller will ensure to break 56 | brows = height - y; 57 | } 58 | 59 | for (unsigned int i = 0; i < bwidth; i++) { 60 | for (unsigned int j = 0; j < brows; j++) { 61 | array[(y+j)*width+(x+i-bearingX)] = 62 | bitmap->buffer[j*bwidth+i]; 63 | } 64 | } 65 | return 0; 66 | } 67 | 68 | std::string render_text(FT_Face &face, std::string text, unsigned char *array, 69 | unsigned int height, unsigned int width, unsigned int fontsize, 70 | int line_space, bool fixed_width, bool fix_spacing, bool no_partial) { 71 | int len = text.length(); 72 | int cur = 0; 73 | FT_UInt glyph_index; 74 | FT_GlyphSlot slot = face->glyph; 75 | 76 | unsigned int vertical_stride = fontsize; 77 | if (line_space == -1) 78 | vertical_stride += (fontsize+1)/2; 79 | else 80 | vertical_stride += ((unsigned int)line_space); 81 | unsigned int maxhoriY = 0; //, minhoriY = 100; 82 | unsigned int widths[len]; 83 | 84 | if (fixed_width) { 85 | std::unordered_map cache; 86 | for (int i = 0; i < len; i++) { 87 | if (cache.find(text[i]) != cache.end()) { 88 | widths[i] = cache[text[i]]; 89 | } else { 90 | glyph_index = FT_Get_Char_Index(face, text[i]); 91 | FT_Load_Glyph(face, glyph_index, FT_LOAD_NO_BITMAP); 92 | widths[i] = slot->advance.x >> 6; 93 | cache[text[i]] = widths[i]; 94 | unsigned int horiY = ((unsigned int) 95 | slot->metrics.horiBearingY) >> 6; 96 | if (horiY <= 50) 97 | maxhoriY = std::max(maxhoriY, horiY); 98 | // if (horiY > 0) 99 | // minhoriY = std::min(horiY, minhoriY); 100 | } 101 | } 102 | } else { 103 | for (int i = 0; i < len; i++) { 104 | glyph_index = FT_Get_Char_Index(face, text[i]); 105 | FT_Load_Glyph(face, glyph_index, FT_LOAD_NO_BITMAP); 106 | widths[i] = slot->advance.x >> 6; 107 | unsigned int horiY = ((unsigned int)slot->metrics.horiBearingY) >> 6; 108 | if (horiY <= 50) 109 | maxhoriY = std::max(maxhoriY, horiY); 110 | // if (horiY > 0) 111 | // minhoriY = std::min(horiY, minhoriY); 112 | } 113 | } 114 | 115 | memset(array, 0, height*width); 116 | unsigned int leftmargin = 3; 117 | unsigned int topmargin = 3; 118 | unsigned int x = leftmargin; 119 | unsigned int y = topmargin; 120 | 121 | unsigned int lasthoriX = 0; 122 | 123 | while (cur < len) { 124 | if (text[cur] == '\n') { 125 | // Just move to the beginning of the new line 126 | y += vertical_stride; 127 | x = leftmargin; 128 | cur++; 129 | continue; 130 | } 131 | 132 | // If the previous character was a space and the current one is 133 | // not, we need to decide now if we should move to the next line. 134 | if (cur && (text[cur-1] == ' ' || text[cur-1] == '\t') 135 | && (text[cur] != ' ' && text[cur] != '\t')) { 136 | unsigned int cur_width = 0; 137 | for (int i = cur; i < len && text[i] != ' ' && 138 | text[i] != '\t'; i++) { 139 | cur_width += widths[i]; 140 | if (cur_width > width) 141 | break; 142 | } 143 | if (x+cur_width > width) { 144 | y += vertical_stride; 145 | x = leftmargin; 146 | } 147 | } 148 | 149 | // Try rendering the current character 150 | glyph_index = FT_Get_Char_Index(face, text[cur]); 151 | if (FT_Load_Glyph(face, glyph_index, FT_LOAD_RENDER)) { 152 | // Can't render - move on 153 | cur++; 154 | continue; 155 | } 156 | 157 | unsigned int horiY = 158 | ((unsigned int)slot->metrics.horiBearingY) >> 6; 159 | unsigned int yoffset = maxhoriY - 160 | ((horiY > 50)? maxhoriY : horiY); 161 | 162 | unsigned int compY = (no_partial? y + maxhoriY : y); 163 | if (compY >= height) 164 | break; 165 | 166 | unsigned int bearingX; 167 | if (fix_spacing) { 168 | unsigned int horiX = 169 | ((unsigned int)slot->metrics.horiBearingX) >> 6; 170 | if (horiX > 100) 171 | horiX = 0; 172 | lasthoriX = std::max(lasthoriX, horiX); 173 | bearingX = lasthoriX - horiX; 174 | } else 175 | bearingX = 0; 176 | 177 | int status = render_single_character_on_array(array, height, 178 | width, &slot->bitmap, x, y + yoffset, bearingX); 179 | if (status == -1) { 180 | // Exceeding the specified width, try again on a new line 181 | y += vertical_stride; 182 | x = leftmargin; 183 | status = render_single_character_on_array(array, height, 184 | width, &slot->bitmap, x, y + yoffset, bearingX); 185 | } 186 | 187 | if (status == 0) { 188 | // We succeeded, so update (otherwise we skip the character) 189 | x += slot->advance.x >> 6; 190 | y += slot->advance.y >> 6; // zero 191 | if (x > width) { 192 | x = leftmargin; 193 | y += vertical_stride; 194 | } 195 | } 196 | cur++; 197 | } 198 | 199 | return text.substr(0, cur); 200 | } 201 | 202 | std::string render_text_unicode(FT_Face &face, std::string text, 203 | unsigned char *array, unsigned int height, unsigned int width, 204 | unsigned int fontsize, int line_space, bool fixed_width, 205 | bool fix_spacing, bool no_partial, bool no_margin, bool fixed_offset) { 206 | int len = text.length(); 207 | int cur = 0; 208 | FT_UInt glyph_index; 209 | FT_GlyphSlot slot = face->glyph; 210 | 211 | unsigned int vertical_stride = fontsize; 212 | if (line_space == -1) 213 | vertical_stride += (fontsize+1)/2; 214 | else 215 | vertical_stride += ((unsigned int)line_space); 216 | unsigned int maxhoriY = 0; //, minhoriY = 100; 217 | unsigned int widths[len]; 218 | 219 | unsigned char *src = (unsigned char *)text.c_str(); 220 | unsigned int decoded[len]; 221 | int cur_e = 0, cur_d = 0; 222 | while (cur_e < len) { 223 | if (src[cur_e] >> 7 == 0) 224 | decoded[cur_d++] = src[cur_e++]; 225 | else if (((src[cur_e] >> 5)&1) == 0) { 226 | decoded[cur_d++] = 227 | ((src[cur_e] & 31) << 6) | (src[cur_e+1] & 63); 228 | cur_e += 2; 229 | } else if (((src[cur_e] >> 4)&1) == 0) { 230 | decoded[cur_d++] = 231 | ((src[cur_e] & 15) << 12) | 232 | ((src[cur_e+1] & 63) << 6) | 233 | (src[cur_e+2] & 63); 234 | cur_e += 3; 235 | } else { 236 | decoded[cur_d++] = 237 | ((src[cur_e] & 7) << 18) | 238 | ((src[cur_e+1] & 63) << 12) | 239 | ((src[cur_e+2] & 63) << 6) | 240 | (src[cur_e+3] & 63); 241 | cur_e += 4; 242 | } 243 | } 244 | 245 | len = cur_d; 246 | FT_Select_Charmap(face, FT_ENCODING_UNICODE); 247 | 248 | if (fixed_width) { 249 | std::unordered_map cache; 250 | for (int i = 0; i < len; i++) { 251 | if (cache.find(decoded[i]) != cache.end()) { 252 | widths[i] = cache[decoded[i]]; 253 | } else { 254 | glyph_index = FT_Get_Char_Index(face, decoded[i]); 255 | FT_Load_Glyph(face, glyph_index, FT_LOAD_NO_BITMAP); 256 | widths[i] = slot->advance.x >> 6; 257 | cache[decoded[i]] = widths[i]; 258 | unsigned int horiY = ((unsigned int) 259 | slot->metrics.horiBearingY) >> 6; 260 | if (horiY <= 50) 261 | maxhoriY = std::max(maxhoriY, horiY); 262 | // if (horiY > 0) 263 | // minhoriY = std::min(minhoriY, horiY); 264 | } 265 | } 266 | } else { 267 | for (int i = 0; i < len; i++) { 268 | glyph_index = FT_Get_Char_Index(face, decoded[i]); 269 | FT_Load_Glyph(face, glyph_index, FT_LOAD_NO_BITMAP); 270 | widths[i] = slot->advance.x >> 6; 271 | unsigned int horiY = ((unsigned int)slot->metrics.horiBearingY) >> 6; 272 | if (horiY <= 50) 273 | maxhoriY = std::max(maxhoriY, horiY); 274 | // if (horiY > 0) 275 | // minhoriY = std::min(minhoriY, horiY); 276 | } 277 | } 278 | 279 | memset(array, 0, height*width); 280 | unsigned int leftmargin = 3; 281 | unsigned int topmargin = 3; 282 | if (no_margin) { 283 | leftmargin = 0; 284 | topmargin = 0; 285 | } 286 | unsigned int x = leftmargin; 287 | unsigned int y = topmargin; 288 | 289 | unsigned int lasthoriX = 0; 290 | cur_e = 0; 291 | 292 | if (fixed_offset) { 293 | // We fixed the maxhoriY to make sure there is no shift 294 | // We test the best maxhoriY by using the following sequence: 295 | // 你好abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,!?./\\-+=@#$%^&*()[]{}<>~`'\";:|_ 296 | // we found that between font size 6-12 (PIXEL uses equivalent to 10) 297 | // maxhoriY = fontsize + 2 298 | if ((fontsize <= 12) and (fontsize >= 6)) { 299 | maxhoriY = fontsize + 2; 300 | } else { 301 | if (fontsize < 6) { 302 | maxhoriY = fontsize + 1; 303 | } 304 | if (fontsize > 12) { 305 | maxhoriY = fontsize + 3; 306 | } 307 | } 308 | } 309 | 310 | while (cur < len) { 311 | if (decoded[cur] == 0xA) { 312 | // Just move to the beginning of the new line 313 | y += vertical_stride; 314 | x = leftmargin; 315 | cur++; 316 | cur_e++; 317 | while ((src[cur_e] & 0xc0) == 0x80) 318 | cur_e++; 319 | continue; 320 | } 321 | 322 | // If the previous character was a space and the current one is 323 | // not, we need to decide now if we should move to the next line. 324 | if (cur && (decoded[cur-1] == 0x20 || decoded[cur-1] == 0x9) 325 | && (decoded[cur] != 0x9 && decoded[cur] != 0x9)) { 326 | unsigned int cur_width = 0; 327 | for (int i = cur; i < len && decoded[i] != 0x20 && 328 | decoded[i] != 0x9; i++) { 329 | cur_width += widths[i]; 330 | if (cur_width > width) 331 | break; 332 | } 333 | if (x+cur_width > width) { 334 | y += vertical_stride; 335 | x = leftmargin; 336 | } 337 | } 338 | 339 | // Try rendering the current character 340 | glyph_index = FT_Get_Char_Index(face, decoded[cur]); 341 | if (FT_Load_Glyph(face, glyph_index, FT_LOAD_RENDER)) { 342 | // Can't render - move on 343 | cur++; 344 | cur_e++; 345 | while ((src[cur_e] & 0xc0) == 0x80) 346 | cur_e++; 347 | continue; 348 | } 349 | 350 | unsigned int horiY = 351 | ((unsigned int)slot->metrics.horiBearingY) >> 6; 352 | unsigned int yoffset = maxhoriY - 353 | ((horiY > maxhoriY)? maxhoriY : horiY); 354 | 355 | unsigned int bearingX; 356 | if (fix_spacing) { 357 | unsigned int horiX = 358 | ((unsigned int)slot->metrics.horiBearingX) >> 6; 359 | if (horiX > 100) 360 | horiX = 0; 361 | lasthoriX = std::max(lasthoriX, horiX); 362 | bearingX = lasthoriX - horiX; 363 | } else 364 | bearingX = 0; 365 | 366 | unsigned int compY = (no_partial? y + maxhoriY : y); 367 | if (compY > height) 368 | break; 369 | 370 | int status = render_single_character_on_array(array, height, 371 | width, &slot->bitmap, x, y + yoffset, bearingX); 372 | 373 | if (status == -1) { 374 | // Exceeding the specified width, try again on a new line 375 | y += vertical_stride; 376 | x = leftmargin; 377 | status = render_single_character_on_array(array, height, 378 | width, &slot->bitmap, x, y + yoffset, bearingX); 379 | } 380 | 381 | if (status == 0) { 382 | // We succeeded, so update (otherwise we skip the character) 383 | x += slot->advance.x >> 6; 384 | y += slot->advance.y >> 6; // zero 385 | if (x > width) { 386 | x = leftmargin; 387 | y += vertical_stride; 388 | } 389 | } 390 | 391 | cur++; 392 | cur_e++; 393 | while ((src[cur_e] & 0xc0) == 0x80) 394 | cur_e++; 395 | } 396 | 397 | return text.substr(0, cur_e); 398 | } 399 | 400 | void only_render_text_unicode(FT_Face &face, std::string text, 401 | unsigned char *array, unsigned int height, unsigned int width, 402 | unsigned int fontsize, int line_space, bool fixed_width, 403 | bool fix_spacing, bool no_partial) { 404 | int len = text.length(); 405 | int cur = 0; 406 | FT_UInt glyph_index; 407 | FT_GlyphSlot slot = face->glyph; 408 | 409 | unsigned int vertical_stride = fontsize; 410 | if (line_space == -1) 411 | vertical_stride += (fontsize+1)/2; 412 | else 413 | vertical_stride += ((unsigned int)line_space); 414 | unsigned int maxhoriY = 0; //, minhoriY = 100; 415 | unsigned int widths[len]; 416 | 417 | unsigned char *src = (unsigned char *)text.c_str(); 418 | unsigned int decoded[len]; 419 | int cur_e = 0, cur_d = 0; 420 | while (cur_e < len) { 421 | if (src[cur_e] >> 7 == 0) 422 | decoded[cur_d++] = src[cur_e++]; 423 | else if (((src[cur_e] >> 5)&1) == 0) { 424 | decoded[cur_d++] = 425 | ((src[cur_e] & 31) << 6) | (src[cur_e+1] & 63); 426 | cur_e += 2; 427 | } else if (((src[cur_e] >> 4)&1) == 0) { 428 | decoded[cur_d++] = 429 | ((src[cur_e] & 15) << 12) | 430 | ((src[cur_e+1] & 63) << 6) | 431 | (src[cur_e+2] & 63); 432 | cur_e += 3; 433 | } else { 434 | decoded[cur_d++] = 435 | ((src[cur_e] & 7) << 18) | 436 | ((src[cur_e+1] & 63) << 12) | 437 | ((src[cur_e+2] & 63) << 6) | 438 | (src[cur_e+3] & 63); 439 | cur_e += 4; 440 | } 441 | } 442 | 443 | len = cur_d; 444 | FT_Select_Charmap(face, FT_ENCODING_UNICODE); 445 | 446 | if (fixed_width) { 447 | std::unordered_map cache; 448 | for (int i = 0; i < len; i++) { 449 | if (cache.find(decoded[i]) != cache.end()) { 450 | widths[i] = cache[decoded[i]]; 451 | } else { 452 | glyph_index = FT_Get_Char_Index(face, decoded[i]); 453 | FT_Load_Glyph(face, glyph_index, FT_LOAD_NO_BITMAP); 454 | widths[i] = slot->advance.x >> 6; 455 | cache[decoded[i]] = widths[i]; 456 | unsigned int horiY = ((unsigned int) 457 | slot->metrics.horiBearingY) >> 6; 458 | if (horiY <= 50) 459 | maxhoriY = std::max(maxhoriY, horiY); 460 | // if (horiY > 0) 461 | // minhoriY = std::min(minhoriY, horiY); 462 | } 463 | } 464 | } else { 465 | for (int i = 0; i < len; i++) { 466 | glyph_index = FT_Get_Char_Index(face, decoded[i]); 467 | FT_Load_Glyph(face, glyph_index, FT_LOAD_NO_BITMAP); 468 | widths[i] = slot->advance.x >> 6; 469 | unsigned int horiY = ((unsigned int)slot->metrics.horiBearingY) >> 6; 470 | if (horiY <= 50) 471 | maxhoriY = std::max(maxhoriY, horiY); 472 | // if (horiY > 0) 473 | // minhoriY = std::min(minhoriY, horiY); 474 | } 475 | } 476 | 477 | memset(array, 0, height*width); 478 | unsigned int leftmargin = 3; 479 | unsigned int topmargin = 3; 480 | unsigned int x = leftmargin; 481 | unsigned int y = topmargin; 482 | 483 | unsigned int lasthoriX = 0; 484 | 485 | while (cur < len) { 486 | if (decoded[cur] == 0xA) { 487 | // Just move to the beginning of the new line 488 | y += vertical_stride; 489 | x = leftmargin; 490 | cur++; 491 | continue; 492 | } 493 | 494 | // If the previous character was a space and the current one is 495 | // not, we need to decide now if we should move to the next line. 496 | if (cur && (decoded[cur-1] == 0x20 || decoded[cur-1] == 0x9) 497 | && (decoded[cur] != 0x9 && decoded[cur] != 0x9)) { 498 | unsigned int cur_width = 0; 499 | for (int i = cur; i < len && decoded[i] != 0x20 && 500 | decoded[i] != 0x9; i++) { 501 | cur_width += widths[i]; 502 | if (cur_width > width) 503 | break; 504 | } 505 | if (x+cur_width > width) { 506 | y += vertical_stride; 507 | x = leftmargin; 508 | } 509 | } 510 | 511 | // Try rendering the current character 512 | glyph_index = FT_Get_Char_Index(face, decoded[cur]); 513 | if (FT_Load_Glyph(face, glyph_index, FT_LOAD_RENDER)) { 514 | // Can't render - move on 515 | cur++; 516 | continue; 517 | } 518 | 519 | unsigned int horiY = 520 | ((unsigned int)slot->metrics.horiBearingY) >> 6; 521 | unsigned int yoffset = maxhoriY - 522 | ((horiY > 50)? maxhoriY : horiY); 523 | 524 | unsigned int bearingX; 525 | if (fix_spacing) { 526 | unsigned int horiX = 527 | ((unsigned int)slot->metrics.horiBearingX) >> 6; 528 | if (horiX > 100) 529 | horiX = 0; 530 | lasthoriX = std::max(lasthoriX, horiX); 531 | bearingX = lasthoriX - horiX; 532 | } else 533 | bearingX = 0; 534 | 535 | unsigned int compY = (no_partial? y + maxhoriY : y); 536 | if (compY > height) 537 | break; 538 | 539 | int status = render_single_character_on_array(array, height, 540 | width, &slot->bitmap, x, y + yoffset, bearingX); 541 | 542 | if (status == -1) { 543 | // Exceeding the specified width, try again on a new line 544 | y += vertical_stride; 545 | x = leftmargin; 546 | status = render_single_character_on_array(array, height, 547 | width, &slot->bitmap, x, y + yoffset, bearingX); 548 | } 549 | 550 | if (status == 0) { 551 | // We succeeded, so update (otherwise we skip the character) 552 | x += slot->advance.x >> 6; 553 | y += slot->advance.y >> 6; // zero 554 | if (x > width) { 555 | x = leftmargin; 556 | y += vertical_stride; 557 | } 558 | } 559 | 560 | cur++; 561 | } 562 | } 563 | 564 | 565 | void show_image_in_terminal(unsigned char *image, unsigned int height, 566 | unsigned int width) { 567 | unsigned int i, j; 568 | for (i = 0; i < height; i++) { 569 | for (j = 0; j < width; j++) 570 | putchar(image[i*width+j] == 0 ? ' ' : 571 | (image[i*width+j] < 128 ? '+' : '*')); 572 | putchar('\n'); 573 | } 574 | } 575 | 576 | std::string render_text_onto_array(std::string text, unsigned char *array, 577 | int height, int width, int fontsize, int line_space, bool fixed_width, 578 | bool fix_spacing, bool no_partial, std::string fontpath) { 579 | FT_Library library; 580 | FT_Face face; 581 | 582 | init_ft(library); 583 | init_face(library, face, fontpath, fontsize); 584 | 585 | std::string rendered = render_text(face, text, (unsigned char *)array, 586 | (unsigned int)height, (unsigned int)width, (unsigned int)fontsize, 587 | line_space, fixed_width, fix_spacing, no_partial); 588 | 589 | finish(library, face); 590 | return rendered; 591 | } 592 | 593 | std::string render_text_onto_array_unicode(std::string text, unsigned char *array, 594 | int height, int width, int fontsize, int line_space, bool fixed_width, 595 | bool fix_spacing, bool no_partial, bool no_margin, bool fixed_offset, std::string fontpath) { 596 | FT_Library library; 597 | FT_Face face; 598 | 599 | init_ft(library); 600 | init_face(library, face, fontpath, fontsize); 601 | 602 | std::string rendered = render_text_unicode(face, text, 603 | (unsigned char *)array, (unsigned int)height, (unsigned int)width, 604 | (unsigned int)fontsize, line_space, fixed_width, fix_spacing, 605 | no_partial, no_margin, fixed_offset); 606 | 607 | finish(library, face); 608 | return rendered; 609 | } 610 | 611 | void only_render_text_onto_array_unicode(std::string text, 612 | unsigned char *array, int height, int width, int fontsize, 613 | int line_space, bool fixed_width, bool fix_spacing, bool no_partial, 614 | std::string fontpath) { 615 | FT_Library library; 616 | FT_Face face; 617 | 618 | init_ft(library); 619 | init_face(library, face, fontpath, fontsize); 620 | 621 | only_render_text_unicode(face, text, (unsigned char *)array, 622 | (unsigned int)height, (unsigned int)width, (unsigned int)fontsize, 623 | line_space, fixed_width, fix_spacing, no_partial); 624 | 625 | finish(library, face); 626 | } -------------------------------------------------------------------------------- /rendering/src/render.h: -------------------------------------------------------------------------------- 1 | #ifndef RENDER_H_ 2 | #define RENDER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include FT_FREETYPE_H 12 | 13 | void init_ft(FT_Library &library); 14 | void init_face(FT_Library library, FT_Face &face, std::string 15 | fontpath="GoNotoCurrent.ttf", unsigned int fontsize=16); 16 | void show_image_in_terminal(unsigned char *image, unsigned int height=256, 17 | unsigned int width=512); 18 | void finish(FT_Library &library, FT_Face &face); 19 | 20 | std::string render_text(FT_Face &face, std::string text, unsigned char *array, 21 | unsigned int height=256, unsigned int width=512, 22 | unsigned int fontsize=16, int line_space=-1, bool fixed_width=true, 23 | bool fix_spacing=true, bool no_partial=false); 24 | std::string render_text_unicode(FT_Face &face, std::string text, 25 | unsigned char *array, unsigned int height=256, unsigned int width=512, 26 | unsigned int fontsize=16, int line_space=-1, bool fixed_width=true, 27 | bool fix_spacing=true, bool no_partial=false, bool no_margin=false, bool fixed_offset=false); 28 | std::string render_text_onto_array(std::string text, unsigned char *array, 29 | int height=256, int width=512, int fontsize=16, int line_space=-1, 30 | bool fixed_width=true, bool fix_spacing=true, bool no_partial=false, 31 | std::string fontpath="GoNotoCurrent.ttf"); 32 | std::string render_text_onto_array_unicode(std::string text, unsigned char *array, 33 | int height=256, int width=512, int fontsize=16, int line_space=-1, 34 | bool fixed_width=true, bool fix_spacing=true, bool no_partial=false, bool no_margin=false, bool fixed_offset=false, 35 | std::string fontpath="GoNotoCurrent.ttf"); 36 | 37 | // These functions do not return the rendered part as a string 38 | void only_render_text_unicode(FT_Face &face, std::string text, unsigned char *array, 39 | unsigned int height=256, unsigned int width=512, 40 | unsigned int fontsize=16, int line_space=-1, bool fixed_width=true, 41 | bool fix_spacing=true, bool no_partial=false); 42 | void only_render_text_onto_array_unicode(std::string text, 43 | unsigned char *array, int height=256, int width=512, int fontsize=16, 44 | int line_space=-1, bool fixed_width=true, bool fix_spacing=true, 45 | bool no_partial=false, std::string fontpath="GoNotoCurrent.ttf"); 46 | 47 | #endif -------------------------------------------------------------------------------- /rendering/src/renderer.cpp: -------------------------------------------------------------------------------- 1 | #include "render.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // Interface 11 | 12 | namespace py = pybind11; 13 | 14 | py::object render(py::array_t &array, std::string text, 15 | int height, int width, int fontsize, int line_space, bool fixed_width, 16 | bool fix_spacing, bool no_partial) { 17 | auto buf = array.request(); 18 | 19 | unsigned char *ptr = (unsigned char *) buf.ptr; 20 | 21 | std::string rendered = render_text_onto_array(text, ptr, height, 22 | width, fontsize, line_space, fixed_width, fix_spacing, no_partial); 23 | 24 | return make_tuple(array, rendered); 25 | } 26 | 27 | py::object render_unicode(py::array_t &array, 28 | std::string text, int height, int width, int fontsize, 29 | int line_space, bool fixed_width, bool fix_spacing, bool no_partial, bool no_margin, bool fixed_offset) { 30 | auto buf = array.request(); 31 | 32 | unsigned char *ptr = (unsigned char *) buf.ptr; 33 | 34 | std::string rendered = render_text_onto_array_unicode(text, ptr, height, 35 | width, fontsize, line_space, fixed_width, fix_spacing, no_partial, no_margin, fixed_offset); 36 | 37 | return make_tuple(array, rendered); 38 | } 39 | 40 | py::object only_render_unicode(py::array_t &array, 41 | std::string text, int height, int width, int fontsize, 42 | int line_space, bool fixed_width, bool fix_spacing, bool no_partial) { 43 | auto buf = array.request(); 44 | 45 | unsigned char *ptr = (unsigned char *) buf.ptr; 46 | 47 | only_render_text_onto_array_unicode(text, ptr, height, 48 | width, fontsize, line_space, fixed_width, fix_spacing, no_partial); 49 | 50 | return array; 51 | } 52 | 53 | PYBIND11_MODULE(renderer, m) { 54 | m.doc() = "pybind11 plugin for rendering text"; 55 | 56 | m.def("render", &render, "A function that renders text"); 57 | m.def("render_unicode", &render_unicode, 58 | "A function that renders unicode text"); 59 | m.def("only_render_unicode", &only_render_unicode, 60 | "A function that renders unicode text"); 61 | } -------------------------------------------------------------------------------- /rendering/src/renderer.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/PTP/e03435f1ec1d9902bd6663798dc84f7382ca7286/rendering/src/renderer.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /rendering/src/renderer.cpython-39-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/PTP/e03435f1ec1d9902bd6663798dc84f7382ca7286/rendering/src/renderer.cpython-39-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.34.1 2 | accelerate==0.24.0 3 | mosaicml-streaming==0.6.1 4 | datasets 5 | evaluate 6 | wandb 7 | opencv-python 8 | scikit-learn 9 | scipy 10 | pillow 11 | pybind11 12 | flash-attn==2.5.5 13 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 18 | 19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 20 | https://huggingface.co/models?filter=text-generation 21 | """ 22 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 23 | 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | import torch.distributed as dist 29 | from dataclasses import dataclass, field 30 | from itertools import chain 31 | from typing import Optional 32 | 33 | import datasets 34 | import torch 35 | from datasets import load_dataset 36 | 37 | import transformers 38 | from transformers import ( 39 | CONFIG_MAPPING, 40 | MODEL_FOR_CAUSAL_LM_MAPPING, 41 | AutoConfig, 42 | AutoModelForCausalLM, 43 | AutoTokenizer, 44 | HfArgumentParser, 45 | Trainer, 46 | TrainingArguments, 47 | default_data_collator, 48 | is_torch_tpu_available, 49 | set_seed, 50 | ) 51 | from transformers.testing_utils import CaptureLogger 52 | from transformers.trainer_utils import get_last_checkpoint 53 | from transformers.utils import check_min_version, send_example_telemetry 54 | from transformers.utils.versions import require_version 55 | from data import NumpyDataset, RenderTextCollator 56 | from streaming_data import get_multiple_domain_dataset, MDSDataset 57 | from modeling.modeling_ptp import inject_flash_attention_ptp 58 | from modeling.modeling_pixel import inject_flash_attention_pixel 59 | from modeling.modeling_screenshot_llama import inject_flash_attention_screenshotllama 60 | from modeling.configuration_ptp import PTPConfig 61 | from modeling.configuration_pixel import PIXELConfig 62 | from modeling.configuration_screenshot_llama import LlamaScreenshotConfig 63 | from modeling.modeling_screenshot_llama import LlamaForScreenshot 64 | 65 | from trainer import trainer_addon 66 | 67 | 68 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 69 | logger = logging.getLogger(__name__) 70 | 71 | 72 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 73 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 74 | 75 | 76 | @dataclass 77 | class ModelArguments: 78 | """ 79 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 80 | """ 81 | 82 | model_name_or_path: Optional[str] = field( 83 | default=None, 84 | metadata={ 85 | "help": ( 86 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 87 | ) 88 | }, 89 | ) 90 | model_type: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 93 | ) 94 | config_overrides: Optional[str] = field( 95 | default=None, 96 | metadata={ 97 | "help": ( 98 | "Override some existing default config settings when a model is trained from scratch. Example: " 99 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 100 | ) 101 | }, 102 | ) 103 | config_name: Optional[str] = field( 104 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 105 | ) 106 | tokenizer_name: Optional[str] = field( 107 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 108 | ) 109 | cache_dir: Optional[str] = field( 110 | default=None, 111 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 112 | ) 113 | use_fast_tokenizer: bool = field( 114 | default=True, 115 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 116 | ) 117 | model_revision: str = field( 118 | default="main", 119 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 120 | ) 121 | use_auth_token: bool = field( 122 | default=False, 123 | metadata={ 124 | "help": ( 125 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 126 | "with private models)." 127 | ) 128 | }, 129 | ) 130 | torch_dtype: Optional[str] = field( 131 | default=None, 132 | metadata={ 133 | "help": ( 134 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 135 | "dtype will be automatically derived from the model's weights." 136 | ), 137 | "choices": ["auto", "bfloat16", "float16", "float32"], 138 | }, 139 | ) 140 | load_bf16: bool = field( 141 | default=False, 142 | metadata={ 143 | "help": ( 144 | "Load the model in bf16 (for llama-based models)" 145 | ) 146 | }, 147 | ) 148 | low_cpu_mem_usage: bool = field( 149 | default=False, 150 | metadata={ 151 | "help": ( 152 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." 153 | "set True will benefit LLM loading time and RAM consumption." 154 | ) 155 | }, 156 | ) 157 | flash_attn: bool = field( 158 | default=False, 159 | metadata={ 160 | "help": "Add flash attention (1.0)" 161 | } 162 | ) 163 | hf_flash_attn2: bool = field( 164 | default=False, 165 | metadata={ 166 | "help": "Add HF's flash attention 2.0 (only support BF16)" 167 | } 168 | ) 169 | mask_ratio: float = field( 170 | default=0, 171 | metadata={ 172 | "help": "Mask ratio for patch masking" 173 | } 174 | ) 175 | span_masking: bool = field( 176 | default=False, 177 | metadata={ 178 | "help": "Apply PIXEL style span masking" 179 | } 180 | ) 181 | max_span_length: int = field( 182 | default=6, 183 | metadata={"help": "For span masking: max span length"} 184 | ) 185 | mae_weight: float = field( 186 | default=1.0, 187 | metadata={ 188 | "help": "The weight for the patch prediction (MAE) loss" 189 | } 190 | ) 191 | text_weight: float = field( 192 | default=1.0, 193 | metadata={ 194 | "help": "The weight for text prediction loss" 195 | } 196 | ) 197 | add_mae_decoder: bool = field( 198 | default=False, 199 | metadata={ 200 | "help": "Add MAE decoder" 201 | } 202 | ) 203 | add_text_decoder: bool = field( 204 | default=True, 205 | metadata={ 206 | "help": "Add text decoder" 207 | } 208 | ) 209 | tie_word_embeddings: bool = field( 210 | default=True, 211 | metadata={ 212 | "help": "Tie word embeddings" 213 | } 214 | ) 215 | pixel: bool = field( 216 | default=False, 217 | metadata={"help": "The model is PIXEL"} 218 | ) 219 | screenshot_llama: bool = field( 220 | default=False, 221 | metadata={"help": "The model is screenshot-llama"} 222 | ) 223 | llama: bool = field( 224 | default=False, 225 | metadata={"help": "The model is text-only llama"} 226 | ) 227 | norm_pix_loss: bool = field( 228 | default=True, 229 | metadata={"help": "Norm pix loss (standardize the target pixels before calculating the loss)"} 230 | ) 231 | ignore_mismatched_sizes: bool = field( 232 | default=False, 233 | metadata={"help": "Ignore mismatched sizes"} 234 | ) 235 | ar_text_weight: float = field( 236 | default=1.0, 237 | metadata={ 238 | "help": "Weight for the text loss in autoregressive models" 239 | } 240 | ) 241 | ar_pixel_weight: float = field( 242 | default=1.0, 243 | metadata={ 244 | "help": "Wight for the patch prediction loss in autoregressive models" 245 | } 246 | ) 247 | embedding_layernorm: bool = field( 248 | default=False, 249 | metadata={"help": "Add a layernorm layer after the embedding"} 250 | ) 251 | 252 | 253 | def __post_init__(self): 254 | if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): 255 | raise ValueError( 256 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 257 | ) 258 | 259 | 260 | @dataclass 261 | class DataTrainingArguments: 262 | """ 263 | Arguments pertaining to what data we are going to input our model for training and eval. 264 | """ 265 | 266 | dataset_name: Optional[str] = field( 267 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 268 | ) 269 | dataset_config_name: Optional[str] = field( 270 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 271 | ) 272 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 273 | validation_file: Optional[str] = field( 274 | default=None, 275 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 276 | ) 277 | max_train_samples: Optional[int] = field( 278 | default=None, 279 | metadata={ 280 | "help": ( 281 | "For debugging purposes or quicker training, truncate the number of training examples to this " 282 | "value if set." 283 | ) 284 | }, 285 | ) 286 | max_eval_samples: Optional[int] = field( 287 | default=None, 288 | metadata={ 289 | "help": ( 290 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 291 | "value if set." 292 | ) 293 | }, 294 | ) 295 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 296 | block_size: Optional[int] = field( 297 | default=1024, 298 | metadata={ 299 | "help": ( 300 | "Optional input sequence length after tokenization. " 301 | "The training dataset will be truncated in block of this size for training. " 302 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 303 | ) 304 | }, 305 | ) 306 | overwrite_cache: bool = field( 307 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 308 | ) 309 | validation_split_percentage: Optional[int] = field( 310 | default=5, 311 | metadata={ 312 | "help": "The percentage of the train set used as validation set in case there's no validation split" 313 | }, 314 | ) 315 | preprocessing_num_workers: Optional[int] = field( 316 | default=None, 317 | metadata={"help": "The number of processes to use for the preprocessing."}, 318 | ) 319 | keep_linebreaks: bool = field( 320 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 321 | ) 322 | font_size: int = field( 323 | default=10, metadata={"help": "Font size for online rendering"} 324 | ) 325 | line_space: int = field( 326 | default=6, metadata={"help": "(Extra) line space for online rendering. The final height of one line is font_size + line_space"} 327 | ) 328 | replace_new_line: bool = field( 329 | default=True, metadata={"help": "Replace new line with a special token (to save rendering space)"} 330 | ) 331 | new_line_token: str = field( 332 | default="//", 333 | ) 334 | rendered_as_target: bool = field( 335 | default=True, metadata={"help": "Only use the rendered text as the text target and pad the rest"} 336 | ) 337 | text_mask_rate: float = field( 338 | default=0, 339 | metadata={ 340 | "help": "Text mask rate" 341 | } 342 | ) 343 | merge_text_masks: bool = field( 344 | default=True, 345 | metadata={ 346 | "help": "Merge consecutive text masks" 347 | } 348 | ) 349 | ignore_white_patches: bool = field( 350 | default=True, 351 | metadata={ 352 | "help": "Ignore white patches: exclude them in attn masks; do not mask them" 353 | } 354 | ) 355 | remove_unicode: bool = field( 356 | default=False, 357 | metadata={ 358 | "help": "Remove all unicode characters" 359 | } 360 | ) 361 | add_black_patch: bool = field( 362 | default=False, 363 | metadata={ 364 | "help": "Add black patch to the image after the text ends" 365 | } 366 | ) 367 | add_prefix: bool = field( 368 | default=False, 369 | metadata={ 370 | "help": "Add a text prefix at the beginning of the image" 371 | } 372 | ) 373 | autoregressive: bool = field( 374 | default=False, 375 | metadata={ 376 | "help": "Autoregressive style training for screenshot-llama" 377 | } 378 | ) 379 | ar_image_block_size: int = field( 380 | default=256, 381 | metadata={ 382 | "help": "(For autoregressive) number of tokens rendered in the screenshot" 383 | } 384 | ) 385 | total_block_size: int = field( 386 | default=None, 387 | metadata={ 388 | "help": "(For autoregressive) total number of tokens" 389 | } 390 | ) 391 | context_mask: int = field( 392 | default=None, 393 | metadata={ 394 | "help": "(For autoregressive) do not calculate loss on the first x tokens" 395 | } 396 | ) 397 | image_mode: str = field( 398 | default="RGB", 399 | metadata={ 400 | "help": "Image mode" 401 | } 402 | ) 403 | streaming_dataset: bool = field( 404 | default=False, metadata={"help": "Whether to use streaming dataset (mosiac) or not."} 405 | ) 406 | streaming_train_root: str = field( 407 | default=None, metadata={"help": "The root directory of the streaming training dataset."} 408 | ) 409 | streaming_val_root: str = field( 410 | default=None, metadata={"help": "The root directory of the streaming validation dataset."} 411 | ) 412 | streaming_domains: str = field( 413 | default=None, metadata={"help": "The domains/proportions of the streaming dataset. Should be a JSON string."} 414 | ) 415 | streaming_remote: bool = field( 416 | default=False, metadata={"help": "Whether to use remote streaming dataset or not."} 417 | ) 418 | sample_mask_at_collator: bool = field( 419 | default=False, metadata={"help": "Sample masks in the data loading part instead of the model part."} 420 | ) 421 | 422 | def __post_init__(self): 423 | if self.streaming: 424 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") 425 | 426 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and not self.streaming_dataset: 427 | raise ValueError("Need either a dataset name or a training/validation file.") 428 | 429 | @dataclass 430 | class OurTrainingArguments(TrainingArguments): 431 | 432 | log_eval_image_pred: bool = field( 433 | default=False, 434 | metadata={"help": "Log eval image prediction to wandb"} 435 | ) 436 | width: int = field( 437 | default=512, metadata={"help": "Width of the input image"} 438 | ) 439 | height: int = field( 440 | default=256, metadata={"help": "Height of the input image"} 441 | ) 442 | patch_width: int = field( 443 | default=16, metadata={"help": "Width of the patch"} 444 | ) 445 | patch_height: int = field( 446 | default=16, metadata={"help": "Height of the patch"} 447 | ) 448 | cosine_w_min: bool = field( 449 | default=False, metadata={"help": "Cosine scheduler with min lr (only activated when using cosine)"} 450 | ) 451 | min_learning_rate: float = field( 452 | default=0, metadata={"help": "Minimum learning rate"} 453 | ) 454 | log_grad_norm: bool = field( 455 | default=False, metadata={"help": "Log grad norm"} 456 | ) 457 | log_train_input: bool = field( 458 | default=False, metadata={"help": "Log train input"} 459 | ) 460 | 461 | 462 | def get_model_and_processor(model_args, training_args, config): 463 | model_kwargs = {} 464 | if model_args.hf_flash_attn2: 465 | assert not model_args.flash_attn # Cannot use this with flash attention 1 together 466 | assert model_args.screenshot_llama or model_args.llama # Only support llama-based models 467 | model_kwargs["use_flash_attention_2"] = True 468 | torch_dtype = ( 469 | model_args.torch_dtype 470 | if model_args.torch_dtype in ["auto", None] 471 | else getattr(torch, model_args.torch_dtype) 472 | ) 473 | model_kwargs["torch_dtype"] = torch_dtype 474 | 475 | if model_args.screenshot_llama: 476 | config.model_type = "screenshot-llama" 477 | from modeling.processing_screenshot_llama import ScreenshotLlamaProcessor 478 | 479 | processor_kwargs = {} 480 | if not (training_args.patch_height == 16 and training_args.patch_width == 16): 481 | processor_kwargs = {"patch_size": {"height": training_args.patch_height, "width": training_args.patch_width}} 482 | config.patch_embed_size = training_args.patch_height * training_args.patch_width * 3 483 | 484 | processor = ScreenshotLlamaProcessor(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, **processor_kwargs) 485 | config.img_begin_token_id = processor.img_begin_token_id 486 | config.img_end_token_id = processor.img_end_token_id 487 | config.patch_token_id = processor.patch_token_id 488 | config.newline_token_id = processor.newline_token_id 489 | config.norm_pix_loss = model_args.norm_pix_loss 490 | 491 | if model_args.model_name_or_path: 492 | model = LlamaForScreenshot.from_pretrained(model_args.model_name_or_path, config=config, **model_kwargs) 493 | if model_args.load_bf16: 494 | model = model.to(dtype=torch.bfloat16) 495 | else: 496 | if model_args.hf_flash_attn2: 497 | config._flash_attn_2_enabled = True 498 | model = LlamaForScreenshot._from_config(config, torch_dtype=torch_dtype) 499 | if model_args.load_bf16: 500 | model = model.to(dtype=torch.bfloat16) 501 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 502 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 503 | 504 | # Add the special tokens for image boundary 505 | embedding_size = model.get_input_embeddings().weight.shape[0] 506 | if len(processor.tokenizer) > embedding_size: 507 | model.resize_token_embeddings(len(processor.tokenizer)) 508 | 509 | model.ar_text_weight = model_args.ar_text_weight 510 | model.ar_pixel_weight = model_args.ar_pixel_weight 511 | 512 | return processor, model 513 | elif model_args.llama: 514 | from modeling.processing_screenshot_llama import ScreenshotLlamaProcessor 515 | from transformers import LlamaForCausalLM 516 | 517 | processor_kwargs = {} 518 | processor = ScreenshotLlamaProcessor(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, **processor_kwargs) 519 | 520 | if model_args.model_name_or_path: 521 | model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **model_kwargs) 522 | if model_args.load_bf16: 523 | model = model.to(dtype=torch.bfloat16) 524 | else: 525 | if model_args.hf_flash_attn2: 526 | config._flash_attn_2_enabled = True 527 | model = LlamaForCausalLM._from_config(config, torch_dtype=torch_dtype) 528 | if model_args.load_bf16: 529 | model = model.to(dtype=torch.bfloat16) 530 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 531 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 532 | 533 | embedding_size = model.get_input_embeddings().weight.shape[0] 534 | if len(processor.tokenizer) > embedding_size: 535 | model.resize_token_embeddings(len(processor.tokenizer)) 536 | 537 | return processor, model 538 | elif config.model_type in ['ptp']: 539 | # Our main model 540 | from modeling.processing_ptp import PTPProcessor 541 | from modeling.modeling_ptp import PTPForConditionalGeneration 542 | 543 | config.add_mae_decoder = model_args.add_mae_decoder 544 | if model_args.add_mae_decoder: 545 | config.mae_weight = model_args.mae_weight 546 | config.add_text_decoder = model_args.add_text_decoder 547 | if model_args.add_text_decoder: 548 | config.text_weight = model_args.text_weight 549 | config.tie_word_embeddings = model_args.tie_word_embeddings 550 | config.text_config.tie_word_embeddings = model_args.tie_word_embeddings 551 | 552 | config.vision_config.embedding_layernorm = model_args.embedding_layernorm 553 | config.vision_config.norm_pix_loss = model_args.norm_pix_loss 554 | config.vision_config.image_size = [ 555 | training_args.height, 556 | training_args.width, 557 | ] 558 | if not (training_args.patch_height == 16 and training_args.patch_width == 16): 559 | config.vision_config.patch_size = (training_args.patch_height, training_args.patch_width) 560 | config.vision_config.num_channels = 3 if config.image_mode == "RGB" else 1 561 | 562 | processor = PTPProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path) 563 | 564 | if not (training_args.patch_height == 16 and training_args.patch_width == 16): 565 | processor.image_processor.patch_size = {"height": training_args.patch_height, "width": training_args.patch_width} 566 | 567 | if model_args.model_name_or_path: 568 | model = PTPForConditionalGeneration.from_pretrained(model_args.model_name_or_path, config=config) 569 | if model_args.load_bf16: 570 | model = model.to(dtype=torch.bfloat16) 571 | else: 572 | model = PTPForConditionalGeneration(config) 573 | if model_args.load_bf16: 574 | model = model.to(dtype=torch.bfloat16) 575 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 576 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 577 | 578 | model.mask_ratio = model_args.mask_ratio 579 | model.mae_weight = model_args.mae_weight 580 | model.text_weight = model_args.text_weight 581 | 582 | model.encoder.embeddings.mask_ratio = model_args.mask_ratio 583 | model.encoder.embeddings.span_masking = model_args.span_masking 584 | 585 | return processor, model 586 | elif config.model_type in ['pixel']: 587 | from modeling.processing_ptp import PTPProcessor 588 | from modeling.modeling_pixel import PIXELForPreTraining 589 | 590 | config.embedding_layernorm = model_args.embedding_layernorm 591 | config.norm_pix_loss = model_args.norm_pix_loss 592 | config.image_size = [ 593 | training_args.height, 594 | training_args.width, 595 | ] 596 | if not (training_args.patch_height == 16 and training_args.patch_width == 16): 597 | config.patch_size = (training_args.patch_height, training_args.patch_width) 598 | config.num_channels = 3 if config.image_mode == "RGB" else 1 599 | 600 | processor = PTPProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path) 601 | 602 | if not (training_args.patch_height == 16 and training_args.patch_width == 16): 603 | processor.image_processor.patch_size = {"height": training_args.patch_height, "width": training_args.patch_width} 604 | 605 | if model_args.model_name_or_path: 606 | model = PIXELForPreTraining.from_pretrained(model_args.model_name_or_path, config=config, ignore_mismatched_sizes=model_args.ignore_mismatched_sizes) 607 | if model_args.load_bf16: 608 | model = model.to(dtype=torch.bfloat16) 609 | else: 610 | model = PIXELForPreTraining(config) 611 | if model_args.load_bf16: 612 | model = model.to(dtype=torch.bfloat16) 613 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 614 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 615 | 616 | 617 | model.vit.embeddings.mask_ratio = model_args.mask_ratio 618 | model.vit.embeddings.span_masking = model_args.span_masking 619 | 620 | return processor, model 621 | else: 622 | raise NotImplementedError 623 | 624 | def main(): 625 | # See all possible arguments in src/transformers/training_args.py 626 | # or by passing the --help flag to this script. 627 | # We now keep distinct sets of args, for a cleaner separation of concerns. 628 | 629 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments)) 630 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 631 | # If we pass only one argument to the script and it's the path to a json file, 632 | # let's parse it to get our arguments. 633 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 634 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 635 | model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1])) 636 | else: 637 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 638 | 639 | # Setup logging 640 | logging.basicConfig( 641 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 642 | datefmt="%m/%d/%Y %H:%M:%S", 643 | handlers=[logging.StreamHandler(sys.stdout)], 644 | ) 645 | 646 | if training_args.should_log: 647 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 648 | transformers.utils.logging.set_verbosity_info() 649 | 650 | log_level = training_args.get_process_log_level() 651 | logger.setLevel(log_level) 652 | datasets.utils.logging.set_verbosity(log_level) 653 | transformers.utils.logging.set_verbosity(log_level) 654 | transformers.utils.logging.enable_default_handler() 655 | transformers.utils.logging.enable_explicit_format() 656 | 657 | # Log on each process the small summary: 658 | logger.warning( 659 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 660 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 661 | ) 662 | logger.info(f"Training/evaluation parameters {training_args}") 663 | 664 | # Detecting last checkpoint. 665 | last_checkpoint = None 666 | if os.path.isdir(training_args.output_dir) and (training_args.do_train or training_args.check_dataset) and not training_args.overwrite_output_dir: 667 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 668 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 669 | raise ValueError( 670 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 671 | "Use --overwrite_output_dir to overcome." 672 | ) 673 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 674 | logger.info( 675 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 676 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 677 | ) 678 | # Set seed before initializing model. 679 | set_seed(training_args.seed) 680 | 681 | if data_args.streaming_dataset: 682 | train_dataset = get_multiple_domain_dataset(root_dir=data_args.streaming_train_root, shuffle=True, remote=data_args.streaming_remote, block_size=data_args.block_size) 683 | eval_dataset = get_multiple_domain_dataset(root_dir=data_args.streaming_val_root, shuffle=False, remote=data_args.streaming_remote, block_size=data_args.block_size) 684 | else: 685 | dataset_class = NumpyDataset 686 | 687 | logger.info("Loading train dataset (numpy)...") 688 | train_dataset = dataset_class(data_args.train_file, block_size=data_args.block_size) if data_args.train_file is not None else None 689 | logger.info("Done") 690 | 691 | logger.info("Loading validation dataset (numpy)...") 692 | eval_dataset = dataset_class(data_args.validation_file, block_size=data_args.block_size) if data_args.validation_file is not None else None 693 | logger.info("Done") 694 | 695 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 696 | # https://huggingface.co/docs/datasets/loading_datasets.html. 697 | 698 | # Load pretrained model and tokenizer 699 | # 700 | # Distributed training: 701 | # The .from_pretrained methods guarantee that only one local process can concurrently 702 | # download model & vocab. 703 | 704 | config_kwargs = { 705 | "cache_dir": model_args.cache_dir, 706 | "revision": model_args.model_revision, 707 | "use_auth_token": True if model_args.use_auth_token else None, 708 | } 709 | if model_args.screenshot_llama: 710 | config_cls = LlamaScreenshotConfig 711 | elif model_args.llama: 712 | config_cls = AutoConfig 713 | elif model_args.pixel: 714 | config_cls = PIXELConfig 715 | else: 716 | config_cls = PTPConfig 717 | if model_args.config_name: 718 | config = config_cls.from_pretrained(model_args.config_name, **config_kwargs) 719 | elif model_args.model_name_or_path: 720 | config = config_cls.from_pretrained(model_args.model_name_or_path, **config_kwargs) 721 | 722 | config.image_mode = data_args.image_mode 723 | training_args.image_mode = data_args.image_mode 724 | processor, model = get_model_and_processor(model_args, training_args, config) 725 | 726 | if model_args.flash_attn: 727 | if model_args.screenshot_llama or model_args.llama: 728 | inject_flash_attention_screenshotllama(model) 729 | elif model_args.pixel: 730 | inject_flash_attention_pixel(model) 731 | else: 732 | inject_flash_attention_ptp(model) 733 | inject_flash_attention_pixel(model.encoder) 734 | if hasattr(model, "mae_decoder"): 735 | inject_flash_attention_pixel(model.mae_decoder) 736 | 737 | if training_args.do_train: 738 | if data_args.max_train_samples is not None: 739 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 740 | train_dataset = train_dataset.select(range(max_train_samples)) 741 | 742 | if training_args.do_eval: 743 | if data_args.max_eval_samples is not None: 744 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 745 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 746 | 747 | def preprocess_logits_for_metrics(logits, labels): 748 | if isinstance(logits, tuple): 749 | # Depending on the model and config, logits may contain extra tensors, 750 | # like past_key_values, but logits always come first 751 | logits = logits[0] 752 | return logits.argmax(dim=-1) 753 | 754 | collator = RenderTextCollator( 755 | processor=processor, 756 | font_size=data_args.font_size, 757 | line_space=data_args.line_space, 758 | replace_new_line=data_args.replace_new_line, 759 | new_line_token=data_args.new_line_token, 760 | width=training_args.width, height=training_args.height, 761 | block_size=data_args.block_size, 762 | rendered_as_target=data_args.rendered_as_target, 763 | patch_width=training_args.patch_width, 764 | patch_height=training_args.patch_height, 765 | text_mask_rate=data_args.text_mask_rate, 766 | merge_text_masks=data_args.merge_text_masks, 767 | ignore_white_patches=data_args.ignore_white_patches, 768 | add_black_patch=data_args.add_black_patch, 769 | add_prefix=data_args.add_prefix, 770 | autoregressive=data_args.autoregressive, 771 | ar_image_block_size=data_args.ar_image_block_size, 772 | total_block_size=data_args.total_block_size, 773 | context_mask=data_args.context_mask, 774 | image_mode=data_args.image_mode, 775 | sample_mask_at_collator=data_args.sample_mask_at_collator, 776 | mask_ratio=model_args.mask_ratio, 777 | span_masking=model_args.span_masking, 778 | max_span_length=model_args.max_span_length, 779 | ) 780 | 781 | # Initialize our Trainer 782 | trainer = Trainer( 783 | model=model, 784 | args=training_args, 785 | train_dataset=train_dataset if training_args.do_train else None, 786 | eval_dataset=eval_dataset if training_args.do_eval else None, 787 | tokenizer=processor, 788 | # Data collator will default to DataCollatorWithPadding, so we change it. 789 | data_collator=collator, 790 | compute_metrics=None, 791 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 792 | if training_args.do_eval and not is_torch_tpu_available() 793 | else None, 794 | ) 795 | 796 | trainer = trainer_addon(trainer, streaming_dataset=data_args.streaming_dataset) 797 | 798 | # Training 799 | if training_args.do_train: 800 | checkpoint = None 801 | if training_args.resume_from_checkpoint is not None: 802 | checkpoint = training_args.resume_from_checkpoint 803 | elif last_checkpoint is not None: 804 | checkpoint = last_checkpoint 805 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 806 | trainer.save_model() # Saves the tokenizer too for easy upload 807 | 808 | metrics = train_result.metrics 809 | 810 | max_train_samples = ( 811 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 812 | ) 813 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 814 | 815 | trainer.log_metrics("train", metrics) 816 | trainer.save_metrics("train", metrics) 817 | trainer.save_state() 818 | 819 | # Evaluation 820 | if training_args.do_eval: 821 | logger.info("*** Evaluate ***") 822 | 823 | metrics = trainer.evaluate() 824 | 825 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 826 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 827 | try: 828 | perplexity = math.exp(metrics["eval_loss"]) 829 | except OverflowError: 830 | perplexity = float("inf") 831 | metrics["perplexity"] = perplexity 832 | 833 | trainer.log_metrics("eval", metrics) 834 | trainer.save_metrics("eval", metrics) 835 | 836 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} 837 | if data_args.dataset_name is not None: 838 | kwargs["dataset_tags"] = data_args.dataset_name 839 | if data_args.dataset_config_name is not None: 840 | kwargs["dataset_args"] = data_args.dataset_config_name 841 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 842 | else: 843 | kwargs["dataset"] = data_args.dataset_name 844 | 845 | 846 | def _mp_fn(index): 847 | # For xla_spawn (TPUs) 848 | main() 849 | 850 | 851 | if __name__ == "__main__": 852 | main() 853 | -------------------------------------------------------------------------------- /run_configs/ptp.yaml: -------------------------------------------------------------------------------- 1 | # Path 2 | config_name: model_configs/ptp/config.json 3 | tokenizer_name: model_configs/ptp 4 | output_dir: result/my-ptp 5 | run_name: my-ptp 6 | train_file: data/wikibook_256_opt_tk_train.npy 7 | validation_file: data/wikibook_256_opt_tk_val.npy 8 | 9 | # Speedup 10 | flash_attn: true 11 | dataloader_num_workers: 8 # Multiple workers to speedup the rendering 12 | 13 | # Logging 14 | do_eval: true 15 | do_train: true 16 | eval_steps: 20000 17 | evaluation_strategy: steps 18 | save_steps: 100000 19 | save_strategy: steps 20 | save_total_limit: 15 21 | log_eval_image_pred: true 22 | logging_steps: 10 23 | 24 | # Rendering 25 | rendered_as_target: true 26 | replace_new_line: true 27 | add_black_patch: true 28 | add_prefix: true 29 | ignore_white_patches: true 30 | 31 | # Image size 32 | font_size: 10 33 | line_space: 6 34 | height: 16 35 | width: 8176 # +CLS to make it 512 sequence length: good for hardware speedup 36 | patch_height: 16 37 | patch_width: 16 38 | 39 | # Optimization 40 | num_train_epochs: 16 41 | per_device_eval_batch_size: 32 42 | per_device_train_batch_size: 32 43 | warmup_steps: 50000 44 | lr_scheduler_type: "cosine" 45 | learning_rate: 1.5e-4 46 | min_learning_rate: 1.0e-5 47 | cosine_w_min: true 48 | weight_decay: 0.05 49 | block_size: 256 50 | 51 | # Modeling 52 | add_mae_decoder: true 53 | add_text_decoder: true 54 | 55 | # Extra 56 | text_mask_rate: 0.25 57 | merge_text_masks: true 58 | mask_ratio: 0.10 59 | span_masking: true 60 | sample_mask_at_collator: true 61 | norm_pix_loss: true 62 | fp16: true 63 | ignore_mismatched_sizes: true 64 | -------------------------------------------------------------------------------- /run_configs/screenshot-llama-1.3b-from-sheared-llama.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Path 3 | config_name: princeton-nlp/Sheared-LLaMA-1.3B 4 | model_name_or_path: princeton-nlp/Sheared-LLaMA-1.3B 5 | output_dir: result/my-screenshot-llama-1.3b-from-sheared-llama 6 | run_name: my-screenshot-llama-1.3b-from-sheared-llama 7 | 8 | # Dataset 9 | streaming_dataset: true 10 | streaming_train_root: data/sheared-llama-rp/for_ft/ 11 | streaming_val_root: data/sheared-llama-rp/eval/ 12 | streaming_remote: false 13 | block_size: 512 14 | 15 | # Speedup 16 | flash_attn: false 17 | hf_flash_attn2: true 18 | dataloader_num_workers: 8 19 | bf16: true 20 | 21 | # Logging 22 | do_eval: true 23 | do_train: true 24 | eval_steps: 5000 25 | evaluation_strategy: steps 26 | save_steps: 10000 # 1B 27 | save_strategy: steps 28 | save_total_limit: 20 29 | log_eval_image_pred: true 30 | logging_steps: 1 31 | 32 | # Rendering 33 | rendered_as_target: true 34 | replace_new_line: true 35 | add_black_patch: false 36 | add_prefix: false 37 | 38 | # Image size 39 | font_size: 10 40 | line_space: 6 41 | height: 16 42 | width: 8192 43 | patch_height: 16 44 | patch_width: 16 45 | 46 | # Length 47 | block_size: 512 # Total text tokens (split into 256 + 256) 48 | ar_image_block_size: 256 # Text tokens rendered as screenshot 49 | total_block_size: 768 # Total length of patch tokens (~512) + text (~256) 50 | 51 | # Optimization 52 | lr_scheduler_type: "cosine" 53 | max_steps: 50000 # 1B 54 | per_device_eval_batch_size: 32 55 | per_device_train_batch_size: 8 56 | gradient_accumulation_steps: 4 # for 8gpus, this leads to bsz=256 57 | warmup_steps: 2000 # 4% 58 | learning_rate: 1.0e-4 59 | 60 | # Extra 61 | autoregressive: true 62 | screenshot_llama: true 63 | ignore_white_patches: true 64 | norm_pix_loss: true 65 | -------------------------------------------------------------------------------- /run_configs/screenshot-llama-380m.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Path 3 | config_name: model_configs/screenshot-llama-380m 4 | tokenizer_name: model_configs/screenshot-llama-380m 5 | output_dir: result/my-screenshot-llama-380m 6 | run_name: my-screenshot-llama-380m 7 | train_file: data/wikibook_512_llama_tk_train.npy 8 | validation_file: data/wikibook_512_llama_tk_val.npy 9 | 10 | # Speedup 11 | flash_attn: true 12 | dataloader_num_workers: 8 13 | 14 | # Logging 15 | do_eval: true 16 | do_train: true 17 | eval_steps: 10000 18 | evaluation_strategy: steps 19 | save_steps: 100000 20 | save_strategy: steps 21 | save_total_limit: 10 22 | log_eval_image_pred: true 23 | logging_steps: 10 24 | 25 | # Rendering 26 | rendered_as_target: true 27 | replace_new_line: true 28 | add_black_patch: false 29 | add_prefix: false 30 | 31 | # Image size 32 | font_size: 10 33 | line_space: 6 34 | height: 16 35 | width: 8192 36 | patch_height: 16 37 | patch_width: 16 38 | 39 | # Length 40 | block_size: 512 # Total text tokens (split into 256 + 256) 41 | ar_image_block_size: 256 # Text tokens rendered as screenshot 42 | total_block_size: 768 # Total length of patch tokens (~512) + text (~256) 43 | 44 | # Optimization 45 | lr_scheduler_type: "cosine" 46 | num_train_epochs: 16 47 | per_device_eval_batch_size: 32 48 | per_device_train_batch_size: 32 49 | warmup_steps: 50000 50 | learning_rate: 1.5e-4 51 | weight_decay: 0.05 52 | 53 | # Extra 54 | fp16: true 55 | autoregressive: true 56 | screenshot_llama: true 57 | ignore_white_patches: true 58 | norm_pix_loss: true 59 | -------------------------------------------------------------------------------- /run_multiple_gpus.sh: -------------------------------------------------------------------------------- 1 | # Requires: NUM_GPU 2 | 3 | OMP_NUM_THREADS=8 WANDB_PROJECT=ptp torchrun --nnodes=1 --nproc_per_node=$NUM_GPU run.py $1 4 | -------------------------------------------------------------------------------- /run_single_gpu.sh: -------------------------------------------------------------------------------- 1 | WANDB_PROJECT=ptp python run.py $1 2 | -------------------------------------------------------------------------------- /streaming_data.py: -------------------------------------------------------------------------------- 1 | from transformers.utils import logging 2 | import torch 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | from streaming import LocalDataset, StreamingDataset, Stream 6 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 7 | import os 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | class MDSDataset(StreamingDataset): 12 | 13 | def __init__(self, block_size=None, return_key="tokens", **kwargs): 14 | super().__init__(**kwargs) 15 | self.block_size = block_size 16 | if block_size is not None: 17 | logger.warning("block_size set in MDSDataset, which means the input might be truncated") 18 | self.return_key = return_key 19 | 20 | 21 | def __getitem__(self, idx): 22 | item = super().__getitem__(idx) 23 | tokens = np.frombuffer(item["tokens"], np.uint16).astype(np.int64) 24 | if self.block_size is not None: 25 | tokens = tokens[:self.block_size] 26 | return {self.return_key: tokens} 27 | 28 | 29 | redpajama_domains_and_proportions = { 30 | "arxiv": 0.025, 31 | "book": 0.045, 32 | "c4-rp": 0.15, 33 | "cc": 0.67, 34 | "github": 0.045, 35 | "stackexchange": 0.02, 36 | "wiki": 0.045 37 | } 38 | 39 | def get_multiple_domain_dataset( 40 | root_dir, 41 | shuffle, 42 | domains_and_proportions=redpajama_domains_and_proportions, 43 | remote=False, 44 | block_size=None, 45 | ): 46 | logger.warning("Loading multiple domain dataset via MosaicML streaming.") 47 | logger.warning("***** Streaming dataset *****") 48 | logger.warning(f"Root dir: {root_dir}") 49 | logger.warning(f"Shuffle: {shuffle}") 50 | logger.warning(f"Domains: {domains_and_proportions}") 51 | logger.warning(f"Remote: {remote}") 52 | logger.warning(f"Block size: {block_size}") 53 | 54 | if remote: 55 | streams = [ 56 | Stream(remote=root_dir+domain, proportion=domains_and_proportions[domain]) 57 | for domain in domains_and_proportions 58 | ] 59 | else: 60 | streams = [ 61 | Stream(local=os.path.join(root_dir, domain), proportion=domains_and_proportions[domain]) 62 | for domain in domains_and_proportions 63 | ] 64 | 65 | dataset = MDSDataset( 66 | block_size=block_size, 67 | streams=streams, 68 | shuffle=shuffle, 69 | ) 70 | 71 | return dataset -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import transformers 3 | from transformers import Trainer 4 | import inspect 5 | from typing import Dict, Union, Any 6 | import torch 7 | import json 8 | from torch import nn 9 | from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model 10 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES 11 | from transformers.trainer_callback import TrainerControl, TrainerState 12 | from transformers.training_args import TrainingArguments 13 | from image_utils import flattened_patches_to_image 14 | import wandb 15 | import numpy as np 16 | import torch.distributed as dist 17 | import random 18 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 19 | from transformers.optimization import get_scheduler 20 | from torch.optim.lr_scheduler import LambdaLR 21 | from torch.optim import Optimizer 22 | import math 23 | import os 24 | import subprocess 25 | from packaging import version 26 | import accelerate 27 | from transformers.trainer_pt_utils import find_batch_size, nested_concat, nested_numpify 28 | 29 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 30 | from transformers.trainer_utils import seed_worker 31 | 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | def is_ge_version(v): 36 | return version.parse(transformers.__version__) >= version.parse(v) 37 | 38 | def _set_signature_columns_if_needed(self): 39 | if self._signature_columns is None: 40 | # Inspect model forward signature to keep only the arguments it accepts. 41 | signature = inspect.signature(self.model.forward) 42 | self._signature_columns = list(signature.parameters.keys()) 43 | # Labels may be named label or label_ids, the default data collator handles that. 44 | self._signature_columns += list(set(["label", "label_ids", "tokens", "image", "font_size", "text", "patch_mask"] + self.label_names)) 45 | 46 | 47 | def compute_loss(self, model, inputs, return_outputs=False): 48 | """ 49 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 50 | 51 | Subclass and override for custom behavior. 52 | """ 53 | 54 | if self.label_smoother is not None and "labels" in inputs: 55 | labels = inputs.pop("labels") 56 | else: 57 | labels = None 58 | outputs = model(**inputs) 59 | # Save past state if it exists 60 | # TODO: this needs to be fixed and made cleaner later. 61 | if self.args.past_index >= 0: 62 | self._past = outputs[self.args.past_index] 63 | 64 | if labels is not None: 65 | if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): 66 | loss = self.label_smoother(outputs, labels, shift_labels=True) 67 | else: 68 | loss = self.label_smoother(outputs, labels) 69 | else: 70 | if isinstance(outputs, dict) and "loss" not in outputs: 71 | raise ValueError( 72 | "The model did not return a loss from the inputs, only the following keys: " 73 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." 74 | ) 75 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 76 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 77 | 78 | # Extra logs 79 | mae_loss_key = "pixel_loss" if hasattr(outputs, "pixel_loss") else "mae_loss" 80 | logits_key = "patch_logits" if hasattr(outputs, "patch_logits") else "logits" 81 | prefix = "eval_" if return_outputs else "" 82 | 83 | if not dist.is_initialized() or dist.get_rank() == 0: # This is an ugly way to log stuff and probably not thread-safe so only log it on rank == 0 84 | if not hasattr(self, "extra_logs"): 85 | self.extra_logs = {} 86 | if hasattr(outputs, mae_loss_key): 87 | self.extra_logs[prefix + mae_loss_key] = outputs[mae_loss_key].item() 88 | if self.args.log_eval_image_pred and return_outputs is True: 89 | images = [ 90 | flattened_patches_to_image( 91 | outputs[logits_key][i].detach().cpu().to(torch.float32), 92 | height=self.args.height, 93 | width=self.args.width, 94 | patch_height=self.args.patch_height, 95 | patch_width=self.args.patch_width, 96 | image_mode=getattr(self.args, 'image_mode', 'RGB') 97 | ) 98 | for i in range(len(outputs[logits_key])) 99 | ] # WARNING: I didn't set the size here 100 | self.extra_logs[prefix + "image_pred"] = [wandb.Image(image) for image in images] 101 | 102 | images = [ 103 | flattened_patches_to_image( 104 | inputs["flattened_patches"][i, :, 2:].detach().cpu().to(torch.float32), 105 | height=self.args.height, 106 | width=self.args.width, 107 | patch_height=self.args.patch_height, 108 | patch_width=self.args.patch_width, 109 | image_mode=getattr(self.args, 'image_mode', 'RGB') 110 | ) 111 | for i in range(len(inputs["flattened_patches"])) 112 | ] 113 | self.extra_logs[prefix + "image_input"] = [wandb.Image(image) for image in images] 114 | 115 | if hasattr(outputs, "mask"): 116 | images = [ 117 | flattened_patches_to_image( 118 | outputs[logits_key][i].detach().cpu().to(torch.float32), 119 | mask=outputs["mask"][i].detach().cpu().long(), 120 | original_patches=inputs["flattened_patches"][i, :, 2:].detach().cpu().to(torch.float32), 121 | height=self.args.height, width=self.args.width, 122 | patch_height=self.args.patch_height, 123 | patch_width=self.args.patch_width, 124 | image_mode=getattr(self.args, 'image_mode', 'RGB') 125 | ) 126 | for i in range(len(outputs[logits_key])) 127 | ] # WARNING: I didn't set the size here 128 | self.extra_logs[prefix + "image_pred_mask"] = [wandb.Image(image) for image in images] 129 | elif "flattened_patches" in inputs and self.args.log_eval_image_pred and return_outputs is True: 130 | images = [ 131 | flattened_patches_to_image( 132 | inputs["flattened_patches"][i, :, 2:].detach().cpu().to(torch.float32), 133 | height=self.args.height, 134 | width=self.args.width, 135 | patch_height=self.args.patch_height, 136 | patch_width=self.args.patch_width, 137 | image_mode=getattr(self.args, 'image_mode', 'RGB') 138 | ) 139 | for i in range(len(inputs["flattened_patches"])) 140 | ] 141 | self.extra_logs[prefix + "image_input"] = [wandb.Image(image) for image in images] 142 | 143 | if hasattr(outputs, "mask"): # ViT MAE 144 | images = [ 145 | flattened_patches_to_image( 146 | outputs[logits_key][i].detach().cpu().to(torch.float32), 147 | height=self.args.height, 148 | width=self.args.width, 149 | patch_height=self.args.patch_height, 150 | patch_width=self.args.patch_width, 151 | image_mode=getattr(self.args, 'image_mode', 'RGB') 152 | ) 153 | for i in range(len(outputs[logits_key])) 154 | ] # WARNING: I didn't set the size here 155 | self.extra_logs[prefix + "image_pred"] = [wandb.Image(image) for image in images] 156 | 157 | images = [flattened_patches_to_image( 158 | outputs[logits_key][i].detach().cpu().to(torch.float32), 159 | mask=outputs["mask"][i].detach().cpu().long(), 160 | original_patches=inputs["flattened_patches"][i, :, 2:].detach().cpu().to(torch.float32), 161 | height=self.args.height, width=self.args.width, 162 | patch_height=self.args.patch_height, 163 | patch_width=self.args.patch_width, 164 | image_mode=getattr(self.args, 'image_mode', 'RGB') 165 | ) 166 | for i in range(len(outputs[logits_key]))] # WARNING: I didn't set the size here 167 | self.extra_logs[prefix + "image_pred_mask"] = [wandb.Image(image) for image in images] 168 | 169 | if hasattr(outputs, "text_loss"): 170 | self.extra_logs[prefix + "text_loss"] = outputs["text_loss"] if isinstance(outputs["text_loss"], float) else outputs["text_loss"].item() 171 | 172 | if hasattr(outputs, "dice_loss") and outputs.dice_loss is not None: 173 | self.extra_logs[prefix + "dice_loss"] = outputs["dice_loss"] if isinstance(outputs["dice_loss"], float) else outputs["dice_loss"].item() 174 | 175 | # Gather losses for logging 176 | if is_ge_version("4.34.1"): 177 | # accelerator gather only applies to >=4.34.1 178 | 179 | if not hasattr(self, "extra_logs"): 180 | self.extra_logs = {} 181 | 182 | batch_size = find_batch_size(inputs) 183 | 184 | if hasattr(outputs, "text_loss"): 185 | text_losses = self.accelerator.gather_for_metrics(outputs["text_loss"].mean().detach().repeat(batch_size)) 186 | self.extra_logs[prefix+"text_loss_aggr"] = text_losses if prefix+"text_loss_aggr" not in self.extra_logs else nested_concat(self.extra_logs[prefix+"text_loss_aggr"], text_losses) 187 | if hasattr(outputs, "dice_loss") and outputs.dice_loss is not None: 188 | dice_losses = self.accelerator.gather_for_metrics(outputs["dice_loss"].mean().detach().repeat(batch_size)) 189 | self.extra_logs[prefix+"dice_loss_aggr"] = dice_losses if prefix+"dice_loss_aggr" not in self.extra_logs else nested_concat(self.extra_logs[prefix+"dice_loss_aggr"], dice_losses) 190 | if hasattr(outputs, mae_loss_key): 191 | mae_losses = self.accelerator.gather_for_metrics(outputs[mae_loss_key].mean().detach().repeat(batch_size)) 192 | self.extra_logs[prefix+mae_loss_key+"_aggr"] = mae_losses if prefix+mae_loss_key+"_aggr" not in self.extra_logs else nested_concat(self.extra_logs[prefix+mae_loss_key+"_aggr"], mae_losses) 193 | 194 | return (loss, outputs) if return_outputs else loss 195 | 196 | # New 197 | def compute_loss_wrapper(self, model, inputs, return_outputs=False): 198 | saved_kwargs = {} 199 | if "true_labels" in inputs: 200 | saved_kwargs["true_labels"] = inputs.pop("true_labels") 201 | 202 | loss_and_outputs = compute_loss(self, model, inputs, return_outputs=return_outputs) 203 | if isinstance(loss_and_outputs, tuple) and len(loss_and_outputs) == 2: 204 | loss, outputs = loss_and_outputs 205 | if isinstance(outputs, tuple): 206 | for k in saved_kwargs: 207 | outputs = outputs + (saved_kwargs[k],) 208 | elif isinstance(outputs, dict): 209 | for k in saved_kwargs: 210 | outputs[k] = saved_kwargs[k] 211 | else: 212 | for k in saved_kwargs: 213 | setattr(outputs, k, saved_kwargs[k]) 214 | return (loss, outputs) 215 | else: 216 | return loss_and_outputs 217 | 218 | def log(self, logs: Dict[str, float]) -> None: 219 | """ 220 | Log `logs` on the various objects watching training. 221 | 222 | Subclass and override this method to inject custom behavior. 223 | 224 | Args: 225 | logs (`Dict[str, float]`): 226 | The values to log. 227 | """ 228 | if self.state.epoch is not None: 229 | logs["epoch"] = round(self.state.epoch, 2) 230 | 231 | logs["step"] = self.state.global_step 232 | 233 | if hasattr(self, "extra_logs"): 234 | for key in self.extra_logs: 235 | if "aggr" in key: 236 | h = nested_numpify(self.extra_logs[key]) 237 | logs.update({key: h.mean().item()}) 238 | else: 239 | logs.update({key: self.extra_logs[key]}) 240 | 241 | self.extra_logs = {} 242 | 243 | self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) 244 | 245 | # Pop up the image type because they can't be saved 246 | pop_keys = [] 247 | for key in logs: 248 | if "image_pred" in key or "image_input" in key: 249 | pop_keys.append(key) 250 | for key in pop_keys: 251 | logs.pop(key) 252 | 253 | output = {**logs} 254 | self.state.log_history.append(output) 255 | 256 | 257 | 258 | import signal 259 | from subprocess import call 260 | class SIGUSR1Callback(transformers.TrainerCallback): 261 | def __init__(self) -> None: 262 | super().__init__() 263 | self.signal_received = False 264 | signal.signal(signal.SIGUSR1, self.handle_signal) 265 | # signal.signal(signal.SIGINT, self.handle_signal) 266 | logger.warn("Handler registered") 267 | 268 | def handle_signal(self, signum, frame): 269 | self.signal_received = True 270 | logger.warn("Signal received") 271 | 272 | def on_step_end(self, args, state, control, **kwargs): 273 | if self.signal_received: 274 | control.should_save = True 275 | control.should_training_stop = True 276 | 277 | def on_train_end(self, args, state, control, **kwargs): 278 | if self.signal_received: 279 | exit(0) 280 | 281 | 282 | 283 | def _pad_tensors_to_max_len(self, tensor, max_length): 284 | if self.model.config.pad_token_id is not None: 285 | pad_token_id = self.model.config.pad_token_id 286 | else: 287 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 288 | 289 | padded_tensor = pad_token_id * torch.ones( 290 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 291 | ) 292 | padded_tensor[:, : tensor.shape[-1]] = tensor 293 | return padded_tensor 294 | 295 | 296 | def prediction_step_seq2seq( 297 | self, 298 | model: nn.Module, 299 | inputs: Dict[str, Union[torch.Tensor, Any]], 300 | prediction_loss_only: bool, 301 | ignore_keys: Optional[List[str]] = None, 302 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 303 | """ 304 | Copied from HF's seq2seq_trainer.py 305 | 306 | Perform an evaluation step on `model` using `inputs`. 307 | 308 | Subclass and override to inject custom behavior. 309 | 310 | Args: 311 | model (`nn.Module`): 312 | The model to evaluate. 313 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 314 | The inputs and targets of the model. 315 | 316 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 317 | argument `labels`. Check your model's documentation for all accepted arguments. 318 | prediction_loss_only (`bool`): 319 | Whether or not to return the loss only. 320 | gen_kwargs: 321 | Additional `generate` specific kwargs. 322 | 323 | Return: 324 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 325 | labels (each being optional). 326 | """ 327 | 328 | has_labels = "labels" in inputs 329 | inputs = self._prepare_inputs(inputs) 330 | 331 | # XXX: adapt synced_gpus for fairscale as well 332 | # Priority (handled in generate): 333 | # gen_kwargs > model.generation_config > default GenerationConfig() 334 | gen_kwargs = self._gen_kwargs 335 | 336 | # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate 337 | # (otherwise, it would continue generating from the padded `decoder_input_ids`) 338 | if ( 339 | "labels" in inputs 340 | and "decoder_input_ids" in inputs 341 | and inputs["labels"].shape == inputs["decoder_input_ids"].shape 342 | ): 343 | inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} 344 | 345 | # New 346 | true_labels = inputs.pop("true_labels", None) 347 | generated_tokens = self.model.generate(**inputs, **gen_kwargs) 348 | 349 | # in case the batch is shorter than max length, the output should be padded 350 | # if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 351 | # generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 352 | if gen_kwargs["max_new_tokens"] is not None and generated_tokens.shape[-1] < gen_kwargs["max_new_tokens"] + 1: 353 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 354 | 355 | with torch.no_grad(): 356 | if has_labels: 357 | with self.compute_loss_context_manager(): 358 | outputs = model(**inputs) 359 | if self.label_smoother is not None: 360 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 361 | else: 362 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 363 | else: 364 | loss = None 365 | 366 | if self.args.prediction_loss_only: 367 | return loss, None, None 368 | 369 | if has_labels: 370 | labels = inputs["labels"] 371 | # if labels.shape[-1] < gen_config.max_length: 372 | # labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) 373 | # if gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: 374 | # labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) 375 | if gen_kwargs["max_new_tokens"] is not None and labels.shape[-1] < gen_kwargs["max_new_tokens"] + 1: 376 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_new_tokens"] + 1) 377 | else: 378 | labels = None 379 | 380 | # New 381 | if true_labels is not None: 382 | labels = true_labels 383 | 384 | return loss, generated_tokens, labels 385 | 386 | 387 | def get_cosine_schedule_to_min_lr_with_warmup( 388 | optimizer: Optimizer, 389 | num_warmup_steps: int, 390 | num_training_steps: int, 391 | max_lr: float, 392 | min_lr: float = 1e-5, 393 | num_cycles: float = 0.5, 394 | last_epoch: int = -1, 395 | ): 396 | """ 397 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 398 | initial lr set in the optimizer to a minimum learning rate, after a warmup period during which it increases linearly 399 | between 0 and the initial lr set in the optimizer. 400 | Args: 401 | optimizer ([`~torch.optim.Optimizer`]): 402 | The optimizer for which to schedule the learning rate. 403 | num_warmup_steps (`int`): 404 | The number of steps for the warmup phase. 405 | num_training_steps (`int`): 406 | The total number of training steps. 407 | max_lr (`float`): 408 | The maximum learning rate after warming up, right before decaying 409 | min_lr (`float`): 410 | The minimum learning rate at the end of training 411 | num_cycles (`float`, *optional*, defaults to 0.5): 412 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to the min 413 | value following a half-cosine). 414 | last_epoch (`int`, *optional*, defaults to -1): 415 | The index of the last epoch when resuming training. 416 | Return: 417 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 418 | """ 419 | 420 | def lr_lambda(current_step): 421 | if current_step < num_warmup_steps: 422 | return float(current_step) / float(max(1, num_warmup_steps)) 423 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 424 | return ( 425 | max( 426 | min_lr, 427 | min_lr + (max_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), 428 | ) 429 | / max_lr # Scale down by max_lr because LambdaLR multiplies back by max_lr 430 | ) 431 | 432 | logger.info("***** Creating cosine scheduler to min_lr with warmup *****") 433 | logger.info(f"\t{num_warmup_steps = }") 434 | logger.info(f"\t{num_training_steps = }") 435 | logger.info(f"\t{max_lr = }") 436 | logger.info(f"\t{min_lr = }") 437 | logger.info(f"\t{num_cycles = }") 438 | logger.info(f"\t{last_epoch = }") 439 | 440 | return LambdaLR(optimizer, lr_lambda, last_epoch) 441 | 442 | 443 | 444 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): 445 | """ 446 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or 447 | passed as an argument. 448 | 449 | Args: 450 | num_training_steps (int): The number of training steps to do. 451 | """ 452 | if self.lr_scheduler is None: 453 | if self.args.lr_scheduler_type == "cosine" and self.args.cosine_w_min: 454 | self.lr_scheduler = get_cosine_schedule_to_min_lr_with_warmup( 455 | optimizer=self.optimizer if optimizer is None else optimizer, 456 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), 457 | num_training_steps=num_training_steps, 458 | max_lr=self.args.learning_rate, 459 | min_lr=self.args.min_learning_rate 460 | ) 461 | else: 462 | self.lr_scheduler = get_scheduler( 463 | self.args.lr_scheduler_type, 464 | optimizer=self.optimizer if optimizer is None else optimizer, 465 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), 466 | num_training_steps=num_training_steps, 467 | ) 468 | return self.lr_scheduler 469 | 470 | 471 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 472 | """ 473 | Perform a training step on a batch of inputs. 474 | 475 | Subclass and override to inject custom behavior. 476 | 477 | Args: 478 | model (`nn.Module`): 479 | The model to train. 480 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 481 | The inputs and targets of the model. 482 | 483 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 484 | argument `labels`. Check your model's documentation for all accepted arguments. 485 | 486 | Return: 487 | `torch.Tensor`: The tensor with training loss on this batch. 488 | """ 489 | model.train() 490 | inputs = self._prepare_inputs(inputs) 491 | 492 | with self.compute_loss_context_manager(): 493 | loss = self.compute_loss(model, inputs) 494 | 495 | if self.args.n_gpu > 1: 496 | loss = loss.mean() # mean() to average on multi-gpu parallel training 497 | 498 | if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: 499 | # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` 500 | loss = loss / self.args.gradient_accumulation_steps 501 | 502 | if self.do_grad_scaling: 503 | self.scaler.scale(loss).backward() 504 | elif self.use_apex: 505 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 506 | scaled_loss.backward() 507 | elif self.deepspeed: 508 | # loss gets scaled under gradient_accumulation_steps in deepspeed 509 | loss = self.deepspeed.backward(loss) 510 | else: 511 | if is_ge_version("4.34.1"): 512 | self.accelerator.backward(loss) 513 | else: 514 | loss.backward() 515 | 516 | if getattr(self.args, "log_grad_norm", False): 517 | if not hasattr(self, "extra_logs"): 518 | self.extra_logs = {} 519 | # Go through all the parameters and log the gradient norm 520 | for name, param in model.named_parameters(): 521 | if param.grad is not None: 522 | self.extra_logs[f"grad_norm_{name}"] = torch.norm(param.grad.detach()).item() 523 | 524 | if getattr(self.args, "log_train_input", False) and "flattened_patches" in inputs: 525 | if not hasattr(self, "extra_logs"): 526 | self.extra_logs = {} 527 | images = [ 528 | flattened_patches_to_image( 529 | inputs["flattened_patches"][i, :, 2:].detach().cpu().to(torch.float32), 530 | height=self.args.height, 531 | width=self.args.width, 532 | patch_height=self.args.patch_height, 533 | patch_width=self.args.patch_width, 534 | image_mode=getattr(self.args, 'image_mode', 'RGB') 535 | ) for i in range(len(inputs["flattened_patches"])) 536 | ] 537 | # We save those images on the disk, in a folder that is named by the step 538 | # First create the folder (naming: step_rank) 539 | # os.makedirs(f"image_logs/{self.state.global_step}_{dist.get_rank()}", exist_ok=True) 540 | # # Save images 541 | # for i in range(len(images)): 542 | # images[i].save(f"image_logs/{self.state.global_step}_{dist.get_rank()}/{i}.png") 543 | self.extra_logs["train_image_input"] = [wandb.Image(image) for image in images] 544 | 545 | return loss.detach() 546 | 547 | 548 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 549 | 550 | def _save_checkpoint(self, model, trial, metrics=None): 551 | # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we 552 | # want to save except FullyShardedDDP. 553 | # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" 554 | 555 | # Save model checkpoint 556 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 557 | 558 | if self.hp_search_backend is None and trial is None: 559 | self.store_flos() 560 | 561 | run_dir = self._get_output_dir(trial=trial) 562 | output_dir = os.path.join(run_dir, checkpoint_folder) 563 | 564 | self._original_save_checkpoint(model, trial, metrics=metrics) 565 | 566 | 567 | def get_train_dataloader_for_streaming(self) -> DataLoader: 568 | """ 569 | Because streaming handles the distributed data parallel by itself, we don't need special data loader. 570 | The plainest data loader is enough. 571 | """ 572 | if self.train_dataset is None: 573 | raise ValueError("Trainer: training requires a train_dataset.") 574 | 575 | train_dataset = self.train_dataset 576 | data_collator = self.data_collator 577 | data_collator = self._get_collator_with_removed_columns(data_collator, description="training") 578 | 579 | dataloader_params = { 580 | "batch_size": self._train_batch_size, 581 | "collate_fn": data_collator, 582 | "num_workers": self.args.dataloader_num_workers, # Streaming dataset is probably not multi-thread safe 583 | "pin_memory": self.args.dataloader_pin_memory, 584 | } 585 | 586 | # Streaming is iterable 587 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 588 | dataloader_params["sampler"] = self._get_train_sampler() 589 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 590 | dataloader_params["worker_init_fn"] = seed_worker 591 | 592 | # Instead of use accelerate to prepare the dataloader, we just return a plain dataloader 593 | return DataLoader(train_dataset, **dataloader_params) 594 | 595 | 596 | def get_eval_dataloader_for_streaming(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: 597 | """ 598 | Because streaming handles the distributed data parallel by itself, we don't need special data loader. 599 | The plainest data loader is enough. 600 | """ 601 | if eval_dataset is None and self.eval_dataset is None: 602 | raise ValueError("Trainer: evaluation requires an eval_dataset.") 603 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 604 | data_collator = self.data_collator 605 | data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") 606 | 607 | dataloader_params = { 608 | "batch_size": self.args.eval_batch_size, 609 | "collate_fn": data_collator, 610 | "num_workers": self.args.dataloader_num_workers, # Streaming dataset is probably not multi-thread safe 611 | "pin_memory": self.args.dataloader_pin_memory, 612 | } 613 | 614 | # Streaming is iterable 615 | if not isinstance(eval_dataset, torch.utils.data.IterableDataset): 616 | dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) 617 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 618 | 619 | # Instead of use accelerate to prepare the dataloader, we just return a plain dataloader 620 | return DataLoader(eval_dataset, **dataloader_params) 621 | 622 | 623 | def trainer_addon(trainer, seq2seq=False, streaming_dataset=False): 624 | trainer._set_signature_columns_if_needed = _set_signature_columns_if_needed.__get__(trainer, Trainer) 625 | # New 626 | trainer.compute_loss = compute_loss_wrapper.__get__(trainer, Trainer) 627 | trainer.log = log.__get__(trainer, Trainer) 628 | trainer.create_scheduler = create_scheduler.__get__(trainer, Trainer) 629 | trainer.training_step = training_step.__get__(trainer, Trainer) 630 | trainer._original_save_checkpoint = trainer._save_checkpoint 631 | trainer._save_checkpoint = _save_checkpoint.__get__(trainer, Trainer) 632 | 633 | if streaming_dataset: 634 | trainer.get_train_dataloader = get_train_dataloader_for_streaming.__get__(trainer, Trainer) 635 | trainer.get_eval_dataloader = get_eval_dataloader_for_streaming.__get__(trainer, Trainer) 636 | 637 | trainer.add_callback(SIGUSR1Callback()) 638 | if seq2seq: 639 | trainer.prediction_step = prediction_step_seq2seq.__get__(trainer, Trainer) 640 | trainer._pad_tensors_to_max_len = _pad_tensors_to_max_len.__get__(trainer, Trainer) 641 | 642 | return trainer 643 | --------------------------------------------------------------------------------