├── LICENSE ├── README.md ├── data ├── __pycache__ │ └── chartqa_data.cpython-310.pyc └── chartqa_data.py ├── finetune_chartqa.py └── model ├── __pycache__ └── chartqa_model.cpython-310.pyc └── chartqa_model.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 vis-nlp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning 2 | 3 | * Authors: [Ahmed Masry](https://ahmedmasryku.github.io/)*, Parsa Kavehzadeh*, Do Long, Shafiq Joty, Enamul Hoque (*equal contribution) 4 | * Paper Link: [UniChart](https://arxiv.org/abs/2305.14761) 5 | * **[NEW]** If you are looking for more advanced Chart Models, explore our latest models for chart understanding: 6 | * [ChartInstruct](https://github.com/vis-nlp/ChartInstruct) 7 | * Our advanced Chart Large Language Model based on LLaVA, supporting LLama2 (7B) and Flan-T5-XL (3B). Perfect for a wide range of chart-related tasks. 8 | * [ChartGemma](https://github.com/vis-nlp/ChartGemma) 9 | * The state-of-the-art Chart LLM built on PaliGemma (3B), optimized for visual reasoning tasks. 10 | * **Both models are user-friendly and can be run with just a few lines of code. Public web demos are available! Check out their GitHub repositories for more details.** 11 | 12 | ## UniChart Pretraining Dataset 13 | Our pretraining dataset is divided into two primary components: 14 | 1. A zip file encompassing all the images. You can access the images through this huggingface dataset: [Images](https://huggingface.co/datasets/ahmed-masry/UniChart-pretrain-images) 15 | 2. A Huggingface dataset containing the input/output pairs utilized for model pretraining. You can find the dataset here: [Huggingface Dataset](https://huggingface.co/datasets/ahmed-masry/unichart-pretrain-data) 16 | 17 | ## UniChart Model Checkpoints 18 | We release the checkpoints for our pretrained models as well as the finetuned checkpoints on the different downstream tasks 19 | | Task | Checkpoint Path | 20 | | ------------- | ------------- | 21 | | Pretrained | [unichart-base-960](https://huggingface.co/ahmed-masry/unichart-base-960) | 22 | | ChartQA | [unichart-chartqa-960](https://huggingface.co/ahmed-masry/unichart-chartqa-960) | 23 | | Chart2Text-Statista | [unichart-chart2text-statista-960](https://huggingface.co/ahmed-masry/unichart-chart2text-statista-960) | 24 | | Chart2Text-Pew | [unichart-chart2text-pew-960](https://huggingface.co/ahmed-masry/unichart-chart2text-pew-960) | 25 | | OpenCQA | [unichart-opencqa-960](https://huggingface.co/ahmed-masry/unichart-opencqa-960) | 26 | 27 | ## Web Demo 28 | If you wish to quickly try our models, you can access our public web demoes hosted on the Hugging Face Spaces platform with a friendly interface! 29 | 30 | | Tasks | Web Demo | 31 | | ------------- | ------------- | 32 | | Base Model (Best for Chart Summarization and Data Table Generation) | [UniChart-Base](https://huggingface.co/spaces/ahmed-masry/UniChart-Base) | 33 | | Chart Question Answering | [UniChart-ChartQA](https://huggingface.co/spaces/ahmed-masry/UniChart-ChartQA) | 34 | 35 | The input prompt for Chart summarization is **** and Data Table Generation is **** 36 | 37 | ## Requirements 38 | 39 | ``` 40 | transformers==4.28.1 41 | pytorch-lightning==1.8.5 42 | datasets 43 | sentencepiece 44 | ``` 45 | Please make sure to use the **exact same version** of the **Transformers** library. We have noticed that there might be a drop in performance when using different versions of the library! 46 | ## Inference 47 | You can easily use our models for inference with the huggingface library! 48 | You just need to do the following: 49 | 1. Change _model_name_ to your prefered checkpoint. 50 | 2. Chage the _imag_path_ to your chart example image path on your system 51 | 3. Write the _input_prompt_ based on your prefered task as shown in the table below. 52 | 53 | | Task | Input Prompt | 54 | | ------------- | ------------- | 55 | | Chart Question Answering | \ question | 56 | | Open Chart Question Answering | \ question | 57 | | Chart Summarization | | 58 | | Data Table Extraction | | 59 | 60 | ``` 61 | from transformers import DonutProcessor, VisionEncoderDecoderModel 62 | from PIL import Image 63 | import torch, os, re 64 | 65 | torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png') 66 | 67 | model_name = "ahmed-masry/unichart-chartqa-960" 68 | image_path = "/content/chart_example_1.png" 69 | input_prompt = " What is the lowest value in blue bar? " 70 | 71 | model = VisionEncoderDecoderModel.from_pretrained(model_name) 72 | processor = DonutProcessor.from_pretrained(model_name) 73 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 74 | model.to(device) 75 | 76 | image = Image.open(image_path).convert("RGB") 77 | decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids 78 | pixel_values = processor(image, return_tensors="pt").pixel_values 79 | 80 | outputs = model.generate( 81 | pixel_values.to(device), 82 | decoder_input_ids=decoder_input_ids.to(device), 83 | max_length=model.decoder.config.max_position_embeddings, 84 | early_stopping=True, 85 | pad_token_id=processor.tokenizer.pad_token_id, 86 | eos_token_id=processor.tokenizer.eos_token_id, 87 | use_cache=True, 88 | num_beams=4, 89 | bad_words_ids=[[processor.tokenizer.unk_token_id]], 90 | return_dict_in_generate=True, 91 | ) 92 | sequence = processor.batch_decode(outputs.sequences)[0] 93 | sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") 94 | sequence = sequence.split("")[1].strip() 95 | print(sequence) 96 | 97 | ``` 98 | 99 | ## Finetuning 100 | In order to finetune the model on the ChartQA dataset, you can edit and run the following command: 101 | ``` 102 | python finetune_chartqa.py --data-path "ahmed-masry/chartqa_without_images" --train-images '/content/ChartQA/ChartQA Dataset/train/png/' \ 103 | --valid-images '/content/ChartQA/ChartQA Dataset/val/png' --max-steps 40000 --batch-size 8 --valid-batch-size 1 --num-workers 12 --lr 5e-5 \ 104 | --check-val-every-n-epoch 1 --warmup-steps 100 --checkpoint-steps 7000 --checkpoint-path "ahmed-masry/unichart-base-960" 105 | ``` 106 | 107 | # Contact 108 | If you have any questions about this work, please contact **[Ahmed Masry](https://ahmedmasryku.github.io/)** using the following email addresses: **amasry17@ku.edu.tr** or **ahmed.elmasry24653@gmail.com**. 109 | 110 | # Reference 111 | Please cite our paper if you use our models or dataset in your research. 112 | 113 | ``` 114 | @misc{masry2023unichart, 115 | title={UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning}, 116 | author={Ahmed Masry and Parsa Kavehzadeh and Xuan Long Do and Enamul Hoque and Shafiq Joty}, 117 | year={2023}, 118 | eprint={2305.14761}, 119 | archivePrefix={arXiv}, 120 | primaryClass={cs.CL} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /data/__pycache__/chartqa_data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vis-nlp/UniChart/bd6004bc8fe9ef8ce9a6cdfd88712f845d78b918/data/__pycache__/chartqa_data.cpython-310.pyc -------------------------------------------------------------------------------- /data/chartqa_data.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | import random 3 | from typing import Any, List, Tuple 4 | from PIL import Image 5 | import torch 6 | from torch.utils.data import Dataset 7 | from transformers import DonutProcessor 8 | from datasets import load_dataset, load_from_disk 9 | 10 | added_tokens = [] 11 | 12 | class ChartQADataset(Dataset): 13 | """ 14 | """ 15 | 16 | def __init__( 17 | self, 18 | dataset: str, 19 | images_folder: str, 20 | max_length: int, 21 | processor : DonutProcessor = None, 22 | split: str = "train", 23 | ignore_id: int = -100, 24 | prompt_end_token: str = None, 25 | task_prefix: str = '', 26 | sort_json_key: bool = True, 27 | ): 28 | super().__init__() 29 | 30 | self.max_length = max_length 31 | self.split = split 32 | self.ignore_id = ignore_id 33 | 34 | self.prompt_end_token = prompt_end_token 35 | self.sort_json_key = sort_json_key 36 | self.images_folder = images_folder 37 | 38 | 39 | self.dataset = dataset 40 | self.dataset_length = len(self.dataset) 41 | 42 | self.processor = processor 43 | self.prompt_end_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token) 44 | self.task_prefix = task_prefix 45 | 46 | 47 | def __len__(self) -> int: 48 | return self.dataset_length 49 | 50 | def __getitem__(self, idx: int): 51 | 52 | sample = self.dataset[idx] 53 | 54 | # input_tensor 55 | img_path = os.path.join(self.images_folder, sample['imgname']) 56 | img = Image.open(img_path) 57 | pixel_values = self.processor(img.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values 58 | input_tensor = pixel_values.squeeze() 59 | 60 | # input_ids 61 | processed_parse = self.task_prefix + " " + sample['query'] + " " + self.prompt_end_token + " " + sample['label'] + self.processor.tokenizer.eos_token 62 | input_ids = self.processor.tokenizer( 63 | processed_parse, 64 | add_special_tokens=False, 65 | max_length=self.max_length, 66 | padding="max_length", 67 | truncation=True, 68 | return_tensors="pt", 69 | )["input_ids"].squeeze(0) 70 | 71 | if self.split == "train": 72 | labels = input_ids.clone() 73 | labels[ 74 | labels == self.processor.tokenizer.pad_token_id 75 | ] = self.ignore_id # model doesn't need to predict pad token 76 | labels[ 77 | : torch.nonzero(labels == self.prompt_end_token_id).sum() + 1 78 | ] = self.ignore_id # model doesn't need to predict prompt 79 | return input_tensor, input_ids, labels 80 | else: 81 | prompt_end_index = torch.nonzero( 82 | input_ids == self.prompt_end_token_id 83 | ).sum() # return prompt end index instead of target output labels 84 | return input_tensor, input_ids, prompt_end_index, processed_parse -------------------------------------------------------------------------------- /finetune_chartqa.py: -------------------------------------------------------------------------------- 1 | from transformers import VisionEncoderDecoderConfig 2 | from transformers import DonutProcessor, VisionEncoderDecoderModel, BartConfig 3 | import argparse 4 | from torch.utils.data import DataLoader 5 | from typing import List 6 | from datasets import load_dataset 7 | 8 | from data.chartqa_data import ChartQADataset 9 | from model.chartqa_model import ChartQAModule 10 | 11 | import pytorch_lightning as pl 12 | 13 | #from pytorch_lightning.loggers import WandbLogger 14 | #from pytorch_lightning.callbacks import LearningRateMonitor 15 | from pytorch_lightning.callbacks import ModelCheckpoint 16 | 17 | 18 | 19 | def main(): 20 | # Instantiate the parser 21 | parser = argparse.ArgumentParser(description='Train Chart Transformer') 22 | parser.add_argument('--data-path', type=str, default = "ahmed-masry/chartqa_without_images", help='Path to the data file') 23 | parser.add_argument('--train-images', type=str, default='/content/ChartQA/ChartQA Dataset/train/png/', help='Path to the training images') 24 | parser.add_argument('--valid-images', type=str, default='/content/ChartQA/ChartQA Dataset/val/png', help='Path to the validation images') 25 | 26 | parser.add_argument('--output-dir', type=str, default="/content/output_data", help='Path to the output directory for saving the checkpoints') 27 | parser.add_argument('--max-steps', type=int, default = 1000, help='Max number of iterations') 28 | parser.add_argument('--batch-size', type=int, default=2, help='Batch Size for the model') 29 | parser.add_argument('--valid-batch-size', type=int, default=2, help='Valid Batch Size for the model') 30 | parser.add_argument('--max-length', type=int, default=512, help='Max length for decoder generation') 31 | parser.add_argument('--num-workers', type=int, default=2, help='Number of workers') 32 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') 33 | 34 | parser.add_argument('--check-val-every-n-epoch', type=int, default=1, help='Ru validation every n epochs') 35 | parser.add_argument('--log-every-n-steps', type=int, default=50, help='Log every n steps') 36 | parser.add_argument('--warmup-steps', type=int, default=50, help='Warmup steps') 37 | parser.add_argument('--checkpoint-steps', type=int, default=1000, help='Checkpoint steps') 38 | parser.add_argument('--gradient-clip-val', type=float, default=1.0, help='gradient clip value') 39 | 40 | parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='accumulate grad batches') 41 | parser.add_argument('--gpus-num', type=int, default=1, help='gpus num') 42 | parser.add_argument('--nodes-num', type=int, default=1, help='nodes num') 43 | 44 | parser.add_argument('--checkpoint-path', type=str, default = "ahmed-masry/unichart-base-960", help='Path to the checkpoint') 45 | 46 | args = parser.parse_args() 47 | 48 | processor = DonutProcessor.from_pretrained(args.checkpoint_path) 49 | model = VisionEncoderDecoderModel.from_pretrained(args.checkpoint_path) 50 | 51 | dataset = load_dataset(args.data_path) 52 | 53 | train_dataset = ChartQADataset(dataset["train"], images_folder = args.train_images, processor = processor, max_length=args.max_length, 54 | split="train", prompt_end_token="", task_prefix = "" 55 | ) 56 | 57 | val_dataset = ChartQADataset(dataset["val"], images_folder = args.valid_images, processor = processor, max_length=args.max_length, 58 | split="valid", prompt_end_token="", task_prefix = "" 59 | ) 60 | 61 | 62 | config = {"max_steps":args.max_steps, 63 | "check_val_every_n_epoch":args.check_val_every_n_epoch, 64 | "log_every_n_steps":args.log_every_n_steps, 65 | "gradient_clip_val":args.gradient_clip_val, 66 | "num_training_samples_per_epoch": len(dataset["train"]), 67 | "lr":args.lr, 68 | "train_batch_sizes": [args.batch_size], 69 | "val_batch_sizes": [args.valid_batch_size], 70 | "num_nodes": args.nodes_num, 71 | "warmup_steps": args.warmup_steps, 72 | "result_path": args.output_dir, 73 | "verbose": True, 74 | } 75 | 76 | model_module = ChartQAModule(config, processor, model, args, train_dataset, val_dataset) 77 | 78 | # wandb_logger = WandbLogger(project="UniChart-ChartQA") 79 | # lr_callback = LearningRateMonitor(logging_interval="step") 80 | checkpoint_callback = ModelCheckpoint(dirpath=args.output_dir, every_n_train_steps = args.checkpoint_steps, save_last = True, save_top_k = -1) 81 | 82 | trainer = pl.Trainer( 83 | accelerator="gpu", 84 | devices=args.gpus_num, 85 | max_steps=args.max_steps, 86 | check_val_every_n_epoch=args.check_val_every_n_epoch, 87 | # val_check_interval=100, 88 | log_every_n_steps=args.log_every_n_steps, 89 | gradient_clip_val=args.gradient_clip_val, 90 | 91 | num_nodes=args.nodes_num, 92 | precision=16, # we'll use mixed precision 93 | num_sanity_val_steps=0, 94 | #enable_checkpointing=True, 95 | default_root_dir=args.output_dir, 96 | # logger=wandb_logger, 97 | callbacks=[checkpoint_callback], 98 | ) 99 | 100 | trainer.fit(model_module) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() -------------------------------------------------------------------------------- /model/__pycache__/chartqa_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vis-nlp/UniChart/bd6004bc8fe9ef8ce9a6cdfd88712f845d78b918/model/__pycache__/chartqa_model.cpython-310.pyc -------------------------------------------------------------------------------- /model/chartqa_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | from nltk import edit_distance 4 | import numpy as np 5 | import math, os 6 | 7 | from torch.nn.utils.rnn import pad_sequence 8 | from torch.optim.lr_scheduler import LambdaLR 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | import pytorch_lightning as pl 13 | from pytorch_lightning.utilities import rank_zero_only 14 | 15 | 16 | class ChartQAModule(pl.LightningModule): 17 | def __init__(self, config, processor, model, args, train_dataset, val_dataset): 18 | super().__init__() 19 | self.config = config 20 | self.processor = processor 21 | self.model = model 22 | self.train_dataset = train_dataset 23 | self.val_dataset = val_dataset 24 | self.args=args 25 | 26 | def training_step(self, batch, batch_idx): 27 | pixel_values, decoder_input_ids, labels = batch 28 | 29 | outputs = self.model(pixel_values, 30 | decoder_input_ids=decoder_input_ids[:, :-1], 31 | labels=labels[:, 1:]) 32 | loss = outputs.loss 33 | self.log_dict({"train_loss": loss}, sync_dist=True) 34 | return loss 35 | 36 | def compute_metric(self, gt, pred): 37 | try: 38 | gt = float(gt) 39 | pred = float(pred) 40 | return abs(gt - pred) / abs(gt) <= 0.05 41 | except: 42 | return str(gt).lower() == str(pred).lower() 43 | 44 | def validation_step(self, batch, batch_idx, dataset_idx=0): 45 | pixel_values, decoder_input_ids, prompt_end_idxs, answers = batch 46 | decoder_prompts = pad_sequence( 47 | [input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)], 48 | batch_first=True, 49 | ) 50 | 51 | outputs = self.model.generate(pixel_values, 52 | decoder_input_ids=decoder_prompts, 53 | max_length=self.args.max_length, 54 | early_stopping=True, 55 | pad_token_id=self.processor.tokenizer.pad_token_id, 56 | eos_token_id=self.processor.tokenizer.eos_token_id, 57 | use_cache=True, 58 | num_beams=4, 59 | bad_words_ids=[[self.processor.tokenizer.unk_token_id]], 60 | return_dict_in_generate=True,) 61 | 62 | predictions = [] 63 | for seq in self.processor.tokenizer.batch_decode(outputs.sequences): 64 | seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "") 65 | predictions.append(seq) 66 | 67 | scores = list() 68 | for pred, answer in zip(predictions, answers): 69 | pred = pred.split("")[1] 70 | pred = pred.replace(self.processor.tokenizer.eos_token, "").replace("", "").strip(' ') 71 | answer = answer.split("")[1] 72 | answer = answer.replace(self.processor.tokenizer.eos_token, "").strip(' ') 73 | if self.compute_metric(answer, pred): 74 | scores.append(1) 75 | else: 76 | scores.append(0) 77 | 78 | return scores 79 | 80 | def validation_epoch_end(self, validation_step_outputs): 81 | # I set this to 1 manually 82 | # (previously set to len(self.config.dataset_name_or_paths)) 83 | num_of_loaders = 1 84 | if num_of_loaders == 1: 85 | validation_step_outputs = [validation_step_outputs] 86 | assert len(validation_step_outputs) == num_of_loaders 87 | cnt = [0] * num_of_loaders 88 | total_metric = [0] * num_of_loaders 89 | val_metric = [0] * num_of_loaders 90 | for i, results in enumerate(validation_step_outputs): 91 | for scores in results: 92 | cnt[i] += len(scores) 93 | total_metric[i] += np.sum(scores) 94 | val_metric[i] = total_metric[i] / cnt[i] 95 | val_metric_name = f"val_metric_{i}th_dataset" 96 | self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True) 97 | self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True) 98 | print("Epoch:", str(self.current_epoch), "Step:", str(self.global_step), "Validation Metric:", str(np.sum(total_metric) / np.sum(cnt))) 99 | 100 | def configure_optimizers(self): 101 | 102 | max_iter = None 103 | 104 | if int(self.config.get("max_epochs", -1)) > 0: 105 | assert len(self.config.get("train_batch_sizes")) == 1, "Set max_epochs only if the number of datasets is 1" 106 | max_iter = (self.config.get("max_epochs") * self.config.get("num_training_samples_per_epoch")) / ( 107 | self.config.get("train_batch_sizes")[0] * torch.cuda.device_count() * self.config.get("num_nodes", 1) 108 | ) 109 | 110 | if int(self.config.get("max_steps", -1)) > 0: 111 | max_iter = min(self.config.get("max_steps"), max_iter) if max_iter is not None else self.config.get("max_steps") 112 | 113 | assert max_iter is not None 114 | optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr")) 115 | scheduler = { 116 | "scheduler": self.cosine_scheduler(optimizer, max_iter, self.config.get("warmup_steps")), 117 | "name": "learning_rate", 118 | "interval": "step", 119 | } 120 | return [optimizer], [scheduler] 121 | 122 | @staticmethod 123 | def cosine_scheduler(optimizer, training_steps, warmup_steps): 124 | def lr_lambda(current_step): 125 | if current_step < warmup_steps: 126 | return current_step / max(1, warmup_steps) 127 | progress = current_step - warmup_steps 128 | progress /= max(1, training_steps - warmup_steps) 129 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) 130 | 131 | return LambdaLR(optimizer, lr_lambda) 132 | 133 | def train_dataloader(self): 134 | return DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers) 135 | 136 | def val_dataloader(self): 137 | return DataLoader(self.val_dataset, batch_size=self.args.valid_batch_size, shuffle=False, num_workers=self.args.num_workers) 138 | 139 | @rank_zero_only 140 | def on_save_checkpoint(self, checkpoint): 141 | save_path = os.path.join(self.config['result_path'], 'chartqa-checkpoint-epoch='+str(self.current_epoch)+'-'+str(self.global_step)) 142 | self.model.save_pretrained(save_path) 143 | self.processor.save_pretrained(save_path) --------------------------------------------------------------------------------