"
15 |
16 | def disable_torch_init():
17 | """
18 | Disable the redundant torch default initialization to accelerate model creation.
19 | """
20 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
21 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
22 |
23 | def expand2square(pil_img, background_color):
24 | # pad to middle for square shape
25 | width, height = pil_img.size
26 | if width == height:
27 | return pil_img
28 | elif width > height:
29 | result = Image.new(pil_img.mode, (width, width), background_color)
30 | result.paste(pil_img, (0, (width - height) // 2))
31 | return result
32 | else:
33 | result = Image.new(pil_img.mode, (height, height), background_color)
34 | result.paste(pil_img, ((height - width) // 2, 0))
35 | return result
36 |
37 | def padding_336(b):
38 | width, height = b.size
39 | tar = int(np.ceil(height / 336) * 336)
40 | top_padding = int((tar - height)/2)
41 | bottom_padding = tar - height - top_padding
42 |
43 | left_padding = 0
44 | right_padding = 0
45 |
46 | mean_fill = 255*[0.48145466, 0.4578275, 0.40821073]
47 | b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255])
48 |
49 | return b
50 |
51 | def HD_transform(img, hd_num=9):
52 | width, height = img.size
53 | trans = False
54 | if width < height:
55 | img = img.transpose(Image.TRANSPOSE)
56 | trans = True
57 | width, height = img.size
58 | ratio = (width/ height)
59 | scale = int(np.ceil(width/336))
60 | # print(width, height, ratio, scale, scale*np.ceil(scale/ratio))
61 | while scale*np.ceil(scale/ratio) > hd_num:
62 | scale -= 1
63 | # print(scale*np.ceil(scale/ratio))
64 | new_w = int(scale * 336)
65 | new_h = int(new_w / ratio)
66 |
67 | img = transforms.functional.resize(img, [new_h, new_w],)
68 | img = padding_336(img)
69 | width, height = img.size
70 | if trans:
71 | img = img.transpose(Image.TRANSPOSE)
72 |
73 | return img
74 |
75 | class ImageTestProcessorHD:
76 | def __init__(self, image_size=224, mean=None, std=None, hd_num=-1):
77 | if mean is None:
78 | self.mean = mean = (0.48145466, 0.4578275, 0.40821073)
79 | if std is None:
80 | self.std = std = (0.26862954, 0.26130258, 0.27577711)
81 |
82 | self.normalize = transforms.Normalize(mean, std)
83 | self.transform = transforms.Compose(
84 | [
85 | transforms.ToTensor(),
86 | self.normalize,
87 | ]
88 | )
89 | self.hd_num = hd_num
90 |
91 | def __call__(self, item):
92 | return self.transform(HD_transform(item, hd_num=self.hd_num))
93 |
94 | def main(args):
95 | disable_torch_init()
96 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
97 |
98 | model = model.cuda().eval()
99 | image_processor = ImageTestProcessorHD(336, hd_num=16)
100 | from bigmodelvis import Visualization
101 | Visualization(model).structure_graph()
102 |
103 | questions = [
104 | '将图中表格转成html格式.',
105 | '请解析输入的文档.'
106 | ]
107 |
108 | raw_image = Image.open('../infmllm2/docs/doc_02.png').convert('RGB')
109 | image_tensor = image_processor(raw_image).cuda()
110 |
111 | history = []
112 |
113 | print("\n" + "=" * 20)
114 | for i, question in enumerate(questions):
115 | history.append({
116 | 'from': 'human',
117 | 'value': question,
118 | })
119 | history.append(
120 | {"from": 'gpt', "value": ""})
121 | samples = {
122 | 'images': [image_tensor.unsqueeze(0)],
123 | 'conversations': [history]
124 | }
125 | with torch.inference_mode():
126 | pred_answers, prompts = model.generate(
127 | samples=samples,
128 | max_length=args.max_new_tokens,
129 | min_length=1,
130 | num_beams=args.num_beams,
131 | top_p=args.top_p,
132 | temperature=args.temperature,
133 | return_prompts=True
134 | )
135 | answer = pred_answers[0]
136 | print(f"Q{i+1}: {question}")
137 | print(f"A{i+1}: {answer}")
138 | history[-1]['value'] = answer
139 |
140 | if __name__ == '__main__':
141 | import argparse
142 | parser = argparse.ArgumentParser()
143 | parser.add_argument("--model_path", type=str, default="./InfMLLM_7B_Chat")
144 | parser.add_argument("--temperature", type=float, default=0.)
145 | parser.add_argument("--top_p", type=float, default=None)
146 | parser.add_argument("--num_beams", type=int, default=1)
147 | parser.add_argument("--max_new_tokens", type=int, default=4096)
148 | args = parser.parse_args()
149 |
150 | main(args)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # INF-MLLM
2 |
3 |
4 |
5 |
6 |
7 | ## Introduction
8 |
9 | INF-MLLM is a series of open-source multimodal large language models developed by INF Tech. This repository contains the code, models, and documentation for our projects, which aim to advance the state-of-the-art in visual-language understanding and document intelligence. We are committed to open research and have released our models and datasets to the community to foster collaboration and innovation.
10 |
11 | ## Updates
12 |
13 | - [2025/11/03] The [Infinity-Parser-7B](https://huggingface.co/infly/Infinity-Parser-7B), [Infinity-Doc-400K dataset](https://huggingface.co/datasets/infly/Infinity-Doc-400K), and synthetic data [generation code](https://github.com/infly-ai/INF-MLLM/tree/main/Infinity-Parser/Infinity-Synth) have been released.
14 | - [2025/09/19] VL-Rethinker has been accepted as a Spotlight paper at NeurIPS 2025!!
15 | - [2025/06/30] We have added an introduction to our latest model, **Infinity-Parser**. The [Infinity-Doc-55K dataset](https://huggingface.co/datasets/infly/Infinity-Doc-55K) and [Infinity-Parser web demo](https://huggingface.co/spaces/infly/Infinity-Parser-Demo) are now available.
16 | - [2025/04/22] VL-Rethinker models (7B & 72B) are released! They achieve new state-of-the-art results on MathVista, MathVerse, and MathVision benchmarks.
17 | - [2024/08/19] We have released **INF-MLLM2**, with the [INF-MLLM2-7B model](https://huggingface.co/QianYEee/InfMLLM2_7B_chat) and evaluation code now available.
18 | - [2023/12/06] The models and evaluation code for **INF-MLLM1** are now available.
19 | - [2023/11/06] We have released **INF-MLLM1** and uploaded the initial version of the manuscript to [arXiv](https://arxiv.org/abs/2311.06791).
20 |
21 | ## Models
22 |
23 | Here is a brief overview of the models available in this repository. For more details, please refer to the respective project directories.
24 |
25 | ### [Infinity-Parser](Infinity-Parser)
26 |
27 | **Infinity-Parser** is an end-to-end scanned document parsing model trained with reinforcement learning. It is designed to maintain the original document's structure and content with high fidelity by incorporating verifiable rewards based on layout and content. Infinity-Parser demonstrates state-of-the-art performance on various benchmarks for text recognition, table and formula extraction, and reading-order detection.
28 |
29 | - **Key Features:** Layout-aware, reinforcement learning, high-fidelity document parsing.
30 | - **Paper:** [Infinity Parser: Layout Aware Reinforcement Learning for Scanned Document Parsing](https://arxiv.org/abs/2506.03197)
31 | - **Dataset:** [Infinity-Doc-55K](https://huggingface.co/datasets/infly/Infinity-Doc-55K), [Infinity-Doc-400K](https://huggingface.co/datasets/infly/Infinity-Doc-400K)
32 | - **Model:** [Infinity-Parser-7B](https://huggingface.co/infly/Infinity-Parser-7B)
33 | - **Web Demo:** [Infinity-Parser-Demo](https://huggingface.co/spaces/infly/Infinity-Parser-Demo)
34 |
35 | ### [VL-Rethinker](https://github.com/TIGER-AI-Lab/VL-Rethinker)
36 |
37 | **VL-Rethinker** is a project designed to incentivize the self-reflection capabilities of Vision-Language Models (VLMs) through Reinforcement Learning. The research introduces a novel technique called Selective Sample Replay (SSR) to enhance the GRPO algorithm, addressing the "vanishing advantages" problem. It also employs "Forced Rethinking" to explicitly guide the model through a self-reflection reasoning step. By combining these methods, VL-Rethinker significantly advances the state-of-the-art performance on multiple vision-language benchmarks, including MathVista, MathVerse, and MathVision.
38 |
39 | - **Key Features:** Advanced RL techniques, fine-grained multimodal dataset, fully open-sourced.
40 | - **Paper:** [VL-Rethinker: Incentivizing Self-Reflection of Vision-Language Models with Reinforcement Learning](https://arxiv.org/abs/2504.08837)
41 | - **Dataset:** [ViRL39K](https://huggingface.co/datasets/TIGER-Lab/ViRL39K)
42 | - **Models:** [VL-Rethinker-7B](https://huggingface.co/TIGER-Lab/VL-Rethinker-7B), [VL-Rethinker-72B](https://huggingface.co/TIGER-Lab/VL-Rethinker-72B)
43 | - **Web Demo:** [VL-Rethinker-Demo](https://huggingface.co/spaces/TIGER-Lab/VL-Rethinker)
44 |
45 | ### [INF-MLLM2](INF-MLLM2)
46 |
47 | **INF-MLLM2** is an advanced multimodal model with significant improvements in high-resolution image processing and document understanding. It supports dynamic image resolutions up to 1344x1344 pixels and features enhanced OCR capabilities for robust document parsing, table and formula recognition, and key information extraction.
48 |
49 | - **Key Features:** High-resolution image support, advanced OCR, progressive multi-stage training.
50 | - **Paper:** [Technical Report](INF-MLLM2/docs/tech_report.pdf)
51 | - **Model:** [INF-MLLM2-7B](https://huggingface.co/QianYEee/InfMLLM2_7B_chat)
52 |
53 | ### [INF-MLLM1](INF-MLLM1)
54 |
55 | **INF-MLLM1** is a unified model for a wide range of visual-language tasks. It is designed to handle both multitask and instruction-tuning scenarios, demonstrating strong performance on various VQA and visual grounding datasets.
56 |
57 | - **Key Features:** Unified framework, multitask learning, instruction tuning.
58 | - **Paper:** [InfMLLM: A Unified Framework for Visual-Language Tasks](https://arxiv.org/abs/2311.06791)
59 | - **Models:** [InfMLLM-7B](https://huggingface.co/mightyzau/InfMLLM_7B), [InfMLLM-7B-Chat](https://huggingface.co/mightyzau/InfMLLM_7B_Chat), [InfMLLM-13B-Chat](https://huggingface.co/mightyzau/inf-mllm-13b-chat)
60 |
--------------------------------------------------------------------------------
/INF-MLLM1/demo.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | rootdir = os.path.abspath(os.path.dirname(__file__))
3 | if rootdir not in sys.path:
4 | sys.path.insert(0, rootdir)
5 |
6 | import re
7 | import torch
8 | from PIL import Image
9 | import requests
10 | from transformers import AutoModel, AutoTokenizer
11 |
12 | from evaluate.infmllm_chat.utils import tokenizer_image_token
13 | from evaluate.infmllm_chat.conversation import conv_templates, SeparatorStyle
14 |
15 | IMAGE_TOKEN_INDEX = -200
16 | DEFAULT_IMAGE_TOKEN = ""
17 |
18 | def disable_torch_init():
19 | """
20 | Disable the redundant torch default initialization to accelerate model creation.
21 | """
22 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
23 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
24 |
25 | def expand2square(pil_img, background_color):
26 | # pad to middle for square shape
27 | width, height = pil_img.size
28 | if width == height:
29 | return pil_img
30 | elif width > height:
31 | result = Image.new(pil_img.mode, (width, width), background_color)
32 | result.paste(pil_img, (0, (width - height) // 2))
33 | return result
34 | else:
35 | result = Image.new(pil_img.mode, (height, height), background_color)
36 | result.paste(pil_img, ((height - width) // 2, 0))
37 | return result
38 |
39 | def get_prompt(conv_mode, question, history=[]):
40 | conv = conv_templates[conv_mode].copy()
41 | if len(history) == 0:
42 | question = DEFAULT_IMAGE_TOKEN + '\n' + question
43 | else:
44 | if DEFAULT_IMAGE_TOKEN not in history[0][0]:
45 | history[0][0] = DEFAULT_IMAGE_TOKEN + '\n' + history[0][0]
46 |
47 | for qa in history:
48 | conv.append_message(conv.roles[0], qa[0])
49 | conv.append_message(conv.roles[1], qa[1])
50 |
51 | conv.append_message(conv.roles[0], question)
52 | conv.append_message(conv.roles[1], None)
53 |
54 | prompt = conv.get_prompt()
55 | return prompt
56 |
57 | def generate(model, tokenizer, stop_str, input_ids, image_tensor):
58 | with torch.inference_mode():
59 | output_ids = model.generate(
60 | input_ids,
61 | images=image_tensor.unsqueeze(0).to(dtype=torch.bfloat16, device='cuda', non_blocking=True),
62 | do_sample=True if args.temperature > 0 else False,
63 | temperature=args.temperature,
64 | top_p=args.top_p,
65 | num_beams=args.num_beams,
66 | max_new_tokens=args.max_new_tokens,
67 | use_cache=True)
68 |
69 | input_token_len = input_ids.shape[1]
70 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
71 | if n_diff_input_output > 0:
72 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
73 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
74 | outputs = outputs.strip()
75 | if outputs.endswith(stop_str):
76 | outputs = outputs[:-len(stop_str)]
77 | return outputs
78 |
79 |
80 | def main(args):
81 | disable_torch_init()
82 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False)
83 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
84 | model = model.cuda().eval()
85 | image_processor = model.get_model().get_vision_tower().image_processor
86 |
87 | stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2 #
88 |
89 | img_url = 'https://farm5.staticflickr.com/4016/4349416002_e3743125b7_z.jpg'
90 | questions = [
91 | 'Why this image is interesting ?',
92 | 'What is the cat watching ?',
93 | 'What is the scientific name of the bird in the picture?',
94 | 'How is the weather outside?',
95 | 'what season is it now ?'
96 | ]
97 |
98 | print(img_url)
99 |
100 | raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
101 | image = expand2square(raw_image, tuple(int(x*255) for x in image_processor.image_mean))
102 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
103 |
104 | history = []
105 |
106 | print("\n" + "=" * 20)
107 | for i, question in enumerate(questions):
108 | prompt = get_prompt(args.conv_mode, question, history)
109 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
110 | input_ids = input_ids.to(device='cuda', non_blocking=True)
111 | answer = generate(model, tokenizer, stop_str, input_ids, image_tensor)
112 |
113 | print(f"Q{i+1}: {question}")
114 | print(f"A{i+1}: {answer}")
115 | history.append([question, answer])
116 |
117 |
118 | if __name__ == '__main__':
119 | import argparse
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("--model_path", type=str, default="./InfMLLM_7B_Chat")
122 | parser.add_argument("--conv_mode", type=str, default="vicuna_v1")
123 | parser.add_argument("--temperature", type=float, default=0.)
124 | parser.add_argument("--top_p", type=float, default=None)
125 | parser.add_argument("--num_beams", type=int, default=1)
126 | parser.add_argument("--max_new_tokens", type=int, default=1024)
127 | args = parser.parse_args()
128 |
129 | main(args)
130 |
131 |
--------------------------------------------------------------------------------
/INF-MLLM1/evaluate/infmllm_chat/calculation_mme.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
4 |
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('-r', '--results_dir', default='./LaVIN', type=str)
8 |
9 | eval_type_dict = {
10 | "Perception": ["existence", "count", "position", "color", "posters", "celebrity", "scene", "landmark", "artwork", "OCR"],
11 | "Cognition": ["commonsense_reasoning", "numerical_calculation", "text_translation", "code_reasoning"]
12 | }
13 |
14 |
15 | class calculate_metrics:
16 | def divide_chunks(self, l, n=2):
17 | # looping till length l
18 | for i in range(0, len(l), n):
19 | yield l[i:i + n]
20 |
21 | return
22 |
23 | def parse_pred_ans(self, pred_ans):
24 | pred_label = None
25 | if pred_ans in ["yes", "no"]:
26 | pred_label = pred_ans
27 | else:
28 | prefix_pred_ans = pred_ans[:4]
29 |
30 | if "yes" in prefix_pred_ans:
31 | pred_label = "yes"
32 | elif "no" in prefix_pred_ans:
33 | pred_label = "no"
34 | else:
35 | pred_label = "other"
36 |
37 | return pred_label
38 |
39 |
40 | def compute_metric(self, gts, preds):
41 | assert len(gts) == len(preds)
42 |
43 | label_map = {
44 | "yes": 1,
45 | "no": 0,
46 | "other": -1,
47 | }
48 |
49 | gts = [label_map[x] for x in gts]
50 | preds = [label_map[x] for x in preds]
51 |
52 | acc = accuracy_score(gts, preds)
53 |
54 | clean_gts = []
55 | clean_preds = []
56 | other_num = 0
57 | for gt, pred in zip(gts, preds):
58 | if pred == -1:
59 | other_num += 1
60 | continue
61 | clean_gts.append(gt)
62 | clean_preds.append(pred)
63 |
64 |
65 | conf_mat = confusion_matrix(clean_gts, clean_preds, labels=[1,0])
66 | precision = precision_score(clean_gts, clean_preds, average='binary')
67 | recall = recall_score(clean_gts, clean_preds, average='binary')
68 | tp, fn = conf_mat[0]
69 | fp, tn = conf_mat[1]
70 |
71 | metric_dict = dict()
72 | metric_dict = {
73 | "TP": tp,
74 | "FN": fn,
75 | "TN": tn,
76 | "FP": fp,
77 | "precision": precision,
78 | "recall": recall,
79 | "other_num": other_num,
80 | "acc": acc,
81 | }
82 |
83 | return metric_dict
84 |
85 |
86 | def process_result(self, results_dir):
87 |
88 | model_score_dict = dict()
89 | for eval_type, task_name_list in eval_type_dict.items():
90 | print("===========", eval_type, "===========")
91 |
92 | scores = 0
93 | task_score_dict = dict()
94 |
95 | for task_name in task_name_list:
96 | print(task_name)
97 |
98 | task_txt = os.path.join(results_dir, task_name + ".txt")
99 | lines = open(task_txt, 'r').readlines()
100 | chunk_lines = list(self.divide_chunks(lines))
101 |
102 | img_num = len(chunk_lines)
103 | task_other_ans_num = 0
104 | task_score = 0
105 | acc_plus_correct_num = 0
106 | gts = []
107 | preds = []
108 |
109 | for img_items in chunk_lines:
110 | assert len(img_items) == 2
111 | img_correct_num = 0
112 |
113 | for img_item in img_items:
114 | try:
115 | img_name, question, gt_ans, pred_ans = img_item.split("\t")
116 | except:
117 | print('img_item: {}'.format(img_item))
118 |
119 | gt_ans = gt_ans.lower()
120 | pred_ans = pred_ans.lower()
121 |
122 | assert gt_ans in ["yes", "no"] # gt can only be yes or no.
123 |
124 | pred_ans = self.parse_pred_ans(pred_ans)
125 | assert pred_ans in ["yes", "no", "other"]
126 |
127 | gts.append(gt_ans)
128 | preds.append(pred_ans)
129 |
130 | if gt_ans == pred_ans:
131 | img_correct_num += 1
132 |
133 | if pred_ans not in ["yes", "no"]:
134 | task_other_ans_num += 1
135 |
136 | if img_correct_num == 2:
137 | acc_plus_correct_num += 1
138 |
139 | # cal TP precision acc, etc.
140 | metric_dict = self.compute_metric(gts, preds)
141 | acc_plus = acc_plus_correct_num / img_num
142 | metric_dict["acc_plus"] = acc_plus
143 |
144 |
145 | for k, v in metric_dict.items():
146 | if k in ["acc", "acc_plus"]:
147 | task_score += v*100
148 |
149 | task_score_dict[task_name] = task_score
150 |
151 | scores += task_score
152 |
153 | print("total score:", scores, "\n")
154 | for task_name, score in task_score_dict.items():
155 | print("\t", task_name, " score:", score)
156 | print("\n")
157 |
158 | return
159 |
160 |
161 |
162 |
163 | if __name__ == "__main__":
164 | cal = calculate_metrics()
165 |
166 | args = parser.parse_args()
167 |
168 | results_dir = args.results_dir
169 | cal.process_result(results_dir)
170 |
171 |
--------------------------------------------------------------------------------
/Infinity-Parser/Infinity-Synth/templates/three_columns/document.css.jinja:
--------------------------------------------------------------------------------
1 | .container h3 {
2 | margin-bottom: {{styles.gap.h3p_gap}};
3 | text-align: {{styles.h3location|default("left")}};
4 | }
5 |
6 | .container p {
7 | margin-top: 0px;
8 | line-height: {{ styles.line_height }};
9 | text-indent: 2em;
10 | margin-bottom: 0px;
11 | }
12 |
13 | body {
14 | height: 100%;
15 | margin: 0;
16 | padding: 0;
17 | display: flex;
18 | justify-content: center;
19 | align-items: center;
20 | }
21 |
22 | p {
23 | word-break: break-all;
24 | }
25 |
26 | .a4-page {
27 | width: 210mm;
28 | height: 297mm;
29 | border: 1px solid #ccc;
30 | display: flex;
31 | flex-direction: column;
32 | padding: 20mm;
33 | page-break-after: always;
34 | box-sizing: border-box;
35 | justify-content: space-between;
36 | position: relative;
37 | }
38 |
39 | .main_content {
40 | column-count: {{styles.columns}};
41 | column-gap: 16px;
42 | }
43 |
44 | .header {
45 | width: calc(100% - 20px);
46 | height: 10mm;
47 | background-color: {{styles.header.background_color}};
48 | padding: 10px;
49 | box-sizing: border-box;
50 | display: flex;
51 | justify-content: space-between;
52 | align-items: center;
53 | position: absolute;
54 | top: 5px;
55 | left: 5px;
56 | right: 5px;
57 | z-index: 1000;
58 | font-family: 'Arial', sans-serif;
59 | font-size: 12px;
60 | font-weight: bold;
61 | text-align: center;
62 | }
63 |
64 | .header-left,
65 | .header-mid,
66 | .header-right {
67 | font-size: 10px;
68 | text-align: center;
69 | }
70 |
71 | .header .page-number-container {
72 | display: flex;
73 | align-items: center;
74 | position: relative;
75 | }
76 |
77 | .header .page-number {
78 | font-family: 'Courier New', monospace;
79 | font-size: 16px;
80 | margin-right: 10px;
81 | }
82 |
83 | .header .rectangle {
84 | width: 40px;
85 | height: 4px;
86 | background-color: #ff6347;
87 | }
88 |
89 | .hcentered-line {
90 | position: absolute;
91 | bottom: 0px;
92 | left: 50%;
93 | transform: translateX(-50%);
94 | width: calc(80% - 20px);
95 | border-bottom: 2px solid black;
96 | }
97 |
98 | .page-num {
99 | flex: 1;
100 | text-align: center;
101 | }
102 |
103 | .footer {
104 | width: 100%;
105 | height: 10mm;
106 | border-top: 2px solid black;
107 | display: flex;
108 | justify-content: space-between;
109 | align-items: center;
110 | padding: 10px;
111 | box-sizing: border-box;
112 | background-image: url('path/to/your/image.jpg');
113 | background-size: cover;
114 | background-position: center;
115 | background-color: {{styles.header.background_color}};
116 | position: absolute;
117 | bottom: 5px;
118 | left: 5px;
119 | right: 5px;
120 | }
121 |
122 | .footer-left,
123 | .footer-mid,
124 | .footer-right {
125 | flex: 1;
126 | text-align: center;
127 | }
128 |
129 | .circle-background {
130 | color: #b8311a;
131 | width: 45px;
132 | height: 15px;
133 | background-color: {{ styles.page_num.background_color }};
134 | border-radius: 50%;
135 | top: -25px;
136 | left: calc(50% - 75px);
137 | z-index: -1;
138 | }
139 |
140 | .title {
141 | font-size: {{ styles.title.font_size|default('10pt') }};
142 | font-family: {{ styles.title.font_family|default('Arial, sans-serif') }};
143 | color: {{ styles.title.color|default('#333') }};
144 | background-color: {{ styles.title.background_color|default('#fff') }};
145 | margin-bottom: {{styles.title_margin_bottom}};
146 | text-align: {{ styles.title.center if styles.title.center else 'center' }};
147 | }
148 |
149 | .figure_caption {
150 | text-align: justify;
151 | font-size: 12px;
152 | color: #333;
153 | }
154 |
155 | .formula-block {
156 | display: flex;
157 | align-items: center;
158 | width: 100%;
159 | box-sizing: border-box;
160 | }
161 |
162 | .formula {
163 | width: max-content;
164 | margin: 0 auto;
165 | text-align: center;
166 | }
167 |
168 | .formula_caption {
169 | width: max-content;
170 | text-align: right;
171 | font-size: 12px;
172 | }
173 |
174 | .table_outer {
175 | width: 67%;
176 | margin: 1px auto;
177 | }
178 |
179 | .table_caption {
180 | width: max-content;
181 | margin: 16px auto;
182 | text-align: left;
183 | font-size: 12px;
184 | }
185 |
186 | .table_footnote {
187 | width: max-content;
188 | text-align: left;
189 | font-size: 11px;
190 | }
191 |
192 | .table-block {
193 | width: 100%;
194 | text-align: center;
195 | }
196 |
197 | .table-block table {
198 | margin: 0;
199 | font-size: 14px;
200 | border-collapse: collapse;
201 | width: 100%;
202 | border: 1px solid #ffffff;
203 | }
204 |
205 | .table-block th,
206 | .table-block td {
207 | padding: 2px 2px;
208 | border: 1px solid #ddd;
209 | font-size: 10px;
210 | text-align: center;
211 | border-left: none;
212 | border-right: none;
213 | border-top: 1px solid #ccc;
214 | border-bottom: 1px solid #ccc;
215 | font-weight: bold;
216 | }
217 |
218 | .table-block table thead tr:first-child th {
219 | border-top: 2px solid #000;
220 | }
221 |
222 | h3 {
223 | font-size: 16px;
224 | margin-bottom: 4px;
225 | }
226 |
227 | .text {
228 | font-size: 12px;
229 | text-align: left;
230 | text-indent: 4ch;
231 | line-height: 1.2;
232 | margin: 2px auto;
233 | }
234 |
235 | .MathJax,
236 | .mjx-tex-display,
237 | .MathJax_Display,
238 | .mjx-math {
239 | margin: 0 !important;
240 | padding: 0 !important;
241 | }
242 |
243 | .page_footnote {
244 | position: relative;
245 | font-size: 9px;
246 | text-align: left;
247 | }
248 |
249 | .page_footnote::before {
250 | content: "";
251 | position: absolute;
252 | top: 0;
253 | left: 0;
254 | width: 40%;
255 | border-top: 1px solid black;
256 | }
257 |
258 | .page_footnote_p {
259 | display: inline-block;
260 | }
261 |
262 | @page {
263 | size: A4;
264 | margin: 0;
265 | }
266 |
267 | @media print {
268 | html, body {
269 | margin: 0;
270 | padding: 0;
271 | }
272 |
273 | .a4-page {
274 | width: 210mm;
275 | height: calc(297mm - 0.5mm);
276 | box-sizing: border-box;
277 | overflow: hidden;
278 | }
279 | }
280 |
--------------------------------------------------------------------------------
/Infinity-Parser/Infinity-Synth/config/Config.py:
--------------------------------------------------------------------------------
1 | # Config
2 |
3 | import random
4 |
5 | class Config:
6 |
7 | text_colors = [
8 | '#000000',
9 | "#333333",
10 | "#222222",
11 | "#0a0a0a",
12 | "#003366",
13 | "#2f4f4f",
14 | "#483d8b",
15 | "#4b0082",
16 | "#2e8b57",
17 | "#696969",
18 | "#800000"
19 | ]
20 | background_colors = [
21 | "transparent",
22 | "#f8f8f8",
23 | "#fafafa",
24 | "#f0f0f0",
25 | "#e0e0e0",
26 | "#fff8e1",
27 | "#f0f8ff",
28 | "#f5f5f5",
29 | "#f4fff4",
30 | "#fff0f5",
31 | "#fffff0"
32 | ]
33 |
34 | font_styles = ["normal", "italic", "oblique"]
35 |
36 |
37 | fonts = {
38 | "english": [
39 | "Times New Roman",
40 | "Georgia",
41 | "Garamond",
42 | "Arial",
43 | "Helvetica",
44 | "Verdana"
45 | ],
46 | "chinese": [
47 | "SimSun",
48 | "NSimSun",
49 | "SimHei",
50 | "Microsoft YaHei",
51 | "KaiTi",
52 | "FangSong"
53 | ]
54 | }
55 |
56 |
57 | font_size_options = {
58 | "title": [ "10pt", "11pt", "12pt", "13pt"],
59 | "authors": ["9pt", "10pt", "11pt"],
60 | "abstract": ["10pt","8pt", "9pt"],
61 | "content": [ "9pt", "10pt", "11pt", "12pt"],
62 | "table": ["10px", "9px", "11px", '12px'],
63 | "width": [155, 160, 165, 170],
64 | "table_caption": ["10px", "9px", "11px"],
65 | "container_img_width": [85, 90, 95, 100],
66 | "abstract_img_width": [85, 90, 95, 100],
67 | "head_figure_width": [ 60, 70, 80, 90],
68 |
69 | }
70 |
71 | table = {
72 | "line_colors": [
73 | '#000000',
74 | "#333333",
75 | "#222222",
76 | "#0a0a0a",
77 | "#003366",
78 | "#2f4f4f",
79 | "#483d8b",
80 | "#4b0082",
81 | "#2e8b57",
82 | "#696969",
83 | "#800000"
84 | ],
85 | "back_color": [
86 | "transparent",
87 | "#f8f8f8",
88 | "#fafafa",
89 | "#f0f0f0",
90 | "#e0e0e0",
91 | "#fff8e1",
92 | "#f0f8ff",
93 | "#f5f5f5",
94 | "#f4fff4",
95 | "#fff0f5",
96 | "#fffff0"
97 | ],
98 | "align": ['center', 'left'],
99 | "width": [80, 90, 100],
100 |
101 | }
102 |
103 | align = ['center', 'left']
104 |
105 | continer = {
106 | "h3p_gap": ["1px", "3px", "5px", "7px"],
107 | "column_gap": ["20px", "25px", "30px"],
108 | "margin_bottom": ["8px", "10px", "12px", "16px"],
109 | "line_height": [1.5, 1.6, 1.7, 1.8],
110 | "align": ['center', 'left']
111 | }
112 |
113 | header = {
114 | "font_size": ["10pt","8pt", "9pt"],
115 |
116 | }
117 |
118 | footer = {
119 | "font_size": ["10pt","8pt", "9pt"],
120 |
121 | }
122 |
123 | container_layout = {
124 | "left": [60, 62, 64, 66],
125 | "gap": [1, 2],
126 | "background_colors": [
127 | "transparent",
128 | "#f8f8f8",
129 | "#fafafa",
130 | "#f0f0f0",
131 | "#e0e0e0",
132 | "#fff8e1",
133 | "#f0f8ff",
134 | "#f5f5f5",
135 | "#f4fff4",
136 | "#fff0f5",
137 | "#fffff0"
138 | ],
139 |
140 | "dark_background_colors": [
141 | "#2c2c2c",
142 | "#36454f",
143 | "#191970",
144 | "#2f4f4f",
145 | "#000080",
146 | "#556b2f",
147 | "#301934",
148 | "#800000",
149 | "#4b0082",
150 | "#000000"
151 | ]
152 | }
153 |
154 |
155 | page_num = {
156 | "back_color": [
157 | "#f8f8f8",
158 | "#fafafa",
159 | "#f0f0f0",
160 | "#e0e0e0",
161 | "#fff8e1",
162 | "#f0f8ff",
163 | "#f5f5f5",
164 | "#f4fff4",
165 | "#fff0f5",
166 | "#fffff0",
167 | "#000000",
168 | "#333333",
169 | "#222222",
170 | "#0a0a0a",
171 | "#003366",
172 | "#2f4f4f",
173 | "#483d8b",
174 | "#4b0082",
175 | "#2e8b57",
176 | "#696969",
177 | "#800000"
178 | ]
179 | }
180 |
181 |
182 | def random_value_from_list(list_name):
183 | def decorator(func):
184 | def wrapper(*args, **kwargs):
185 | list_to_use = getattr(Config, list_name, [])
186 | if not list_to_use:
187 | raise ValueError(f"List '{list_name}' not found in Config class.")
188 | weights = [ 100 if i == 0 else 1 for i in range(len(list_to_use)) ]
189 | random_value = random.choices(list_to_use, weights=weights, k=1)[0]
190 | return func(random_value, *args, **kwargs)
191 | return wrapper
192 | return decorator
193 |
194 |
195 | def get_config_value_by_list(list_name):
196 | @random_value_from_list(list_name)
197 | def wrapper(random_value):
198 | return random_value
199 | return wrapper()
200 |
201 | def random_value_from_dict(config_key):
202 | def decorator(func):
203 | def wrapper(*args, **kwargs):
204 | dict_name, key = config_key.split('.')
205 | config_dict = getattr(Config, dict_name, None)
206 | if not config_dict:
207 | raise ValueError(f"Config dictionary '{dict_name}' not found")
208 | options = config_dict.get(key, [])
209 | if not options:
210 | raise ValueError(f"No options available for '{key}' in '{dict_name}'")
211 | selected_value = random.choice(options)
212 | return func(selected_value, *args, **kwargs)
213 | return wrapper
214 | return decorator
215 |
216 |
217 | def get_config_value_by_dict(config_key):
218 | @random_value_from_dict(config_key)
219 | def wrapper(random_value):
220 | return random_value
221 | return wrapper()
222 |
223 | def get_config_value(para):
224 | if len(para.split('.'))>1:
225 | value = get_config_value_by_dict(para)
226 | return value
227 | else:
228 | return get_config_value_by_list(para)
--------------------------------------------------------------------------------
/Infinity-Parser/Infinity-Synth/config/styles.py:
--------------------------------------------------------------------------------
1 | from config.Config import get_config_value
2 | import random
3 | from utils.utils import get_text_color, random_hex_color , generate_font_color
4 | import re
5 |
6 | def extract_single_number(text):
7 | match = re.search(r'(\d+)pt', text)
8 | return int(match.group(1)) if match else None
9 |
10 | def produce_stytles():
11 | page_back_color = get_config_value("page_num.back_color")
12 | header_back_color = random_hex_color()
13 | right_backcolor = header_back_color if random.random()>0.4 else random_hex_color()
14 |
15 | styles = {
16 | "incude_image_table": True if random.random()>1 else False,
17 |
18 | "title": {
19 | "font_size": get_config_value('font_size_options.title'),
20 | "font_family": get_config_value("fonts.chinese"),
21 | "font_weight": "bold",
22 | "color": get_config_value('text_colors'),
23 | "background_color": get_config_value('background_colors'),
24 | "center": get_config_value("align")
25 | },
26 | "authors": {
27 | "font_size": get_config_value('font_size_options.authors'),
28 | "font_family": get_config_value("fonts.chinese"),
29 | "font_weight": "normal", # Typically, author info is not bold
30 | "color": get_config_value('text_colors'),
31 | "background_color": get_config_value('background_colors'),
32 | "center": get_config_value("align")
33 | },
34 | "abstract": {
35 | "font_size": get_config_value('font_size_options.abstract'),
36 | "font_family": get_config_value("fonts.chinese"),
37 | "font_weight": "italic", # Abstracts are often italicized for emphasis
38 | "color": get_config_value('text_colors'),
39 | "background_color": get_config_value('background_colors'),
40 | "center": get_config_value("align")
41 | },
42 | "content": {
43 | "font_size": get_config_value('font_size_options.content'),
44 | "font_family": get_config_value("fonts.chinese"),
45 | "font_weight": "normal", # Regular content typically does not use bold
46 | "color": get_config_value('text_colors'),
47 | "background_color": get_config_value('background_colors')
48 | },
49 | "section_title": {
50 | "font_size": get_config_value('font_size_options.content'),
51 | "font_family": get_config_value("fonts.chinese"),
52 | "font_weight": "bold", # Regular content typically does not use bold
53 | "color": get_config_value('text_colors'),
54 | "background_color": get_config_value('background_colors')
55 |
56 | },
57 |
58 | "table": {
59 |
60 | "font_size": get_config_value('font_size_options.table'),
61 | "font_family_en": get_config_value("fonts.english"),
62 | "font_family_zh": get_config_value("fonts.chinese"),
63 | # "font_weight": "bold", # Regular content typically does not use bold
64 | "line_color": get_config_value('table.line_colors'),
65 | "background_color": get_config_value('background_colors'),
66 | "back_color": get_config_value('table.back_color'),
67 | "align": get_config_value("table.align"),
68 | "width": get_config_value("table.width"),
69 | "table_caption": get_config_value("font_size_options.table_caption"),
70 | },
71 | "body_text": {
72 | "font_size": "1em",
73 | "font_family": "Arial, sans-serif",
74 | "font_weight": "normal",
75 | "color": "#444",
76 | "background_color": "#fff",
77 | "line_height": "1.6"
78 | },
79 | "gap":{
80 | "h3p_gap":get_config_value("continer.h3p_gap")
81 | },
82 | "h3location": get_config_value("continer.align"),
83 |
84 | "column_gap": get_config_value("continer.column_gap"),
85 |
86 | "title_margin_bottom": get_config_value("continer.margin_bottom"),
87 |
88 | "authors_margin_bottom": get_config_value("continer.margin_bottom"),
89 |
90 | "abstract_margin_bottom": get_config_value("continer.margin_bottom"),
91 |
92 | "abstract_width": get_config_value("font_size_options.width"),
93 |
94 | "line_height": get_config_value("continer.line_height"),
95 |
96 | "caption":{
97 | "font_size": get_config_value('font_size_options.content'),
98 | "line_height": get_config_value("continer.line_height"),
99 | },
100 | "should_cross_column": "True",
101 |
102 | "figure_up": "True" if random.random() > 0.5 else None,
103 | "container_per_width": get_config_value("font_size_options.container_img_width"),
104 | "abstract_per_width": get_config_value("font_size_options.abstract_img_width"),
105 | "head_figure_width": get_config_value("font_size_options.head_figure_width"),
106 | "three_line": "True" if random.random() > 0.5 else None,
107 | "two_line": "True" if random.random() > 0.1 else None,
108 |
109 |
110 | "header": {
111 | "page_num_size": get_config_value("header.font_size"),
112 | "background_color": get_config_value('background_colors'),
113 |
114 | },
115 |
116 | "footer":{
117 | "page_num_size": get_config_value("footer.font_size"),
118 | "background_color": get_config_value('background_colors'),
119 |
120 | },
121 |
122 | "page_num":{
123 | "background_color": page_back_color,
124 | "page_num_coloer": get_text_color(page_back_color)
125 | },
126 |
127 | "container_layout": {
128 | "left": get_config_value("container_layout.left"),
129 | "gap": get_config_value("container_layout.gap"),
130 | "back_color": get_config_value('container_layout.background_colors')
131 | },
132 | "header_right": {
133 | "header_backcolor": header_back_color,
134 | "right_backcolor": right_backcolor,
135 | "header_font_color": generate_font_color(header_back_color),
136 | "right_font_color": generate_font_color(right_backcolor),
137 | "include_P": "True" if random.random()>0.5 else None,
138 | "padding_value": random.randint(16,20),
139 | }
140 | }
141 |
142 | return styles
143 |
144 |
145 | def get_styles_num(config) -> dict:
146 | """
147 | """
148 | styles = produce_stytles()
149 |
150 | styles["columns"] = config["layout_config"]["columns"]
151 |
152 |
153 | return styles
154 |
--------------------------------------------------------------------------------
/Infinity-Parser/Infinity-Synth/utils/LatexUtil.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Pattern
3 |
4 | class LatexError(Exception):
5 | pass
6 |
7 |
8 | class LatexValidationError(LatexError):
9 | pass
10 |
11 |
12 | class BracketMismatchError(LatexValidationError):
13 | pass
14 |
15 |
16 | class EnvironmentMismatchError(LatexValidationError):
17 | pass
18 |
19 |
20 | class InvalidCharacterError(LatexValidationError):
21 | pass
22 |
23 |
24 | class LatexSimplificationError(LatexError):
25 | pass
26 |
27 | class LatexValidator:
28 | _invalid_unicode_re: Pattern[str] = re.compile(r"[\u0000-\u001F\u007F]")
29 | _env_token_re: Pattern[str] = re.compile(r"\\(begin|end)\{([^\}]+)\}")
30 | _illegal_backslash_re: Pattern[str] = re.compile(r"(\\[^a-zA-Z])")
31 | _allowed_non_letter_prefixes = {
32 | "\\\\",
33 | "\\[",
34 | "\\]",
35 | "\\(",
36 | "\\)",
37 | "\\%",
38 | "\\&",
39 | "\\$",
40 | "\\#",
41 | "\\,",
42 | "\\;",
43 | "\\:",
44 | "\\!",
45 | "\\ ",
46 | "\\quad",
47 | "\\qquad",
48 | }
49 |
50 | def __call__(self, latex: str) -> bool:
51 | return self.is_valid(latex)
52 |
53 | def is_valid(self, latex: str) -> bool:
54 | if not latex or not isinstance(latex, str):
55 | raise LatexValidationError("Input is empty or not a string.")
56 |
57 | for i, line in enumerate(latex.splitlines(), start=1):
58 | if self._invalid_unicode_re.search(line):
59 | snippet = repr(line.strip())[:60]
60 | raise InvalidCharacterError(
61 | f"Line {i} contains invalid Unicode control characters: {snippet}"
62 | )
63 |
64 | if self._has_illegal_backslashes(latex):
65 | raise InvalidCharacterError("Contains illegal backslash usage.")
66 |
67 | if not self._are_brackets_balanced(latex, "{", "}"):
68 | raise BracketMismatchError("Mismatched {} brackets.")
69 | if not self._are_brackets_balanced(latex, "[", "]"):
70 | raise BracketMismatchError("Mismatched [] brackets.")
71 | if not self._are_brackets_balanced(latex, "(", ")"):
72 | raise BracketMismatchError("Mismatched () brackets.")
73 | if not self._are_environments_balanced(latex):
74 | raise EnvironmentMismatchError("Environment \\begin/\\end mismatch.")
75 | return True
76 |
77 | def _are_brackets_balanced(self, s: str, open_b: str, close_b: str) -> bool:
78 | stack = []
79 | for c in s:
80 | if c == open_b:
81 | stack.append(c)
82 | elif c == close_b:
83 | if not stack:
84 | return False
85 | stack.pop()
86 | return not stack
87 |
88 | def _are_environments_balanced(self, s: str) -> bool:
89 | tokens = self._env_token_re.findall(s)
90 | stack = []
91 | for kind, name in tokens:
92 | if kind == "begin":
93 | stack.append(name)
94 | elif kind == "end":
95 | if not stack or stack[-1] != name:
96 | return False
97 | stack.pop()
98 | return not stack
99 |
100 | def _has_illegal_backslashes(self, s: str) -> bool:
101 | for match in self._illegal_backslash_re.findall(s):
102 | if match not in self._allowed_non_letter_prefixes:
103 | return True
104 | return False
105 |
106 |
107 |
108 | class LatexSimplifier:
109 | _whitespace_re: Pattern[str] = re.compile(r"\s+")
110 | _operator_spacing_re: Pattern[str] = re.compile(r"\s*([=+\-*/<>])\s*")
111 | _inline_wrap_re: Pattern[str] = re.compile(r"^\$(.*?)\$$", re.DOTALL)
112 | _display_wrap_re: Pattern[str] = re.compile(r"^\$\$(.*?)\$\$$", re.DOTALL)
113 | _bracket_wrap_re: Pattern[str] = re.compile(r"^\\[\[\(](.*?)\\[\]\)]$", re.DOTALL)
114 | _text_expr_re: Pattern[str] = re.compile(r"\\text\{.*?\}")
115 | _operator_expr_re: Pattern[str] = re.compile(r"\\operatorname\{.*?\}")
116 | _structure_spacing_re = re.compile(r"\s*(\\(?:begin|end)\{[^\}]+\})\s*")
117 | # _old_style_font_re: Pattern[str] = re.compile(r"(\\(?:bf|it|rm|tt|sf|sl|sc))\s+")
118 | _backslash_spacing_re = re.compile(r"(\\)\s")
119 | _cmd_spacing_re = re.compile(r"(\\[a-zA-Z]+)\s+(?=[a-zA-Z])")
120 | _all_space_re = re.compile(r"\s+")
121 |
122 | @staticmethod
123 | def _protect_space(m) -> str:
124 | return m.group(0).replace(" ", "␣")
125 |
126 | @staticmethod
127 | def _protect_oldstylefontspace(m) -> str:
128 | return m.group(1) + "␣"
129 |
130 | def remove_wrappers(self, latex: str) -> str:
131 | latex = latex.strip()
132 | for pattern in [
133 | self._display_wrap_re,
134 | self._inline_wrap_re,
135 | self._bracket_wrap_re,
136 | ]:
137 | match = pattern.match(latex)
138 | if match:
139 | return match.group(1).strip()
140 | return latex
141 |
142 | def compress_whitespace(self, latex: str) -> str:
143 |
144 | latex = self._text_expr_re.sub(LatexSimplifier._protect_space, latex)
145 | latex = self._operator_expr_re.sub(LatexSimplifier._protect_space, latex)
146 |
147 | latex = self._backslash_spacing_re.sub(r"\1␣", latex)
148 |
149 | latex = self._cmd_spacing_re.sub(r"\1␣", latex)
150 |
151 | latex = self._all_space_re.sub("", latex)
152 |
153 | latex = latex.replace("␣", " ")
154 | return latex
155 |
156 |
157 |
158 |
159 | class LatexNormalizer:
160 | def __init__(
161 | self,
162 | *,
163 | strip_wrappers: bool = True,
164 | flatten_multiline_to_single_line: bool = True,
165 | simplify_whitespace: bool = True,
166 | validate: bool = True,
167 | ) -> None:
168 | self.strip_wrappers = strip_wrappers
169 | self.flatten_multiline_to_single_line = flatten_multiline_to_single_line
170 | self.simplify_whitespace = simplify_whitespace
171 | self.validate = validate
172 |
173 | self._validator = LatexValidator()
174 | self._simplifier = LatexSimplifier()
175 |
176 | def __call__(self, latex: str) -> str:
177 | if not isinstance(latex, str):
178 | raise LatexValidationError("Input is not a string.")
179 |
180 | if self.strip_wrappers:
181 | latex = self._simplifier.remove_wrappers(latex)
182 |
183 | if self.flatten_multiline_to_single_line:
184 | lines = [line.strip() for line in latex.splitlines() if line.strip()]
185 | latex = " ".join(lines)
186 |
187 | if self.simplify_whitespace:
188 | latex = self._simplifier.compress_whitespace(latex)
189 |
190 | if self.validate:
191 | self._validator(latex)
192 | return latex
--------------------------------------------------------------------------------
/Infinity-Parser/Infinity-Synth/scripts/doc_parser_v2.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from tqdm import tqdm
4 | import random
5 | import sys
6 | import os
7 |
8 | current_file = os.path.abspath(__file__) # 当前文件的绝对路径
9 | parent_dir = os.path.dirname(os.path.dirname(current_file)) # 上一级目录
10 | sys.path.append(parent_dir)
11 |
12 |
13 | from utils.LatexUtil import LatexNormalizer, LatexError
14 | from typing import TextIO
15 |
16 |
17 |
18 | latextool = LatexNormalizer()
19 |
20 | prompts = [
21 | "Please convert the document content into Markdown format.",
22 | ]
23 |
24 | from bs4 import BeautifulSoup
25 |
26 | # def html_table_to_markdown(html: str) -> str:
27 | # soup = BeautifulSoup(html, "html.parser")
28 | # table = soup.find("table")
29 | # if table is None:
30 | # return "No found."
31 |
32 | # def get_cell_text(cell):
33 | # return cell.get_text(strip=True).replace("|", "\\|")
34 |
35 | # rows = table.find_all("tr")
36 | # if not rows:
37 | # return ""
38 |
39 | # # 提取表头
40 | # header_cells = rows[0].find_all(["th", "td"])
41 | # header = [get_cell_text(cell) for cell in header_cells]
42 | # markdown = "| " + " | ".join(header) + " |\n"
43 | # markdown += "| " + " | ".join(["---"] * len(header)) + " |\n"
44 |
45 | # # 提取后续行
46 | # for row in rows[1:]:
47 | # cells = row.find_all(["td", "th"])
48 | # line = [get_cell_text(cell) for cell in cells]
49 | # markdown += "| " + " | ".join(line) + " |\n"
50 |
51 | # return markdown
52 |
53 | def html_table_to_markdown(html: str) -> str:
54 | soup = BeautifulSoup(html, "html.parser")
55 | table = soup.find("table")
56 | if table is None:
57 | return "No found."
58 |
59 | def get_cell_text(cell):
60 | return cell.get_text(strip=True).replace("|", "\\|")
61 |
62 | rows = table.find_all("tr")
63 | if not rows:
64 | return ""
65 |
66 | # 构建表格矩阵来处理跨行跨列
67 | matrix = []
68 | max_cols = 0
69 |
70 | # 第一遍:计算最大列数
71 | for row in rows:
72 | cells = row.find_all(["td", "th"])
73 | col_count = 0
74 | for cell in cells:
75 | colspan = int(cell.get("colspan", 1))
76 | col_count += colspan
77 | max_cols = max(max_cols, col_count)
78 |
79 | # 第二遍:构建矩阵
80 | for row_idx, row in enumerate(rows):
81 | if row_idx >= len(matrix):
82 | matrix.append([None] * max_cols)
83 |
84 | cells = row.find_all(["td", "th"])
85 | col_idx = 0
86 |
87 | for cell in cells:
88 | # 找到下一个空的位置
89 | while col_idx < max_cols and matrix[row_idx][col_idx] is not None:
90 | col_idx += 1
91 |
92 | if col_idx >= max_cols:
93 | break
94 |
95 | colspan = int(cell.get("colspan", 1))
96 | rowspan = int(cell.get("rowspan", 1))
97 | cell_text = get_cell_text(cell)
98 |
99 | # 填充当前单元格及其跨越的区域
100 | for r in range(row_idx, min(row_idx + rowspan, len(rows))):
101 | # 确保有足够的行
102 | while len(matrix) <= r:
103 | matrix.append([None] * max_cols)
104 |
105 | for c in range(col_idx, min(col_idx + colspan, max_cols)):
106 | if r == row_idx and c == col_idx:
107 | # 主单元格
108 | matrix[r][c] = cell_text
109 | else:
110 | # 跨越区域标记为空字符串
111 | matrix[r][c] = ""
112 |
113 | col_idx += colspan
114 |
115 | # 确保所有行都有相同的列数
116 | for row in matrix:
117 | while len(row) < max_cols:
118 | row.append("")
119 | # 将None替换为空字符串
120 | for i in range(len(row)):
121 | if row[i] is None:
122 | row[i] = ""
123 |
124 | if not matrix:
125 | return ""
126 |
127 | # 生成Markdown表格
128 | markdown_lines = []
129 |
130 | # 表头
131 | header_line = "| " + " | ".join(matrix[0]) + " |"
132 | markdown_lines.append(header_line)
133 |
134 | # 分隔线
135 | separator_line = "| " + " | ".join(["---"] * max_cols) + " |"
136 | markdown_lines.append(separator_line)
137 |
138 | # 数据行
139 | for row in matrix[1:]:
140 | data_line = "| " + " | ".join(row) + " |"
141 | markdown_lines.append(data_line)
142 |
143 | return "\n".join(markdown_lines)
144 |
145 |
146 | def form2docparse(datas):
147 |
148 | results = []
149 | for ind, data in tqdm(enumerate(datas)):
150 | image = data['image']
151 | res = []
152 | try:
153 | for idx, item in enumerate(data['form']):
154 | if item['category'] == 'title':
155 | res.append('#'*item['level'] + ' ' + item['text'])
156 | elif item['category'] == "formula":
157 | res.append("$$" + latextool(item['text']) + "$$")
158 | elif item['category'] not in ['figure', 'header', 'footer', 'table', "formula"]:
159 | res.append(item['text'])
160 | elif item['category'] == "table":
161 | res.append(html_table_to_markdown(item['text']))
162 | markdown = '\n\n'.join(res)
163 | results.append({
164 | 'images': [image],
165 | 'conversations': [
166 | {
167 | 'from': 'human',
168 | 'value': random.choice(prompts)
169 | },
170 | {
171 | 'from': 'gpt',
172 | 'value': f'```markdown\n{markdown}\n```'
173 | }
174 | ]
175 |
176 | })
177 | except Exception as e:
178 | continue
179 |
180 | return results
181 |
182 | def load_and_merge_json_files(directory):
183 | """读取目录下所有 JSON 文件并合并成一个字典列表"""
184 | merged_data = []
185 | for filename in os.listdir(directory):
186 | if filename.endswith(".json"):
187 | filepath = os.path.join(directory, filename)
188 | with open(filepath, "r", encoding="utf-8") as file:
189 | data = json.load(file)
190 | if isinstance(data, list): # 如果 JSON 是数组形式,直接合并
191 | merged_data.extend(data)
192 | else: # 如果是单个对象,加入列表
193 | merged_data.append(data)
194 | return merged_data
195 |
196 | if __name__ == "__main__":
197 | if len(sys.argv) != 3:
198 | print("Usage: python script.py ")
199 | sys.exit(1)
200 |
201 | input_dir = sys.argv[1]
202 | output_file = sys.argv[2]
203 |
204 | # 读取并合并目录下所有 JSON 文件
205 | merged_data = load_and_merge_json_files(input_dir)
206 |
207 | # 处理合并后的数据
208 | result = form2docparse(merged_data)
209 |
210 | # 输出结果到文件
211 | with open(output_file, "w", encoding="utf-8") as file:
212 | json.dump(result, file, indent=2, ensure_ascii=False)
213 |
214 |
--------------------------------------------------------------------------------
/INF-MLLM1/evaluate/infmllm_chat/model_vqa_loader.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | from PIL import Image
7 | import math
8 | import shortuuid
9 | from torch.utils.data import Dataset, DataLoader
10 |
11 | from transformers import AutoModel, AutoTokenizer
12 | from evaluate.infmllm_chat.utils import tokenizer_image_token
13 | from evaluate.infmllm_chat.conversation import conv_templates, SeparatorStyle
14 |
15 | IMAGE_TOKEN_INDEX = -200
16 | DEFAULT_IMAGE_TOKEN = ""
17 |
18 | def expand2square(pil_img, background_color):
19 | # pad to middle for square shape
20 | width, height = pil_img.size
21 | if width == height:
22 | return pil_img
23 | elif width > height:
24 | result = Image.new(pil_img.mode, (width, width), background_color)
25 | result.paste(pil_img, (0, (width - height) // 2))
26 | return result
27 | else:
28 | result = Image.new(pil_img.mode, (height, height), background_color)
29 | result.paste(pil_img, ((height - width) // 2, 0))
30 | return result
31 |
32 | def disable_torch_init():
33 | """
34 | Disable the redundant torch default initialization to accelerate model creation.
35 | """
36 | import torch
37 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
38 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
39 |
40 | def split_list(lst, n):
41 | """Split a list into n (roughly) equal-sized chunks"""
42 | chunk_size = math.ceil(len(lst) / n) # integer division
43 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
44 |
45 |
46 | def get_chunk(lst, n, k):
47 | chunks = split_list(lst, n)
48 | return chunks[k]
49 |
50 |
51 | # Custom dataset class
52 | class CustomDataset(Dataset):
53 | def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
54 | self.questions = questions
55 | self.image_folder = image_folder
56 | self.tokenizer = tokenizer
57 | self.image_processor = image_processor
58 | self.model_config = model_config
59 |
60 | def __getitem__(self, index):
61 | line = self.questions[index]
62 | image_file = line["image"]
63 | qs = line["text"]
64 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
65 |
66 | conv = conv_templates[args.conv_mode].copy()
67 | conv.append_message(conv.roles[0], qs)
68 | conv.append_message(conv.roles[1], None)
69 | prompt = conv.get_prompt()
70 |
71 | image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
72 | # To be consistent with training ?
73 | image = expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
74 | image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
75 |
76 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
77 |
78 | return input_ids, image_tensor
79 |
80 | def __len__(self):
81 | return len(self.questions)
82 |
83 |
84 | # DataLoader
85 | def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
86 | assert batch_size == 1, "batch_size must be 1"
87 | dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
88 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
89 | return data_loader
90 |
91 |
92 | def eval_model(args):
93 | # Model
94 | disable_torch_init()
95 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False)
96 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
97 | model = model.cuda().eval()
98 | image_processor = model.get_model().get_vision_tower().image_processor
99 |
100 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
101 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
102 | answers_file = os.path.expanduser(args.answers_file)
103 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
104 | ans_file = open(answers_file, "w")
105 |
106 | data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
107 |
108 | for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
109 | idx = line["question_id"]
110 | cur_prompt = line["text"]
111 |
112 | stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2
113 | input_ids = input_ids.to(device='cuda', non_blocking=True)
114 |
115 | with torch.inference_mode():
116 | output_ids = model.generate(
117 | input_ids,
118 | images=image_tensor.to(dtype=torch.bfloat16, device='cuda', non_blocking=True),
119 | do_sample=True if args.temperature > 0 else False,
120 | temperature=args.temperature,
121 | top_p=args.top_p,
122 | num_beams=args.num_beams,
123 | max_new_tokens=128,
124 | use_cache=True)
125 |
126 | input_token_len = input_ids.shape[1]
127 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
128 | if n_diff_input_output > 0:
129 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
130 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
131 | outputs = outputs.strip()
132 | if outputs.endswith(stop_str):
133 | outputs = outputs[:-len(stop_str)]
134 | outputs = outputs.strip()
135 |
136 | ans_id = shortuuid.uuid()
137 | ans_file.write(json.dumps({"question_id": idx,
138 | "prompt": cur_prompt,
139 | "text": outputs,
140 | "answer_id": ans_id,
141 | "metadata": {}}) + "\n")
142 | # ans_file.flush()
143 | ans_file.close()
144 |
145 | print("image_size: {}".format(model.config.image_size))
146 | print("pool_out_size: {}".format(model.config.pool_out_size))
147 |
148 | if __name__ == "__main__":
149 | parser = argparse.ArgumentParser()
150 | parser.add_argument("--model-path", type=str)
151 | parser.add_argument("--image-folder", type=str, default="")
152 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
153 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
154 | parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
155 | parser.add_argument("--num-chunks", type=int, default=1)
156 | parser.add_argument("--chunk-idx", type=int, default=0)
157 | parser.add_argument("--temperature", type=float, default=0.2)
158 | parser.add_argument("--top_p", type=float, default=None)
159 | parser.add_argument("--num_beams", type=int, default=1)
160 | args = parser.parse_args()
161 | eval_model(args)
162 |
--------------------------------------------------------------------------------
/Infinity-Parser/Infinity-Synth/utils/Text.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import random
3 | import json
4 | from bs4 import BeautifulSoup # required for check_merged_cells()
5 |
6 |
7 | def add_html_header(text: str, level: int, serial_num: str) -> str:
8 | """
9 | Wrap the given text with an HTML header tag based on level (h2, h3, h4).
10 | :param text: header text
11 | :param level: heading level 1–3 (internally mapped to h2–h4)
12 | :param serial_num: numbering prefix like "1.2.3"
13 | """
14 | level = level + 1 # convert 1→h2, 2→h3, 3→h4
15 | if level not in [2, 3, 4]:
16 | raise ValueError("Header level must map to h2, h3, or h4")
17 |
18 | return f"{serial_num} {text}"
19 |
20 |
21 | def generate_next_headings(levels: list, start: str) -> list:
22 | """
23 | Given a list of hierarchical levels and a starting heading number,
24 | generate the subsequent hierarchical numbering.
25 | Example: levels=[2,3,2], start="2.1" → ["2.1.1", "2.2"]
26 | """
27 | current = list(map(int, start.split('.')))
28 | results = [start]
29 |
30 | for level in levels:
31 | if level > len(current):
32 | current.append(1)
33 | elif level == len(current):
34 | current[-1] += 1
35 | else:
36 | current = current[:level]
37 | current[-1] += 1
38 |
39 | results.append('.'.join(map(str, current)))
40 |
41 | return results[1:]
42 |
43 |
44 | def generate_random_list(length: int) -> list:
45 | """
46 | Generate a random hierarchical list of 1/2/3 levels, where 1 and 3 cannot be adjacent.
47 | """
48 | if length <= 0:
49 | return []
50 |
51 | result = []
52 | choices = [1, 2, 3]
53 |
54 | for i in range(length):
55 | if i == 0:
56 | result.append(random.choice(choices))
57 | else:
58 | if result[-1] == 1:
59 | next_choices = [2]
60 | elif result[-1] == 3:
61 | next_choices = [2]
62 | else:
63 | next_choices = choices
64 | result.append(random.choice(next_choices))
65 |
66 | return result
67 |
68 |
69 | def generate_random_number(level):
70 | """
71 | Generate hierarchical numbering based on level depth 1/2/3.
72 | """
73 | parts = [random.randint(1, 10) for _ in range(level)]
74 | return ".".join(map(str, parts))
75 |
76 |
77 | def produce_multihead_number(text: dict):
78 | """
79 | Build multi-level HTML headings and merge adjacent paragraphs randomly.
80 | """
81 | level = generate_random_list(len(text))
82 | start_num = generate_random_number(level[0])
83 | num_list = generate_next_headings(level, start_num)
84 |
85 | ordered = OrderedDict()
86 | pre_text = ""
87 |
88 | for i, (key, value) in enumerate(text.items()):
89 | next_level = level[i + 1] if i + 1 < len(text) else 1
90 | new_key = add_html_header(key, level[i], num_list[i])
91 |
92 | if next_level > level[i] and random.random() > 0.3 and isinstance(value, str):
93 | ordered[new_key] = None
94 | pre_text = value
95 | else:
96 | if isinstance(value, dict):
97 | ordered[new_key] = value
98 | elif isinstance(value, list):
99 | value.append(pre_text)
100 | pre_text = ""
101 | ordered[new_key] = value
102 | else:
103 | ordered[new_key] = value + pre_text
104 | pre_text = ""
105 |
106 | return ordered
107 |
108 |
109 | def generate_random_list_only_2(length: int) -> tuple:
110 | """
111 | Randomly generate a level list using only {1,2} or {2,3}.
112 | """
113 | mode = random.choice(['1,2', '2,3'])
114 | choices = [1, 2] if mode == '1,2' else [2, 3]
115 | return random.choices(choices, k=length), mode
116 |
117 |
118 | def generate_title_numbers(levels, mode):
119 | """
120 | Generate hierarchical title numbering, ensuring consistent style per level.
121 | Reset lower-level counters when higher ones appear.
122 | """
123 | if len(levels) > 40:
124 | print("Too long")
125 | return []
126 |
127 | counters = {lvl: 1 for lvl in range(1, max(levels) + 1)}
128 | chinese = [
129 | '一', '二', '三', '四', '五', '六', '七', '八', '九', '十',
130 | '十一', '十二', '十三', '十四', '十五', '十六', '十七', '十八', '十九', '二十',
131 | '二十一', '二十二', '二十三', '二十四', '二十五', '二十六', '二十七', '二十八', '二十九', '三十'
132 | ]
133 | chinese_b = [f"({c})" for c in chinese]
134 | arabic = [f"第{x}节" for x in range(1, 51)]
135 |
136 | style_defs = {
137 | 1: [lambda x: chinese_b[x - 1], lambda x: f"第{x}章", lambda x: chinese[x - 1]],
138 | 2: [lambda x: arabic[x - 1], lambda x: f"第{x}节", lambda x: f"(第{x}节)"],
139 | 3: [lambda x: chinese[x - 1], lambda x: chinese_b[x - 1]],
140 | }
141 |
142 | available_levels = [1, 2] if mode == '1,2' else [2, 3]
143 | used = set()
144 | level_styles = {}
145 |
146 | for lvl in available_levels:
147 | opts = [f for f in style_defs[lvl] if f not in used]
148 | style = random.choice(opts) if opts else (lambda x: f"{lvl}.{x}")
149 | level_styles[lvl] = style
150 | used.add(style)
151 |
152 | result = []
153 | for lvl in levels:
154 | if lvl not in available_levels:
155 | continue
156 | num = counters[lvl]
157 | style = level_styles[lvl]
158 | result.append(style(num))
159 | counters[lvl] += 1
160 | for lower in range(lvl + 1, max(levels) + 1):
161 | counters[lower] = 1
162 |
163 | return result
164 |
165 |
166 | def produce_simple_number(text: dict):
167 | """
168 | Build simple hierarchical headings with either 1–2 or 2–3 rules.
169 | """
170 | level, mode = generate_random_list_only_2(len(text))
171 | num_list = generate_title_numbers(level, mode)
172 |
173 | ordered = OrderedDict()
174 | pre_text = ""
175 |
176 | for i, (key, value) in enumerate(text.items()):
177 | next_level = level[i + 1] if i + 1 < len(text) else 1
178 | new_key = add_html_header(key, level[i], num_list[i])
179 |
180 | if next_level > level[i] and random.random() > 0.3 and isinstance(value, str):
181 | ordered[new_key] = None
182 | pre_text = value
183 | else:
184 | if isinstance(value, dict):
185 | ordered[new_key] = value
186 | elif isinstance(value, list):
187 | value.append(pre_text)
188 | pre_text = ""
189 | ordered[new_key] = value
190 | else:
191 | ordered[new_key] = value + pre_text
192 | pre_text = ""
193 |
194 | return ordered
195 |
196 |
197 | def check_merged_cells(html_content: str) -> bool:
198 | """
199 | Detect if HTML tables contain colspan or rowspan (merged cells).
200 | """
201 | soup = BeautifulSoup(html_content, 'html.parser')
202 | for table in soup.find_all('table'):
203 | for cell in table.find_all(['td', 'th']):
204 | if cell.has_attr('colspan') or cell.has_attr('rowspan'):
205 | return True
206 | return False
207 |
--------------------------------------------------------------------------------
/INF-MLLM1/evaluate/infmllm_chat/model_vqa_science.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 | from PIL import Image
8 | import math
9 |
10 | from transformers import AutoModel, AutoTokenizer
11 | from evaluate.infmllm_chat.utils import tokenizer_image_token, KeywordsStoppingCriteria
12 | from evaluate.infmllm_chat.conversation import conv_templates, SeparatorStyle
13 |
14 | IMAGE_TOKEN_INDEX = -200
15 | DEFAULT_IMAGE_TOKEN = ""
16 |
17 | def disable_torch_init():
18 | """
19 | Disable the redundant torch default initialization to accelerate model creation.
20 | """
21 | import torch
22 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
23 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
24 |
25 | def expand2square(pil_img, background_color):
26 | # pad to middle for square shape
27 | width, height = pil_img.size
28 | if width == height:
29 | return pil_img
30 | elif width > height:
31 | result = Image.new(pil_img.mode, (width, width), background_color)
32 | result.paste(pil_img, (0, (width - height) // 2))
33 | return result
34 | else:
35 | result = Image.new(pil_img.mode, (height, height), background_color)
36 | result.paste(pil_img, ((height - width) // 2, 0))
37 | return result
38 |
39 | def split_list(lst, n):
40 | """Split a list into n (roughly) equal-sized chunks"""
41 | chunk_size = math.ceil(len(lst) / n) # integer division
42 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
43 |
44 |
45 | def get_chunk(lst, n, k):
46 | chunks = split_list(lst, n)
47 | return chunks[k]
48 |
49 |
50 | def eval_model(args):
51 | # Model
52 | disable_torch_init()
53 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False)
54 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
55 | model = model.cuda().eval()
56 | image_processor = model.get_model().get_vision_tower().image_processor
57 |
58 | questions = json.load(open(os.path.expanduser(args.question_file), "r"))
59 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
60 | answers_file = os.path.expanduser(args.answers_file)
61 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
62 | ans_file = open(answers_file, "w")
63 | for i, line in enumerate(tqdm(questions)):
64 | idx = line["id"]
65 | question = line['conversations'][0]
66 | qs = question['value'].replace('', '').strip()
67 | cur_prompt = qs
68 |
69 | if 'image' in line:
70 | image_file = line["image"]
71 | image = Image.open(os.path.join(args.image_folder, image_file))
72 |
73 | # To be consistent with training ?
74 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
75 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
76 | images = image_tensor.unsqueeze(0).to(dtype=torch.bfloat16, device='cuda')
77 |
78 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
79 | cur_prompt = '' + '\n' + cur_prompt
80 | else:
81 | images = None
82 |
83 | if args.single_pred_prompt:
84 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
85 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
86 |
87 | conv = conv_templates[args.conv_mode].copy()
88 | conv.append_message(conv.roles[0], qs)
89 | conv.append_message(conv.roles[1], None)
90 | prompt = conv.get_prompt()
91 |
92 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
93 |
94 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
95 | keywords = [stop_str]
96 | stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None
97 |
98 | with torch.inference_mode():
99 | output_ids = model.generate(
100 | input_ids,
101 | images=images,
102 | do_sample=True if args.temperature > 0 else False,
103 | temperature=args.temperature,
104 | max_new_tokens=1024,
105 | use_cache=True,
106 | stopping_criteria=stopping_criteria,
107 | )
108 |
109 | input_token_len = input_ids.shape[1]
110 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
111 | if n_diff_input_output > 0:
112 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
113 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
114 | outputs = outputs.strip()
115 | if outputs.endswith(stop_str):
116 | outputs = outputs[:-len(stop_str)]
117 | outputs = outputs.strip()
118 |
119 | # prompt for answer
120 | if args.answer_prompter:
121 | outputs_reasoning = outputs
122 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
123 |
124 | with torch.inference_mode():
125 | output_ids = model.generate(
126 | input_ids,
127 | images=images,
128 | do_sample=True if args.temperature > 0 else False,
129 | temperature=args.temperature,
130 | max_new_tokens=64,
131 | use_cache=True,
132 | stopping_criteria=[stopping_criteria])
133 |
134 | input_token_len = input_ids.shape[1]
135 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
136 | if n_diff_input_output > 0:
137 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
138 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
139 | outputs = outputs.strip()
140 | if outputs.endswith(stop_str):
141 | outputs = outputs[:-len(stop_str)]
142 | outputs = outputs.strip()
143 | outputs = outputs_reasoning + '\n The answer is ' + outputs
144 |
145 | ans_id = shortuuid.uuid()
146 | ans_file.write(json.dumps({"question_id": idx,
147 | "prompt": cur_prompt,
148 | "text": outputs,
149 | "answer_id": ans_id,
150 | "metadata": {}}) + "\n")
151 | ans_file.flush()
152 | ans_file.close()
153 |
154 | print("image_size: {}".format(model.config.image_size))
155 | print("pool_out_size: {}".format(model.config.pool_out_size))
156 |
157 | if __name__ == "__main__":
158 | parser = argparse.ArgumentParser()
159 | parser.add_argument("--model-path", type=str)
160 | parser.add_argument("--image-folder", type=str, default="")
161 | parser.add_argument("--question-file", type=str, default="tables/question.json")
162 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
163 | parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
164 | parser.add_argument("--num-chunks", type=int, default=1)
165 | parser.add_argument("--chunk-idx", type=int, default=0)
166 | parser.add_argument("--temperature", type=float, default=0.2)
167 | parser.add_argument("--answer-prompter", action="store_true")
168 | parser.add_argument("--single-pred-prompt", action="store_true")
169 | args = parser.parse_args()
170 |
171 | eval_model(args)
172 |
--------------------------------------------------------------------------------
/INF-MLLM1/evaluate/infmllm_chat/model_vqa_mmbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | import pandas as pd
6 | from tqdm import tqdm
7 | import shortuuid
8 | from PIL import Image
9 | import math
10 |
11 | from transformers import AutoModel, AutoTokenizer
12 | from evaluate.infmllm_chat.utils import tokenizer_image_token, load_image_from_base64
13 | from evaluate.infmllm_chat.conversation import conv_templates, SeparatorStyle
14 |
15 | IMAGE_TOKEN_INDEX = -200
16 | DEFAULT_IMAGE_TOKEN = ""
17 |
18 | all_options = ['A', 'B', 'C', 'D']
19 |
20 |
21 | def disable_torch_init():
22 | """
23 | Disable the redundant torch default initialization to accelerate model creation.
24 | """
25 | import torch
26 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
27 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
28 |
29 | def expand2square(pil_img, background_color):
30 | # pad to middle for square shape
31 | width, height = pil_img.size
32 | if width == height:
33 | return pil_img
34 | elif width > height:
35 | result = Image.new(pil_img.mode, (width, width), background_color)
36 | result.paste(pil_img, (0, (width - height) // 2))
37 | return result
38 | else:
39 | result = Image.new(pil_img.mode, (height, height), background_color)
40 | result.paste(pil_img, ((height - width) // 2, 0))
41 | return result
42 |
43 | def split_list(lst, n):
44 | """Split a list into n (roughly) equal-sized chunks"""
45 | chunk_size = math.ceil(len(lst) / n) # integer division
46 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
47 |
48 |
49 | def get_chunk(lst, n, k):
50 | chunks = split_list(lst, n)
51 | return chunks[k]
52 |
53 |
54 | def is_none(value):
55 | if value is None:
56 | return True
57 | if type(value) is float and math.isnan(value):
58 | return True
59 | if type(value) is str and value.lower() == 'nan':
60 | return True
61 | if type(value) is str and value.lower() == 'none':
62 | return True
63 | return False
64 |
65 | def get_options(row, options):
66 | parsed_options = []
67 | for option in options:
68 | option_value = row[option]
69 | if is_none(option_value):
70 | break
71 | parsed_options.append(option_value)
72 | return parsed_options
73 |
74 |
75 | def eval_model(args):
76 | # Model
77 | disable_torch_init()
78 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False)
79 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
80 | model = model.cuda().eval()
81 | image_processor = model.get_model().get_vision_tower().image_processor
82 |
83 | questions = pd.read_table(os.path.expanduser(args.question_file))
84 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
85 | answers_file = os.path.expanduser(args.answers_file)
86 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
87 | ans_file = open(answers_file, "w")
88 |
89 | for index, row in tqdm(questions.iterrows(), total=len(questions)):
90 | options = get_options(row, all_options)
91 | cur_option_char = all_options[:len(options)]
92 |
93 | if args.all_rounds:
94 | num_rounds = len(options)
95 | else:
96 | num_rounds = 1
97 |
98 | for round_idx in range(num_rounds):
99 | idx = row['index']
100 | question = row['question']
101 | hint = row['hint']
102 | image = load_image_from_base64(row['image'])
103 | if not is_none(hint):
104 | question = hint + '\n' + question
105 | for option_char, option in zip(all_options[:len(options)], options):
106 | question = question + '\n' + option_char + '. ' + option
107 | qs = cur_prompt = question
108 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
109 |
110 | if args.single_pred_prompt:
111 | if args.lang == 'cn':
112 | qs = qs + '\n' + "请直接回答选项字母。"
113 | else:
114 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
115 |
116 | conv = conv_templates[args.conv_mode].copy()
117 | conv.append_message(conv.roles[0], qs)
118 | conv.append_message(conv.roles[1], None)
119 | prompt = conv.get_prompt()
120 |
121 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
122 |
123 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
124 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
125 |
126 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
127 |
128 | with torch.inference_mode():
129 | output_ids = model.generate(
130 | input_ids,
131 | images=image_tensor.unsqueeze(0).to(dtype=torch.bfloat16, device='cuda'),
132 | do_sample=True if args.temperature > 0 else False,
133 | temperature=args.temperature,
134 | top_p=args.top_p,
135 | num_beams=args.num_beams,
136 | # no_repeat_ngram_size=3,
137 | max_new_tokens=1024,
138 | use_cache=True)
139 |
140 | input_token_len = input_ids.shape[1]
141 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
142 | if n_diff_input_output > 0:
143 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
144 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
145 | outputs = outputs.strip()
146 | if outputs.endswith(stop_str):
147 | outputs = outputs[:-len(stop_str)]
148 | outputs = outputs.strip()
149 |
150 | ans_id = shortuuid.uuid()
151 | ans_file.write(json.dumps({"question_id": idx,
152 | "round_id": round_idx,
153 | "prompt": cur_prompt,
154 | "text": outputs,
155 | "options": options,
156 | "option_char": cur_option_char,
157 | "answer_id": ans_id,
158 | "metadata": {}}) + "\n")
159 | ans_file.flush()
160 |
161 | # rotate options
162 | options = options[1:] + options[:1]
163 | cur_option_char = cur_option_char[1:] + cur_option_char[:1]
164 | ans_file.close()
165 |
166 | print("image_size: {}".format(model.config.image_size))
167 | print("pool_out_size: {}".format(model.config.pool_out_size))
168 |
169 |
170 | if __name__ == "__main__":
171 | parser = argparse.ArgumentParser()
172 | parser.add_argument("--model-path", type=str)
173 | parser.add_argument("--image-folder", type=str, default="")
174 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
175 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
176 | parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
177 | parser.add_argument("--num-chunks", type=int, default=1)
178 | parser.add_argument("--chunk-idx", type=int, default=0)
179 | parser.add_argument("--temperature", type=float, default=0.2)
180 | parser.add_argument("--top_p", type=float, default=None)
181 | parser.add_argument("--num_beams", type=int, default=1)
182 | parser.add_argument("--all-rounds", action="store_true")
183 | parser.add_argument("--single-pred-prompt", action="store_true")
184 | parser.add_argument("--lang", type=str, default="en")
185 | args = parser.parse_args()
186 |
187 | eval_model(args)
188 |
--------------------------------------------------------------------------------
/Infinity-Parser/Infinity-Synth/templates/three_columns/document.html.jinja:
--------------------------------------------------------------------------------
1 | {# Copyright (c) Microsoft Corporation. All rights reserved. #}
2 |
3 | {% extends "base.html.jinja" %}
4 | {%- block style %}
5 | {# Global Style #}
6 | {% import "macro/dimension.css.jinja" as dimension %}
7 | {{ dimension.a4_paper() }}
8 | {% import "macro/text.css.jinja" as text %}
9 | {{ text.set_font(font_family, font_size) }}
10 | {{ text.set_hyphenation(hyphenate) }}
11 | {{ text.set_text_align(text_align) }}
12 | {% import "macro/page_layout.css.jinja" as layout %}
13 | {{ layout.set_page_num() }}
14 | {# Element-Specific Style #}
15 | {%- include "three_columns/document.css.jinja" with context %}
16 |
17 | mjx-container[jax="CHTML"][display="false"] { display: inline-block; vertical-align: baseline; }
18 | mjx-container[jax="CHTML"][display="true"] { display: block; text-align: center; margin: .6em 0; }
19 | pre, code { white-space: pre; }
20 |
21 | {% endblock style %}
22 |
23 | {% block body %}
24 |
25 |
26 |
27 |
28 | {% set header = input_data.get('header', {}) %}
29 | {% if header %}
30 |
44 | {% endif %}
45 |
46 |
47 |
48 | {% set ns = namespace(formula_idx=1, fig_idx=1, tab_idx=1) %}
49 | {% for ele in input_data.get("body", None) %}
50 |
51 | {% if ele.type == "table" %}
52 |
53 |
54 |
{{ ele.caption }}
55 |
56 | {{ ele.html | safe }}
57 |
58 |
59 |
60 |
61 |
62 | {% set ns.tab_idx = ns.tab_idx + 1 %}
63 | {% elif ele.type == "figure" %}
64 |

65 |
66 |
图{{ ns.fig_idx }}:{{ ele.caption }}
67 |
68 | {% set ns.fig_idx = ns.fig_idx + 1 %}
69 |
70 | {% elif ele.type == "title" %}
71 |
{{ ele.content }}
72 |
73 | {% elif ele.type == "Body" %}
74 |
75 |
{{ ele.heading }}
76 |
77 | {% for txt in ele.text %}
78 |
{{ txt }}
79 | {% endfor %}
80 |
81 | {% elif ele.type == "formula" %}
82 |
83 |
89 |
90 | {% set ns.formula_idx = ns.formula_idx + 1 %}
91 |
92 | {% endif %}
93 |
94 | {% endfor %}
95 |
96 |
97 |
98 |
99 | {% set page_footnote = input_data.get('page_footnote', None) %}
100 |
101 |
102 |
103 | {% if page_footnote %}
104 |
107 | {% endif %}
108 |
109 |
110 | {% set footer = input_data.get('footer', {}) %}
111 | {% if footer %}
112 |
124 | {% endif %}
125 |
126 |
127 |
128 |
233 |
234 |
235 |
252 |
254 |
255 | {% endblock body %}
--------------------------------------------------------------------------------