├── LICENSE
├── README.md
├── deepspeed
└── ds_zero3_config.json
├── images
├── abst.png
├── image1.png
└── image2.png
├── requirements.txt
├── scripts
├── create_test_chartqa_generator.sh
├── create_test_dude_generator.sh
├── create_test_infovqa_generator.sh
├── create_test_slidevqa_generator.sh
└── create_train_generator.sh
├── setup.py
├── src
└── vdocrag
│ ├── __init__.py
│ ├── utils
│ ├── __init__.py
│ ├── eval_opendocvqa.py
│ └── format
│ │ ├── __init__.py
│ │ ├── convert_qas_to_trec_qrels.py
│ │ └── convert_result_to_trec.py
│ ├── vdocgenerator
│ ├── __init__.py
│ ├── arguments.py
│ ├── collator.py
│ ├── dataset.py
│ ├── driver
│ │ ├── generate.py
│ │ └── train.py
│ ├── modeling
│ │ ├── __init__.py
│ │ └── vdocgenerator.py
│ └── trainer.py
│ └── vdocretriever
│ ├── __init__.py
│ ├── arguments.py
│ ├── collator.py
│ ├── dataset.py
│ ├── driver
│ ├── encode.py
│ ├── search.py
│ └── train.py
│ ├── modeling
│ ├── __init__.py
│ └── vdocretriever.py
│ ├── searcher.py
│ └── trainer.py
└── test.py
/LICENSE:
--------------------------------------------------------------------------------
1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION
2 |
3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT").
4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE.
5 |
6 |
7 | BACKGROUND
8 | A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and related documentation listed in Exhibit A to this Agreement.
9 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement.
10 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement.
11 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows:
12 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper submitted by NTT to a certain academy. User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1.
13 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software.
14 | 3. Term. This Agreement is effective whichever is earlier (i) upon User’s acceptance of the Agreement, or (ii) upon User’s installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by User’s decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof, or to destroy all such materials and provide written verification of such destruction to NTT.
15 | 4. Proprietary Rights
16 | (a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights, copyrights and trade secret rights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software.
17 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE.
18 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied.
19 | 5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE.
20 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT.
21 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3.
22 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent.
23 | 9. General
24 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect.
25 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter.
26 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User.
27 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding.
28 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof.
29 | (f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT’s obligation set forth under this Agreement due to any cause beyond NTT’s reasonable control.
30 |
31 | EXHIBIT A
32 | The software and related data include the following files,
33 | ├── LICENSE
34 | ├── README.md
35 | ├── deepspeed
36 | │ └── ds_zero3_config.json
37 | ├── images
38 | │ ├── abst.png
39 | │ ├── image1.png
40 | │ └── image2.png
41 | ├── requirements.txt
42 | ├── scripts
43 | │ ├── create_test_chartqa_generator.sh
44 | │ ├── create_test_dude_generator.sh
45 | │ ├── create_test_infovqa_generator.sh
46 | │ ├── create_test_slidevqa_generator.sh
47 | │ └── create_train_generator.sh
48 | ├── setup.py
49 | ├── src
50 | │ └── vdocrag
51 | │ ├── __init__.py
52 | │ ├── utils
53 | │ │ ├── __init__.py
54 | │ │ ├── eval_opendocvqa.py
55 | │ │ └── format
56 | │ │ ├── __init__.py
57 | │ │ ├── convert_qas_to_trec_qrels.py
58 | │ │ └── convert_result_to_trec.py
59 | │ ├── vdocgenerator
60 | │ │ ├── __init__.py
61 | │ │ ├── arguments.py
62 | │ │ ├── collator.py
63 | │ │ ├── dataset.py
64 | │ │ ├── driver
65 | │ │ │ ├── generate.py
66 | │ │ │ └── train.py
67 | │ │ ├── modeling
68 | │ │ │ ├── __init__.py
69 | │ │ │ └── vdocgenerator.py
70 | │ │ └── trainer.py
71 | │ └── vdocretriever
72 | │ ├── __init__.py
73 | │ ├── arguments.py
74 | │ ├── collator.py
75 | │ ├── dataset.py
76 | │ ├── driver
77 | │ │ ├── encode.py
78 | │ │ ├── search.py
79 | │ │ └── train.py
80 | │ ├── modeling
81 | │ │ ├── __init__.py
82 | │ │ └── vdocretriever.py
83 | │ ├── searcher.py
84 | │ └── trainer.py
85 | └── test.py
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # VDocRAG: Retirval-Augmented Generation over Visually-Rich Documents
4 |
5 | [](https://vdocrag.github.io/)
6 | [](http://arxiv.org/abs/2504.09795)
7 | [](https://huggingface.co/NTT-hil-insight/VDocRetriever-Phi3-vision)
8 | [](https://huggingface.co/NTT-hil-insight/VDocGenerator-Phi3-vision)
9 | [](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA)
10 | [](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA-Corpus)
11 | [](https://cvpr.thecvf.com/)
12 |
13 |
14 | This repository includes the VDocRAG introduced by the following paper: Ryota Tanaka, Taichi Iki, Taku Hasegawa, Kyosuke Nishida, Kuniko Saito, and Jun Suzuki. [VDocRAG: Retirval-Augmented Generation over Visually-Rich Documents](http://arxiv.org/abs/2504.09795). In Proc. of CVPR 2025.
15 |
16 |
17 | **VDocRAG** is a new RAG framework that can directly understand diverse real-world documents purely from visual features.
18 |
19 | **💪 Key Enhancements of VDocRAG:**
20 | - **New Pretraining Tasks:** we propose novel self-supervised pre-training tasks (**RCR** and **RCG**) that adapt large vision-language models for retrieval by compressing visual information into dense token representations while aligning them with textual content in documents.
21 | - **New Dataset:** we introduce **OpenDocVQA**, the first unified collection of open-domain document visual question answering datasets, encompassing diverse document types and formats.
22 |
23 |
24 |

25 |
26 |
27 | # 📌Contents
28 | - [News](#news)
29 | - [Installation](#installation)
30 | - [Quick Start](#quick_start)
31 | - [Dataset](#dataset)
32 | - [Retriever](#retriever)
33 | - [Generator](#generator)
34 | - [LICENSE](#license)
35 | - [Citation](#citation)
36 | - [Acknowledgement](#acknowledgement)
37 |
38 |
39 |
40 | # 📢 News
41 | - [2025/04]: The pre-trained VDocRetriever weights are out in 🤗 [Huggingface Hub](https://huggingface.co/NTT-hil-insight/VDocRetriever-Phi3-vision-pretrained).
42 | - [2025/04]: The technical report, code, data, and model for VDocRAG are all available online.
43 | - [2025/02]: 🎉 VDocRAG is accepted to CVPR 2025.
44 |
45 |
46 | # ⚙️ Installation
47 | 1. Clone the repository.
48 | 2. Install PyTorch based on your CUDA version from PyTorch.
49 | 3. Install dependencies and VDocRAG.
50 | ```bash
51 | pip install -r requirements.txt
52 | pip install -e .
53 | ```
54 |
55 |
56 | # ⚡️ Quick Start
57 | You can download [VDocRetriever](https://huggingface.co/NTT-hil-insight/VDocRetriever-Phi3-vision) and [VDocGenerator](https://huggingface.co/NTT-hil-insight/VDocGenerator-Phi3-vision) from 🤗 HuggingFace Hub. To get started, first import the libraries as shown below:
58 | ```py
59 | from PIL import Image
60 | import requests
61 | from io import BytesIO
62 | from torch.nn.functional import cosine_similarity
63 | import torch
64 | from transformers import AutoProcessor
65 | from vdocrag.vdocretriever.modeling import VDocRetriever
66 | from vdocrag.vdocgenerator.modeling import VDocGenerator
67 | ```
68 |
69 | ## Retrieval
70 | ```py
71 | processor = AutoProcessor.from_pretrained('microsoft/Phi-3-vision-128k-instruct', trust_remote_code=True)
72 | model = VDocRetriever.load('microsoft/Phi-3-vision-128k-instruct',
73 | lora_name_or_path='NTT-hil-insight/VDocRetriever-Phi3-vision',
74 | pooling='eos',
75 | normalize=True,
76 | trust_remote_code=True,
77 | attn_implementation="flash_attention_2",
78 | torch_dtype=torch.bfloat16,
79 | use_cache=False).to('cuda:0')
80 |
81 | # Process query inputs and get the embeddings
82 | queries = ["Instruct: I’m looking for an image that answers the question.\nQuery: What is the total percentage of Palestinians residing at West Bank?",
83 | "Instruct: I’m looking for an image that answers the question.\nQuery: How many international visitors came to Japan in 2017?"]
84 | query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=256, truncation=True).to('cuda:0')
85 |
86 | with torch.no_grad():
87 | model_output = model(query=query_inputs, use_cache=False)
88 | query_embeddings = model_output.q_reps
89 |
90 | urls = [
91 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image1.png",
92 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image2.png"
93 | ]
94 |
95 | doc_images = [Image.open(BytesIO(requests.get(url).content)).resize((1344, 1344)) for url in urls]
96 |
97 | # Process images and get the embeddings
98 | doc_prompt = "<|image_1|>\nWhat is shown in this image?"
99 | collated_list = [
100 | processor(doc_prompt, images=image, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0') for image in doc_images
101 | ]
102 |
103 | doc_inputs = {
104 | key: torch.stack([item[key][0] for item in collated_list], dim=0)
105 | for key in ['input_ids', 'attention_mask', 'pixel_values', 'image_sizes']
106 | }
107 |
108 | with torch.no_grad():
109 | model_output = model(document=doc_inputs, use_cache=False)
110 | doc_embeddings = model_output.p_reps
111 |
112 | # Calculate cosine similarity
113 | num_queries = query_embeddings.size(0)
114 | num_passages = doc_embeddings.size(0)
115 |
116 | for i in range(num_queries):
117 | query_embedding = query_embeddings[i].unsqueeze(0)
118 | similarities = cosine_similarity(query_embedding, doc_embeddings)
119 | print(f"Similarities for Query {i}: {similarities.cpu().float().numpy()}")
120 |
121 | # >> Similarities for Query 0: [0.515625 0.38476562]
122 | # Similarities for Query 1: [0.37890625 0.5703125 ]
123 | ```
124 |
125 | ## Generation
126 | ```py
127 | model = VDocGenerator.load('microsoft/Phi-3-vision-128k-instruct',
128 | lora_name_or_path='NTT-hil-insight/VDocGenerator-Phi3-vision',
129 | trust_remote_code=True,
130 | attn_implementation="flash_attention_2",
131 | torch_dtype=torch.bfloat16,
132 | use_cache=False).to('cuda:0')
133 |
134 | # Process images with the prompt
135 | query = "How many international visitors came to Japan in 2017? \n Answer briefly."
136 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(doc_images))])
137 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}]
138 | prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
139 | processed = processor(prompt, images=doc_images, return_tensors="pt").to('cuda:0')
140 |
141 | # Generate the answer
142 | generate_ids = model.generate(processed,
143 | generation_args={
144 | "max_new_tokens": 64,
145 | "temperature": 0.0,
146 | "do_sample": False,
147 | "eos_token_id": processor.tokenizer.eos_token_id
148 | })
149 | generate_ids = generate_ids[:, processed['input_ids'].shape[1]:]
150 | response = processor.batch_decode(generate_ids,
151 | skip_special_tokens=True,
152 | clean_up_tokenization_spaces=False)[0].strip()
153 |
154 | print("Model prediction: {0}".format(response))
155 |
156 | # >> Model prediction: 28.69m
157 | ```
158 |
159 |
160 | # 💾 Dataset
161 | OpenDocVQA is a unified collection of open-domain document visual question answering datasets, encompassing diverse document types and formats. It consists of 9 open-domain DocumentVQA datasets, including a newly created MHDocVQA dataset to address multi-hop questions over multiple documents, and collected and filtered QA datasets (DocVQA, InfoVQA, DUDE, VisulMRC, ChartQA, OpenWikiTable, MPMQA, and SlideVQA). In total, OpenDocVQA contains 43k QA paris with 200k document images.
162 |
163 | You can donwload OpenDocVQA dataset from 🤗 HuggingFace Hub as follows:
164 | - [QA Pairs](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA)
165 | - [Corpus](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA-Corpus)
166 |
167 |
168 | # 🔎 Retriever
169 |
170 | ## Pre-training VDocRetriever
171 | This script supports our proposed pre-training tasks, including RCR (Representation Compression via Retrieval) and RCG (Representation Compression via Generation).
172 | ```bash
173 | deepspeed --include localhost:0 --master_port 60000 --module vdocrag.vdocretriever.driver.train \
174 | --deepspeed deepspeed/ds_zero3_config.json \
175 | --output_dir outputs/vdocretriever-phi3-vision_pretrain \
176 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
177 | --lora \
178 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
179 | --save_steps 2000 \
180 | --dataset_name NTT-hil-insight/VDocRetriever-Pretrain-DocStruct \
181 | --bf16 \
182 | --pooling eos \
183 | --append_eos_token \
184 | --normalize \
185 | --temperature 0.01 \
186 | --per_device_train_batch_size 1 \
187 | --gradient_checkpointing \
188 | --train_group_size 1 \
189 | --learning_rate 1e-4 \
190 | --query_max_len 512 \
191 | --answer_max_len 512 \
192 | --num_train_epochs 1 \
193 | --logging_steps 10 \
194 | --overwrite_output_dir \
195 | --gradient_accumulation_steps 4 \
196 | --pretrain \
197 | --image_attention_mask \
198 | --report_to wandb \
199 | ```
200 |
201 | ## Fine-tuning VDocRetriever
202 | If you set `--lora_name_or_path` as `NTT-hil-insight/VDocRetriever-Phi3-vision-pretrained`, you can use our pre-trained VDocRetriever weights without pre-traing models locally.
203 |
204 | ```bash
205 | deepspeed --include localhost:0 --master_port 60000 --module vdocrag.vdocretriever.driver.train \
206 | --deepspeed deepspeed/ds_zero3_config.json \
207 | --output_dir outputs/vdocretriever-phi3-vision_finetune \
208 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
209 | --lora_name_or_path outputs/vdocretriever-phi3-vision_pretrain \
210 | --lora \
211 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
212 | --save_steps 2000 \
213 | --dataset_name NTT-hil-insight/OpenDocVQA \
214 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
215 | --bf16 \
216 | --pooling eos \
217 | --append_eos_token \
218 | --normalize \
219 | --temperature 0.01 \
220 | --per_device_train_batch_size 4 \
221 | --gradient_checkpointing \
222 | --train_group_size 1 \
223 | --learning_rate 1e-4 \
224 | --query_max_len 256 \
225 | --answer_max_len 256 \
226 | --num_train_epochs 1 \
227 | --logging_steps 10 \
228 | --overwrite_output_dir \
229 | --gradient_accumulation_steps 4 \
230 | --report_to wandb \
231 | ```
232 |
233 | ## Query encoding
234 | If you want to use our fine-tuned model directly, `lora_name_or_path` is set to `NTT-hil-insight/VDocRetriever-Phi3-vision`. `QUERY_DATASET` must be selected from the following options: {chartqa, slidevqa, infovqa, dude}
235 |
236 | ```bash
237 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
238 | --output_dir=temp \
239 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
240 | --lora_name_or_path outputs/vdocretriever-phi3-vision_finetune \
241 | --lora \
242 | --bf16 \
243 | --pooling eos \
244 | --append_eos_token \
245 | --normalize \
246 | --encode_is_query \
247 | --per_device_eval_batch_size 24 \
248 | --query_max_len 256 \
249 | --dataset_name NTT-hil-insight/OpenDocVQA \
250 | --dataset_config $QUERY_DATASET \
251 | --dataset_split test \
252 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl
253 | ```
254 |
255 | ## Document encoding
256 | If you want to use our fine-tuned model directly, `lora_name_or_path` is set to `NTT-hil-insight/VDocRetriever-Phi3-vision`.
257 | `CORPUS_DATASET` must be selected from the following options: {all, chartqa, slidevqa, infovqa, dude}
258 |
259 | ```bash
260 | for s in 0 1 2 3; do
261 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
262 | --output_dir=temp \
263 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
264 | --lora_name_or_path outputs/vdocretriever-phi3-vision_finetune \
265 | --lora \
266 | --bf16 \
267 | --pooling eos \
268 | --append_eos_token \
269 | --normalize \
270 | --per_device_eval_batch_size 4 \
271 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
272 | --corpus_config $CORPUS_DATASET \
273 | --corpus_split test \
274 | --dataset_number_of_shards 4 \
275 | --dataset_shard_index ${s} \
276 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl
277 | ```
278 |
279 | ## Retrieval
280 | ```bash
281 | python -m vdocrag.vdocretriever.driver.search \
282 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \
283 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \
284 | --depth 1000 \
285 | --batch_size 64 \
286 | --save_text \
287 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
288 | ```
289 |
290 | ## Evaluation
291 | ```bash
292 | # Convert retrieval results (.txt) to .trec file
293 | python -m vdocrag.utils.format.convert_result_to_trec --input $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
294 | --output $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.trec \
295 | --remove_query
296 | # Create ground-truth retrieval results
297 | python -m vdocrag.utils.format.convert_qas_to_trec_qrels --dataset_name NTT-hil-insight/OpenDocVQA \
298 | --dataset_config ${QUERY_DATASET} \
299 | --output $EMBEDDING_OUTPUT_DIR/qrels.${QUERY_DATASET}.txt \
300 | # Evaluate with pyserini
301 | python -m pyserini.eval.trec_eval -c -mrecall.1,5,10 -mndcg_cut.1,5,10 $EMBEDDING_OUTPUT_DIR/qrels.${QUERY_DATASET}.txt $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.trec
302 | ```
303 |
304 |
305 | # 💬 Generator
306 |
307 | ## Data Creation for Generator
308 | Before training and evaluating generator models, you should create a data included triples `(query id, retrieved document id, retrieved score)` for the generator. You can create the data automatically as follows:
309 | - Train Data: `script/create_train_generator.sh`
310 | - Test Data
311 | - ChartQA: `script/create_test_chartqa_generator.sh`
312 | - SlideVQA: `script/create_test_slidevqa_generator.sh`
313 | - InfoVQA: `script/create_test_infovqa_generator.sh`
314 | - DUDE: `script/create_test_dude_generator.sh`
315 |
316 | If you want to use your own models, `lora_name_or_path` should be replaced with your model name. By default, all generated test set is set to the single-pool setting. When you evaluated models under the all-pool setting, you can change `CORPUS_DATASET` into `all`.
317 |
318 |
319 | ## Fine-tuning VDocGenerator
320 | `retrieval_results_path` is the saved path where the retrieval results created in the previous section (Data Creation for Generator).
321 |
322 | ```bash
323 | deepspeed --include localhost:0 --master_port 60000 --module vdocrag.vdocgenerator.driver.train \
324 | --deepspeed deepspeed/ds_zero3_config.json \
325 | --output_dir outputs/vdocgenerator-phi3-vision_finetune \
326 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
327 | --lora \
328 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
329 | --save_steps 100 \
330 | --dataset_name NTT-hil-insight/OpenDocVQA \
331 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
332 | --retrieval_results_path outputs/vdocretriever-phi3-vision_finetune/embs/rank.train.all.txt \
333 | --bf16 \
334 | --per_device_train_batch_size 2 \
335 | --gradient_checkpointing \
336 | --learning_rate 1e-4 \
337 | --query_max_len 256 \
338 | --num_train_epochs 1 \
339 | --logging_steps 10 \
340 | --overwrite_output_dir \
341 | --gradient_accumulation_steps 4 \
342 | --top_k 3 \
343 | --report_to wandb \
344 | ```
345 |
346 | ## Generating Answers
347 | If you want to use our fine-tuned model directly, `lora_name_or_path` is set to `NTT-hil-insight/VDocGenerator-Phi3-vision`.
348 | ```bash
349 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocgenerator.driver.generate \
350 | --output_dir=temp \
351 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
352 | --lora_name_or_path outputs/vdocgenerator-phi3-vision_finetune \
353 | --lora \
354 | --dataset_name NTT-hil-insight/OpenDocVQA \
355 | --dataset_split test \
356 | --dataset_config ${QUERY_DATASET} \
357 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
358 | --corpus_config $CORPUS_DATASET \
359 | --corpus_split test \
360 | --retrieval_results_path outputs/vdocretriever-phi3-vision_finetune/embs/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
361 | --bf16 \
362 | --per_device_eval_batch_size 1 \
363 | --top_k 3 \
364 | --output_path outputs/vdocgenerator-phi3-vision_finetune/answers/answers.${QUERY_DATASET}.${CORPUS_DATASET}.json \
365 | ```
366 |
367 | ## Evaluation
368 | ```bash
369 | python -m vdocrag.utils.eval_opendocvqa --input outputs/vdocgenerator-phi3-vision_finetune/answers/answers.${QUERY_DATASET}.${CORPUS_DATASET}.json
370 | ```
371 |
372 |
373 | # 📝 License
374 | The code is released under the NTT License as found in the [LICENSE](./LICENSE) file.
375 |
376 |
377 | # ✒️ Citation
378 | ```bibtex
379 | @inproceedings{tanaka2025vdocrag,
380 | author = {Ryota Tanaka and
381 | Taichi Iki and
382 | Taku Hasegawa and
383 | Kyosuke Nishida and
384 | Kuniko Saito and
385 | Jun Suzuki},
386 | title = {VDocRAG: Retrieval-Augmented Generation over Visually-Rich Documents},
387 | booktitle = {CVPR},
388 | year = {2025}
389 | }
390 | ```
391 |
392 |
393 | # 📔 Acknowledgement
394 | We have adapted code from [Tevatron](https://github.com/texttron/tevatron/), a flexible and efficient toolkit that supports training and inference for neural retrieval models.
395 |
396 |
--------------------------------------------------------------------------------
/deepspeed/ds_zero3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "offload_optimizer": {
5 | "device": "none",
6 | "pin_memory": true
7 | },
8 | "offload_param": {
9 | "device": "none",
10 | "pin_memory": true
11 | },
12 | "overlap_comm": true,
13 | "contiguous_gradients": true,
14 | "sub_group_size": 1e9,
15 | "reduce_bucket_size": 1e6,
16 | "stage3_prefetch_bucket_size": "auto",
17 | "stage3_param_persistence_threshold": "auto",
18 | "stage3_max_live_parameters": 1e9,
19 | "stage3_max_reuse_distance": 1e9,
20 | "stage3_gather_16bit_weights_on_model_save": true
21 | },
22 | "fp16": {
23 | "enabled": "auto",
24 | "loss_scale": 0,
25 | "initial_scale_power": 10,
26 | "loss_scale_window": 1000,
27 | "hysteresis": 2,
28 | "min_loss_scale": 1
29 | },
30 | "bf16": {
31 | "enabled": "auto",
32 | "loss_scale": 0,
33 | "initial_scale_power": 10,
34 | "loss_scale_window": 1000,
35 | "hysteresis": 2,
36 | "min_loss_scale": 1
37 | },
38 | "optimizer": {
39 | "type": "AdamW",
40 | "params": {
41 | "lr": "auto",
42 | "betas": "auto",
43 | "eps": "auto",
44 | "weight_decay": "auto",
45 | "torch_adam": true
46 | }
47 | },
48 |
49 | "scheduler": {
50 | "type": "WarmupDecayLR",
51 | "params": {
52 | "warmup_min_lr": "auto",
53 | "warmup_max_lr": "auto",
54 | "warmup_num_steps": "auto",
55 | "total_num_steps": "auto"
56 | }
57 | },
58 |
59 | "gradient_accumulation_steps": "auto",
60 | "gradient_clipping": "auto",
61 | "steps_per_print": 1000,
62 | "train_batch_size": "auto",
63 | "train_micro_batch_size_per_gpu": "auto",
64 | "wall_clock_breakdown": false
65 | }
66 |
--------------------------------------------------------------------------------
/images/abst.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/images/abst.png
--------------------------------------------------------------------------------
/images/image1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/images/image1.png
--------------------------------------------------------------------------------
/images/image2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/images/image2.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | deepspeed==0.14.5
2 | accelerate==0.34.0
3 | datasets==2.20.0
4 | flash-attn==2.6.3
5 | huggingface-hub==0.30.1
6 | numpy==1.26.4
7 | peft==0.11.1
8 | pillow==10.4.0
9 | torch==2.4.0
10 | torchvision==0.19.0
11 | tqdm==4.66.4
12 | transformers==4.40.2
13 | wandb==0.17.7
14 | pyserini==0.44.0
15 | faiss-gpu==1.7.2
--------------------------------------------------------------------------------
/scripts/create_test_chartqa_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | QUERY_DATASET=chartqa
4 | CORPUS_DATASET=chartqa
5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs
6 |
7 | # encoding queries
8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
9 | --output_dir=temp \
10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
12 | --lora \
13 | --bf16 \
14 | --pooling eos \
15 | --append_eos_token \
16 | --normalize \
17 | --encode_is_query \
18 | --per_device_eval_batch_size 24 \
19 | --query_max_len 256 \
20 | --dataset_name NTT-hil-insight/OpenDocVQA \
21 | --dataset_config $QUERY_DATASET \
22 | --dataset_split test \
23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl
24 |
25 | # encoding documents
26 | for s in 0 1 2 3
27 | do
28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
29 | --output_dir=temp \
30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
32 | --lora \
33 | --bf16 \
34 | --pooling eos \
35 | --append_eos_token \
36 | --normalize \
37 | --per_device_eval_batch_size 4 \
38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
39 | --corpus_config $CORPUS_DATASET \
40 | --corpus_split test \
41 | --dataset_number_of_shards 4 \
42 | --dataset_shard_index ${s} \
43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl
44 | done
45 |
46 | # retrieving documents
47 | python -m vdocrag.vdocretriever.driver.search \
48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \
49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \
50 | --depth 1000 \
51 | --batch_size 64 \
52 | --save_text \
53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
--------------------------------------------------------------------------------
/scripts/create_test_dude_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | QUERY_DATASET=dude
4 | CORPUS_DATASET=dude
5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs
6 |
7 | # encoding queries
8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
9 | --output_dir=temp \
10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
12 | --lora \
13 | --bf16 \
14 | --pooling eos \
15 | --append_eos_token \
16 | --normalize \
17 | --encode_is_query \
18 | --per_device_eval_batch_size 24 \
19 | --query_max_len 256 \
20 | --dataset_name NTT-hil-insight/OpenDocVQA \
21 | --dataset_config $QUERY_DATASET \
22 | --dataset_split test \
23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl
24 |
25 | # encoding documents
26 | for s in 0 1 2 3
27 | do
28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
29 | --output_dir=temp \
30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
32 | --lora \
33 | --bf16 \
34 | --pooling eos \
35 | --append_eos_token \
36 | --normalize \
37 | --per_device_eval_batch_size 4 \
38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
39 | --corpus_config $CORPUS_DATASET \
40 | --corpus_split test \
41 | --dataset_number_of_shards 4 \
42 | --dataset_shard_index ${s} \
43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl
44 | done
45 |
46 | # retrieving documentss
47 | python -m vdocrag.vdocretriever.driver.search \
48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \
49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \
50 | --depth 1000 \
51 | --batch_size 64 \
52 | --save_text \
53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
--------------------------------------------------------------------------------
/scripts/create_test_infovqa_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | QUERY_DATASET=infovqa
4 | CORPUS_DATASET=infovqa
5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs
6 |
7 | # encoding queries
8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
9 | --output_dir=temp \
10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
12 | --lora \
13 | --bf16 \
14 | --pooling eos \
15 | --append_eos_token \
16 | --normalize \
17 | --encode_is_query \
18 | --per_device_eval_batch_size 24 \
19 | --query_max_len 256 \
20 | --dataset_name NTT-hil-insight/OpenDocVQA \
21 | --dataset_config $QUERY_DATASET \
22 | --dataset_split test \
23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl
24 |
25 | # encoding documents
26 | for s in 0 1 2 3
27 | do
28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
29 | --output_dir=temp \
30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
32 | --lora \
33 | --bf16 \
34 | --pooling eos \
35 | --append_eos_token \
36 | --normalize \
37 | --per_device_eval_batch_size 4 \
38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
39 | --corpus_config $CORPUS_DATASET \
40 | --corpus_split test \
41 | --dataset_number_of_shards 4 \
42 | --dataset_shard_index ${s} \
43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl
44 | done
45 |
46 | # retrieving documentss
47 | python -m vdocrag.vdocretriever.driver.search \
48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \
49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \
50 | --depth 1000 \
51 | --batch_size 64 \
52 | --save_text \
53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
--------------------------------------------------------------------------------
/scripts/create_test_slidevqa_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | QUERY_DATASET=slidevqa
4 | CORPUS_DATASET=slidevqa
5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs
6 |
7 | # encoding queries
8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
9 | --output_dir=temp \
10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
12 | --lora \
13 | --bf16 \
14 | --pooling eos \
15 | --append_eos_token \
16 | --normalize \
17 | --encode_is_query \
18 | --per_device_eval_batch_size 24 \
19 | --query_max_len 256 \
20 | --dataset_name NTT-hil-insight/OpenDocVQA \
21 | --dataset_config $QUERY_DATASET \
22 | --dataset_split test \
23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl
24 |
25 | # encoding documents
26 | for s in 0 1 2 3
27 | do
28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
29 | --output_dir=temp \
30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
32 | --lora \
33 | --bf16 \
34 | --pooling eos \
35 | --append_eos_token \
36 | --normalize \
37 | --per_device_eval_batch_size 4 \
38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
39 | --corpus_config $CORPUS_DATASET \
40 | --corpus_split test \
41 | --dataset_number_of_shards 4 \
42 | --dataset_shard_index ${s} \
43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl
44 | done
45 |
46 | # retrieving documentss
47 | python -m vdocrag.vdocretriever.driver.search \
48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \
49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \
50 | --depth 1000 \
51 | --batch_size 64 \
52 | --save_text \
53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
--------------------------------------------------------------------------------
/scripts/create_train_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | QUERY_DATASET=train
4 | CORPUS_DATASET=all
5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs
6 |
7 | # encoding queries
8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
9 | --output_dir=temp \
10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
12 | --lora \
13 | --bf16 \
14 | --pooling eos \
15 | --append_eos_token \
16 | --normalize \
17 | --encode_is_query \
18 | --per_device_eval_batch_size 24 \
19 | --query_max_len 256 \
20 | --dataset_name NTT-hil-insight/OpenDocVQA \
21 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-train.pkl
22 |
23 | # encoding documents
24 | for s in 0 1 2 3
25 | do
26 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \
27 | --output_dir=temp \
28 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \
29 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \
30 | --lora \
31 | --bf16 \
32 | --pooling eos \
33 | --append_eos_token \
34 | --normalize \
35 | --per_device_eval_batch_size 4 \
36 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \
37 | --dataset_number_of_shards 4 \
38 | --dataset_shard_index ${s} \
39 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl
40 | done
41 |
42 | # retrieving documents
43 | python -m vdocrag.vdocretriever.driver.search \
44 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \
45 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \
46 | --depth 1000 \
47 | --batch_size 64 \
48 | --save_text \
49 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='vdocrag',
5 | version='0.0.1',
6 | packages=find_packages("src"),
7 | package_dir={'': 'src'},
8 | python_requires='>=3.7',
9 | )
10 |
--------------------------------------------------------------------------------
/src/vdocrag/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/__init__.py
--------------------------------------------------------------------------------
/src/vdocrag/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/utils/__init__.py
--------------------------------------------------------------------------------
/src/vdocrag/utils/eval_opendocvqa.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import argparse
4 | from collections import Counter
5 |
6 | def has_word(sentence, word):
7 | pattern = r"\b" + re.escape(word) + r"\b"
8 | match = re.search(pattern, sentence)
9 | if match:
10 | return True
11 | else:
12 | return False
13 |
14 | class VQAEval:
15 | def __init__(self):
16 | self.contractions = {
17 | "aint": "ain't",
18 | "arent": "aren't",
19 | "cant": "can't",
20 | "couldve": "could've",
21 | "couldnt": "couldn't",
22 | "couldn'tve": "couldn't've",
23 | "couldnt've": "couldn't've",
24 | "didnt": "didn't",
25 | "doesnt": "doesn't",
26 | "dont": "don't",
27 | "hadnt": "hadn't",
28 | "hadnt've": "hadn't've",
29 | "hadn'tve": "hadn't've",
30 | "hasnt": "hasn't",
31 | "havent": "haven't",
32 | "hed": "he'd",
33 | "hed've": "he'd've",
34 | "he'dve": "he'd've",
35 | "hes": "he's",
36 | "howd": "how'd",
37 | "howll": "how'll",
38 | "hows": "how's",
39 | "Id've": "I'd've",
40 | "I'dve": "I'd've",
41 | "Im": "I'm",
42 | "Ive": "I've",
43 | "isnt": "isn't",
44 | "itd": "it'd",
45 | "itd've": "it'd've",
46 | "it'dve": "it'd've",
47 | "itll": "it'll",
48 | "let's": "let's",
49 | "maam": "ma'am",
50 | "mightnt": "mightn't",
51 | "mightnt've": "mightn't've",
52 | "mightn'tve": "mightn't've",
53 | "mightve": "might've",
54 | "mustnt": "mustn't",
55 | "mustve": "must've",
56 | "neednt": "needn't",
57 | "notve": "not've",
58 | "oclock": "o'clock",
59 | "oughtnt": "oughtn't",
60 | "ow's'at": "'ow's'at",
61 | "'ows'at": "'ow's'at",
62 | "'ow'sat": "'ow's'at",
63 | "shant": "shan't",
64 | "shed've": "she'd've",
65 | "she'dve": "she'd've",
66 | "she's": "she's",
67 | "shouldve": "should've",
68 | "shouldnt": "shouldn't",
69 | "shouldnt've": "shouldn't've",
70 | "shouldn'tve": "shouldn't've",
71 | "somebody'd": "somebodyd",
72 | "somebodyd've": "somebody'd've",
73 | "somebody'dve": "somebody'd've",
74 | "somebodyll": "somebody'll",
75 | "somebodys": "somebody's",
76 | "someoned": "someone'd",
77 | "someoned've": "someone'd've",
78 | "someone'dve": "someone'd've",
79 | "someonell": "someone'll",
80 | "someones": "someone's",
81 | "somethingd": "something'd",
82 | "somethingd've": "something'd've",
83 | "something'dve": "something'd've",
84 | "somethingll": "something'll",
85 | "thats": "that's",
86 | "thered": "there'd",
87 | "thered've": "there'd've",
88 | "there'dve": "there'd've",
89 | "therere": "there're",
90 | "theres": "there's",
91 | "theyd": "they'd",
92 | "theyd've": "they'd've",
93 | "they'dve": "they'd've",
94 | "theyll": "they'll",
95 | "theyre": "they're",
96 | "theyve": "they've",
97 | "twas": "'twas",
98 | "wasnt": "wasn't",
99 | "wed've": "we'd've",
100 | "we'dve": "we'd've",
101 | "weve": "we've",
102 | "werent": "weren't",
103 | "whatll": "what'll",
104 | "whatre": "what're",
105 | "whats": "what's",
106 | "whatve": "what've",
107 | "whens": "when's",
108 | "whered": "where'd",
109 | "wheres": "where's",
110 | "whereve": "where've",
111 | "whod": "who'd",
112 | "whod've": "who'd've",
113 | "who'dve": "who'd've",
114 | "wholl": "who'll",
115 | "whos": "who's",
116 | "whove": "who've",
117 | "whyll": "why'll",
118 | "whyre": "why're",
119 | "whys": "why's",
120 | "wont": "won't",
121 | "wouldve": "would've",
122 | "wouldnt": "wouldn't",
123 | "wouldnt've": "wouldn't've",
124 | "wouldn'tve": "wouldn't've",
125 | "yall": "y'all",
126 | "yall'll": "y'all'll",
127 | "y'allll": "y'all'll",
128 | "yall'd've": "y'all'd've",
129 | "y'alld've": "y'all'd've",
130 | "y'all'dve": "y'all'd've",
131 | "youd": "you'd",
132 | "youd've": "you'd've",
133 | "you'dve": "you'd've",
134 | "youll": "you'll",
135 | "youre": "you're",
136 | "youve": "you've",
137 | }
138 | self.manualMap = {
139 | "none": "0",
140 | "zero": "0",
141 | "one": "1",
142 | "two": "2",
143 | "three": "3",
144 | "four": "4",
145 | "five": "5",
146 | "six": "6",
147 | "seven": "7",
148 | "eight": "8",
149 | "nine": "9",
150 | "ten": "10",
151 | }
152 | self.articles = ["a", "an", "the"]
153 |
154 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
155 | self.commaStrip = re.compile("(\d)(\,)(\d)")
156 | self.punct = [
157 | ";",
158 | r"/",
159 | "[",
160 | "]",
161 | '"',
162 | "{",
163 | "}",
164 | "(",
165 | ")",
166 | "=",
167 | "+",
168 | "\\",
169 | "_",
170 | "-",
171 | ">",
172 | "<",
173 | "@",
174 | "`",
175 | ",",
176 | "?",
177 | "!",
178 | ]
179 |
180 | def evaluate(self, answer, gt_answers):
181 | answer = answer.replace("\n", " ")
182 | answer = answer.replace("\t", " ")
183 | answer = answer.strip()
184 | answer = self.processPunctuation(answer)
185 | answer = self.processDigitArticle(answer)
186 | if type(gt_answers) == str:
187 | gt_answers = [gt_answers]
188 | for i in range(len(gt_answers)):
189 | gt_answers[i] = gt_answers[i].replace("\n", " ")
190 | gt_answers[i] = gt_answers[i].replace("\t", " ")
191 | gt_answers[i] = gt_answers[i].strip()
192 | gt_answers[i] = self.processPunctuation(gt_answers[i])
193 | gt_answers[i] = self.processDigitArticle(gt_answers[i])
194 | if answer == gt_answers[i]:
195 | return 1
196 | return 0
197 |
198 | def processPunctuation(self, inText):
199 | outText = inText
200 | for p in self.punct:
201 | if (p + " " in inText or " " + p in inText) or (
202 | re.search(self.commaStrip, inText) != None
203 | ):
204 | outText = outText.replace(p, "")
205 | else:
206 | outText = outText.replace(p, " ")
207 | outText = self.periodStrip.sub("", outText, re.UNICODE)
208 | return outText
209 |
210 | def processDigitArticle(self, inText):
211 | outText = []
212 | tempText = inText.lower().split()
213 | for word in tempText:
214 | word = self.manualMap.setdefault(word, word)
215 | if word not in self.articles:
216 | outText.append(word)
217 | else:
218 | pass
219 | for wordId, word in enumerate(outText):
220 | if word in self.contractions:
221 | outText[wordId] = self.contractions[word]
222 | outText = " ".join(outText)
223 | return outText
224 |
225 | def evaluate_anls(self, answer, gt_answers):
226 | answer = answer.replace("\n", " ")
227 | answer = answer.replace("\t", " ")
228 | answer = answer.strip()
229 | answer = self.processPunctuation(answer)
230 | answer = self.processDigitArticle(answer)
231 | if type(gt_answers) == str:
232 | gt_answers = [gt_answers]
233 | values = []
234 | for i in range(len(gt_answers)):
235 | gt_answers[i] = gt_answers[i].replace("\n", " ")
236 | gt_answers[i] = gt_answers[i].replace("\t", " ")
237 | gt_answers[i] = gt_answers[i].strip()
238 | gt_answers[i] = self.processPunctuation(gt_answers[i])
239 | gt_answers[i] = self.processDigitArticle(gt_answers[i])
240 | dist = self.levenshtein_distance(gt_answers[i], answer)
241 | length = max(len(gt_answers[i]), len(answer))
242 | values.append(0.0 if length == 0 else float(dist) / float(length))
243 |
244 | vqa_anls = 1 - min(values)
245 | return vqa_anls
246 |
247 | def evaluate_f1(self, answer, gt_answers):
248 | answer = answer.replace("\n", " ")
249 | answer = answer.replace("\t", " ")
250 | answer = answer.strip()
251 | answer = self.processPunctuation(answer)
252 | answer = self.processDigitArticle(answer)
253 | if type(gt_answers) == str:
254 | gt_answers = [gt_answers]
255 |
256 | f1s = []
257 | for i in range(len(gt_answers)):
258 | gt_answers[i] = gt_answers[i].replace("\n", " ")
259 | gt_answers[i] = gt_answers[i].replace("\t", " ")
260 | gt_answers[i] = gt_answers[i].strip()
261 | gt_answers[i] = self.processPunctuation(gt_answers[i])
262 | gt_answers[i] = self.processDigitArticle(gt_answers[i])
263 | prediction_tokens = answer.split()
264 | ground_truth_tokens = gt_answers[i].split()
265 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
266 | num_same = sum(common.values())
267 | try:
268 | precision = 1.0 * num_same / len(prediction_tokens)
269 | recall = 1.0 * num_same / len(ground_truth_tokens)
270 | f1 = (2 * precision * recall) / (precision + recall)
271 | except:
272 | f1 = 0
273 | f1s.append(f1)
274 | f1 = max(f1s)
275 |
276 | return f1
277 |
278 | def evaluate_racc(self, answer, gt_answers):
279 | answer = answer.replace("\n", " ")
280 | answer = answer.replace("\t", " ")
281 | answer = answer.strip()
282 | answer = self.processPunctuation(answer)
283 | answer = self.processDigitArticle(answer)
284 | if type(gt_answers) == str:
285 | gt_answers = [gt_answers]
286 |
287 | for i in range(len(gt_answers)):
288 | gt_answers[i] = gt_answers[i].replace("\n", " ")
289 | gt_answers[i] = gt_answers[i].replace("\t", " ")
290 | gt_answers[i] = gt_answers[i].strip()
291 | gt_answers[i] = self.processPunctuation(gt_answers[i])
292 | gt_answers[i] = self.processDigitArticle(gt_answers[i])
293 | answer_float = self._to_float(answer)
294 | gt_answer_float = self._to_float(gt_answers[i])
295 | if answer_float is not None and gt_answer_float:
296 | relative_change = abs(answer_float - gt_answer_float) / abs(gt_answer_float)
297 | return relative_change <= 0.05
298 | else:
299 | return answer.lower() == gt_answers[i].lower()
300 | return 0
301 |
302 | def _to_float(self, text):
303 | try:
304 | if text.endswith("%"):
305 | # Convert percentages to floats.
306 | return float(text.rstrip("%")) / 100.0
307 | else:
308 | return float(text)
309 | except ValueError:
310 | return None
311 | def levenshtein_distance(self, s1, s2):
312 | if len(s1) > len(s2):
313 | s1, s2 = s2, s1
314 |
315 | distances = range(len(s1) + 1)
316 | for i2, c2 in enumerate(s2):
317 | distances_ = [i2+1]
318 | for i1, c1 in enumerate(s1):
319 | if c1 == c2:
320 | distances_.append(distances[i1])
321 | else:
322 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
323 | distances = distances_
324 | return distances[-1]
325 |
326 | class SumEval:
327 | def __init__(self):
328 | self.contractions = {
329 | "aint": "ain't",
330 | "arent": "aren't",
331 | "cant": "can't",
332 | "couldve": "could've",
333 | "couldnt": "couldn't",
334 | "couldn'tve": "couldn't've",
335 | "couldnt've": "couldn't've",
336 | "didnt": "didn't",
337 | "doesnt": "doesn't",
338 | "dont": "don't",
339 | "hadnt": "hadn't",
340 | "hadnt've": "hadn't've",
341 | "hadn'tve": "hadn't've",
342 | "hasnt": "hasn't",
343 | "havent": "haven't",
344 | "hed": "he'd",
345 | "hed've": "he'd've",
346 | "he'dve": "he'd've",
347 | "hes": "he's",
348 | "howd": "how'd",
349 | "howll": "how'll",
350 | "hows": "how's",
351 | "Id've": "I'd've",
352 | "I'dve": "I'd've",
353 | "Im": "I'm",
354 | "Ive": "I've",
355 | "isnt": "isn't",
356 | "itd": "it'd",
357 | "itd've": "it'd've",
358 | "it'dve": "it'd've",
359 | "itll": "it'll",
360 | "let's": "let's",
361 | "maam": "ma'am",
362 | "mightnt": "mightn't",
363 | "mightnt've": "mightn't've",
364 | "mightn'tve": "mightn't've",
365 | "mightve": "might've",
366 | "mustnt": "mustn't",
367 | "mustve": "must've",
368 | "neednt": "needn't",
369 | "notve": "not've",
370 | "oclock": "o'clock",
371 | "oughtnt": "oughtn't",
372 | "ow's'at": "'ow's'at",
373 | "'ows'at": "'ow's'at",
374 | "'ow'sat": "'ow's'at",
375 | "shant": "shan't",
376 | "shed've": "she'd've",
377 | "she'dve": "she'd've",
378 | "she's": "she's",
379 | "shouldve": "should've",
380 | "shouldnt": "shouldn't",
381 | "shouldnt've": "shouldn't've",
382 | "shouldn'tve": "shouldn't've",
383 | "somebody'd": "somebodyd",
384 | "somebodyd've": "somebody'd've",
385 | "somebody'dve": "somebody'd've",
386 | "somebodyll": "somebody'll",
387 | "somebodys": "somebody's",
388 | "someoned": "someone'd",
389 | "someoned've": "someone'd've",
390 | "someone'dve": "someone'd've",
391 | "someonell": "someone'll",
392 | "someones": "someone's",
393 | "somethingd": "something'd",
394 | "somethingd've": "something'd've",
395 | "something'dve": "something'd've",
396 | "somethingll": "something'll",
397 | "thats": "that's",
398 | "thered": "there'd",
399 | "thered've": "there'd've",
400 | "there'dve": "there'd've",
401 | "therere": "there're",
402 | "theres": "there's",
403 | "theyd": "they'd",
404 | "theyd've": "they'd've",
405 | "they'dve": "they'd've",
406 | "theyll": "they'll",
407 | "theyre": "they're",
408 | "theyve": "they've",
409 | "twas": "'twas",
410 | "wasnt": "wasn't",
411 | "wed've": "we'd've",
412 | "we'dve": "we'd've",
413 | "weve": "we've",
414 | "werent": "weren't",
415 | "whatll": "what'll",
416 | "whatre": "what're",
417 | "whats": "what's",
418 | "whatve": "what've",
419 | "whens": "when's",
420 | "whered": "where'd",
421 | "wheres": "where's",
422 | "whereve": "where've",
423 | "whod": "who'd",
424 | "whod've": "who'd've",
425 | "who'dve": "who'd've",
426 | "wholl": "who'll",
427 | "whos": "who's",
428 | "whove": "who've",
429 | "whyll": "why'll",
430 | "whyre": "why're",
431 | "whys": "why's",
432 | "wont": "won't",
433 | "wouldve": "would've",
434 | "wouldnt": "wouldn't",
435 | "wouldnt've": "wouldn't've",
436 | "wouldn'tve": "wouldn't've",
437 | "yall": "y'all",
438 | "yall'll": "y'all'll",
439 | "y'allll": "y'all'll",
440 | "yall'd've": "y'all'd've",
441 | "y'alld've": "y'all'd've",
442 | "y'all'dve": "y'all'd've",
443 | "youd": "you'd",
444 | "youd've": "you'd've",
445 | "you'dve": "you'd've",
446 | "youll": "you'll",
447 | "youre": "you're",
448 | "youve": "you've",
449 | }
450 | self.manualMap = {
451 | "none": "0",
452 | "zero": "0",
453 | "one": "1",
454 | "two": "2",
455 | "three": "3",
456 | "four": "4",
457 | "five": "5",
458 | "six": "6",
459 | "seven": "7",
460 | "eight": "8",
461 | "nine": "9",
462 | "ten": "10",
463 | }
464 | self.articles = ["a", "an", "the"]
465 |
466 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
467 | self.commaStrip = re.compile("(\d)(\,)(\d)")
468 | self.punct = [
469 | ";",
470 | r"/",
471 | "[",
472 | "]",
473 | '"',
474 | "{",
475 | "}",
476 | "(",
477 | ")",
478 | "=",
479 | "+",
480 | "\\",
481 | "_",
482 | "-",
483 | ">",
484 | "<",
485 | "@",
486 | "`",
487 | ",",
488 | "?",
489 | "!",
490 | ]
491 |
492 | def process(self, answer, gt_answer):
493 | answer = answer.replace("\n", " ")
494 | answer = answer.replace("\t", " ")
495 | answer = answer.strip()
496 | answer = self.processPunctuation(answer)
497 | answer = self.processDigitArticle(answer)
498 | gt_answer = gt_answer.replace("\n", " ")
499 | gt_answer = gt_answer.replace("\t", " ")
500 | gt_answer = gt_answer.strip()
501 | gt_answer = self.processPunctuation(gt_answer)
502 | gt_answer = self.processDigitArticle(gt_answer)
503 |
504 | return answer, gt_answer
505 |
506 | def processPunctuation(self, inText):
507 | outText = inText
508 | for p in self.punct:
509 | if (p + " " in inText or " " + p in inText) or (
510 | re.search(self.commaStrip, inText) != None
511 | ):
512 | outText = outText.replace(p, "")
513 | else:
514 | outText = outText.replace(p, " ")
515 | outText = self.periodStrip.sub("", outText, re.UNICODE)
516 | return outText
517 |
518 | def processDigitArticle(self, inText):
519 | outText = []
520 | tempText = inText.lower().split()
521 | for word in tempText:
522 | word = self.manualMap.setdefault(word, word)
523 | if word not in self.articles:
524 | outText.append(word)
525 | else:
526 | pass
527 | for wordId, word in enumerate(outText):
528 | if word in self.contractions:
529 | outText[wordId] = self.contractions[word]
530 | outText = " ".join(outText)
531 | return outText
532 |
533 | if __name__ == "__main__":
534 | parser = argparse.ArgumentParser()
535 | parser.add_argument("--input", type=str, default="outputs/vdocgenerator-phi3-vision_finetune/answers/answers.chartqa.chartqa.json")
536 | args = parser.parse_args()
537 |
538 | cor = 0
539 | r_cor = 0
540 | num = 0
541 | anls_score = 0
542 | f1_score = 0
543 |
544 | preds, gt = {}, {}
545 | eval = VQAEval()
546 | with open(args.input, 'r') as f:
547 | data = json.load(f)
548 | for question_id in data:
549 | d = data[question_id]
550 | answer = d["prediction"]
551 | gt_answers = d["ground_truth"]
552 |
553 | if eval.evaluate(answer, gt_answers)==1:
554 | cor += 1
555 | if eval.evaluate_racc(answer, gt_answers)==1:
556 | r_cor += 1
557 |
558 | anls_score+=eval.evaluate_anls(answer, gt_answers)
559 | f1_score += eval.evaluate_f1(answer, gt_answers)
560 | num += 1
561 |
562 | print("exact acc: ", float(cor) / num)
563 | print("relaxed acc: ", float(r_cor) / num)
564 | print("anls: ", float(anls_score) / num)
565 | print("f1: ", float(f1_score) / num)
566 |
--------------------------------------------------------------------------------
/src/vdocrag/utils/format/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/utils/format/__init__.py
--------------------------------------------------------------------------------
/src/vdocrag/utils/format/convert_qas_to_trec_qrels.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | from datasets import load_dataset
4 |
5 | parser = ArgumentParser()
6 | parser.add_argument('--dataset_name', type=str, required=True)
7 | parser.add_argument('--dataset_config', type=str, required=True)
8 | parser.add_argument('--output', type=str, required=True)
9 | args = parser.parse_args()
10 |
11 | data = load_dataset(
12 | args.dataset_name,
13 | args.dataset_config,
14 | split="test",
15 | )
16 |
17 | with open(args.output, 'w') as f_out:
18 | for d in data:
19 | query_id = d['query_id']
20 | for docid in d['relevant_doc_ids']:
21 | f_out.write(f'{query_id} Q0 {docid} 1\n')
22 |
--------------------------------------------------------------------------------
/src/vdocrag/utils/format/convert_result_to_trec.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | parser = ArgumentParser()
4 | parser.add_argument('--input', type=str, required=True)
5 | parser.add_argument('--output', type=str, required=True)
6 | parser.add_argument('--remove_query', action='store_true')
7 | args = parser.parse_args()
8 |
9 | with open(args.input) as f_in, open(args.output, 'w') as f_out:
10 | cur_qid = None
11 | rank = 0
12 | for line in f_in:
13 | qid, docid, score = line.split('\t')
14 | score = score.replace('\n', '')
15 | if cur_qid != qid:
16 | cur_qid = qid
17 | rank = 0
18 | if args.remove_query and qid == docid:
19 | continue
20 | rank += 1
21 | f_out.write(f'{qid} Q0 {docid} {rank} {score} dense\n')
22 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/vdocgenerator/__init__.py
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Optional
4 | from transformers import TrainingArguments
5 |
6 |
7 | @dataclass
8 | class ModelArguments:
9 | model_name_or_path: str = field(
10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
11 | )
12 | config_name: Optional[str] = field(
13 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
14 | )
15 | tokenizer_name: Optional[str] = field(
16 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
17 | )
18 | cache_dir: Optional[str] = field(
19 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
20 | )
21 | # for lora
22 | lora: bool = field(default=False,
23 | metadata={"help": "do parameter-efficient fine-tuning with lora"}
24 | )
25 |
26 | lora_name_or_path: Optional[str] = field(
27 | default=None, metadata={"help": "Path to pretrained lora model or model identifier from huggingface.co/models"}
28 | )
29 |
30 | lora_r: int = field(
31 | default=8,
32 | metadata={"help": "lora r"}
33 | )
34 |
35 | lora_alpha: int = field(
36 | default=64,
37 | metadata={"help": "lora alpha"}
38 | )
39 |
40 | lora_dropout: float = field(
41 | default=0.1,
42 | metadata={"help": "lora dropout"}
43 | )
44 |
45 | lora_target_modules: str = field(
46 | default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
47 | metadata={"help": "lora target modules"}
48 | )
49 |
50 | # for Jax training
51 | dtype: Optional[str] = field(
52 | default="float32",
53 | metadata={
54 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one "
55 | "of `[float32, float16, bfloat16]`. "
56 | },
57 | )
58 |
59 |
60 | @dataclass
61 | class DataArguments:
62 | dataset_name: str = field(
63 | default='json', metadata={"help": "huggingface dataset name"}
64 | )
65 | dataset_path: str = field(
66 | default='json', metadata={"help": "dataset path"}
67 | )
68 | dataset_config: str = field(
69 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"}
70 | )
71 |
72 | dataset_path: str = field(
73 | default=None, metadata={"help": "Path to local data files or directory"}
74 | )
75 |
76 | dataset_split: str = field(
77 | default='train', metadata={"help": "dataset split"}
78 | )
79 |
80 | dataset_cache_dir: Optional[str] = field(
81 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"}
82 | )
83 | corpus_name: str = field(
84 | default='NTT-hil-insight/OpenDocVQA-Corpus', metadata={"help": "huggingface dataset name"}
85 | )
86 |
87 | corpus_config: str = field(
88 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"}
89 | )
90 |
91 | corpus_path: str = field(
92 | default=None, metadata={"help": "Path to local data files or directory"}
93 | )
94 |
95 | corpus_split: str = field(
96 | default='train', metadata={"help": "dataset split"}
97 | )
98 |
99 | retrieval_results_path: str = field(
100 | default=None, metadata={"help": "Path to local data files or directory"}
101 | )
102 |
103 | dataset_number_of_shards: int = field(
104 | default=1, metadata={"help": "number of shards to split the dataset into"}
105 | )
106 |
107 | dataset_shard_index: int = field(
108 | default=0, metadata={"help": "shard index to use, to be used with dataset_number_of_shards"}
109 | )
110 |
111 | top_k: int = field(
112 | default=3, metadata={"help": "number of documents used to train for each query"}
113 | )
114 | gold: bool = field(
115 | default=False, metadata={"help": "gold retrieval results"})
116 |
117 | output_path: str = field(default=None, metadata={"help": "where to save the encode"})
118 |
119 |
120 | query_max_len: Optional[int] = field(
121 | default=32,
122 | metadata={
123 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer "
124 | "than this will be truncated, sequences shorter will be padded."
125 | },
126 | )
127 | answer_max_len: Optional[int] = field(
128 | default=128,
129 | metadata={
130 | "help": "The maximum total input sequence length after tokenization for document. Sequences longer "
131 | "than this will be truncated, sequences shorter will be padded."
132 | },
133 | )
134 |
135 | pad_to_multiple_of: Optional[int] = field(
136 | default=16,
137 | metadata={
138 | "help": "If set will pad the sequence to a multiple of the provided value. This is especially useful to "
139 | "enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)."
140 | },
141 | )
142 |
143 |
144 | @dataclass
145 | class VDocGeneratorTrainingArguments(TrainingArguments):
146 | warmup_ratio: float = field(default=0.1)
147 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | from typing import List, Tuple
4 | from dataclasses import dataclass
5 | from transformers import PreTrainedTokenizer, ProcessorMixin
6 | from vdocrag.vdocgenerator.arguments import DataArguments, TrainingArguments
7 | from transformers.feature_extraction_utils import BatchFeature
8 | from collections import defaultdict
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | @dataclass
14 | class TrainCollator:
15 | data_args: DataArguments
16 | tokenizer: PreTrainedTokenizer
17 | processor: ProcessorMixin
18 |
19 | def __call__(self, features: List[Tuple[str, List[str]]]):
20 | """
21 | Collate function for training.
22 | :param features: list of (query, documents) tuples
23 | :return: tokenized query_ids, document_ids
24 | """
25 |
26 | all_queries = [f[0] for f in features]
27 | all_answers = [f[1] for f in features]
28 | all_images = [f[2] for f in features]
29 |
30 | collated = {}
31 | all_input_ids, all_label_ids, pixel_values, image_sizes = [], [], [], []
32 | for i, (query, answer, images) in enumerate(zip(all_queries, all_answers, all_images)):
33 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(images))])
34 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}]
35 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
36 | processed = self.processor(prompt, images, return_tensors="pt")
37 | answer = f'{answer}<|end|>\n<|endoftext|>'
38 | answer_input_ids = self.tokenizer(
39 | answer, add_special_tokens=False, return_tensors='pt'
40 | )['input_ids']
41 | prompt_input_ids = processed['input_ids']
42 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1)
43 | labels = torch.cat(
44 | [
45 | torch.tensor([-100] * len(prompt_input_ids[0])).unsqueeze(0),
46 | answer_input_ids,
47 | ],
48 | dim=1,
49 | )
50 | # prepare expected shape for pad_sequence
51 | all_input_ids.append(input_ids.squeeze(0).unsqueeze(1))
52 | all_label_ids.append(labels.squeeze(0).unsqueeze(1))
53 | pixel_values.append(processed['pixel_values'])
54 | image_sizes.append(processed['image_sizes'])
55 |
56 | input_ids = torch._C._nn.pad_sequence(
57 | all_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
58 | ).squeeze(2)
59 | labels = torch._C._nn.pad_sequence(
60 | all_label_ids, batch_first=True, padding_value=-100
61 | ).squeeze(2)
62 |
63 | collated['input_ids'] = input_ids
64 | collated['labels'] = labels
65 | collated['pixel_values'] = torch.cat(pixel_values, dim=0)
66 | collated['image_sizes'] = torch.cat(image_sizes, dim=0)
67 | collated['attention_mask'] = collated['input_ids'].ne(self.tokenizer.pad_token_id)
68 |
69 | return collated
70 |
71 | @dataclass
72 | class DecodeCollator:
73 | data_args: DataArguments
74 | tokenizer: PreTrainedTokenizer
75 | processor: ProcessorMixin
76 |
77 | def __call__(self, features: List[Tuple[str, str]]):
78 | """
79 | Collate function for encoding.
80 | :param features: list of (id, text) tuples
81 | """
82 | query_ids = [f[0] for f in features]
83 | all_queries = [f[1] for f in features]
84 | all_answers = [f[2] for f in features]
85 | all_images = [f[3] for f in features]
86 |
87 | collated = defaultdict(list)
88 | pixel_values, image_sizes = [], []
89 | for i, (query, images) in enumerate(zip(all_queries, all_images)):
90 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(images))])
91 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}]
92 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93 | processed = self.processor(prompt, images, return_tensors="pt")
94 | prompt_input_ids = processed['input_ids']
95 | collated['input_ids'].append(prompt_input_ids)
96 | pixel_values.append(processed['pixel_values'])
97 | image_sizes.append(processed['image_sizes'])
98 |
99 | collated['input_ids'] = torch.cat(collated['input_ids'], dim=0)
100 | collated['pixel_values'] = torch.cat(pixel_values, dim=0)
101 | collated['image_sizes'] = torch.cat(image_sizes, dim=0)
102 |
103 | return query_ids, all_answers, collated
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import List, Tuple
3 |
4 | from datasets import load_dataset
5 | from torch.utils.data import Dataset
6 | from PIL import Image
7 | import os
8 | import json
9 | from vdocrag.vdocgenerator.arguments import DataArguments
10 | from scipy.special import softmax
11 | from collections import defaultdict
12 | import numpy as np
13 | from functools import partial
14 |
15 | import logging
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def format_query_for_QA(query):
20 | return query.split("Query: ")[-1].strip() + "\n Answer briefly."
21 |
22 | def add_candidates(example, retrieved_docs, top_k):
23 | query_id = example["query_id"]
24 | candidates = retrieved_docs.get(query_id, [])[:top_k]
25 | example["candidates"] = candidates
26 | return example
27 |
28 | class TrainDataset(Dataset):
29 | def __init__(self, data_args: DataArguments, trainer = None):
30 | self.data_args = data_args
31 | self.train_data = load_dataset(
32 | self.data_args.dataset_name,
33 | self.data_args.dataset_config,
34 | data_files=self.data_args.dataset_path,
35 | split=self.data_args.dataset_split,
36 | cache_dir=self.data_args.dataset_cache_dir,
37 | )
38 |
39 | self.corpus = load_dataset(
40 | self.data_args.corpus_name,
41 | self.data_args.corpus_config,
42 | data_files=self.data_args.corpus_path,
43 | split=self.data_args.corpus_split,
44 | cache_dir=self.data_args.dataset_cache_dir,
45 | )
46 |
47 | self.docid2idx = {}
48 | for idx, doc_id in enumerate(self.corpus['doc_id']):
49 | self.docid2idx[str(doc_id)] = idx
50 |
51 | self.retrieved_docs = defaultdict(list)
52 | with open(self.data_args.retrieval_results_path) as f:
53 | lines = f.read().splitlines()
54 | for line in lines:
55 | query_id, doc_id, score = line.split()
56 | self.retrieved_docs[query_id].append(doc_id)
57 |
58 | self.train_data = self.train_data.map(
59 | partial(add_candidates,
60 | retrieved_docs=self.retrieved_docs,
61 | top_k=self.data_args.top_k)
62 | )
63 |
64 | self.trainer = trainer
65 |
66 | def __len__(self):
67 | return len(self.train_data)
68 |
69 | def _get_image(self, doc_id):
70 | image = self.corpus[self.docid2idx[doc_id]]['image']
71 | return image
72 |
73 | def __getitem__(self, item) -> Tuple[str, List[str]]:
74 | group = self.train_data[item]
75 | query = format_query_for_QA(group['query'])
76 | answer = group['answers'][0]
77 | images = [self._get_image(doc_id) for doc_id in group["candidates"]]
78 |
79 | return query, answer, images
80 |
81 |
82 | class DecodeDataset(Dataset):
83 |
84 | def __init__(self, data_args: DataArguments):
85 | self.data_args = data_args
86 | self.test_data = load_dataset(
87 | self.data_args.dataset_name,
88 | self.data_args.dataset_config,
89 | data_files=self.data_args.dataset_path,
90 | split=self.data_args.dataset_split,
91 | cache_dir=self.data_args.dataset_cache_dir,
92 | )
93 |
94 | self.corpus = load_dataset(
95 | self.data_args.corpus_name,
96 | self.data_args.corpus_config,
97 | data_files=self.data_args.corpus_path,
98 | split=self.data_args.corpus_split,
99 | cache_dir=self.data_args.dataset_cache_dir,
100 | )
101 |
102 | self.docid2idx = {}
103 | for idx, doc_id in enumerate(self.corpus['doc_id']):
104 | self.docid2idx[str(doc_id)] = idx
105 |
106 | self.retrieved_docs = defaultdict(list)
107 | with open(self.data_args.retrieval_results_path) as f:
108 | lines = f.read().splitlines()
109 | for line in lines:
110 | query_id, doc_id, score = line.split()
111 | self.retrieved_docs[query_id].append(doc_id)
112 |
113 | self.test_data = self.test_data.map(
114 | partial(add_candidates,
115 | retrieved_docs=self.retrieved_docs,
116 | top_k=self.data_args.top_k)
117 | )
118 |
119 | def __len__(self):
120 | return len(self.test_data)
121 |
122 | def _get_image(self, doc_id):
123 | image = self.corpus[self.docid2idx[doc_id]]['image']
124 | return image
125 |
126 | def __getitem__(self, item) -> Tuple[str, str]:
127 | data = self.test_data[item]
128 | query_id = data['query_id']
129 | query = format_query_for_QA(data["query"])
130 | answers = data['answers']
131 | images = [self._get_image(doc_id) for doc_id in data["candidates"]]
132 | return query_id, query, answers, images
133 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/driver/generate.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 | import json
6 | from contextlib import nullcontext
7 |
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | import torch
12 | import time
13 | from torch.utils.data import DataLoader
14 | from transformers import AutoTokenizer, AutoProcessor
15 | from transformers import (
16 | HfArgumentParser,
17 | )
18 |
19 | from vdocrag.vdocgenerator.arguments import ModelArguments, DataArguments, \
20 | VDocGeneratorTrainingArguments as TrainingArguments
21 | from vdocrag.vdocgenerator.dataset import DecodeDataset
22 | from vdocrag.vdocgenerator.collator import DecodeCollator
23 | from vdocrag.vdocgenerator.modeling import DecoderOutput, VDocGenerator
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | def main():
29 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
30 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
31 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
32 | else:
33 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
34 | model_args: ModelArguments
35 | data_args: DataArguments
36 | training_args: TrainingArguments
37 |
38 | if training_args.local_rank > 0 or training_args.n_gpu > 1:
39 | raise NotImplementedError('Multi-GPU encoding is not supported.')
40 |
41 | # Setup logging
42 | logging.basicConfig(
43 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
44 | datefmt="%m/%d/%Y %H:%M:%S",
45 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
46 | )
47 |
48 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
49 | cache_dir=model_args.cache_dir,
50 | trust_remote_code=True,)
51 |
52 | tokenizer = processor.tokenizer
53 |
54 | if tokenizer.pad_token_id is None:
55 | tokenizer.pad_token_id = tokenizer.eos_token_id
56 | tokenizer.padding_side = 'right'
57 |
58 | if training_args.bf16:
59 | torch_dtype = torch.bfloat16
60 | elif training_args.fp16:
61 | torch_dtype = torch.float16
62 | else:
63 | torch_dtype = torch.float32
64 |
65 | model = VDocGenerator.load(
66 | model_args.model_name_or_path,
67 | lora_name_or_path=model_args.lora_name_or_path,
68 | trust_remote_code=True,
69 | cache_dir=model_args.cache_dir,
70 | torch_dtype=torch_dtype,
71 | _attn_implementation='flash_attention_2',
72 | )
73 |
74 | decode_dataset = DecodeDataset(
75 | data_args=data_args,
76 | )
77 |
78 | decode_collator = DecodeCollator(
79 | data_args=data_args,
80 | tokenizer=tokenizer,
81 | processor=processor,
82 | )
83 |
84 | decode_loader = DataLoader(
85 | decode_dataset,
86 | batch_size=training_args.per_device_eval_batch_size,
87 | collate_fn=decode_collator,
88 | shuffle=False,
89 | drop_last=False,
90 | num_workers=training_args.dataloader_num_workers,
91 | )
92 | responses = {}
93 | model = model.to(training_args.device)
94 | model.eval()
95 |
96 | generation_args = {
97 | "max_new_tokens": 64,
98 | "temperature": 0.0,
99 | "do_sample": False,
100 | "eos_token_id": tokenizer.eos_token_id,
101 | }
102 |
103 | # TODO batch > 1
104 | for (batch_ids, answers, batch) in tqdm(decode_loader):
105 | with nullcontext():
106 | with torch.no_grad():
107 | for k, v in batch.items():
108 | batch[k] = v.to(training_args.device)
109 | generate_ids = model.generate(batch,
110 | generation_args=generation_args,
111 | )
112 | generate_ids = generate_ids[:, batch['input_ids'].shape[1]:]
113 | response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
114 | response = response.strip()
115 | responses[batch_ids[0]] = {"ground_truth": answers[0], "prediction": response}
116 |
117 | if not os.path.exists(os.path.dirname(data_args.output_path)):
118 | os.makedirs(os.path.dirname(data_args.output_path))
119 | with open(data_args.output_path, 'w') as f:
120 | json.dump(responses, f)
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/driver/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import torch
5 | import wandb
6 |
7 | from transformers import AutoTokenizer
8 | from transformers import AutoProcessor
9 |
10 | from transformers import (
11 | HfArgumentParser,
12 | set_seed,
13 | )
14 |
15 | from vdocrag.vdocgenerator.arguments import ModelArguments, DataArguments, \
16 | VDocGeneratorTrainingArguments as TrainingArguments
17 | from vdocrag.vdocgenerator.dataset import TrainDataset
18 | from vdocrag.vdocgenerator.collator import TrainCollator
19 | from vdocrag.vdocgenerator.modeling import VDocGenerator
20 | from vdocrag.vdocgenerator.trainer import VDocGeneratorTrainer as Trainer
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def main():
26 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
27 |
28 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
29 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
30 | else:
31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
32 | model_args: ModelArguments
33 | data_args: DataArguments
34 | training_args: TrainingArguments
35 |
36 | if (
37 | os.path.exists(training_args.output_dir)
38 | and os.listdir(training_args.output_dir)
39 | and training_args.do_train
40 | and not training_args.overwrite_output_dir
41 | ):
42 | raise ValueError(
43 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
44 | )
45 |
46 | # Setup logging
47 | logging.basicConfig(
48 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
49 | datefmt="%m/%d/%Y %H:%M:%S",
50 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
51 | )
52 | logger.warning(
53 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
54 | training_args.local_rank,
55 | training_args.device,
56 | training_args.n_gpu,
57 | bool(training_args.local_rank != -1),
58 | training_args.fp16,
59 | )
60 | logger.info("Training/evaluation parameters %s", training_args)
61 | logger.info("MODEL parameters %s", model_args)
62 |
63 | set_seed(training_args.seed)
64 |
65 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
66 | cache_dir=model_args.cache_dir,
67 | trust_remote_code=True)
68 | tokenizer = processor.tokenizer
69 |
70 | if tokenizer.pad_token_id is None:
71 | tokenizer.pad_token_id = tokenizer.eos_token_id
72 | tokenizer.padding_side = 'right'
73 |
74 | if training_args.bf16:
75 | torch_dtype = torch.bfloat16
76 | elif training_args.fp16:
77 | torch_dtype = torch.float16
78 | else:
79 | torch_dtype = torch.float32
80 |
81 | model = VDocGenerator.build(
82 | model_args,
83 | training_args,
84 | cache_dir=model_args.cache_dir,
85 | trust_remote_code=True,
86 | torch_dtype=torch_dtype,
87 | _attn_implementation='flash_attention_2',
88 | )
89 |
90 | train_dataset = TrainDataset(data_args)
91 | collator = TrainCollator(data_args, tokenizer, processor)
92 |
93 | trainer_cls = Trainer
94 |
95 | trainer = trainer_cls(
96 | model=model,
97 | args=training_args,
98 | train_dataset=train_dataset,
99 | data_collator=collator
100 | )
101 | train_dataset.trainer = trainer
102 |
103 | trainer.train() # TODO: resume training
104 | trainer.save_model()
105 | if trainer.is_world_process_zero():
106 | tokenizer.save_pretrained(training_args.output_dir)
107 |
108 |
109 | if __name__ == "__main__":
110 | main()
111 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .vdocgenerator import VDocGenerator, DecoderOutput
2 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/modeling/vdocgenerator.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Optional
3 |
4 | import torch
5 | import torch.distributed as dist
6 | from torch import nn, Tensor
7 |
8 | from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM
9 | from peft import LoraConfig, TaskType, get_peft_model, PeftModel
10 |
11 | from transformers.file_utils import ModelOutput
12 | from vdocrag.vdocgenerator.arguments import ModelArguments, VDocGeneratorTrainingArguments as TrainingArguments
13 |
14 | import logging
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | @dataclass
19 | class DecoderOutput(ModelOutput):
20 | loss: Optional[Tensor] = None
21 |
22 |
23 | class VDocGenerator(nn.Module):
24 | TRANSFORMER_CLS = AutoModelForCausalLM
25 |
26 | def __init__(self,
27 | decoder: PreTrainedModel,
28 | ):
29 | super().__init__()
30 | self.config = decoder.config
31 | self.decoder = decoder
32 | self.is_ddp = dist.is_initialized()
33 | if self.is_ddp:
34 | self.process_rank = dist.get_rank()
35 | self.world_size = dist.get_world_size()
36 |
37 | def forward(self, inputs: Dict[str, Tensor] = None, use_cache: bool = True):
38 | outputs = self.decode(inputs, use_cache=use_cache)
39 |
40 | # for training
41 | if self.training:
42 | loss = outputs.loss
43 |
44 | if self.is_ddp:
45 | loss = loss * self.world_size # counter average weight reduction
46 |
47 | # for eval
48 | else:
49 | loss = None
50 | return DecoderOutput(
51 | loss=loss,
52 | )
53 |
54 | def gradient_checkpointing_enable(self, **kwargs):
55 | self.decoder.model.gradient_checkpointing_enable()
56 |
57 | @classmethod
58 | def build(
59 | cls,
60 | model_args: ModelArguments,
61 | train_args: TrainingArguments,
62 | **hf_kwargs,
63 | ):
64 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
65 | if base_model.config.pad_token_id is None:
66 | base_model.config.pad_token_id = 0
67 | if model_args.lora or model_args.lora_name_or_path:
68 | if train_args.gradient_checkpointing:
69 | base_model.enable_input_require_grads()
70 | if model_args.lora_name_or_path:
71 | lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
72 | lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, is_trainable=True)
73 | else:
74 | lora_config = LoraConfig(
75 | base_model_name_or_path=model_args.model_name_or_path,
76 | task_type=TaskType.FEATURE_EXTRACTION,
77 | r=model_args.lora_r,
78 | lora_alpha=model_args.lora_alpha,
79 | lora_dropout=model_args.lora_dropout,
80 | target_modules=model_args.lora_target_modules.split(','),
81 | inference_mode=False
82 | )
83 | lora_model = get_peft_model(base_model, lora_config)
84 | model = cls(
85 | decoder=lora_model,
86 | )
87 | else:
88 | model = cls(
89 | decoder=base_model,
90 | )
91 | return model
92 |
93 | @classmethod
94 | def load(cls,
95 | model_name_or_path: str,
96 | lora_name_or_path: str = None,
97 | **hf_kwargs):
98 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs)
99 | if base_model.config.pad_token_id is None:
100 | base_model.config.pad_token_id = 0
101 | if lora_name_or_path:
102 | lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
103 | lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
104 | lora_model = lora_model.merge_and_unload()
105 | model = cls(
106 | decoder=lora_model,
107 | )
108 | else:
109 | model = cls(
110 | decoder=base_model,
111 | )
112 | return model
113 |
114 | def save(self, output_dir: str):
115 | self.decoder.save_pretrained(output_dir)
116 |
117 | def decode(self, input, use_cache=True):
118 | return self.decoder(**input, use_cache=use_cache)
119 |
120 | def generate(self, input, generation_args, use_cache=True):
121 | return self.decoder.generate(**input, **generation_args, use_cache=use_cache)
--------------------------------------------------------------------------------
/src/vdocrag/vdocgenerator/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 |
4 | import torch
5 |
6 | from transformers.trainer import Trainer, TRAINING_ARGS_NAME
7 | import torch.distributed as dist
8 | from .modeling import VDocGenerator
9 | from huggingface_hub import login
10 |
11 | import logging
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class VDocGeneratorTrainer(Trainer):
16 | def __init__(self, *args, **kwargs):
17 | super(VDocGeneratorTrainer, self).__init__(*args, **kwargs)
18 | self.is_ddp = dist.is_initialized()
19 | self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1
20 |
21 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
22 | # If we are executing this function, we are the process zero, so we don't check for that.
23 | output_dir = output_dir if output_dir is not None else self.args.output_dir
24 | os.makedirs(output_dir, exist_ok=True)
25 | logger.info(f"Saving model checkpoint to {output_dir}")
26 |
27 | supported_classes = (VDocGenerator,)
28 | # Save a trained model and configuration using `save_pretrained()`.
29 | # They can then be reloaded using `from_pretrained()`
30 | if not isinstance(self.model, supported_classes):
31 | raise ValueError(f"Unsupported model class {self.model}")
32 | else:
33 | if state_dict is None:
34 | state_dict = self.model.state_dict()
35 | prefix = 'decoder.'
36 | assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
37 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
38 | self.model.decoder.save_pretrained(
39 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
40 | )
41 |
42 | if self.tokenizer is not None:
43 | self.tokenizer.save_pretrained(output_dir)
44 |
45 | # Good practice: save your training arguments together with the trained model
46 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
47 |
48 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
49 | return model(inputs).loss
50 |
51 | def training_step(self, *args):
52 | return super(VDocGeneratorTrainer, self).training_step(*args) / self._dist_loss_scale_factor
53 |
54 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/vdocretriever/__init__.py
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Optional
4 | from transformers import TrainingArguments
5 |
6 |
7 | @dataclass
8 | class ModelArguments:
9 | model_name_or_path: str = field(
10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
11 | )
12 | config_name: Optional[str] = field(
13 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
14 | )
15 | tokenizer_name: Optional[str] = field(
16 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
17 | )
18 | cache_dir: Optional[str] = field(
19 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
20 | )
21 |
22 | pooling: str = field(
23 | default='cls',
24 | metadata={"help": "pooling method for query and document encoder"}
25 | )
26 | normalize: bool = field(
27 | default=False,
28 | metadata={"help": "normalize query and document representations"}
29 | )
30 |
31 | temperature: float = field(
32 | default=1.0,
33 | metadata={"help": "temperature for softmax"}
34 | )
35 |
36 | # for lora
37 | lora: bool = field(default=False,
38 | metadata={"help": "do parameter-efficient fine-tuning with lora"}
39 | )
40 |
41 | lora_name_or_path: Optional[str] = field(
42 | default=None, metadata={"help": "Path to pretrained lora model or model identifier from huggingface.co/models"}
43 | )
44 |
45 | lora_r: int = field(
46 | default=8,
47 | metadata={"help": "lora r"}
48 | )
49 |
50 | lora_alpha: int = field(
51 | default=64,
52 | metadata={"help": "lora alpha"}
53 | )
54 |
55 | lora_dropout: float = field(
56 | default=0.1,
57 | metadata={"help": "lora dropout"}
58 | )
59 |
60 | lora_target_modules: str = field(
61 | default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
62 | metadata={"help": "lora target modules"}
63 | )
64 |
65 | dtype: Optional[str] = field(
66 | default="float32",
67 | metadata={
68 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one "
69 | "of `[float32, float16, bfloat16]`. "
70 | },
71 | )
72 |
73 | @dataclass
74 | class DataArguments:
75 | dataset_path: str = field(
76 | default='json', metadata={"help": "dataset path"}
77 | )
78 | dataset_name: str = field(
79 | default='NTT-hil-insight/OpenDocVQA', metadata={"help": "huggingface dataset name"}
80 | )
81 | dataset_config: str = field(
82 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"}
83 | )
84 |
85 | dataset_path: str = field(
86 | default=None, metadata={"help": "Path to local data files or directory"}
87 | )
88 |
89 | dataset_split: str = field(
90 | default='train', metadata={"help": "dataset split"}
91 | )
92 |
93 | dataset_cache_dir: Optional[str] = field(
94 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"}
95 | )
96 |
97 | corpus_name: str = field(
98 | default='NTT-hil-insight/OpenDocVQA-Corpus', metadata={"help": "huggingface dataset name"}
99 | )
100 |
101 | corpus_config: str = field(
102 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"}
103 | )
104 |
105 | corpus_path: str = field(
106 | default=None, metadata={"help": "Path to local data files or directory"}
107 | )
108 |
109 | corpus_split: str = field(
110 | default='train', metadata={"help": "dataset split"}
111 | )
112 |
113 | dataset_number_of_shards: int = field(
114 | default=1, metadata={"help": "number of shards to split the dataset into"}
115 | )
116 |
117 | dataset_shard_index: int = field(
118 | default=0, metadata={"help": "shard index to use, to be used with dataset_number_of_shards"}
119 | )
120 |
121 | train_group_size: int = field(
122 | default=8, metadata={"help": "number of documents used to train for each query"}
123 | )
124 | positive_document_no_shuffle: bool = field(
125 | default=False, metadata={"help": "always use the first positive document for training"})
126 |
127 | image_attention_mask: bool = field(
128 | default=False, metadata={"help": "custom attention mask for RCG task"})
129 |
130 | pretrain: bool = field(
131 | default=False, metadata={"help": "whether pre-training is executed or not"})
132 |
133 | encode_is_query: bool = field(default=False)
134 |
135 | encode_output_path: str = field(default=None, metadata={"help": "where to save the encode"})
136 |
137 | query_max_len: Optional[int] = field(
138 | default=32,
139 | metadata={
140 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer "
141 | "than this will be truncated, sequences shorter will be padded."
142 | },
143 | )
144 | answer_max_len: Optional[int] = field(
145 | default=128,
146 | metadata={
147 | "help": "The maximum total input sequence length after tokenization for document. Sequences longer "
148 | "than this will be truncated, sequences shorter will be padded."
149 | },
150 | )
151 |
152 | append_eos_token: bool = field(
153 | default=False, metadata={"help": "append eos token to query and document, this is currently used for repllama"}
154 | )
155 |
156 | pad_to_multiple_of: Optional[int] = field(
157 | default=16,
158 | metadata={
159 | "help": "If set will pad the sequence to a multiple of the provided value. This is especially useful to "
160 | "enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)."
161 | },
162 | )
163 |
164 |
165 | @dataclass
166 | class VDocRetrieverTrainingArguments(TrainingArguments):
167 | warmup_ratio: float = field(default=0.1)
168 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | from typing import List, Tuple
4 | from dataclasses import dataclass
5 | from transformers import PreTrainedTokenizer, ProcessorMixin
6 | from vdocrag.vdocretriever.arguments import DataArguments
7 | from collections import defaultdict
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | @dataclass
13 | class TrainCollator:
14 | data_args: DataArguments
15 | tokenizer: PreTrainedTokenizer
16 | processor: ProcessorMixin
17 |
18 | def build_image_attention_mask(self, seq_len, input_lengths):
19 | image_attention_masks = []
20 | for input_len in input_lengths:
21 | image_attention_mask = torch.tril(torch.ones(seq_len, seq_len), diagonal=0)
22 | image_attention_mask[input_len:, :input_len-1] = 0
23 | image_attention_masks.append(image_attention_mask.unsqueeze(0))
24 | image_attention_masks = torch.cat(image_attention_masks, dim=0)
25 | return image_attention_masks
26 |
27 | def __call__(self, features: List[Tuple[str, List[str]]]):
28 | all_queries = [f[0] for f in features]
29 | all_images = [f[-1] for f in features]
30 |
31 | q_collated = self.tokenizer(
32 | all_queries,
33 | padding=False,
34 | truncation=True,
35 | max_length=self.data_args.query_max_len-1 if self.data_args.append_eos_token else self.data_args.query_max_len,
36 | return_attention_mask=False,
37 | return_token_type_ids=False,
38 | add_special_tokens=True,
39 | )
40 |
41 | d_collated = {}
42 | collated_list = [self.processor("<|image_1|>\nWhat is shown in this image?", image, return_tensors="pt") for image in all_images]
43 | d_collated['input_ids'] = [d['input_ids'][0].tolist() for d in collated_list]
44 |
45 | if self.data_args.append_eos_token:
46 | q_collated['input_ids'] = [q + [self.tokenizer.eos_token_id] for q in q_collated['input_ids']]
47 | d_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in d_collated['input_ids']]
48 |
49 | if self.data_args.pretrain:
50 | p_collated = {}
51 | all_input_ids, all_label_ids, input_lengths = [], [], []
52 |
53 | for i, ocr in enumerate(all_queries):
54 | prompt_input_ids = torch.tensor(d_collated['input_ids'][i]).unsqueeze(0)
55 | answer = f'{ocr}<|end|>\n<|endoftext|>'
56 | answer_input_ids = self.tokenizer(
57 | answer, add_special_tokens=False, max_length=self.data_args.answer_max_len, truncation=True, return_tensors='pt')['input_ids']
58 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1)
59 | labels = torch.cat(
60 | [
61 | torch.tensor([-100] * len(prompt_input_ids[0])).unsqueeze(0),
62 | answer_input_ids,
63 | ],
64 | dim=1,
65 | )
66 | all_input_ids.append(input_ids.squeeze(0).unsqueeze(1))
67 | all_label_ids.append(labels.squeeze(0).unsqueeze(1))
68 | input_lengths.append(len(prompt_input_ids[0]))
69 |
70 | input_ids = torch._C._nn.pad_sequence(
71 | all_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
72 | ).squeeze(2)
73 | labels = torch._C._nn.pad_sequence(
74 | all_label_ids, batch_first=True, padding_value=-100
75 | ).squeeze(2)
76 |
77 | p_collated['input_ids'] = input_ids
78 | p_collated['labels'] = labels
79 |
80 | if self.data_args.image_attention_mask:
81 | image_attention_mask = self.build_image_attention_mask(input_ids.size()[1], input_lengths)
82 | p_collated['attention_mask'] = image_attention_mask.unsqueeze(1)
83 | else:
84 | p_collated = None
85 |
86 | q_collated = self.tokenizer.pad(
87 | q_collated,
88 | padding=True,
89 | pad_to_multiple_of=self.data_args.pad_to_multiple_of,
90 | return_attention_mask=True,
91 | return_tensors='pt',
92 | )
93 | d_collated = self.tokenizer.pad(
94 | d_collated,
95 | padding=True,
96 | pad_to_multiple_of=self.data_args.pad_to_multiple_of,
97 | return_attention_mask=True,
98 | return_tensors='pt',
99 | )
100 |
101 | d_collated['pixel_values'] = torch.stack([d['pixel_values'][0] for d in collated_list], dim=0)
102 | d_collated['image_sizes'] = torch.stack([d['image_sizes'][0] for d in collated_list], dim=0)
103 | if self.data_args.pretrain:
104 | p_collated['pixel_values'] = d_collated['pixel_values']
105 | p_collated['image_sizes'] = d_collated['image_sizes']
106 |
107 | return q_collated, d_collated, p_collated
108 |
109 | @dataclass
110 | class EncodeCollator:
111 | data_args: DataArguments
112 | tokenizer: PreTrainedTokenizer
113 | processor: ProcessorMixin
114 |
115 | def __call__(self, features: List[Tuple[str, str]]):
116 | text_ids = [x[0] for x in features]
117 | texts = [x[1] for x in features]
118 | images = [x[-1] for x in features]
119 |
120 | if self.data_args.encode_is_query:
121 | collated = self.tokenizer(
122 | texts,
123 | padding=False,
124 | truncation=True,
125 | max_length=self.data_args.query_max_len-1 if self.data_args.append_eos_token else self.data_args.query_max_len,
126 | return_attention_mask=False,
127 | return_token_type_ids=False,
128 | add_special_tokens=True,
129 | )
130 | else:
131 | collated = {}
132 | collated_list = [self.processor("<|image_1|>\nWhat is shown in this image?", image, return_tensors="pt") for image in images]
133 | collated['input_ids'] = [d['input_ids'][0].tolist() for d in collated_list]
134 |
135 | if self.data_args.append_eos_token:
136 | collated['input_ids'] = [x + [self.tokenizer.eos_token_id] for x in collated['input_ids']]
137 |
138 | collated = self.tokenizer.pad(
139 | collated,
140 | padding=True,
141 | pad_to_multiple_of=self.data_args.pad_to_multiple_of,
142 | return_attention_mask=True,
143 | return_tensors='pt',
144 | )
145 | if not self.data_args.encode_is_query:
146 | collated['pixel_values'] = torch.stack([d['pixel_values'][0] for d in collated_list], dim=0)
147 | collated['image_sizes'] = torch.stack([d['image_sizes'][0] for d in collated_list], dim=0)
148 |
149 | return text_ids, collated
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from datasets import load_dataset
4 | from torch.utils.data import Dataset
5 | from PIL import Image
6 | import os
7 | import json
8 | from vdocrag.vdocretriever.arguments import DataArguments
9 |
10 | import logging
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class TrainDataset(Dataset):
15 | def __init__(self, data_args: DataArguments, trainer = None):
16 | self.data_args = data_args
17 | self.train_data = load_dataset(
18 | self.data_args.dataset_name,
19 | self.data_args.dataset_config,
20 | data_files=self.data_args.dataset_path,
21 | split=self.data_args.dataset_split,
22 | cache_dir=self.data_args.dataset_cache_dir,
23 | )
24 | if not self.data_args.pretrain:
25 | self.corpus = load_dataset(
26 | self.data_args.corpus_name,
27 | self.data_args.corpus_config,
28 | data_files=self.data_args.corpus_path,
29 | split=self.data_args.corpus_split,
30 | cache_dir=self.data_args.dataset_cache_dir,
31 | )
32 |
33 | self.docid2idx = {}
34 | if 'doc_id' in self.corpus.features:
35 | for idx, docid in enumerate(self.corpus['doc_id']):
36 | self.docid2idx[str(docid)] = idx
37 | else:
38 | for idx in range(len(self.corpus)):
39 | self.docid2idx[str(idx)] = idx
40 |
41 | self.trainer = trainer
42 |
43 | def __len__(self):
44 | return len(self.train_data)
45 |
46 | def __getitem__(self, item) -> Tuple[str, List[str]]:
47 | group = self.train_data[item]
48 | epoch = int(self.trainer.state.epoch)
49 |
50 | _hashed_seed = hash(item + self.trainer.args.seed)
51 |
52 | query = group['query']
53 | if self.data_args.pretrain:
54 | image = group['image']
55 | else:
56 | relevant_docids = group['relevant_doc_ids']
57 |
58 | if self.data_args.positive_document_no_shuffle:
59 | docid = relevant_docids[0]
60 | else:
61 | docid = relevant_docids[(_hashed_seed + epoch) % len(relevant_docids)]
62 |
63 | image = image = self.corpus[self.docid2idx[docid]]['image']
64 |
65 | return query, image
66 |
67 |
68 | class EncodeDataset(Dataset):
69 | def __init__(self, data_args: DataArguments):
70 | self.data_args = data_args
71 | if self.data_args.encode_is_query:
72 | self.encode_data = load_dataset(
73 | self.data_args.dataset_name,
74 | self.data_args.dataset_config,
75 | data_files=self.data_args.dataset_path,
76 | split=self.data_args.dataset_split,
77 | cache_dir=self.data_args.dataset_cache_dir,
78 | )
79 | else:
80 | self.encode_data = load_dataset(
81 | self.data_args.corpus_name,
82 | self.data_args.corpus_config,
83 | data_files=self.data_args.corpus_path,
84 | split=self.data_args.corpus_split,
85 | cache_dir=self.data_args.dataset_cache_dir,
86 | )
87 |
88 | if self.data_args.dataset_number_of_shards > 1:
89 | self.encode_data = self.encode_data.shard(
90 | num_shards=self.data_args.dataset_number_of_shards,
91 | index=self.data_args.dataset_shard_index,
92 | )
93 |
94 | def __len__(self):
95 | return len(self.encode_data)
96 |
97 | def __getitem__(self, item) -> Tuple[str, str]:
98 | data = self.encode_data[item]
99 | text, image = None, None
100 | if self.data_args.encode_is_query:
101 | id = data['query_id']
102 | text = data['query']
103 | else:
104 | id = data['doc_id']
105 | image = data['image']
106 | return id, text, image
107 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/driver/encode.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import torch
11 |
12 | from torch.utils.data import DataLoader
13 | from transformers import AutoTokenizer, AutoProcessor
14 | from transformers import (
15 | HfArgumentParser,
16 | )
17 |
18 | from vdocrag.vdocretriever.arguments import ModelArguments, DataArguments, \
19 | VDocRetrieverTrainingArguments as TrainingArguments
20 | from vdocrag.vdocretriever.dataset import EncodeDataset
21 | from vdocrag.vdocretriever.collator import EncodeCollator
22 | from vdocrag.vdocretriever.modeling import EncoderOutput, VDocRetriever
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def main():
28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
31 | else:
32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
33 | model_args: ModelArguments
34 | data_args: DataArguments
35 | training_args: TrainingArguments
36 |
37 | if training_args.local_rank > 0 or training_args.n_gpu > 1:
38 | raise NotImplementedError('Multi-GPU encoding is not supported.')
39 |
40 | # Setup logging
41 | logging.basicConfig(
42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
43 | datefmt="%m/%d/%Y %H:%M:%S",
44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
45 | )
46 |
47 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
48 | cache_dir=model_args.cache_dir,
49 | trust_remote_code=True,)
50 | tokenizer = processor.tokenizer
51 |
52 | if tokenizer.pad_token_id is None:
53 | tokenizer.pad_token_id = tokenizer.eos_token_id
54 | tokenizer.padding_side = 'right'
55 |
56 | if training_args.bf16:
57 | torch_dtype = torch.bfloat16
58 | elif training_args.fp16:
59 | torch_dtype = torch.float16
60 | else:
61 | torch_dtype = torch.float32
62 |
63 | model = VDocRetriever.load(
64 | model_args.model_name_or_path,
65 | pooling=model_args.pooling,
66 | normalize=model_args.normalize,
67 | lora_name_or_path=model_args.lora_name_or_path,
68 | trust_remote_code=True,
69 | cache_dir=model_args.cache_dir,
70 | torch_dtype=torch_dtype,
71 | _attn_implementation='flash_attention_2',
72 | )
73 |
74 | encode_dataset = EncodeDataset(
75 | data_args=data_args,
76 | )
77 |
78 | encode_collator = EncodeCollator(
79 | data_args=data_args,
80 | tokenizer=tokenizer,
81 | processor=processor,
82 | )
83 |
84 | encode_loader = DataLoader(
85 | encode_dataset,
86 | batch_size=training_args.per_device_eval_batch_size,
87 | collate_fn=encode_collator,
88 | shuffle=False,
89 | drop_last=False,
90 | num_workers=training_args.dataloader_num_workers,
91 | )
92 | encoded = []
93 | lookup_indices = []
94 | model = model.to(training_args.device)
95 | model.eval()
96 |
97 | for (batch_ids, batch) in tqdm(encode_loader):
98 | lookup_indices.extend(batch_ids)
99 | with nullcontext():
100 | with torch.no_grad():
101 | for k, v in batch.items():
102 | batch[k] = v.to(training_args.device)
103 | if data_args.encode_is_query:
104 | model_output: EncoderOutput = model(query=batch, use_cache=False)
105 | encoded.append(model_output.q_reps.cpu().detach().float().numpy())
106 | else:
107 | model_output: EncoderOutput = model(document=batch, use_cache=False)
108 | encoded.append(model_output.p_reps.cpu().detach().float().numpy())
109 |
110 | encoded = np.concatenate(encoded)
111 | if not os.path.exists(os.path.dirname(data_args.encode_output_path)):
112 | os.makedirs(os.path.dirname(data_args.encode_output_path))
113 | with open(data_args.encode_output_path, 'wb') as f:
114 | pickle.dump((encoded, lookup_indices), f)
115 |
116 |
117 | if __name__ == "__main__":
118 | main()
119 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/driver/search.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import numpy as np
4 | import glob
5 | from argparse import ArgumentParser
6 | from itertools import chain
7 | from tqdm import tqdm
8 |
9 | from vdocrag.vdocretriever.searcher import FaissFlatSearcher
10 |
11 | import logging
12 | logger = logging.getLogger(__name__)
13 | logging.basicConfig(
14 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
15 | datefmt="%m/%d/%Y %H:%M:%S",
16 | level=logging.INFO,
17 | )
18 |
19 |
20 | def search_queries(retriever, q_reps, p_lookup, args):
21 | if args.batch_size > 0:
22 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size, args.quiet)
23 | else:
24 | all_scores, all_indices = retriever.search(q_reps, args.depth)
25 |
26 | psg_indices = [[str(p_lookup[x]) for x in q_dd] for q_dd in all_indices]
27 | psg_indices = np.array(psg_indices)
28 | return all_scores, psg_indices
29 |
30 |
31 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file):
32 | with open(ranking_save_file, 'w') as f:
33 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices):
34 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)]
35 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True)
36 | for s, idx in score_list:
37 | f.write(f'{qid}\t{idx}\t{s}\n')
38 |
39 |
40 | def pickle_load(path):
41 | with open(path, 'rb') as f:
42 | reps, lookup = pickle.load(f)
43 | return np.array(reps), lookup
44 |
45 |
46 | def pickle_save(obj, path):
47 | with open(path, 'wb') as f:
48 | pickle.dump(obj, f)
49 |
50 |
51 | def main():
52 | parser = ArgumentParser()
53 | parser.add_argument('--query_reps', required=True)
54 | parser.add_argument('--document_reps', required=True)
55 | parser.add_argument('--batch_size', type=int, default=128)
56 | parser.add_argument('--depth', type=int, default=1000)
57 | parser.add_argument('--save_ranking_to', required=True)
58 | parser.add_argument('--save_text', action='store_true')
59 | parser.add_argument('--quiet', action='store_true')
60 |
61 | args = parser.parse_args()
62 |
63 | index_files = glob.glob(args.document_reps)
64 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.')
65 |
66 | p_reps_0, p_lookup_0 = pickle_load(index_files[0])
67 | retriever = FaissFlatSearcher(p_reps_0)
68 |
69 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:]))
70 | if len(index_files) > 1:
71 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files))
72 | look_up = []
73 | for p_reps, p_lookup in shards:
74 | retriever.add(p_reps)
75 | look_up += p_lookup
76 |
77 | q_reps, q_lookup = pickle_load(args.query_reps)
78 | q_reps = q_reps
79 |
80 | logger.info('Index Search Start')
81 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args)
82 | logger.info('Index Search Finished')
83 |
84 | if args.save_text:
85 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to)
86 | else:
87 | pickle_save((all_scores, psg_indices), args.save_ranking_to)
88 |
89 |
90 | if __name__ == '__main__':
91 | main()
92 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/driver/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import torch
5 | import wandb
6 |
7 | from transformers import AutoTokenizer
8 | from transformers import AutoProcessor
9 |
10 | from transformers import (
11 | HfArgumentParser,
12 | set_seed,
13 | )
14 |
15 | from vdocrag.vdocretriever.arguments import ModelArguments, DataArguments, \
16 | VDocRetrieverTrainingArguments as TrainingArguments
17 | from vdocrag.vdocretriever.dataset import TrainDataset
18 | from vdocrag.vdocretriever.collator import TrainCollator
19 | from vdocrag.vdocretriever.modeling import VDocRetriever
20 | from vdocrag.vdocretriever.trainer import VDocRetrieverTrainer as Trainer
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def main():
26 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
27 |
28 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
29 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
30 | else:
31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
32 | model_args: ModelArguments
33 | data_args: DataArguments
34 | training_args: TrainingArguments
35 |
36 | if (
37 | os.path.exists(training_args.output_dir)
38 | and os.listdir(training_args.output_dir)
39 | and training_args.do_train
40 | and not training_args.overwrite_output_dir
41 | ):
42 | raise ValueError(
43 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
44 | )
45 |
46 | # Setup logging
47 | logging.basicConfig(
48 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
49 | datefmt="%m/%d/%Y %H:%M:%S",
50 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
51 | )
52 | logger.warning(
53 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
54 | training_args.local_rank,
55 | training_args.device,
56 | training_args.n_gpu,
57 | bool(training_args.local_rank != -1),
58 | training_args.fp16,
59 | )
60 | logger.info("Training/evaluation parameters %s", training_args)
61 | logger.info("MODEL parameters %s", model_args)
62 |
63 | set_seed(training_args.seed)
64 |
65 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
66 | cache_dir=model_args.cache_dir,
67 | trust_remote_code=True)
68 | tokenizer = processor.tokenizer
69 |
70 | if training_args.bf16:
71 | torch_dtype = torch.bfloat16
72 | elif training_args.fp16:
73 | torch_dtype = torch.float16
74 | else:
75 | torch_dtype = torch.float32
76 |
77 | model = VDocRetriever.build(
78 | model_args,
79 | training_args,
80 | cache_dir=model_args.cache_dir,
81 | trust_remote_code=True,
82 | torch_dtype=torch_dtype,
83 | _attn_implementation='eager' if data_args.pretrain else 'flash_attention_2',
84 | )
85 |
86 | train_dataset = TrainDataset(data_args)
87 | collator = TrainCollator(data_args, tokenizer, processor)
88 |
89 | trainer_cls = Trainer
90 |
91 | trainer = trainer_cls(
92 | model=model,
93 | args=training_args,
94 | train_dataset=train_dataset,
95 | data_collator=collator
96 | )
97 | train_dataset.trainer = trainer
98 |
99 | trainer.train() # TODO: resume training
100 | trainer.save_model()
101 | if trainer.is_world_process_zero():
102 | tokenizer.save_pretrained(training_args.output_dir)
103 |
104 |
105 | if __name__ == "__main__":
106 | main()
107 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .vdocretriever import VDocRetriever, EncoderOutput
2 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/modeling/vdocretriever.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Optional
3 |
4 | import torch
5 | import torch.distributed as dist
6 | from torch import nn, Tensor
7 |
8 | from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq
9 | from peft import LoraConfig, TaskType, get_peft_model, PeftModel
10 |
11 | from transformers.file_utils import ModelOutput
12 | from vdocrag.vdocretriever.arguments import ModelArguments, VDocRetrieverTrainingArguments as TrainingArguments
13 |
14 | import logging
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | @dataclass
19 | class EncoderOutput(ModelOutput):
20 | q_reps: Optional[Tensor] = None
21 | p_reps: Optional[Tensor] = None
22 | loss: Optional[Tensor] = None
23 | scores: Optional[Tensor] = None
24 |
25 | class VDocRetriever(nn.Module):
26 | TRANSFORMER_CLS = AutoModelForCausalLM
27 |
28 | def __init__(self,
29 | encoder: PreTrainedModel,
30 | pooling: str = 'cls',
31 | normalize: bool = False,
32 | temperature: float = 1.0,
33 | ):
34 | super().__init__()
35 | self.config = encoder.config
36 | self.encoder = encoder
37 | self.pooling = pooling
38 | self.normalize = normalize
39 | self.temperature = temperature
40 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
41 | self.is_ddp = dist.is_initialized()
42 | if self.is_ddp:
43 | self.process_rank = dist.get_rank()
44 | self.world_size = dist.get_world_size()
45 |
46 | def forward(self, query: Dict[str, Tensor] = None,
47 | document: Dict[str, Tensor] = None,
48 | pair: Dict[str, Tensor] = None,
49 | use_cache: bool = True
50 | ):
51 | q_reps = self.encode_query(query, use_cache=use_cache) if query else None
52 | p_reps = self.encode_document(document, use_cache=use_cache) if document else None
53 | outputs = self.generate_output(pair, use_cache=use_cache) if pair else None # pre-training
54 |
55 | # for inference
56 | if q_reps is None or p_reps is None:
57 | return EncoderOutput(
58 | q_reps=q_reps,
59 | p_reps=p_reps
60 | )
61 |
62 | # for training
63 | if self.training:
64 | if self.is_ddp:
65 | q_reps = self._dist_gather_tensor(q_reps)
66 | p_reps = self._dist_gather_tensor(p_reps)
67 |
68 | scores = self.compute_similarity(q_reps, p_reps)
69 | scores = scores.view(q_reps.size(0), -1)
70 |
71 | target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
72 | target = target * (p_reps.size(0) // q_reps.size(0))
73 |
74 | loss = self.compute_loss(scores / self.temperature, target)
75 |
76 | if outputs:
77 | loss = loss + outputs.loss
78 |
79 | if self.is_ddp:
80 | loss = loss * self.world_size # counter average weight reduction
81 |
82 | # for eval
83 | else:
84 | scores = self.compute_similarity(q_reps, p_reps)
85 | loss = None
86 |
87 | return EncoderOutput(
88 | loss=loss,
89 | scores=scores,
90 | q_reps=q_reps,
91 | p_reps=p_reps,
92 | )
93 |
94 | def compute_similarity(self, q_reps, p_reps):
95 | return torch.matmul(q_reps, p_reps.transpose(0, 1))
96 |
97 | def compute_loss(self, scores, target):
98 | return self.cross_entropy(scores, target)
99 |
100 | def gradient_checkpointing_enable(self, **kwargs):
101 | self.encoder.model.gradient_checkpointing_enable()
102 |
103 | def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
104 | if t is None:
105 | return None
106 | t = t.contiguous()
107 |
108 | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
109 | dist.all_gather(all_tensors, t)
110 |
111 | all_tensors[self.process_rank] = t
112 | all_tensors = torch.cat(all_tensors, dim=0)
113 |
114 | return all_tensors
115 |
116 | @classmethod
117 | def build(
118 | cls,
119 | model_args: ModelArguments,
120 | train_args: TrainingArguments,
121 | **hf_kwargs,
122 | ):
123 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
124 | if base_model.config.pad_token_id is None:
125 | base_model.config.pad_token_id = 0
126 |
127 | if model_args.lora or model_args.lora_name_or_path:
128 | if train_args.gradient_checkpointing:
129 | base_model.enable_input_require_grads()
130 | if model_args.lora_name_or_path:
131 | lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
132 | lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, is_trainable=True)
133 | else:
134 | lora_config = LoraConfig(
135 | base_model_name_or_path=model_args.model_name_or_path,
136 | task_type=TaskType.FEATURE_EXTRACTION,
137 | r=model_args.lora_r,
138 | lora_alpha=model_args.lora_alpha,
139 | lora_dropout=model_args.lora_dropout,
140 | target_modules=model_args.lora_target_modules.split(','),
141 | inference_mode=False
142 | )
143 | lora_model = get_peft_model(base_model, lora_config)
144 | model = cls(
145 | encoder=lora_model,
146 | pooling=model_args.pooling,
147 | normalize=model_args.normalize,
148 | temperature=model_args.temperature,
149 | )
150 | else:
151 | model = cls(
152 | encoder=base_model,
153 | pooling=model_args.pooling,
154 | normalize=model_args.normalize,
155 | temperature=model_args.temperature
156 | )
157 | return model
158 |
159 | @classmethod
160 | def load(cls,
161 | model_name_or_path: str,
162 | pooling: str = 'cls',
163 | normalize: bool = False,
164 | lora_name_or_path: str = None,
165 | **hf_kwargs):
166 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs)
167 | if base_model.config.pad_token_id is None:
168 | base_model.config.pad_token_id = 0
169 | if lora_name_or_path:
170 | lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
171 | lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
172 | lora_model = lora_model.merge_and_unload()
173 | model = cls(
174 | encoder=lora_model,
175 | pooling=pooling,
176 | normalize=normalize
177 | )
178 | else:
179 | model = cls(
180 | encoder=base_model,
181 | pooling=pooling,
182 | normalize=normalize
183 | )
184 | return model
185 |
186 | def save(self, output_dir: str):
187 | self.encoder.save_pretrained(output_dir)
188 |
189 | def encode_query(self, qry, use_cache=True):
190 | query_hidden_states = self.encoder(**qry, return_dict=True, output_hidden_states=True, output_attentions=True, use_cache=use_cache)
191 | query_hidden_states = query_hidden_states.hidden_states[-1]
192 | return self._pooling(query_hidden_states, qry['attention_mask'])
193 |
194 | def encode_document(self, doc, use_cache=True):
195 | return self.encode_query(doc, use_cache=use_cache)
196 |
197 | def generate_output(self, pair, use_cache=True):
198 | return self.encoder(**pair, use_cache=use_cache)
199 |
200 | def _pooling(self, last_hidden_state, attention_mask):
201 | if self.pooling in ['cls', 'first']:
202 | reps = last_hidden_state[:, 0]
203 | elif self.pooling in ['mean', 'avg', 'average']:
204 | masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
205 | reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
206 | elif self.pooling in ['last', 'eos']:
207 | sequence_lengths = attention_mask.sum(dim=1) - 1
208 | batch_size = last_hidden_state.shape[0]
209 | reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
210 | else:
211 | raise ValueError(f'unknown pooling method: {self.pooling}')
212 | if self.normalize:
213 | reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
214 | return reps
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/searcher.py:
--------------------------------------------------------------------------------
1 | import faiss
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 |
6 | import logging
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class FaissFlatSearcher:
12 | def __init__(self, init_reps: np.ndarray):
13 | index = faiss.IndexFlatIP(init_reps.shape[1])
14 | self.index = index
15 |
16 | def add(self, p_reps: np.ndarray):
17 | self.index.add(p_reps)
18 |
19 | def search(self, q_reps: np.ndarray, k: int):
20 | return self.index.search(q_reps, k)
21 |
22 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int, quiet: bool=False):
23 | num_query = q_reps.shape[0]
24 | all_scores = []
25 | all_indices = []
26 | for start_idx in tqdm(range(0, num_query, batch_size), disable=quiet):
27 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k)
28 | all_scores.append(nn_scores)
29 | all_indices.append(nn_indices)
30 | all_scores = np.concatenate(all_scores, axis=0)
31 | all_indices = np.concatenate(all_indices, axis=0)
32 |
33 | return all_scores, all_indices
34 |
35 |
36 | class FaissSearcher(FaissFlatSearcher):
37 |
38 | def __init__(self, init_reps: np.ndarray, factory_str: str):
39 | index = faiss.index_factory(init_reps.shape[1], factory_str)
40 | self.index = index
41 | self.index.verbose = True
42 | if not self.index.is_trained:
43 | self.index.train(init_reps)
44 |
--------------------------------------------------------------------------------
/src/vdocrag/vdocretriever/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 |
4 | import torch
5 |
6 | from transformers.trainer import Trainer, TRAINING_ARGS_NAME
7 | import torch.distributed as dist
8 | from .modeling import VDocRetriever
9 | from huggingface_hub import login
10 |
11 | import logging
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class VDocRetrieverTrainer(Trainer):
16 | def __init__(self, *args, **kwargs):
17 | super(VDocRetrieverTrainer, self).__init__(*args, **kwargs)
18 | self.is_ddp = dist.is_initialized()
19 | self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1
20 |
21 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
22 | # If we are executing this function, we are the process zero, so we don't check for that.
23 | output_dir = output_dir if output_dir is not None else self.args.output_dir
24 | os.makedirs(output_dir, exist_ok=True)
25 | logger.info(f"Saving model checkpoint to {output_dir}")
26 |
27 | supported_classes = (VDocRetriever,)
28 | # Save a trained model and configuration using `save_pretrained()`.
29 | # They can then be reloaded using `from_pretrained()`
30 | if not isinstance(self.model, supported_classes):
31 | raise ValueError(f"Unsupported model class {self.model}")
32 | else:
33 | if state_dict is None:
34 | state_dict = self.model.state_dict()
35 | prefix = 'encoder.'
36 | assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
37 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
38 | self.model.encoder.save_pretrained(
39 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
40 | )
41 |
42 | if self.tokenizer is not None:
43 | self.tokenizer.save_pretrained(output_dir)
44 |
45 | # Good practice: save your training arguments together with the trained model
46 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
47 |
48 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
49 | query, document, pair = inputs
50 | return model(query=query, document=document, pair=pair).loss
51 |
52 | def training_step(self, *args):
53 | return super(VDocRetrieverTrainer, self).training_step(*args) / self._dist_loss_scale_factor
54 |
55 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import requests
3 | from io import BytesIO
4 | from torch.nn.functional import cosine_similarity
5 | import torch
6 | from transformers import AutoProcessor
7 | from vdocrag.vdocretriever.modeling import VDocRetriever
8 | from vdocrag.vdocgenerator.modeling import VDocGenerator
9 |
10 | ### Retrieval ###
11 |
12 | processor = AutoProcessor.from_pretrained('microsoft/Phi-3-vision-128k-instruct', trust_remote_code=True)
13 | model = VDocRetriever.load('microsoft/Phi-3-vision-128k-instruct', lora_name_or_path='NTT-hil-insight/VDocRetriever-Phi3-vision', pooling='eos', normalize=True, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0')
14 |
15 | # Process query inputs and get embeddings
16 | queries = ["Instruct: I’m looking for an image that answers the question.\nQuery: What is the total percentage of Palestinians residing at West Bank?",
17 | "Instruct: I’m looking for an image that answers the question.\nQuery: How many international visitors came to Japan in 2017?"]
18 | query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=256, truncation=True).to('cuda:0')
19 |
20 | with torch.no_grad():
21 | model_output = model(query=query_inputs, use_cache=False)
22 | query_embeddings = model_output.q_reps
23 |
24 | # List of image URLs
25 | urls = [
26 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image1.png",
27 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image2.png"
28 | ]
29 |
30 | # Download, open, and resize images
31 | doc_images = [Image.open(BytesIO(requests.get(url).content)).resize((1344, 1344)) for url in urls]
32 |
33 | # Process images with prompt
34 | doc_prompt = "<|image_1|>\nWhat is shown in this image?"
35 | collated_list = [
36 | processor(
37 | doc_prompt,
38 | images=image,
39 | return_tensors="pt",
40 | padding="longest",
41 | max_length=4096,
42 | truncation=True
43 | ).to('cuda:0') for image in doc_images
44 | ]
45 |
46 | # Stack tensors into input dict
47 | doc_inputs = {
48 | key: torch.stack([item[key][0] for item in collated_list], dim=0)
49 | for key in ['input_ids', 'attention_mask', 'pixel_values', 'image_sizes']
50 | }
51 |
52 | with torch.no_grad():
53 | model_output = model(document=doc_inputs, use_cache=False)
54 | doc_embeddings = model_output.p_reps
55 |
56 | # Calculate cosine similarity
57 | num_queries = query_embeddings.size(0)
58 | num_passages = doc_embeddings.size(0)
59 |
60 | for i in range(num_queries):
61 | query_embedding = query_embeddings[i].unsqueeze(0)
62 | similarities = cosine_similarity(query_embedding, doc_embeddings)
63 | print(f"Similarities for Query {i}: {similarities.cpu().float().numpy()}")
64 |
65 | # >> Similarities for Query 0: [0.5078125 0.38085938]
66 | # Similarities for Query 1: [0.37695312 0.5703125 ]
67 |
68 | ### Generation ###
69 |
70 | model = VDocGenerator.load('microsoft/Phi-3-vision-128k-instruct', lora_name_or_path='NTT-hil-insight/VDocGenerator-Phi3-vision', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0')
71 |
72 | query = "How many international visitors came to Japan in 2017? \n Answer briefly."
73 |
74 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(doc_images))])
75 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}]
76 | prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
77 |
78 | processed = processor(prompt, images=doc_images, return_tensors="pt").to('cuda:0')
79 | generate_ids = model.generate(processed, generation_args={"max_new_tokens": 64, "temperature": 0.0, "do_sample": False, "eos_token_id": processor.tokenizer.eos_token_id})
80 | generate_ids = generate_ids[:, processed['input_ids'].shape[1]:]
81 | response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
82 | response = response.strip()
83 | print("Model prediction: {0}".format(response))
84 |
85 | ## >> Model prediction: 28.69m
--------------------------------------------------------------------------------