├── LICENSE ├── README.md ├── deepspeed └── ds_zero3_config.json ├── images ├── abst.png ├── image1.png └── image2.png ├── requirements.txt ├── scripts ├── create_test_chartqa_generator.sh ├── create_test_dude_generator.sh ├── create_test_infovqa_generator.sh ├── create_test_slidevqa_generator.sh └── create_train_generator.sh ├── setup.py ├── src └── vdocrag │ ├── __init__.py │ ├── utils │ ├── __init__.py │ ├── eval_opendocvqa.py │ └── format │ │ ├── __init__.py │ │ ├── convert_qas_to_trec_qrels.py │ │ └── convert_result_to_trec.py │ ├── vdocgenerator │ ├── __init__.py │ ├── arguments.py │ ├── collator.py │ ├── dataset.py │ ├── driver │ │ ├── generate.py │ │ └── train.py │ ├── modeling │ │ ├── __init__.py │ │ └── vdocgenerator.py │ └── trainer.py │ └── vdocretriever │ ├── __init__.py │ ├── arguments.py │ ├── collator.py │ ├── dataset.py │ ├── driver │ ├── encode.py │ ├── search.py │ └── train.py │ ├── modeling │ ├── __init__.py │ └── vdocretriever.py │ ├── searcher.py │ └── trainer.py └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT"). 4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. 5 | 6 | 7 | BACKGROUND 8 | A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and related documentation listed in Exhibit A to this Agreement. 9 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. 10 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. 11 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: 12 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper submitted by NTT to a certain academy. User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. 13 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. 14 | 3. Term. This Agreement is effective whichever is earlier (i) upon User’s acceptance of the Agreement, or (ii) upon User’s installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by User’s decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof, or to destroy all such materials and provide written verification of such destruction to NTT. 15 | 4. Proprietary Rights 16 | (a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights, copyrights and trade secret rights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. 17 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE. 18 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 19 | 5.  Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 20 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 21 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 22 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. 23 | 9. General 24 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. 25 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 26 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 27 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 28 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 29 | (f)   NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT’s obligation set forth under this Agreement due to any cause beyond NTT’s reasonable control. 30 |   31 | EXHIBIT A 32 | The software and related data include the following files, 33 | ├── LICENSE 34 | ├── README.md 35 | ├── deepspeed 36 | │ └── ds_zero3_config.json 37 | ├── images 38 | │ ├── abst.png 39 | │ ├── image1.png 40 | │ └── image2.png 41 | ├── requirements.txt 42 | ├── scripts 43 | │ ├── create_test_chartqa_generator.sh 44 | │ ├── create_test_dude_generator.sh 45 | │ ├── create_test_infovqa_generator.sh 46 | │ ├── create_test_slidevqa_generator.sh 47 | │ └── create_train_generator.sh 48 | ├── setup.py 49 | ├── src 50 | │ └── vdocrag 51 | │ ├── __init__.py 52 | │ ├── utils 53 | │ │ ├── __init__.py 54 | │ │ ├── eval_opendocvqa.py 55 | │ │ └── format 56 | │ │ ├── __init__.py 57 | │ │ ├── convert_qas_to_trec_qrels.py 58 | │ │ └── convert_result_to_trec.py 59 | │ ├── vdocgenerator 60 | │ │ ├── __init__.py 61 | │ │ ├── arguments.py 62 | │ │ ├── collator.py 63 | │ │ ├── dataset.py 64 | │ │ ├── driver 65 | │ │ │ ├── generate.py 66 | │ │ │ └── train.py 67 | │ │ ├── modeling 68 | │ │ │ ├── __init__.py 69 | │ │ │ └── vdocgenerator.py 70 | │ │ └── trainer.py 71 | │ └── vdocretriever 72 | │ ├── __init__.py 73 | │ ├── arguments.py 74 | │ ├── collator.py 75 | │ ├── dataset.py 76 | │ ├── driver 77 | │ │ ├── encode.py 78 | │ │ ├── search.py 79 | │ │ └── train.py 80 | │ ├── modeling 81 | │ │ ├── __init__.py 82 | │ │ └── vdocretriever.py 83 | │ ├── searcher.py 84 | │ └── trainer.py 85 | └── test.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # VDocRAG: Retirval-Augmented Generation over Visually-Rich Documents 4 | 5 | [![Project Page](https://img.shields.io/badge/VDocRAG-Website-green?logo=googlechrome&logoColor=green)](https://vdocrag.github.io/) 6 | [![VDocRAG Paper](https://img.shields.io/badge/VDocRAG-arXiv%202025-b31b1b?logo=arxiv&logoColor=red)](http://arxiv.org/abs/2504.09795) 7 | [![Model (Retriever)](https://img.shields.io/badge/%F0%9F%A4%97%20Model-Retriever-yellow)](https://huggingface.co/NTT-hil-insight/VDocRetriever-Phi3-vision) 8 | [![Model (Generator)](https://img.shields.io/badge/%F0%9F%A4%97%20Model-Generator-yellow)](https://huggingface.co/NTT-hil-insight/VDocGenerator-Phi3-vision) 9 | [![Dataset (QA)](https://img.shields.io/badge/%F0%9F%A4%97%20Dataset-QA-yellow)](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA) 10 | [![Dataset (Corpus)](https://img.shields.io/badge/%F0%9F%A4%97%20Dataset-Corpus-yellow)](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA-Corpus) 11 | [![Conference](https://img.shields.io/badge/CVPR-2025-blue)](https://cvpr.thecvf.com/) 12 |
13 | 14 | This repository includes the VDocRAG introduced by the following paper: Ryota Tanaka, Taichi Iki, Taku Hasegawa, Kyosuke Nishida, Kuniko Saito, and Jun Suzuki. [VDocRAG: Retirval-Augmented Generation over Visually-Rich Documents](http://arxiv.org/abs/2504.09795). In Proc. of CVPR 2025. 15 | 16 | 17 | **VDocRAG** is a new RAG framework that can directly understand diverse real-world documents purely from visual features. 18 | 19 | **💪 Key Enhancements of VDocRAG:** 20 | - **New Pretraining Tasks:** we propose novel self-supervised pre-training tasks (**RCR** and **RCG**) that adapt large vision-language models for retrieval by compressing visual information into dense token representations while aligning them with textual content in documents. 21 | - **New Dataset:** we introduce **OpenDocVQA**, the first unified collection of open-domain document visual question answering datasets, encompassing diverse document types and formats. 22 | 23 |
24 | VDocRAG 25 |
26 | 27 | # 📌Contents 28 | - [News](#news) 29 | - [Installation](#installation) 30 | - [Quick Start](#quick_start) 31 | - [Dataset](#dataset) 32 | - [Retriever](#retriever) 33 | - [Generator](#generator) 34 | - [LICENSE](#license) 35 | - [Citation](#citation) 36 | - [Acknowledgement](#acknowledgement) 37 | 38 | 39 | 40 | # 📢 News 41 | - [2025/04]: The pre-trained VDocRetriever weights are out in 🤗 [Huggingface Hub](https://huggingface.co/NTT-hil-insight/VDocRetriever-Phi3-vision-pretrained). 42 | - [2025/04]: The technical report, code, data, and model for VDocRAG are all available online. 43 | - [2025/02]: 🎉 VDocRAG is accepted to CVPR 2025. 44 | 45 | 46 | # ⚙️ Installation 47 | 1. Clone the repository. 48 | 2. Install PyTorch based on your CUDA version from PyTorch. 49 | 3. Install dependencies and VDocRAG. 50 | ```bash 51 | pip install -r requirements.txt 52 | pip install -e . 53 | ``` 54 | 55 | 56 | # ⚡️ Quick Start 57 | You can download [VDocRetriever](https://huggingface.co/NTT-hil-insight/VDocRetriever-Phi3-vision) and [VDocGenerator](https://huggingface.co/NTT-hil-insight/VDocGenerator-Phi3-vision) from 🤗 HuggingFace Hub. To get started, first import the libraries as shown below: 58 | ```py 59 | from PIL import Image 60 | import requests 61 | from io import BytesIO 62 | from torch.nn.functional import cosine_similarity 63 | import torch 64 | from transformers import AutoProcessor 65 | from vdocrag.vdocretriever.modeling import VDocRetriever 66 | from vdocrag.vdocgenerator.modeling import VDocGenerator 67 | ``` 68 | 69 | ## Retrieval 70 | ```py 71 | processor = AutoProcessor.from_pretrained('microsoft/Phi-3-vision-128k-instruct', trust_remote_code=True) 72 | model = VDocRetriever.load('microsoft/Phi-3-vision-128k-instruct', 73 | lora_name_or_path='NTT-hil-insight/VDocRetriever-Phi3-vision', 74 | pooling='eos', 75 | normalize=True, 76 | trust_remote_code=True, 77 | attn_implementation="flash_attention_2", 78 | torch_dtype=torch.bfloat16, 79 | use_cache=False).to('cuda:0') 80 | 81 | # Process query inputs and get the embeddings 82 | queries = ["Instruct: I’m looking for an image that answers the question.\nQuery: What is the total percentage of Palestinians residing at West Bank?", 83 | "Instruct: I’m looking for an image that answers the question.\nQuery: How many international visitors came to Japan in 2017?"] 84 | query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=256, truncation=True).to('cuda:0') 85 | 86 | with torch.no_grad(): 87 | model_output = model(query=query_inputs, use_cache=False) 88 | query_embeddings = model_output.q_reps 89 | 90 | urls = [ 91 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image1.png", 92 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image2.png" 93 | ] 94 | 95 | doc_images = [Image.open(BytesIO(requests.get(url).content)).resize((1344, 1344)) for url in urls] 96 | 97 | # Process images and get the embeddings 98 | doc_prompt = "<|image_1|>\nWhat is shown in this image?" 99 | collated_list = [ 100 | processor(doc_prompt, images=image, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0') for image in doc_images 101 | ] 102 | 103 | doc_inputs = { 104 | key: torch.stack([item[key][0] for item in collated_list], dim=0) 105 | for key in ['input_ids', 'attention_mask', 'pixel_values', 'image_sizes'] 106 | } 107 | 108 | with torch.no_grad(): 109 | model_output = model(document=doc_inputs, use_cache=False) 110 | doc_embeddings = model_output.p_reps 111 | 112 | # Calculate cosine similarity 113 | num_queries = query_embeddings.size(0) 114 | num_passages = doc_embeddings.size(0) 115 | 116 | for i in range(num_queries): 117 | query_embedding = query_embeddings[i].unsqueeze(0) 118 | similarities = cosine_similarity(query_embedding, doc_embeddings) 119 | print(f"Similarities for Query {i}: {similarities.cpu().float().numpy()}") 120 | 121 | # >> Similarities for Query 0: [0.515625 0.38476562] 122 | # Similarities for Query 1: [0.37890625 0.5703125 ] 123 | ``` 124 | 125 | ## Generation 126 | ```py 127 | model = VDocGenerator.load('microsoft/Phi-3-vision-128k-instruct', 128 | lora_name_or_path='NTT-hil-insight/VDocGenerator-Phi3-vision', 129 | trust_remote_code=True, 130 | attn_implementation="flash_attention_2", 131 | torch_dtype=torch.bfloat16, 132 | use_cache=False).to('cuda:0') 133 | 134 | # Process images with the prompt 135 | query = "How many international visitors came to Japan in 2017? \n Answer briefly." 136 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(doc_images))]) 137 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}] 138 | prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 139 | processed = processor(prompt, images=doc_images, return_tensors="pt").to('cuda:0') 140 | 141 | # Generate the answer 142 | generate_ids = model.generate(processed, 143 | generation_args={ 144 | "max_new_tokens": 64, 145 | "temperature": 0.0, 146 | "do_sample": False, 147 | "eos_token_id": processor.tokenizer.eos_token_id 148 | }) 149 | generate_ids = generate_ids[:, processed['input_ids'].shape[1]:] 150 | response = processor.batch_decode(generate_ids, 151 | skip_special_tokens=True, 152 | clean_up_tokenization_spaces=False)[0].strip() 153 | 154 | print("Model prediction: {0}".format(response)) 155 | 156 | # >> Model prediction: 28.69m 157 | ``` 158 | 159 | 160 | # 💾 Dataset 161 | OpenDocVQA is a unified collection of open-domain document visual question answering datasets, encompassing diverse document types and formats. It consists of 9 open-domain DocumentVQA datasets, including a newly created MHDocVQA dataset to address multi-hop questions over multiple documents, and collected and filtered QA datasets (DocVQA, InfoVQA, DUDE, VisulMRC, ChartQA, OpenWikiTable, MPMQA, and SlideVQA). In total, OpenDocVQA contains 43k QA paris with 200k document images. 162 | 163 | You can donwload OpenDocVQA dataset from 🤗 HuggingFace Hub as follows: 164 | - [QA Pairs](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA) 165 | - [Corpus](https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA-Corpus) 166 | 167 | 168 | # 🔎 Retriever 169 | 170 | ## Pre-training VDocRetriever 171 | This script supports our proposed pre-training tasks, including RCR (Representation Compression via Retrieval) and RCG (Representation Compression via Generation). 172 | ```bash 173 | deepspeed --include localhost:0 --master_port 60000 --module vdocrag.vdocretriever.driver.train \ 174 | --deepspeed deepspeed/ds_zero3_config.json \ 175 | --output_dir outputs/vdocretriever-phi3-vision_pretrain \ 176 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 177 | --lora \ 178 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ 179 | --save_steps 2000 \ 180 | --dataset_name NTT-hil-insight/VDocRetriever-Pretrain-DocStruct \ 181 | --bf16 \ 182 | --pooling eos \ 183 | --append_eos_token \ 184 | --normalize \ 185 | --temperature 0.01 \ 186 | --per_device_train_batch_size 1 \ 187 | --gradient_checkpointing \ 188 | --train_group_size 1 \ 189 | --learning_rate 1e-4 \ 190 | --query_max_len 512 \ 191 | --answer_max_len 512 \ 192 | --num_train_epochs 1 \ 193 | --logging_steps 10 \ 194 | --overwrite_output_dir \ 195 | --gradient_accumulation_steps 4 \ 196 | --pretrain \ 197 | --image_attention_mask \ 198 | --report_to wandb \ 199 | ``` 200 | 201 | ## Fine-tuning VDocRetriever 202 | If you set `--lora_name_or_path` as `NTT-hil-insight/VDocRetriever-Phi3-vision-pretrained`, you can use our pre-trained VDocRetriever weights without pre-traing models locally. 203 | 204 | ```bash 205 | deepspeed --include localhost:0 --master_port 60000 --module vdocrag.vdocretriever.driver.train \ 206 | --deepspeed deepspeed/ds_zero3_config.json \ 207 | --output_dir outputs/vdocretriever-phi3-vision_finetune \ 208 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 209 | --lora_name_or_path outputs/vdocretriever-phi3-vision_pretrain \ 210 | --lora \ 211 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ 212 | --save_steps 2000 \ 213 | --dataset_name NTT-hil-insight/OpenDocVQA \ 214 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 215 | --bf16 \ 216 | --pooling eos \ 217 | --append_eos_token \ 218 | --normalize \ 219 | --temperature 0.01 \ 220 | --per_device_train_batch_size 4 \ 221 | --gradient_checkpointing \ 222 | --train_group_size 1 \ 223 | --learning_rate 1e-4 \ 224 | --query_max_len 256 \ 225 | --answer_max_len 256 \ 226 | --num_train_epochs 1 \ 227 | --logging_steps 10 \ 228 | --overwrite_output_dir \ 229 | --gradient_accumulation_steps 4 \ 230 | --report_to wandb \ 231 | ``` 232 | 233 | ## Query encoding 234 | If you want to use our fine-tuned model directly, `lora_name_or_path` is set to `NTT-hil-insight/VDocRetriever-Phi3-vision`. `QUERY_DATASET` must be selected from the following options: {chartqa, slidevqa, infovqa, dude} 235 | 236 | ```bash 237 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 238 | --output_dir=temp \ 239 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 240 | --lora_name_or_path outputs/vdocretriever-phi3-vision_finetune \ 241 | --lora \ 242 | --bf16 \ 243 | --pooling eos \ 244 | --append_eos_token \ 245 | --normalize \ 246 | --encode_is_query \ 247 | --per_device_eval_batch_size 24 \ 248 | --query_max_len 256 \ 249 | --dataset_name NTT-hil-insight/OpenDocVQA \ 250 | --dataset_config $QUERY_DATASET \ 251 | --dataset_split test \ 252 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl 253 | ``` 254 | 255 | ## Document encoding 256 | If you want to use our fine-tuned model directly, `lora_name_or_path` is set to `NTT-hil-insight/VDocRetriever-Phi3-vision`. 257 | `CORPUS_DATASET` must be selected from the following options: {all, chartqa, slidevqa, infovqa, dude} 258 | 259 | ```bash 260 | for s in 0 1 2 3; do 261 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 262 | --output_dir=temp \ 263 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 264 | --lora_name_or_path outputs/vdocretriever-phi3-vision_finetune \ 265 | --lora \ 266 | --bf16 \ 267 | --pooling eos \ 268 | --append_eos_token \ 269 | --normalize \ 270 | --per_device_eval_batch_size 4 \ 271 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 272 | --corpus_config $CORPUS_DATASET \ 273 | --corpus_split test \ 274 | --dataset_number_of_shards 4 \ 275 | --dataset_shard_index ${s} \ 276 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl 277 | ``` 278 | 279 | ## Retrieval 280 | ```bash 281 | python -m vdocrag.vdocretriever.driver.search \ 282 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \ 283 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \ 284 | --depth 1000 \ 285 | --batch_size 64 \ 286 | --save_text \ 287 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ 288 | ``` 289 | 290 | ## Evaluation 291 | ```bash 292 | # Convert retrieval results (.txt) to .trec file 293 | python -m vdocrag.utils.format.convert_result_to_trec --input $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ 294 | --output $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.trec \ 295 | --remove_query 296 | # Create ground-truth retrieval results 297 | python -m vdocrag.utils.format.convert_qas_to_trec_qrels --dataset_name NTT-hil-insight/OpenDocVQA \ 298 | --dataset_config ${QUERY_DATASET} \ 299 | --output $EMBEDDING_OUTPUT_DIR/qrels.${QUERY_DATASET}.txt \ 300 | # Evaluate with pyserini 301 | python -m pyserini.eval.trec_eval -c -mrecall.1,5,10 -mndcg_cut.1,5,10 $EMBEDDING_OUTPUT_DIR/qrels.${QUERY_DATASET}.txt $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.trec 302 | ``` 303 | 304 | 305 | # 💬 Generator 306 | 307 | ## Data Creation for Generator 308 | Before training and evaluating generator models, you should create a data included triples `(query id, retrieved document id, retrieved score)` for the generator. You can create the data automatically as follows: 309 | - Train Data: `script/create_train_generator.sh` 310 | - Test Data 311 | - ChartQA: `script/create_test_chartqa_generator.sh` 312 | - SlideVQA: `script/create_test_slidevqa_generator.sh` 313 | - InfoVQA: `script/create_test_infovqa_generator.sh` 314 | - DUDE: `script/create_test_dude_generator.sh` 315 | 316 | If you want to use your own models, `lora_name_or_path` should be replaced with your model name. By default, all generated test set is set to the single-pool setting. When you evaluated models under the all-pool setting, you can change `CORPUS_DATASET` into `all`. 317 | 318 | 319 | ## Fine-tuning VDocGenerator 320 | `retrieval_results_path` is the saved path where the retrieval results created in the previous section (Data Creation for Generator). 321 | 322 | ```bash 323 | deepspeed --include localhost:0 --master_port 60000 --module vdocrag.vdocgenerator.driver.train \ 324 | --deepspeed deepspeed/ds_zero3_config.json \ 325 | --output_dir outputs/vdocgenerator-phi3-vision_finetune \ 326 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 327 | --lora \ 328 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ 329 | --save_steps 100 \ 330 | --dataset_name NTT-hil-insight/OpenDocVQA \ 331 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 332 | --retrieval_results_path outputs/vdocretriever-phi3-vision_finetune/embs/rank.train.all.txt \ 333 | --bf16 \ 334 | --per_device_train_batch_size 2 \ 335 | --gradient_checkpointing \ 336 | --learning_rate 1e-4 \ 337 | --query_max_len 256 \ 338 | --num_train_epochs 1 \ 339 | --logging_steps 10 \ 340 | --overwrite_output_dir \ 341 | --gradient_accumulation_steps 4 \ 342 | --top_k 3 \ 343 | --report_to wandb \ 344 | ``` 345 | 346 | ## Generating Answers 347 | If you want to use our fine-tuned model directly, `lora_name_or_path` is set to `NTT-hil-insight/VDocGenerator-Phi3-vision`. 348 | ```bash 349 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocgenerator.driver.generate \ 350 | --output_dir=temp \ 351 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 352 | --lora_name_or_path outputs/vdocgenerator-phi3-vision_finetune \ 353 | --lora \ 354 | --dataset_name NTT-hil-insight/OpenDocVQA \ 355 | --dataset_split test \ 356 | --dataset_config ${QUERY_DATASET} \ 357 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 358 | --corpus_config $CORPUS_DATASET \ 359 | --corpus_split test \ 360 | --retrieval_results_path outputs/vdocretriever-phi3-vision_finetune/embs/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ 361 | --bf16 \ 362 | --per_device_eval_batch_size 1 \ 363 | --top_k 3 \ 364 | --output_path outputs/vdocgenerator-phi3-vision_finetune/answers/answers.${QUERY_DATASET}.${CORPUS_DATASET}.json \ 365 | ``` 366 | 367 | ## Evaluation 368 | ```bash 369 | python -m vdocrag.utils.eval_opendocvqa --input outputs/vdocgenerator-phi3-vision_finetune/answers/answers.${QUERY_DATASET}.${CORPUS_DATASET}.json 370 | ``` 371 | 372 | 373 | # 📝 License 374 | The code is released under the NTT License as found in the [LICENSE](./LICENSE) file. 375 | 376 | 377 | # ✒️ Citation 378 | ```bibtex 379 | @inproceedings{tanaka2025vdocrag, 380 | author = {Ryota Tanaka and 381 | Taichi Iki and 382 | Taku Hasegawa and 383 | Kyosuke Nishida and 384 | Kuniko Saito and 385 | Jun Suzuki}, 386 | title = {VDocRAG: Retrieval-Augmented Generation over Visually-Rich Documents}, 387 | booktitle = {CVPR}, 388 | year = {2025} 389 | } 390 | ``` 391 | 392 | 393 | # 📔 Acknowledgement 394 | We have adapted code from [Tevatron](https://github.com/texttron/tevatron/), a flexible and efficient toolkit that supports training and inference for neural retrieval models. 395 | 396 | -------------------------------------------------------------------------------- /deepspeed/ds_zero3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "offload_optimizer": { 5 | "device": "none", 6 | "pin_memory": true 7 | }, 8 | "offload_param": { 9 | "device": "none", 10 | "pin_memory": true 11 | }, 12 | "overlap_comm": true, 13 | "contiguous_gradients": true, 14 | "sub_group_size": 1e9, 15 | "reduce_bucket_size": 1e6, 16 | "stage3_prefetch_bucket_size": "auto", 17 | "stage3_param_persistence_threshold": "auto", 18 | "stage3_max_live_parameters": 1e9, 19 | "stage3_max_reuse_distance": 1e9, 20 | "stage3_gather_16bit_weights_on_model_save": true 21 | }, 22 | "fp16": { 23 | "enabled": "auto", 24 | "loss_scale": 0, 25 | "initial_scale_power": 10, 26 | "loss_scale_window": 1000, 27 | "hysteresis": 2, 28 | "min_loss_scale": 1 29 | }, 30 | "bf16": { 31 | "enabled": "auto", 32 | "loss_scale": 0, 33 | "initial_scale_power": 10, 34 | "loss_scale_window": 1000, 35 | "hysteresis": 2, 36 | "min_loss_scale": 1 37 | }, 38 | "optimizer": { 39 | "type": "AdamW", 40 | "params": { 41 | "lr": "auto", 42 | "betas": "auto", 43 | "eps": "auto", 44 | "weight_decay": "auto", 45 | "torch_adam": true 46 | } 47 | }, 48 | 49 | "scheduler": { 50 | "type": "WarmupDecayLR", 51 | "params": { 52 | "warmup_min_lr": "auto", 53 | "warmup_max_lr": "auto", 54 | "warmup_num_steps": "auto", 55 | "total_num_steps": "auto" 56 | } 57 | }, 58 | 59 | "gradient_accumulation_steps": "auto", 60 | "gradient_clipping": "auto", 61 | "steps_per_print": 1000, 62 | "train_batch_size": "auto", 63 | "train_micro_batch_size_per_gpu": "auto", 64 | "wall_clock_breakdown": false 65 | } 66 | -------------------------------------------------------------------------------- /images/abst.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/images/abst.png -------------------------------------------------------------------------------- /images/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/images/image1.png -------------------------------------------------------------------------------- /images/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/images/image2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepspeed==0.14.5 2 | accelerate==0.34.0 3 | datasets==2.20.0 4 | flash-attn==2.6.3 5 | huggingface-hub==0.30.1 6 | numpy==1.26.4 7 | peft==0.11.1 8 | pillow==10.4.0 9 | torch==2.4.0 10 | torchvision==0.19.0 11 | tqdm==4.66.4 12 | transformers==4.40.2 13 | wandb==0.17.7 14 | pyserini==0.44.0 15 | faiss-gpu==1.7.2 -------------------------------------------------------------------------------- /scripts/create_test_chartqa_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | QUERY_DATASET=chartqa 4 | CORPUS_DATASET=chartqa 5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs 6 | 7 | # encoding queries 8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 9 | --output_dir=temp \ 10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 12 | --lora \ 13 | --bf16 \ 14 | --pooling eos \ 15 | --append_eos_token \ 16 | --normalize \ 17 | --encode_is_query \ 18 | --per_device_eval_batch_size 24 \ 19 | --query_max_len 256 \ 20 | --dataset_name NTT-hil-insight/OpenDocVQA \ 21 | --dataset_config $QUERY_DATASET \ 22 | --dataset_split test \ 23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl 24 | 25 | # encoding documents 26 | for s in 0 1 2 3 27 | do 28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 29 | --output_dir=temp \ 30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 32 | --lora \ 33 | --bf16 \ 34 | --pooling eos \ 35 | --append_eos_token \ 36 | --normalize \ 37 | --per_device_eval_batch_size 4 \ 38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 39 | --corpus_config $CORPUS_DATASET \ 40 | --corpus_split test \ 41 | --dataset_number_of_shards 4 \ 42 | --dataset_shard_index ${s} \ 43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl 44 | done 45 | 46 | # retrieving documents 47 | python -m vdocrag.vdocretriever.driver.search \ 48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \ 49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \ 50 | --depth 1000 \ 51 | --batch_size 64 \ 52 | --save_text \ 53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ -------------------------------------------------------------------------------- /scripts/create_test_dude_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | QUERY_DATASET=dude 4 | CORPUS_DATASET=dude 5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs 6 | 7 | # encoding queries 8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 9 | --output_dir=temp \ 10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 12 | --lora \ 13 | --bf16 \ 14 | --pooling eos \ 15 | --append_eos_token \ 16 | --normalize \ 17 | --encode_is_query \ 18 | --per_device_eval_batch_size 24 \ 19 | --query_max_len 256 \ 20 | --dataset_name NTT-hil-insight/OpenDocVQA \ 21 | --dataset_config $QUERY_DATASET \ 22 | --dataset_split test \ 23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl 24 | 25 | # encoding documents 26 | for s in 0 1 2 3 27 | do 28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 29 | --output_dir=temp \ 30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 32 | --lora \ 33 | --bf16 \ 34 | --pooling eos \ 35 | --append_eos_token \ 36 | --normalize \ 37 | --per_device_eval_batch_size 4 \ 38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 39 | --corpus_config $CORPUS_DATASET \ 40 | --corpus_split test \ 41 | --dataset_number_of_shards 4 \ 42 | --dataset_shard_index ${s} \ 43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl 44 | done 45 | 46 | # retrieving documentss 47 | python -m vdocrag.vdocretriever.driver.search \ 48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \ 49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \ 50 | --depth 1000 \ 51 | --batch_size 64 \ 52 | --save_text \ 53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ -------------------------------------------------------------------------------- /scripts/create_test_infovqa_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | QUERY_DATASET=infovqa 4 | CORPUS_DATASET=infovqa 5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs 6 | 7 | # encoding queries 8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 9 | --output_dir=temp \ 10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 12 | --lora \ 13 | --bf16 \ 14 | --pooling eos \ 15 | --append_eos_token \ 16 | --normalize \ 17 | --encode_is_query \ 18 | --per_device_eval_batch_size 24 \ 19 | --query_max_len 256 \ 20 | --dataset_name NTT-hil-insight/OpenDocVQA \ 21 | --dataset_config $QUERY_DATASET \ 22 | --dataset_split test \ 23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl 24 | 25 | # encoding documents 26 | for s in 0 1 2 3 27 | do 28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 29 | --output_dir=temp \ 30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 32 | --lora \ 33 | --bf16 \ 34 | --pooling eos \ 35 | --append_eos_token \ 36 | --normalize \ 37 | --per_device_eval_batch_size 4 \ 38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 39 | --corpus_config $CORPUS_DATASET \ 40 | --corpus_split test \ 41 | --dataset_number_of_shards 4 \ 42 | --dataset_shard_index ${s} \ 43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl 44 | done 45 | 46 | # retrieving documentss 47 | python -m vdocrag.vdocretriever.driver.search \ 48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \ 49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \ 50 | --depth 1000 \ 51 | --batch_size 64 \ 52 | --save_text \ 53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ -------------------------------------------------------------------------------- /scripts/create_test_slidevqa_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | QUERY_DATASET=slidevqa 4 | CORPUS_DATASET=slidevqa 5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs 6 | 7 | # encoding queries 8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 9 | --output_dir=temp \ 10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 12 | --lora \ 13 | --bf16 \ 14 | --pooling eos \ 15 | --append_eos_token \ 16 | --normalize \ 17 | --encode_is_query \ 18 | --per_device_eval_batch_size 24 \ 19 | --query_max_len 256 \ 20 | --dataset_name NTT-hil-insight/OpenDocVQA \ 21 | --dataset_config $QUERY_DATASET \ 22 | --dataset_split test \ 23 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl 24 | 25 | # encoding documents 26 | for s in 0 1 2 3 27 | do 28 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 29 | --output_dir=temp \ 30 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 31 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 32 | --lora \ 33 | --bf16 \ 34 | --pooling eos \ 35 | --append_eos_token \ 36 | --normalize \ 37 | --per_device_eval_batch_size 4 \ 38 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 39 | --corpus_config $CORPUS_DATASET \ 40 | --corpus_split test \ 41 | --dataset_number_of_shards 4 \ 42 | --dataset_shard_index ${s} \ 43 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl 44 | done 45 | 46 | # retrieving documentss 47 | python -m vdocrag.vdocretriever.driver.search \ 48 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \ 49 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \ 50 | --depth 1000 \ 51 | --batch_size 64 \ 52 | --save_text \ 53 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ -------------------------------------------------------------------------------- /scripts/create_train_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | QUERY_DATASET=train 4 | CORPUS_DATASET=all 5 | EMBEDDING_OUTPUT_DIR=outputs/vdocretriever-phi3-vision_finetune/embs 6 | 7 | # encoding queries 8 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 9 | --output_dir=temp \ 10 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 11 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 12 | --lora \ 13 | --bf16 \ 14 | --pooling eos \ 15 | --append_eos_token \ 16 | --normalize \ 17 | --encode_is_query \ 18 | --per_device_eval_batch_size 24 \ 19 | --query_max_len 256 \ 20 | --dataset_name NTT-hil-insight/OpenDocVQA \ 21 | --encode_output_path $EMBEDDING_OUTPUT_DIR/query-train.pkl 22 | 23 | # encoding documents 24 | for s in 0 1 2 3 25 | do 26 | CUDA_VISIBLE_DEVICES=0 python -m vdocrag.vdocretriever.driver.encode \ 27 | --output_dir=temp \ 28 | --model_name_or_path microsoft/Phi-3-vision-128k-instruct \ 29 | --lora_name_or_path NTT-hil-insight/VDocRetriever-Phi3-vision \ 30 | --lora \ 31 | --bf16 \ 32 | --pooling eos \ 33 | --append_eos_token \ 34 | --normalize \ 35 | --per_device_eval_batch_size 4 \ 36 | --corpus_name NTT-hil-insight/OpenDocVQA-Corpus \ 37 | --dataset_number_of_shards 4 \ 38 | --dataset_shard_index ${s} \ 39 | --encode_output_path $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}.${s}.pkl 40 | done 41 | 42 | # retrieving documents 43 | python -m vdocrag.vdocretriever.driver.search \ 44 | --query_reps $EMBEDDING_OUTPUT_DIR/query-${QUERY_DATASET}.pkl \ 45 | --document_reps $EMBEDDING_OUTPUT_DIR/corpus.${CORPUS_DATASET}'.*.pkl' \ 46 | --depth 1000 \ 47 | --batch_size 64 \ 48 | --save_text \ 49 | --save_ranking_to $EMBEDDING_OUTPUT_DIR/rank.${QUERY_DATASET}.${CORPUS_DATASET}.txt \ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='vdocrag', 5 | version='0.0.1', 6 | packages=find_packages("src"), 7 | package_dir={'': 'src'}, 8 | python_requires='>=3.7', 9 | ) 10 | -------------------------------------------------------------------------------- /src/vdocrag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/__init__.py -------------------------------------------------------------------------------- /src/vdocrag/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/utils/__init__.py -------------------------------------------------------------------------------- /src/vdocrag/utils/eval_opendocvqa.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import argparse 4 | from collections import Counter 5 | 6 | def has_word(sentence, word): 7 | pattern = r"\b" + re.escape(word) + r"\b" 8 | match = re.search(pattern, sentence) 9 | if match: 10 | return True 11 | else: 12 | return False 13 | 14 | class VQAEval: 15 | def __init__(self): 16 | self.contractions = { 17 | "aint": "ain't", 18 | "arent": "aren't", 19 | "cant": "can't", 20 | "couldve": "could've", 21 | "couldnt": "couldn't", 22 | "couldn'tve": "couldn't've", 23 | "couldnt've": "couldn't've", 24 | "didnt": "didn't", 25 | "doesnt": "doesn't", 26 | "dont": "don't", 27 | "hadnt": "hadn't", 28 | "hadnt've": "hadn't've", 29 | "hadn'tve": "hadn't've", 30 | "hasnt": "hasn't", 31 | "havent": "haven't", 32 | "hed": "he'd", 33 | "hed've": "he'd've", 34 | "he'dve": "he'd've", 35 | "hes": "he's", 36 | "howd": "how'd", 37 | "howll": "how'll", 38 | "hows": "how's", 39 | "Id've": "I'd've", 40 | "I'dve": "I'd've", 41 | "Im": "I'm", 42 | "Ive": "I've", 43 | "isnt": "isn't", 44 | "itd": "it'd", 45 | "itd've": "it'd've", 46 | "it'dve": "it'd've", 47 | "itll": "it'll", 48 | "let's": "let's", 49 | "maam": "ma'am", 50 | "mightnt": "mightn't", 51 | "mightnt've": "mightn't've", 52 | "mightn'tve": "mightn't've", 53 | "mightve": "might've", 54 | "mustnt": "mustn't", 55 | "mustve": "must've", 56 | "neednt": "needn't", 57 | "notve": "not've", 58 | "oclock": "o'clock", 59 | "oughtnt": "oughtn't", 60 | "ow's'at": "'ow's'at", 61 | "'ows'at": "'ow's'at", 62 | "'ow'sat": "'ow's'at", 63 | "shant": "shan't", 64 | "shed've": "she'd've", 65 | "she'dve": "she'd've", 66 | "she's": "she's", 67 | "shouldve": "should've", 68 | "shouldnt": "shouldn't", 69 | "shouldnt've": "shouldn't've", 70 | "shouldn'tve": "shouldn't've", 71 | "somebody'd": "somebodyd", 72 | "somebodyd've": "somebody'd've", 73 | "somebody'dve": "somebody'd've", 74 | "somebodyll": "somebody'll", 75 | "somebodys": "somebody's", 76 | "someoned": "someone'd", 77 | "someoned've": "someone'd've", 78 | "someone'dve": "someone'd've", 79 | "someonell": "someone'll", 80 | "someones": "someone's", 81 | "somethingd": "something'd", 82 | "somethingd've": "something'd've", 83 | "something'dve": "something'd've", 84 | "somethingll": "something'll", 85 | "thats": "that's", 86 | "thered": "there'd", 87 | "thered've": "there'd've", 88 | "there'dve": "there'd've", 89 | "therere": "there're", 90 | "theres": "there's", 91 | "theyd": "they'd", 92 | "theyd've": "they'd've", 93 | "they'dve": "they'd've", 94 | "theyll": "they'll", 95 | "theyre": "they're", 96 | "theyve": "they've", 97 | "twas": "'twas", 98 | "wasnt": "wasn't", 99 | "wed've": "we'd've", 100 | "we'dve": "we'd've", 101 | "weve": "we've", 102 | "werent": "weren't", 103 | "whatll": "what'll", 104 | "whatre": "what're", 105 | "whats": "what's", 106 | "whatve": "what've", 107 | "whens": "when's", 108 | "whered": "where'd", 109 | "wheres": "where's", 110 | "whereve": "where've", 111 | "whod": "who'd", 112 | "whod've": "who'd've", 113 | "who'dve": "who'd've", 114 | "wholl": "who'll", 115 | "whos": "who's", 116 | "whove": "who've", 117 | "whyll": "why'll", 118 | "whyre": "why're", 119 | "whys": "why's", 120 | "wont": "won't", 121 | "wouldve": "would've", 122 | "wouldnt": "wouldn't", 123 | "wouldnt've": "wouldn't've", 124 | "wouldn'tve": "wouldn't've", 125 | "yall": "y'all", 126 | "yall'll": "y'all'll", 127 | "y'allll": "y'all'll", 128 | "yall'd've": "y'all'd've", 129 | "y'alld've": "y'all'd've", 130 | "y'all'dve": "y'all'd've", 131 | "youd": "you'd", 132 | "youd've": "you'd've", 133 | "you'dve": "you'd've", 134 | "youll": "you'll", 135 | "youre": "you're", 136 | "youve": "you've", 137 | } 138 | self.manualMap = { 139 | "none": "0", 140 | "zero": "0", 141 | "one": "1", 142 | "two": "2", 143 | "three": "3", 144 | "four": "4", 145 | "five": "5", 146 | "six": "6", 147 | "seven": "7", 148 | "eight": "8", 149 | "nine": "9", 150 | "ten": "10", 151 | } 152 | self.articles = ["a", "an", "the"] 153 | 154 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 155 | self.commaStrip = re.compile("(\d)(\,)(\d)") 156 | self.punct = [ 157 | ";", 158 | r"/", 159 | "[", 160 | "]", 161 | '"', 162 | "{", 163 | "}", 164 | "(", 165 | ")", 166 | "=", 167 | "+", 168 | "\\", 169 | "_", 170 | "-", 171 | ">", 172 | "<", 173 | "@", 174 | "`", 175 | ",", 176 | "?", 177 | "!", 178 | ] 179 | 180 | def evaluate(self, answer, gt_answers): 181 | answer = answer.replace("\n", " ") 182 | answer = answer.replace("\t", " ") 183 | answer = answer.strip() 184 | answer = self.processPunctuation(answer) 185 | answer = self.processDigitArticle(answer) 186 | if type(gt_answers) == str: 187 | gt_answers = [gt_answers] 188 | for i in range(len(gt_answers)): 189 | gt_answers[i] = gt_answers[i].replace("\n", " ") 190 | gt_answers[i] = gt_answers[i].replace("\t", " ") 191 | gt_answers[i] = gt_answers[i].strip() 192 | gt_answers[i] = self.processPunctuation(gt_answers[i]) 193 | gt_answers[i] = self.processDigitArticle(gt_answers[i]) 194 | if answer == gt_answers[i]: 195 | return 1 196 | return 0 197 | 198 | def processPunctuation(self, inText): 199 | outText = inText 200 | for p in self.punct: 201 | if (p + " " in inText or " " + p in inText) or ( 202 | re.search(self.commaStrip, inText) != None 203 | ): 204 | outText = outText.replace(p, "") 205 | else: 206 | outText = outText.replace(p, " ") 207 | outText = self.periodStrip.sub("", outText, re.UNICODE) 208 | return outText 209 | 210 | def processDigitArticle(self, inText): 211 | outText = [] 212 | tempText = inText.lower().split() 213 | for word in tempText: 214 | word = self.manualMap.setdefault(word, word) 215 | if word not in self.articles: 216 | outText.append(word) 217 | else: 218 | pass 219 | for wordId, word in enumerate(outText): 220 | if word in self.contractions: 221 | outText[wordId] = self.contractions[word] 222 | outText = " ".join(outText) 223 | return outText 224 | 225 | def evaluate_anls(self, answer, gt_answers): 226 | answer = answer.replace("\n", " ") 227 | answer = answer.replace("\t", " ") 228 | answer = answer.strip() 229 | answer = self.processPunctuation(answer) 230 | answer = self.processDigitArticle(answer) 231 | if type(gt_answers) == str: 232 | gt_answers = [gt_answers] 233 | values = [] 234 | for i in range(len(gt_answers)): 235 | gt_answers[i] = gt_answers[i].replace("\n", " ") 236 | gt_answers[i] = gt_answers[i].replace("\t", " ") 237 | gt_answers[i] = gt_answers[i].strip() 238 | gt_answers[i] = self.processPunctuation(gt_answers[i]) 239 | gt_answers[i] = self.processDigitArticle(gt_answers[i]) 240 | dist = self.levenshtein_distance(gt_answers[i], answer) 241 | length = max(len(gt_answers[i]), len(answer)) 242 | values.append(0.0 if length == 0 else float(dist) / float(length)) 243 | 244 | vqa_anls = 1 - min(values) 245 | return vqa_anls 246 | 247 | def evaluate_f1(self, answer, gt_answers): 248 | answer = answer.replace("\n", " ") 249 | answer = answer.replace("\t", " ") 250 | answer = answer.strip() 251 | answer = self.processPunctuation(answer) 252 | answer = self.processDigitArticle(answer) 253 | if type(gt_answers) == str: 254 | gt_answers = [gt_answers] 255 | 256 | f1s = [] 257 | for i in range(len(gt_answers)): 258 | gt_answers[i] = gt_answers[i].replace("\n", " ") 259 | gt_answers[i] = gt_answers[i].replace("\t", " ") 260 | gt_answers[i] = gt_answers[i].strip() 261 | gt_answers[i] = self.processPunctuation(gt_answers[i]) 262 | gt_answers[i] = self.processDigitArticle(gt_answers[i]) 263 | prediction_tokens = answer.split() 264 | ground_truth_tokens = gt_answers[i].split() 265 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 266 | num_same = sum(common.values()) 267 | try: 268 | precision = 1.0 * num_same / len(prediction_tokens) 269 | recall = 1.0 * num_same / len(ground_truth_tokens) 270 | f1 = (2 * precision * recall) / (precision + recall) 271 | except: 272 | f1 = 0 273 | f1s.append(f1) 274 | f1 = max(f1s) 275 | 276 | return f1 277 | 278 | def evaluate_racc(self, answer, gt_answers): 279 | answer = answer.replace("\n", " ") 280 | answer = answer.replace("\t", " ") 281 | answer = answer.strip() 282 | answer = self.processPunctuation(answer) 283 | answer = self.processDigitArticle(answer) 284 | if type(gt_answers) == str: 285 | gt_answers = [gt_answers] 286 | 287 | for i in range(len(gt_answers)): 288 | gt_answers[i] = gt_answers[i].replace("\n", " ") 289 | gt_answers[i] = gt_answers[i].replace("\t", " ") 290 | gt_answers[i] = gt_answers[i].strip() 291 | gt_answers[i] = self.processPunctuation(gt_answers[i]) 292 | gt_answers[i] = self.processDigitArticle(gt_answers[i]) 293 | answer_float = self._to_float(answer) 294 | gt_answer_float = self._to_float(gt_answers[i]) 295 | if answer_float is not None and gt_answer_float: 296 | relative_change = abs(answer_float - gt_answer_float) / abs(gt_answer_float) 297 | return relative_change <= 0.05 298 | else: 299 | return answer.lower() == gt_answers[i].lower() 300 | return 0 301 | 302 | def _to_float(self, text): 303 | try: 304 | if text.endswith("%"): 305 | # Convert percentages to floats. 306 | return float(text.rstrip("%")) / 100.0 307 | else: 308 | return float(text) 309 | except ValueError: 310 | return None 311 | def levenshtein_distance(self, s1, s2): 312 | if len(s1) > len(s2): 313 | s1, s2 = s2, s1 314 | 315 | distances = range(len(s1) + 1) 316 | for i2, c2 in enumerate(s2): 317 | distances_ = [i2+1] 318 | for i1, c1 in enumerate(s1): 319 | if c1 == c2: 320 | distances_.append(distances[i1]) 321 | else: 322 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 323 | distances = distances_ 324 | return distances[-1] 325 | 326 | class SumEval: 327 | def __init__(self): 328 | self.contractions = { 329 | "aint": "ain't", 330 | "arent": "aren't", 331 | "cant": "can't", 332 | "couldve": "could've", 333 | "couldnt": "couldn't", 334 | "couldn'tve": "couldn't've", 335 | "couldnt've": "couldn't've", 336 | "didnt": "didn't", 337 | "doesnt": "doesn't", 338 | "dont": "don't", 339 | "hadnt": "hadn't", 340 | "hadnt've": "hadn't've", 341 | "hadn'tve": "hadn't've", 342 | "hasnt": "hasn't", 343 | "havent": "haven't", 344 | "hed": "he'd", 345 | "hed've": "he'd've", 346 | "he'dve": "he'd've", 347 | "hes": "he's", 348 | "howd": "how'd", 349 | "howll": "how'll", 350 | "hows": "how's", 351 | "Id've": "I'd've", 352 | "I'dve": "I'd've", 353 | "Im": "I'm", 354 | "Ive": "I've", 355 | "isnt": "isn't", 356 | "itd": "it'd", 357 | "itd've": "it'd've", 358 | "it'dve": "it'd've", 359 | "itll": "it'll", 360 | "let's": "let's", 361 | "maam": "ma'am", 362 | "mightnt": "mightn't", 363 | "mightnt've": "mightn't've", 364 | "mightn'tve": "mightn't've", 365 | "mightve": "might've", 366 | "mustnt": "mustn't", 367 | "mustve": "must've", 368 | "neednt": "needn't", 369 | "notve": "not've", 370 | "oclock": "o'clock", 371 | "oughtnt": "oughtn't", 372 | "ow's'at": "'ow's'at", 373 | "'ows'at": "'ow's'at", 374 | "'ow'sat": "'ow's'at", 375 | "shant": "shan't", 376 | "shed've": "she'd've", 377 | "she'dve": "she'd've", 378 | "she's": "she's", 379 | "shouldve": "should've", 380 | "shouldnt": "shouldn't", 381 | "shouldnt've": "shouldn't've", 382 | "shouldn'tve": "shouldn't've", 383 | "somebody'd": "somebodyd", 384 | "somebodyd've": "somebody'd've", 385 | "somebody'dve": "somebody'd've", 386 | "somebodyll": "somebody'll", 387 | "somebodys": "somebody's", 388 | "someoned": "someone'd", 389 | "someoned've": "someone'd've", 390 | "someone'dve": "someone'd've", 391 | "someonell": "someone'll", 392 | "someones": "someone's", 393 | "somethingd": "something'd", 394 | "somethingd've": "something'd've", 395 | "something'dve": "something'd've", 396 | "somethingll": "something'll", 397 | "thats": "that's", 398 | "thered": "there'd", 399 | "thered've": "there'd've", 400 | "there'dve": "there'd've", 401 | "therere": "there're", 402 | "theres": "there's", 403 | "theyd": "they'd", 404 | "theyd've": "they'd've", 405 | "they'dve": "they'd've", 406 | "theyll": "they'll", 407 | "theyre": "they're", 408 | "theyve": "they've", 409 | "twas": "'twas", 410 | "wasnt": "wasn't", 411 | "wed've": "we'd've", 412 | "we'dve": "we'd've", 413 | "weve": "we've", 414 | "werent": "weren't", 415 | "whatll": "what'll", 416 | "whatre": "what're", 417 | "whats": "what's", 418 | "whatve": "what've", 419 | "whens": "when's", 420 | "whered": "where'd", 421 | "wheres": "where's", 422 | "whereve": "where've", 423 | "whod": "who'd", 424 | "whod've": "who'd've", 425 | "who'dve": "who'd've", 426 | "wholl": "who'll", 427 | "whos": "who's", 428 | "whove": "who've", 429 | "whyll": "why'll", 430 | "whyre": "why're", 431 | "whys": "why's", 432 | "wont": "won't", 433 | "wouldve": "would've", 434 | "wouldnt": "wouldn't", 435 | "wouldnt've": "wouldn't've", 436 | "wouldn'tve": "wouldn't've", 437 | "yall": "y'all", 438 | "yall'll": "y'all'll", 439 | "y'allll": "y'all'll", 440 | "yall'd've": "y'all'd've", 441 | "y'alld've": "y'all'd've", 442 | "y'all'dve": "y'all'd've", 443 | "youd": "you'd", 444 | "youd've": "you'd've", 445 | "you'dve": "you'd've", 446 | "youll": "you'll", 447 | "youre": "you're", 448 | "youve": "you've", 449 | } 450 | self.manualMap = { 451 | "none": "0", 452 | "zero": "0", 453 | "one": "1", 454 | "two": "2", 455 | "three": "3", 456 | "four": "4", 457 | "five": "5", 458 | "six": "6", 459 | "seven": "7", 460 | "eight": "8", 461 | "nine": "9", 462 | "ten": "10", 463 | } 464 | self.articles = ["a", "an", "the"] 465 | 466 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 467 | self.commaStrip = re.compile("(\d)(\,)(\d)") 468 | self.punct = [ 469 | ";", 470 | r"/", 471 | "[", 472 | "]", 473 | '"', 474 | "{", 475 | "}", 476 | "(", 477 | ")", 478 | "=", 479 | "+", 480 | "\\", 481 | "_", 482 | "-", 483 | ">", 484 | "<", 485 | "@", 486 | "`", 487 | ",", 488 | "?", 489 | "!", 490 | ] 491 | 492 | def process(self, answer, gt_answer): 493 | answer = answer.replace("\n", " ") 494 | answer = answer.replace("\t", " ") 495 | answer = answer.strip() 496 | answer = self.processPunctuation(answer) 497 | answer = self.processDigitArticle(answer) 498 | gt_answer = gt_answer.replace("\n", " ") 499 | gt_answer = gt_answer.replace("\t", " ") 500 | gt_answer = gt_answer.strip() 501 | gt_answer = self.processPunctuation(gt_answer) 502 | gt_answer = self.processDigitArticle(gt_answer) 503 | 504 | return answer, gt_answer 505 | 506 | def processPunctuation(self, inText): 507 | outText = inText 508 | for p in self.punct: 509 | if (p + " " in inText or " " + p in inText) or ( 510 | re.search(self.commaStrip, inText) != None 511 | ): 512 | outText = outText.replace(p, "") 513 | else: 514 | outText = outText.replace(p, " ") 515 | outText = self.periodStrip.sub("", outText, re.UNICODE) 516 | return outText 517 | 518 | def processDigitArticle(self, inText): 519 | outText = [] 520 | tempText = inText.lower().split() 521 | for word in tempText: 522 | word = self.manualMap.setdefault(word, word) 523 | if word not in self.articles: 524 | outText.append(word) 525 | else: 526 | pass 527 | for wordId, word in enumerate(outText): 528 | if word in self.contractions: 529 | outText[wordId] = self.contractions[word] 530 | outText = " ".join(outText) 531 | return outText 532 | 533 | if __name__ == "__main__": 534 | parser = argparse.ArgumentParser() 535 | parser.add_argument("--input", type=str, default="outputs/vdocgenerator-phi3-vision_finetune/answers/answers.chartqa.chartqa.json") 536 | args = parser.parse_args() 537 | 538 | cor = 0 539 | r_cor = 0 540 | num = 0 541 | anls_score = 0 542 | f1_score = 0 543 | 544 | preds, gt = {}, {} 545 | eval = VQAEval() 546 | with open(args.input, 'r') as f: 547 | data = json.load(f) 548 | for question_id in data: 549 | d = data[question_id] 550 | answer = d["prediction"] 551 | gt_answers = d["ground_truth"] 552 | 553 | if eval.evaluate(answer, gt_answers)==1: 554 | cor += 1 555 | if eval.evaluate_racc(answer, gt_answers)==1: 556 | r_cor += 1 557 | 558 | anls_score+=eval.evaluate_anls(answer, gt_answers) 559 | f1_score += eval.evaluate_f1(answer, gt_answers) 560 | num += 1 561 | 562 | print("exact acc: ", float(cor) / num) 563 | print("relaxed acc: ", float(r_cor) / num) 564 | print("anls: ", float(anls_score) / num) 565 | print("f1: ", float(f1_score) / num) 566 | -------------------------------------------------------------------------------- /src/vdocrag/utils/format/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/utils/format/__init__.py -------------------------------------------------------------------------------- /src/vdocrag/utils/format/convert_qas_to_trec_qrels.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from datasets import load_dataset 4 | 5 | parser = ArgumentParser() 6 | parser.add_argument('--dataset_name', type=str, required=True) 7 | parser.add_argument('--dataset_config', type=str, required=True) 8 | parser.add_argument('--output', type=str, required=True) 9 | args = parser.parse_args() 10 | 11 | data = load_dataset( 12 | args.dataset_name, 13 | args.dataset_config, 14 | split="test", 15 | ) 16 | 17 | with open(args.output, 'w') as f_out: 18 | for d in data: 19 | query_id = d['query_id'] 20 | for docid in d['relevant_doc_ids']: 21 | f_out.write(f'{query_id} Q0 {docid} 1\n') 22 | -------------------------------------------------------------------------------- /src/vdocrag/utils/format/convert_result_to_trec.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | parser = ArgumentParser() 4 | parser.add_argument('--input', type=str, required=True) 5 | parser.add_argument('--output', type=str, required=True) 6 | parser.add_argument('--remove_query', action='store_true') 7 | args = parser.parse_args() 8 | 9 | with open(args.input) as f_in, open(args.output, 'w') as f_out: 10 | cur_qid = None 11 | rank = 0 12 | for line in f_in: 13 | qid, docid, score = line.split('\t') 14 | score = score.replace('\n', '') 15 | if cur_qid != qid: 16 | cur_qid = qid 17 | rank = 0 18 | if args.remove_query and qid == docid: 19 | continue 20 | rank += 1 21 | f_out.write(f'{qid} Q0 {docid} {rank} {score} dense\n') 22 | -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/vdocgenerator/__init__.py -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | from transformers import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: str = field( 10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 11 | ) 12 | config_name: Optional[str] = field( 13 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 14 | ) 15 | tokenizer_name: Optional[str] = field( 16 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 17 | ) 18 | cache_dir: Optional[str] = field( 19 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 20 | ) 21 | # for lora 22 | lora: bool = field(default=False, 23 | metadata={"help": "do parameter-efficient fine-tuning with lora"} 24 | ) 25 | 26 | lora_name_or_path: Optional[str] = field( 27 | default=None, metadata={"help": "Path to pretrained lora model or model identifier from huggingface.co/models"} 28 | ) 29 | 30 | lora_r: int = field( 31 | default=8, 32 | metadata={"help": "lora r"} 33 | ) 34 | 35 | lora_alpha: int = field( 36 | default=64, 37 | metadata={"help": "lora alpha"} 38 | ) 39 | 40 | lora_dropout: float = field( 41 | default=0.1, 42 | metadata={"help": "lora dropout"} 43 | ) 44 | 45 | lora_target_modules: str = field( 46 | default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj", 47 | metadata={"help": "lora target modules"} 48 | ) 49 | 50 | # for Jax training 51 | dtype: Optional[str] = field( 52 | default="float32", 53 | metadata={ 54 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one " 55 | "of `[float32, float16, bfloat16]`. " 56 | }, 57 | ) 58 | 59 | 60 | @dataclass 61 | class DataArguments: 62 | dataset_name: str = field( 63 | default='json', metadata={"help": "huggingface dataset name"} 64 | ) 65 | dataset_path: str = field( 66 | default='json', metadata={"help": "dataset path"} 67 | ) 68 | dataset_config: str = field( 69 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"} 70 | ) 71 | 72 | dataset_path: str = field( 73 | default=None, metadata={"help": "Path to local data files or directory"} 74 | ) 75 | 76 | dataset_split: str = field( 77 | default='train', metadata={"help": "dataset split"} 78 | ) 79 | 80 | dataset_cache_dir: Optional[str] = field( 81 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"} 82 | ) 83 | corpus_name: str = field( 84 | default='NTT-hil-insight/OpenDocVQA-Corpus', metadata={"help": "huggingface dataset name"} 85 | ) 86 | 87 | corpus_config: str = field( 88 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"} 89 | ) 90 | 91 | corpus_path: str = field( 92 | default=None, metadata={"help": "Path to local data files or directory"} 93 | ) 94 | 95 | corpus_split: str = field( 96 | default='train', metadata={"help": "dataset split"} 97 | ) 98 | 99 | retrieval_results_path: str = field( 100 | default=None, metadata={"help": "Path to local data files or directory"} 101 | ) 102 | 103 | dataset_number_of_shards: int = field( 104 | default=1, metadata={"help": "number of shards to split the dataset into"} 105 | ) 106 | 107 | dataset_shard_index: int = field( 108 | default=0, metadata={"help": "shard index to use, to be used with dataset_number_of_shards"} 109 | ) 110 | 111 | top_k: int = field( 112 | default=3, metadata={"help": "number of documents used to train for each query"} 113 | ) 114 | gold: bool = field( 115 | default=False, metadata={"help": "gold retrieval results"}) 116 | 117 | output_path: str = field(default=None, metadata={"help": "where to save the encode"}) 118 | 119 | 120 | query_max_len: Optional[int] = field( 121 | default=32, 122 | metadata={ 123 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer " 124 | "than this will be truncated, sequences shorter will be padded." 125 | }, 126 | ) 127 | answer_max_len: Optional[int] = field( 128 | default=128, 129 | metadata={ 130 | "help": "The maximum total input sequence length after tokenization for document. Sequences longer " 131 | "than this will be truncated, sequences shorter will be padded." 132 | }, 133 | ) 134 | 135 | pad_to_multiple_of: Optional[int] = field( 136 | default=16, 137 | metadata={ 138 | "help": "If set will pad the sequence to a multiple of the provided value. This is especially useful to " 139 | "enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)." 140 | }, 141 | ) 142 | 143 | 144 | @dataclass 145 | class VDocGeneratorTrainingArguments(TrainingArguments): 146 | warmup_ratio: float = field(default=0.1) 147 | -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from typing import List, Tuple 4 | from dataclasses import dataclass 5 | from transformers import PreTrainedTokenizer, ProcessorMixin 6 | from vdocrag.vdocgenerator.arguments import DataArguments, TrainingArguments 7 | from transformers.feature_extraction_utils import BatchFeature 8 | from collections import defaultdict 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @dataclass 14 | class TrainCollator: 15 | data_args: DataArguments 16 | tokenizer: PreTrainedTokenizer 17 | processor: ProcessorMixin 18 | 19 | def __call__(self, features: List[Tuple[str, List[str]]]): 20 | """ 21 | Collate function for training. 22 | :param features: list of (query, documents) tuples 23 | :return: tokenized query_ids, document_ids 24 | """ 25 | 26 | all_queries = [f[0] for f in features] 27 | all_answers = [f[1] for f in features] 28 | all_images = [f[2] for f in features] 29 | 30 | collated = {} 31 | all_input_ids, all_label_ids, pixel_values, image_sizes = [], [], [], [] 32 | for i, (query, answer, images) in enumerate(zip(all_queries, all_answers, all_images)): 33 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(images))]) 34 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}] 35 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 36 | processed = self.processor(prompt, images, return_tensors="pt") 37 | answer = f'{answer}<|end|>\n<|endoftext|>' 38 | answer_input_ids = self.tokenizer( 39 | answer, add_special_tokens=False, return_tensors='pt' 40 | )['input_ids'] 41 | prompt_input_ids = processed['input_ids'] 42 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1) 43 | labels = torch.cat( 44 | [ 45 | torch.tensor([-100] * len(prompt_input_ids[0])).unsqueeze(0), 46 | answer_input_ids, 47 | ], 48 | dim=1, 49 | ) 50 | # prepare expected shape for pad_sequence 51 | all_input_ids.append(input_ids.squeeze(0).unsqueeze(1)) 52 | all_label_ids.append(labels.squeeze(0).unsqueeze(1)) 53 | pixel_values.append(processed['pixel_values']) 54 | image_sizes.append(processed['image_sizes']) 55 | 56 | input_ids = torch._C._nn.pad_sequence( 57 | all_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 58 | ).squeeze(2) 59 | labels = torch._C._nn.pad_sequence( 60 | all_label_ids, batch_first=True, padding_value=-100 61 | ).squeeze(2) 62 | 63 | collated['input_ids'] = input_ids 64 | collated['labels'] = labels 65 | collated['pixel_values'] = torch.cat(pixel_values, dim=0) 66 | collated['image_sizes'] = torch.cat(image_sizes, dim=0) 67 | collated['attention_mask'] = collated['input_ids'].ne(self.tokenizer.pad_token_id) 68 | 69 | return collated 70 | 71 | @dataclass 72 | class DecodeCollator: 73 | data_args: DataArguments 74 | tokenizer: PreTrainedTokenizer 75 | processor: ProcessorMixin 76 | 77 | def __call__(self, features: List[Tuple[str, str]]): 78 | """ 79 | Collate function for encoding. 80 | :param features: list of (id, text) tuples 81 | """ 82 | query_ids = [f[0] for f in features] 83 | all_queries = [f[1] for f in features] 84 | all_answers = [f[2] for f in features] 85 | all_images = [f[3] for f in features] 86 | 87 | collated = defaultdict(list) 88 | pixel_values, image_sizes = [], [] 89 | for i, (query, images) in enumerate(zip(all_queries, all_images)): 90 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(images))]) 91 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}] 92 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 93 | processed = self.processor(prompt, images, return_tensors="pt") 94 | prompt_input_ids = processed['input_ids'] 95 | collated['input_ids'].append(prompt_input_ids) 96 | pixel_values.append(processed['pixel_values']) 97 | image_sizes.append(processed['image_sizes']) 98 | 99 | collated['input_ids'] = torch.cat(collated['input_ids'], dim=0) 100 | collated['pixel_values'] = torch.cat(pixel_values, dim=0) 101 | collated['image_sizes'] = torch.cat(image_sizes, dim=0) 102 | 103 | return query_ids, all_answers, collated -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple 3 | 4 | from datasets import load_dataset 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import os 8 | import json 9 | from vdocrag.vdocgenerator.arguments import DataArguments 10 | from scipy.special import softmax 11 | from collections import defaultdict 12 | import numpy as np 13 | from functools import partial 14 | 15 | import logging 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def format_query_for_QA(query): 20 | return query.split("Query: ")[-1].strip() + "\n Answer briefly." 21 | 22 | def add_candidates(example, retrieved_docs, top_k): 23 | query_id = example["query_id"] 24 | candidates = retrieved_docs.get(query_id, [])[:top_k] 25 | example["candidates"] = candidates 26 | return example 27 | 28 | class TrainDataset(Dataset): 29 | def __init__(self, data_args: DataArguments, trainer = None): 30 | self.data_args = data_args 31 | self.train_data = load_dataset( 32 | self.data_args.dataset_name, 33 | self.data_args.dataset_config, 34 | data_files=self.data_args.dataset_path, 35 | split=self.data_args.dataset_split, 36 | cache_dir=self.data_args.dataset_cache_dir, 37 | ) 38 | 39 | self.corpus = load_dataset( 40 | self.data_args.corpus_name, 41 | self.data_args.corpus_config, 42 | data_files=self.data_args.corpus_path, 43 | split=self.data_args.corpus_split, 44 | cache_dir=self.data_args.dataset_cache_dir, 45 | ) 46 | 47 | self.docid2idx = {} 48 | for idx, doc_id in enumerate(self.corpus['doc_id']): 49 | self.docid2idx[str(doc_id)] = idx 50 | 51 | self.retrieved_docs = defaultdict(list) 52 | with open(self.data_args.retrieval_results_path) as f: 53 | lines = f.read().splitlines() 54 | for line in lines: 55 | query_id, doc_id, score = line.split() 56 | self.retrieved_docs[query_id].append(doc_id) 57 | 58 | self.train_data = self.train_data.map( 59 | partial(add_candidates, 60 | retrieved_docs=self.retrieved_docs, 61 | top_k=self.data_args.top_k) 62 | ) 63 | 64 | self.trainer = trainer 65 | 66 | def __len__(self): 67 | return len(self.train_data) 68 | 69 | def _get_image(self, doc_id): 70 | image = self.corpus[self.docid2idx[doc_id]]['image'] 71 | return image 72 | 73 | def __getitem__(self, item) -> Tuple[str, List[str]]: 74 | group = self.train_data[item] 75 | query = format_query_for_QA(group['query']) 76 | answer = group['answers'][0] 77 | images = [self._get_image(doc_id) for doc_id in group["candidates"]] 78 | 79 | return query, answer, images 80 | 81 | 82 | class DecodeDataset(Dataset): 83 | 84 | def __init__(self, data_args: DataArguments): 85 | self.data_args = data_args 86 | self.test_data = load_dataset( 87 | self.data_args.dataset_name, 88 | self.data_args.dataset_config, 89 | data_files=self.data_args.dataset_path, 90 | split=self.data_args.dataset_split, 91 | cache_dir=self.data_args.dataset_cache_dir, 92 | ) 93 | 94 | self.corpus = load_dataset( 95 | self.data_args.corpus_name, 96 | self.data_args.corpus_config, 97 | data_files=self.data_args.corpus_path, 98 | split=self.data_args.corpus_split, 99 | cache_dir=self.data_args.dataset_cache_dir, 100 | ) 101 | 102 | self.docid2idx = {} 103 | for idx, doc_id in enumerate(self.corpus['doc_id']): 104 | self.docid2idx[str(doc_id)] = idx 105 | 106 | self.retrieved_docs = defaultdict(list) 107 | with open(self.data_args.retrieval_results_path) as f: 108 | lines = f.read().splitlines() 109 | for line in lines: 110 | query_id, doc_id, score = line.split() 111 | self.retrieved_docs[query_id].append(doc_id) 112 | 113 | self.test_data = self.test_data.map( 114 | partial(add_candidates, 115 | retrieved_docs=self.retrieved_docs, 116 | top_k=self.data_args.top_k) 117 | ) 118 | 119 | def __len__(self): 120 | return len(self.test_data) 121 | 122 | def _get_image(self, doc_id): 123 | image = self.corpus[self.docid2idx[doc_id]]['image'] 124 | return image 125 | 126 | def __getitem__(self, item) -> Tuple[str, str]: 127 | data = self.test_data[item] 128 | query_id = data['query_id'] 129 | query = format_query_for_QA(data["query"]) 130 | answers = data['answers'] 131 | images = [self._get_image(doc_id) for doc_id in data["candidates"]] 132 | return query_id, query, answers, images 133 | -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/driver/generate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | import json 6 | from contextlib import nullcontext 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import time 13 | from torch.utils.data import DataLoader 14 | from transformers import AutoTokenizer, AutoProcessor 15 | from transformers import ( 16 | HfArgumentParser, 17 | ) 18 | 19 | from vdocrag.vdocgenerator.arguments import ModelArguments, DataArguments, \ 20 | VDocGeneratorTrainingArguments as TrainingArguments 21 | from vdocrag.vdocgenerator.dataset import DecodeDataset 22 | from vdocrag.vdocgenerator.collator import DecodeCollator 23 | from vdocrag.vdocgenerator.modeling import DecoderOutput, VDocGenerator 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main(): 29 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 30 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 31 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 32 | else: 33 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 34 | model_args: ModelArguments 35 | data_args: DataArguments 36 | training_args: TrainingArguments 37 | 38 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 39 | raise NotImplementedError('Multi-GPU encoding is not supported.') 40 | 41 | # Setup logging 42 | logging.basicConfig( 43 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 44 | datefmt="%m/%d/%Y %H:%M:%S", 45 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 46 | ) 47 | 48 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 49 | cache_dir=model_args.cache_dir, 50 | trust_remote_code=True,) 51 | 52 | tokenizer = processor.tokenizer 53 | 54 | if tokenizer.pad_token_id is None: 55 | tokenizer.pad_token_id = tokenizer.eos_token_id 56 | tokenizer.padding_side = 'right' 57 | 58 | if training_args.bf16: 59 | torch_dtype = torch.bfloat16 60 | elif training_args.fp16: 61 | torch_dtype = torch.float16 62 | else: 63 | torch_dtype = torch.float32 64 | 65 | model = VDocGenerator.load( 66 | model_args.model_name_or_path, 67 | lora_name_or_path=model_args.lora_name_or_path, 68 | trust_remote_code=True, 69 | cache_dir=model_args.cache_dir, 70 | torch_dtype=torch_dtype, 71 | _attn_implementation='flash_attention_2', 72 | ) 73 | 74 | decode_dataset = DecodeDataset( 75 | data_args=data_args, 76 | ) 77 | 78 | decode_collator = DecodeCollator( 79 | data_args=data_args, 80 | tokenizer=tokenizer, 81 | processor=processor, 82 | ) 83 | 84 | decode_loader = DataLoader( 85 | decode_dataset, 86 | batch_size=training_args.per_device_eval_batch_size, 87 | collate_fn=decode_collator, 88 | shuffle=False, 89 | drop_last=False, 90 | num_workers=training_args.dataloader_num_workers, 91 | ) 92 | responses = {} 93 | model = model.to(training_args.device) 94 | model.eval() 95 | 96 | generation_args = { 97 | "max_new_tokens": 64, 98 | "temperature": 0.0, 99 | "do_sample": False, 100 | "eos_token_id": tokenizer.eos_token_id, 101 | } 102 | 103 | # TODO batch > 1 104 | for (batch_ids, answers, batch) in tqdm(decode_loader): 105 | with nullcontext(): 106 | with torch.no_grad(): 107 | for k, v in batch.items(): 108 | batch[k] = v.to(training_args.device) 109 | generate_ids = model.generate(batch, 110 | generation_args=generation_args, 111 | ) 112 | generate_ids = generate_ids[:, batch['input_ids'].shape[1]:] 113 | response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 114 | response = response.strip() 115 | responses[batch_ids[0]] = {"ground_truth": answers[0], "prediction": response} 116 | 117 | if not os.path.exists(os.path.dirname(data_args.output_path)): 118 | os.makedirs(os.path.dirname(data_args.output_path)) 119 | with open(data_args.output_path, 'w') as f: 120 | json.dump(responses, f) 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/driver/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import torch 5 | import wandb 6 | 7 | from transformers import AutoTokenizer 8 | from transformers import AutoProcessor 9 | 10 | from transformers import ( 11 | HfArgumentParser, 12 | set_seed, 13 | ) 14 | 15 | from vdocrag.vdocgenerator.arguments import ModelArguments, DataArguments, \ 16 | VDocGeneratorTrainingArguments as TrainingArguments 17 | from vdocrag.vdocgenerator.dataset import TrainDataset 18 | from vdocrag.vdocgenerator.collator import TrainCollator 19 | from vdocrag.vdocgenerator.modeling import VDocGenerator 20 | from vdocrag.vdocgenerator.trainer import VDocGeneratorTrainer as Trainer 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def main(): 26 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 27 | 28 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 29 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 30 | else: 31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 32 | model_args: ModelArguments 33 | data_args: DataArguments 34 | training_args: TrainingArguments 35 | 36 | if ( 37 | os.path.exists(training_args.output_dir) 38 | and os.listdir(training_args.output_dir) 39 | and training_args.do_train 40 | and not training_args.overwrite_output_dir 41 | ): 42 | raise ValueError( 43 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 44 | ) 45 | 46 | # Setup logging 47 | logging.basicConfig( 48 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 49 | datefmt="%m/%d/%Y %H:%M:%S", 50 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 51 | ) 52 | logger.warning( 53 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 54 | training_args.local_rank, 55 | training_args.device, 56 | training_args.n_gpu, 57 | bool(training_args.local_rank != -1), 58 | training_args.fp16, 59 | ) 60 | logger.info("Training/evaluation parameters %s", training_args) 61 | logger.info("MODEL parameters %s", model_args) 62 | 63 | set_seed(training_args.seed) 64 | 65 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 66 | cache_dir=model_args.cache_dir, 67 | trust_remote_code=True) 68 | tokenizer = processor.tokenizer 69 | 70 | if tokenizer.pad_token_id is None: 71 | tokenizer.pad_token_id = tokenizer.eos_token_id 72 | tokenizer.padding_side = 'right' 73 | 74 | if training_args.bf16: 75 | torch_dtype = torch.bfloat16 76 | elif training_args.fp16: 77 | torch_dtype = torch.float16 78 | else: 79 | torch_dtype = torch.float32 80 | 81 | model = VDocGenerator.build( 82 | model_args, 83 | training_args, 84 | cache_dir=model_args.cache_dir, 85 | trust_remote_code=True, 86 | torch_dtype=torch_dtype, 87 | _attn_implementation='flash_attention_2', 88 | ) 89 | 90 | train_dataset = TrainDataset(data_args) 91 | collator = TrainCollator(data_args, tokenizer, processor) 92 | 93 | trainer_cls = Trainer 94 | 95 | trainer = trainer_cls( 96 | model=model, 97 | args=training_args, 98 | train_dataset=train_dataset, 99 | data_collator=collator 100 | ) 101 | train_dataset.trainer = trainer 102 | 103 | trainer.train() # TODO: resume training 104 | trainer.save_model() 105 | if trainer.is_world_process_zero(): 106 | tokenizer.save_pretrained(training_args.output_dir) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .vdocgenerator import VDocGenerator, DecoderOutput 2 | -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/modeling/vdocgenerator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch import nn, Tensor 7 | 8 | from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM 9 | from peft import LoraConfig, TaskType, get_peft_model, PeftModel 10 | 11 | from transformers.file_utils import ModelOutput 12 | from vdocrag.vdocgenerator.arguments import ModelArguments, VDocGeneratorTrainingArguments as TrainingArguments 13 | 14 | import logging 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | @dataclass 19 | class DecoderOutput(ModelOutput): 20 | loss: Optional[Tensor] = None 21 | 22 | 23 | class VDocGenerator(nn.Module): 24 | TRANSFORMER_CLS = AutoModelForCausalLM 25 | 26 | def __init__(self, 27 | decoder: PreTrainedModel, 28 | ): 29 | super().__init__() 30 | self.config = decoder.config 31 | self.decoder = decoder 32 | self.is_ddp = dist.is_initialized() 33 | if self.is_ddp: 34 | self.process_rank = dist.get_rank() 35 | self.world_size = dist.get_world_size() 36 | 37 | def forward(self, inputs: Dict[str, Tensor] = None, use_cache: bool = True): 38 | outputs = self.decode(inputs, use_cache=use_cache) 39 | 40 | # for training 41 | if self.training: 42 | loss = outputs.loss 43 | 44 | if self.is_ddp: 45 | loss = loss * self.world_size # counter average weight reduction 46 | 47 | # for eval 48 | else: 49 | loss = None 50 | return DecoderOutput( 51 | loss=loss, 52 | ) 53 | 54 | def gradient_checkpointing_enable(self, **kwargs): 55 | self.decoder.model.gradient_checkpointing_enable() 56 | 57 | @classmethod 58 | def build( 59 | cls, 60 | model_args: ModelArguments, 61 | train_args: TrainingArguments, 62 | **hf_kwargs, 63 | ): 64 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 65 | if base_model.config.pad_token_id is None: 66 | base_model.config.pad_token_id = 0 67 | if model_args.lora or model_args.lora_name_or_path: 68 | if train_args.gradient_checkpointing: 69 | base_model.enable_input_require_grads() 70 | if model_args.lora_name_or_path: 71 | lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs) 72 | lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, is_trainable=True) 73 | else: 74 | lora_config = LoraConfig( 75 | base_model_name_or_path=model_args.model_name_or_path, 76 | task_type=TaskType.FEATURE_EXTRACTION, 77 | r=model_args.lora_r, 78 | lora_alpha=model_args.lora_alpha, 79 | lora_dropout=model_args.lora_dropout, 80 | target_modules=model_args.lora_target_modules.split(','), 81 | inference_mode=False 82 | ) 83 | lora_model = get_peft_model(base_model, lora_config) 84 | model = cls( 85 | decoder=lora_model, 86 | ) 87 | else: 88 | model = cls( 89 | decoder=base_model, 90 | ) 91 | return model 92 | 93 | @classmethod 94 | def load(cls, 95 | model_name_or_path: str, 96 | lora_name_or_path: str = None, 97 | **hf_kwargs): 98 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs) 99 | if base_model.config.pad_token_id is None: 100 | base_model.config.pad_token_id = 0 101 | if lora_name_or_path: 102 | lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs) 103 | lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config) 104 | lora_model = lora_model.merge_and_unload() 105 | model = cls( 106 | decoder=lora_model, 107 | ) 108 | else: 109 | model = cls( 110 | decoder=base_model, 111 | ) 112 | return model 113 | 114 | def save(self, output_dir: str): 115 | self.decoder.save_pretrained(output_dir) 116 | 117 | def decode(self, input, use_cache=True): 118 | return self.decoder(**input, use_cache=use_cache) 119 | 120 | def generate(self, input, generation_args, use_cache=True): 121 | return self.decoder.generate(**input, **generation_args, use_cache=use_cache) -------------------------------------------------------------------------------- /src/vdocrag/vdocgenerator/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from transformers.trainer import Trainer, TRAINING_ARGS_NAME 7 | import torch.distributed as dist 8 | from .modeling import VDocGenerator 9 | from huggingface_hub import login 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class VDocGeneratorTrainer(Trainer): 16 | def __init__(self, *args, **kwargs): 17 | super(VDocGeneratorTrainer, self).__init__(*args, **kwargs) 18 | self.is_ddp = dist.is_initialized() 19 | self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 20 | 21 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 22 | # If we are executing this function, we are the process zero, so we don't check for that. 23 | output_dir = output_dir if output_dir is not None else self.args.output_dir 24 | os.makedirs(output_dir, exist_ok=True) 25 | logger.info(f"Saving model checkpoint to {output_dir}") 26 | 27 | supported_classes = (VDocGenerator,) 28 | # Save a trained model and configuration using `save_pretrained()`. 29 | # They can then be reloaded using `from_pretrained()` 30 | if not isinstance(self.model, supported_classes): 31 | raise ValueError(f"Unsupported model class {self.model}") 32 | else: 33 | if state_dict is None: 34 | state_dict = self.model.state_dict() 35 | prefix = 'decoder.' 36 | assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) 37 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} 38 | self.model.decoder.save_pretrained( 39 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors 40 | ) 41 | 42 | if self.tokenizer is not None: 43 | self.tokenizer.save_pretrained(output_dir) 44 | 45 | # Good practice: save your training arguments together with the trained model 46 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 47 | 48 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 49 | return model(inputs).loss 50 | 51 | def training_step(self, *args): 52 | return super(VDocGeneratorTrainer, self).training_step(*args) / self._dist_loss_scale_factor 53 | 54 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttmdlab-nlp/VDocRAG/b438c325519c738f5121682cc9bef5c1fa859e0a/src/vdocrag/vdocretriever/__init__.py -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | from transformers import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: str = field( 10 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 11 | ) 12 | config_name: Optional[str] = field( 13 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 14 | ) 15 | tokenizer_name: Optional[str] = field( 16 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 17 | ) 18 | cache_dir: Optional[str] = field( 19 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 20 | ) 21 | 22 | pooling: str = field( 23 | default='cls', 24 | metadata={"help": "pooling method for query and document encoder"} 25 | ) 26 | normalize: bool = field( 27 | default=False, 28 | metadata={"help": "normalize query and document representations"} 29 | ) 30 | 31 | temperature: float = field( 32 | default=1.0, 33 | metadata={"help": "temperature for softmax"} 34 | ) 35 | 36 | # for lora 37 | lora: bool = field(default=False, 38 | metadata={"help": "do parameter-efficient fine-tuning with lora"} 39 | ) 40 | 41 | lora_name_or_path: Optional[str] = field( 42 | default=None, metadata={"help": "Path to pretrained lora model or model identifier from huggingface.co/models"} 43 | ) 44 | 45 | lora_r: int = field( 46 | default=8, 47 | metadata={"help": "lora r"} 48 | ) 49 | 50 | lora_alpha: int = field( 51 | default=64, 52 | metadata={"help": "lora alpha"} 53 | ) 54 | 55 | lora_dropout: float = field( 56 | default=0.1, 57 | metadata={"help": "lora dropout"} 58 | ) 59 | 60 | lora_target_modules: str = field( 61 | default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj", 62 | metadata={"help": "lora target modules"} 63 | ) 64 | 65 | dtype: Optional[str] = field( 66 | default="float32", 67 | metadata={ 68 | "help": "Floating-point format in which the model weights should be initialized and trained. Choose one " 69 | "of `[float32, float16, bfloat16]`. " 70 | }, 71 | ) 72 | 73 | @dataclass 74 | class DataArguments: 75 | dataset_path: str = field( 76 | default='json', metadata={"help": "dataset path"} 77 | ) 78 | dataset_name: str = field( 79 | default='NTT-hil-insight/OpenDocVQA', metadata={"help": "huggingface dataset name"} 80 | ) 81 | dataset_config: str = field( 82 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"} 83 | ) 84 | 85 | dataset_path: str = field( 86 | default=None, metadata={"help": "Path to local data files or directory"} 87 | ) 88 | 89 | dataset_split: str = field( 90 | default='train', metadata={"help": "dataset split"} 91 | ) 92 | 93 | dataset_cache_dir: Optional[str] = field( 94 | default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"} 95 | ) 96 | 97 | corpus_name: str = field( 98 | default='NTT-hil-insight/OpenDocVQA-Corpus', metadata={"help": "huggingface dataset name"} 99 | ) 100 | 101 | corpus_config: str = field( 102 | default=None, metadata={"help": "huggingface dataset config, useful for datasets with sub-datasets"} 103 | ) 104 | 105 | corpus_path: str = field( 106 | default=None, metadata={"help": "Path to local data files or directory"} 107 | ) 108 | 109 | corpus_split: str = field( 110 | default='train', metadata={"help": "dataset split"} 111 | ) 112 | 113 | dataset_number_of_shards: int = field( 114 | default=1, metadata={"help": "number of shards to split the dataset into"} 115 | ) 116 | 117 | dataset_shard_index: int = field( 118 | default=0, metadata={"help": "shard index to use, to be used with dataset_number_of_shards"} 119 | ) 120 | 121 | train_group_size: int = field( 122 | default=8, metadata={"help": "number of documents used to train for each query"} 123 | ) 124 | positive_document_no_shuffle: bool = field( 125 | default=False, metadata={"help": "always use the first positive document for training"}) 126 | 127 | image_attention_mask: bool = field( 128 | default=False, metadata={"help": "custom attention mask for RCG task"}) 129 | 130 | pretrain: bool = field( 131 | default=False, metadata={"help": "whether pre-training is executed or not"}) 132 | 133 | encode_is_query: bool = field(default=False) 134 | 135 | encode_output_path: str = field(default=None, metadata={"help": "where to save the encode"}) 136 | 137 | query_max_len: Optional[int] = field( 138 | default=32, 139 | metadata={ 140 | "help": "The maximum total input sequence length after tokenization for query. Sequences longer " 141 | "than this will be truncated, sequences shorter will be padded." 142 | }, 143 | ) 144 | answer_max_len: Optional[int] = field( 145 | default=128, 146 | metadata={ 147 | "help": "The maximum total input sequence length after tokenization for document. Sequences longer " 148 | "than this will be truncated, sequences shorter will be padded." 149 | }, 150 | ) 151 | 152 | append_eos_token: bool = field( 153 | default=False, metadata={"help": "append eos token to query and document, this is currently used for repllama"} 154 | ) 155 | 156 | pad_to_multiple_of: Optional[int] = field( 157 | default=16, 158 | metadata={ 159 | "help": "If set will pad the sequence to a multiple of the provided value. This is especially useful to " 160 | "enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)." 161 | }, 162 | ) 163 | 164 | 165 | @dataclass 166 | class VDocRetrieverTrainingArguments(TrainingArguments): 167 | warmup_ratio: float = field(default=0.1) 168 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from typing import List, Tuple 4 | from dataclasses import dataclass 5 | from transformers import PreTrainedTokenizer, ProcessorMixin 6 | from vdocrag.vdocretriever.arguments import DataArguments 7 | from collections import defaultdict 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass 13 | class TrainCollator: 14 | data_args: DataArguments 15 | tokenizer: PreTrainedTokenizer 16 | processor: ProcessorMixin 17 | 18 | def build_image_attention_mask(self, seq_len, input_lengths): 19 | image_attention_masks = [] 20 | for input_len in input_lengths: 21 | image_attention_mask = torch.tril(torch.ones(seq_len, seq_len), diagonal=0) 22 | image_attention_mask[input_len:, :input_len-1] = 0 23 | image_attention_masks.append(image_attention_mask.unsqueeze(0)) 24 | image_attention_masks = torch.cat(image_attention_masks, dim=0) 25 | return image_attention_masks 26 | 27 | def __call__(self, features: List[Tuple[str, List[str]]]): 28 | all_queries = [f[0] for f in features] 29 | all_images = [f[-1] for f in features] 30 | 31 | q_collated = self.tokenizer( 32 | all_queries, 33 | padding=False, 34 | truncation=True, 35 | max_length=self.data_args.query_max_len-1 if self.data_args.append_eos_token else self.data_args.query_max_len, 36 | return_attention_mask=False, 37 | return_token_type_ids=False, 38 | add_special_tokens=True, 39 | ) 40 | 41 | d_collated = {} 42 | collated_list = [self.processor("<|image_1|>\nWhat is shown in this image?", image, return_tensors="pt") for image in all_images] 43 | d_collated['input_ids'] = [d['input_ids'][0].tolist() for d in collated_list] 44 | 45 | if self.data_args.append_eos_token: 46 | q_collated['input_ids'] = [q + [self.tokenizer.eos_token_id] for q in q_collated['input_ids']] 47 | d_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in d_collated['input_ids']] 48 | 49 | if self.data_args.pretrain: 50 | p_collated = {} 51 | all_input_ids, all_label_ids, input_lengths = [], [], [] 52 | 53 | for i, ocr in enumerate(all_queries): 54 | prompt_input_ids = torch.tensor(d_collated['input_ids'][i]).unsqueeze(0) 55 | answer = f'{ocr}<|end|>\n<|endoftext|>' 56 | answer_input_ids = self.tokenizer( 57 | answer, add_special_tokens=False, max_length=self.data_args.answer_max_len, truncation=True, return_tensors='pt')['input_ids'] 58 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1) 59 | labels = torch.cat( 60 | [ 61 | torch.tensor([-100] * len(prompt_input_ids[0])).unsqueeze(0), 62 | answer_input_ids, 63 | ], 64 | dim=1, 65 | ) 66 | all_input_ids.append(input_ids.squeeze(0).unsqueeze(1)) 67 | all_label_ids.append(labels.squeeze(0).unsqueeze(1)) 68 | input_lengths.append(len(prompt_input_ids[0])) 69 | 70 | input_ids = torch._C._nn.pad_sequence( 71 | all_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 72 | ).squeeze(2) 73 | labels = torch._C._nn.pad_sequence( 74 | all_label_ids, batch_first=True, padding_value=-100 75 | ).squeeze(2) 76 | 77 | p_collated['input_ids'] = input_ids 78 | p_collated['labels'] = labels 79 | 80 | if self.data_args.image_attention_mask: 81 | image_attention_mask = self.build_image_attention_mask(input_ids.size()[1], input_lengths) 82 | p_collated['attention_mask'] = image_attention_mask.unsqueeze(1) 83 | else: 84 | p_collated = None 85 | 86 | q_collated = self.tokenizer.pad( 87 | q_collated, 88 | padding=True, 89 | pad_to_multiple_of=self.data_args.pad_to_multiple_of, 90 | return_attention_mask=True, 91 | return_tensors='pt', 92 | ) 93 | d_collated = self.tokenizer.pad( 94 | d_collated, 95 | padding=True, 96 | pad_to_multiple_of=self.data_args.pad_to_multiple_of, 97 | return_attention_mask=True, 98 | return_tensors='pt', 99 | ) 100 | 101 | d_collated['pixel_values'] = torch.stack([d['pixel_values'][0] for d in collated_list], dim=0) 102 | d_collated['image_sizes'] = torch.stack([d['image_sizes'][0] for d in collated_list], dim=0) 103 | if self.data_args.pretrain: 104 | p_collated['pixel_values'] = d_collated['pixel_values'] 105 | p_collated['image_sizes'] = d_collated['image_sizes'] 106 | 107 | return q_collated, d_collated, p_collated 108 | 109 | @dataclass 110 | class EncodeCollator: 111 | data_args: DataArguments 112 | tokenizer: PreTrainedTokenizer 113 | processor: ProcessorMixin 114 | 115 | def __call__(self, features: List[Tuple[str, str]]): 116 | text_ids = [x[0] for x in features] 117 | texts = [x[1] for x in features] 118 | images = [x[-1] for x in features] 119 | 120 | if self.data_args.encode_is_query: 121 | collated = self.tokenizer( 122 | texts, 123 | padding=False, 124 | truncation=True, 125 | max_length=self.data_args.query_max_len-1 if self.data_args.append_eos_token else self.data_args.query_max_len, 126 | return_attention_mask=False, 127 | return_token_type_ids=False, 128 | add_special_tokens=True, 129 | ) 130 | else: 131 | collated = {} 132 | collated_list = [self.processor("<|image_1|>\nWhat is shown in this image?", image, return_tensors="pt") for image in images] 133 | collated['input_ids'] = [d['input_ids'][0].tolist() for d in collated_list] 134 | 135 | if self.data_args.append_eos_token: 136 | collated['input_ids'] = [x + [self.tokenizer.eos_token_id] for x in collated['input_ids']] 137 | 138 | collated = self.tokenizer.pad( 139 | collated, 140 | padding=True, 141 | pad_to_multiple_of=self.data_args.pad_to_multiple_of, 142 | return_attention_mask=True, 143 | return_tensors='pt', 144 | ) 145 | if not self.data_args.encode_is_query: 146 | collated['pixel_values'] = torch.stack([d['pixel_values'][0] for d in collated_list], dim=0) 147 | collated['image_sizes'] = torch.stack([d['image_sizes'][0] for d in collated_list], dim=0) 148 | 149 | return text_ids, collated -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from datasets import load_dataset 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | import os 7 | import json 8 | from vdocrag.vdocretriever.arguments import DataArguments 9 | 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class TrainDataset(Dataset): 15 | def __init__(self, data_args: DataArguments, trainer = None): 16 | self.data_args = data_args 17 | self.train_data = load_dataset( 18 | self.data_args.dataset_name, 19 | self.data_args.dataset_config, 20 | data_files=self.data_args.dataset_path, 21 | split=self.data_args.dataset_split, 22 | cache_dir=self.data_args.dataset_cache_dir, 23 | ) 24 | if not self.data_args.pretrain: 25 | self.corpus = load_dataset( 26 | self.data_args.corpus_name, 27 | self.data_args.corpus_config, 28 | data_files=self.data_args.corpus_path, 29 | split=self.data_args.corpus_split, 30 | cache_dir=self.data_args.dataset_cache_dir, 31 | ) 32 | 33 | self.docid2idx = {} 34 | if 'doc_id' in self.corpus.features: 35 | for idx, docid in enumerate(self.corpus['doc_id']): 36 | self.docid2idx[str(docid)] = idx 37 | else: 38 | for idx in range(len(self.corpus)): 39 | self.docid2idx[str(idx)] = idx 40 | 41 | self.trainer = trainer 42 | 43 | def __len__(self): 44 | return len(self.train_data) 45 | 46 | def __getitem__(self, item) -> Tuple[str, List[str]]: 47 | group = self.train_data[item] 48 | epoch = int(self.trainer.state.epoch) 49 | 50 | _hashed_seed = hash(item + self.trainer.args.seed) 51 | 52 | query = group['query'] 53 | if self.data_args.pretrain: 54 | image = group['image'] 55 | else: 56 | relevant_docids = group['relevant_doc_ids'] 57 | 58 | if self.data_args.positive_document_no_shuffle: 59 | docid = relevant_docids[0] 60 | else: 61 | docid = relevant_docids[(_hashed_seed + epoch) % len(relevant_docids)] 62 | 63 | image = image = self.corpus[self.docid2idx[docid]]['image'] 64 | 65 | return query, image 66 | 67 | 68 | class EncodeDataset(Dataset): 69 | def __init__(self, data_args: DataArguments): 70 | self.data_args = data_args 71 | if self.data_args.encode_is_query: 72 | self.encode_data = load_dataset( 73 | self.data_args.dataset_name, 74 | self.data_args.dataset_config, 75 | data_files=self.data_args.dataset_path, 76 | split=self.data_args.dataset_split, 77 | cache_dir=self.data_args.dataset_cache_dir, 78 | ) 79 | else: 80 | self.encode_data = load_dataset( 81 | self.data_args.corpus_name, 82 | self.data_args.corpus_config, 83 | data_files=self.data_args.corpus_path, 84 | split=self.data_args.corpus_split, 85 | cache_dir=self.data_args.dataset_cache_dir, 86 | ) 87 | 88 | if self.data_args.dataset_number_of_shards > 1: 89 | self.encode_data = self.encode_data.shard( 90 | num_shards=self.data_args.dataset_number_of_shards, 91 | index=self.data_args.dataset_shard_index, 92 | ) 93 | 94 | def __len__(self): 95 | return len(self.encode_data) 96 | 97 | def __getitem__(self, item) -> Tuple[str, str]: 98 | data = self.encode_data[item] 99 | text, image = None, None 100 | if self.data_args.encode_is_query: 101 | id = data['query_id'] 102 | text = data['query'] 103 | else: 104 | id = data['doc_id'] 105 | image = data['image'] 106 | return id, text, image 107 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/driver/encode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoTokenizer, AutoProcessor 14 | from transformers import ( 15 | HfArgumentParser, 16 | ) 17 | 18 | from vdocrag.vdocretriever.arguments import ModelArguments, DataArguments, \ 19 | VDocRetrieverTrainingArguments as TrainingArguments 20 | from vdocrag.vdocretriever.dataset import EncodeDataset 21 | from vdocrag.vdocretriever.collator import EncodeCollator 22 | from vdocrag.vdocretriever.modeling import EncoderOutput, VDocRetriever 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def main(): 28 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 29 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 30 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 31 | else: 32 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 33 | model_args: ModelArguments 34 | data_args: DataArguments 35 | training_args: TrainingArguments 36 | 37 | if training_args.local_rank > 0 or training_args.n_gpu > 1: 38 | raise NotImplementedError('Multi-GPU encoding is not supported.') 39 | 40 | # Setup logging 41 | logging.basicConfig( 42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 43 | datefmt="%m/%d/%Y %H:%M:%S", 44 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 45 | ) 46 | 47 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 48 | cache_dir=model_args.cache_dir, 49 | trust_remote_code=True,) 50 | tokenizer = processor.tokenizer 51 | 52 | if tokenizer.pad_token_id is None: 53 | tokenizer.pad_token_id = tokenizer.eos_token_id 54 | tokenizer.padding_side = 'right' 55 | 56 | if training_args.bf16: 57 | torch_dtype = torch.bfloat16 58 | elif training_args.fp16: 59 | torch_dtype = torch.float16 60 | else: 61 | torch_dtype = torch.float32 62 | 63 | model = VDocRetriever.load( 64 | model_args.model_name_or_path, 65 | pooling=model_args.pooling, 66 | normalize=model_args.normalize, 67 | lora_name_or_path=model_args.lora_name_or_path, 68 | trust_remote_code=True, 69 | cache_dir=model_args.cache_dir, 70 | torch_dtype=torch_dtype, 71 | _attn_implementation='flash_attention_2', 72 | ) 73 | 74 | encode_dataset = EncodeDataset( 75 | data_args=data_args, 76 | ) 77 | 78 | encode_collator = EncodeCollator( 79 | data_args=data_args, 80 | tokenizer=tokenizer, 81 | processor=processor, 82 | ) 83 | 84 | encode_loader = DataLoader( 85 | encode_dataset, 86 | batch_size=training_args.per_device_eval_batch_size, 87 | collate_fn=encode_collator, 88 | shuffle=False, 89 | drop_last=False, 90 | num_workers=training_args.dataloader_num_workers, 91 | ) 92 | encoded = [] 93 | lookup_indices = [] 94 | model = model.to(training_args.device) 95 | model.eval() 96 | 97 | for (batch_ids, batch) in tqdm(encode_loader): 98 | lookup_indices.extend(batch_ids) 99 | with nullcontext(): 100 | with torch.no_grad(): 101 | for k, v in batch.items(): 102 | batch[k] = v.to(training_args.device) 103 | if data_args.encode_is_query: 104 | model_output: EncoderOutput = model(query=batch, use_cache=False) 105 | encoded.append(model_output.q_reps.cpu().detach().float().numpy()) 106 | else: 107 | model_output: EncoderOutput = model(document=batch, use_cache=False) 108 | encoded.append(model_output.p_reps.cpu().detach().float().numpy()) 109 | 110 | encoded = np.concatenate(encoded) 111 | if not os.path.exists(os.path.dirname(data_args.encode_output_path)): 112 | os.makedirs(os.path.dirname(data_args.encode_output_path)) 113 | with open(data_args.encode_output_path, 'wb') as f: 114 | pickle.dump((encoded, lookup_indices), f) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/driver/search.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import glob 5 | from argparse import ArgumentParser 6 | from itertools import chain 7 | from tqdm import tqdm 8 | 9 | from vdocrag.vdocretriever.searcher import FaissFlatSearcher 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig( 14 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 15 | datefmt="%m/%d/%Y %H:%M:%S", 16 | level=logging.INFO, 17 | ) 18 | 19 | 20 | def search_queries(retriever, q_reps, p_lookup, args): 21 | if args.batch_size > 0: 22 | all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size, args.quiet) 23 | else: 24 | all_scores, all_indices = retriever.search(q_reps, args.depth) 25 | 26 | psg_indices = [[str(p_lookup[x]) for x in q_dd] for q_dd in all_indices] 27 | psg_indices = np.array(psg_indices) 28 | return all_scores, psg_indices 29 | 30 | 31 | def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): 32 | with open(ranking_save_file, 'w') as f: 33 | for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices): 34 | score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)] 35 | score_list = sorted(score_list, key=lambda x: x[0], reverse=True) 36 | for s, idx in score_list: 37 | f.write(f'{qid}\t{idx}\t{s}\n') 38 | 39 | 40 | def pickle_load(path): 41 | with open(path, 'rb') as f: 42 | reps, lookup = pickle.load(f) 43 | return np.array(reps), lookup 44 | 45 | 46 | def pickle_save(obj, path): 47 | with open(path, 'wb') as f: 48 | pickle.dump(obj, f) 49 | 50 | 51 | def main(): 52 | parser = ArgumentParser() 53 | parser.add_argument('--query_reps', required=True) 54 | parser.add_argument('--document_reps', required=True) 55 | parser.add_argument('--batch_size', type=int, default=128) 56 | parser.add_argument('--depth', type=int, default=1000) 57 | parser.add_argument('--save_ranking_to', required=True) 58 | parser.add_argument('--save_text', action='store_true') 59 | parser.add_argument('--quiet', action='store_true') 60 | 61 | args = parser.parse_args() 62 | 63 | index_files = glob.glob(args.document_reps) 64 | logger.info(f'Pattern match found {len(index_files)} files; loading them into index.') 65 | 66 | p_reps_0, p_lookup_0 = pickle_load(index_files[0]) 67 | retriever = FaissFlatSearcher(p_reps_0) 68 | 69 | shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:])) 70 | if len(index_files) > 1: 71 | shards = tqdm(shards, desc='Loading shards into index', total=len(index_files)) 72 | look_up = [] 73 | for p_reps, p_lookup in shards: 74 | retriever.add(p_reps) 75 | look_up += p_lookup 76 | 77 | q_reps, q_lookup = pickle_load(args.query_reps) 78 | q_reps = q_reps 79 | 80 | logger.info('Index Search Start') 81 | all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) 82 | logger.info('Index Search Finished') 83 | 84 | if args.save_text: 85 | write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) 86 | else: 87 | pickle_save((all_scores, psg_indices), args.save_ranking_to) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/driver/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import torch 5 | import wandb 6 | 7 | from transformers import AutoTokenizer 8 | from transformers import AutoProcessor 9 | 10 | from transformers import ( 11 | HfArgumentParser, 12 | set_seed, 13 | ) 14 | 15 | from vdocrag.vdocretriever.arguments import ModelArguments, DataArguments, \ 16 | VDocRetrieverTrainingArguments as TrainingArguments 17 | from vdocrag.vdocretriever.dataset import TrainDataset 18 | from vdocrag.vdocretriever.collator import TrainCollator 19 | from vdocrag.vdocretriever.modeling import VDocRetriever 20 | from vdocrag.vdocretriever.trainer import VDocRetrieverTrainer as Trainer 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def main(): 26 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 27 | 28 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 29 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 30 | else: 31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 32 | model_args: ModelArguments 33 | data_args: DataArguments 34 | training_args: TrainingArguments 35 | 36 | if ( 37 | os.path.exists(training_args.output_dir) 38 | and os.listdir(training_args.output_dir) 39 | and training_args.do_train 40 | and not training_args.overwrite_output_dir 41 | ): 42 | raise ValueError( 43 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 44 | ) 45 | 46 | # Setup logging 47 | logging.basicConfig( 48 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 49 | datefmt="%m/%d/%Y %H:%M:%S", 50 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 51 | ) 52 | logger.warning( 53 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 54 | training_args.local_rank, 55 | training_args.device, 56 | training_args.n_gpu, 57 | bool(training_args.local_rank != -1), 58 | training_args.fp16, 59 | ) 60 | logger.info("Training/evaluation parameters %s", training_args) 61 | logger.info("MODEL parameters %s", model_args) 62 | 63 | set_seed(training_args.seed) 64 | 65 | processor = AutoProcessor.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 66 | cache_dir=model_args.cache_dir, 67 | trust_remote_code=True) 68 | tokenizer = processor.tokenizer 69 | 70 | if training_args.bf16: 71 | torch_dtype = torch.bfloat16 72 | elif training_args.fp16: 73 | torch_dtype = torch.float16 74 | else: 75 | torch_dtype = torch.float32 76 | 77 | model = VDocRetriever.build( 78 | model_args, 79 | training_args, 80 | cache_dir=model_args.cache_dir, 81 | trust_remote_code=True, 82 | torch_dtype=torch_dtype, 83 | _attn_implementation='eager' if data_args.pretrain else 'flash_attention_2', 84 | ) 85 | 86 | train_dataset = TrainDataset(data_args) 87 | collator = TrainCollator(data_args, tokenizer, processor) 88 | 89 | trainer_cls = Trainer 90 | 91 | trainer = trainer_cls( 92 | model=model, 93 | args=training_args, 94 | train_dataset=train_dataset, 95 | data_collator=collator 96 | ) 97 | train_dataset.trainer = trainer 98 | 99 | trainer.train() # TODO: resume training 100 | trainer.save_model() 101 | if trainer.is_world_process_zero(): 102 | tokenizer.save_pretrained(training_args.output_dir) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .vdocretriever import VDocRetriever, EncoderOutput 2 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/modeling/vdocretriever.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch import nn, Tensor 7 | 8 | from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq 9 | from peft import LoraConfig, TaskType, get_peft_model, PeftModel 10 | 11 | from transformers.file_utils import ModelOutput 12 | from vdocrag.vdocretriever.arguments import ModelArguments, VDocRetrieverTrainingArguments as TrainingArguments 13 | 14 | import logging 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | @dataclass 19 | class EncoderOutput(ModelOutput): 20 | q_reps: Optional[Tensor] = None 21 | p_reps: Optional[Tensor] = None 22 | loss: Optional[Tensor] = None 23 | scores: Optional[Tensor] = None 24 | 25 | class VDocRetriever(nn.Module): 26 | TRANSFORMER_CLS = AutoModelForCausalLM 27 | 28 | def __init__(self, 29 | encoder: PreTrainedModel, 30 | pooling: str = 'cls', 31 | normalize: bool = False, 32 | temperature: float = 1.0, 33 | ): 34 | super().__init__() 35 | self.config = encoder.config 36 | self.encoder = encoder 37 | self.pooling = pooling 38 | self.normalize = normalize 39 | self.temperature = temperature 40 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') 41 | self.is_ddp = dist.is_initialized() 42 | if self.is_ddp: 43 | self.process_rank = dist.get_rank() 44 | self.world_size = dist.get_world_size() 45 | 46 | def forward(self, query: Dict[str, Tensor] = None, 47 | document: Dict[str, Tensor] = None, 48 | pair: Dict[str, Tensor] = None, 49 | use_cache: bool = True 50 | ): 51 | q_reps = self.encode_query(query, use_cache=use_cache) if query else None 52 | p_reps = self.encode_document(document, use_cache=use_cache) if document else None 53 | outputs = self.generate_output(pair, use_cache=use_cache) if pair else None # pre-training 54 | 55 | # for inference 56 | if q_reps is None or p_reps is None: 57 | return EncoderOutput( 58 | q_reps=q_reps, 59 | p_reps=p_reps 60 | ) 61 | 62 | # for training 63 | if self.training: 64 | if self.is_ddp: 65 | q_reps = self._dist_gather_tensor(q_reps) 66 | p_reps = self._dist_gather_tensor(p_reps) 67 | 68 | scores = self.compute_similarity(q_reps, p_reps) 69 | scores = scores.view(q_reps.size(0), -1) 70 | 71 | target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) 72 | target = target * (p_reps.size(0) // q_reps.size(0)) 73 | 74 | loss = self.compute_loss(scores / self.temperature, target) 75 | 76 | if outputs: 77 | loss = loss + outputs.loss 78 | 79 | if self.is_ddp: 80 | loss = loss * self.world_size # counter average weight reduction 81 | 82 | # for eval 83 | else: 84 | scores = self.compute_similarity(q_reps, p_reps) 85 | loss = None 86 | 87 | return EncoderOutput( 88 | loss=loss, 89 | scores=scores, 90 | q_reps=q_reps, 91 | p_reps=p_reps, 92 | ) 93 | 94 | def compute_similarity(self, q_reps, p_reps): 95 | return torch.matmul(q_reps, p_reps.transpose(0, 1)) 96 | 97 | def compute_loss(self, scores, target): 98 | return self.cross_entropy(scores, target) 99 | 100 | def gradient_checkpointing_enable(self, **kwargs): 101 | self.encoder.model.gradient_checkpointing_enable() 102 | 103 | def _dist_gather_tensor(self, t: Optional[torch.Tensor]): 104 | if t is None: 105 | return None 106 | t = t.contiguous() 107 | 108 | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] 109 | dist.all_gather(all_tensors, t) 110 | 111 | all_tensors[self.process_rank] = t 112 | all_tensors = torch.cat(all_tensors, dim=0) 113 | 114 | return all_tensors 115 | 116 | @classmethod 117 | def build( 118 | cls, 119 | model_args: ModelArguments, 120 | train_args: TrainingArguments, 121 | **hf_kwargs, 122 | ): 123 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs) 124 | if base_model.config.pad_token_id is None: 125 | base_model.config.pad_token_id = 0 126 | 127 | if model_args.lora or model_args.lora_name_or_path: 128 | if train_args.gradient_checkpointing: 129 | base_model.enable_input_require_grads() 130 | if model_args.lora_name_or_path: 131 | lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs) 132 | lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, is_trainable=True) 133 | else: 134 | lora_config = LoraConfig( 135 | base_model_name_or_path=model_args.model_name_or_path, 136 | task_type=TaskType.FEATURE_EXTRACTION, 137 | r=model_args.lora_r, 138 | lora_alpha=model_args.lora_alpha, 139 | lora_dropout=model_args.lora_dropout, 140 | target_modules=model_args.lora_target_modules.split(','), 141 | inference_mode=False 142 | ) 143 | lora_model = get_peft_model(base_model, lora_config) 144 | model = cls( 145 | encoder=lora_model, 146 | pooling=model_args.pooling, 147 | normalize=model_args.normalize, 148 | temperature=model_args.temperature, 149 | ) 150 | else: 151 | model = cls( 152 | encoder=base_model, 153 | pooling=model_args.pooling, 154 | normalize=model_args.normalize, 155 | temperature=model_args.temperature 156 | ) 157 | return model 158 | 159 | @classmethod 160 | def load(cls, 161 | model_name_or_path: str, 162 | pooling: str = 'cls', 163 | normalize: bool = False, 164 | lora_name_or_path: str = None, 165 | **hf_kwargs): 166 | base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs) 167 | if base_model.config.pad_token_id is None: 168 | base_model.config.pad_token_id = 0 169 | if lora_name_or_path: 170 | lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs) 171 | lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config) 172 | lora_model = lora_model.merge_and_unload() 173 | model = cls( 174 | encoder=lora_model, 175 | pooling=pooling, 176 | normalize=normalize 177 | ) 178 | else: 179 | model = cls( 180 | encoder=base_model, 181 | pooling=pooling, 182 | normalize=normalize 183 | ) 184 | return model 185 | 186 | def save(self, output_dir: str): 187 | self.encoder.save_pretrained(output_dir) 188 | 189 | def encode_query(self, qry, use_cache=True): 190 | query_hidden_states = self.encoder(**qry, return_dict=True, output_hidden_states=True, output_attentions=True, use_cache=use_cache) 191 | query_hidden_states = query_hidden_states.hidden_states[-1] 192 | return self._pooling(query_hidden_states, qry['attention_mask']) 193 | 194 | def encode_document(self, doc, use_cache=True): 195 | return self.encode_query(doc, use_cache=use_cache) 196 | 197 | def generate_output(self, pair, use_cache=True): 198 | return self.encoder(**pair, use_cache=use_cache) 199 | 200 | def _pooling(self, last_hidden_state, attention_mask): 201 | if self.pooling in ['cls', 'first']: 202 | reps = last_hidden_state[:, 0] 203 | elif self.pooling in ['mean', 'avg', 'average']: 204 | masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) 205 | reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 206 | elif self.pooling in ['last', 'eos']: 207 | sequence_lengths = attention_mask.sum(dim=1) - 1 208 | batch_size = last_hidden_state.shape[0] 209 | reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] 210 | else: 211 | raise ValueError(f'unknown pooling method: {self.pooling}') 212 | if self.normalize: 213 | reps = torch.nn.functional.normalize(reps, p=2, dim=-1) 214 | return reps -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/searcher.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class FaissFlatSearcher: 12 | def __init__(self, init_reps: np.ndarray): 13 | index = faiss.IndexFlatIP(init_reps.shape[1]) 14 | self.index = index 15 | 16 | def add(self, p_reps: np.ndarray): 17 | self.index.add(p_reps) 18 | 19 | def search(self, q_reps: np.ndarray, k: int): 20 | return self.index.search(q_reps, k) 21 | 22 | def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int, quiet: bool=False): 23 | num_query = q_reps.shape[0] 24 | all_scores = [] 25 | all_indices = [] 26 | for start_idx in tqdm(range(0, num_query, batch_size), disable=quiet): 27 | nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k) 28 | all_scores.append(nn_scores) 29 | all_indices.append(nn_indices) 30 | all_scores = np.concatenate(all_scores, axis=0) 31 | all_indices = np.concatenate(all_indices, axis=0) 32 | 33 | return all_scores, all_indices 34 | 35 | 36 | class FaissSearcher(FaissFlatSearcher): 37 | 38 | def __init__(self, init_reps: np.ndarray, factory_str: str): 39 | index = faiss.index_factory(init_reps.shape[1], factory_str) 40 | self.index = index 41 | self.index.verbose = True 42 | if not self.index.is_trained: 43 | self.index.train(init_reps) 44 | -------------------------------------------------------------------------------- /src/vdocrag/vdocretriever/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from transformers.trainer import Trainer, TRAINING_ARGS_NAME 7 | import torch.distributed as dist 8 | from .modeling import VDocRetriever 9 | from huggingface_hub import login 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class VDocRetrieverTrainer(Trainer): 16 | def __init__(self, *args, **kwargs): 17 | super(VDocRetrieverTrainer, self).__init__(*args, **kwargs) 18 | self.is_ddp = dist.is_initialized() 19 | self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 20 | 21 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 22 | # If we are executing this function, we are the process zero, so we don't check for that. 23 | output_dir = output_dir if output_dir is not None else self.args.output_dir 24 | os.makedirs(output_dir, exist_ok=True) 25 | logger.info(f"Saving model checkpoint to {output_dir}") 26 | 27 | supported_classes = (VDocRetriever,) 28 | # Save a trained model and configuration using `save_pretrained()`. 29 | # They can then be reloaded using `from_pretrained()` 30 | if not isinstance(self.model, supported_classes): 31 | raise ValueError(f"Unsupported model class {self.model}") 32 | else: 33 | if state_dict is None: 34 | state_dict = self.model.state_dict() 35 | prefix = 'encoder.' 36 | assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) 37 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} 38 | self.model.encoder.save_pretrained( 39 | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors 40 | ) 41 | 42 | if self.tokenizer is not None: 43 | self.tokenizer.save_pretrained(output_dir) 44 | 45 | # Good practice: save your training arguments together with the trained model 46 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 47 | 48 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 49 | query, document, pair = inputs 50 | return model(query=query, document=document, pair=pair).loss 51 | 52 | def training_step(self, *args): 53 | return super(VDocRetrieverTrainer, self).training_step(*args) / self._dist_loss_scale_factor 54 | 55 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import requests 3 | from io import BytesIO 4 | from torch.nn.functional import cosine_similarity 5 | import torch 6 | from transformers import AutoProcessor 7 | from vdocrag.vdocretriever.modeling import VDocRetriever 8 | from vdocrag.vdocgenerator.modeling import VDocGenerator 9 | 10 | ### Retrieval ### 11 | 12 | processor = AutoProcessor.from_pretrained('microsoft/Phi-3-vision-128k-instruct', trust_remote_code=True) 13 | model = VDocRetriever.load('microsoft/Phi-3-vision-128k-instruct', lora_name_or_path='NTT-hil-insight/VDocRetriever-Phi3-vision', pooling='eos', normalize=True, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0') 14 | 15 | # Process query inputs and get embeddings 16 | queries = ["Instruct: I’m looking for an image that answers the question.\nQuery: What is the total percentage of Palestinians residing at West Bank?", 17 | "Instruct: I’m looking for an image that answers the question.\nQuery: How many international visitors came to Japan in 2017?"] 18 | query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=256, truncation=True).to('cuda:0') 19 | 20 | with torch.no_grad(): 21 | model_output = model(query=query_inputs, use_cache=False) 22 | query_embeddings = model_output.q_reps 23 | 24 | # List of image URLs 25 | urls = [ 26 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image1.png", 27 | "https://huggingface.co/datasets/NTT-hil-insight/OpenDocVQA/resolve/main/image2.png" 28 | ] 29 | 30 | # Download, open, and resize images 31 | doc_images = [Image.open(BytesIO(requests.get(url).content)).resize((1344, 1344)) for url in urls] 32 | 33 | # Process images with prompt 34 | doc_prompt = "<|image_1|>\nWhat is shown in this image?" 35 | collated_list = [ 36 | processor( 37 | doc_prompt, 38 | images=image, 39 | return_tensors="pt", 40 | padding="longest", 41 | max_length=4096, 42 | truncation=True 43 | ).to('cuda:0') for image in doc_images 44 | ] 45 | 46 | # Stack tensors into input dict 47 | doc_inputs = { 48 | key: torch.stack([item[key][0] for item in collated_list], dim=0) 49 | for key in ['input_ids', 'attention_mask', 'pixel_values', 'image_sizes'] 50 | } 51 | 52 | with torch.no_grad(): 53 | model_output = model(document=doc_inputs, use_cache=False) 54 | doc_embeddings = model_output.p_reps 55 | 56 | # Calculate cosine similarity 57 | num_queries = query_embeddings.size(0) 58 | num_passages = doc_embeddings.size(0) 59 | 60 | for i in range(num_queries): 61 | query_embedding = query_embeddings[i].unsqueeze(0) 62 | similarities = cosine_similarity(query_embedding, doc_embeddings) 63 | print(f"Similarities for Query {i}: {similarities.cpu().float().numpy()}") 64 | 65 | # >> Similarities for Query 0: [0.5078125 0.38085938] 66 | # Similarities for Query 1: [0.37695312 0.5703125 ] 67 | 68 | ### Generation ### 69 | 70 | model = VDocGenerator.load('microsoft/Phi-3-vision-128k-instruct', lora_name_or_path='NTT-hil-insight/VDocGenerator-Phi3-vision', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0') 71 | 72 | query = "How many international visitors came to Japan in 2017? \n Answer briefly." 73 | 74 | image_tokens = "\n".join([f"<|image_{i+1}|>" for i in range(len(doc_images))]) 75 | messages = [{"role": "user", "content": f"{image_tokens}\n{query}"}] 76 | prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 77 | 78 | processed = processor(prompt, images=doc_images, return_tensors="pt").to('cuda:0') 79 | generate_ids = model.generate(processed, generation_args={"max_new_tokens": 64, "temperature": 0.0, "do_sample": False, "eos_token_id": processor.tokenizer.eos_token_id}) 80 | generate_ids = generate_ids[:, processed['input_ids'].shape[1]:] 81 | response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 82 | response = response.strip() 83 | print("Model prediction: {0}".format(response)) 84 | 85 | ## >> Model prediction: 28.69m --------------------------------------------------------------------------------