├── .gitignore ├── LICENSE ├── README.md ├── annotations └── sample_coco_annotation.json ├── assets ├── cover.png └── td-id-caption.png ├── coco_to_florence.py ├── images ├── sample_paper_page_1.png └── sample_paper_page_2.png ├── inference.py ├── pdf_to_table_figures.py ├── pdfs └── arxiv_2305_04160.pdf ├── sample_images └── arxiv_2305_10853_5.png └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yifei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TF-ID 2 | This repository contains the full training code to reproduce all TF-ID models. We also open-source the model weights and human annotated dataset all under mit license. 3 | 4 | ## Model Summary 5 | ![TF-ID](https://github.com/ai8hyf/TF-ID/blob/main/assets/cover.png) 6 | 7 | TF-ID (Table/Figure IDentifier) is a family of object detection models finetuned to extract tables and figures in academic papers created by [Yifei Hu](https://x.com/hu_yifei). They come in four versions: 8 | | Model | Model size | Model Description | 9 | | ------- | ------------- | ------------- | 10 | | TF-ID-base[[HF]](https://huggingface.co/yifeihu/TF-ID-base) | 0.23B | Extract tables/figures and their caption text 11 | | TF-ID-large[[HF]](https://huggingface.co/yifeihu/TF-ID-large) (Recommended) | 0.77B | Extract tables/figures and their caption text 12 | | TF-ID-base-no-caption[[HF]](https://huggingface.co/yifeihu/TF-ID-base-no-caption) | 0.23B | Extract tables/figures without caption text 13 | | TF-ID-large-no-caption[[HF]](https://huggingface.co/yifeihu/TF-ID-large-no-caption) (Recommended) | 0.77B | Extract tables/figures without caption text 14 | 15 | All TF-ID models are finetuned from [microsoft/Florence-2](https://huggingface.co/microsoft/Florence-2-large-ft) checkpoints. 16 | 17 | ## Sample Usage 18 | - Use `python inference.py` to extract bounding boxes from one given image 19 | - Use `python pdf_to_table_figures.py`to extract all tables and figures from one pdf paper and save the cropped figures and tables under `./sample_output` 20 | - **TF-ID-large** are used in the scripts by default. You can swtich to a different variant by changing the model_id in the scripts, but large models are always recommended. 21 | 22 | ## Train TF-ID models from scratch 23 | 1. Clone the repo: `git clone https://github.com/ai8hyf/TF-ID` 24 | 2. `cd TF-ID` 25 | 3. Download the [huggingface.co/datasets/yifeihu/TF-ID-arxiv-papers](https://huggingface.co/datasets/yifeihu/TF-ID-arxiv-papers) from Hugging Face 26 | 4. Move **annotations_with_caption.json** to `./annotations` (Use **annotations_no_caption.json** if you don't want the bounding boxes to include text captions) 27 | 5. Unzip the **arxiv_paper_images.zip** and move the .png images to `./images` 28 | 6. Convert the coco format dataset to florence 2 format: `python coco_to_florence.py` 29 | 7. You should see **train.jsonl** and **test.jsonl** under `./annotations` 30 | 8. Train the model with Accelerate: `accelerate launch train.py` 31 | 9. The checkpoints will be saved under `./model_checkpoints` 32 | 33 | ## Hardware Requirement 34 | With [microsoft/Florence-2-large-ft](https://huggingface.co/microsoft/Florence-2-large-ft), `BATCH_SIZE=4` will require at least 40GB VRAM on a single GPU. The [microsoft/Florence-2-base-ft](https://huggingface.co/microsoft/Florence-2-base-ft) model takes much less VRAM. Please modify the `BATCH_SIZE` and `CHECKPOINT` parameter in the `train.py` before you start training. 35 | 36 | ## Benchmarks 37 | We tested the models on paper pages outside the training dataset. The papers are a subset of huggingface daily paper. 38 | Correct output - the model draws correct bounding boxes for every table/figure in the given page. 39 | 40 | | Model | Total Images | Correct Output | Success Rate | 41 | |---------------------------------------------------------------|--------------|----------------|--------------| 42 | | TF-ID-base[[HF]](https://huggingface.co/yifeihu/TF-ID-base) | 258 | 251 | 97.29% | 43 | | TF-ID-large[[HF]](https://huggingface.co/yifeihu/TF-ID-large) | 258 | 253 | 98.06% | 44 | | TF-ID-base-no-caption[[HF]](https://huggingface.co/yifeihu/TF-ID-base-no-caption) | 261 | 253 | 96.93% | 45 | | TF-ID-large-no-caption[[HF]](https://huggingface.co/yifeihu/TF-ID-large-no-caption) | 261 | 254 | 97.32% | 46 | 47 | Depending on the use cases, some "incorrect" output could be totally usable. For example, the model draw two bounding boxes for one figure with two child components. 48 | 49 | ## Acknowledgement 50 | - I learned how to work with Florence 2 models from this [Roboflow's awesome tutorial](https://blog.roboflow.com/fine-tune-florence-2-object-detection/). 51 | - My friend Yi Zhang helped annotate some data to train our proof-of-concept models including a yolo-based TF-ID model. 52 | 53 | ## Citation 54 | If you find TD-ID useful, please cite this project as: 55 | ``` 56 | @misc{TF-ID, 57 | author = {Yifei Hu}, 58 | title = {TF-ID: Table/Figure IDentifier for academic papers}, 59 | year = {2024}, 60 | publisher = {GitHub}, 61 | journal = {GitHub repository}, 62 | howpublished = {\url{https://github.com/ai8hyf/TF-ID}}, 63 | } 64 | ``` -------------------------------------------------------------------------------- /assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai8hyf/TF-ID/7f29dafd98b1a01d14079077720c01b8e1d1b270/assets/cover.png -------------------------------------------------------------------------------- /assets/td-id-caption.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai8hyf/TF-ID/7f29dafd98b1a01d14079077720c01b8e1d1b270/assets/td-id-caption.png -------------------------------------------------------------------------------- /coco_to_florence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | # 1. download the dataset from https://huggingface.co/datasets/yifeihu/TF-ID-arxiv-papers 6 | # 2. move annotations_with_caption.json to the ./annotations folder 7 | # 3. unzip the arxiv_paper_images.zip and move the images to the ./images folder 8 | # 4. run the script: python coco_to_florence.py 9 | 10 | train_percentage = 0.85 11 | coco_json_dir = "./annotations/annotations_with_caption.json" # we take coco format dataset by default 12 | # coco_json_dir = "./annotations/annotations_no_caption.json" # we take coco format dataset by default 13 | output_dir = "./annotations" 14 | 15 | def convert_to_florence_format(coco_json_dir, output_dir): 16 | 17 | # the code here is very messy because it was easy to debug and understand 18 | 19 | print("start converting coco annotations to florence format...") 20 | 21 | with open(coco_json_dir, 'r') as file: 22 | data = json.load(file) 23 | 24 | category_dict = {category['id']: category['name'] for category in data['categories']} 25 | print("labels :", category_dict) 26 | 27 | img_dict = {} 28 | for img in data['images']: 29 | img_dict[img['id']] = { 30 | 'width': img['width'], 31 | 'height': img['height'], 32 | 'file_name': img['file_name'], 33 | 'annotations': [], 34 | 'annotations_str': "" 35 | } 36 | 37 | annotation_dict = {annotation['image_id']: annotation['bbox'] for annotation in data['annotations']} 38 | 39 | def format_annotation(annotation): 40 | category_id = annotation['category_id'] 41 | bbox = annotation['bbox'] # coco bbox format: [x, y, width, height] 42 | this_image_width = img_dict[int(annotation['image_id'])]['width'] 43 | this_image_height = img_dict[int(annotation['image_id'])]['height'] 44 | # normalize the numbers to be between 0 and 1 then multiplied by 1000. 45 | # forence 2 format: label 46 | x1 = int(bbox[0] / this_image_width * 1000) 47 | y1 = int(bbox[1] / this_image_height * 1000) 48 | x2 = int((bbox[0] + bbox[2]) / this_image_width * 1000) 49 | y2 = int((bbox[1] + bbox[3]) / this_image_height * 1000) 50 | 51 | return f"{category_dict[category_id]}" 52 | 53 | for annotation in data['annotations']: 54 | try: 55 | annotation_str = format_annotation(annotation) 56 | if annotation['image_id'] in img_dict: 57 | img_dict[annotation['image_id']]['annotations'].append(annotation_str) 58 | except: 59 | continue 60 | 61 | florence_data = [] 62 | for img_id, img_data in img_dict.items(): 63 | annotations_str = "".join(img_data['annotations']) 64 | 65 | if len(annotations_str) > 0: 66 | florence_data.append({ 67 | "image": img_data['file_name'], 68 | "prefix": "", 69 | "suffix": annotations_str 70 | }) 71 | else: 72 | # OPTIONAL: some images have no annotations, you can choose to ignore them 73 | # only randomly sample 5% of the images without annotations 74 | if random.random() < 0.05: 75 | florence_data.append({ 76 | "image": img_data['file_name'], 77 | "prefix": "", 78 | "suffix": "" 79 | }) 80 | 81 | print("total number of images:", len(florence_data)) 82 | 83 | # split the data into train and test and save them into jsonl files 84 | train_split = int(len(florence_data) * train_percentage) 85 | train_data = florence_data[:train_split] 86 | test_data = florence_data[train_split:] 87 | 88 | print("train size:", len(train_data)) 89 | print("test size:", len(test_data)) 90 | 91 | train_output_dir = os.path.join(output_dir, "train.jsonl") 92 | test_output_dir = os.path.join(output_dir, "test.jsonl") 93 | 94 | # save train and test data into jsonl files 95 | if os.path.exists(train_output_dir): 96 | os.remove(train_output_dir) 97 | 98 | with open(train_output_dir, 'w') as file: 99 | for entry in train_data: 100 | json.dump(entry, file) 101 | file.write("\n") 102 | 103 | if os.path.exists(test_output_dir): 104 | os.remove(test_output_dir) 105 | 106 | with open(test_output_dir, 'w') as file: 107 | for entry in test_data: 108 | json.dump(entry, file) 109 | file.write("\n") 110 | 111 | print("train and test data saved to ", output_dir) 112 | print("Now you can run \"accelerate launch train.py\" to train the model.") 113 | 114 | convert_to_florence_format(coco_json_dir, output_dir) -------------------------------------------------------------------------------- /images/sample_paper_page_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai8hyf/TF-ID/7f29dafd98b1a01d14079077720c01b8e1d1b270/images/sample_paper_page_1.png -------------------------------------------------------------------------------- /images/sample_paper_page_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai8hyf/TF-ID/7f29dafd98b1a01d14079077720c01b8e1d1b270/images/sample_paper_page_2.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | from transformers import AutoProcessor, AutoModelForCausalLM 4 | 5 | model_id = "yifeihu/TF-ID-large" # recommended: use large models for better performance 6 | # model_id = "yifeihu/TF-ID-base" 7 | # model_id = "yifeihu/TF-ID-large-no-caption" # recommended: use large models for better performance 8 | # model_id = "yifeihu/TF-ID-base-no-caption" 9 | 10 | model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) 11 | processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) 12 | 13 | prompt = "" 14 | 15 | # TF-ID models were trained on digital pdf papers 16 | image_url = "./sample_images/arxiv_2305_10853_5.png" 17 | image = Image.open(requests.get(image_url, stream=True).raw) 18 | 19 | inputs = processor(text=prompt, images=image, return_tensors="pt") 20 | 21 | generated_ids = model.generate( 22 | input_ids=inputs["input_ids"], 23 | pixel_values=inputs["pixel_values"], 24 | max_new_tokens=1024, 25 | do_sample=False, 26 | num_beams=3 27 | ) 28 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 29 | 30 | parsed_answer = processor.post_process_generation(generated_text, task="", image_size=(image.width, image.height)) 31 | 32 | print(parsed_answer) 33 | 34 | # to visualize the generated answer, check out this colab example from Florence 2 repo: https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb 35 | -------------------------------------------------------------------------------- /pdf_to_table_figures.py: -------------------------------------------------------------------------------- 1 | from pdf2image import convert_from_path, convert_from_bytes 2 | from pdf2image.exceptions import ( 3 | PDFInfoNotInstalledError, 4 | PDFPageCountError, 5 | PDFSyntaxError 6 | ) 7 | from PIL import Image 8 | from transformers import AutoProcessor, AutoModelForCausalLM 9 | 10 | import os 11 | import json 12 | import time 13 | 14 | def pdf_to_image(pdf_path): 15 | images = convert_from_path(pdf_path) 16 | return images 17 | 18 | def tf_id_detection(image, model, processor): 19 | prompt = "" 20 | inputs = processor(text=prompt, images=image, return_tensors="pt") 21 | generated_ids = model.generate( 22 | input_ids=inputs["input_ids"], 23 | pixel_values=inputs["pixel_values"], 24 | max_new_tokens=1024, 25 | do_sample=False, 26 | num_beams=3 27 | ) 28 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 29 | annotation = processor.post_process_generation(generated_text, task="", image_size=(image.width, image.height)) 30 | return annotation[""] 31 | 32 | def save_image_from_bbox(image, annotation, page, output_dir): 33 | # the name should be page + label + index 34 | for i in range(len(annotation['bboxes'])): 35 | bbox = annotation['bboxes'][i] 36 | label = annotation['labels'][i] 37 | x1, y1, x2, y2 = bbox 38 | cropped_image = image.crop((x1, y1, x2, y2)) 39 | cropped_image.save(os.path.join(output_dir, f"page_{page}_{label}_{i}.png")) 40 | 41 | def pdf_to_table_figures(pdf_path, model_id, output_dir): 42 | timestr = time.strftime("%Y%m%d-%H%M%S") 43 | output_dir = os.path.join(output_dir, timestr) 44 | 45 | os.makedirs(output_dir, exist_ok=True) 46 | 47 | images = pdf_to_image(pdf_path) 48 | print(f"PDF loaded. Number of pages: {len(images)}") 49 | model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) 50 | processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) 51 | print("Model loaded: ", model_id) 52 | 53 | print("=====================================") 54 | print("start saving cropped images") 55 | for i, image in enumerate(images): 56 | annotation = tf_id_detection(image, model, processor) 57 | save_image_from_bbox(image, annotation, i, output_dir) 58 | print(f"Page {i} saved. Number of objects: {len(annotation['bboxes'])}") 59 | 60 | print("=====================================") 61 | print("All images saved to: ", output_dir) 62 | 63 | model_id = "yifeihu/TF-ID-large" 64 | pdf_path = "./pdfs/arxiv_2305_04160.pdf" 65 | output_dir = "./sample_output" 66 | 67 | pdf_to_table_figures(pdf_path, model_id, output_dir) 68 | -------------------------------------------------------------------------------- /pdfs/arxiv_2305_04160.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai8hyf/TF-ID/7f29dafd98b1a01d14079077720c01b8e1d1b270/pdfs/arxiv_2305_04160.pdf -------------------------------------------------------------------------------- /sample_images/arxiv_2305_10853_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai8hyf/TF-ID/7f29dafd98b1a01d14079077720c01b8e1d1b270/sample_images/arxiv_2305_10853_5.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # this script uses code snippets from Roboflow's tutorial: https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-florence-2-on-detection-dataset.ipynb 2 | 3 | import torch 4 | from transformers import AutoModelForCausalLM, AutoProcessor 5 | from PIL import Image 6 | from transformers import ( 7 | AdamW, 8 | AutoModelForCausalLM, 9 | AutoProcessor, 10 | get_scheduler 11 | ) 12 | from tqdm import tqdm 13 | import os 14 | import json 15 | from torch.utils.data import Dataset, DataLoader 16 | from typing import List, Dict, Any, Tuple, Generator 17 | from accelerate import Accelerator 18 | 19 | # run coco_to_florence.py first to convert coco annotations to florence format 20 | # use "accelerate launch train.py" to run this script 21 | 22 | BATCH_SIZE = 4 # adjust based on your GPU specs 23 | gradient_accumulation_steps = 2 # adjust based on your GPU specs 24 | NUM_WORKERS = 0 25 | epochs = 8 26 | learning_rate = 5e-6 27 | 28 | img_dir = "./images" 29 | train_labels = "./annotations/train.jsonl" # generated by running coco_to_florence.py 30 | test_labels = "./annotations/test.jsonl" # generated by running coco_to_florence.py 31 | output_dir = "./model_checkpoints" 32 | 33 | CHECKPOINT = "microsoft/Florence-2-base-ft" 34 | # CHECKPOINT = "microsoft/Florence-2-large-ft" 35 | 36 | accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) 37 | 38 | DEVICE = accelerator.device 39 | model = AutoModelForCausalLM.from_pretrained( 40 | CHECKPOINT, trust_remote_code=True).to(DEVICE) 41 | processor = AutoProcessor.from_pretrained( 42 | CHECKPOINT, trust_remote_code=True) 43 | 44 | class JSONLDataset: 45 | def __init__(self, jsonl_file_path: str, image_directory_path: str): 46 | self.jsonl_file_path = jsonl_file_path 47 | self.image_directory_path = image_directory_path 48 | self.entries = self._load_entries() 49 | 50 | def _load_entries(self) -> List[Dict[str, Any]]: 51 | entries = [] 52 | with open(self.jsonl_file_path, 'r') as file: 53 | for line in file: 54 | data = json.loads(line) 55 | entries.append(data) 56 | return entries 57 | 58 | def __len__(self) -> int: 59 | return len(self.entries) 60 | 61 | def __getitem__(self, idx: int) -> Tuple[Image.Image, Dict[str, Any]]: 62 | if idx < 0 or idx >= len(self.entries): 63 | raise IndexError("Index out of range") 64 | 65 | entry = self.entries[idx] 66 | image_path = os.path.join(self.image_directory_path, entry['image']) 67 | try: 68 | image = Image.open(image_path) 69 | return (image, entry) 70 | except FileNotFoundError: 71 | raise FileNotFoundError(f"Image file {image_path} not found.") 72 | 73 | class DetectionDataset(Dataset): 74 | def __init__(self, jsonl_file_path: str, image_directory_path: str): 75 | self.dataset = JSONLDataset(jsonl_file_path, image_directory_path) 76 | 77 | def __len__(self): 78 | return len(self.dataset) 79 | 80 | def __getitem__(self, idx): 81 | image, data = self.dataset[idx] 82 | prefix = data['prefix'] 83 | suffix = data['suffix'] 84 | return prefix, suffix, image 85 | 86 | def collate_fn(batch): 87 | questions, answers, images = zip(*batch) 88 | inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(DEVICE) 89 | return inputs, answers 90 | 91 | train_dataset = DetectionDataset( 92 | jsonl_file_path = train_labels, 93 | image_directory_path = img_dir 94 | ) 95 | 96 | test_dataset = DetectionDataset( 97 | jsonl_file_path = test_labels, 98 | image_directory_path = img_dir 99 | ) 100 | 101 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=True) 102 | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=False) 103 | 104 | def train_model(train_loader, val_loader, model, processor, epochs, lr): 105 | optimizer = AdamW(model.parameters(), lr=lr) 106 | num_training_steps = epochs * len(train_loader) 107 | lr_scheduler = get_scheduler( 108 | name="linear", 109 | optimizer=optimizer, 110 | num_warmup_steps=12, 111 | num_training_steps=num_training_steps, 112 | ) 113 | 114 | model, optimizer, train_loader, lr_scheduler = accelerator.prepare( 115 | model, optimizer, train_loader, lr_scheduler 116 | ) 117 | 118 | for epoch in range(epochs): 119 | model.train() 120 | train_loss = 0 121 | for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"): 122 | with accelerator.accumulate(model): 123 | input_ids = inputs["input_ids"] 124 | pixel_values = inputs["pixel_values"] 125 | labels = processor.tokenizer( 126 | text=answers, 127 | return_tensors="pt", 128 | padding=True, 129 | return_token_type_ids=False 130 | ).input_ids.to(DEVICE) 131 | 132 | outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels) 133 | loss = outputs.loss 134 | 135 | accelerator.backward(loss), 136 | optimizer.step(), 137 | lr_scheduler.step(), 138 | optimizer.zero_grad() 139 | train_loss += loss.item() 140 | 141 | avg_train_loss = train_loss / len(train_loader) 142 | 143 | print(f"Average Training Loss: {avg_train_loss}") 144 | 145 | model.eval() 146 | val_loss = 0 147 | with torch.no_grad(): 148 | for inputs, answers in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"): 149 | 150 | input_ids = inputs["input_ids"] 151 | pixel_values = inputs["pixel_values"] 152 | labels = processor.tokenizer( 153 | text=answers, 154 | return_tensors="pt", 155 | padding=True, 156 | return_token_type_ids=False 157 | ).input_ids.to(DEVICE) 158 | 159 | outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels) 160 | loss = outputs.loss 161 | 162 | val_loss += loss.item() 163 | 164 | avg_val_loss = val_loss / len(val_loader) 165 | print(f"Average Validation Loss: {avg_val_loss}") 166 | # this loss is not very informative, should use a better metric for evaluation 167 | 168 | weights_output_dir = output_dir + f"/epoch_{epoch+1}" 169 | os.makedirs(weights_output_dir, exist_ok=True) 170 | accelerator.save_model(model, weights_output_dir) 171 | 172 | train_model(train_loader, test_loader, model, processor, epochs=epochs, lr=learning_rate) --------------------------------------------------------------------------------