├── detikzify ├── __init__.py ├── mcts │ ├── __init__.py │ ├── LICENSE │ ├── node.py │ ├── montecarlo.py │ └── README.md ├── dataset │ ├── scicap │ │ ├── __init__.py │ │ └── scicap.py │ ├── paper2fig │ │ ├── __init__.py │ │ └── paper2fig.py │ └── __init__.py ├── infer │ ├── __init__.py │ └── tikz.py ├── train │ ├── __init__.py │ ├── adapter │ │ ├── __init__.py │ │ └── train.py │ ├── pretrain.py │ └── train.py ├── webui │ ├── __init__.py │ ├── __main__.py │ ├── README.md │ ├── helpers.py │ └── strings.py ├── util │ ├── __init__.py │ ├── torch.py │ ├── subprocess.py │ ├── functools.py │ ├── image.py │ ├── generation.py │ └── trainer.py ├── model │ ├── v1 │ │ ├── configuration_detikzify.py │ │ ├── __init__.py │ │ ├── processing_detikzify.py │ │ └── modeling_detikzify.py │ ├── adapter │ │ ├── __init__.py │ │ └── processing_adapter.py │ ├── __init__.py │ ├── configuration_detikzify.py │ └── processing_detikzify.py └── evaluate │ ├── __init__.py │ ├── clipscore.py │ ├── kid.py │ ├── eed.py │ ├── dreamsim.py │ ├── crystalbleu.py │ └── imagesim.py ├── MANIFEST.in ├── .gitattributes ├── examples ├── tikzero │ ├── README.md │ ├── train.py │ └── pretrain.py ├── README.md ├── infer.py ├── train.py ├── pretrain.py ├── sketchify.py ├── eval.py └── refine.py ├── pyproject.toml ├── .gitignore └── LICENSE /detikzify/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /detikzify/mcts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune examples 2 | -------------------------------------------------------------------------------- /detikzify/dataset/scicap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /detikzify/dataset/paper2fig/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | detikzify/mcts/** linguist-vendored 2 | -------------------------------------------------------------------------------- /detikzify/infer/__init__.py: -------------------------------------------------------------------------------- 1 | from .tikz import * 2 | from .generate import * 3 | -------------------------------------------------------------------------------- /detikzify/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .pretrain import train as pretrain 2 | from .train import train 3 | -------------------------------------------------------------------------------- /detikzify/webui/__init__.py: -------------------------------------------------------------------------------- 1 | from .webui import * 2 | from .strings import * 3 | from .helpers import * 4 | -------------------------------------------------------------------------------- /detikzify/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .functools import * 2 | from .generation import * 3 | from .image import * 4 | from .subprocess import * 5 | from .torch import * 6 | from .trainer import * 7 | -------------------------------------------------------------------------------- /detikzify/train/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import SiglipVisionModel 2 | 3 | from ...model.adapter import CrossAttentionAdapterMixin 4 | from .pretrain import train as pretrain 5 | from .train import train 6 | #from .train import train 7 | 8 | 9 | class CrossAttentionSiglipVisionModel(SiglipVisionModel, CrossAttentionAdapterMixin): 10 | ... 11 | -------------------------------------------------------------------------------- /detikzify/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.load import load_dataset as _load_dataset 2 | from os.path import dirname, isdir, join 3 | 4 | def load_dataset(path, *args, **kwargs): 5 | if isdir(local := join(dirname(__file__), path)): 6 | return _load_dataset(local, *args, trust_remote_code=True, **kwargs) 7 | return _load_dataset(path, *args, **kwargs) 8 | -------------------------------------------------------------------------------- /examples/tikzero/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | The examples in this directory are specific to training Ti*k*Zero adapters. Each 3 | script has a command line interface and information about available options can 4 | be found by invoking them with the `--help` flag. 5 | 6 | > [!TIP] 7 | > The [`eval.py`](../eval.py) evaluation script in the parent directory also 8 | > works with adapters. 9 | -------------------------------------------------------------------------------- /detikzify/model/v1/configuration_detikzify.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaConfig 2 | 3 | class DetikzifyConfig(LlamaConfig): 4 | model_type = "detikzify" 5 | 6 | # compatibility with new inference code 7 | @property 8 | def image_token_id(self): 9 | return self.patch_token_id 10 | 11 | @property 12 | def pooling_mode(self): 13 | return "cos" 14 | -------------------------------------------------------------------------------- /detikzify/util/torch.py: -------------------------------------------------------------------------------- 1 | from torch.cuda import is_available as is_torch_cuda_available 2 | from transformers.utils import is_torch_npu_available, is_torch_xpu_available 3 | 4 | # https://github.com/huggingface/peft/blob/c4cf9e7d3b2948e71ec65a19e6cd1ff230781d13/src/peft/utils/other.py#L60-L71 5 | def infer_device(): 6 | if is_torch_cuda_available(): 7 | torch_device = "cuda" 8 | elif is_torch_xpu_available(): 9 | torch_device = "xpu" 10 | elif is_torch_npu_available(): 11 | torch_device = "npu" 12 | else: 13 | torch_device = "cpu" 14 | return torch_device 15 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | The examples in this directory may be helpful for training or pretraining your 3 | own DeTi*k*Zify models, or reproducing our evaluation results. Each script has 4 | a command line interface and information about available options can be found 5 | by invoking them with the `--help` flag. 6 | 7 | > [!NOTE] 8 | > The scripts provided here reflect the [training pipeline and overall 9 | > state](https://huggingface.co/nllg/detikzify-v2-8b-preview#model-card-for-detikzifyv2-8b) 10 | > of DeTi*k*Zifyv2. If you want to reproduce 11 | > DeTi*k*Zifyv1 you have to switch to a previous 12 | > [release](https://github.com/potamides/DeTikZify/releases) of this 13 | > repository. 14 | -------------------------------------------------------------------------------- /detikzify/model/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | from .modeling_adapter import CrossAttentionAdapterMixin 4 | from .processing_adapter import AdapterProcessor 5 | 6 | def has_adapter(model): 7 | return hasattr(model, "adapter") 8 | 9 | def load(model, processor, adapter_name_or_path=None, **kwargs): 10 | embedding_model = "meta-llama/Llama-3.2-1B" 11 | model.load_cross_attn_adapter(embedding_model, adapter_name_or_path, **kwargs) 12 | processor = AdapterProcessor( 13 | processor=processor, 14 | tokenizer=AutoTokenizer.from_pretrained( 15 | embedding_model, 16 | pad_token="<|finetune_right_pad_id|>", 17 | model_max_length=512, 18 | ), 19 | ) 20 | model.embedding_model.config.pad_token_id = processor.tokenizer.pad_token_id 21 | 22 | return model, processor 23 | -------------------------------------------------------------------------------- /detikzify/mcts/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2010-2018 ImparaAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /detikzify/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | # pyright: reportUnsupportedDunderAll=false 2 | from importlib import import_module 3 | from typing import Any 4 | 5 | from .imagesim import * # this metric is used by MCTS, so it is not optional 6 | 7 | __all__ = [ 8 | "ImageSim", 9 | "CrystalBLEU", 10 | "KernelInceptionDistance", 11 | "TexEditDistance", 12 | "DreamSim", 13 | "ClipScore", 14 | ] 15 | 16 | # lazy import optional metrics (https://peps.python.org/pep-0562/) 17 | def __getattr__(name) -> Any: 18 | def load(metric): 19 | return getattr(import_module("." + metric, __name__), name) 20 | try: 21 | match name: 22 | case "CrystalBLEU": 23 | return load("crystalbleu") 24 | case "KernelInceptionDistance": 25 | return load("kid") 26 | case "TexEditDistance": 27 | return load("eed") 28 | case "DreamSim": 29 | return load("dreamsim") 30 | case "ClipScore": 31 | return load("clipscore") 32 | 33 | except ImportError: 34 | raise ValueError( 35 | "Missing dependencies: " 36 | "Install this project with the [evaluate] feature name!" 37 | ) 38 | return import_module("." + name, __name__) 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "detikzify" 7 | readme = "README.md" 8 | license = {text = "Apache-2.0 License"} 9 | dependencies = [ 10 | "torch~=2.7.1", 11 | "torchvision~=0.22.1", 12 | "transformers[accelerate,tokenizers]~=4.52.4", 13 | "datasets~=3.6.0", 14 | "Pillow~=10.4.0", 15 | "requests~=2.32.3", 16 | "numpy~=2.1.1", 17 | # pdf 18 | "PyMuPDF~=1.24.10", 19 | "pdf2image~=1.17.0", 20 | "pdfCropMargins~=2.1.4", 21 | # webui 22 | "gradio~=4.38.1", 23 | "fastapi~=0.112.4", # https://github.com/gradio-app/gradio/issues/9278 24 | "pydantic~=2.10.6", # https://github.com/gradio-app/gradio/issues/10662 25 | # evaluate 26 | "POT~=0.9.4", 27 | "torchmetrics~=1.7.2", 28 | ] 29 | requires-python = "~=3.11" 30 | dynamic = ["version"] 31 | 32 | [project.optional-dependencies] 33 | evaluate = [ 34 | "Pygments~=2.18.0", 35 | "crystalbleu~=0.1.0", 36 | "nltk~=3.9.1", 37 | "sacremoses~=0.1.1", 38 | "dreamsim~=0.2.1", 39 | "protobuf~=5.28.3", 40 | "sentencepiece~=0.2.0" 41 | ] 42 | examples = [ 43 | "detikzify[evaluate]", 44 | "diffusers~=0.30.2" 45 | ] 46 | legacy = [ 47 | "timm~=1.0.11" 48 | ] 49 | deepspeed = [ 50 | "deepspeed~=0.17.1" 51 | ] 52 | 53 | [project.urls] 54 | repository = "https://github.com/potamides/DeTikZify" 55 | 56 | [tool.setuptools_scm] 57 | write_to = "detikzify/_version.py" 58 | parentdir_prefix_version = "detikzify-" 59 | 60 | [tool.setuptools.packages.find] 61 | include = ["detikzify*"] 62 | -------------------------------------------------------------------------------- /detikzify/webui/__main__.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from .strings import ALGORITHMS, MODELS 4 | from .webui import build_ui 5 | 6 | def parse_args(): 7 | argument_parser = ArgumentParser( 8 | description="Web UI for DeTikZify." 9 | ) 10 | argument_parser.add_argument( 11 | "--model", 12 | default=list(MODELS)[0], 13 | help="Initially selected model. You can also specify a path to your own models.", 14 | ) 15 | argument_parser.add_argument( 16 | "--algorithm", 17 | default=list(ALGORITHMS)[0], 18 | choices=list(ALGORITHMS), 19 | help="The inference algorithm to use.", 20 | ) 21 | argument_parser.add_argument( 22 | "--lock", 23 | action="store_true", 24 | help="Whether to allow users to change the model or not.", 25 | ) 26 | argument_parser.add_argument( 27 | "--lock_reason", 28 | default="Duplicate this space to be able to change this value.", 29 | help="Additional information why model selection is locked.", 30 | ) 31 | argument_parser.add_argument( 32 | "--share", 33 | action="store_true", 34 | help="Whether to create a publicly shareable link for the interface.", 35 | ) 36 | argument_parser.add_argument( 37 | "--light", 38 | action="store_true", 39 | help= "Whether to enforce light theme (useful for vector graphics with dark text)." 40 | ) 41 | argument_parser.add_argument( 42 | "--timeout", 43 | default=60, 44 | type=int, 45 | help="Allowed timeframe for compilation.", 46 | ) 47 | return vars(argument_parser.parse_args()) 48 | 49 | if __name__ == "__main__": 50 | args = parse_args() 51 | share = args.pop("share") 52 | build_ui(**args).queue().launch(share=share) 53 | -------------------------------------------------------------------------------- /detikzify/util/subprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import subprocess 4 | 5 | # Patched subprocess.run and subprocess.check_output that also kill children of 6 | # the started process on timeouts (cf. 7 | # https://alexandra-zaharia.github.io/posts/kill-subprocess-and-its-children-on-timeout-python/) 8 | class _Popen(subprocess.Popen): 9 | def __init__(self, *args, **kwargs): 10 | if os.name == "nt": 11 | return super().__init__( 12 | *args, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, **kwargs 13 | ) 14 | return super().__init__(*args, start_new_session=True, **kwargs) 15 | 16 | def safe_killpg(self): 17 | try: 18 | if os.name == "nt": 19 | # https://stackoverflow.com/a/28609523 20 | return os.kill(self.pid, signal.CTRL_BREAK_EVENT) 21 | return os.killpg(os.getpgid(self.pid), signal.SIGKILL) 22 | except ProcessLookupError: 23 | pass # Supress the race condition error; bpo-40550. 24 | 25 | def run(*popenargs, input=None, timeout=None, check=False, **kwargs): 26 | with _Popen(*popenargs, **kwargs) as process: 27 | try: 28 | stdout, stderr = process.communicate(input, timeout=timeout) 29 | except subprocess.TimeoutExpired: 30 | process.safe_killpg() 31 | process.wait() 32 | raise 33 | except: 34 | process.safe_killpg() 35 | raise 36 | retcode = process.poll() 37 | if check and retcode: 38 | raise subprocess.CalledProcessError( 39 | retcode, process.args, output=stdout, stderr=stderr 40 | ) 41 | return subprocess.CompletedProcess( 42 | process.args, retcode, stdout, stderr # type: ignore 43 | ) 44 | 45 | def check_output(*popenargs, timeout=None, **kwargs): 46 | return run( 47 | *popenargs, stdout=subprocess.PIPE, timeout=timeout, check=True, **kwargs 48 | ).stdout 49 | -------------------------------------------------------------------------------- /examples/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from argparse import ArgumentParser 3 | from sys import flags 4 | 5 | from PIL import UnidentifiedImageError 6 | from torch import bfloat16, float16 7 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported 8 | from transformers import TextStreamer, set_seed 9 | from transformers.utils import is_flash_attn_2_available 10 | 11 | from detikzify.infer import DetikzifyPipeline 12 | from detikzify.model import load 13 | 14 | try: 15 | import readline # patches input() 16 | except: 17 | pass 18 | 19 | def parse_args(): 20 | argument_parser = ArgumentParser( 21 | description="Inference helper for fine-tuned models." 22 | ) 23 | argument_parser.add_argument( 24 | "--model_name_or_path", 25 | required=True, 26 | help="the model checkpoint for weights initialization (local or hub)", 27 | ) 28 | return argument_parser.parse_args() 29 | 30 | if __name__ == "__main__": 31 | set_seed(0) 32 | model, processor = load( 33 | **vars(parse_args()), 34 | device_map="auto", 35 | torch_dtype=bfloat16 if is_cuda_available() and is_bf16_supported() else float16, 36 | attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, 37 | ) 38 | pipe = DetikzifyPipeline( 39 | model=model, 40 | processor=processor, 41 | streamer=TextStreamer( 42 | tokenizer=processor.tokenizer, 43 | skip_prompt=True, 44 | skip_special_tokens=True 45 | ) 46 | ) 47 | 48 | if flags.interactive: 49 | print("pipe(*args, **kwargs):", str(DetikzifyPipeline.sample.__doc__).strip()) 50 | else: 51 | print("Specify the path to an image (locally or as URL) to detikzify it!") 52 | while True: 53 | try: 54 | image = input("Image: ") 55 | except (KeyboardInterrupt, EOFError): 56 | break 57 | try: 58 | pipe(image=image) 59 | except KeyboardInterrupt: 60 | pass 61 | except (UnidentifiedImageError, FileNotFoundError, AttributeError, ValueError): 62 | print("Error: Cannot identify image file!") 63 | -------------------------------------------------------------------------------- /detikzify/util/functools.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from collections.abc import Callable 3 | from copy import copy 4 | from functools import cache, wraps 5 | from typing import Any 6 | 7 | def cache_cast(cast_func: Callable[..., Any]): 8 | """ 9 | functools.cache which takes a user-defined function to convert arguments 10 | into something immutable so it can be cached. 11 | """ 12 | def decorator(func): 13 | cache_args, cache_kwargs = None, None 14 | @cache 15 | def cached_func(_): 16 | return func(*cache_args, **cache_kwargs) 17 | @wraps(func) 18 | def wrapped_func(*args, **kwargs): 19 | nonlocal cache_args, cache_kwargs 20 | cache_args, cache_kwargs = args, kwargs 21 | return cached_func(cast_func(*args, **kwargs)) 22 | return wrapped_func 23 | return decorator 24 | 25 | def cast(cls, object): 26 | clone = copy(object) 27 | clone.__class__ = cls 28 | return clone 29 | 30 | # https://stackoverflow.com/a/12377059 31 | def listify(fn=None, wrapper=list): 32 | """ 33 | A decorator which wraps a function's return value in ``list(...)``. 34 | 35 | Useful when an algorithm can be expressed more cleanly as a generator but 36 | the function should return a list. 37 | 38 | Example:: 39 | 40 | >>> @listify 41 | ... def get_lengths(iterable): 42 | ... for i in iterable: 43 | ... yield len(i) 44 | >>> get_lengths(["spam", "eggs"]) 45 | [4, 4] 46 | >>> 47 | >>> @listify(wrapper=tuple) 48 | ... def get_lengths_tuple(iterable): 49 | ... for i in iterable: 50 | ... yield len(i) 51 | >>> get_lengths_tuple(["foo", "bar"]) 52 | (3, 3) 53 | """ 54 | def listify_return(fn): 55 | @wraps(fn) 56 | def listify_helper(*args, **kw): 57 | return wrapper(fn(*args, **kw)) 58 | return listify_helper 59 | if fn is None: 60 | return listify_return 61 | return listify_return(fn) 62 | 63 | def batchify(fn=None): 64 | def batch(list_of_dicts): 65 | batch_dict = defaultdict(list) 66 | for d in list_of_dicts: 67 | for k, v in d.items(): 68 | batch_dict[k].append(v) 69 | return batch_dict 70 | return listify(fn=fn, wrapper=batch) 71 | -------------------------------------------------------------------------------- /detikzify/model/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets import DownloadManager 2 | from safetensors.torch import load_file 3 | from transformers.utils.hub import has_file 4 | from transformers import ( 5 | AutoConfig, 6 | AutoModelForVision2Seq, 7 | AutoProcessor, 8 | is_timm_available, 9 | ) 10 | from transformers.utils.hub import is_remote_url 11 | 12 | from .configuration_detikzify import * 13 | from .modeling_detikzify import * 14 | from .processing_detikzify import * 15 | from .adapter import load as load_adapter 16 | 17 | if is_timm_available(): 18 | from .v1 import models as v1_models, load as load_v1 19 | 20 | def register(): 21 | try: 22 | AutoConfig.register("detikzify", DetikzifyConfig) 23 | AutoModelForVision2Seq.register(DetikzifyConfig, DetikzifyForConditionalGeneration) 24 | AutoProcessor.register(DetikzifyConfig, DetikzifyProcessor) 25 | except ValueError: 26 | pass # already registered 27 | 28 | def load(model_name_or_path, modality_projector=None, is_v1=False, **kwargs): 29 | # backwards compatibility with v1 models 30 | if is_timm_available() and (is_v1 or model_name_or_path in v1_models): # type: ignore 31 | model, tokenizer, image_processor = load_v1( # type: ignore 32 | model_name_or_path=model_name_or_path, 33 | modality_projector=modality_projector, 34 | **kwargs 35 | ) 36 | return model, DetikzifyProcessor( 37 | tokenizer=tokenizer, 38 | image_processor=image_processor, 39 | image_seq_len=model.config.num_patches, 40 | image_token=tokenizer.convert_ids_to_tokens(model.config.patch_token_id) 41 | ) 42 | 43 | register() 44 | processor = AutoProcessor.from_pretrained(model_name_or_path) 45 | model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, **kwargs) 46 | 47 | if modality_projector is not None: 48 | if is_remote_url(modality_projector): 49 | modality_projector = DownloadManager().download(modality_projector) 50 | model.load_state_dict( 51 | state_dict=load_file( 52 | filename=modality_projector, # type: ignore 53 | device=str(model.device) 54 | ), 55 | strict=False 56 | ) 57 | 58 | if has_file(model_name_or_path, "adapter/model.safetensors"): 59 | model, processor = load_adapter(model=model, processor=processor) 60 | 61 | return model, processor 62 | -------------------------------------------------------------------------------- /detikzify/dataset/paper2fig/paper2fig.py: -------------------------------------------------------------------------------- 1 | """ 2 | Images from the Paper2Fig100k dataset. 3 | """ 4 | from itertools import chain 5 | from json import load 6 | from os.path import basename 7 | import tarfile 8 | 9 | from datasets import Features, Image, Sequence, Value, builder 10 | from datasets.info import DatasetInfo 11 | from datasets.splits import Split, SplitGenerator 12 | 13 | from detikzify.util import convert, expand 14 | 15 | class Paper2FigConfig(builder.BuilderConfig): 16 | """BuilderConfig for Paper2Fig.""" 17 | 18 | def __init__(self, size, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.size = size 21 | self.archive = "https://zenodo.org/records/7299423/files/Paper2Fig100k.tar.gz" 22 | 23 | class Paper2Fig(builder.GeneratorBasedBuilder): 24 | """The Paper2Fig100k dataset in the format DeTikZify expects (everything is training data).""" 25 | 26 | BUILDER_CONFIG_CLASS = Paper2FigConfig 27 | 28 | def _info(self): 29 | features = { 30 | "caption": Value("string"), 31 | "mention": Sequence(Sequence(Value("string"))), 32 | "ocr": Sequence(Value("string")), 33 | "image": Image(), 34 | } 35 | return DatasetInfo( 36 | description=str(__doc__), 37 | features=Features(features), 38 | ) 39 | 40 | def _split_generators(self, dl_manager): 41 | archive = dl_manager.download(self.config.archive) # type: ignore 42 | return [SplitGenerator(name=str(Split.TRAIN), gen_kwargs=dict(archive=archive))] 43 | 44 | def _generate_examples(self, archive): 45 | with tarfile.open(archive) as tf: 46 | metadata = dict() 47 | for figdata in chain.from_iterable(load(tf.extractfile(f)) for f in tf if f.name.endswith(".json")): # type: ignore 48 | metadata[figdata.pop("figure_id")] = figdata 49 | for idx, member in enumerate(tf): 50 | if member.name.endswith(".png"): 51 | figure_id = basename(member.name).removesuffix(".png") 52 | figdata = metadata[figure_id] 53 | yield idx, dict( 54 | caption=figdata["captions"][0], 55 | mention=[figdata["captions"][1:]], 56 | ocr=[result['text'] for result in figdata['ocr_result']['ocr_result']], 57 | image=convert(expand(tf.extractfile(member), self.config.size), "png"), 58 | ) 59 | -------------------------------------------------------------------------------- /detikzify/model/v1/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets import DownloadManager 2 | from transformers import AutoConfig, AutoModel 3 | from transformers import AutoTokenizer, PretrainedConfig 4 | from transformers.utils.hub import is_remote_url 5 | 6 | from .configuration_detikzify import * 7 | from .modeling_detikzify import * 8 | from .processing_detikzify import * 9 | 10 | models = [ 11 | "nllg/detikzify-ds-1.3b", 12 | "nllg/detikzify-ds-7b", 13 | "nllg/detikzify-tl-1.1b", 14 | "nllg/detikzify-cl-7b", 15 | ] 16 | 17 | def register(): 18 | try: 19 | AutoConfig.register("detikzify", DetikzifyConfig) 20 | AutoModel.register(DetikzifyConfig, DetikzifyForCausalLM) 21 | except ValueError: 22 | pass # already registered 23 | 24 | def load(model_name_or_path, vision_tower="vit_so400m_patch14_siglip_384.webli", modality_projector=None, **kwargs): 25 | base_tokenizer = PretrainedConfig.from_pretrained(model_name_or_path).name_or_path or model_name_or_path 26 | tokenizer = AutoTokenizer.from_pretrained( 27 | pretrained_model_name_or_path=base_tokenizer, 28 | model_max_length=2048, 29 | add_bos_token=False, 30 | add_eos_token=True, 31 | pad_token="", 32 | padding_side="right", # NOTE: only for training, need to change to "left" for batched inference 33 | legacy=False 34 | ) 35 | model = DetikzifyForCausalLM.from_pretrained( 36 | pretrained_model_name_or_path=model_name_or_path, 37 | use_cache=True, 38 | **kwargs 39 | ) 40 | model.config.model_type = DetikzifyConfig.model_type # type: ignore 41 | model.generation_config.pad_token_id = tokenizer.pad_token_id # type: ignore 42 | 43 | if len(tokenizer) > model.config.vocab_size: # type: ignore 44 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) # type: ignore 45 | if modality_projector and is_remote_url(modality_projector): 46 | modality_projector = DownloadManager().download(modality_projector) 47 | 48 | processor = model.get_model().initialize_vision_modules( # type: ignore 49 | patch_token_id=tokenizer.bos_token_id, 50 | modality_projector=modality_projector, 51 | vision_tower=getattr(model.config, "vision_tower", vision_tower), # type: ignore 52 | feature_layer=getattr(model.config, "feature_layer", -1), # type: ignore 53 | concat_patches=getattr(model.config, "concat_patches", 3) # type: ignore 54 | ) 55 | 56 | return model, tokenizer, processor 57 | -------------------------------------------------------------------------------- /examples/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu 2 | from argparse import ArgumentParser 3 | from os.path import basename, join 4 | 5 | from datasets import Dataset 6 | from transformers import set_seed 7 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info 8 | 9 | from detikzify.dataset import load_dataset 10 | from detikzify.model import load 11 | from detikzify.train import train 12 | 13 | def parse_args(): 14 | argument_parser = ArgumentParser( 15 | description="Fine-tune DeTikZify on DaTikZ." 16 | ) 17 | argument_parser.add_argument("--base_model", 18 | required=True, 19 | help="The model checkpoint for weights initialization." 20 | ) 21 | argument_parser.add_argument( 22 | "--projector", 23 | help="url or path to a pretrained modality projector" 24 | ) 25 | argument_parser.add_argument("--datikz", 26 | required=True, 27 | help="path to the DaTikZ train split processed by the ./sketchify script (in parquet format)", 28 | ) 29 | argument_parser.add_argument("--sketch_ratio", 30 | default=.5, 31 | help="ratio of synthetic sketches generated through the ./sketchify script or image transforms", 32 | ) 33 | argument_parser.add_argument("--output", 34 | required=True, 35 | help="directory where to write the model files", 36 | ) 37 | argument_parser.add_argument("--deepspeed", 38 | help="path to a DeepSpeed json config file", 39 | ) 40 | argument_parser.add_argument("--gradient_checkpointing", 41 | action="store_true", 42 | help="use gradient checkpointing", 43 | ) 44 | 45 | return argument_parser.parse_args() 46 | 47 | if __name__ == "__main__": 48 | set_verbosity_info() 49 | enable_explicit_format() 50 | set_seed(0) 51 | 52 | args = parse_args() 53 | model, processor = load(args.base_model, modality_projector=args.projector) 54 | 55 | datikz: Dataset = load_dataset("parquet", data_files=args.datikz, split="train") # type: ignore 56 | datikz = datikz.select_columns(["image", "code", "sketches"]).rename_column("code", "text") 57 | 58 | train( 59 | model=model, 60 | processor=processor, 61 | dataset=datikz, 62 | sketch_ratio=args.sketch_ratio, 63 | output_dir=join(args.output, basename(model.config.name_or_path)), # type: ignore 64 | gradient_checkpointing=args.gradient_checkpointing, 65 | deepspeed=args.deepspeed, 66 | ) 67 | -------------------------------------------------------------------------------- /detikzify/evaluate/clipscore.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import List 3 | 4 | from PIL import Image 5 | import torch 6 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported 7 | from torchmetrics import Metric 8 | from transformers import AutoModel, AutoProcessor 9 | 10 | from ..util import expand, infer_device, load 11 | 12 | class ClipScore(Metric): 13 | """Calculates CLIPScore which is a text-to-image similarity metric.""" 14 | 15 | higher_is_better = True 16 | 17 | def __init__( 18 | self, 19 | model_name: str = "google/siglip-so400m-patch14-384", 20 | preprocess: bool = True, 21 | device: str = infer_device(), 22 | dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, 23 | **kwargs 24 | ): 25 | super().__init__(**kwargs) 26 | self.model_name = model_name 27 | self.preprocess = preprocess 28 | self._device = device 29 | self.set_dtype(dtype) 30 | 31 | self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") 32 | self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 33 | 34 | def __str__(self): 35 | return self.__class__.__name__ 36 | 37 | @cached_property 38 | def model(self): 39 | model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype) 40 | return model.to(self.device) 41 | 42 | @cached_property 43 | def processor(self): 44 | return AutoProcessor.from_pretrained(self.model_name) 45 | 46 | def update( 47 | self, 48 | images: Image.Image | str | List[Image.Image | str], 49 | text: str | List[str] 50 | ): 51 | images = images if isinstance(images, List) else [images] 52 | text = text if isinstance(text, List) else [text] 53 | 54 | for img, txt in zip(images, text): 55 | img = load(img) 56 | if self.preprocess: 57 | img = expand(img, max(img.size), do_trim=True) 58 | 59 | with torch.inference_mode(): 60 | inputs = self.processor(text=txt, images=img, truncation=True, return_tensors="pt") 61 | outputs = self.model( 62 | input_ids=inputs.input_ids.to(self.device), 63 | pixel_values=inputs.pixel_values.to(self.device, self.dtype) 64 | ) 65 | self.score += torch.sigmoid(outputs.logits_per_image).item() 66 | self.n_samples += 1 67 | 68 | def compute(self): 69 | return (self.score / self.n_samples).item() 70 | -------------------------------------------------------------------------------- /detikzify/evaluate/kid.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import List 3 | 4 | from PIL import Image 5 | import torch 6 | from torch import nn 7 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported 8 | from torchmetrics.image.kid import KernelInceptionDistance as KID 9 | from transformers import AutoModel, AutoImageProcessor 10 | 11 | from ..util import expand, infer_device, load 12 | 13 | class FeatureWrapper(nn.Module): 14 | def __init__(self, model_name, device, dtype): 15 | super().__init__() 16 | self.model_name = model_name 17 | self.device = device 18 | self.dtype = dtype 19 | 20 | @cached_property 21 | def model(self): 22 | model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype) 23 | return model.to(self.device) 24 | 25 | def forward(self, pixel_values): 26 | with torch.inference_mode(): 27 | return self.model.get_image_features(pixel_values.to(self.device, self.dtype)) 28 | 29 | class KernelInceptionDistance(KID): 30 | """Wrapper around torchmetrics Kernel Inception Distance with CLIP""" 31 | 32 | def __init__( 33 | self, 34 | model_name: str = "google/siglip-so400m-patch14-384", 35 | subset_size: int = 50, 36 | preprocess: bool = True, 37 | device: str = infer_device(), 38 | dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, 39 | **kwargs 40 | ): 41 | super().__init__( 42 | subset_size=subset_size, 43 | feature=FeatureWrapper( 44 | model_name=model_name, 45 | device=device, 46 | dtype=dtype), 47 | **kwargs 48 | ) 49 | self.preprocess = preprocess 50 | 51 | def __str__(self): 52 | return self.__class__.__name__ 53 | 54 | @cached_property 55 | def processor(self): 56 | return AutoImageProcessor.from_pretrained(self.inception.model_name) 57 | 58 | def open(self, img): 59 | img = load(img) 60 | if self.preprocess: 61 | return expand(img, max(img.size), do_trim=True) 62 | return img 63 | 64 | def update(self, imgs: Image.Image | str | List[Image.Image | str], *args, **kwargs): 65 | if not isinstance(imgs, List): 66 | imgs = [imgs] 67 | super().update( 68 | self.processor([self.open(img) for img in imgs], return_tensors="pt")["pixel_values"], 69 | *args, 70 | **kwargs 71 | ) 72 | 73 | def compute(self, *args, **kwargs): # type: ignore 74 | return tuple(tensor.item() for tensor in super().compute(*args, **kwargs)) 75 | -------------------------------------------------------------------------------- /detikzify/evaluate/eed.py: -------------------------------------------------------------------------------- 1 | from pygments.lexers.markup import TexLexer 2 | from pygments.token import Comment, Text 3 | from torchmetrics.text import ExtendedEditDistance 4 | from torchmetrics.functional.text.eed import ( 5 | _compute_sentence_statistics, 6 | _preprocess_en, 7 | _preprocess_ja, 8 | ) 9 | from torchmetrics.functional.text.helper import _validate_inputs 10 | 11 | class TexEditDistance(ExtendedEditDistance): 12 | """Adapt torchmetrics ExtendedEditDistance for TeX""" 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.lexer = TexLexer() 16 | 17 | def __str__(self): 18 | return self.__class__.__name__ 19 | 20 | def _preprocess_sentences(self, preds, target, language): 21 | target, preds = _validate_inputs(hypothesis_corpus=preds, ref_corpus=target) 22 | 23 | def tokenize(text): 24 | tokens = list() 25 | for tokentype, value in self.lexer.get_tokens(text): 26 | if value.strip(): 27 | if tokentype is Text: 28 | if language == "en": 29 | preprocess_function = _preprocess_en 30 | elif language == "ja": 31 | preprocess_function = _preprocess_ja 32 | else: 33 | raise ValueError(f"Expected argument `language` to either be `en` or `ja` but got {language}") 34 | tokens.extend(preprocess_function(value).split()) 35 | elif not tokentype is Comment: 36 | tokens.extend(value.split()) 37 | 38 | return " " + " ".join(tokens) + " " 39 | 40 | preds = [tokenize(pred) for pred in preds] 41 | target = [[tokenize(ref) for ref in reference] for reference in target] 42 | 43 | return preds, target 44 | 45 | def update(self, preds, target): 46 | """Update state with predictions and targets.""" 47 | preds, target = self._preprocess_sentences(preds, target, self.language) 48 | 49 | if self.sentence_eed is None: 50 | self.sentence_eed = [] 51 | 52 | if 0 in (len(preds), len(target[0])): 53 | return self.sentence_eed 54 | 55 | for (hypothesis, target_words) in zip(preds, target): 56 | score = _compute_sentence_statistics( 57 | hypothesis, 58 | target_words, 59 | self.alpha, 60 | self.rho, 61 | self.deletion, 62 | self.insertion 63 | ) 64 | self.sentence_eed.append(score) 65 | 66 | return self.sentence_eed 67 | 68 | def compute(self, *args, **kwargs): 69 | return super().compute(*args, **kwargs).item() # type: ignore 70 | -------------------------------------------------------------------------------- /detikzify/mcts/node.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | from math import log, sqrt 4 | 5 | class Node: 6 | def __init__(self, state): 7 | self.state = state 8 | self.win_value = 0 9 | self.policy_value = None 10 | self.visits = 0 11 | self.parent = None 12 | self.children = [] 13 | self.expanded = False 14 | self.player_number = None 15 | self.discovery_factor = 0.35 16 | self.is_widen_node = False 17 | 18 | def update_win_value(self, value): 19 | self.win_value += value 20 | self.visits += 1 21 | 22 | if self.parent: 23 | self.parent.update_win_value(value) 24 | 25 | def update_policy_value(self, value): 26 | self.policy_value = value 27 | 28 | def add_child(self, child): 29 | self.children.append(child) 30 | child.parent = self 31 | 32 | def add_children(self, children): 33 | for child in children: 34 | self.add_child(child) 35 | 36 | def get_preferred_child(self, root_node): 37 | best_children = [] 38 | best_score = float("-inf") 39 | 40 | for child in self.children: 41 | score = child.get_score(root_node) 42 | 43 | if score > best_score: 44 | best_score = score 45 | best_children = [child] 46 | elif score == best_score: 47 | best_children.append(child) 48 | 49 | return random.choice(best_children) 50 | 51 | def get_score(self, root_node): 52 | discovery_operand = ( 53 | self.discovery_factor 54 | * (self.policy_value or 1) 55 | * sqrt(log(self.parent.visits) / (self.visits or 1)) 56 | ) 57 | 58 | if self.is_widen_node: 59 | win_operand = 0 60 | else: 61 | win_multiplier = ( 62 | 1 if self.parent.player_number == root_node.player_number else -1 63 | ) 64 | win_operand = win_multiplier * self.win_value / (self.visits or 1) 65 | 66 | self.score = win_operand + discovery_operand 67 | 68 | return self.score 69 | 70 | def is_scorable(self): 71 | return self.visits or self.policy_value != None 72 | 73 | def print_node(self, f, i, root, st): 74 | escape = lambda x : json.dumps(x).strip('"') 75 | if self.parent is None: 76 | f.write((' ' * i) + st + " [label=\"" + escape(self.state) + "\",shape=box]\n") 77 | else: 78 | diff = '\n'.join([x for x in self.state.split("\n") if x not in self.parent.state.split("\n")]) 79 | f.write((' ' * i) + st + " [label=\"" + escape(diff) + "\",shape=box]\n") 80 | 81 | num = 0 82 | for child in self.children: 83 | new_st = st + "_" + str(num) 84 | child.print_node(f, i + 2, root, new_st) 85 | f.write(' ' * i + st + " -- " + new_st + "\n") 86 | num = num + 1 87 | -------------------------------------------------------------------------------- /examples/pretrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu 2 | from argparse import ArgumentParser 3 | from functools import partial 4 | from itertools import chain 5 | from os.path import basename, join 6 | 7 | from datasets import Dataset, IterableDataset 8 | from transformers import set_seed 9 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info 10 | 11 | from detikzify.dataset import load_dataset 12 | from detikzify.model import load 13 | from detikzify.train import pretrain 14 | from detikzify.util import convert, expand, batchify 15 | 16 | @batchify 17 | def preprocess(batch, size): 18 | """Concatenate captions and OCR tokens.""" 19 | for caption_images in chain.from_iterable(batch['caption_images']): 20 | caption = caption_images['caption'] 21 | for cil_pair in caption_images['cil_pairs']: 22 | sub_caption = cil_pair['sub_caption'] 23 | ocr = " ".join(cil_pair['image_ocr']) 24 | if text:=" ".join(filter(None, [caption, sub_caption, ocr])): 25 | yield dict( 26 | text=text, 27 | image=convert(expand(cil_pair['image'], size, do_trim=True), "png") 28 | ) 29 | 30 | def parse_args(): 31 | argument_parser = ArgumentParser( 32 | description="Pretrain projection layer of DeTikZify." 33 | ) 34 | argument_parser.add_argument("--base_model", 35 | required=True, 36 | help="The model checkpoint for weights initialization." 37 | ) 38 | argument_parser.add_argument("--size", 39 | default=1_000_000, 40 | type=int, 41 | help="the amount of figures to use for pretraining" 42 | ) 43 | argument_parser.add_argument("--output", 44 | required=True, 45 | help="directory where to write the model files", 46 | ) 47 | argument_parser.add_argument("--deepspeed", 48 | help="path to a DeepSpeed json config file", 49 | ) 50 | argument_parser.add_argument("--gradient_checkpointing", 51 | action="store_true", 52 | help="use gradient checkpointing", 53 | ) 54 | 55 | return argument_parser.parse_args() 56 | 57 | if __name__ == "__main__": 58 | set_verbosity_info() 59 | enable_explicit_format() 60 | set_seed(0) 61 | 62 | args = parse_args() 63 | model, processor = load(args.base_model) 64 | 65 | arxivcap: IterableDataset = load_dataset("MMInstruction/ArxivCap", split="train", streaming=True) # type: ignore 66 | arxivcap = arxivcap.shuffle(0).map( 67 | preprocess, 68 | batched=True, 69 | remove_columns=arxivcap.column_names, 70 | fn_kwargs=dict(size=model.config.vision_config.image_size), 71 | ) 72 | 73 | pretrain( 74 | model=model, 75 | processor=processor, 76 | output_dir=join(args.output, basename(model.config.name_or_path)), 77 | gradient_checkpointing=args.gradient_checkpointing, 78 | deepspeed=args.deepspeed, 79 | dataset=Dataset.from_generator( 80 | generator=partial(iter, arxivcap.take(args.size)), 81 | features=arxivcap.features, 82 | ) 83 | ) 84 | -------------------------------------------------------------------------------- /detikzify/util/image.py: -------------------------------------------------------------------------------- 1 | from base64 import b64decode 2 | from codecs import encode 3 | from io import BytesIO 4 | from os.path import isfile 5 | 6 | from PIL import Image, ImageChops, ImageOps 7 | import pymupdf 8 | import requests 9 | from transformers.utils.hub import is_remote_url 10 | 11 | DUMMY_IMAGE = Image.new("RGB", (24, 24), color="white") 12 | 13 | def convert(image, filetype): 14 | image.save(imgbytes:=BytesIO(), format=filetype) 15 | return Image.open(imgbytes) 16 | 17 | def remove_alpha(image, bg): 18 | # https://stackoverflow.com/a/62414364 19 | background = Image.new('RGBA', image.size, bg) 20 | alpha_composite = Image.alpha_composite(background, image.convert("RGBA")) 21 | return alpha_composite.convert("RGB") 22 | 23 | # https://stackoverflow.com/a/10616717 24 | def trim(image, bg="white"): 25 | bg = Image.new(image.mode, image.size, bg) 26 | diff = ImageChops.difference(image, bg) 27 | #diff = ImageChops.add(diff, diff, 2.0, -10) 28 | return image.crop(bbox) if (bbox:=diff.getbbox()) else image 29 | 30 | def expand(image, size, do_trim=False, bg="white"): 31 | """Expand image to a square of size {size}. Optionally trims borders first.""" 32 | image = trim(image, bg=bg) if do_trim else image 33 | return ImageOps.pad(image, (size, size), color=bg, method=Image.Resampling.LANCZOS) 34 | 35 | # based on transformers/image_utils.py (added support for rgba images) 36 | def load(image: Image.Image | str | bytes, bg="white", timeout=None): 37 | if isinstance(image, bytes): 38 | # assume image bytes and open 39 | image = Image.open(BytesIO(image)) 40 | elif isinstance(image, str): 41 | if is_remote_url(image): 42 | # https://stackoverflow.com/a/69791396 43 | headers = {'user-agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:68.0) Gecko/20100101 Firefox/68.0'} 44 | image = Image.open(BytesIO(requests.get(image, timeout=timeout, headers=headers).content)) 45 | elif isfile(image): 46 | image = Image.open(image) 47 | else: 48 | try: 49 | image.removeprefix("data:image/") 50 | image = Image.open(BytesIO(b64decode(image))) 51 | except Exception as e: 52 | raise ValueError( 53 | "Incorrect image source. " 54 | "Must be a valid URL starting with `http://` or `https://`, " 55 | "a valid path to an image file, bytes, or a base64 encoded " 56 | f"string. Got {image}. Failed with {e}" 57 | ) 58 | 59 | image = ImageOps.exif_transpose(image) # type: ignore 60 | return remove_alpha(image, bg=bg) 61 | 62 | def redact(doc, rot_13=False): 63 | for page in (copy:=pymupdf.open("pdf", doc.tobytes())): 64 | for word in page.get_text("words", clip=pymupdf.INFINITE_RECT()): # type: ignore 65 | text = encode(word[4], "rot13") if rot_13 else None 66 | page.add_redact_annot(word[:4], text=text, fill=False) # type: ignore 67 | page.apply_redactions( # type: ignore 68 | images=pymupdf.PDF_REDACT_IMAGE_NONE, # type: ignore 69 | graphics=pymupdf.PDF_REDACT_LINE_ART_NONE # type: ignore 70 | ) 71 | return copy 72 | -------------------------------------------------------------------------------- /detikzify/evaluate/dreamsim.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import List 3 | 4 | from PIL import Image 5 | from dreamsim import dreamsim 6 | from huggingface_hub import cached_assets_path 7 | import torch 8 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported 9 | from torchmetrics import Metric 10 | 11 | from ..util import expand, infer_device, load 12 | 13 | class DreamSim(Metric): 14 | """Perceptual image similarity using DreamSim""" 15 | 16 | higher_is_better = True 17 | 18 | def __init__( 19 | self, 20 | model_name: str = "ensemble", 21 | pretrained: bool = True, 22 | normalize: bool = True, 23 | preprocess: bool = True, 24 | device: str = infer_device(), 25 | dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, 26 | **kwargs 27 | ): 28 | super().__init__(**kwargs) 29 | self.model_name = model_name 30 | self.pretrained = pretrained 31 | self.normalize = normalize 32 | self._device = device 33 | self.set_dtype(dtype) 34 | self.preprocess = preprocess 35 | 36 | self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") 37 | self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 38 | 39 | def __str__(self): 40 | return self.__class__.__name__ 41 | 42 | @cached_property 43 | def dreamsim(self): 44 | model, processor = dreamsim( 45 | dreamsim_type=self.model_name, 46 | pretrained = self.pretrained, 47 | normalize_embeds=self.normalize, 48 | device=str(self.device), 49 | cache_dir=str(cached_assets_path(library_name="evaluate", namespace=self.__class__.__name__.lower())) 50 | ) 51 | for extractor in model.extractor_list: 52 | extractor.model = extractor.model.to(self.dtype) 53 | extractor.proj = extractor.proj.to(self.dtype) 54 | return dict( 55 | model=model.to(self.dtype), 56 | processor=processor 57 | ) 58 | 59 | @property 60 | def model(self): 61 | return self.dreamsim['model'] 62 | 63 | @property 64 | def processor(self): 65 | return self.dreamsim['processor'] 66 | 67 | def update( 68 | self, 69 | img1: Image.Image | str | List[Image.Image | str], 70 | img2: Image.Image | str | List[Image.Image | str], 71 | ): 72 | if isinstance(img1, List) or isinstance(img2, List): 73 | assert type(img1) == type(img2) and len(img1) == len(img2) # type: ignore 74 | else: 75 | img1, img2 = [img1], [img2] 76 | 77 | for i1, i2 in zip(img1, img2): # type: ignore 78 | i1, i2 = load(i1), load(i2) 79 | if self.preprocess: 80 | i1 = expand(load(i1), max(i1.size), do_trim=True) 81 | i2 = expand(load(i2), max(i2.size), do_trim=True) 82 | i1 = self.processor(i1).to(self.device, self.dtype) 83 | i2 = self.processor(i2).to(self.device, self.dtype) 84 | with torch.inference_mode(): 85 | self.score += 1 - self.model(i1, i2).item() # type: ignore 86 | self.n_samples += 1 87 | 88 | def compute(self): 89 | return (self.score / self.n_samples).item() 90 | -------------------------------------------------------------------------------- /detikzify/util/generation.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | from typing import Optional 3 | 4 | from transformers import StoppingCriteria 5 | from transformers.generation import streamers 6 | 7 | class ExplicitAbort(StoppingCriteria): 8 | """ 9 | Abort a model generation explicitly (i.e., when using a streamer in a thread). 10 | """ 11 | def __init__(self): 12 | super().__init__() 13 | self.should_stop = False 14 | 15 | def __call__(self, input_ids, scores, **kwargs) -> bool: 16 | return self.should_stop 17 | 18 | def reset(self): 19 | self.should_stop = False 20 | return self 21 | 22 | def abort(self): 23 | self.should_stop = True 24 | 25 | class TokenStreamer(streamers.BaseStreamer): 26 | """ 27 | Stream raw token ids (i.e., not decoded strings). 28 | """ 29 | def __init__(self, skip_prompt: bool = True, timeout: Optional[float] = None): 30 | self.skip_prompt = skip_prompt 31 | self.next_tokens_are_prompt = True 32 | self.token_queue = Queue() 33 | self.stop_signal = None 34 | self.timeout = timeout 35 | 36 | def put(self, value): 37 | if len(value.shape) > 1 and value.shape[0] > 1: 38 | raise ValueError("TokenStreamer only supports batch size 1") 39 | elif len(value.shape) > 1: 40 | value = value[0] 41 | 42 | if self.skip_prompt and self.next_tokens_are_prompt: 43 | self.next_tokens_are_prompt = False 44 | return 45 | 46 | for token_id in value.tolist(): 47 | self.token_queue.put(token_id, timeout=self.timeout) 48 | 49 | def end(self): 50 | self.next_tokens_are_prompt = True 51 | self.token_queue.put(self.stop_signal, timeout=self.timeout) 52 | 53 | def propagate_error(self, exc): 54 | self.token_queue.put(exc, timeout=self.timeout) 55 | 56 | def __iter__(self): 57 | return self 58 | 59 | def __next__(self): 60 | value = self.token_queue.get(timeout=self.timeout) 61 | if value == self.stop_signal: 62 | raise StopIteration() 63 | elif isinstance(value, BaseException): 64 | raise value 65 | else: 66 | return value 67 | 68 | class TextIteratorStreamer(streamers.TextIteratorStreamer): 69 | def propagate_error(self, exc): 70 | self.text_queue.put(exc, timeout=self.timeout) 71 | 72 | def __next__(self): 73 | value = self.text_queue.get(timeout=self.timeout) 74 | if value == self.stop_signal: 75 | raise StopIteration() 76 | elif isinstance(value, BaseException): 77 | raise value 78 | else: 79 | return value 80 | 81 | class StreamerList(list, streamers.BaseStreamer): 82 | """ 83 | Similar to StoppingCriteriaList, only for Streamers. 84 | """ 85 | def put(self, value): 86 | for streamer in self: 87 | streamer.put(value) 88 | 89 | def end(self): 90 | for streamer in self: 91 | streamer.end() 92 | 93 | def unwrap_processor(processor): 94 | """ 95 | Unwrap a processor, nested processors can happen when using the adapter 96 | processor. 97 | """ 98 | if hasattr(processor, "processor"): 99 | return unwrap_processor(processor.processor) 100 | else: 101 | return processor 102 | -------------------------------------------------------------------------------- /detikzify/model/adapter/processing_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, TYPE_CHECKING, Union, Unpack 2 | 3 | from transformers.feature_extraction_utils import BatchFeature 4 | from transformers.image_utils import ImageInput, make_list_of_images 5 | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin 6 | from transformers.tokenization_utils_base import ( 7 | BatchEncoding, 8 | PreTokenizedInput, 9 | TextInput, 10 | ) 11 | from transformers.utils import logging 12 | 13 | from ...util import DUMMY_IMAGE 14 | 15 | if TYPE_CHECKING: 16 | from transformers.tokenization_utils_base import PreTokenizedInput 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | class AdapterProcessor(ProcessorMixin): 21 | attributes = ["processor", "tokenizer"] 22 | processor_class = ("ProcessorMixin", "ImageProcessingMixin") 23 | tokenizer_class = "AutoTokenizer" 24 | 25 | def __init__(self, processor, tokenizer=None, **kwargs): 26 | if processor is None: 27 | raise ValueError("You need to specify a `processor`.") 28 | if tokenizer is None: 29 | raise ValueError("You need to specify a `tokenizer`.") 30 | super().__init__(processor, tokenizer, **kwargs) 31 | 32 | def __call__( 33 | self, 34 | text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, 35 | images: Optional[ImageInput] = None, 36 | **kwargs: Unpack[ProcessingKwargs], 37 | ) -> BatchEncoding: 38 | if images is None and text is None: 39 | raise ValueError("Either `images` or `text` (or both) are expected as arguments to an `AdapterProcessor` instance.") 40 | 41 | text_kwargs, images_kwargs = kwargs.pop("text_kwargs", {}), kwargs.pop("images_kwargs", {}) 42 | 43 | if text is None: 44 | text_inputs = dict() 45 | else: 46 | text = [text] if isinstance(text, str) else text 47 | text_inputs = {f"adapter_{key}": value for key, value in self.tokenizer(text=text, **kwargs, **text_kwargs).items()} 48 | if getattr(self.processor, "model_expects_text", False): 49 | images_kwargs.update(text=text, add_bos_token=True) 50 | if images is None: 51 | image_inputs = self.processor(images=len(text) * [DUMMY_IMAGE], **kwargs, **images_kwargs) 52 | image_inputs = dict((k, image_inputs[k]) for k in ["input_ids", "attention_mask"] if k in image_inputs) 53 | else: 54 | images = make_list_of_images(images) 55 | image_inputs = self.processor(images=images, **kwargs, **images_kwargs) 56 | 57 | if text is not None and images is not None and len(images) != len(text): 58 | raise ValueError( 59 | f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image." 60 | ) 61 | 62 | return BatchFeature(data={**image_inputs, **text_inputs}) 63 | 64 | def batch_decode(self, *args, **kwargs): 65 | return self.processor.batch_decode(*args, **kwargs) 66 | 67 | def decode(self, *args, **kwargs): 68 | return self.processor.decode(*args, **kwargs) 69 | 70 | @property 71 | def model_input_names(self): 72 | tokenizer_input_names = self.tokenizer.model_input_names 73 | processor_input_names = self.processor.model_input_names 74 | return list(dict.fromkeys(tokenizer_input_names + processor_input_names)) 75 | -------------------------------------------------------------------------------- /detikzify/mcts/montecarlo.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | 5 | class MonteCarlo: 6 | def __init__(self, root_node, mins_timeout=None): 7 | self.root_node = root_node 8 | self.solution = None 9 | self.child_finder = None 10 | self.node_evaluator = lambda child, montecarlo: None 11 | self.stats_expansion_count = 0 12 | self.stats_failed_expansion_count = 0 13 | self.mins_timeout = mins_timeout 14 | 15 | def make_choice(self): 16 | best_children = [] 17 | most_visits = float("-inf") 18 | 19 | for child in self.root_node.children: 20 | if child.visits > most_visits: 21 | most_visits = child.visits 22 | best_children = [child] 23 | elif child.visits == most_visits: 24 | best_children.append(child) 25 | 26 | return random.choice(best_children) 27 | 28 | def make_exploratory_choice(self): 29 | children_visits = map(lambda child: child.visits, self.root_node.children) 30 | children_visit_probabilities = [ 31 | visit / self.root_node.visits for visit in children_visits 32 | ] 33 | random_probability = random.uniform(0, 1) 34 | probabilities_already_counted = 0.0 35 | 36 | for i, probability in enumerate(children_visit_probabilities): 37 | if probabilities_already_counted + probability >= random_probability: 38 | return self.root_node.children[i] 39 | 40 | probabilities_already_counted += probability 41 | 42 | def simulate(self, expansion_count=1): 43 | i = 0 44 | 45 | start_time = time.time() 46 | 47 | while expansion_count is None or i < expansion_count: 48 | i += 1 49 | 50 | if self.solution is not None: 51 | return 52 | 53 | if self.mins_timeout is not None: 54 | curr_time = time.time() 55 | duration = curr_time - start_time 56 | 57 | if duration > (self.mins_timeout * 60): 58 | print("reached timelimit, stopping expansion on current node") 59 | return 60 | 61 | current_node = self.root_node 62 | 63 | while current_node.expanded: 64 | current_node = current_node.get_preferred_child(self.root_node) 65 | 66 | self.expand(current_node) 67 | 68 | def expand(self, node): 69 | self.stats_expansion_count += 1 70 | self.child_finder(node, self) 71 | 72 | for child in node.children: 73 | child_win_value = self.node_evaluator(child, self) 74 | 75 | if child_win_value != None: 76 | child.update_win_value(child_win_value) 77 | 78 | if not child.is_scorable(): 79 | self.random_rollout(child) 80 | child.children = [] 81 | 82 | if len(node.children): 83 | node.expanded = True 84 | else: 85 | self.stats_failed_expansion_count += 1 86 | 87 | def random_rollout(self, node): 88 | self.child_finder(node, self) 89 | child = random.choice(node.children) 90 | node.children = [] 91 | node.add_child(child) 92 | child_win_value = self.node_evaluator(child, self) 93 | 94 | if child_win_value != None: 95 | node.update_win_value(child_win_value) 96 | else: 97 | self.random_rollout(child) 98 | 99 | def print_tree(self, f): 100 | f.write("graph\n{\n") 101 | self.root_node.print_node(f, 0, self.root_node, "a") 102 | f.write("}\n") 103 | -------------------------------------------------------------------------------- /examples/tikzero/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu 2 | from argparse import ArgumentParser 3 | from datetime import timedelta 4 | from os.path import basename, join 5 | 6 | from datasets import load_dataset 7 | from torch import distributed as dist 8 | from transformers import AutoTokenizer, set_seed 9 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info 10 | 11 | from detikzify.model import load 12 | from detikzify.model.adapter import AdapterProcessor 13 | from detikzify.train.adapter import CrossAttentionSiglipVisionModel, train 14 | 15 | 16 | def load_adapter(base_model, embedding_model, adapter_model): 17 | model, processor = load(base_model) 18 | vision_model = CrossAttentionSiglipVisionModel.from_pretrained( 19 | pretrained_model_name_or_path=None, 20 | config=model.config.vision_config, 21 | state_dict=model.model.vision_model.state_dict(), 22 | torch_dtype="bfloat16", 23 | ) 24 | del model 25 | 26 | vision_model.load_cross_attn_adapter(embedding_model, adapter_model) 27 | processor = AdapterProcessor( 28 | processor=processor.image_processor, 29 | tokenizer=AutoTokenizer.from_pretrained( 30 | embedding_model, 31 | pad_token="<|finetune_right_pad_id|>", 32 | model_max_length=512, 33 | ) 34 | ) 35 | vision_model.embedding_model.config.pad_token_id = processor.tokenizer.pad_token_id 36 | 37 | return vision_model, processor 38 | 39 | def parse_args(): 40 | argument_parser = ArgumentParser( 41 | description="Fine-tune a TikZero adapter end-to-end, optionally conditioned on captions." 42 | ) 43 | argument_parser.add_argument("--base_model", 44 | required=True, 45 | help="The DeTikZify model checkpoint for weights initialization." 46 | ) 47 | argument_parser.add_argument("--embedding_model", 48 | default="meta-llama/Llama-3.2-1B", 49 | help=( 50 | "The adapter embedding model checkpoint for weights initialization. " 51 | "Only LLaMA 3.1/3.2 models are officially supported." 52 | ) 53 | ) 54 | argument_parser.add_argument("--adapter_model", 55 | required=True, 56 | help= "The adapter model checkpoint obtained from the `pretrain.py` script." 57 | ) 58 | argument_parser.add_argument("--datikz", 59 | default="nllg/datikz-v3", 60 | help="Path or name of the DaTikZ dataset.", 61 | ) 62 | argument_parser.add_argument("--caption_condition", 63 | action="store_true", 64 | help="whether to also condition model on captions", 65 | ) 66 | argument_parser.add_argument("--output", 67 | required=True, 68 | help="directory where to write the model files", 69 | ) 70 | argument_parser.add_argument("--deepspeed", 71 | help="path to a DeepSpeed json config file", 72 | ) 73 | argument_parser.add_argument("--gradient_checkpointing", 74 | action="store_true", 75 | help="use gradient checkpointing", 76 | ) 77 | 78 | return argument_parser.parse_args() 79 | 80 | if __name__ == "__main__": 81 | set_verbosity_info() 82 | enable_explicit_format() 83 | dist.init_process_group(timeout=timedelta(days=3)) 84 | set_seed(0) 85 | 86 | args = parse_args() 87 | vision_model, processor = load_adapter(args.base_model, args.embedding_model, args.adapter_model) 88 | datikz = load_dataset(args.datikz, split="train") 89 | 90 | train( 91 | model=vision_model, 92 | processor=processor, 93 | dataset=datikz.filter(lambda ex: len(ex['caption']) > 0), 94 | caption_condition=args.caption_condition, 95 | output_dir=join(args.output, basename(model.config.name_or_path)), # type: ignore 96 | gradient_checkpointing=args.gradient_checkpointing, 97 | deepspeed=args.deepspeed, 98 | ) 99 | -------------------------------------------------------------------------------- /detikzify/train/pretrain.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | import os 4 | from typing import List 5 | 6 | from transformers import Trainer, TrainingArguments 7 | 8 | IGNORE_INDEX = -100 9 | WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) 10 | 11 | def tokenize( 12 | batch, 13 | processor, 14 | **kwargs 15 | ): 16 | image_token = processor.image_token 17 | image_token_id = processor.tokenizer.convert_tokens_to_ids(image_token) 18 | 19 | input_ids = processor( 20 | text=batch['text'], 21 | images=batch['image'], 22 | max_length=processor.tokenizer.model_max_length, 23 | pad_to_multiple_of=8, 24 | add_eos_token=True, 25 | **kwargs 26 | ) 27 | input_ids['labels'] = copy.deepcopy(input_ids['input_ids']) 28 | 29 | # do not train on image and pad tokens 30 | for label_ids in input_ids['labels']: 31 | for idx, label_id in enumerate(label_ids): 32 | if label_id in {image_token_id, processor.tokenizer.pad_token_id}: 33 | label_ids[idx] = IGNORE_INDEX 34 | 35 | return input_ids 36 | 37 | 38 | def train( 39 | output_dir: str, 40 | model, 41 | processor, 42 | dataset, 43 | deepspeed=None, 44 | # training hyperparams 45 | batch_size: int = 256, 46 | micro_batch_size: int = 1, 47 | num_epochs: int = 1, 48 | learning_rate: float = 1e-3, 49 | gradient_checkpointing: bool = False, 50 | full_finetune_modules: List[str] = [ 51 | "modality_projection", 52 | ], 53 | ): 54 | gradient_accumulation_steps = batch_size // micro_batch_size 55 | 56 | if WORLD_SIZE != 1: 57 | gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE 58 | for name, param in model.named_parameters(): 59 | if not any(module in name for module in full_finetune_modules): 60 | param.requires_grad = False 61 | 62 | dataset.set_transform(partial( 63 | tokenize, 64 | processor=processor, 65 | return_tensors="pt", 66 | truncation=True, 67 | padding=True 68 | )) 69 | 70 | trainer = Trainer( 71 | model=model, 72 | train_dataset=dataset, 73 | args=TrainingArguments( 74 | per_device_train_batch_size=micro_batch_size, 75 | gradient_accumulation_steps=gradient_accumulation_steps, 76 | gradient_checkpointing=gradient_checkpointing, 77 | # https://github.com/huggingface/transformers/issues/21381 78 | gradient_checkpointing_kwargs={'use_reentrant':False}, 79 | dataloader_num_workers=WORLD_SIZE, 80 | warmup_ratio=0.03, 81 | weight_decay=0, 82 | num_train_epochs=num_epochs, 83 | learning_rate=learning_rate, 84 | torch_compile=True, 85 | bf16=True, 86 | tf32=True, 87 | logging_steps=10, 88 | lr_scheduler_type="cosine", 89 | optim="adamw_torch" if deepspeed else "adamw_torch_fused", 90 | ddp_find_unused_parameters=False, 91 | remove_unused_columns=False, 92 | save_strategy="no", 93 | report_to="none", 94 | output_dir=output_dir, 95 | deepspeed=deepspeed, 96 | ) 97 | ) 98 | 99 | if trainer.is_deepspeed_enabled and trainer.accelerator.state.deepspeed_plugin.hf_ds_config.is_zero3(): 100 | raise ValueError("Pretraining with zero stage 3 is not yet supported.") 101 | 102 | trainer.train() 103 | 104 | model.save_pretrained( 105 | output_dir, 106 | state_dict={ 107 | name: weight 108 | for name, weight in model.state_dict().items() 109 | if any(key_match in name for key_match in full_finetune_modules) 110 | }, 111 | ) 112 | trainer.save_state() 113 | 114 | return model, processor 115 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Project pecific 163 | */_version.py 164 | -------------------------------------------------------------------------------- /detikzify/webui/README.md: -------------------------------------------------------------------------------- 1 | # Web UI 2 | The web UI of DeTi*k*Zify requires [TeX Live 3 | 2023](https://www.tug.org/texlive), [ghostscript](https://www.ghostscript.com), 4 | and [poppler](https://poppler.freedesktop.org). You can launch it by running 5 | `python -m detikzify.webui`. It comes with a command line interface. With the 6 | `--share` flag, for example, you can create a shareable link. Checkout `--help` 7 | for a full list of supported options. As scientific figures usually use black 8 | fonts on a white background, it is best to use the web UI in light mode. This 9 | can be enforced by using the `--light` flag. If [FlashAttention]( 10 | https://huggingface.co/docs/transformers/en/perf_infer_gpu_one#flashattention) 11 | is installed, it is picked up automatically and should boost inference speeds. 12 | 13 | ## Usage Tips 14 | **Visual Prompting** Creating sketches for DeTi*k*Zify (or providing any input 15 | images) shares many similarities with the process of [prompting large language 16 | models](https://en.wikipedia.org/wiki/Prompt_engineering). If DeTi*k*Zify 17 | struggles to comprehend your intent, consider "rephrasing" your input. This 18 | could mean simplifying your sketches or focusing more on the key issue at hand. 19 | In 20 | [this](https://github.com/potamides/DeTikZify/assets/53401822/2819ebca-81f6-4173-8809-0b4255d3e976) 21 | particular instance, for example, we attempted to "prompt" DeTikZify to align 22 | characters diagonally around an equal sign, but it was unsuccessful even after 23 | many simulations. However, upon adjusting the input (by reducing the stroke 24 | width and using more easily recognizable characters) we achieved the [intended 25 | output](https://github.com/potamides/DeTikZify/assets/53401822/c8ecfbff-d22e-41d5-8f73-e0cfafe88690) 26 | after only one simulation. 27 | 28 | **Image Editor** You can draw sketches in the integrated image editor, but its 29 | feature set is quite limited. If you are not satisfied with the synthesized 30 | Ti*k*Z programs, try drawing more elaborate sketches in an editor of your 31 | choice (perhaps with graphics primitives) and upload them into the UI. 32 | Alternatively, experimenting with line thickness and/or colors in the 33 | integrated editor might also help. 34 | 35 | **Input Postprocessing** Please note that all input images are cropped to the 36 | smallest square around their content and then resized to the resolution 37 | DeTi*k*Zify expects. If you leave large margins this means that DeTi*k*Zify 38 | might perceive your input differently from how you intended (e.g., by drawing 39 | thicker axes). As a rule of thumb, always try to fill as much of the canvas as 40 | possible. 41 | 42 | **Input Complexity** If you provide very complex sketches (or figures) and are 43 | not satisfied with the results, you can also try segmenting (or simplifying) 44 | your input and letting DeTi*k*Zify synthesize the individual pieces 45 | independently. This has the advantage that the results will probably be better, 46 | and the disadvantage that you will have to modify and assemble the pieces 47 | yourself. 48 | 49 | **Source Code Artifacts** Due to the way we preprocess our 50 | [arXiv.org](https://arxiv.org) data, the preambles of the extracted Ti*k*Z 51 | programs sometimes include packages that are not used inside the `tikzpicture` 52 | environments, and the DeTi*k*Zify models pick up on this behavior. While this 53 | does not hinder compilation in any way, we still recommend everyone to check 54 | the generated preambles and clean them up, if necessary. 55 | 56 | **Accuracy-Efficiency Trade-Off** We noticed that lower values for temperatures 57 | and top-p (nucleus) values force DeTi*k*Zify to generate Ti*k*Z programs that 58 | follow the input images more closely, at the expense of generating more 59 | compile-time errors. We pick sensible defaults that aim to balance these two 60 | aspects, but you might want to try to tune these parameters yourself. 61 | 62 | **External Graphics** In DaTi*k*Zv2, we replace any externally 63 | included graphics in the `tikzpicture` environments with the [example 64 | image](https://mirrors.ctan.org/macros/latex/contrib/mwe/example-image.pdf) 65 | placeholder from the [mwe](http://www.ctan.org/pkg/mwe) package. So if you want 66 | to generate code with placeholders for your own external graphics, just draw 67 | that example image. 68 | -------------------------------------------------------------------------------- /detikzify/webui/helpers.py: -------------------------------------------------------------------------------- 1 | from functools import cache, lru_cache 2 | from inspect import signature 3 | from operator import itemgetter 4 | from os import fdopen 5 | from tempfile import mkstemp 6 | 7 | import gradio as gr 8 | 9 | from ..infer import TikzDocument 10 | from ..model import load 11 | 12 | def to_svg( 13 | tikzdoc: TikzDocument, 14 | build_dir: str 15 | ): 16 | if not tikzdoc.is_rasterizable: 17 | if tikzdoc.compiled_with_errors: 18 | raise gr.Error("TikZ code did not compile!") 19 | else: 20 | gr.Warning("TikZ code compiled to an empty image!") 21 | elif tikzdoc.compiled_with_errors: 22 | gr.Warning("TikZ code compiled with errors!") 23 | 24 | fd, path = mkstemp(dir=build_dir, suffix=".svg") 25 | with fdopen(fd, "w") as f: 26 | if pdf:=tikzdoc.pdf: 27 | f.write(pdf[0].get_svg_image()) 28 | return path if pdf else None 29 | 30 | # https://stackoverflow.com/a/50992575 31 | def make_ordinal(n): 32 | n = int(n) 33 | if 11 <= (n % 100) <= 13: 34 | suffix = 'th' 35 | else: 36 | suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)] 37 | return str(n) + suffix 38 | 39 | class MctsOutputs(set): 40 | def __init__(self, build_dir, *args, **kwargs): 41 | super().__init__(*args, **kwargs) 42 | self.build_dir, self.svgmap, self.fails = build_dir, dict(), 0 43 | 44 | def add(self, score, tikzdoc): # type: ignore 45 | if (score, tikzdoc) not in self: 46 | try: 47 | svg = to_svg(tikzdoc, build_dir=self.build_dir) 48 | super().add((score, tikzdoc)) 49 | self.svgmap[tikzdoc] = svg 50 | except gr.Error: 51 | gr.Warning("TikZ code did not compile, discarding output!") 52 | if len(self): self.fails += 1 53 | elif len(self): self.fails += 1 54 | 55 | @property 56 | def programs(self): 57 | return [tikzdoc.code for _, tikzdoc in sorted(self, key=itemgetter(0), reverse=True)] 58 | 59 | @property 60 | def images(self): 61 | return [ 62 | (self.svgmap[tikzdoc], make_ordinal(idx)) 63 | for idx, (_, tikzdoc) in enumerate(sorted(self, key=itemgetter(0), reverse=True), 1) 64 | ] 65 | 66 | @property 67 | def first_success(self): 68 | return len(self) == 1 and not self.fails 69 | 70 | def make_light(stylable): 71 | """ 72 | Patch gradio to only contain light mode colors. 73 | """ 74 | if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme 75 | params = signature(stylable.set).parameters 76 | colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params} 77 | return stylable.set(**colors) 78 | elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals) 79 | stylable.load( 80 | fn=None, 81 | js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))" 82 | ) 83 | return stylable 84 | else: 85 | raise ValueError 86 | 87 | @lru_cache(maxsize=1) 88 | def cached_load(*args, **kwargs): 89 | gr.Info("Instantiating model. This could take a while...") 90 | return load(*args, **kwargs) 91 | 92 | @cache 93 | def info_once(message): 94 | gr.Info(message) 95 | 96 | class GeneratorLock: 97 | """ 98 | Ensure that only one instance of a given generator is active. 99 | Useful when a previous invocation was canceled. See 100 | https://github.com/gradio-app/gradio/issues/8503 for more information. 101 | """ 102 | def __init__(self, gen_func): 103 | self.gen_func = gen_func 104 | self.generator = None 105 | 106 | def generate(self, *args, **kwargs): 107 | if self.generator: 108 | if self.generator.gi_running: 109 | return # somehow we can end up here 110 | self.generator.close() 111 | self.generator = self.gen_func(*args, **kwargs) 112 | yield from self.generator 113 | 114 | def __call__(self, *args, **kwargs): 115 | yield from self.generate(*args, **kwargs) 116 | -------------------------------------------------------------------------------- /detikzify/dataset/scicap/scicap.py: -------------------------------------------------------------------------------- 1 | """ 2 | The SciCap dataset, unified in a single train split. 3 | """ 4 | 5 | from json import load 6 | from os import symlink 7 | from os.path import basename, join 8 | from subprocess import run 9 | from tempfile import TemporaryDirectory 10 | from zipfile import ZipFile 11 | 12 | from datasets import Features, Image, Sequence, Value, builder 13 | from datasets.info import DatasetInfo 14 | from datasets.splits import Split, SplitGenerator 15 | from datasets.utils.hub import hf_hub_url 16 | 17 | from detikzify.util import convert, expand 18 | 19 | class SciCapConfig(builder.BuilderConfig): 20 | """BuilderConfig for SciCap.""" 21 | 22 | def __init__(self, size, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.repo_id = "CrowdAILab/scicap" 25 | self.size = size 26 | self.files = { 27 | "img": { 28 | (public:="img-split"): 10, 29 | (hidden:="img-hide_test"): 0 30 | }, 31 | "text": { 32 | "train": public, 33 | "train-acl": public, 34 | "val": public, 35 | "public-test": public, 36 | "hide_test": hidden, 37 | } 38 | } 39 | 40 | 41 | class SciCap(builder.GeneratorBasedBuilder): 42 | """The SciCap dataset in the format DeTikZify expects (everything is training data).""" 43 | 44 | BUILDER_CONFIG_CLASS = SciCapConfig 45 | 46 | def _info(self): 47 | features = { 48 | "caption": Value("string"), 49 | "mention": Sequence(Sequence(Value("string"))), 50 | "paragraph": Sequence(Value("string")), 51 | "ocr": Sequence(Value("string")), 52 | "image": Image(), 53 | } 54 | return DatasetInfo( 55 | description=str(__doc__), 56 | features=Features(features), 57 | ) 58 | def _split_generators(self, dl_manager): 59 | with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: 60 | def dl(path): 61 | return dl_manager.download(hf_hub_url(self.config.repo_id, path)) # type: ignore 62 | 63 | def zip_dl(path, num_splits=0): 64 | paths = [f"{path}.zip"] + list(f"{path}.z{{:02d}}".format(i+1) for i in range(num_splits)) 65 | downloaded = [dl(path) for path in paths] 66 | if num_splits: 67 | output = join(tmpdirname, f"{path}-joined.zip") 68 | for src, dst in zip(downloaded, paths): 69 | symlink(src, join(tmpdirname, dst)) # type: ignore 70 | run(["zip", "-FF", join(tmpdirname, paths[0]), "--out", output], check=True, capture_output=True) 71 | return output 72 | else: 73 | return downloaded[0] 74 | 75 | files_to_download = self.config.files # type: ignore 76 | img = {file:zip_dl(file, num_splits) for file, num_splits in files_to_download['img'].items()} 77 | text = {dl(f"{file}.json"):img[img_file] for file, img_file in files_to_download['text'].items()} 78 | 79 | yield SplitGenerator(name=str(Split.TRAIN), gen_kwargs={"shards": text}) 80 | 81 | def _generate_examples(self, shards): 82 | idx = 0 83 | for path, image_zip in shards.items(): 84 | with ZipFile(file=image_zip, mode='r') as zf: 85 | imagemap = {basename(name):name for name in zf.namelist()} 86 | with open(path) as f: 87 | images, annotations = load(f).values() 88 | for annotation, image in zip(annotations, images): 89 | assert image["id"] == annotation['image_id'] 90 | with zf.open(imagemap[image['file_name']]) as img: 91 | yield idx, dict( 92 | caption=annotation.get("caption_no_index"), 93 | mention=annotation.get("mention"), 94 | paragraph=annotation.get("paragraph"), 95 | ocr=image.get("ocr"), 96 | image=convert(expand(img, self.config.size), "png") 97 | ) 98 | idx += 1 99 | -------------------------------------------------------------------------------- /detikzify/evaluate/crystalbleu.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from functools import cached_property 3 | from hashlib import md5 4 | from itertools import chain, tee 5 | from pickle import dump, load 6 | from typing import List 7 | 8 | from crystalbleu import corpus_bleu 9 | from datasets.utils.logging import get_logger 10 | from huggingface_hub import cached_assets_path 11 | from pygments.lexers.markup import TexLexer 12 | from pygments.token import Comment, Name, Text 13 | from sacremoses import MosesTokenizer 14 | from torchmetrics import Metric 15 | 16 | logger = get_logger("datasets") 17 | 18 | # adopted from nltk 19 | def pad_sequence(sequence, n, pad_left=False, pad_right=False, left_pad_symbol=None, right_pad_symbol=None): 20 | sequence = iter(sequence) 21 | if pad_left: 22 | sequence = chain((left_pad_symbol,) * (n - 1), sequence) 23 | if pad_right: 24 | sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) 25 | return sequence 26 | 27 | # adopted from nltk 28 | def ngrams(sequence, n, **kwargs): 29 | sequence = pad_sequence(sequence, n, **kwargs) 30 | iterables = tee(sequence, n) 31 | 32 | for i, sub_iterable in enumerate(iterables): # For each window, 33 | for _ in range(i): # iterate through every order of ngrams 34 | next(sub_iterable, None) # generate the ngrams within the window. 35 | return zip(*iterables) # Unpack and flattens the iterables. 36 | 37 | class CrystalBLEU(Metric): 38 | """Wrapper around https://github.com/sola-st/crystalbleu (adapted for LaTeX)""" 39 | 40 | def __init__(self, corpus, k=500, n=4, use_cache=True, **kwargs): 41 | super().__init__(**kwargs) 42 | self.lexer = TexLexer() 43 | self.tokenizer = MosesTokenizer() 44 | self.use_cache = use_cache 45 | self.corpus = corpus 46 | self.k = k 47 | self.n = n 48 | 49 | self.add_state("list_of_references", [], dist_reduce_fx="cat") 50 | self.add_state("hypotheses", [], dist_reduce_fx="cat") 51 | 52 | def __str__(self): 53 | return self.__class__.__name__ 54 | 55 | @cached_property 56 | def trivially_shared_ngrams(self): 57 | """ 58 | Computes trivially shared ngrams and caches them. 59 | """ 60 | cache_dir = cached_assets_path(library_name="evaluate", namespace=self.__class__.__name__.lower()) 61 | dhash = md5() 62 | dhash.update(str(sorted(self.corpus)).encode()) 63 | hashname = f"{dhash.hexdigest()}.pkl" 64 | 65 | if (cache_file:=(cache_dir / hashname)).is_file() and self.use_cache: 66 | logger.info(f"Found cached trivially shared ngrams ({cache_file})") 67 | with open(cache_file, "rb") as f: 68 | return load(f) 69 | else: 70 | all_ngrams = list() 71 | for o in range(1, self.n+1): 72 | for tex in self.corpus: 73 | all_ngrams.extend(ngrams(self._tokenize(tex), o)) 74 | frequencies = Counter(all_ngrams) 75 | 76 | trivially_shared_ngrams = dict(frequencies.most_common(self.k)) 77 | if self.use_cache: 78 | logger.info(f"Caching trivially shared ngrams ({cache_file})") 79 | with open(cache_file, "wb") as f: 80 | dump(trivially_shared_ngrams, f) 81 | return trivially_shared_ngrams 82 | 83 | def _tokenize(self, text): 84 | tokens = list() 85 | for tokentype, value in self.lexer.get_tokens(text): 86 | if value.strip() and not tokentype is Comment: 87 | if any(tokentype is tp for tp in [Text, Name.Attribute, Name.Builtin]): 88 | tokens.extend(self.tokenizer.tokenize(value.strip())) 89 | else: 90 | tokens.append(value.strip()) 91 | return tokens 92 | 93 | def update( 94 | self, 95 | list_of_references: List[List[str]], 96 | hypotheses: List[str], 97 | ): 98 | assert len(list_of_references) == len(hypotheses) 99 | self.list_of_references.extend([self._tokenize(ref) for ref in refs] for refs in list_of_references) 100 | self.hypotheses.extend(self._tokenize(hyp) for hyp in hypotheses) 101 | 102 | def compute(self): 103 | return corpus_bleu( 104 | list_of_references=self.list_of_references, 105 | hypotheses=self.hypotheses, 106 | ignoring=self.trivially_shared_ngrams 107 | ) 108 | -------------------------------------------------------------------------------- /detikzify/model/configuration_detikzify.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # 14 | # Adapted from 15 | # https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be 16 | 17 | import os 18 | from typing import Union 19 | 20 | from transformers import CONFIG_MAPPING 21 | from transformers.configuration_utils import PretrainedConfig 22 | from transformers.utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class DetikzifyVisionConfig(PretrainedConfig): 29 | model_type = "detikzify" 30 | 31 | def __init__( 32 | self, 33 | hidden_size=1152, 34 | intermediate_size=4304, 35 | num_hidden_layers=27, 36 | num_attention_heads=16, 37 | num_channels=3, 38 | image_size=420, 39 | patch_size=14, 40 | hidden_act="gelu_pytorch_tanh", 41 | layer_norm_eps=1e-6, 42 | attention_dropout=0.0, 43 | initializer_range=0.02, 44 | **kwargs, 45 | ): 46 | super().__init__(**kwargs) 47 | 48 | self.hidden_size = hidden_size 49 | self.intermediate_size = intermediate_size 50 | self.num_hidden_layers = num_hidden_layers 51 | self.num_attention_heads = num_attention_heads 52 | self.num_channels = num_channels 53 | self.patch_size = patch_size 54 | self.image_size = image_size 55 | self.attention_dropout = attention_dropout 56 | self.layer_norm_eps = layer_norm_eps 57 | self.hidden_act = hidden_act 58 | self.initializer_range = initializer_range 59 | 60 | @classmethod 61 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": 62 | cls._set_token_in_kwargs(kwargs) 63 | 64 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 65 | 66 | # get the vision config dict if we are loading from DetikzifyConfig 67 | if config_dict.get("model_type") == "detikzify": 68 | config_dict = config_dict["vision_config"] 69 | 70 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 71 | logger.warning( 72 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 73 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 74 | ) 75 | 76 | return cls.from_dict(config_dict, **kwargs) 77 | 78 | 79 | class DetikzifyConfig(PretrainedConfig): 80 | model_type = "detikzify" 81 | is_composition = True 82 | 83 | def __init__( 84 | self, 85 | use_cache=True, 86 | image_token_id=128005, 87 | tie_word_embeddings=False, 88 | vision_config=None, 89 | text_config=None, 90 | concat_factor=3, 91 | pad_token_id=128004, 92 | **kwargs, 93 | ): 94 | self.image_token_id = image_token_id 95 | self.use_cache = use_cache 96 | self.tie_word_embeddings = tie_word_embeddings 97 | 98 | if vision_config is None: 99 | self.vision_config = DetikzifyVisionConfig() 100 | logger.info("vision_config is None, using default vision config") 101 | elif isinstance(vision_config, dict): 102 | self.vision_config = DetikzifyVisionConfig(**vision_config) 103 | elif isinstance(vision_config, DetikzifyVisionConfig): 104 | self.vision_config = vision_config 105 | 106 | if isinstance(text_config, dict): 107 | text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" 108 | text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) 109 | elif text_config is None: 110 | logger.info("text_config is None, using default text config") 111 | text_config = CONFIG_MAPPING["llama"]( 112 | rms_norm_eps=1e-5, 113 | pad_token_id=pad_token_id, 114 | tie_word_embeddings=False, 115 | ) 116 | 117 | self.text_config = text_config 118 | self.concat_factor = concat_factor 119 | 120 | super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) 121 | -------------------------------------------------------------------------------- /detikzify/util/trainer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from numpy import arange 4 | import torch 5 | from torchvision import tv_tensors 6 | from torchvision.transforms import v2 7 | from transformers import ( 8 | IntervalStrategy, 9 | TrainerCallback, 10 | TrainerControl, 11 | TrainerState, 12 | TrainingArguments, 13 | ) 14 | from transformers.trainer_utils import has_length 15 | from torchvision.transforms.v2._utils import query_size 16 | 17 | class SplitEpochSaveCallback(TrainerCallback): 18 | """ 19 | If save_strategy==EPOCH also save checkpoints at arbitrary fractions of an 20 | epoch (controlled by step_size). 21 | """ 22 | 23 | def __init__(self, step_size: float = 0.5): 24 | self.steps = arange(step_size, 1, step_size) 25 | 26 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 27 | if has_length(train_dataloader:=kwargs['train_dataloader']): 28 | self.num_update_steps_per_epoch = max(len(train_dataloader) // args.gradient_accumulation_steps, 1) 29 | else: 30 | self.num_update_steps_per_epoch = args.max_steps 31 | 32 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): # type: ignore 33 | steps = [round(self.num_update_steps_per_epoch * step) for step in self.steps] 34 | if ( 35 | state.global_step % self.num_update_steps_per_epoch in steps 36 | and args.save_strategy == IntervalStrategy.EPOCH 37 | ): 38 | control.should_save = True 39 | 40 | return control 41 | 42 | class SketchAugment(v2.Compose): 43 | def __init__(self, intensity=1): 44 | super().__init__([ 45 | v2.RandomOrder([ 46 | v2.ElasticTransform(alpha=50. * intensity, fill=255), 47 | v2.JPEG((40 * intensity, 100)), 48 | v2.ColorJitter(brightness=(.75 + .25 * intensity, 1.75)), 49 | v2.RandomEqualize(), 50 | v2.RandomGrayscale() 51 | ]), 52 | v2.RGB() 53 | ]) 54 | 55 | class FullErase(v2.Lambda): 56 | def __init__(self, value=255): 57 | super().__init__(partial(v2.functional.erase, i=0, j=0, h=-1, w=-1, v=torch.tensor(value))) 58 | 59 | class EditBase(v2.Transform): 60 | def __init__(self, *, alpha: float = 1.0) -> None: 61 | super().__init__() 62 | self.alpha = float(alpha) 63 | self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) 64 | 65 | def _get_boxes(self, flat_inputs): 66 | lam = self._dist.sample((len(flat_inputs),)).squeeze() # type: ignore 67 | 68 | H, W = query_size(flat_inputs) 69 | 70 | r_x = torch.randint(W, size=(len(flat_inputs),)) 71 | r_y = torch.randint(H, size=(len(flat_inputs),)) 72 | 73 | r = 0.5 * torch.sqrt(1.0 - lam) 74 | r_w_half = (r * W).int() 75 | r_h_half = (r * H).int() 76 | 77 | x1 = torch.clamp(r_x - r_w_half, min=0) 78 | y1 = torch.clamp(r_y - r_h_half, min=0) 79 | x2 = torch.clamp(r_x + r_w_half, max=W) 80 | y2 = torch.clamp(r_y + r_h_half, max=H) 81 | 82 | grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="ij") 83 | 84 | grid_x = grid_x.unsqueeze(0).expand(len(flat_inputs), -1, -1) 85 | grid_y = grid_y.unsqueeze(0).expand(len(flat_inputs), -1, -1) 86 | 87 | mask = (grid_x >= x1.unsqueeze(1).unsqueeze(2)) & (grid_x < x2.unsqueeze(1).unsqueeze(2)) & \ 88 | (grid_y >= y1.unsqueeze(1).unsqueeze(2)) & (grid_y < y2.unsqueeze(1).unsqueeze(2)) 89 | 90 | return mask.unsqueeze(1).expand(-1, 3, -1, -1) 91 | 92 | class EditCutMix(EditBase): 93 | def _transform(self, inpt, params): 94 | output = inpt.clone() 95 | rolled = inpt.roll(1, 0) 96 | box = self._get_boxes(inpt) 97 | output[box] = rolled[box] 98 | 99 | if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): 100 | output = tv_tensors.wrap(output, like=inpt) 101 | 102 | return output 103 | 104 | class EditMixUp(EditBase): 105 | def _transform(self, inpt, params): 106 | lam = self._dist.sample((len(inpt),)).view(-1, *([1] * len(inpt.shape[1:]))) # type: ignore 107 | output = inpt.roll(1, 0).mul(1.0 - lam).add_(inpt.mul(lam)).to(inpt.dtype) 108 | 109 | if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): 110 | output = tv_tensors.wrap(output, like=inpt) 111 | 112 | return output 113 | 114 | class EditCutOut(EditBase): 115 | def __init__(self, *args, value=255, **kwargs): 116 | self.value = value 117 | super().__init__(*args, **kwargs) 118 | 119 | def _transform(self, inpt, params): 120 | output = inpt.clone() 121 | box = self._get_boxes(inpt) 122 | output[box] = self.value 123 | 124 | if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): 125 | output = tv_tensors.wrap(output, like=inpt) 126 | 127 | return output 128 | -------------------------------------------------------------------------------- /examples/tikzero/pretrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu 2 | from argparse import ArgumentParser 3 | from datetime import timedelta 4 | from itertools import chain 5 | from os import sched_getaffinity 6 | from os.path import basename, join 7 | 8 | from accelerate import Accelerator 9 | from datasets import Dataset 10 | from datasets import concatenate_datasets, load_dataset 11 | from torch import distributed as dist 12 | from transformers import AutoTokenizer, set_seed 13 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info 14 | 15 | from detikzify.model import load 16 | from detikzify.model.adapter import AdapterProcessor 17 | from detikzify.train.adapter import CrossAttentionSiglipVisionModel, pretrain 18 | from detikzify.util import batchify, convert, expand 19 | 20 | @batchify 21 | def process_arxivcap(batch, size): 22 | """Concatenate captions and OCR tokens.""" 23 | for caption_images in chain.from_iterable(batch['caption_images']): 24 | caption = caption_images['caption'] 25 | for cil_pair in caption_images['cil_pairs']: 26 | sub_caption = cil_pair['sub_caption'] 27 | if text:=" ".join(filter(None, [caption, sub_caption])): 28 | yield dict( 29 | text=text, 30 | image=convert(expand(cil_pair['image'], size, do_trim=True), "png") 31 | ) 32 | 33 | def process_openmoji(ex, size): 34 | ex['image'] = convert(expand(ex['image'], size, do_trim=True), "png") 35 | return ex 36 | 37 | def init_adapter(base_model, embedding_model): 38 | model, processor = load(base_model) 39 | 40 | vision_model = CrossAttentionSiglipVisionModel.from_pretrained( 41 | pretrained_model_name_or_path=None, 42 | config=model.config.vision_config, 43 | state_dict=model.model.vision_model.state_dict(), 44 | torch_dtype="bfloat16", 45 | ) 46 | del model 47 | 48 | vision_model.init_cross_attn_adapter(embedding_model) 49 | processor = AdapterProcessor( 50 | processor=processor.image_processor, 51 | tokenizer=AutoTokenizer.from_pretrained( 52 | embedding_model, 53 | pad_token="<|finetune_right_pad_id|>", 54 | model_max_length=512, 55 | ) 56 | ) 57 | vision_model.embedding_model.config.pad_token_id = processor.tokenizer.pad_token_id 58 | 59 | return vision_model, processor 60 | 61 | def parse_args(): 62 | argument_parser = ArgumentParser( 63 | description="Pre-train a TikZero adapter on ArxivCap." 64 | ) 65 | argument_parser.add_argument("--base_model", 66 | required=True, 67 | help="The DeTikZify model checkpoint for weights initialization." 68 | ) 69 | argument_parser.add_argument("--embedding_model", 70 | default="meta-llama/Llama-3.2-1B", 71 | help=( 72 | "The adapter embedding model checkpoint for weights initialization. " 73 | "Only LLaMA 3.1/3.2 models are officially supported." 74 | ) 75 | ) 76 | argument_parser.add_argument("--output", 77 | required=True, 78 | help="directory where to write the model files", 79 | ) 80 | argument_parser.add_argument("--deepspeed", 81 | help="path to a DeepSpeed json config file", 82 | ) 83 | argument_parser.add_argument("--gradient_checkpointing", 84 | action="store_true", 85 | help="use gradient checkpointing", 86 | ) 87 | argument_parser.add_argument("--mse_loss", 88 | action="store_true", 89 | help="train using mse loss instead of cosine similarity", 90 | ) 91 | 92 | return argument_parser.parse_args() 93 | 94 | if __name__ == "__main__": 95 | set_verbosity_info() 96 | enable_explicit_format() 97 | dist.init_process_group(timeout=timedelta(days=3)) 98 | set_seed(0) 99 | 100 | args = parse_args() 101 | vision_model, processor = init_adapter(args.base_model, args.embedding_model) 102 | 103 | with Accelerator().main_process_first(): 104 | arxivcap: Dataset = load_dataset("MMInstruction/ArxivCap", split="train") # type: ignore 105 | openmoji: Dataset = load_dataset("soypablo/Emoji_Dataset-Openmoji", split="train") # type: ignore 106 | arxivcap = arxivcap.map( 107 | process_arxivcap, 108 | batched=True, 109 | remove_columns=arxivcap.column_names, 110 | batch_size=100, 111 | fn_kwargs=dict(size=vision_model.config.image_size), 112 | num_proc=len(sched_getaffinity(0)) 113 | ) 114 | openmoji = openmoji.map( 115 | process_openmoji, 116 | fn_kwargs=dict(size=vision_model.config.image_size), 117 | ) 118 | 119 | pretrain( 120 | model=vision_model, 121 | processor=processor, 122 | dataset=concatenate_datasets([arxivcap, openmoji]), 123 | output_dir=join(args.output, basename(args.base_model)), 124 | gradient_checkpointing=args.gradient_checkpointing, 125 | deepspeed=args.deepspeed, 126 | mse_loss=args.mse_loss, 127 | ) 128 | -------------------------------------------------------------------------------- /examples/sketchify.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu 2 | 3 | from argparse import ArgumentParser 4 | from functools import cached_property 5 | from itertools import chain 6 | from math import ceil, floor 7 | from os import environ 8 | from random import choice, gauss, random, sample 9 | 10 | from PIL import Image 11 | from datasets import load_dataset 12 | from diffusers import DiffusionPipeline 13 | import torch 14 | 15 | from detikzify.util import convert 16 | 17 | # performance optimizations: https://huggingface.co/blog/sd3 18 | torch.set_float32_matmul_precision("high") 19 | torch._inductor.config.conv_1x1_as_mm = True 20 | torch._inductor.config.coordinate_descent_tuning = True 21 | torch._inductor.config.epilogue_fusion = False 22 | torch._inductor.config.coordinate_descent_check_all_directions = True 23 | 24 | WORLD_SIZE = int(environ.get("WORLD_SIZE", 1)) 25 | RANK = int(environ.get("RANK", 0)) 26 | 27 | class Sketchifier: 28 | def __init__( 29 | self, 30 | model="nllg/sketch-pix2pix", 31 | device=torch.device("cuda", RANK), 32 | grayscale_ratio=0.1, 33 | ): 34 | self.model = model 35 | self.device = torch.device(device) 36 | self.grayscale_ratio = grayscale_ratio 37 | 38 | @cached_property 39 | def pipe(self): 40 | pipe = DiffusionPipeline.from_pretrained( 41 | pretrained_model_name_or_path="nllg/ultrasketch", 42 | custom_pipeline="nllg/ultrasketch", 43 | trust_remote_code=True, 44 | torch_dtype=torch.float16, 45 | ) 46 | pipe.set_progress_bar_config(disable=True) 47 | 48 | # speed up inference 49 | pipe.to(self.device) 50 | pipe.transformer.to(memory_format=torch.channels_last) 51 | pipe.vae.to(memory_format=torch.channels_last) 52 | 53 | pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) 54 | pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True) 55 | 56 | return pipe 57 | 58 | def __call__(self, *args, **kwargs): 59 | return self.sketchify(*args, **kwargs) 60 | 61 | def sketchify(self, image): 62 | with torch.inference_mode(), torch.autocast(self.device.type, enabled=False): # type: ignore 63 | sketch = self.pipe( 64 | prompt="Turn it into a hand-drawn sketch", 65 | image=image, 66 | mask_img=Image.new("RGB", image.size, "white"), 67 | num_inference_steps=50, 68 | image_guidance_scale=1.7, 69 | guidance_scale=1.5, 70 | strength=max(.85, min(.95, gauss(.9, .5))) 71 | ).images[0] 72 | sketch = sketch if random() > self.grayscale_ratio else sketch.convert("L") 73 | return convert(sketch, "png") 74 | 75 | def sketchify(dataset, num_epochs, ratio, sketchifier): 76 | """ 77 | Randomly sketchify of all examples in for each epoch 78 | given with . 79 | """ 80 | # prepare the sketches (distribute load among all workers) 81 | worker_sketches, all_sketches = list(), WORLD_SIZE * [None] 82 | for i in torch.arange(len(dataset['image'])).tensor_split(WORLD_SIZE)[RANK]: 83 | # randomize in which epochs how many images should be sketchified 84 | num_sketches = choice([floor(product:=ratio*num_epochs), ceil(product)]) 85 | sketch_epochs = sample(range(num_epochs), k=num_sketches) 86 | # generate the sketches 87 | sketches = [sketchifier(dataset['image'][i.item()]) for _ in range(num_sketches)] 88 | worker_sketches.append([sketches.pop() if epoch in sketch_epochs else None for epoch in range(num_epochs)]) 89 | 90 | torch.distributed.all_gather_object(all_sketches, worker_sketches) # type: ignore 91 | dataset['sketches'] = list(chain.from_iterable(all_sketches)) # type: ignore 92 | return dataset 93 | 94 | def parse_args(): 95 | argument_parser = ArgumentParser( 96 | description="Sketchify an existing DaTikZ dataset." 97 | ) 98 | argument_parser.add_argument( 99 | "--path", 100 | default="nllg/datikz-v3", 101 | help="Path or name of the DaTikZ dataset.", 102 | ) 103 | return argument_parser.parse_args() 104 | 105 | if __name__ == "__main__": 106 | args = parse_args() 107 | datikz = load_dataset(args.path) 108 | torch.distributed.init_process_group() 109 | 110 | train = datikz['train'].map( 111 | function=sketchify, 112 | batched=True, 113 | batch_size=WORLD_SIZE * 1000, 114 | desc="Sketchify (train)", 115 | fn_kwargs=dict( 116 | num_epochs=5, 117 | ratio=0.5, 118 | sketchifier=(sketchifier:=Sketchifier()) 119 | ) 120 | ) 121 | 122 | # test split is small so keep it simple and do it only on the main process 123 | if RANK == 0: 124 | test = datikz['test'].map( 125 | function=lambda ex: ex | {"sketch": sketchifier(ex['image'])}, 126 | desc="Sketchify (test)" 127 | ) 128 | 129 | train.to_parquet("datikz-train-sketches.parquet", compression="GZIP") 130 | test.to_parquet("datikz-test-sketches.parquet", compression="GZIP") 131 | -------------------------------------------------------------------------------- /detikzify/model/processing_detikzify.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Adapted from 16 | # https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be 17 | 18 | from typing import List, Optional, Union, Unpack 19 | 20 | from transformers.feature_extraction_utils import BatchFeature 21 | from transformers.image_utils import ImageInput, make_list_of_images 22 | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin 23 | from transformers.tokenization_utils_base import ( 24 | BatchEncoding, 25 | PreTokenizedInput, 26 | TextInput, 27 | ) 28 | from transformers.utils import logging 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | 33 | class DetikzifyProcessorKwargs(ProcessingKwargs, total=False): 34 | _defaults = { 35 | "text_kwargs": { 36 | "add_special_tokens": False, 37 | "padding": False, 38 | }, 39 | } 40 | 41 | 42 | class DetikzifyProcessor(ProcessorMixin): 43 | attributes = ["image_processor", "tokenizer"] 44 | image_processor_class = "AutoImageProcessor" 45 | tokenizer_class = "AutoTokenizer" 46 | 47 | def __init__( 48 | self, 49 | image_processor, 50 | tokenizer=None, 51 | image_seq_len: int = 300, 52 | image_token: str = "<|reserved_special_token_2|>", 53 | model_expects_text: bool = False, 54 | **kwargs, 55 | ): 56 | if image_processor is None: 57 | raise ValueError("You need to specify an `image_processor`.") 58 | if tokenizer is None: 59 | raise ValueError("You need to specify a `tokenizer`.") 60 | if image_token not in tokenizer.vocab: 61 | raise ValueError(f"{image_token} needs to be added to the `tokenizer` vocabulary.") 62 | 63 | self.image_token = image_token 64 | self.image_seq_len = image_seq_len 65 | self.model_expects_text = model_expects_text 66 | 67 | super().__init__(image_processor, tokenizer, **kwargs) 68 | 69 | def __call__( 70 | self, 71 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 72 | images: ImageInput = None, 73 | image_seq_len: Optional[int] = None, 74 | add_bos_token: bool = None, 75 | add_eos_token: bool = None, 76 | **kwargs: Unpack[DetikzifyProcessorKwargs], 77 | ) -> BatchEncoding: 78 | output_kwargs = self._merge_kwargs( 79 | DetikzifyProcessorKwargs, 80 | tokenizer_init_kwargs=self.tokenizer.init_kwargs, 81 | **kwargs, 82 | ) 83 | # Temporary fix for "padding_side" in init_kwargs 84 | output_kwargs["text_kwargs"].pop("padding_side", None) 85 | 86 | if images is None: 87 | raise ValueError("`images` are expected as arguments to a `DetikzifyProcessor` instance.") 88 | else: 89 | if isinstance(images, list) and all(isinstance(img, list) and len(img) == 1 for img in images): 90 | # compatibility with trl 91 | images = [img[0] for img in images] 92 | images = make_list_of_images(images) 93 | if text is None: 94 | text = len(images) * [""] 95 | elif isinstance(text, str): 96 | text = [text] 97 | if len(images) != len(text): 98 | raise ValueError( 99 | f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image." 100 | ) 101 | 102 | prompt_strings = [] 103 | for prompt in text: 104 | assert self.image_token not in prompt, "Image tokens are added by the processor!" 105 | if add_bos_token: 106 | prompt += self.tokenizer.bos_token 107 | if add_eos_token: 108 | prompt += self.tokenizer.eos_token 109 | image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len 110 | prompt_strings.append((self.image_token * image_seq_len) + prompt) 111 | 112 | image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) 113 | text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) 114 | 115 | return BatchFeature(data={**image_inputs, **text_inputs}) 116 | 117 | def batch_decode(self, *args, **kwargs): 118 | return self.tokenizer.batch_decode(*args, **kwargs) 119 | 120 | def decode(self, *args, **kwargs): 121 | return self.tokenizer.decode(*args, **kwargs) 122 | 123 | @property 124 | def model_input_names(self): 125 | tokenizer_input_names = self.tokenizer.model_input_names 126 | image_processor_input_names = self.image_processor.model_input_names 127 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 128 | -------------------------------------------------------------------------------- /detikzify/train/train.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import os 3 | from random import random 4 | from typing import Dict, List 5 | 6 | from PIL import Image 7 | import torch 8 | from torch.utils.data import Dataset 9 | from transformers import Trainer, TrainerCallback, TrainingArguments 10 | from transformers.trainer_utils import get_last_checkpoint 11 | from transformers.utils import logging 12 | 13 | from ..util import SketchAugment, SplitEpochSaveCallback 14 | from .pretrain import tokenize 15 | 16 | logger = logging.get_logger("transformers") 17 | 18 | WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) 19 | RANK = int(os.environ.get("RANK", 0)) 20 | 21 | class ImageSketchDataset(Dataset, TrainerCallback): 22 | """ 23 | Dataset which samples sketches instead of images, when a sketch exists 24 | for the current epoch. 25 | """ 26 | def __init__(self, dataset, processor, ds_sketch_ratio=.5): 27 | super().__init__() 28 | self.processor = processor 29 | self.dataset = dataset.with_transform(self.tokenize) 30 | self.ds_sketch_ratio = ds_sketch_ratio 31 | self.sketchify = SketchAugment() 32 | self.cur_epoch = 0 33 | 34 | def __len__(self): 35 | return len(self.dataset) 36 | 37 | def tokenize(self, batch): 38 | for idx, sketches in enumerate(batch['sketches']): 39 | if (sketch:=sketches[self.cur_epoch]): 40 | if random() >= self.ds_sketch_ratio: 41 | batch['image'][idx] = Image.open(BytesIO(sketch['bytes'])).convert("RGB") 42 | else: 43 | batch['image'][idx] = self.sketchify(batch['image'][idx]) 44 | 45 | return tokenize( 46 | batch=batch, 47 | processor=self.processor, 48 | return_tensors="pt", 49 | truncation=False, 50 | padding=True 51 | ) 52 | 53 | def filter(self, *args, **kwargs): 54 | self.dataset = self.dataset.filter(*args, **kwargs) 55 | 56 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 57 | return self.dataset[index] 58 | 59 | def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]: 60 | return self.dataset[*indices] 61 | 62 | def on_epoch_end(self, *args, **kwargs): 63 | self.cur_epoch += 1 64 | 65 | def train( 66 | output_dir: str, 67 | model, 68 | processor, 69 | dataset, 70 | overwrite=False, 71 | deepspeed=None, 72 | # training hyperparams 73 | batch_size: int = 128, 74 | micro_batch_size: int = 1, 75 | num_epochs: int = 5, 76 | learning_rate: float = 5e-5, 77 | sketch_ratio=.5, 78 | gradient_checkpointing: bool = False, 79 | ): 80 | assert num_epochs <= len(dataset[0]['sketches']) 81 | gradient_accumulation_steps = batch_size // micro_batch_size 82 | if WORLD_SIZE != 1: 83 | gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE 84 | 85 | dataset = ImageSketchDataset(dataset, processor, ds_sketch_ratio=sketch_ratio) 86 | logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}") 87 | eos_token_id, model_max_length = processor.tokenizer.eos_token_id, processor.tokenizer.model_max_length 88 | dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length) 89 | logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}") 90 | 91 | last_checkpoint = None 92 | if os.path.isdir(output_dir) and not overwrite: 93 | last_checkpoint = get_last_checkpoint(output_dir) 94 | if last_checkpoint is None and len(os.listdir(output_dir)) > 0: 95 | raise ValueError( 96 | f"Output directory ({output_dir}) already exists and is not empty. " 97 | "Use `overwrite` to overcome." 98 | ) 99 | elif last_checkpoint is not None: 100 | logger.info( 101 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 102 | "the `output_dir` or add `overwrite` to train from scratch." 103 | ) 104 | 105 | trainer = Trainer( 106 | model=model, 107 | train_dataset=dataset, 108 | args=TrainingArguments( 109 | per_device_train_batch_size=micro_batch_size, 110 | gradient_accumulation_steps=gradient_accumulation_steps, 111 | gradient_checkpointing=gradient_checkpointing, 112 | # https://github.com/huggingface/transformers/issues/32576 113 | gradient_checkpointing_kwargs={'use_reentrant':False}, 114 | dataloader_num_workers=WORLD_SIZE, 115 | warmup_ratio=0.03, 116 | weight_decay=0, 117 | num_train_epochs=num_epochs, 118 | learning_rate=learning_rate, 119 | torch_compile=True, 120 | bf16=True, 121 | tf32=True, 122 | logging_steps=10, 123 | lr_scheduler_type="cosine", 124 | optim="adamw_torch" if deepspeed else "adamw_torch_fused", 125 | ddp_find_unused_parameters=False, 126 | remove_unused_columns=False, 127 | save_strategy="epoch", 128 | report_to="none", 129 | save_total_limit=1, 130 | output_dir=output_dir, 131 | deepspeed=deepspeed, 132 | ), 133 | callbacks=[SplitEpochSaveCallback(step_size=0.25)], 134 | data_collator=lambda batch: batch 135 | ) 136 | 137 | trainer.add_callback(trainer.train_dataset) 138 | trainer.train(resume_from_checkpoint=last_checkpoint) 139 | 140 | if trainer.is_deepspeed_enabled: 141 | # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading 142 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint 143 | last_checkpoint = get_last_checkpoint(output_dir) 144 | load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint) 145 | 146 | trainer.save_model(output_dir) 147 | trainer.save_state() 148 | 149 | return model, processor 150 | -------------------------------------------------------------------------------- /detikzify/mcts/README.md: -------------------------------------------------------------------------------- 1 | A Python3 library for running a [Monte Carlo tree search](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search), either traditionally by drilling down to end game states or with expert policies as might be provided by a neural network. 2 | 3 | Adapted from **Version:** 1.3.1 of [ImparaAI/monte-carlo-tree-search](https://github.com/ImparaAI/monte-carlo-tree-search). 4 | 5 | # Monte Carlo tree search basics 6 | 7 | The Monte Carlo tree search (MCTS) algorithm can help with making a decision from a number of options. It avoids exploring every possible option by randomly sampling a small number of pathways and picking the move with the highest probability of victory. This is commonly applied to games like chess or go where it's useful to know what move should come next if you want to win the game. 8 | 9 | MCTS works by expanding the search tree to figure out which moves (or child/subsequent states) are likely to produce a positive result if chosen. While time is available, the algorithm continues to explore the tree, always slightly favoring the direction that has either proven to be fruitful or is less explored. When no time is left, the most explored direction is chosen. 10 | 11 | The search tree expansion can be done in two different ways: 12 | 13 | - **Traditional**: At least one random rollout to a game's end state (e.g. win, loss, tie) for each move under evaluation so the algorithm can make a choice. 14 | - **Expert policy (i.e. neural network)**: Instead of expensively rolling all the way out to a game's end state ask an expert (a neural network for example) which move is most likely to produce a positive outcome. 15 | 16 | For a deeper dive into the topic, check out [this article](http://tim.hibal.org/blog/alpha-zero-how-and-why-it-works/). 17 | 18 | # This library 19 | 20 | As the user of this library, you only have to provide: 21 | 22 | - A function that finds the direct children of each search tree node (called the **`child_finder`**) 23 | - A function for evaluating nodes for end state outcomes (called the **`node_evaluator`**) 24 | -- *(Not necessary with neural network)* 25 | 26 | # Usage 27 | 28 | Create a new Monte Carlo tree: 29 | 30 | ```python 31 | from chess import Game 32 | from montecarlo.node import Node 33 | from montecarlo.montecarlo import MonteCarlo 34 | 35 | chess_game = Game() 36 | montecarlo = MonteCarlo(Node(chess_game)) 37 | ``` 38 | 39 | The root node describes your current game state. This state will be used by you later in the **`child_finder`** and the **`node_evaluator`**. 40 | 41 | For the sake of demonstration, we will assume you have a generic `Game` library that can tell you what moves are possible and allows you to perform those moves to change the game's state. 42 | 43 | ## Traditional Monte Carlo 44 | 45 | Add a **`child_finder`** and a **`node_evaluator`**: 46 | 47 | ```python 48 | def child_finder(node, montecarlo): 49 | for move in node.state.get_possible_moves(): 50 | child = Node(deepcopy(node.state)) #or however you want to construct the child's state 51 | child.state.move(move) #or however your library works 52 | node.add_child(child) 53 | 54 | def node_evaluator(node, montecarlo): 55 | if node.state.won(): 56 | return 1 57 | elif node.state.lost(): 58 | return -1 59 | 60 | montecarlo.child_finder = child_finder 61 | montecarlo.node_evaluator = node_evaluator 62 | ``` 63 | 64 | The **`child_finder`** should add any child nodes to the parent node passed into the function, if there are any. If there are none, the parent should be in an end state, so the **`node_evaluator`** should return a value between `-1` and `1`. 65 | 66 | ## Expert policy (AI) 67 | 68 | If you have an expert policy that you can apply to the children as they're being generated, the library will recognize that it doesn't need to make the costly drill down to an end state. If your neural net produces both an expert policy value for the children and a win value for the parent node, you can skip declaring the `node_evaluator` altogether. 69 | 70 | ```python 71 | def child_finder(node, montecarlo): 72 | win_value, expert_policy_values = neural_network.predict(node.state) 73 | 74 | for move in node.state.get_possible_moves(): 75 | child = Node(deepcopy(node.state)) 76 | child.state.move(move) 77 | child.player_number = child.state.whose_turn() 78 | child.policy_value = get_child_policy_value(child, expert_policy_values) #should return a probability value between 0 and 1 79 | node.add_child(child) 80 | 81 | node.update_win_value(win_value) 82 | 83 | montecarlo.child_finder = child_finder 84 | ``` 85 | 86 | ## Simulate and make a choice 87 | 88 | Run the simulations: 89 | 90 | ```python 91 | montecarlo.simulate(50) #number of expansions to run. higher is typically more accurate at the cost of processing time 92 | ``` 93 | 94 | Once the simulations have run you can ask the instance to make a choice: 95 | 96 | ```python 97 | chosen_child_node = montecarlo.make_choice() 98 | chosen_child_node.state.do_something() 99 | ``` 100 | 101 | After you've chosen a new root node, you can override it on the `montecarlo` instance and do more simulations from the new position in the tree. 102 | 103 | ```python 104 | montecarlo.root_node = montecarlo.make_choice() 105 | ``` 106 | 107 | If you're training a neural network, you may want to make a more exploratory choice for the first N moves of a game: 108 | 109 | ```python 110 | montecarlo.root_node = montecarlo.make_exploratory_choice() 111 | ``` 112 | 113 | This won't provide a purely random choice, rather it will be random with a bias favoring the more explored pathways. 114 | 115 | ## Turn-based environments 116 | 117 | If you are modeling a turn-based environment (e.g. a two player board game), set the `player_number` on each node so the selection process can invert child win values: 118 | 119 | ```python 120 | node = Node(state) 121 | node.player_number = 1 122 | ``` 123 | 124 | It doesn't matter what this number is (you can use 1 and 2 or 5 and 6), only that it is consistent with other nodes. 125 | 126 | ## Tweaking the discovery factor 127 | 128 | When building a new child node, you can change the rate at which discovery is preferred: 129 | 130 | ```python 131 | node = Node(state) 132 | node.discovery_factor = 0.2 #0.35 by default, can be between 0 and 1 133 | ``` 134 | 135 | The closer this number is to 1, the more discovery will be favored over demonstrated value in later simulations. 136 | -------------------------------------------------------------------------------- /detikzify/evaluate/imagesim.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from math import tanh 3 | from typing import List, Literal, Optional 4 | 5 | from PIL import Image 6 | from ot.lp import emd2 7 | import torch 8 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported 9 | import torch.nn.functional as F 10 | from torchmetrics import Metric 11 | from torchmetrics.functional import pairwise_cosine_similarity 12 | from transformers import AutoImageProcessor, AutoModel, PreTrainedModel, ProcessorMixin 13 | 14 | from ..model.adapter import ( 15 | AdapterProcessor, 16 | CrossAttentionAdapterMixin as AdapterMixin, 17 | has_adapter, 18 | ) 19 | from ..util import cast, expand, infer_device, load, unwrap_processor 20 | 21 | class ImageSim(Metric): 22 | """Perceptual image similarity using visual encoders.""" 23 | 24 | higher_is_better = True 25 | 26 | def __init__( 27 | self, 28 | model_name: str = "google/siglip-so400m-patch14-384", 29 | mode: Literal["emd", "cos", "cos_avg"] = "cos", 30 | preprocess: bool = True, 31 | device: str = infer_device(), 32 | dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, 33 | **kwargs 34 | ): 35 | super().__init__(**kwargs) 36 | self.model_name = model_name 37 | self.preprocess = preprocess 38 | self.mode = mode 39 | self._device = device 40 | self.set_dtype(dtype) 41 | 42 | self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") 43 | self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 44 | 45 | def __str__(self): 46 | return self.__class__.__name__ + f" ({self.mode.upper().replace('_', '-')})" 47 | 48 | @cached_property 49 | def model(self): 50 | # even if we instantiate with from_detikzify we still end up in this function 51 | if (model:=dict(self.named_children()).get("model")) is None: 52 | model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype) 53 | model = model.vision_model.to(self.device) 54 | return model 55 | 56 | @cached_property 57 | def processor(self): 58 | return AutoImageProcessor.from_pretrained(self.model_name) 59 | 60 | @classmethod 61 | def from_detikzify(cls, model: PreTrainedModel, processor: ProcessorMixin, mode=None, *args, **kwargs): 62 | derived_kwargs = dict( 63 | model_name = model.name_or_path, 64 | mode = getattr(model.config, "pooling_mode", "emd") if mode is None else mode, 65 | device = model.device, 66 | dtype = model.dtype, 67 | ) 68 | imagesim = cls(*args, **(derived_kwargs | kwargs)) 69 | 70 | if has_adapter(model): 71 | class AdapterVisionModel(type(model.model.vision_model), AdapterMixin): 72 | embedding_model=model.embedding_model 73 | adapter=model.adapter 74 | 75 | @classmethod 76 | def cast(cls, vision_model): 77 | adapter_vision_model = cast(cls, vision_model) 78 | adapter_vision_model.add_hooks() 79 | return adapter_vision_model 80 | 81 | imagesim.model = AdapterVisionModel.cast(model.model.vision_model) 82 | imagesim.processor = AdapterProcessor( 83 | processor=unwrap_processor(processor).image_processor, 84 | tokenizer=processor.tokenizer # type: ignore 85 | ) 86 | else: 87 | imagesim.model = model.model.vision_model 88 | imagesim.processor = unwrap_processor(processor).image_processor 89 | return imagesim 90 | 91 | def get_vision_features(self, image: Optional[Image.Image | str] = None, text: Optional[str] = None): 92 | if image is not None: 93 | image = load(image) 94 | if self.preprocess: 95 | image = expand(image, max(image.size), do_trim=True) 96 | 97 | with torch.inference_mode(): 98 | if text is not None: 99 | encoding = self.processor(text=text, images=image, return_tensors="pt").to(self.device, self.dtype) 100 | else: 101 | encoding = self.processor(images=image, return_tensors="pt").to(self.device, self.dtype) 102 | if self.mode == "cos": 103 | return self.model(**encoding).pooler_output.squeeze() 104 | elif self.mode == "cos_avg": 105 | return self.model(**encoding).last_hidden_state.squeeze().mean(dim=0) 106 | else: 107 | return self.model(**encoding).last_hidden_state.squeeze() 108 | 109 | def get_similarity( 110 | self, 111 | img1: Optional[Image.Image | str] = None, 112 | img2: Optional[Image.Image | str] = None, 113 | text1: Optional[str] = None, 114 | text2: Optional[str] = None, 115 | ): 116 | img1_feats = self.get_vision_features(img1, text1) 117 | img2_feats = self.get_vision_features(img2, text2) 118 | 119 | if img1_feats.is_mps: # mps backend does not support dtype double 120 | img1_feats, img2_feats = img1_feats.cpu(), img2_feats.cpu() 121 | if img1_feats.ndim > 1: 122 | dists = 1 - pairwise_cosine_similarity(img1_feats.double(), img2_feats.double()).cpu().numpy() 123 | return 2 * tanh(-emd2(M=dists, a=list(), b=list())) + 1 # type: ignore 124 | else: 125 | return F.cosine_similarity(img1_feats.double(), img2_feats.double(), dim=0).item() 126 | 127 | def update( 128 | self, 129 | img1: Optional[Image.Image | str | List[Image.Image | str]] = None, 130 | img2: Optional[Image.Image | str | List[Image.Image | str]] = None, 131 | text1: Optional[str | List[str]] = None, 132 | text2: Optional[str | List[str]] = None, 133 | ): 134 | inputs = dict() 135 | for key, value in dict(img1=img1, img2=img2, text1=text1, text2=text2).items(): 136 | if value is not None: 137 | inputs[key] = value if isinstance(value, List) else [value] 138 | 139 | assert not ({"img1", "text1"}.isdisjoint(inputs.keys()) or {"img2", "text2"}.isdisjoint(inputs.keys())) 140 | assert len(set(map(len, inputs.values()))) == 1 141 | 142 | for inpt in zip(*inputs.values()): 143 | self.score += self.get_similarity(**dict(zip(inputs.keys(), inpt))) 144 | self.n_samples += 1 145 | 146 | def compute(self): 147 | return (self.score / self.n_samples).item() 148 | -------------------------------------------------------------------------------- /detikzify/infer/tikz.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import cache, cached_property 3 | from io import BytesIO 4 | from os import environ 5 | from os.path import isfile, join 6 | from re import MULTILINE, escape, findall, search 7 | from subprocess import CalledProcessError, DEVNULL, TimeoutExpired 8 | from tempfile import NamedTemporaryFile, TemporaryDirectory 9 | from typing import Dict, Optional, Union 10 | 11 | from PIL import Image 12 | from pdf2image.pdf2image import convert_from_bytes 13 | from pdfCropMargins import crop 14 | import pymupdf 15 | from transformers.utils import logging 16 | 17 | from ..util import check_output, expand, redact as redact_text 18 | 19 | logger = logging.get_logger("transformers") 20 | 21 | class TikzDocument: 22 | """ 23 | Facilitate some operations with TikZ code. To compile the images a full 24 | TeXLive installation is assumed to be on the PATH. Cropping additionally 25 | requires Ghostscript, and rasterization needs poppler. 26 | """ 27 | # engines to try, could also try: https://tex.stackexchange.com/a/495999 28 | engines = ["pdflatex", "lualatex", "xelatex"] 29 | Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""]) 30 | 31 | def __init__(self, code: str, timeout: Optional[int] = 60): 32 | self.code = code 33 | self.timeout = timeout 34 | # https://stackoverflow.com/a/68550238 35 | self.compile = cache(self.compile) 36 | 37 | @property 38 | def status(self) -> int: 39 | return self.compile().status 40 | 41 | @property 42 | def pdf(self) -> Optional[pymupdf.Document]: # type: ignore 43 | return self.compile().pdf 44 | 45 | @property 46 | def log(self) -> str: 47 | return self.compile().log 48 | 49 | @property 50 | def compiled_with_errors(self) -> bool: 51 | return self.status != 0 52 | 53 | @property 54 | def errors(self, rootfile: Optional[str] = None) -> Dict[int, str]: 55 | """ 56 | Returns a dict of (linenr, errormsg) pairs. linenr==0 is a special 57 | value reserved for errors that do not have a linenumber in rootfile. 58 | """ 59 | if self.compiled_with_errors: 60 | if not rootfile and (match:=search(r"^\((.+)$", self.log, MULTILINE)): 61 | rootfile = match.group(1) 62 | else: 63 | ValueError("rootfile not found!") 64 | 65 | errors = dict() 66 | for file, line, error in findall(r'^(.+):(\d+):(.+)$', self.log, MULTILINE): 67 | if file == rootfile: 68 | errors[int(line)] = error.strip() 69 | else: # error occurred in other file 70 | errors[0] = error.strip() 71 | 72 | return errors or {0: "Fatal error occurred, no output PDF file produced!"} 73 | return dict() 74 | 75 | @cached_property 76 | def is_rasterizable(self) -> bool: 77 | """true if we have an image""" 78 | return self.rasterize() is not None 79 | 80 | @cached_property 81 | def has_content(self) -> bool: 82 | """true if we have an image that isn't empty""" 83 | return (img:=self.rasterize()) is not None and img.getcolors(1) is None 84 | 85 | @classmethod 86 | def set_engines(cls, engines: Union[str, list]): 87 | cls.engines = [engines] if isinstance(engines, str) else engines 88 | 89 | def compile(self) -> "Output": 90 | output = dict() 91 | with TemporaryDirectory() as tmpdirname: 92 | with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile: 93 | codelines = self.code.split("\n") 94 | # make sure we don't have page numbers in compiled pdf (for cropping) 95 | codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}")) 96 | tmpfile.write("\n".join(codelines).encode()) 97 | 98 | try: 99 | # compile 100 | errorln, tmppdf, outpdf = -1, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf") 101 | open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile 102 | 103 | def try_save_last_page(): 104 | try: 105 | doc = pymupdf.open(tmppdf) 106 | doc.select([len(doc)-1]) 107 | doc.save(outpdf) 108 | except: 109 | pass 110 | 111 | for engine in self.engines: 112 | try: 113 | check_output( 114 | cwd=tmpdirname, 115 | timeout=self.timeout, 116 | stderr=DEVNULL, 117 | env=environ | dict(max_print_line="1000"), # improve formatting of log 118 | args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name] 119 | ) 120 | except (CalledProcessError, TimeoutExpired) as proc: 121 | log = (getattr(proc, "output", b'') or b'').decode(errors="ignore") 122 | error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE) 123 | # only update status and log if first error occurs later than in previous engine 124 | if (linenr:=int(error.group(1)) if error else 0) > errorln: 125 | errorln = linenr 126 | output.update(status=getattr(proc, 'returncode', -1), log=log) 127 | try_save_last_page() 128 | else: 129 | output.update(status=0, log='') 130 | try_save_last_page() 131 | break 132 | 133 | # crop 134 | croppdf = f"{tmpfile.name}.crop" 135 | crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True) 136 | if isfile(croppdf): 137 | output['pdf'] = pymupdf.open(croppdf) 138 | 139 | except FileNotFoundError: 140 | logger.error("Missing dependencies: Did you install TeX Live?") 141 | except RuntimeError: # pdf error during cropping 142 | pass 143 | 144 | if output.get("status") == 0 and not output.get("pdf", None): 145 | logger.warning("Could compile document but something seems to have gone wrong during cropping!") 146 | 147 | return self.Output(**output) 148 | 149 | def rasterize(self, size=420, expand_to_square=True, redact=False, **redact_kwargs) -> Optional[Image.Image]: 150 | if pdf:=self.pdf: 151 | if redact: 152 | pdf = redact_text(pdf, **redact_kwargs) 153 | image = convert_from_bytes(pdf.tobytes(), size=size, single_file=True)[0] 154 | if expand_to_square: 155 | return expand(image, size) 156 | return image 157 | 158 | def save(self, filename: str, *args, **kwargs): 159 | match filename.split(".")[-1]: 160 | case "tex": content = self.code.encode() 161 | case "pdf" if self.pdf: content = self.pdf.tobytes() 162 | case fmt if img := self.rasterize(*args, **kwargs): 163 | img.save(imgByteArr:=BytesIO(), format=fmt) 164 | content = imgByteArr.getvalue() 165 | case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!") 166 | 167 | with open(filename, "wb") as f: 168 | f.write(content) 169 | -------------------------------------------------------------------------------- /detikzify/webui/strings.py: -------------------------------------------------------------------------------- 1 | from os.path import basename 2 | 3 | from transformers import is_timm_available 4 | 5 | BANNER = '''\ 6 |

DeTikZify: Synthesizing Graphics Programs for Scientific Figures and Sketches with TikZ

7 | 8 |

9 | 10 | View on arXiv 11 | 12 | 13 | View on GitHub 14 | 15 | 16 | View on Hugging Face 17 | 18 | 19 | Open in Colab 20 | 21 |

22 | ''' 23 | 24 | MODELS = { 25 | basename(model): model 26 | for model in [ 27 | "nllg/detikzify-v2.5-8b", 28 | "nllg/detikzify-v2-8b", 29 | ] 30 | } 31 | 32 | if is_timm_available(): 33 | MODELS |= { 34 | basename(model).replace("detikzify", "detikzify-v1"): model 35 | for model in [ 36 | "nllg/detikzify-ds-7b", 37 | "nllg/detikzify-cl-7b", 38 | "nllg/detikzify-ds-1.3b", 39 | "nllg/detikzify-tl-1.1b", 40 | ] 41 | } 42 | 43 | ALGORITHMS = { 44 | "mcts": "MCTS", 45 | "sampling": "Sampling" 46 | } 47 | 48 | # https://github.com/gradio-app/gradio/issues/3202#issuecomment-1741571240 49 | # https://github.com/gradio-app/gradio/issues/2666#issuecomment-1651127149 50 | # https://stackoverflow.com/a/64033350 51 | CSS = """ 52 | .input-image { 53 | flex-grow: 1; 54 | } 55 | .output-code { 56 | flex-grow: 1; 57 | height: 0vh; 58 | min-height: 250px; 59 | scrollbar-width: thin !important; 60 | } 61 | .output-code .hide { 62 | display: none; 63 | } 64 | .output-code .cm-scroller { 65 | flex-grow: 1; 66 | } 67 | .output-code .cm-gutters { 68 | position: relative !important; 69 | } 70 | .output-image { 71 | flex-grow: 1; 72 | height: 0vh; 73 | min-height: 250px; 74 | overflow-y: auto !important; 75 | scrollbar-width: thin !important; 76 | } 77 | .output-image .image-container, .output-image .grid-container { 78 | width: 100%; 79 | height: 100%; 80 | } 81 | .output-image .thumbnail-item img { 82 | object-fit: contain; 83 | } 84 | .output-image .grid-wrap.fixed-height { 85 | max-height: 100% !important; 86 | } 87 | .outputs .tabs { 88 | display: flex; 89 | flex-direction: column; 90 | flex-grow: 1; 91 | } 92 | .outputs .tabitem[style="display: block;"] { 93 | flex-grow: 1; 94 | display: flex !important; 95 | } 96 | .outputs .gap { 97 | flex-grow: 1; 98 | } 99 | .outputs .form { 100 | flex-grow: 1 !important; 101 | } 102 | .outputs .form > :last-child{ 103 | flex-grow: 1; 104 | } 105 | """ 106 | 107 | # (Ab)use an invisible fake button with id preview-close to propagate the 108 | # actual press of the button that closes a preview 109 | # https://github.com/gradio-app/gradio/issues/6697 110 | GALLERY_DESELECT_HACK = """ 111 | 131 | """ 132 | -------------------------------------------------------------------------------- /detikzify/train/adapter/train.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from datetime import timedelta 3 | import os 4 | from typing import Dict, List 5 | 6 | from accelerate import Accelerator, InitProcessGroupKwargs 7 | from datasets import Dataset 8 | import torch 9 | from torch.utils.data import Dataset 10 | from transformers import Trainer, TrainingArguments 11 | from transformers.trainer_utils import get_last_checkpoint 12 | from transformers.utils import logging 13 | 14 | from ...util import SplitEpochSaveCallback, unwrap_processor 15 | 16 | logger = logging.get_logger("transformers") 17 | 18 | IGNORE_INDEX = -100 19 | WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) 20 | RANK = int(os.environ.get("RANK", 0)) 21 | 22 | def tokenize( 23 | batch, 24 | processor, 25 | caption_condition=False, 26 | **kwargs 27 | ): 28 | unwrapped_processor = unwrap_processor(processor) 29 | image_token = unwrapped_processor.image_token 30 | image_token_id = unwrapped_processor.tokenizer.convert_tokens_to_ids(image_token) 31 | bos_token = unwrapped_processor.tokenizer.bos_token 32 | 33 | input_ids = processor( 34 | text=batch['caption'], 35 | images_kwargs=dict( 36 | text=[bos_token.join(text) for text in zip(batch['caption'], batch['code'])] if caption_condition else batch['code'], 37 | max_length=unwrapped_processor.tokenizer.model_max_length, 38 | pad_to_multiple_of=8, 39 | add_eos_token=True, 40 | truncation=False, 41 | padding=True 42 | ), 43 | text_kwargs=dict( 44 | padding=True, 45 | truncation=True, 46 | ), 47 | **kwargs 48 | ) 49 | input_ids['labels'] = deepcopy(input_ids['input_ids']) 50 | 51 | if caption_condition: 52 | # do not train on caption and pad tokens 53 | for label_ids in input_ids['labels']: 54 | after_bos_token = False 55 | for idx, label_id in enumerate(label_ids): 56 | if not after_bos_token or label_id in {image_token_id, unwrapped_processor.tokenizer.pad_token_id}: 57 | if label_id == unwrapped_processor.tokenizer.bos_token_id: 58 | after_bos_token = True 59 | label_ids[idx] = IGNORE_INDEX 60 | elif label_id == unwrapped_processor.tokenizer.bos_token_id: 61 | after_bos_token = True 62 | else: 63 | # do not train on image and pad tokens 64 | for label_ids in input_ids['labels']: 65 | for idx, label_id in enumerate(label_ids): 66 | if label_id in {image_token_id, processor.tokenizer.pad_token_id}: 67 | label_ids[idx] = IGNORE_INDEX 68 | 69 | return input_ids 70 | 71 | class AdapterDataset(Dataset): 72 | def __init__(self, dataset, processor, caption_condition=False): 73 | super().__init__() 74 | self.processor = processor 75 | self.dataset = dataset.with_transform(self.tokenize) 76 | self.caption_condition = caption_condition 77 | 78 | def __len__(self): 79 | return len(self.dataset) 80 | 81 | def tokenize(self, batch): 82 | return tokenize( 83 | batch=batch, 84 | processor=self.processor, 85 | caption_condition=self.caption_condition, 86 | return_tensors="pt", 87 | ) 88 | 89 | def filter(self, *args, **kwargs): 90 | self.dataset = self.dataset.filter(*args, **kwargs) 91 | 92 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 93 | return self.dataset[index] 94 | 95 | def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]: 96 | return self.dataset[*indices] 97 | 98 | def train( 99 | output_dir: str, 100 | model, 101 | processor, 102 | dataset, 103 | overwrite=False, 104 | deepspeed=None, 105 | # training hyperparams 106 | caption_condition: bool = False, 107 | batch_size: int = 128, 108 | micro_batch_size: int = 1, 109 | num_epochs: int = 5, 110 | learning_rate: float = 5e-5, 111 | gradient_checkpointing: bool = False, 112 | ): 113 | gradient_accumulation_steps = batch_size // micro_batch_size 114 | if WORLD_SIZE != 1: 115 | gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE 116 | 117 | for _, param in model.model.vision_model.named_parameters(): 118 | param.requires_grad = False 119 | for _, param in model.adapter.named_parameters(): 120 | param.requires_grad = False 121 | for _, param in model.embedding_model.named_parameters(): 122 | param.requires_grad = False 123 | model.enable_input_require_grads() 124 | model.embedding_model.enable_input_require_grads() 125 | 126 | dataset = AdapterDataset(dataset, processor, caption_condition=caption_condition) 127 | logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}") 128 | eos_token_id, model_max_length = unwrap_processor(processor).tokenizer.eos_token_id, unwrap_processor(processor).tokenizer.model_max_length 129 | with Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(days=3))]).main_process_first(): 130 | dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length, num_proc=64, batch_size=16) 131 | logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}") 132 | 133 | last_checkpoint = None 134 | if os.path.isdir(output_dir) and not overwrite: 135 | last_checkpoint = get_last_checkpoint(output_dir) 136 | if last_checkpoint is None and len(os.listdir(output_dir)) > 0: 137 | raise ValueError( 138 | f"Output directory ({output_dir}) already exists and is not empty. " 139 | "Use `overwrite` to overcome." 140 | ) 141 | elif last_checkpoint is not None: 142 | logger.info( 143 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 144 | "the `output_dir` or add `overwrite` to train from scratch." 145 | ) 146 | 147 | trainer = Trainer( 148 | model=model, 149 | train_dataset=dataset, 150 | args=TrainingArguments( 151 | per_device_train_batch_size=micro_batch_size, 152 | gradient_accumulation_steps=gradient_accumulation_steps, 153 | gradient_checkpointing=gradient_checkpointing, 154 | # https://github.com/huggingface/transformers/issues/32576 155 | #gradient_checkpointing_kwargs={'use_reentrant':False}, 156 | dataloader_num_workers=WORLD_SIZE, 157 | warmup_ratio=0.03, 158 | weight_decay=0, 159 | num_train_epochs=num_epochs, 160 | learning_rate=learning_rate, 161 | torch_compile=True, 162 | bf16=True, 163 | tf32=True, 164 | logging_steps=10, 165 | lr_scheduler_type="cosine", 166 | optim="adamw_torch" if deepspeed else "adamw_torch_fused", 167 | ddp_find_unused_parameters=False, 168 | remove_unused_columns=False, 169 | save_strategy="epoch", 170 | report_to="none", 171 | save_total_limit=1, 172 | output_dir=output_dir, 173 | deepspeed=deepspeed, 174 | ), 175 | callbacks=[SplitEpochSaveCallback(step_size=0.25)], 176 | data_collator=lambda batch: batch 177 | ) 178 | 179 | trainer.add_callback(trainer.train_dataset) 180 | trainer.train(resume_from_checkpoint=last_checkpoint) 181 | 182 | if trainer.is_deepspeed_enabled: 183 | # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading 184 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint 185 | last_checkpoint = get_last_checkpoint(output_dir) 186 | load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint) 187 | 188 | trainer.model.unload_cross_attn_adapter() 189 | trainer.save_model(output_dir) 190 | trainer.save_state() 191 | processor.processor.save_pretrained(output_dir) 192 | 193 | return model, processor 194 | -------------------------------------------------------------------------------- /examples/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | from datetime import timedelta 5 | from functools import partial 6 | from itertools import count 7 | from json import dump, load as load_json 8 | from operator import itemgetter 9 | from os import getenv 10 | from os.path import isfile, join 11 | from time import time 12 | 13 | from datasets import load_dataset 14 | from numpy import array 15 | from scipy.stats.mstats import winsorize 16 | from torch import bfloat16, distributed as dist, float16 17 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported 18 | from tqdm import tqdm 19 | from transformers import set_seed 20 | from transformers.utils import is_flash_attn_2_available 21 | 22 | from detikzify.evaluate import ( 23 | ClipScore, 24 | CrystalBLEU, 25 | DreamSim, 26 | ImageSim, 27 | KernelInceptionDistance, 28 | TexEditDistance, 29 | ) 30 | from detikzify.infer import DetikzifyPipeline, TikzDocument 31 | from detikzify.model import adapter, load as load_model 32 | 33 | WORLD_SIZE = int(getenv("WORLD_SIZE", 1)) 34 | RANK = int(getenv("RANK", 0)) 35 | 36 | def parse_args(): 37 | argument_parser = ArgumentParser( 38 | description="Evaluate fine-tuned models." 39 | ) 40 | argument_parser.add_argument( 41 | "--cache_dir", 42 | help="directory where model outputs should be saved to", 43 | ) 44 | argument_parser.add_argument( 45 | "--trainset", 46 | default="nllg/datikz-v3", 47 | help="path or name of the DaTikZ train set", 48 | ) 49 | argument_parser.add_argument( 50 | "--testset", 51 | required=True, 52 | help="path to the DaTikZ test split processed by the ./sketchify script (in parquet format)", 53 | ) 54 | argument_parser.add_argument( 55 | "--output", 56 | required=True, 57 | help="where to save the computed scores (as json)", 58 | ) 59 | argument_parser.add_argument( 60 | "--timeout", 61 | type=int, 62 | help="minimum time to run MCTS in seconds", 63 | ) 64 | argument_parser.add_argument( 65 | "--model_inputs", 66 | default="image", 67 | choices=["image", "sketch", "caption", "caption-image", "caption-sketch"], 68 | help="which inputs to condition the model on", 69 | ) 70 | argument_parser.add_argument( 71 | "--path", 72 | nargs='+', 73 | metavar="MODEL=PATH[:ADAPTER] | MODEL=JSON", 74 | required=True, 75 | help="(multiple) key-value pairs of model names and paths/urls to models and optionally adapters or json files", 76 | ) 77 | return argument_parser.parse_args() 78 | 79 | # https://stackoverflow.com/a/54802737 80 | def chunk(l, n): 81 | """Yield n number of striped chunks from l.""" 82 | for i in range(0, n): 83 | yield l[i::n] 84 | 85 | def interleave(chunks): 86 | """Interleave chunks until one is exhausted.""" 87 | interleaved = list() 88 | for idx in count(): 89 | try: 90 | interleaved.extend(chunk[idx] for chunk in chunks) 91 | except IndexError: 92 | break 93 | return interleaved 94 | 95 | def generate(pipe, item, model_inputs, strict=False, timeout=None, **tqdm_kwargs): 96 | """Run MCTS until the generated tikz code compiles.""" 97 | start, success, tikzpics = time(), False, set() 98 | inputs = {"text" if key == "caption" else "image": item[key] for key in model_inputs.split("-")} 99 | 100 | for score, tikzpic in tqdm(pipe.simulate(**inputs), desc="Try", **tqdm_kwargs): 101 | tikzpics.add((score, tikzpic.code)) 102 | if not tikzpic.compiled_with_errors if strict else tikzpic.is_rasterizable: 103 | success = True 104 | if success and (not timeout or time() - start >= timeout): 105 | break 106 | return [tikzpic for _, tikzpic in sorted(tikzpics, key=itemgetter(0))] 107 | 108 | def predict(model_name, base_model, testset, model_inputs="image", adapter_model=None, cache_file=None, timeout=None): 109 | predictions, worker_preds = list(), list() 110 | model, processor = load_model( 111 | model_name_or_path=base_model, 112 | device_map=RANK, 113 | torch_dtype=bfloat16 if is_cuda_available() and is_bf16_supported() else float16, 114 | attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, 115 | ) 116 | if adapter_model is not None: 117 | model, processor = adapter.load(model, processor, adapter_model) 118 | # if we don't have a timeout (i.e., only run mcts until we obtain smth compileable), we can use fast metrics 119 | pipe = DetikzifyPipeline(model=model, processor=processor, metric="model" if timeout else "fast") 120 | 121 | if cache_file and isfile(cache_file): 122 | with open(cache_file) as f: 123 | predictions = load_json(f) 124 | try: 125 | worker_chunk = list(chunk(list(testset)[len(predictions):], WORLD_SIZE))[RANK] 126 | # FIXME: right now there only is a progress bar for Rank 0 127 | for item in tqdm(worker_chunk, desc=f"{model_name.title()} ({RANK})", disable=RANK!=0): 128 | tikz = generate(pipe, item, model_inputs, timeout=timeout, position=1, leave=False, disable=RANK!=0) 129 | worker_preds.append(tikz) 130 | del model, processor, pipe 131 | finally: 132 | dist.all_gather_object(gathered:=WORLD_SIZE * [None], worker_preds) 133 | predictions.extend(interleave(gathered)) 134 | if cache_file and RANK == 0: 135 | with open(cache_file, 'w') as f: 136 | dump(predictions, f) 137 | return predictions 138 | 139 | def load_metrics(trainset, measure_throughput=False, **kwargs): 140 | bleu = CrystalBLEU(corpus=trainset, **kwargs) 141 | eed = TexEditDistance(**kwargs) 142 | clip = ClipScore(**kwargs) 143 | imgsim = ImageSim(**kwargs) 144 | dreamsim = DreamSim(**kwargs) 145 | kid = KernelInceptionDistance(**kwargs) 146 | 147 | def mean_token_efficiency(predictions, limit=0.05): 148 | samples = list() 149 | for preds in predictions: 150 | samples.append(len(preds[-1].code)/sum(len(pred.code) for pred in preds)) 151 | return winsorize(array(samples), limits=limit).mean().item() 152 | 153 | def mean_sampling_throughput(predictions, limit=0.05): 154 | return winsorize(array(list(map(len, predictions))), limits=limit).mean().item() 155 | 156 | def compute(references, predictions, compute_redacted=True, **redact_kwargs): 157 | ref_code, pred_code = [[ref['code']] for ref in references], [pred[-1].code for pred in predictions] 158 | ref_image, pred_image = [ref['image'] for ref in references], [pred[-1].rasterize() for pred in predictions] 159 | captions = [ref['caption'] for ref in references] 160 | assert all(pred[-1].is_rasterizable for pred in predictions) 161 | 162 | if measure_throughput: 163 | scores = {"MeanSamplingThroughput": mean_sampling_throughput(predictions=predictions)} 164 | else: 165 | scores = {"MeanTokenEfficiency": mean_token_efficiency(predictions=predictions)} 166 | 167 | redacted_metrics, standard_metrics = {}, { 168 | bleu: partial(bleu.update, list_of_references=ref_code, hypotheses=pred_code), 169 | eed: partial(eed.update, target=ref_code, preds=pred_code), 170 | clip: partial(clip.update, text=captions, images=pred_image), 171 | imgsim: lambda: [imgsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_image)], 172 | dreamsim: lambda: [dreamsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_image)], 173 | kid: lambda: [(kid.update(img1, True), kid.update(img2, False)) for img1, img2 in zip(ref_image, pred_image)], 174 | } 175 | 176 | if compute_redacted: 177 | pred_redacted = [pred[-1].rasterize(redact=True, **redact_kwargs) for pred in predictions] 178 | redacted_metrics.update({ 179 | clip: partial(clip.update, text=captions, images=pred_redacted), 180 | imgsim: lambda: [imgsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_redacted)], 181 | dreamsim: lambda: [dreamsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_redacted)], 182 | }) 183 | 184 | for metrics, redacted in [(standard_metrics, False), (redacted_metrics, True)]: 185 | for metric, update in metrics.items(): 186 | update() 187 | if redacted: 188 | scores[f"Redacted {str(metric)}"] = metric.compute() 189 | else: 190 | scores[str(metric)] = metric.compute() # type: ignore 191 | metric.reset() 192 | 193 | return scores 194 | 195 | return compute 196 | 197 | if __name__ == "__main__": 198 | set_seed(0) 199 | dist.init_process_group(timeout=timedelta(days=3)) 200 | args = parse_args() 201 | 202 | trainset = load_dataset(args.trainset, split="train") 203 | testset = load_dataset("parquet", data_files={"test": args.testset}, split="test").sort("caption") # type: ignore 204 | 205 | predictions = defaultdict(list) 206 | for model_name, path in map(lambda s: s.split('='), tqdm(args.path, desc="Predicting")): 207 | if path.endswith("json"): 208 | with open(path) as f: 209 | predictions[model_name] = load_json(f) 210 | else: 211 | cache_file = join(args.cache_dir, f'{model_name}.json') if args.cache_dir else None 212 | predictions[model_name] = predict( 213 | model_name=model_name, 214 | base_model=path.partition(":")[0], 215 | adapter_model=path.partition(":")[2] or None, 216 | model_inputs=args.model_inputs, 217 | testset=testset, 218 | cache_file=cache_file, 219 | timeout=args.timeout, 220 | ) 221 | 222 | if RANK == 0: # Scoring only on main process 223 | scores = dict() 224 | metrics = load_metrics(trainset['code'], measure_throughput=args.timeout is not None, sync_on_compute=False) # type: ignore 225 | for model_name, prediction in tqdm(predictions.items(), desc="Computing metrics", total=len(predictions)): 226 | scores[model_name] = metrics( 227 | references=testset, 228 | # use an unrealistically long timeout as we know that the (last) images compile 229 | predictions=[[TikzDocument(code, 600) for code in pred] for pred in prediction], 230 | rot_13=True 231 | ) 232 | with open(args.output, "w") as file: 233 | dump(scores, file) 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /detikzify/model/v1/processing_detikzify.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/huggingface/optimum-intel/blob/main/optimum/intel/openvino/modeling_timm.py Below is the original copyright: 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import Dict, List, Optional, Union 18 | 19 | import numpy as np 20 | from timm.data import resolve_data_config 21 | from timm.models import resolve_pretrained_cfg 22 | from transformers.image_processing_utils import ( 23 | BaseImageProcessor, 24 | BatchFeature, 25 | get_size_dict, 26 | ) 27 | from transformers.image_transforms import resize, to_channel_dimension_format 28 | from transformers.image_utils import ( 29 | ChannelDimension, 30 | IMAGENET_STANDARD_MEAN, 31 | IMAGENET_STANDARD_STD, 32 | ImageInput, 33 | PILImageResampling, 34 | make_list_of_images, 35 | to_numpy_array, 36 | valid_images, 37 | ) 38 | from transformers.utils import TensorType 39 | 40 | 41 | class DetikzifyImageProcessor(BaseImageProcessor): 42 | r""" 43 | Constructs a ViT image processor. 44 | 45 | Args: 46 | do_resize (`bool`, *optional*, defaults to `True`): 47 | Whether to resize the image's (height, width) dimensions to the specified `(size["height"], 48 | size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. 49 | size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`): 50 | Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` 51 | method. 52 | resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): 53 | Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the 54 | `preprocess` method. 55 | do_rescale (`bool`, *optional*, defaults to `True`): 56 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` 57 | parameter in the `preprocess` method. 58 | rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): 59 | Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the 60 | `preprocess` method. 61 | do_normalize (`bool`, *optional*, defaults to `True`): 62 | Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` 63 | method. 64 | image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): 65 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of 66 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. 67 | image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): 68 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the 69 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. 70 | """ 71 | 72 | model_input_names = ["pixel_values"] 73 | 74 | def __init__( 75 | self, 76 | do_resize: bool = True, 77 | size: Optional[Dict[str, int]] = None, 78 | resample: PILImageResampling = PILImageResampling.BILINEAR, 79 | do_rescale: bool = True, 80 | rescale_factor: Union[int, float] = 1 / 255, 81 | do_normalize: bool = True, 82 | image_mean: Optional[Union[float, List[float]]] = None, 83 | image_std: Optional[Union[float, List[float]]] = None, 84 | **kwargs, 85 | ) -> None: 86 | super().__init__(**kwargs) 87 | size = size if size is not None else {"height": 224, "width": 224} 88 | size = get_size_dict(size) 89 | self.do_resize = do_resize 90 | self.do_rescale = do_rescale 91 | self.do_normalize = do_normalize 92 | self.size = size 93 | self.resample = resample 94 | self.rescale_factor = rescale_factor 95 | self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN 96 | self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD 97 | 98 | @classmethod 99 | def from_pretrained( 100 | cls, 101 | pretrained_model_name_or_path: Union[str, os.PathLike], 102 | **kwargs, 103 | ): 104 | pretrained_cfg = resolve_pretrained_cfg(variant=pretrained_model_name_or_path) 105 | timm_config_dict = resolve_data_config(pretrained_cfg.to_dict()) 106 | 107 | _, im_h, im_w = timm_config_dict.get("input_size", [3, 224, 224]) 108 | 109 | image_preprocess_config_dict = { 110 | "crop_size": {"height": im_h, "width": im_w}, 111 | "do_center_crop": True if timm_config_dict.get("crop_mode") == "center" else False, 112 | "do_normalize": True, 113 | "do_reduce_labels": False, 114 | "do_rescale": True, 115 | "do_resize": True, 116 | "image_mean": timm_config_dict.get("mean", IMAGENET_STANDARD_MEAN), 117 | "image_processor_type": "TimmImageProcessor", 118 | "image_std": timm_config_dict.get("std", IMAGENET_STANDARD_STD), 119 | "resample": 3, 120 | "rescale_factor": 0.00392156862745098, 121 | "size": {"height": im_h, "width": im_w}, 122 | } 123 | 124 | return cls.from_dict(image_preprocess_config_dict, **kwargs) 125 | 126 | def resize( 127 | self, 128 | image: np.ndarray, 129 | size: Dict[str, int], 130 | resample: PILImageResampling = PILImageResampling.BILINEAR, 131 | data_format: Optional[Union[str, ChannelDimension]] = None, 132 | **kwargs, 133 | ) -> np.ndarray: 134 | """ 135 | Resize an image to `(size["height"], size["width"])`. 136 | 137 | Args: 138 | image (`np.ndarray`): 139 | Image to resize. 140 | size (`Dict[str, int]`): 141 | Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. 142 | resample: 143 | `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. 144 | data_format (`ChannelDimension` or `str`, *optional*): 145 | The channel dimension format for the output image. If unset, the channel dimension format of the input 146 | image is used. Can be one of: 147 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 148 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 149 | 150 | Returns: 151 | `np.ndarray`: The resized image. 152 | """ 153 | size = get_size_dict(size) 154 | if "height" not in size or "width" not in size: 155 | raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") 156 | if image.ndim == 2: 157 | image = np.stack([image] * 3, axis=-1) 158 | return resize( 159 | image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs 160 | ) 161 | 162 | def preprocess( 163 | self, 164 | images: ImageInput, 165 | do_resize: Optional[bool] = None, 166 | size: Dict[str, int] = None, 167 | resample: PILImageResampling = None, 168 | do_rescale: Optional[bool] = None, 169 | rescale_factor: Optional[float] = None, 170 | do_normalize: Optional[bool] = None, 171 | image_mean: Optional[Union[float, List[float]]] = None, 172 | image_std: Optional[Union[float, List[float]]] = None, 173 | return_tensors: Optional[Union[str, TensorType]] = None, 174 | data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, 175 | **kwargs, 176 | ): 177 | """ 178 | Preprocess an image or batch of images. 179 | 180 | Args: 181 | images (`ImageInput`): 182 | Image to preprocess. 183 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 184 | Whether to resize the image. 185 | size (`Dict[str, int]`, *optional*, defaults to `self.size`): 186 | Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after 187 | resizing. 188 | resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): 189 | `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has 190 | an effect if `do_resize` is set to `True`. 191 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 192 | Whether to rescale the image values between [0 - 1]. 193 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 194 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 195 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 196 | Whether to normalize the image. 197 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 198 | Image mean to use if `do_normalize` is set to `True`. 199 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 200 | Image standard deviation to use if `do_normalize` is set to `True`. 201 | return_tensors (`str` or `TensorType`, *optional*): 202 | The type of tensors to return. Can be one of: 203 | - Unset: Return a list of `np.ndarray`. 204 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 205 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 206 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 207 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 208 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): 209 | The channel dimension format for the output image. Can be one of: 210 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 211 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 212 | - Unset: Use the channel dimension format of the input image. 213 | """ 214 | do_resize = do_resize if do_resize is not None else self.do_resize 215 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 216 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 217 | resample = resample if resample is not None else self.resample 218 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 219 | image_mean = image_mean if image_mean is not None else self.image_mean 220 | image_std = image_std if image_std is not None else self.image_std 221 | 222 | size = size if size is not None else self.size 223 | size_dict = get_size_dict(size) 224 | 225 | images = make_list_of_images(images) 226 | 227 | if not valid_images(images): 228 | raise ValueError( 229 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 230 | "torch.Tensor, tf.Tensor or jax.ndarray." 231 | ) 232 | 233 | if do_resize and size is None: 234 | raise ValueError("Size must be specified if do_resize is True.") 235 | 236 | if do_rescale and rescale_factor is None: 237 | raise ValueError("Rescale factor must be specified if do_rescale is True.") 238 | 239 | # All transformations expect numpy arrays. 240 | images = [to_numpy_array(image) for image in images] 241 | 242 | if do_resize: 243 | images = [self.resize(image=image, size=size_dict, resample=resample) for image in images] 244 | 245 | if do_rescale: 246 | images = [self.rescale(image=image, scale=rescale_factor) for image in images] 247 | 248 | if do_normalize: 249 | images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] 250 | 251 | images = [to_channel_dimension_format(image, data_format) for image in images] 252 | data = {"pixel_values": images} 253 | return BatchFeature(data=data, tensor_type=return_tensors) 254 | -------------------------------------------------------------------------------- /examples/refine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu 2 | from argparse import ArgumentParser 3 | from collections import ChainMap 4 | from datetime import timedelta 5 | from functools import cached_property, partial 6 | from operator import itemgetter 7 | from os import environ 8 | from os.path import basename 9 | from random import choice, random 10 | from sys import exit 11 | from typing import Optional 12 | 13 | from accelerate import Accelerator 14 | from datasets import ( 15 | Dataset, 16 | DatasetDict, 17 | Features, 18 | Image, 19 | Value, 20 | concatenate_datasets, 21 | load_dataset, 22 | ) 23 | from torch.distributed import init_process_group 24 | from torch.multiprocessing import Pool, set_start_method 25 | from transformers import set_seed 26 | from transformers.trainer_utils import get_last_checkpoint 27 | from transformers.utils import is_flash_attn_2_available 28 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info 29 | 30 | from detikzify.evaluate.imagesim import ImageSim 31 | from detikzify.infer.tikz import TikzDocument 32 | from detikzify.model import load 33 | from detikzify.util import SketchAugment, batchify, expand 34 | from sketchify import Sketchifier 35 | 36 | try: 37 | from trl import GRPOConfig, GRPOTrainer 38 | except ModuleNotFoundError: 39 | print( 40 | "You need to install trl with vision support to be able to use this script:", 41 | "git clone https://github.com/hellopahe/trl.git", 42 | "curl -L http://github.com/huggingface/trl/pull/3568.patch | git -C trl apply", 43 | "pip install trl", 44 | sep="\n\t" 45 | ) 46 | exit(1) 47 | 48 | 49 | WORLD_SIZE = int(environ.get("WORLD_SIZE", 1)) 50 | 51 | 52 | class RandSketchifier: 53 | def __init__(self, size, sketch_ratio): 54 | self.sketchifier = Sketchifier() 55 | self.deep_sketchifier = SketchAugment() 56 | self.size = size 57 | self.sketch_ratio = sketch_ratio 58 | 59 | @staticmethod 60 | def randbool(): 61 | return choice([True, False]) 62 | 63 | def sketchify(self, img): 64 | return self.sketchifier(img) if self.randbool() else self.deep_sketchifier(img) 65 | 66 | def randsketchify(self, img): 67 | return self.sketchify(img) if random() < self.sketch_ratio else img 68 | 69 | def __call__(self, img): 70 | return self.randsketchify(expand(img, self.size, do_trim=True)) 71 | 72 | 73 | class TrainDataset: 74 | def __init__(self, processor, datikz_name="nllg/datikz-v3", size=420, sketch_ratio=.5): 75 | self.processor = processor 76 | self.datikz: DatasetDict = load_dataset(datikz_name) # type: ignore 77 | self.sketchify = RandSketchifier( 78 | size=size, 79 | sketch_ratio=sketch_ratio 80 | ) 81 | 82 | @staticmethod 83 | def get_fig_type(ftype): 84 | for type_ in ["table", "photograph", "plot", "schematic", "other"]: 85 | if type_ in ftype.lower(): 86 | return type_ 87 | return "N/A" 88 | 89 | @batchify 90 | def extract_figures(self, batch, meta_ds, filter_urls): 91 | for img in batch['image']: 92 | filename = basename(img['path'].split("::")[0]) 93 | if filename in meta_ds and filename.rpartition("v")[0] not in filter_urls: 94 | meta = meta_ds[filename] 95 | ftype = self.get_fig_type(meta["figure_type"]) 96 | # "other" are mostly text snippets 97 | if meta["content_type"] == "figure": 98 | yield dict(image=img, type=ftype) 99 | 100 | def sample_spiqa_dataset(self, n_samples, split="train"): 101 | img_ds: Dataset = load_dataset( # type: ignore 102 | path="google/spiqa", 103 | data_files="train_val/SPIQA_train_val_Images.zip", 104 | split="train", 105 | features=Features({"image": Image(decode=False), "label": Value("string")}) 106 | ) 107 | 108 | meta_ds = load_dataset(path="google/spiqa", data_files=f"train_val/SPIQA_{split}.json", split="train") 109 | meta_ds = ChainMap(*map(itemgetter("all_figures"), meta_ds[0].values())) # type: ignore 110 | 111 | filter_urls = concatenate_datasets(list(self.datikz.values()))['uri'] 112 | filter_urls = {basename(url) for url in filter_urls if url.startswith("https://arxiv.org")} 113 | 114 | img_ds = img_ds.map( 115 | self.extract_figures, 116 | batched=True, 117 | remove_columns=img_ds.column_names, 118 | fn_kwargs=dict(meta_ds=meta_ds, filter_urls=filter_urls) 119 | ).shuffle() 120 | 121 | schematics = img_ds.filter(lambda ex: ex['type'] == "schematic").select(range(round(.6 * n_samples))) 122 | plots = img_ds.filter(lambda ex: ex['type'] == "plot").select(range(round(.2 * n_samples))) 123 | other = img_ds.filter(lambda ex: ex['type'] not in ["plot", "schematic"]).select(range(n_samples - len(schematics) - len(plots))) 124 | 125 | for ex in concatenate_datasets([schematics, plots, other]).cast_column("image", Image()): 126 | yield {"prompt": "", "images": self.sketchify(ex['image'])} # type: ignore 127 | 128 | def sample_datikz_dataset(self, n_samples, split="train"): 129 | tokenizer, image_seq_len = self.processor.tokenizer, self.processor.image_seq_len 130 | 131 | datikz_filtered = self.datikz[split].filter( 132 | function=lambda ex: len(tokenizer.tokenize(ex['code'])) + image_seq_len > tokenizer.model_max_length, 133 | ).train_test_split(train_size=n_samples) 134 | 135 | for ex in datikz_filtered['train']: 136 | yield {"prompt": "", "images": self.sketchify(ex['image'])} # type: ignore 137 | 138 | def sample(self, n_samples): 139 | spiqa: Dataset = Dataset.from_generator( # type: ignore 140 | generator=self.sample_spiqa_dataset, 141 | gen_kwargs=dict(n_samples=round(.5 * n_samples)) 142 | ) 143 | datikz: Dataset = Dataset.from_generator( # type: ignore 144 | generator=self.sample_datikz_dataset, 145 | gen_kwargs=dict(n_samples=n_samples-len(spiqa)) 146 | ) 147 | 148 | return concatenate_datasets([spiqa, datikz]) 149 | 150 | 151 | class RewardFunc: 152 | __name__ = "SelfSim Reward" 153 | 154 | def __init__(self, model, processor, num_workers=1, strict=False): 155 | self.model = model 156 | self.processor = processor 157 | self.strict = strict 158 | self.pool = Pool(num_workers) 159 | 160 | @cached_property 161 | def reward_model(self): 162 | return ImageSim.from_detikzify(self.model, self.processor, sync_on_compute=False) 163 | 164 | @staticmethod 165 | def compile(code, size, strict): 166 | doc = TikzDocument(code) 167 | 168 | if doc.is_rasterizable and not (strict and doc.compiled_with_errors): 169 | return doc.rasterize(size=size) 170 | 171 | def __call__(self, images, completions, **_): 172 | rewards, compile = list(), partial( 173 | self.compile, 174 | size=self.model.config.vision_config.image_size, 175 | strict=self.strict 176 | ) 177 | 178 | for doc, img in zip(self.pool.imap(compile, completions), images): 179 | if doc is not None: 180 | self.reward_model.update(doc, img) 181 | rewards.append(self.reward_model.compute()) 182 | else: 183 | rewards.append(-1.) 184 | self.reward_model.reset() 185 | return rewards 186 | 187 | 188 | def train( 189 | model, 190 | processor, 191 | dataset, 192 | output_dir: str, 193 | overwrite: bool = False, 194 | deepspeed: Optional[str] = None, 195 | num_compile_workers: int = 4, 196 | # training hyperparams 197 | strict: bool = False, 198 | freeze_encoder: bool = True, 199 | num_generations: int = 32, 200 | batch_size: int = 16, 201 | micro_batch_size: int = 1, 202 | num_train_steps: int = 500, 203 | learning_rate: float = 1e-5, 204 | gradient_checkpointing: bool = False, 205 | ): 206 | for _, param in model.model.vision_model.named_parameters(): 207 | param.requires_grad = not freeze_encoder 208 | model.enable_input_require_grads() 209 | 210 | training_args = GRPOConfig( 211 | per_device_train_batch_size=micro_batch_size, 212 | gradient_accumulation_steps=num_generations * batch_size // micro_batch_size // WORLD_SIZE, 213 | num_generations=num_generations, 214 | gradient_checkpointing=gradient_checkpointing, 215 | # https://github.com/huggingface/transformers/issues/32576 216 | gradient_checkpointing_kwargs={'use_reentrant': False}, 217 | lr_scheduler_type="cosine", 218 | weight_decay=0.01, 219 | epsilon=0.4, 220 | temperature=max(1., model.generation_config.temperature), 221 | top_p=model.generation_config.top_p, 222 | top_k=model.generation_config.top_k, 223 | max_steps=num_train_steps, 224 | logging_steps=num_train_steps//100, 225 | save_steps=num_train_steps//10, 226 | save_strategy="steps", 227 | save_total_limit=1, 228 | learning_rate=learning_rate, 229 | torch_compile=True, 230 | bf16=True, 231 | tf32=True, 232 | max_completion_length=processor.tokenizer.model_max_length-processor.image_seq_len, 233 | max_prompt_length=None, 234 | optim="adamw_torch" if deepspeed else "adamw_torch_fused", 235 | ddp_find_unused_parameters=False, 236 | remove_unused_columns=False, 237 | report_to="none", 238 | log_completions=True, 239 | num_completions_to_print=1, 240 | output_dir=output_dir, 241 | overwrite_output_dir=overwrite, 242 | deepspeed=deepspeed, 243 | ) 244 | 245 | trainer = GRPOTrainer( 246 | model=model, 247 | processing_class=processor, 248 | reward_funcs=[RewardFunc(model, processor, num_workers=num_compile_workers, strict=strict)], 249 | args=training_args, 250 | train_dataset=dataset, 251 | ) 252 | trainer.generation_config.bad_words_ids = [[model.config.image_token_id]] 253 | # trainer.generation_config.begin_suppress_tokens = [model.config.text_config.eos_token_id] 254 | trainer.train(resume_from_checkpoint=None if overwrite else get_last_checkpoint(output_dir)) 255 | 256 | if trainer.is_deepspeed_enabled: 257 | # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading 258 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint # type: ignore 259 | last_checkpoint = get_last_checkpoint(output_dir) 260 | load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint) 261 | 262 | trainer.save_model(output_dir) 263 | 264 | 265 | def parse_args(): 266 | argument_parser = ArgumentParser( 267 | description="Post-train DeTikZify with GRPO." 268 | ) 269 | argument_parser.add_argument("--base_model", 270 | required=True, 271 | help="The model checkpoint for weights initialization." 272 | ) 273 | argument_parser.add_argument("--datikz", 274 | default="nllg/datikz-v3", 275 | help="Path or name of the DaTikZ dataset.", 276 | ) 277 | argument_parser.add_argument("--sketch_ratio", 278 | default=.5, 279 | type=float, 280 | help="ratio of synthetic sketches generated through UltraSketch or image transforms", 281 | ) 282 | argument_parser.add_argument("--output", 283 | required=True, 284 | dest="output_dir", 285 | help="directory where to write the model files", 286 | ) 287 | argument_parser.add_argument("--num_compile_workers", 288 | default=4, 289 | type=int, 290 | help="number of threads to compile TikZ code with", 291 | ) 292 | argument_parser.add_argument("--deepspeed", 293 | help="path to a DeepSpeed json config file", 294 | ) 295 | argument_parser.add_argument("--gradient_checkpointing", 296 | action="store_true", 297 | help="use gradient checkpointing", 298 | ) 299 | argument_parser.add_argument("--strict", 300 | action="store_true", 301 | help="treat recoverable compilation errors as fatal errors", 302 | ) 303 | argument_parser.add_argument("--batch_size", 304 | default=16, 305 | type=int, 306 | help="global batch size for training", 307 | ) 308 | argument_parser.add_argument("--num_train_steps", 309 | default=500, 310 | type=int, 311 | help="number of training steps to run GRPO for", 312 | ) 313 | return vars(argument_parser.parse_args()) 314 | 315 | 316 | if __name__ == "__main__": 317 | set_verbosity_info() 318 | enable_explicit_format() 319 | set_start_method('forkserver') # https://github.com/pytorch/pytorch/issues/17199#issuecomment-465313245 320 | init_process_group("nccl", timeout=timedelta(days=3)) 321 | set_seed(0) 322 | 323 | args = parse_args() 324 | model, processor = load( 325 | model_name_or_path=args.pop("base_model"), 326 | torch_dtype="bfloat16", 327 | attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, 328 | ) 329 | 330 | with Accelerator().main_process_first(): 331 | dataset = TrainDataset( 332 | processor=processor, 333 | datikz_name=args.pop('datikz'), 334 | sketch_ratio=args.pop("sketch_ratio"), 335 | size=model.config.vision_config.image_size, 336 | ).sample(args['batch_size'] * args['num_train_steps']) 337 | 338 | train(model=model, processor=processor, dataset=dataset, **args) 339 | -------------------------------------------------------------------------------- /detikzify/model/v1/modeling_detikzify.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava.py. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from pickle import UnpicklingError 18 | from typing import List, Optional, Tuple, Union 19 | 20 | from numpy import clip 21 | from safetensors.torch import load_file 22 | from timm import create_model as create_vision_model 23 | import torch 24 | import torch.nn as nn 25 | from torch.nn import CrossEntropyLoss 26 | import torch.nn.functional as F 27 | from transformers import ( 28 | BatchEncoding, 29 | LlamaConfig, 30 | LlamaForCausalLM, 31 | LlamaModel, 32 | PretrainedConfig, 33 | PreTrainedModel, 34 | ) 35 | from transformers.modeling_outputs import ( 36 | BaseModelOutputWithPoolingAndNoAttention, 37 | BaseModelOutputWithPast, 38 | CausalLMOutputWithPast, 39 | ) 40 | from transformers.utils import logging 41 | 42 | from .configuration_detikzify import DetikzifyConfig 43 | from .processing_detikzify import DetikzifyImageProcessor 44 | 45 | 46 | logger = logging.get_logger("transformers") 47 | 48 | 49 | class DetikzifyVisionModel(PreTrainedModel): 50 | _no_split_modules = ["VisionTransformer"] 51 | def __init__(self, model, **kwargs) -> None: 52 | super().__init__(PretrainedConfig.from_dict(model.pretrained_cfg), **kwargs) 53 | # HACK: wrap in list so that vision model does not count as a parameter 54 | self.model = [model] 55 | 56 | def get_input_embeddings(self) -> torch.nn.Module: 57 | return self.model[0].patch_embed 58 | 59 | def to_input_dtype(self, pixel_values: torch.Tensor): 60 | target_dtype = self.get_input_embeddings().proj.weight.dtype 61 | return pixel_values.to(dtype=target_dtype) 62 | 63 | def forward(self, pixel_values: torch.Tensor): 64 | last_hidden_state = self.model[0].forward_features(self.to_input_dtype(pixel_values)) 65 | pooler_output = self.model[0].forward_head(last_hidden_state) 66 | return BaseModelOutputWithPoolingAndNoAttention( 67 | last_hidden_state=last_hidden_state, 68 | pooler_output=pooler_output 69 | ) 70 | 71 | def get_intermediate_layers(self, pixel_values: torch.Tensor, *args, **kwargs): 72 | return self.model[0].get_intermediate_layers(self.to_input_dtype(pixel_values), *args, **kwargs) 73 | 74 | 75 | class DetikzifyModel(LlamaModel): 76 | config_class = DetikzifyConfig 77 | 78 | def __init__(self, config: LlamaConfig): 79 | super(DetikzifyModel, self).__init__(config) 80 | 81 | if getattr(config, "use_mm_proj"): 82 | self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) 83 | self.vision_model = DetikzifyVisionModel(create_vision_model(config.vision_tower)) 84 | 85 | def initialize_vision_modules( 86 | self, 87 | vision_tower, 88 | patch_token_id, 89 | concat_patches=3, 90 | feature_layer=-1, 91 | modality_projector=None, 92 | **kwargs 93 | ): 94 | vision_model = create_vision_model(vision_tower, pretrained=True, **kwargs) 95 | self.vision_model = DetikzifyVisionModel(vision_model.to(self.device, self.dtype).eval().requires_grad_(False)) 96 | 97 | processor = DetikzifyImageProcessor.from_pretrained(vision_tower) 98 | 99 | self.config.use_mm_proj = True 100 | self.config.vision_tower = vision_tower 101 | self.config.mm_hidden_size = vision_model.embed_dim * concat_patches 102 | self.config.patch_token_id = patch_token_id 103 | self.config.concat_patches = concat_patches 104 | self.config.feature_layer = int(clip(feature_layer, -(depth:=len(vision_model.blocks)), depth-1) % depth) 105 | self.config.vision_config = processor.to_dict() # type: ignore 106 | self.config.num_patches = vision_model.patch_embed.num_patches // concat_patches 107 | 108 | if not hasattr(self, 'mm_projector'): 109 | self.mm_projector = nn.Linear( 110 | self.config.mm_hidden_size, 111 | self.config.hidden_size, 112 | dtype=self.dtype, 113 | device=self.device 114 | ) 115 | 116 | if modality_projector is not None: 117 | try: # first try to load as pickle 118 | mm_projector_weights = torch.load(modality_projector, map_location=self.device) 119 | except UnpicklingError: # and if that fails we try safetensors 120 | mm_projector_weights = load_file(modality_projector, device=str(self.device)) 121 | self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) 122 | 123 | return processor 124 | 125 | # https://stackoverflow.com/a/57208704 126 | def _apply(self, fn): 127 | super()._apply(fn) 128 | if hasattr(self, "vision_model"): 129 | self.set_vision_model = self.vision_model._apply(fn) 130 | return self 131 | 132 | def get_vision_features(self, pixel_values): 133 | concat, n_patch, layer = self.config.concat_patches, self.config.num_patches, self.config.feature_layer 134 | feats = self.vision_model.get_intermediate_layers(pixel_values, n=[layer], norm=True)[0] 135 | # in case the number of feature vectors is not divisible by the number 136 | # of patches we want to concatenate, we remove the first feature(s) 137 | return feats[:, -n_patch * concat:].reshape(-1, n_patch, feats.shape[-1] * concat) 138 | 139 | def is_tensor(self, thing): 140 | if isinstance(thing, (BatchEncoding, dict)): 141 | return all(isinstance(v, torch.Tensor) for v in thing.values()) 142 | return isinstance(thing, torch.Tensor) 143 | 144 | def forward( 145 | self, 146 | input_ids: torch.LongTensor = None, 147 | attention_mask: Optional[torch.Tensor] = None, 148 | past_key_values: Optional[List[torch.FloatTensor]] = None, 149 | inputs_embeds: Optional[torch.FloatTensor] = None, 150 | use_cache: Optional[bool] = None, 151 | output_attentions: Optional[bool] = None, 152 | output_hidden_states: Optional[bool] = None, 153 | pixel_values: Optional[torch.FloatTensor] = None, 154 | return_dict: Optional[bool] = None, 155 | ) -> Union[Tuple, BaseModelOutputWithPast]: 156 | 157 | if inputs_embeds is None: 158 | inputs_embeds = self.embed_tokens(input_ids) 159 | 160 | if hasattr(self, "vision_model") and (input_ids.shape[1] != 1 or self.training) and pixel_values is not None: 161 | with torch.no_grad(): 162 | image_features = self.get_vision_features(pixel_values) 163 | image_features = self.mm_projector(image_features) 164 | dummy_image_features = torch.zeros(len(image_features[0]), self.config.mm_hidden_size, device=inputs_embeds.device, dtype=inputs_embeds.dtype) 165 | dummy_image_features = self.mm_projector(dummy_image_features) 166 | 167 | new_input_embeds = [] 168 | cur_image_idx = 0 169 | for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): 170 | if (cur_input_ids == self.config.image_token_id).sum() == 0: 171 | # multimodal LLM, but the current sample is not multimodal 172 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() 173 | new_input_embeds.append(cur_input_embeds) 174 | cur_image_idx += 1 175 | continue 176 | 177 | cur_image_features = image_features[cur_image_idx].to(cur_input_embeds.device) 178 | num_patches = cur_image_features.shape[0] 179 | if (cur_input_ids == self.config.image_token_id).sum() != num_patches: 180 | raise ValueError("The number of image patch tokens should be the same as the number of image patches.") 181 | masked_indices = torch.where(cur_input_ids == self.config.image_token_id)[0] 182 | mask_index_start = masked_indices[0] 183 | if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): 184 | raise ValueError("The image patch tokens should be consecutive.") 185 | cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) 186 | new_input_embeds.append(cur_new_input_embeds) 187 | cur_image_idx += 1 188 | 189 | inputs_embeds = torch.stack(new_input_embeds, dim=0) 190 | 191 | return super(DetikzifyModel, self).forward( 192 | input_ids=None, 193 | attention_mask=attention_mask, 194 | past_key_values=past_key_values, 195 | inputs_embeds=inputs_embeds, 196 | use_cache=use_cache, 197 | output_attentions=output_attentions, 198 | output_hidden_states=output_hidden_states, 199 | return_dict=return_dict 200 | ) 201 | 202 | 203 | class DetikzifyForCausalLM(LlamaForCausalLM): 204 | config_class = DetikzifyConfig 205 | 206 | def __init__(self, config): 207 | super(LlamaForCausalLM, self).__init__(config) 208 | self.model = DetikzifyModel(config) 209 | 210 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 211 | 212 | # Initialize weights and apply final processing 213 | self.post_init() 214 | 215 | def get_model(self): 216 | return self.model 217 | 218 | def forward( 219 | self, 220 | input_ids: torch.LongTensor = None, 221 | attention_mask: Optional[torch.Tensor] = None, 222 | past_key_values: Optional[List[torch.FloatTensor]] = None, 223 | inputs_embeds: Optional[torch.FloatTensor] = None, 224 | labels: Optional[torch.LongTensor] = None, 225 | use_cache: Optional[bool] = None, 226 | output_attentions: Optional[bool] = None, 227 | output_hidden_states: Optional[bool] = None, 228 | pixel_values: Optional[torch.FloatTensor] = None, 229 | return_dict: Optional[bool] = None, 230 | ) -> Union[Tuple, CausalLMOutputWithPast]: 231 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 232 | output_hidden_states = ( 233 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 234 | ) 235 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 236 | 237 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 238 | outputs = self.model( 239 | input_ids=input_ids, 240 | attention_mask=attention_mask, 241 | past_key_values=past_key_values, 242 | inputs_embeds=inputs_embeds, 243 | use_cache=use_cache, 244 | output_attentions=output_attentions, 245 | output_hidden_states=output_hidden_states, 246 | return_dict=return_dict, 247 | pixel_values=pixel_values 248 | ) 249 | 250 | hidden_states = outputs[0] 251 | if self.config.pretraining_tp > 1: 252 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 253 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 254 | logits = torch.cat(logits, dim=-1) 255 | else: 256 | logits = self.lm_head(hidden_states) 257 | logits = logits.float() 258 | 259 | 260 | loss = None 261 | if labels is not None: 262 | # Shift so that tokens < n predict n 263 | shift_logits = logits[..., :-1, :].contiguous() 264 | shift_labels = labels[..., 1:].contiguous() 265 | # Flatten the tokens 266 | loss_fct = CrossEntropyLoss() 267 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 268 | shift_labels = shift_labels.view(-1) 269 | # Enable model/pipeline parallelism 270 | shift_labels = shift_labels.to(shift_logits.device) 271 | loss = loss_fct(shift_logits, shift_labels) 272 | 273 | if not return_dict: 274 | output = (logits,) + outputs[1:] 275 | return (loss,) + output if loss is not None else output 276 | 277 | return CausalLMOutputWithPast( 278 | loss=loss, 279 | logits=logits, 280 | past_key_values=outputs.past_key_values, 281 | hidden_states=outputs.hidden_states, 282 | attentions=outputs.attentions, 283 | ) 284 | 285 | def prepare_inputs_for_generation( 286 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 287 | ): 288 | if past_key_values: 289 | input_ids = input_ids[:, -1:] 290 | 291 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 292 | if inputs_embeds is not None and past_key_values is None: 293 | model_inputs = {"inputs_embeds": inputs_embeds} 294 | else: 295 | model_inputs = {"input_ids": input_ids} 296 | 297 | model_inputs.update( 298 | { 299 | "past_key_values": past_key_values, 300 | "use_cache": kwargs.get("use_cache"), 301 | "attention_mask": attention_mask, 302 | "pixel_values": kwargs.get("pixel_values", None), 303 | } 304 | ) 305 | return model_inputs 306 | --------------------------------------------------------------------------------