.
79 | """
80 | # prepare the sketches (distribute load among all workers)
81 | worker_sketches, all_sketches = list(), WORLD_SIZE * [None]
82 | for i in torch.arange(len(dataset['image'])).tensor_split(WORLD_SIZE)[RANK]:
83 | # randomize in which epochs how many images should be sketchified
84 | num_sketches = choice([floor(product:=ratio*num_epochs), ceil(product)])
85 | sketch_epochs = sample(range(num_epochs), k=num_sketches)
86 | # generate the sketches
87 | sketches = [sketchifier(dataset['image'][i.item()]) for _ in range(num_sketches)]
88 | worker_sketches.append([sketches.pop() if epoch in sketch_epochs else None for epoch in range(num_epochs)])
89 |
90 | torch.distributed.all_gather_object(all_sketches, worker_sketches) # type: ignore
91 | dataset['sketches'] = list(chain.from_iterable(all_sketches)) # type: ignore
92 | return dataset
93 |
94 | def parse_args():
95 | argument_parser = ArgumentParser(
96 | description="Sketchify an existing DaTikZ dataset."
97 | )
98 | argument_parser.add_argument(
99 | "--path",
100 | default="nllg/datikz-v3",
101 | help="Path or name of the DaTikZ dataset.",
102 | )
103 | return argument_parser.parse_args()
104 |
105 | if __name__ == "__main__":
106 | args = parse_args()
107 | datikz = load_dataset(args.path)
108 | torch.distributed.init_process_group()
109 |
110 | train = datikz['train'].map(
111 | function=sketchify,
112 | batched=True,
113 | batch_size=WORLD_SIZE * 1000,
114 | desc="Sketchify (train)",
115 | fn_kwargs=dict(
116 | num_epochs=5,
117 | ratio=0.5,
118 | sketchifier=(sketchifier:=Sketchifier())
119 | )
120 | )
121 |
122 | # test split is small so keep it simple and do it only on the main process
123 | if RANK == 0:
124 | test = datikz['test'].map(
125 | function=lambda ex: ex | {"sketch": sketchifier(ex['image'])},
126 | desc="Sketchify (test)"
127 | )
128 |
129 | train.to_parquet("datikz-train-sketches.parquet", compression="GZIP")
130 | test.to_parquet("datikz-test-sketches.parquet", compression="GZIP")
131 |
--------------------------------------------------------------------------------
/detikzify/model/processing_detikzify.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Adapted from
16 | # https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be
17 |
18 | from typing import List, Optional, Union, Unpack
19 |
20 | from transformers.feature_extraction_utils import BatchFeature
21 | from transformers.image_utils import ImageInput, make_list_of_images
22 | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
23 | from transformers.tokenization_utils_base import (
24 | BatchEncoding,
25 | PreTokenizedInput,
26 | TextInput,
27 | )
28 | from transformers.utils import logging
29 |
30 | logger = logging.get_logger(__name__)
31 |
32 |
33 | class DetikzifyProcessorKwargs(ProcessingKwargs, total=False):
34 | _defaults = {
35 | "text_kwargs": {
36 | "add_special_tokens": False,
37 | "padding": False,
38 | },
39 | }
40 |
41 |
42 | class DetikzifyProcessor(ProcessorMixin):
43 | attributes = ["image_processor", "tokenizer"]
44 | image_processor_class = "AutoImageProcessor"
45 | tokenizer_class = "AutoTokenizer"
46 |
47 | def __init__(
48 | self,
49 | image_processor,
50 | tokenizer=None,
51 | image_seq_len: int = 300,
52 | image_token: str = "<|reserved_special_token_2|>",
53 | model_expects_text: bool = False,
54 | **kwargs,
55 | ):
56 | if image_processor is None:
57 | raise ValueError("You need to specify an `image_processor`.")
58 | if tokenizer is None:
59 | raise ValueError("You need to specify a `tokenizer`.")
60 | if image_token not in tokenizer.vocab:
61 | raise ValueError(f"{image_token} needs to be added to the `tokenizer` vocabulary.")
62 |
63 | self.image_token = image_token
64 | self.image_seq_len = image_seq_len
65 | self.model_expects_text = model_expects_text
66 |
67 | super().__init__(image_processor, tokenizer, **kwargs)
68 |
69 | def __call__(
70 | self,
71 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
72 | images: ImageInput = None,
73 | image_seq_len: Optional[int] = None,
74 | add_bos_token: bool = None,
75 | add_eos_token: bool = None,
76 | **kwargs: Unpack[DetikzifyProcessorKwargs],
77 | ) -> BatchEncoding:
78 | output_kwargs = self._merge_kwargs(
79 | DetikzifyProcessorKwargs,
80 | tokenizer_init_kwargs=self.tokenizer.init_kwargs,
81 | **kwargs,
82 | )
83 | # Temporary fix for "padding_side" in init_kwargs
84 | output_kwargs["text_kwargs"].pop("padding_side", None)
85 |
86 | if images is None:
87 | raise ValueError("`images` are expected as arguments to a `DetikzifyProcessor` instance.")
88 | else:
89 | if isinstance(images, list) and all(isinstance(img, list) and len(img) == 1 for img in images):
90 | # compatibility with trl
91 | images = [img[0] for img in images]
92 | images = make_list_of_images(images)
93 | if text is None:
94 | text = len(images) * [""]
95 | elif isinstance(text, str):
96 | text = [text]
97 | if len(images) != len(text):
98 | raise ValueError(
99 | f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
100 | )
101 |
102 | prompt_strings = []
103 | for prompt in text:
104 | assert self.image_token not in prompt, "Image tokens are added by the processor!"
105 | if add_bos_token:
106 | prompt += self.tokenizer.bos_token
107 | if add_eos_token:
108 | prompt += self.tokenizer.eos_token
109 | image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
110 | prompt_strings.append((self.image_token * image_seq_len) + prompt)
111 |
112 | image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
113 | text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
114 |
115 | return BatchFeature(data={**image_inputs, **text_inputs})
116 |
117 | def batch_decode(self, *args, **kwargs):
118 | return self.tokenizer.batch_decode(*args, **kwargs)
119 |
120 | def decode(self, *args, **kwargs):
121 | return self.tokenizer.decode(*args, **kwargs)
122 |
123 | @property
124 | def model_input_names(self):
125 | tokenizer_input_names = self.tokenizer.model_input_names
126 | image_processor_input_names = self.image_processor.model_input_names
127 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
128 |
--------------------------------------------------------------------------------
/detikzify/train/train.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import os
3 | from random import random
4 | from typing import Dict, List
5 |
6 | from PIL import Image
7 | import torch
8 | from torch.utils.data import Dataset
9 | from transformers import Trainer, TrainerCallback, TrainingArguments
10 | from transformers.trainer_utils import get_last_checkpoint
11 | from transformers.utils import logging
12 |
13 | from ..util import SketchAugment, SplitEpochSaveCallback
14 | from .pretrain import tokenize
15 |
16 | logger = logging.get_logger("transformers")
17 |
18 | WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
19 | RANK = int(os.environ.get("RANK", 0))
20 |
21 | class ImageSketchDataset(Dataset, TrainerCallback):
22 | """
23 | Dataset which samples sketches instead of images, when a sketch exists
24 | for the current epoch.
25 | """
26 | def __init__(self, dataset, processor, ds_sketch_ratio=.5):
27 | super().__init__()
28 | self.processor = processor
29 | self.dataset = dataset.with_transform(self.tokenize)
30 | self.ds_sketch_ratio = ds_sketch_ratio
31 | self.sketchify = SketchAugment()
32 | self.cur_epoch = 0
33 |
34 | def __len__(self):
35 | return len(self.dataset)
36 |
37 | def tokenize(self, batch):
38 | for idx, sketches in enumerate(batch['sketches']):
39 | if (sketch:=sketches[self.cur_epoch]):
40 | if random() >= self.ds_sketch_ratio:
41 | batch['image'][idx] = Image.open(BytesIO(sketch['bytes'])).convert("RGB")
42 | else:
43 | batch['image'][idx] = self.sketchify(batch['image'][idx])
44 |
45 | return tokenize(
46 | batch=batch,
47 | processor=self.processor,
48 | return_tensors="pt",
49 | truncation=False,
50 | padding=True
51 | )
52 |
53 | def filter(self, *args, **kwargs):
54 | self.dataset = self.dataset.filter(*args, **kwargs)
55 |
56 | def __getitem__(self, index) -> Dict[str, torch.Tensor]:
57 | return self.dataset[index]
58 |
59 | def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]:
60 | return self.dataset[*indices]
61 |
62 | def on_epoch_end(self, *args, **kwargs):
63 | self.cur_epoch += 1
64 |
65 | def train(
66 | output_dir: str,
67 | model,
68 | processor,
69 | dataset,
70 | overwrite=False,
71 | deepspeed=None,
72 | # training hyperparams
73 | batch_size: int = 128,
74 | micro_batch_size: int = 1,
75 | num_epochs: int = 5,
76 | learning_rate: float = 5e-5,
77 | sketch_ratio=.5,
78 | gradient_checkpointing: bool = False,
79 | ):
80 | assert num_epochs <= len(dataset[0]['sketches'])
81 | gradient_accumulation_steps = batch_size // micro_batch_size
82 | if WORLD_SIZE != 1:
83 | gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE
84 |
85 | dataset = ImageSketchDataset(dataset, processor, ds_sketch_ratio=sketch_ratio)
86 | logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}")
87 | eos_token_id, model_max_length = processor.tokenizer.eos_token_id, processor.tokenizer.model_max_length
88 | dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length)
89 | logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}")
90 |
91 | last_checkpoint = None
92 | if os.path.isdir(output_dir) and not overwrite:
93 | last_checkpoint = get_last_checkpoint(output_dir)
94 | if last_checkpoint is None and len(os.listdir(output_dir)) > 0:
95 | raise ValueError(
96 | f"Output directory ({output_dir}) already exists and is not empty. "
97 | "Use `overwrite` to overcome."
98 | )
99 | elif last_checkpoint is not None:
100 | logger.info(
101 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
102 | "the `output_dir` or add `overwrite` to train from scratch."
103 | )
104 |
105 | trainer = Trainer(
106 | model=model,
107 | train_dataset=dataset,
108 | args=TrainingArguments(
109 | per_device_train_batch_size=micro_batch_size,
110 | gradient_accumulation_steps=gradient_accumulation_steps,
111 | gradient_checkpointing=gradient_checkpointing,
112 | # https://github.com/huggingface/transformers/issues/32576
113 | gradient_checkpointing_kwargs={'use_reentrant':False},
114 | dataloader_num_workers=WORLD_SIZE,
115 | warmup_ratio=0.03,
116 | weight_decay=0,
117 | num_train_epochs=num_epochs,
118 | learning_rate=learning_rate,
119 | torch_compile=True,
120 | bf16=True,
121 | tf32=True,
122 | logging_steps=10,
123 | lr_scheduler_type="cosine",
124 | optim="adamw_torch" if deepspeed else "adamw_torch_fused",
125 | ddp_find_unused_parameters=False,
126 | remove_unused_columns=False,
127 | save_strategy="epoch",
128 | report_to="none",
129 | save_total_limit=1,
130 | output_dir=output_dir,
131 | deepspeed=deepspeed,
132 | ),
133 | callbacks=[SplitEpochSaveCallback(step_size=0.25)],
134 | data_collator=lambda batch: batch
135 | )
136 |
137 | trainer.add_callback(trainer.train_dataset)
138 | trainer.train(resume_from_checkpoint=last_checkpoint)
139 |
140 | if trainer.is_deepspeed_enabled:
141 | # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading
142 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
143 | last_checkpoint = get_last_checkpoint(output_dir)
144 | load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint)
145 |
146 | trainer.save_model(output_dir)
147 | trainer.save_state()
148 |
149 | return model, processor
150 |
--------------------------------------------------------------------------------
/detikzify/mcts/README.md:
--------------------------------------------------------------------------------
1 | A Python3 library for running a [Monte Carlo tree search](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search), either traditionally by drilling down to end game states or with expert policies as might be provided by a neural network.
2 |
3 | Adapted from **Version:** 1.3.1 of [ImparaAI/monte-carlo-tree-search](https://github.com/ImparaAI/monte-carlo-tree-search).
4 |
5 | # Monte Carlo tree search basics
6 |
7 | The Monte Carlo tree search (MCTS) algorithm can help with making a decision from a number of options. It avoids exploring every possible option by randomly sampling a small number of pathways and picking the move with the highest probability of victory. This is commonly applied to games like chess or go where it's useful to know what move should come next if you want to win the game.
8 |
9 | MCTS works by expanding the search tree to figure out which moves (or child/subsequent states) are likely to produce a positive result if chosen. While time is available, the algorithm continues to explore the tree, always slightly favoring the direction that has either proven to be fruitful or is less explored. When no time is left, the most explored direction is chosen.
10 |
11 | The search tree expansion can be done in two different ways:
12 |
13 | - **Traditional**: At least one random rollout to a game's end state (e.g. win, loss, tie) for each move under evaluation so the algorithm can make a choice.
14 | - **Expert policy (i.e. neural network)**: Instead of expensively rolling all the way out to a game's end state ask an expert (a neural network for example) which move is most likely to produce a positive outcome.
15 |
16 | For a deeper dive into the topic, check out [this article](http://tim.hibal.org/blog/alpha-zero-how-and-why-it-works/).
17 |
18 | # This library
19 |
20 | As the user of this library, you only have to provide:
21 |
22 | - A function that finds the direct children of each search tree node (called the **`child_finder`**)
23 | - A function for evaluating nodes for end state outcomes (called the **`node_evaluator`**)
24 | -- *(Not necessary with neural network)*
25 |
26 | # Usage
27 |
28 | Create a new Monte Carlo tree:
29 |
30 | ```python
31 | from chess import Game
32 | from montecarlo.node import Node
33 | from montecarlo.montecarlo import MonteCarlo
34 |
35 | chess_game = Game()
36 | montecarlo = MonteCarlo(Node(chess_game))
37 | ```
38 |
39 | The root node describes your current game state. This state will be used by you later in the **`child_finder`** and the **`node_evaluator`**.
40 |
41 | For the sake of demonstration, we will assume you have a generic `Game` library that can tell you what moves are possible and allows you to perform those moves to change the game's state.
42 |
43 | ## Traditional Monte Carlo
44 |
45 | Add a **`child_finder`** and a **`node_evaluator`**:
46 |
47 | ```python
48 | def child_finder(node, montecarlo):
49 | for move in node.state.get_possible_moves():
50 | child = Node(deepcopy(node.state)) #or however you want to construct the child's state
51 | child.state.move(move) #or however your library works
52 | node.add_child(child)
53 |
54 | def node_evaluator(node, montecarlo):
55 | if node.state.won():
56 | return 1
57 | elif node.state.lost():
58 | return -1
59 |
60 | montecarlo.child_finder = child_finder
61 | montecarlo.node_evaluator = node_evaluator
62 | ```
63 |
64 | The **`child_finder`** should add any child nodes to the parent node passed into the function, if there are any. If there are none, the parent should be in an end state, so the **`node_evaluator`** should return a value between `-1` and `1`.
65 |
66 | ## Expert policy (AI)
67 |
68 | If you have an expert policy that you can apply to the children as they're being generated, the library will recognize that it doesn't need to make the costly drill down to an end state. If your neural net produces both an expert policy value for the children and a win value for the parent node, you can skip declaring the `node_evaluator` altogether.
69 |
70 | ```python
71 | def child_finder(node, montecarlo):
72 | win_value, expert_policy_values = neural_network.predict(node.state)
73 |
74 | for move in node.state.get_possible_moves():
75 | child = Node(deepcopy(node.state))
76 | child.state.move(move)
77 | child.player_number = child.state.whose_turn()
78 | child.policy_value = get_child_policy_value(child, expert_policy_values) #should return a probability value between 0 and 1
79 | node.add_child(child)
80 |
81 | node.update_win_value(win_value)
82 |
83 | montecarlo.child_finder = child_finder
84 | ```
85 |
86 | ## Simulate and make a choice
87 |
88 | Run the simulations:
89 |
90 | ```python
91 | montecarlo.simulate(50) #number of expansions to run. higher is typically more accurate at the cost of processing time
92 | ```
93 |
94 | Once the simulations have run you can ask the instance to make a choice:
95 |
96 | ```python
97 | chosen_child_node = montecarlo.make_choice()
98 | chosen_child_node.state.do_something()
99 | ```
100 |
101 | After you've chosen a new root node, you can override it on the `montecarlo` instance and do more simulations from the new position in the tree.
102 |
103 | ```python
104 | montecarlo.root_node = montecarlo.make_choice()
105 | ```
106 |
107 | If you're training a neural network, you may want to make a more exploratory choice for the first N moves of a game:
108 |
109 | ```python
110 | montecarlo.root_node = montecarlo.make_exploratory_choice()
111 | ```
112 |
113 | This won't provide a purely random choice, rather it will be random with a bias favoring the more explored pathways.
114 |
115 | ## Turn-based environments
116 |
117 | If you are modeling a turn-based environment (e.g. a two player board game), set the `player_number` on each node so the selection process can invert child win values:
118 |
119 | ```python
120 | node = Node(state)
121 | node.player_number = 1
122 | ```
123 |
124 | It doesn't matter what this number is (you can use 1 and 2 or 5 and 6), only that it is consistent with other nodes.
125 |
126 | ## Tweaking the discovery factor
127 |
128 | When building a new child node, you can change the rate at which discovery is preferred:
129 |
130 | ```python
131 | node = Node(state)
132 | node.discovery_factor = 0.2 #0.35 by default, can be between 0 and 1
133 | ```
134 |
135 | The closer this number is to 1, the more discovery will be favored over demonstrated value in later simulations.
136 |
--------------------------------------------------------------------------------
/detikzify/evaluate/imagesim.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from math import tanh
3 | from typing import List, Literal, Optional
4 |
5 | from PIL import Image
6 | from ot.lp import emd2
7 | import torch
8 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported
9 | import torch.nn.functional as F
10 | from torchmetrics import Metric
11 | from torchmetrics.functional import pairwise_cosine_similarity
12 | from transformers import AutoImageProcessor, AutoModel, PreTrainedModel, ProcessorMixin
13 |
14 | from ..model.adapter import (
15 | AdapterProcessor,
16 | CrossAttentionAdapterMixin as AdapterMixin,
17 | has_adapter,
18 | )
19 | from ..util import cast, expand, infer_device, load, unwrap_processor
20 |
21 | class ImageSim(Metric):
22 | """Perceptual image similarity using visual encoders."""
23 |
24 | higher_is_better = True
25 |
26 | def __init__(
27 | self,
28 | model_name: str = "google/siglip-so400m-patch14-384",
29 | mode: Literal["emd", "cos", "cos_avg"] = "cos",
30 | preprocess: bool = True,
31 | device: str = infer_device(),
32 | dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16,
33 | **kwargs
34 | ):
35 | super().__init__(**kwargs)
36 | self.model_name = model_name
37 | self.preprocess = preprocess
38 | self.mode = mode
39 | self._device = device
40 | self.set_dtype(dtype)
41 |
42 | self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum")
43 | self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")
44 |
45 | def __str__(self):
46 | return self.__class__.__name__ + f" ({self.mode.upper().replace('_', '-')})"
47 |
48 | @cached_property
49 | def model(self):
50 | # even if we instantiate with from_detikzify we still end up in this function
51 | if (model:=dict(self.named_children()).get("model")) is None:
52 | model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype)
53 | model = model.vision_model.to(self.device)
54 | return model
55 |
56 | @cached_property
57 | def processor(self):
58 | return AutoImageProcessor.from_pretrained(self.model_name)
59 |
60 | @classmethod
61 | def from_detikzify(cls, model: PreTrainedModel, processor: ProcessorMixin, mode=None, *args, **kwargs):
62 | derived_kwargs = dict(
63 | model_name = model.name_or_path,
64 | mode = getattr(model.config, "pooling_mode", "emd") if mode is None else mode,
65 | device = model.device,
66 | dtype = model.dtype,
67 | )
68 | imagesim = cls(*args, **(derived_kwargs | kwargs))
69 |
70 | if has_adapter(model):
71 | class AdapterVisionModel(type(model.model.vision_model), AdapterMixin):
72 | embedding_model=model.embedding_model
73 | adapter=model.adapter
74 |
75 | @classmethod
76 | def cast(cls, vision_model):
77 | adapter_vision_model = cast(cls, vision_model)
78 | adapter_vision_model.add_hooks()
79 | return adapter_vision_model
80 |
81 | imagesim.model = AdapterVisionModel.cast(model.model.vision_model)
82 | imagesim.processor = AdapterProcessor(
83 | processor=unwrap_processor(processor).image_processor,
84 | tokenizer=processor.tokenizer # type: ignore
85 | )
86 | else:
87 | imagesim.model = model.model.vision_model
88 | imagesim.processor = unwrap_processor(processor).image_processor
89 | return imagesim
90 |
91 | def get_vision_features(self, image: Optional[Image.Image | str] = None, text: Optional[str] = None):
92 | if image is not None:
93 | image = load(image)
94 | if self.preprocess:
95 | image = expand(image, max(image.size), do_trim=True)
96 |
97 | with torch.inference_mode():
98 | if text is not None:
99 | encoding = self.processor(text=text, images=image, return_tensors="pt").to(self.device, self.dtype)
100 | else:
101 | encoding = self.processor(images=image, return_tensors="pt").to(self.device, self.dtype)
102 | if self.mode == "cos":
103 | return self.model(**encoding).pooler_output.squeeze()
104 | elif self.mode == "cos_avg":
105 | return self.model(**encoding).last_hidden_state.squeeze().mean(dim=0)
106 | else:
107 | return self.model(**encoding).last_hidden_state.squeeze()
108 |
109 | def get_similarity(
110 | self,
111 | img1: Optional[Image.Image | str] = None,
112 | img2: Optional[Image.Image | str] = None,
113 | text1: Optional[str] = None,
114 | text2: Optional[str] = None,
115 | ):
116 | img1_feats = self.get_vision_features(img1, text1)
117 | img2_feats = self.get_vision_features(img2, text2)
118 |
119 | if img1_feats.is_mps: # mps backend does not support dtype double
120 | img1_feats, img2_feats = img1_feats.cpu(), img2_feats.cpu()
121 | if img1_feats.ndim > 1:
122 | dists = 1 - pairwise_cosine_similarity(img1_feats.double(), img2_feats.double()).cpu().numpy()
123 | return 2 * tanh(-emd2(M=dists, a=list(), b=list())) + 1 # type: ignore
124 | else:
125 | return F.cosine_similarity(img1_feats.double(), img2_feats.double(), dim=0).item()
126 |
127 | def update(
128 | self,
129 | img1: Optional[Image.Image | str | List[Image.Image | str]] = None,
130 | img2: Optional[Image.Image | str | List[Image.Image | str]] = None,
131 | text1: Optional[str | List[str]] = None,
132 | text2: Optional[str | List[str]] = None,
133 | ):
134 | inputs = dict()
135 | for key, value in dict(img1=img1, img2=img2, text1=text1, text2=text2).items():
136 | if value is not None:
137 | inputs[key] = value if isinstance(value, List) else [value]
138 |
139 | assert not ({"img1", "text1"}.isdisjoint(inputs.keys()) or {"img2", "text2"}.isdisjoint(inputs.keys()))
140 | assert len(set(map(len, inputs.values()))) == 1
141 |
142 | for inpt in zip(*inputs.values()):
143 | self.score += self.get_similarity(**dict(zip(inputs.keys(), inpt)))
144 | self.n_samples += 1
145 |
146 | def compute(self):
147 | return (self.score / self.n_samples).item()
148 |
--------------------------------------------------------------------------------
/detikzify/infer/tikz.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from functools import cache, cached_property
3 | from io import BytesIO
4 | from os import environ
5 | from os.path import isfile, join
6 | from re import MULTILINE, escape, findall, search
7 | from subprocess import CalledProcessError, DEVNULL, TimeoutExpired
8 | from tempfile import NamedTemporaryFile, TemporaryDirectory
9 | from typing import Dict, Optional, Union
10 |
11 | from PIL import Image
12 | from pdf2image.pdf2image import convert_from_bytes
13 | from pdfCropMargins import crop
14 | import pymupdf
15 | from transformers.utils import logging
16 |
17 | from ..util import check_output, expand, redact as redact_text
18 |
19 | logger = logging.get_logger("transformers")
20 |
21 | class TikzDocument:
22 | """
23 | Facilitate some operations with TikZ code. To compile the images a full
24 | TeXLive installation is assumed to be on the PATH. Cropping additionally
25 | requires Ghostscript, and rasterization needs poppler.
26 | """
27 | # engines to try, could also try: https://tex.stackexchange.com/a/495999
28 | engines = ["pdflatex", "lualatex", "xelatex"]
29 | Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""])
30 |
31 | def __init__(self, code: str, timeout: Optional[int] = 60):
32 | self.code = code
33 | self.timeout = timeout
34 | # https://stackoverflow.com/a/68550238
35 | self.compile = cache(self.compile)
36 |
37 | @property
38 | def status(self) -> int:
39 | return self.compile().status
40 |
41 | @property
42 | def pdf(self) -> Optional[pymupdf.Document]: # type: ignore
43 | return self.compile().pdf
44 |
45 | @property
46 | def log(self) -> str:
47 | return self.compile().log
48 |
49 | @property
50 | def compiled_with_errors(self) -> bool:
51 | return self.status != 0
52 |
53 | @property
54 | def errors(self, rootfile: Optional[str] = None) -> Dict[int, str]:
55 | """
56 | Returns a dict of (linenr, errormsg) pairs. linenr==0 is a special
57 | value reserved for errors that do not have a linenumber in rootfile.
58 | """
59 | if self.compiled_with_errors:
60 | if not rootfile and (match:=search(r"^\((.+)$", self.log, MULTILINE)):
61 | rootfile = match.group(1)
62 | else:
63 | ValueError("rootfile not found!")
64 |
65 | errors = dict()
66 | for file, line, error in findall(r'^(.+):(\d+):(.+)$', self.log, MULTILINE):
67 | if file == rootfile:
68 | errors[int(line)] = error.strip()
69 | else: # error occurred in other file
70 | errors[0] = error.strip()
71 |
72 | return errors or {0: "Fatal error occurred, no output PDF file produced!"}
73 | return dict()
74 |
75 | @cached_property
76 | def is_rasterizable(self) -> bool:
77 | """true if we have an image"""
78 | return self.rasterize() is not None
79 |
80 | @cached_property
81 | def has_content(self) -> bool:
82 | """true if we have an image that isn't empty"""
83 | return (img:=self.rasterize()) is not None and img.getcolors(1) is None
84 |
85 | @classmethod
86 | def set_engines(cls, engines: Union[str, list]):
87 | cls.engines = [engines] if isinstance(engines, str) else engines
88 |
89 | def compile(self) -> "Output":
90 | output = dict()
91 | with TemporaryDirectory() as tmpdirname:
92 | with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile:
93 | codelines = self.code.split("\n")
94 | # make sure we don't have page numbers in compiled pdf (for cropping)
95 | codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}"))
96 | tmpfile.write("\n".join(codelines).encode())
97 |
98 | try:
99 | # compile
100 | errorln, tmppdf, outpdf = -1, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf")
101 | open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile
102 |
103 | def try_save_last_page():
104 | try:
105 | doc = pymupdf.open(tmppdf)
106 | doc.select([len(doc)-1])
107 | doc.save(outpdf)
108 | except:
109 | pass
110 |
111 | for engine in self.engines:
112 | try:
113 | check_output(
114 | cwd=tmpdirname,
115 | timeout=self.timeout,
116 | stderr=DEVNULL,
117 | env=environ | dict(max_print_line="1000"), # improve formatting of log
118 | args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name]
119 | )
120 | except (CalledProcessError, TimeoutExpired) as proc:
121 | log = (getattr(proc, "output", b'') or b'').decode(errors="ignore")
122 | error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE)
123 | # only update status and log if first error occurs later than in previous engine
124 | if (linenr:=int(error.group(1)) if error else 0) > errorln:
125 | errorln = linenr
126 | output.update(status=getattr(proc, 'returncode', -1), log=log)
127 | try_save_last_page()
128 | else:
129 | output.update(status=0, log='')
130 | try_save_last_page()
131 | break
132 |
133 | # crop
134 | croppdf = f"{tmpfile.name}.crop"
135 | crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True)
136 | if isfile(croppdf):
137 | output['pdf'] = pymupdf.open(croppdf)
138 |
139 | except FileNotFoundError:
140 | logger.error("Missing dependencies: Did you install TeX Live?")
141 | except RuntimeError: # pdf error during cropping
142 | pass
143 |
144 | if output.get("status") == 0 and not output.get("pdf", None):
145 | logger.warning("Could compile document but something seems to have gone wrong during cropping!")
146 |
147 | return self.Output(**output)
148 |
149 | def rasterize(self, size=420, expand_to_square=True, redact=False, **redact_kwargs) -> Optional[Image.Image]:
150 | if pdf:=self.pdf:
151 | if redact:
152 | pdf = redact_text(pdf, **redact_kwargs)
153 | image = convert_from_bytes(pdf.tobytes(), size=size, single_file=True)[0]
154 | if expand_to_square:
155 | return expand(image, size)
156 | return image
157 |
158 | def save(self, filename: str, *args, **kwargs):
159 | match filename.split(".")[-1]:
160 | case "tex": content = self.code.encode()
161 | case "pdf" if self.pdf: content = self.pdf.tobytes()
162 | case fmt if img := self.rasterize(*args, **kwargs):
163 | img.save(imgByteArr:=BytesIO(), format=fmt)
164 | content = imgByteArr.getvalue()
165 | case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!")
166 |
167 | with open(filename, "wb") as f:
168 | f.write(content)
169 |
--------------------------------------------------------------------------------
/detikzify/webui/strings.py:
--------------------------------------------------------------------------------
1 | from os.path import basename
2 |
3 | from transformers import is_timm_available
4 |
5 | BANNER = '''\
6 | DeTikZify: Synthesizing Graphics Programs for Scientific Figures and Sketches with TikZ
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | '''
23 |
24 | MODELS = {
25 | basename(model): model
26 | for model in [
27 | "nllg/detikzify-v2.5-8b",
28 | "nllg/detikzify-v2-8b",
29 | ]
30 | }
31 |
32 | if is_timm_available():
33 | MODELS |= {
34 | basename(model).replace("detikzify", "detikzify-v1"): model
35 | for model in [
36 | "nllg/detikzify-ds-7b",
37 | "nllg/detikzify-cl-7b",
38 | "nllg/detikzify-ds-1.3b",
39 | "nllg/detikzify-tl-1.1b",
40 | ]
41 | }
42 |
43 | ALGORITHMS = {
44 | "mcts": "MCTS",
45 | "sampling": "Sampling"
46 | }
47 |
48 | # https://github.com/gradio-app/gradio/issues/3202#issuecomment-1741571240
49 | # https://github.com/gradio-app/gradio/issues/2666#issuecomment-1651127149
50 | # https://stackoverflow.com/a/64033350
51 | CSS = """
52 | .input-image {
53 | flex-grow: 1;
54 | }
55 | .output-code {
56 | flex-grow: 1;
57 | height: 0vh;
58 | min-height: 250px;
59 | scrollbar-width: thin !important;
60 | }
61 | .output-code .hide {
62 | display: none;
63 | }
64 | .output-code .cm-scroller {
65 | flex-grow: 1;
66 | }
67 | .output-code .cm-gutters {
68 | position: relative !important;
69 | }
70 | .output-image {
71 | flex-grow: 1;
72 | height: 0vh;
73 | min-height: 250px;
74 | overflow-y: auto !important;
75 | scrollbar-width: thin !important;
76 | }
77 | .output-image .image-container, .output-image .grid-container {
78 | width: 100%;
79 | height: 100%;
80 | }
81 | .output-image .thumbnail-item img {
82 | object-fit: contain;
83 | }
84 | .output-image .grid-wrap.fixed-height {
85 | max-height: 100% !important;
86 | }
87 | .outputs .tabs {
88 | display: flex;
89 | flex-direction: column;
90 | flex-grow: 1;
91 | }
92 | .outputs .tabitem[style="display: block;"] {
93 | flex-grow: 1;
94 | display: flex !important;
95 | }
96 | .outputs .gap {
97 | flex-grow: 1;
98 | }
99 | .outputs .form {
100 | flex-grow: 1 !important;
101 | }
102 | .outputs .form > :last-child{
103 | flex-grow: 1;
104 | }
105 | """
106 |
107 | # (Ab)use an invisible fake button with id preview-close to propagate the
108 | # actual press of the button that closes a preview
109 | # https://github.com/gradio-app/gradio/issues/6697
110 | GALLERY_DESELECT_HACK = """
111 |
131 | """
132 |
--------------------------------------------------------------------------------
/detikzify/train/adapter/train.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from datetime import timedelta
3 | import os
4 | from typing import Dict, List
5 |
6 | from accelerate import Accelerator, InitProcessGroupKwargs
7 | from datasets import Dataset
8 | import torch
9 | from torch.utils.data import Dataset
10 | from transformers import Trainer, TrainingArguments
11 | from transformers.trainer_utils import get_last_checkpoint
12 | from transformers.utils import logging
13 |
14 | from ...util import SplitEpochSaveCallback, unwrap_processor
15 |
16 | logger = logging.get_logger("transformers")
17 |
18 | IGNORE_INDEX = -100
19 | WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
20 | RANK = int(os.environ.get("RANK", 0))
21 |
22 | def tokenize(
23 | batch,
24 | processor,
25 | caption_condition=False,
26 | **kwargs
27 | ):
28 | unwrapped_processor = unwrap_processor(processor)
29 | image_token = unwrapped_processor.image_token
30 | image_token_id = unwrapped_processor.tokenizer.convert_tokens_to_ids(image_token)
31 | bos_token = unwrapped_processor.tokenizer.bos_token
32 |
33 | input_ids = processor(
34 | text=batch['caption'],
35 | images_kwargs=dict(
36 | text=[bos_token.join(text) for text in zip(batch['caption'], batch['code'])] if caption_condition else batch['code'],
37 | max_length=unwrapped_processor.tokenizer.model_max_length,
38 | pad_to_multiple_of=8,
39 | add_eos_token=True,
40 | truncation=False,
41 | padding=True
42 | ),
43 | text_kwargs=dict(
44 | padding=True,
45 | truncation=True,
46 | ),
47 | **kwargs
48 | )
49 | input_ids['labels'] = deepcopy(input_ids['input_ids'])
50 |
51 | if caption_condition:
52 | # do not train on caption and pad tokens
53 | for label_ids in input_ids['labels']:
54 | after_bos_token = False
55 | for idx, label_id in enumerate(label_ids):
56 | if not after_bos_token or label_id in {image_token_id, unwrapped_processor.tokenizer.pad_token_id}:
57 | if label_id == unwrapped_processor.tokenizer.bos_token_id:
58 | after_bos_token = True
59 | label_ids[idx] = IGNORE_INDEX
60 | elif label_id == unwrapped_processor.tokenizer.bos_token_id:
61 | after_bos_token = True
62 | else:
63 | # do not train on image and pad tokens
64 | for label_ids in input_ids['labels']:
65 | for idx, label_id in enumerate(label_ids):
66 | if label_id in {image_token_id, processor.tokenizer.pad_token_id}:
67 | label_ids[idx] = IGNORE_INDEX
68 |
69 | return input_ids
70 |
71 | class AdapterDataset(Dataset):
72 | def __init__(self, dataset, processor, caption_condition=False):
73 | super().__init__()
74 | self.processor = processor
75 | self.dataset = dataset.with_transform(self.tokenize)
76 | self.caption_condition = caption_condition
77 |
78 | def __len__(self):
79 | return len(self.dataset)
80 |
81 | def tokenize(self, batch):
82 | return tokenize(
83 | batch=batch,
84 | processor=self.processor,
85 | caption_condition=self.caption_condition,
86 | return_tensors="pt",
87 | )
88 |
89 | def filter(self, *args, **kwargs):
90 | self.dataset = self.dataset.filter(*args, **kwargs)
91 |
92 | def __getitem__(self, index) -> Dict[str, torch.Tensor]:
93 | return self.dataset[index]
94 |
95 | def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]:
96 | return self.dataset[*indices]
97 |
98 | def train(
99 | output_dir: str,
100 | model,
101 | processor,
102 | dataset,
103 | overwrite=False,
104 | deepspeed=None,
105 | # training hyperparams
106 | caption_condition: bool = False,
107 | batch_size: int = 128,
108 | micro_batch_size: int = 1,
109 | num_epochs: int = 5,
110 | learning_rate: float = 5e-5,
111 | gradient_checkpointing: bool = False,
112 | ):
113 | gradient_accumulation_steps = batch_size // micro_batch_size
114 | if WORLD_SIZE != 1:
115 | gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE
116 |
117 | for _, param in model.model.vision_model.named_parameters():
118 | param.requires_grad = False
119 | for _, param in model.adapter.named_parameters():
120 | param.requires_grad = False
121 | for _, param in model.embedding_model.named_parameters():
122 | param.requires_grad = False
123 | model.enable_input_require_grads()
124 | model.embedding_model.enable_input_require_grads()
125 |
126 | dataset = AdapterDataset(dataset, processor, caption_condition=caption_condition)
127 | logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}")
128 | eos_token_id, model_max_length = unwrap_processor(processor).tokenizer.eos_token_id, unwrap_processor(processor).tokenizer.model_max_length
129 | with Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(days=3))]).main_process_first():
130 | dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length, num_proc=64, batch_size=16)
131 | logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}")
132 |
133 | last_checkpoint = None
134 | if os.path.isdir(output_dir) and not overwrite:
135 | last_checkpoint = get_last_checkpoint(output_dir)
136 | if last_checkpoint is None and len(os.listdir(output_dir)) > 0:
137 | raise ValueError(
138 | f"Output directory ({output_dir}) already exists and is not empty. "
139 | "Use `overwrite` to overcome."
140 | )
141 | elif last_checkpoint is not None:
142 | logger.info(
143 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
144 | "the `output_dir` or add `overwrite` to train from scratch."
145 | )
146 |
147 | trainer = Trainer(
148 | model=model,
149 | train_dataset=dataset,
150 | args=TrainingArguments(
151 | per_device_train_batch_size=micro_batch_size,
152 | gradient_accumulation_steps=gradient_accumulation_steps,
153 | gradient_checkpointing=gradient_checkpointing,
154 | # https://github.com/huggingface/transformers/issues/32576
155 | #gradient_checkpointing_kwargs={'use_reentrant':False},
156 | dataloader_num_workers=WORLD_SIZE,
157 | warmup_ratio=0.03,
158 | weight_decay=0,
159 | num_train_epochs=num_epochs,
160 | learning_rate=learning_rate,
161 | torch_compile=True,
162 | bf16=True,
163 | tf32=True,
164 | logging_steps=10,
165 | lr_scheduler_type="cosine",
166 | optim="adamw_torch" if deepspeed else "adamw_torch_fused",
167 | ddp_find_unused_parameters=False,
168 | remove_unused_columns=False,
169 | save_strategy="epoch",
170 | report_to="none",
171 | save_total_limit=1,
172 | output_dir=output_dir,
173 | deepspeed=deepspeed,
174 | ),
175 | callbacks=[SplitEpochSaveCallback(step_size=0.25)],
176 | data_collator=lambda batch: batch
177 | )
178 |
179 | trainer.add_callback(trainer.train_dataset)
180 | trainer.train(resume_from_checkpoint=last_checkpoint)
181 |
182 | if trainer.is_deepspeed_enabled:
183 | # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading
184 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
185 | last_checkpoint = get_last_checkpoint(output_dir)
186 | load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint)
187 |
188 | trainer.model.unload_cross_attn_adapter()
189 | trainer.save_model(output_dir)
190 | trainer.save_state()
191 | processor.processor.save_pretrained(output_dir)
192 |
193 | return model, processor
194 |
--------------------------------------------------------------------------------
/examples/eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu
2 | from argparse import ArgumentParser
3 | from collections import defaultdict
4 | from datetime import timedelta
5 | from functools import partial
6 | from itertools import count
7 | from json import dump, load as load_json
8 | from operator import itemgetter
9 | from os import getenv
10 | from os.path import isfile, join
11 | from time import time
12 |
13 | from datasets import load_dataset
14 | from numpy import array
15 | from scipy.stats.mstats import winsorize
16 | from torch import bfloat16, distributed as dist, float16
17 | from torch.cuda import is_available as is_cuda_available, is_bf16_supported
18 | from tqdm import tqdm
19 | from transformers import set_seed
20 | from transformers.utils import is_flash_attn_2_available
21 |
22 | from detikzify.evaluate import (
23 | ClipScore,
24 | CrystalBLEU,
25 | DreamSim,
26 | ImageSim,
27 | KernelInceptionDistance,
28 | TexEditDistance,
29 | )
30 | from detikzify.infer import DetikzifyPipeline, TikzDocument
31 | from detikzify.model import adapter, load as load_model
32 |
33 | WORLD_SIZE = int(getenv("WORLD_SIZE", 1))
34 | RANK = int(getenv("RANK", 0))
35 |
36 | def parse_args():
37 | argument_parser = ArgumentParser(
38 | description="Evaluate fine-tuned models."
39 | )
40 | argument_parser.add_argument(
41 | "--cache_dir",
42 | help="directory where model outputs should be saved to",
43 | )
44 | argument_parser.add_argument(
45 | "--trainset",
46 | default="nllg/datikz-v3",
47 | help="path or name of the DaTikZ train set",
48 | )
49 | argument_parser.add_argument(
50 | "--testset",
51 | required=True,
52 | help="path to the DaTikZ test split processed by the ./sketchify script (in parquet format)",
53 | )
54 | argument_parser.add_argument(
55 | "--output",
56 | required=True,
57 | help="where to save the computed scores (as json)",
58 | )
59 | argument_parser.add_argument(
60 | "--timeout",
61 | type=int,
62 | help="minimum time to run MCTS in seconds",
63 | )
64 | argument_parser.add_argument(
65 | "--model_inputs",
66 | default="image",
67 | choices=["image", "sketch", "caption", "caption-image", "caption-sketch"],
68 | help="which inputs to condition the model on",
69 | )
70 | argument_parser.add_argument(
71 | "--path",
72 | nargs='+',
73 | metavar="MODEL=PATH[:ADAPTER] | MODEL=JSON",
74 | required=True,
75 | help="(multiple) key-value pairs of model names and paths/urls to models and optionally adapters or json files",
76 | )
77 | return argument_parser.parse_args()
78 |
79 | # https://stackoverflow.com/a/54802737
80 | def chunk(l, n):
81 | """Yield n number of striped chunks from l."""
82 | for i in range(0, n):
83 | yield l[i::n]
84 |
85 | def interleave(chunks):
86 | """Interleave chunks until one is exhausted."""
87 | interleaved = list()
88 | for idx in count():
89 | try:
90 | interleaved.extend(chunk[idx] for chunk in chunks)
91 | except IndexError:
92 | break
93 | return interleaved
94 |
95 | def generate(pipe, item, model_inputs, strict=False, timeout=None, **tqdm_kwargs):
96 | """Run MCTS until the generated tikz code compiles."""
97 | start, success, tikzpics = time(), False, set()
98 | inputs = {"text" if key == "caption" else "image": item[key] for key in model_inputs.split("-")}
99 |
100 | for score, tikzpic in tqdm(pipe.simulate(**inputs), desc="Try", **tqdm_kwargs):
101 | tikzpics.add((score, tikzpic.code))
102 | if not tikzpic.compiled_with_errors if strict else tikzpic.is_rasterizable:
103 | success = True
104 | if success and (not timeout or time() - start >= timeout):
105 | break
106 | return [tikzpic for _, tikzpic in sorted(tikzpics, key=itemgetter(0))]
107 |
108 | def predict(model_name, base_model, testset, model_inputs="image", adapter_model=None, cache_file=None, timeout=None):
109 | predictions, worker_preds = list(), list()
110 | model, processor = load_model(
111 | model_name_or_path=base_model,
112 | device_map=RANK,
113 | torch_dtype=bfloat16 if is_cuda_available() and is_bf16_supported() else float16,
114 | attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
115 | )
116 | if adapter_model is not None:
117 | model, processor = adapter.load(model, processor, adapter_model)
118 | # if we don't have a timeout (i.e., only run mcts until we obtain smth compileable), we can use fast metrics
119 | pipe = DetikzifyPipeline(model=model, processor=processor, metric="model" if timeout else "fast")
120 |
121 | if cache_file and isfile(cache_file):
122 | with open(cache_file) as f:
123 | predictions = load_json(f)
124 | try:
125 | worker_chunk = list(chunk(list(testset)[len(predictions):], WORLD_SIZE))[RANK]
126 | # FIXME: right now there only is a progress bar for Rank 0
127 | for item in tqdm(worker_chunk, desc=f"{model_name.title()} ({RANK})", disable=RANK!=0):
128 | tikz = generate(pipe, item, model_inputs, timeout=timeout, position=1, leave=False, disable=RANK!=0)
129 | worker_preds.append(tikz)
130 | del model, processor, pipe
131 | finally:
132 | dist.all_gather_object(gathered:=WORLD_SIZE * [None], worker_preds)
133 | predictions.extend(interleave(gathered))
134 | if cache_file and RANK == 0:
135 | with open(cache_file, 'w') as f:
136 | dump(predictions, f)
137 | return predictions
138 |
139 | def load_metrics(trainset, measure_throughput=False, **kwargs):
140 | bleu = CrystalBLEU(corpus=trainset, **kwargs)
141 | eed = TexEditDistance(**kwargs)
142 | clip = ClipScore(**kwargs)
143 | imgsim = ImageSim(**kwargs)
144 | dreamsim = DreamSim(**kwargs)
145 | kid = KernelInceptionDistance(**kwargs)
146 |
147 | def mean_token_efficiency(predictions, limit=0.05):
148 | samples = list()
149 | for preds in predictions:
150 | samples.append(len(preds[-1].code)/sum(len(pred.code) for pred in preds))
151 | return winsorize(array(samples), limits=limit).mean().item()
152 |
153 | def mean_sampling_throughput(predictions, limit=0.05):
154 | return winsorize(array(list(map(len, predictions))), limits=limit).mean().item()
155 |
156 | def compute(references, predictions, compute_redacted=True, **redact_kwargs):
157 | ref_code, pred_code = [[ref['code']] for ref in references], [pred[-1].code for pred in predictions]
158 | ref_image, pred_image = [ref['image'] for ref in references], [pred[-1].rasterize() for pred in predictions]
159 | captions = [ref['caption'] for ref in references]
160 | assert all(pred[-1].is_rasterizable for pred in predictions)
161 |
162 | if measure_throughput:
163 | scores = {"MeanSamplingThroughput": mean_sampling_throughput(predictions=predictions)}
164 | else:
165 | scores = {"MeanTokenEfficiency": mean_token_efficiency(predictions=predictions)}
166 |
167 | redacted_metrics, standard_metrics = {}, {
168 | bleu: partial(bleu.update, list_of_references=ref_code, hypotheses=pred_code),
169 | eed: partial(eed.update, target=ref_code, preds=pred_code),
170 | clip: partial(clip.update, text=captions, images=pred_image),
171 | imgsim: lambda: [imgsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_image)],
172 | dreamsim: lambda: [dreamsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_image)],
173 | kid: lambda: [(kid.update(img1, True), kid.update(img2, False)) for img1, img2 in zip(ref_image, pred_image)],
174 | }
175 |
176 | if compute_redacted:
177 | pred_redacted = [pred[-1].rasterize(redact=True, **redact_kwargs) for pred in predictions]
178 | redacted_metrics.update({
179 | clip: partial(clip.update, text=captions, images=pred_redacted),
180 | imgsim: lambda: [imgsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_redacted)],
181 | dreamsim: lambda: [dreamsim.update(img1=img1, img2=img2) for img1, img2 in zip(ref_image, pred_redacted)],
182 | })
183 |
184 | for metrics, redacted in [(standard_metrics, False), (redacted_metrics, True)]:
185 | for metric, update in metrics.items():
186 | update()
187 | if redacted:
188 | scores[f"Redacted {str(metric)}"] = metric.compute()
189 | else:
190 | scores[str(metric)] = metric.compute() # type: ignore
191 | metric.reset()
192 |
193 | return scores
194 |
195 | return compute
196 |
197 | if __name__ == "__main__":
198 | set_seed(0)
199 | dist.init_process_group(timeout=timedelta(days=3))
200 | args = parse_args()
201 |
202 | trainset = load_dataset(args.trainset, split="train")
203 | testset = load_dataset("parquet", data_files={"test": args.testset}, split="test").sort("caption") # type: ignore
204 |
205 | predictions = defaultdict(list)
206 | for model_name, path in map(lambda s: s.split('='), tqdm(args.path, desc="Predicting")):
207 | if path.endswith("json"):
208 | with open(path) as f:
209 | predictions[model_name] = load_json(f)
210 | else:
211 | cache_file = join(args.cache_dir, f'{model_name}.json') if args.cache_dir else None
212 | predictions[model_name] = predict(
213 | model_name=model_name,
214 | base_model=path.partition(":")[0],
215 | adapter_model=path.partition(":")[2] or None,
216 | model_inputs=args.model_inputs,
217 | testset=testset,
218 | cache_file=cache_file,
219 | timeout=args.timeout,
220 | )
221 |
222 | if RANK == 0: # Scoring only on main process
223 | scores = dict()
224 | metrics = load_metrics(trainset['code'], measure_throughput=args.timeout is not None, sync_on_compute=False) # type: ignore
225 | for model_name, prediction in tqdm(predictions.items(), desc="Computing metrics", total=len(predictions)):
226 | scores[model_name] = metrics(
227 | references=testset,
228 | # use an unrealistically long timeout as we know that the (last) images compile
229 | predictions=[[TikzDocument(code, 600) for code in pred] for pred in prediction],
230 | rot_13=True
231 | )
232 | with open(args.output, "w") as file:
233 | dump(scores, file)
234 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/detikzify/model/v1/processing_detikzify.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/huggingface/optimum-intel/blob/main/optimum/intel/openvino/modeling_timm.py Below is the original copyright:
2 | # Copyright 2024 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from typing import Dict, List, Optional, Union
18 |
19 | import numpy as np
20 | from timm.data import resolve_data_config
21 | from timm.models import resolve_pretrained_cfg
22 | from transformers.image_processing_utils import (
23 | BaseImageProcessor,
24 | BatchFeature,
25 | get_size_dict,
26 | )
27 | from transformers.image_transforms import resize, to_channel_dimension_format
28 | from transformers.image_utils import (
29 | ChannelDimension,
30 | IMAGENET_STANDARD_MEAN,
31 | IMAGENET_STANDARD_STD,
32 | ImageInput,
33 | PILImageResampling,
34 | make_list_of_images,
35 | to_numpy_array,
36 | valid_images,
37 | )
38 | from transformers.utils import TensorType
39 |
40 |
41 | class DetikzifyImageProcessor(BaseImageProcessor):
42 | r"""
43 | Constructs a ViT image processor.
44 |
45 | Args:
46 | do_resize (`bool`, *optional*, defaults to `True`):
47 | Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
48 | size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
49 | size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
50 | Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
51 | method.
52 | resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
53 | Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
54 | `preprocess` method.
55 | do_rescale (`bool`, *optional*, defaults to `True`):
56 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
57 | parameter in the `preprocess` method.
58 | rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
59 | Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
60 | `preprocess` method.
61 | do_normalize (`bool`, *optional*, defaults to `True`):
62 | Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
63 | method.
64 | image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
65 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of
66 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
67 | image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
68 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
69 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
70 | """
71 |
72 | model_input_names = ["pixel_values"]
73 |
74 | def __init__(
75 | self,
76 | do_resize: bool = True,
77 | size: Optional[Dict[str, int]] = None,
78 | resample: PILImageResampling = PILImageResampling.BILINEAR,
79 | do_rescale: bool = True,
80 | rescale_factor: Union[int, float] = 1 / 255,
81 | do_normalize: bool = True,
82 | image_mean: Optional[Union[float, List[float]]] = None,
83 | image_std: Optional[Union[float, List[float]]] = None,
84 | **kwargs,
85 | ) -> None:
86 | super().__init__(**kwargs)
87 | size = size if size is not None else {"height": 224, "width": 224}
88 | size = get_size_dict(size)
89 | self.do_resize = do_resize
90 | self.do_rescale = do_rescale
91 | self.do_normalize = do_normalize
92 | self.size = size
93 | self.resample = resample
94 | self.rescale_factor = rescale_factor
95 | self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
96 | self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
97 |
98 | @classmethod
99 | def from_pretrained(
100 | cls,
101 | pretrained_model_name_or_path: Union[str, os.PathLike],
102 | **kwargs,
103 | ):
104 | pretrained_cfg = resolve_pretrained_cfg(variant=pretrained_model_name_or_path)
105 | timm_config_dict = resolve_data_config(pretrained_cfg.to_dict())
106 |
107 | _, im_h, im_w = timm_config_dict.get("input_size", [3, 224, 224])
108 |
109 | image_preprocess_config_dict = {
110 | "crop_size": {"height": im_h, "width": im_w},
111 | "do_center_crop": True if timm_config_dict.get("crop_mode") == "center" else False,
112 | "do_normalize": True,
113 | "do_reduce_labels": False,
114 | "do_rescale": True,
115 | "do_resize": True,
116 | "image_mean": timm_config_dict.get("mean", IMAGENET_STANDARD_MEAN),
117 | "image_processor_type": "TimmImageProcessor",
118 | "image_std": timm_config_dict.get("std", IMAGENET_STANDARD_STD),
119 | "resample": 3,
120 | "rescale_factor": 0.00392156862745098,
121 | "size": {"height": im_h, "width": im_w},
122 | }
123 |
124 | return cls.from_dict(image_preprocess_config_dict, **kwargs)
125 |
126 | def resize(
127 | self,
128 | image: np.ndarray,
129 | size: Dict[str, int],
130 | resample: PILImageResampling = PILImageResampling.BILINEAR,
131 | data_format: Optional[Union[str, ChannelDimension]] = None,
132 | **kwargs,
133 | ) -> np.ndarray:
134 | """
135 | Resize an image to `(size["height"], size["width"])`.
136 |
137 | Args:
138 | image (`np.ndarray`):
139 | Image to resize.
140 | size (`Dict[str, int]`):
141 | Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
142 | resample:
143 | `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
144 | data_format (`ChannelDimension` or `str`, *optional*):
145 | The channel dimension format for the output image. If unset, the channel dimension format of the input
146 | image is used. Can be one of:
147 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
148 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
149 |
150 | Returns:
151 | `np.ndarray`: The resized image.
152 | """
153 | size = get_size_dict(size)
154 | if "height" not in size or "width" not in size:
155 | raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
156 | if image.ndim == 2:
157 | image = np.stack([image] * 3, axis=-1)
158 | return resize(
159 | image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
160 | )
161 |
162 | def preprocess(
163 | self,
164 | images: ImageInput,
165 | do_resize: Optional[bool] = None,
166 | size: Dict[str, int] = None,
167 | resample: PILImageResampling = None,
168 | do_rescale: Optional[bool] = None,
169 | rescale_factor: Optional[float] = None,
170 | do_normalize: Optional[bool] = None,
171 | image_mean: Optional[Union[float, List[float]]] = None,
172 | image_std: Optional[Union[float, List[float]]] = None,
173 | return_tensors: Optional[Union[str, TensorType]] = None,
174 | data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
175 | **kwargs,
176 | ):
177 | """
178 | Preprocess an image or batch of images.
179 |
180 | Args:
181 | images (`ImageInput`):
182 | Image to preprocess.
183 | do_resize (`bool`, *optional*, defaults to `self.do_resize`):
184 | Whether to resize the image.
185 | size (`Dict[str, int]`, *optional*, defaults to `self.size`):
186 | Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
187 | resizing.
188 | resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
189 | `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
190 | an effect if `do_resize` is set to `True`.
191 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
192 | Whether to rescale the image values between [0 - 1].
193 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
194 | Rescale factor to rescale the image by if `do_rescale` is set to `True`.
195 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
196 | Whether to normalize the image.
197 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
198 | Image mean to use if `do_normalize` is set to `True`.
199 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
200 | Image standard deviation to use if `do_normalize` is set to `True`.
201 | return_tensors (`str` or `TensorType`, *optional*):
202 | The type of tensors to return. Can be one of:
203 | - Unset: Return a list of `np.ndarray`.
204 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
205 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
206 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
207 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
208 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
209 | The channel dimension format for the output image. Can be one of:
210 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
211 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
212 | - Unset: Use the channel dimension format of the input image.
213 | """
214 | do_resize = do_resize if do_resize is not None else self.do_resize
215 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale
216 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize
217 | resample = resample if resample is not None else self.resample
218 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
219 | image_mean = image_mean if image_mean is not None else self.image_mean
220 | image_std = image_std if image_std is not None else self.image_std
221 |
222 | size = size if size is not None else self.size
223 | size_dict = get_size_dict(size)
224 |
225 | images = make_list_of_images(images)
226 |
227 | if not valid_images(images):
228 | raise ValueError(
229 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
230 | "torch.Tensor, tf.Tensor or jax.ndarray."
231 | )
232 |
233 | if do_resize and size is None:
234 | raise ValueError("Size must be specified if do_resize is True.")
235 |
236 | if do_rescale and rescale_factor is None:
237 | raise ValueError("Rescale factor must be specified if do_rescale is True.")
238 |
239 | # All transformations expect numpy arrays.
240 | images = [to_numpy_array(image) for image in images]
241 |
242 | if do_resize:
243 | images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
244 |
245 | if do_rescale:
246 | images = [self.rescale(image=image, scale=rescale_factor) for image in images]
247 |
248 | if do_normalize:
249 | images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
250 |
251 | images = [to_channel_dimension_format(image, data_format) for image in images]
252 | data = {"pixel_values": images}
253 | return BatchFeature(data=data, tensor_type=return_tensors)
254 |
--------------------------------------------------------------------------------
/examples/refine.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env -S torchrun --nproc_per_node gpu
2 | from argparse import ArgumentParser
3 | from collections import ChainMap
4 | from datetime import timedelta
5 | from functools import cached_property, partial
6 | from operator import itemgetter
7 | from os import environ
8 | from os.path import basename
9 | from random import choice, random
10 | from sys import exit
11 | from typing import Optional
12 |
13 | from accelerate import Accelerator
14 | from datasets import (
15 | Dataset,
16 | DatasetDict,
17 | Features,
18 | Image,
19 | Value,
20 | concatenate_datasets,
21 | load_dataset,
22 | )
23 | from torch.distributed import init_process_group
24 | from torch.multiprocessing import Pool, set_start_method
25 | from transformers import set_seed
26 | from transformers.trainer_utils import get_last_checkpoint
27 | from transformers.utils import is_flash_attn_2_available
28 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info
29 |
30 | from detikzify.evaluate.imagesim import ImageSim
31 | from detikzify.infer.tikz import TikzDocument
32 | from detikzify.model import load
33 | from detikzify.util import SketchAugment, batchify, expand
34 | from sketchify import Sketchifier
35 |
36 | try:
37 | from trl import GRPOConfig, GRPOTrainer
38 | except ModuleNotFoundError:
39 | print(
40 | "You need to install trl with vision support to be able to use this script:",
41 | "git clone https://github.com/hellopahe/trl.git",
42 | "curl -L http://github.com/huggingface/trl/pull/3568.patch | git -C trl apply",
43 | "pip install trl",
44 | sep="\n\t"
45 | )
46 | exit(1)
47 |
48 |
49 | WORLD_SIZE = int(environ.get("WORLD_SIZE", 1))
50 |
51 |
52 | class RandSketchifier:
53 | def __init__(self, size, sketch_ratio):
54 | self.sketchifier = Sketchifier()
55 | self.deep_sketchifier = SketchAugment()
56 | self.size = size
57 | self.sketch_ratio = sketch_ratio
58 |
59 | @staticmethod
60 | def randbool():
61 | return choice([True, False])
62 |
63 | def sketchify(self, img):
64 | return self.sketchifier(img) if self.randbool() else self.deep_sketchifier(img)
65 |
66 | def randsketchify(self, img):
67 | return self.sketchify(img) if random() < self.sketch_ratio else img
68 |
69 | def __call__(self, img):
70 | return self.randsketchify(expand(img, self.size, do_trim=True))
71 |
72 |
73 | class TrainDataset:
74 | def __init__(self, processor, datikz_name="nllg/datikz-v3", size=420, sketch_ratio=.5):
75 | self.processor = processor
76 | self.datikz: DatasetDict = load_dataset(datikz_name) # type: ignore
77 | self.sketchify = RandSketchifier(
78 | size=size,
79 | sketch_ratio=sketch_ratio
80 | )
81 |
82 | @staticmethod
83 | def get_fig_type(ftype):
84 | for type_ in ["table", "photograph", "plot", "schematic", "other"]:
85 | if type_ in ftype.lower():
86 | return type_
87 | return "N/A"
88 |
89 | @batchify
90 | def extract_figures(self, batch, meta_ds, filter_urls):
91 | for img in batch['image']:
92 | filename = basename(img['path'].split("::")[0])
93 | if filename in meta_ds and filename.rpartition("v")[0] not in filter_urls:
94 | meta = meta_ds[filename]
95 | ftype = self.get_fig_type(meta["figure_type"])
96 | # "other" are mostly text snippets
97 | if meta["content_type"] == "figure":
98 | yield dict(image=img, type=ftype)
99 |
100 | def sample_spiqa_dataset(self, n_samples, split="train"):
101 | img_ds: Dataset = load_dataset( # type: ignore
102 | path="google/spiqa",
103 | data_files="train_val/SPIQA_train_val_Images.zip",
104 | split="train",
105 | features=Features({"image": Image(decode=False), "label": Value("string")})
106 | )
107 |
108 | meta_ds = load_dataset(path="google/spiqa", data_files=f"train_val/SPIQA_{split}.json", split="train")
109 | meta_ds = ChainMap(*map(itemgetter("all_figures"), meta_ds[0].values())) # type: ignore
110 |
111 | filter_urls = concatenate_datasets(list(self.datikz.values()))['uri']
112 | filter_urls = {basename(url) for url in filter_urls if url.startswith("https://arxiv.org")}
113 |
114 | img_ds = img_ds.map(
115 | self.extract_figures,
116 | batched=True,
117 | remove_columns=img_ds.column_names,
118 | fn_kwargs=dict(meta_ds=meta_ds, filter_urls=filter_urls)
119 | ).shuffle()
120 |
121 | schematics = img_ds.filter(lambda ex: ex['type'] == "schematic").select(range(round(.6 * n_samples)))
122 | plots = img_ds.filter(lambda ex: ex['type'] == "plot").select(range(round(.2 * n_samples)))
123 | other = img_ds.filter(lambda ex: ex['type'] not in ["plot", "schematic"]).select(range(n_samples - len(schematics) - len(plots)))
124 |
125 | for ex in concatenate_datasets([schematics, plots, other]).cast_column("image", Image()):
126 | yield {"prompt": "", "images": self.sketchify(ex['image'])} # type: ignore
127 |
128 | def sample_datikz_dataset(self, n_samples, split="train"):
129 | tokenizer, image_seq_len = self.processor.tokenizer, self.processor.image_seq_len
130 |
131 | datikz_filtered = self.datikz[split].filter(
132 | function=lambda ex: len(tokenizer.tokenize(ex['code'])) + image_seq_len > tokenizer.model_max_length,
133 | ).train_test_split(train_size=n_samples)
134 |
135 | for ex in datikz_filtered['train']:
136 | yield {"prompt": "", "images": self.sketchify(ex['image'])} # type: ignore
137 |
138 | def sample(self, n_samples):
139 | spiqa: Dataset = Dataset.from_generator( # type: ignore
140 | generator=self.sample_spiqa_dataset,
141 | gen_kwargs=dict(n_samples=round(.5 * n_samples))
142 | )
143 | datikz: Dataset = Dataset.from_generator( # type: ignore
144 | generator=self.sample_datikz_dataset,
145 | gen_kwargs=dict(n_samples=n_samples-len(spiqa))
146 | )
147 |
148 | return concatenate_datasets([spiqa, datikz])
149 |
150 |
151 | class RewardFunc:
152 | __name__ = "SelfSim Reward"
153 |
154 | def __init__(self, model, processor, num_workers=1, strict=False):
155 | self.model = model
156 | self.processor = processor
157 | self.strict = strict
158 | self.pool = Pool(num_workers)
159 |
160 | @cached_property
161 | def reward_model(self):
162 | return ImageSim.from_detikzify(self.model, self.processor, sync_on_compute=False)
163 |
164 | @staticmethod
165 | def compile(code, size, strict):
166 | doc = TikzDocument(code)
167 |
168 | if doc.is_rasterizable and not (strict and doc.compiled_with_errors):
169 | return doc.rasterize(size=size)
170 |
171 | def __call__(self, images, completions, **_):
172 | rewards, compile = list(), partial(
173 | self.compile,
174 | size=self.model.config.vision_config.image_size,
175 | strict=self.strict
176 | )
177 |
178 | for doc, img in zip(self.pool.imap(compile, completions), images):
179 | if doc is not None:
180 | self.reward_model.update(doc, img)
181 | rewards.append(self.reward_model.compute())
182 | else:
183 | rewards.append(-1.)
184 | self.reward_model.reset()
185 | return rewards
186 |
187 |
188 | def train(
189 | model,
190 | processor,
191 | dataset,
192 | output_dir: str,
193 | overwrite: bool = False,
194 | deepspeed: Optional[str] = None,
195 | num_compile_workers: int = 4,
196 | # training hyperparams
197 | strict: bool = False,
198 | freeze_encoder: bool = True,
199 | num_generations: int = 32,
200 | batch_size: int = 16,
201 | micro_batch_size: int = 1,
202 | num_train_steps: int = 500,
203 | learning_rate: float = 1e-5,
204 | gradient_checkpointing: bool = False,
205 | ):
206 | for _, param in model.model.vision_model.named_parameters():
207 | param.requires_grad = not freeze_encoder
208 | model.enable_input_require_grads()
209 |
210 | training_args = GRPOConfig(
211 | per_device_train_batch_size=micro_batch_size,
212 | gradient_accumulation_steps=num_generations * batch_size // micro_batch_size // WORLD_SIZE,
213 | num_generations=num_generations,
214 | gradient_checkpointing=gradient_checkpointing,
215 | # https://github.com/huggingface/transformers/issues/32576
216 | gradient_checkpointing_kwargs={'use_reentrant': False},
217 | lr_scheduler_type="cosine",
218 | weight_decay=0.01,
219 | epsilon=0.4,
220 | temperature=max(1., model.generation_config.temperature),
221 | top_p=model.generation_config.top_p,
222 | top_k=model.generation_config.top_k,
223 | max_steps=num_train_steps,
224 | logging_steps=num_train_steps//100,
225 | save_steps=num_train_steps//10,
226 | save_strategy="steps",
227 | save_total_limit=1,
228 | learning_rate=learning_rate,
229 | torch_compile=True,
230 | bf16=True,
231 | tf32=True,
232 | max_completion_length=processor.tokenizer.model_max_length-processor.image_seq_len,
233 | max_prompt_length=None,
234 | optim="adamw_torch" if deepspeed else "adamw_torch_fused",
235 | ddp_find_unused_parameters=False,
236 | remove_unused_columns=False,
237 | report_to="none",
238 | log_completions=True,
239 | num_completions_to_print=1,
240 | output_dir=output_dir,
241 | overwrite_output_dir=overwrite,
242 | deepspeed=deepspeed,
243 | )
244 |
245 | trainer = GRPOTrainer(
246 | model=model,
247 | processing_class=processor,
248 | reward_funcs=[RewardFunc(model, processor, num_workers=num_compile_workers, strict=strict)],
249 | args=training_args,
250 | train_dataset=dataset,
251 | )
252 | trainer.generation_config.bad_words_ids = [[model.config.image_token_id]]
253 | # trainer.generation_config.begin_suppress_tokens = [model.config.text_config.eos_token_id]
254 | trainer.train(resume_from_checkpoint=None if overwrite else get_last_checkpoint(output_dir))
255 |
256 | if trainer.is_deepspeed_enabled:
257 | # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading
258 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint # type: ignore
259 | last_checkpoint = get_last_checkpoint(output_dir)
260 | load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint)
261 |
262 | trainer.save_model(output_dir)
263 |
264 |
265 | def parse_args():
266 | argument_parser = ArgumentParser(
267 | description="Post-train DeTikZify with GRPO."
268 | )
269 | argument_parser.add_argument("--base_model",
270 | required=True,
271 | help="The model checkpoint for weights initialization."
272 | )
273 | argument_parser.add_argument("--datikz",
274 | default="nllg/datikz-v3",
275 | help="Path or name of the DaTikZ dataset.",
276 | )
277 | argument_parser.add_argument("--sketch_ratio",
278 | default=.5,
279 | type=float,
280 | help="ratio of synthetic sketches generated through UltraSketch or image transforms",
281 | )
282 | argument_parser.add_argument("--output",
283 | required=True,
284 | dest="output_dir",
285 | help="directory where to write the model files",
286 | )
287 | argument_parser.add_argument("--num_compile_workers",
288 | default=4,
289 | type=int,
290 | help="number of threads to compile TikZ code with",
291 | )
292 | argument_parser.add_argument("--deepspeed",
293 | help="path to a DeepSpeed json config file",
294 | )
295 | argument_parser.add_argument("--gradient_checkpointing",
296 | action="store_true",
297 | help="use gradient checkpointing",
298 | )
299 | argument_parser.add_argument("--strict",
300 | action="store_true",
301 | help="treat recoverable compilation errors as fatal errors",
302 | )
303 | argument_parser.add_argument("--batch_size",
304 | default=16,
305 | type=int,
306 | help="global batch size for training",
307 | )
308 | argument_parser.add_argument("--num_train_steps",
309 | default=500,
310 | type=int,
311 | help="number of training steps to run GRPO for",
312 | )
313 | return vars(argument_parser.parse_args())
314 |
315 |
316 | if __name__ == "__main__":
317 | set_verbosity_info()
318 | enable_explicit_format()
319 | set_start_method('forkserver') # https://github.com/pytorch/pytorch/issues/17199#issuecomment-465313245
320 | init_process_group("nccl", timeout=timedelta(days=3))
321 | set_seed(0)
322 |
323 | args = parse_args()
324 | model, processor = load(
325 | model_name_or_path=args.pop("base_model"),
326 | torch_dtype="bfloat16",
327 | attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
328 | )
329 |
330 | with Accelerator().main_process_first():
331 | dataset = TrainDataset(
332 | processor=processor,
333 | datikz_name=args.pop('datikz'),
334 | sketch_ratio=args.pop("sketch_ratio"),
335 | size=model.config.vision_config.image_size,
336 | ).sample(args['batch_size'] * args['num_train_steps'])
337 |
338 | train(model=model, processor=processor, dataset=dataset, **args)
339 |
--------------------------------------------------------------------------------
/detikzify/model/v1/modeling_detikzify.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava.py. Below is the original copyright:
2 | # Copyright 2023 Haotian Liu
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from pickle import UnpicklingError
18 | from typing import List, Optional, Tuple, Union
19 |
20 | from numpy import clip
21 | from safetensors.torch import load_file
22 | from timm import create_model as create_vision_model
23 | import torch
24 | import torch.nn as nn
25 | from torch.nn import CrossEntropyLoss
26 | import torch.nn.functional as F
27 | from transformers import (
28 | BatchEncoding,
29 | LlamaConfig,
30 | LlamaForCausalLM,
31 | LlamaModel,
32 | PretrainedConfig,
33 | PreTrainedModel,
34 | )
35 | from transformers.modeling_outputs import (
36 | BaseModelOutputWithPoolingAndNoAttention,
37 | BaseModelOutputWithPast,
38 | CausalLMOutputWithPast,
39 | )
40 | from transformers.utils import logging
41 |
42 | from .configuration_detikzify import DetikzifyConfig
43 | from .processing_detikzify import DetikzifyImageProcessor
44 |
45 |
46 | logger = logging.get_logger("transformers")
47 |
48 |
49 | class DetikzifyVisionModel(PreTrainedModel):
50 | _no_split_modules = ["VisionTransformer"]
51 | def __init__(self, model, **kwargs) -> None:
52 | super().__init__(PretrainedConfig.from_dict(model.pretrained_cfg), **kwargs)
53 | # HACK: wrap in list so that vision model does not count as a parameter
54 | self.model = [model]
55 |
56 | def get_input_embeddings(self) -> torch.nn.Module:
57 | return self.model[0].patch_embed
58 |
59 | def to_input_dtype(self, pixel_values: torch.Tensor):
60 | target_dtype = self.get_input_embeddings().proj.weight.dtype
61 | return pixel_values.to(dtype=target_dtype)
62 |
63 | def forward(self, pixel_values: torch.Tensor):
64 | last_hidden_state = self.model[0].forward_features(self.to_input_dtype(pixel_values))
65 | pooler_output = self.model[0].forward_head(last_hidden_state)
66 | return BaseModelOutputWithPoolingAndNoAttention(
67 | last_hidden_state=last_hidden_state,
68 | pooler_output=pooler_output
69 | )
70 |
71 | def get_intermediate_layers(self, pixel_values: torch.Tensor, *args, **kwargs):
72 | return self.model[0].get_intermediate_layers(self.to_input_dtype(pixel_values), *args, **kwargs)
73 |
74 |
75 | class DetikzifyModel(LlamaModel):
76 | config_class = DetikzifyConfig
77 |
78 | def __init__(self, config: LlamaConfig):
79 | super(DetikzifyModel, self).__init__(config)
80 |
81 | if getattr(config, "use_mm_proj"):
82 | self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
83 | self.vision_model = DetikzifyVisionModel(create_vision_model(config.vision_tower))
84 |
85 | def initialize_vision_modules(
86 | self,
87 | vision_tower,
88 | patch_token_id,
89 | concat_patches=3,
90 | feature_layer=-1,
91 | modality_projector=None,
92 | **kwargs
93 | ):
94 | vision_model = create_vision_model(vision_tower, pretrained=True, **kwargs)
95 | self.vision_model = DetikzifyVisionModel(vision_model.to(self.device, self.dtype).eval().requires_grad_(False))
96 |
97 | processor = DetikzifyImageProcessor.from_pretrained(vision_tower)
98 |
99 | self.config.use_mm_proj = True
100 | self.config.vision_tower = vision_tower
101 | self.config.mm_hidden_size = vision_model.embed_dim * concat_patches
102 | self.config.patch_token_id = patch_token_id
103 | self.config.concat_patches = concat_patches
104 | self.config.feature_layer = int(clip(feature_layer, -(depth:=len(vision_model.blocks)), depth-1) % depth)
105 | self.config.vision_config = processor.to_dict() # type: ignore
106 | self.config.num_patches = vision_model.patch_embed.num_patches // concat_patches
107 |
108 | if not hasattr(self, 'mm_projector'):
109 | self.mm_projector = nn.Linear(
110 | self.config.mm_hidden_size,
111 | self.config.hidden_size,
112 | dtype=self.dtype,
113 | device=self.device
114 | )
115 |
116 | if modality_projector is not None:
117 | try: # first try to load as pickle
118 | mm_projector_weights = torch.load(modality_projector, map_location=self.device)
119 | except UnpicklingError: # and if that fails we try safetensors
120 | mm_projector_weights = load_file(modality_projector, device=str(self.device))
121 | self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
122 |
123 | return processor
124 |
125 | # https://stackoverflow.com/a/57208704
126 | def _apply(self, fn):
127 | super()._apply(fn)
128 | if hasattr(self, "vision_model"):
129 | self.set_vision_model = self.vision_model._apply(fn)
130 | return self
131 |
132 | def get_vision_features(self, pixel_values):
133 | concat, n_patch, layer = self.config.concat_patches, self.config.num_patches, self.config.feature_layer
134 | feats = self.vision_model.get_intermediate_layers(pixel_values, n=[layer], norm=True)[0]
135 | # in case the number of feature vectors is not divisible by the number
136 | # of patches we want to concatenate, we remove the first feature(s)
137 | return feats[:, -n_patch * concat:].reshape(-1, n_patch, feats.shape[-1] * concat)
138 |
139 | def is_tensor(self, thing):
140 | if isinstance(thing, (BatchEncoding, dict)):
141 | return all(isinstance(v, torch.Tensor) for v in thing.values())
142 | return isinstance(thing, torch.Tensor)
143 |
144 | def forward(
145 | self,
146 | input_ids: torch.LongTensor = None,
147 | attention_mask: Optional[torch.Tensor] = None,
148 | past_key_values: Optional[List[torch.FloatTensor]] = None,
149 | inputs_embeds: Optional[torch.FloatTensor] = None,
150 | use_cache: Optional[bool] = None,
151 | output_attentions: Optional[bool] = None,
152 | output_hidden_states: Optional[bool] = None,
153 | pixel_values: Optional[torch.FloatTensor] = None,
154 | return_dict: Optional[bool] = None,
155 | ) -> Union[Tuple, BaseModelOutputWithPast]:
156 |
157 | if inputs_embeds is None:
158 | inputs_embeds = self.embed_tokens(input_ids)
159 |
160 | if hasattr(self, "vision_model") and (input_ids.shape[1] != 1 or self.training) and pixel_values is not None:
161 | with torch.no_grad():
162 | image_features = self.get_vision_features(pixel_values)
163 | image_features = self.mm_projector(image_features)
164 | dummy_image_features = torch.zeros(len(image_features[0]), self.config.mm_hidden_size, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
165 | dummy_image_features = self.mm_projector(dummy_image_features)
166 |
167 | new_input_embeds = []
168 | cur_image_idx = 0
169 | for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
170 | if (cur_input_ids == self.config.image_token_id).sum() == 0:
171 | # multimodal LLM, but the current sample is not multimodal
172 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
173 | new_input_embeds.append(cur_input_embeds)
174 | cur_image_idx += 1
175 | continue
176 |
177 | cur_image_features = image_features[cur_image_idx].to(cur_input_embeds.device)
178 | num_patches = cur_image_features.shape[0]
179 | if (cur_input_ids == self.config.image_token_id).sum() != num_patches:
180 | raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
181 | masked_indices = torch.where(cur_input_ids == self.config.image_token_id)[0]
182 | mask_index_start = masked_indices[0]
183 | if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
184 | raise ValueError("The image patch tokens should be consecutive.")
185 | cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
186 | new_input_embeds.append(cur_new_input_embeds)
187 | cur_image_idx += 1
188 |
189 | inputs_embeds = torch.stack(new_input_embeds, dim=0)
190 |
191 | return super(DetikzifyModel, self).forward(
192 | input_ids=None,
193 | attention_mask=attention_mask,
194 | past_key_values=past_key_values,
195 | inputs_embeds=inputs_embeds,
196 | use_cache=use_cache,
197 | output_attentions=output_attentions,
198 | output_hidden_states=output_hidden_states,
199 | return_dict=return_dict
200 | )
201 |
202 |
203 | class DetikzifyForCausalLM(LlamaForCausalLM):
204 | config_class = DetikzifyConfig
205 |
206 | def __init__(self, config):
207 | super(LlamaForCausalLM, self).__init__(config)
208 | self.model = DetikzifyModel(config)
209 |
210 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
211 |
212 | # Initialize weights and apply final processing
213 | self.post_init()
214 |
215 | def get_model(self):
216 | return self.model
217 |
218 | def forward(
219 | self,
220 | input_ids: torch.LongTensor = None,
221 | attention_mask: Optional[torch.Tensor] = None,
222 | past_key_values: Optional[List[torch.FloatTensor]] = None,
223 | inputs_embeds: Optional[torch.FloatTensor] = None,
224 | labels: Optional[torch.LongTensor] = None,
225 | use_cache: Optional[bool] = None,
226 | output_attentions: Optional[bool] = None,
227 | output_hidden_states: Optional[bool] = None,
228 | pixel_values: Optional[torch.FloatTensor] = None,
229 | return_dict: Optional[bool] = None,
230 | ) -> Union[Tuple, CausalLMOutputWithPast]:
231 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
232 | output_hidden_states = (
233 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
234 | )
235 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
236 |
237 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
238 | outputs = self.model(
239 | input_ids=input_ids,
240 | attention_mask=attention_mask,
241 | past_key_values=past_key_values,
242 | inputs_embeds=inputs_embeds,
243 | use_cache=use_cache,
244 | output_attentions=output_attentions,
245 | output_hidden_states=output_hidden_states,
246 | return_dict=return_dict,
247 | pixel_values=pixel_values
248 | )
249 |
250 | hidden_states = outputs[0]
251 | if self.config.pretraining_tp > 1:
252 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
253 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
254 | logits = torch.cat(logits, dim=-1)
255 | else:
256 | logits = self.lm_head(hidden_states)
257 | logits = logits.float()
258 |
259 |
260 | loss = None
261 | if labels is not None:
262 | # Shift so that tokens < n predict n
263 | shift_logits = logits[..., :-1, :].contiguous()
264 | shift_labels = labels[..., 1:].contiguous()
265 | # Flatten the tokens
266 | loss_fct = CrossEntropyLoss()
267 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
268 | shift_labels = shift_labels.view(-1)
269 | # Enable model/pipeline parallelism
270 | shift_labels = shift_labels.to(shift_logits.device)
271 | loss = loss_fct(shift_logits, shift_labels)
272 |
273 | if not return_dict:
274 | output = (logits,) + outputs[1:]
275 | return (loss,) + output if loss is not None else output
276 |
277 | return CausalLMOutputWithPast(
278 | loss=loss,
279 | logits=logits,
280 | past_key_values=outputs.past_key_values,
281 | hidden_states=outputs.hidden_states,
282 | attentions=outputs.attentions,
283 | )
284 |
285 | def prepare_inputs_for_generation(
286 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
287 | ):
288 | if past_key_values:
289 | input_ids = input_ids[:, -1:]
290 |
291 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
292 | if inputs_embeds is not None and past_key_values is None:
293 | model_inputs = {"inputs_embeds": inputs_embeds}
294 | else:
295 | model_inputs = {"input_ids": input_ids}
296 |
297 | model_inputs.update(
298 | {
299 | "past_key_values": past_key_values,
300 | "use_cache": kwargs.get("use_cache"),
301 | "attention_mask": attention_mask,
302 | "pixel_values": kwargs.get("pixel_values", None),
303 | }
304 | )
305 | return model_inputs
306 |
--------------------------------------------------------------------------------