├── .gitignore ├── LICENSE ├── Owl ├── .gitattributes ├── .gitignore ├── LICENSE ├── OwlEval │ ├── OwlEval.md │ ├── answer │ │ ├── MMreact_answer.jsonl │ │ ├── blip2_13b_answer.jsonl │ │ ├── llava_13b_answer.jsonl │ │ ├── mPLUG_Owl_7b_answer.jsonl │ │ ├── minigpt4_13b_answer.jsonl │ │ └── openflanmingo_answer.jsonl │ ├── cases │ │ ├── 1.jpg │ │ ├── 10.jpg │ │ ├── 11.jpg │ │ ├── 12.jpg │ │ ├── 13.jpg │ │ ├── 14.jpg │ │ ├── 15.jpg │ │ ├── 16.jpg │ │ ├── 17.jpg │ │ ├── 18.jpg │ │ ├── 19.jpg │ │ ├── 2.jpg │ │ ├── 20.jpg │ │ ├── 21.jpg │ │ ├── 22.jpg │ │ ├── 23.jpg │ │ ├── 24.jpg │ │ ├── 25.jpg │ │ ├── 26.jpg │ │ ├── 27.jpg │ │ ├── 28.jpg │ │ ├── 29.jpg │ │ ├── 3.jpg │ │ ├── 30.jpg │ │ ├── 31.jpg │ │ ├── 32.jpg │ │ ├── 33.jpg │ │ ├── 34.jpg │ │ ├── 35.jpg │ │ ├── 36.jpg │ │ ├── 37.jpg │ │ ├── 38.jpg │ │ ├── 39.jpg │ │ ├── 4.jpg │ │ ├── 40.jpg │ │ ├── 41.jpg │ │ ├── 42.jpg │ │ ├── 43.jpg │ │ ├── 44.jpg │ │ ├── 45.jpg │ │ ├── 46.jpg │ │ ├── 47.jpg │ │ ├── 48.jpg │ │ ├── 49.jpg │ │ ├── 5.jpg │ │ ├── 50.jpg │ │ ├── 6.jpg │ │ ├── 7.jpg │ │ ├── 8.jpg │ │ └── 9.jpg │ └── questions.jsonl ├── README.md ├── README_zh.md ├── assets │ ├── -twitter-blue.svg │ ├── Demo-ModelScope-brightgreen.svg │ ├── LICENSE-Apache License-blue.svg │ ├── Paper-Arxiv-orange.svg │ ├── Paper-PDF-orange.svg │ └── modelscopeIcon.svg ├── configs │ └── v0.yaml ├── examples │ ├── Yao_Ming.jpeg │ ├── ca.jpeg │ ├── fridge.jpg │ ├── laundry.jpeg │ ├── monalisa-fun.jpg │ ├── monday.jpg │ ├── mug_ad.jpeg │ ├── rap.jpeg │ ├── titanic.jpeg │ ├── vga.jpeg │ └── website.jpg ├── mplug_owl │ ├── __init__.py │ ├── configuration_mplug_owl.py │ ├── modeling_mplug_owl.py │ ├── processing_mplug_owl.py │ └── tokenization_mplug_owl.py ├── mplug_owl_video │ ├── __init__.py │ ├── configuration_mplug_owl.py │ ├── modeling_mplug_owl.py │ ├── processing_mplug_owl.py │ └── tokenization_mplug_owl.py ├── pipeline │ ├── __init__.py │ ├── data_utils │ │ ├── __init__.py │ │ ├── processors │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── caption_processor.py │ │ │ └── default_processor.py │ │ ├── randaugment.py │ │ ├── registry.py │ │ └── xgpt3_dataset.py │ ├── interface.py │ ├── train.py │ └── utils.py ├── requirements.txt ├── scripts │ ├── train_it.sh │ └── train_it_wo_lora.sh ├── serve │ ├── __init__.py │ ├── conversation.py │ ├── gradio_css.py │ ├── gradio_patch.py │ ├── io_utils.py │ ├── model_utils.py │ ├── model_worker.py │ ├── serve_utils.py │ └── web_server.py └── test-ant.py ├── assets └── method.png ├── common ├── args.py ├── gpt.py ├── interrupt_wrapper.py ├── models.py └── utils.py ├── configs ├── minigpt4_infer_fp16-dpo.yaml ├── minigpt4_infer_fp16.yaml ├── minigpt4_infer_fp16_hadpo.yaml └── minigpt4_train_fp16.yaml ├── dataset ├── caption_prompt.txt ├── caption_prompt_finetune.txt ├── synonyms.txt └── vqa_prompt.txt ├── deploy ├── dockerfile ├── requirements.txt ├── run.sh ├── run_eval.sh └── sources.list ├── evaluate ├── eval_auto.py ├── eval_caption.py ├── eval_gqa.py ├── eval_mme.py ├── eval_qbench.py ├── eval_sqa.py ├── eval_utils.py ├── m4c_evaluator.py └── mme │ └── calculation.py ├── finetune ├── data.py ├── format_data.py ├── loss.py ├── train.py └── train_vqa.py ├── insight ├── clip.py ├── object.py └── plot_time.py ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ └── train_xformers.py └── utils.py ├── minigpt4 ├── __init__.py ├── common │ ├── __init__.py │ ├── config.py │ ├── dist_utils.py │ ├── eval_utils.py │ ├── gradcam.py │ ├── logger.py │ ├── optims.py │ ├── registry.py │ ├── utils.py │ └── vqa_tools │ │ ├── VQA │ │ ├── PythonEvaluationTools │ │ │ ├── vqaEvalDemo.py │ │ │ └── vqaEvaluation │ │ │ │ ├── __init__.py │ │ │ │ └── vqaEval.py │ │ ├── PythonHelperTools │ │ │ ├── vqaDemo.py │ │ │ └── vqaTools │ │ │ │ ├── __init__.py │ │ │ │ └── vqa.py │ │ ├── QuestionTypes │ │ │ ├── abstract_v002_question_types.txt │ │ │ └── mscoco_question_types.txt │ │ ├── README.md │ │ └── license.txt │ │ ├── __init__.py │ │ ├── vqa.py │ │ └── vqa_eval.py ├── configs │ ├── datasets │ │ ├── aokvqa │ │ │ └── defaults.yaml │ │ ├── cc_sbu │ │ │ ├── align.yaml │ │ │ └── defaults.yaml │ │ ├── coco │ │ │ ├── caption.yaml │ │ │ └── defaults_vqa.yaml │ │ ├── coco_bbox │ │ │ ├── invrefcoco.yaml │ │ │ ├── invrefcocog.yaml │ │ │ ├── invrefcocop.yaml │ │ │ ├── refcoco.yaml │ │ │ ├── refcocog.yaml │ │ │ └── refcocop.yaml │ │ ├── flickr │ │ │ ├── caption_to_phrase.yaml │ │ │ ├── default.yaml │ │ │ └── object_to_phrase.yaml │ │ ├── gqa │ │ │ └── balanced_val.yaml │ │ ├── laion │ │ │ └── defaults.yaml │ │ ├── llava │ │ │ ├── conversation.yaml │ │ │ ├── detail.yaml │ │ │ └── reason.yaml │ │ ├── multitask_conversation │ │ │ └── default.yaml │ │ ├── nlp │ │ │ └── unnatural_instruction.yaml │ │ ├── ocrvqa │ │ │ └── ocrvqa.yaml │ │ ├── okvqa │ │ │ └── defaults.yaml │ │ ├── textcaps │ │ │ └── caption.yaml │ │ └── vg │ │ │ └── ref.yaml │ ├── default.yaml │ └── models │ │ ├── minigpt4_llama2.yaml │ │ ├── minigpt4_vicuna0.yaml │ │ └── minigpt_v2.yaml ├── conversation │ ├── __init__.py │ └── conversation.py ├── datasets │ ├── __init__.py │ ├── builders │ │ ├── __init__.py │ │ ├── base_dataset_builder.py │ │ └── image_text_pair_builder.py │ ├── data_utils.py │ └── datasets │ │ ├── __init__.py │ │ ├── aok_vqa_datasets.py │ │ ├── base_dataset.py │ │ ├── caption_datasets.py │ │ ├── cc_sbu_dataset.py │ │ ├── coco_caption.py │ │ ├── coco_dataset.py │ │ ├── coco_vqa_datasets.py │ │ ├── dataloader_utils.py │ │ ├── flickr.py │ │ ├── gqa_datasets.py │ │ ├── laion_dataset.py │ │ ├── llava_dataset.py │ │ ├── multitask_conversation.py │ │ ├── ocrvqa_dataset.py │ │ ├── text_caps.py │ │ ├── unnatural_instruction.py │ │ ├── vg_dataset.py │ │ └── vqa_datasets.py ├── models │ ├── Qformer.py │ ├── __init__.py │ ├── base_model.py │ ├── eva_vit.py │ ├── minigpt4.py │ ├── minigpt_base.py │ ├── minigpt_v2.py │ └── modeling_llama.py ├── processors │ ├── __init__.py │ ├── base_processor.py │ ├── blip_processors.py │ └── randaugment.py ├── runners │ ├── __init__.py │ └── runner_base.py └── tasks │ ├── __init__.py │ ├── base_task.py │ └── image_text_pretrain.py ├── readme.md └── share4v ├── __init__.py ├── constants.py ├── conversation.py ├── eval ├── eval_gpt_review.py ├── eval_gpt_review_bench.py ├── eval_gpt_review_visual.py ├── eval_pope.py ├── eval_science_qa.py ├── eval_science_qa_gpt4.py ├── eval_science_qa_gpt4_requery.py ├── eval_textvqa.py ├── m4c_evaluator.py ├── model_qa.py ├── model_vqa.py ├── model_vqa_loader.py ├── model_vqa_mmbench.py ├── model_vqa_qbench.py ├── model_vqa_science.py ├── qa_baseline_gpt35.py ├── run_share4v.py ├── summarize_gpt_review.py └── table │ └── rule.json ├── mm_utils.py ├── model ├── __init__.py ├── builder.py ├── consolidate.py ├── language_model │ └── share4v_llama.py ├── multimodal_encoder │ ├── builder.py │ ├── clip_encoder.py │ ├── configuration_evaclip.py │ └── modeling_evaclip.py ├── multimodal_projector │ └── builder.py ├── share4v_arch.py └── utils.py ├── train ├── llama_flash_attn_monkey_patch.py ├── llama_xformers_attn_monkey_patch.py ├── share4v_trainer.py ├── train.py ├── train_mem.py └── train_xformers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/data 2 | __pycache__ 3 | checkpoint* 4 | tmp.txt 5 | finetune/sent_stat.py 6 | *.bak 7 | insight/search.py 8 | common/gpt_free.py 9 | -------------------------------------------------------------------------------- /Owl/.gitattributes: -------------------------------------------------------------------------------- 1 | *.py eol=lf 2 | *.rst eol=lf 3 | *.md eol=lf 4 | *.mdx eol=lf -------------------------------------------------------------------------------- /Owl/.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # tests and logs 12 | tests/fixtures/cached_*_text.txt 13 | logs/ 14 | lightning_logs/ 15 | lang_code_data/ 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # vscode 125 | .vs 126 | .vscode 127 | 128 | # Pycharm 129 | .idea 130 | 131 | # TF code 132 | tensorflow_code 133 | 134 | # Models 135 | proc_data 136 | 137 | # examples 138 | runs 139 | /runs_old 140 | /wandb 141 | /output 142 | /configs_dev 143 | /scripts_dev 144 | # /examples/runs 145 | # /examples/**/*.args 146 | # /examples/rag/sweep 147 | 148 | # data 149 | /data 150 | serialization_dir 151 | 152 | # emacs 153 | *.*~ 154 | debug.env 155 | 156 | # vim 157 | .*.swp 158 | 159 | #ctags 160 | tags 161 | 162 | # pre-commit 163 | .pre-commit* 164 | 165 | # .lock 166 | *.lock 167 | 168 | # DS_Store (MacOS) 169 | .DS_Store 170 | 171 | # ruff 172 | .ruff_cache -------------------------------------------------------------------------------- /Owl/OwlEval/OwlEval.md: -------------------------------------------------------------------------------- 1 | # OwlEval 2 | 3 | We have compiled some examples and their corresponding questions from recent open-source work, and organized them into OwlEval. 4 | 5 | Following we will introduce the OwlEval and the data format in this document. 6 | 7 | ## Data Format 8 | 9 | ### questions 10 | 11 | `questions.jsonl` contains case images and information about their corresponding questions 12 | 13 | Each row contains the following field: 14 | 15 | - `image`: Indicate the name of the picture 16 | - `question_id`: Indicate the question id number, there are 82 questions 17 | 18 | - `question`: Represent specific problem information 19 | - `ability`: Represent the required abilities of the model. Detailed definiton can be found in paper. 20 | - `a` is `Instruction Understanding` 21 | - `b` is `Visual Understanding` 22 | - `c` is `Optical Character Recognition` 23 | - `d` is `Knowledge Transfer Ability` 24 | - `e` is `Reasoning Ability` 25 | - `f` is `Multi-turn Dialogue Ability` 26 | - `category`:Indicate whether the problem is a single-turn problem or a multi-turn problem 27 | 28 | For example: 29 | 30 | ```json 31 | {"image": "1.jpg", "question_id": 1, "question": "What is funny about this image? Describe it panel by panel.", "ability": "a,b,e", "category": ["single_turn"]} 32 | ``` 33 | 34 | ### answer 35 | 36 | This contains the responses of each model for each question, integrated into six jsonl: 37 | 38 | `llava_13b_answer.jsonl` 39 | 40 | `minigpt4_13b_answer.jsonl` 41 | 42 | `MMreact_answer.jsonl` 43 | 44 | `mPLUG_Owl_7b_answer.jsonl` 45 | 46 | `BLIP2_13b_answer.jsonl` 47 | 48 | `openflamingo_answer.jsonl` 49 | 50 | For each `answer/xxx.jsonl` it contains the following information: 51 | 52 | - `image`: Indicate the name of the picture 53 | - `question_id`: Indicate the question id number, there are 82 questions 54 | 55 | - `question`: Represent specific problem information 56 | - `answer`: Replie given by the model 57 | - `model_id`: The ID of the model the answer is generated by 58 | 59 | For example: 60 | 61 | ```json 62 | {"image": "10.jpg", "question_id": 15, "question": "How many bedrooms are there in this floor plan?", "answer": "There are three bedrooms in this floor plan.", "model_id": "llava-13b"} 63 | ``` 64 | 65 | ### cases 66 | 67 | This folder contains 50 evaluation pictures, where 21 from mini GPT-4, 13 from mm-react, 9 from blip-2, 3 from GPT-4 and 4 collected by us 68 | -------------------------------------------------------------------------------- /Owl/OwlEval/cases/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/1.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/10.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/11.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/12.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/13.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/14.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/15.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/16.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/17.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/18.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/19.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/2.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/20.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/21.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/22.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/23.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/24.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/25.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/26.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/27.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/28.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/29.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/29.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/3.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/30.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/31.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/32.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/33.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/34.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/34.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/35.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/35.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/36.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/37.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/37.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/38.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/39.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/4.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/40.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/41.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/41.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/42.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/42.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/43.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/43.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/44.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/45.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/45.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/46.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/46.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/47.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/47.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/48.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/48.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/49.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/49.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/5.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/50.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/6.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/7.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/8.jpg -------------------------------------------------------------------------------- /Owl/OwlEval/cases/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/9.jpg -------------------------------------------------------------------------------- /Owl/assets/-twitter-blue.svg: -------------------------------------------------------------------------------- 1 | twittertwitter -------------------------------------------------------------------------------- /Owl/assets/Demo-ModelScope-brightgreen.svg: -------------------------------------------------------------------------------- 1 | Demo: ModelScopeDemoModelScope -------------------------------------------------------------------------------- /Owl/assets/LICENSE-Apache License-blue.svg: -------------------------------------------------------------------------------- 1 | LICENSE: Apache LicenseLICENSEApache License -------------------------------------------------------------------------------- /Owl/assets/Paper-Arxiv-orange.svg: -------------------------------------------------------------------------------- 1 | Paper: ArxivPaperArxiv -------------------------------------------------------------------------------- /Owl/assets/Paper-PDF-orange.svg: -------------------------------------------------------------------------------- 1 | Paper: PDFPaperPDF -------------------------------------------------------------------------------- /Owl/assets/modelscopeIcon.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Owl/configs/v0.yaml: -------------------------------------------------------------------------------- 1 | data_files: [ 2 | 'sft_v0.1_train.jsonl', 3 | 'sft_v0.1_dev.jsonl' 4 | ] 5 | 6 | train_processors: { 7 | sft: {type: 'CaptionProcessor', image_size: 224, min_scale: 0.5, randaug: False} 8 | } 9 | 10 | valid_processors: { 11 | sft: {type: 'DefaultProcessor', image_size: 224} 12 | } -------------------------------------------------------------------------------- /Owl/examples/Yao_Ming.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/Yao_Ming.jpeg -------------------------------------------------------------------------------- /Owl/examples/ca.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/ca.jpeg -------------------------------------------------------------------------------- /Owl/examples/fridge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/fridge.jpg -------------------------------------------------------------------------------- /Owl/examples/laundry.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/laundry.jpeg -------------------------------------------------------------------------------- /Owl/examples/monalisa-fun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/monalisa-fun.jpg -------------------------------------------------------------------------------- /Owl/examples/monday.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/monday.jpg -------------------------------------------------------------------------------- /Owl/examples/mug_ad.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/mug_ad.jpeg -------------------------------------------------------------------------------- /Owl/examples/rap.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/rap.jpeg -------------------------------------------------------------------------------- /Owl/examples/titanic.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/titanic.jpeg -------------------------------------------------------------------------------- /Owl/examples/vga.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/vga.jpeg -------------------------------------------------------------------------------- /Owl/examples/website.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/website.jpg -------------------------------------------------------------------------------- /Owl/mplug_owl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available 17 | 18 | 19 | _import_structure = { 20 | "configuration_mplug_owl": ["MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MplugOwlConfig"], 21 | "processing_mplug_owl": ["MplugOwlImageProcessor", "MplugOwlProcessor"], 22 | "tokenization_mplug_owl": ["MplugOwlTokenizer"], 23 | } 24 | 25 | try: 26 | if not is_tokenizers_available(): 27 | raise OptionalDependencyNotAvailable() 28 | except OptionalDependencyNotAvailable: 29 | pass 30 | 31 | 32 | try: 33 | if not is_torch_available(): 34 | raise OptionalDependencyNotAvailable() 35 | except OptionalDependencyNotAvailable: 36 | pass 37 | else: 38 | _import_structure["modeling_mplug_owl"] = [ 39 | "MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST", 40 | "MplugOwlForConditionalGeneration", 41 | "MplugOwlModel", 42 | ] 43 | 44 | 45 | if TYPE_CHECKING: 46 | from .configuration_mplug_owl import MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP, MplugOwlConfig 47 | from .tokenization_mplug_owl import MplugOwlTokenizer 48 | 49 | try: 50 | if not is_tokenizers_available(): 51 | raise OptionalDependencyNotAvailable() 52 | except OptionalDependencyNotAvailable: 53 | pass 54 | 55 | try: 56 | if not is_torch_available(): 57 | raise OptionalDependencyNotAvailable() 58 | except OptionalDependencyNotAvailable: 59 | pass 60 | else: 61 | from .modeling_mplug_owl import ( 62 | MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST, 63 | MplugOwlForConditionalGeneration, 64 | MplugOwlModel, 65 | MplugOwlPreTrainedModel, 66 | ) 67 | 68 | 69 | else: 70 | import sys 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 73 | 74 | from .configuration_mplug_owl import * 75 | from .modeling_mplug_owl import * 76 | from .processing_mplug_owl import * 77 | from .tokenization_mplug_owl import * 78 | -------------------------------------------------------------------------------- /Owl/mplug_owl/tokenization_mplug_owl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 x-plug and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for MplugOwl.""" 16 | 17 | from transformers.utils import logging 18 | from transformers.models.llama.tokenization_llama import LlamaTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "MAGAer13/mplug-owl-llama-7b": "https://huggingface.co/MAGAer13/mplug-owl-llama-7b/resolve/main/vocab.txt", 28 | }, 29 | } 30 | 31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 32 | "MAGAer13/mplug-owl-llama-7b": 1024, 33 | } 34 | 35 | 36 | class MplugOwlTokenizer(LlamaTokenizer): 37 | def __init__( 38 | self, 39 | vocab_file, 40 | unk_token="", 41 | bos_token="", 42 | eos_token="", 43 | pad_token="", 44 | sp_model_kwargs=None, 45 | add_bos_token=False, 46 | add_eos_token=False, 47 | clean_up_tokenization_spaces=False, 48 | **kwargs, 49 | ): 50 | super().__init__( 51 | vocab_file, 52 | unk_token, 53 | bos_token, 54 | eos_token, 55 | pad_token, 56 | sp_model_kwargs, 57 | add_bos_token, 58 | add_eos_token, 59 | clean_up_tokenization_spaces, 60 | **kwargs, 61 | ) 62 | self.eod_id = self.eos_token_id 63 | -------------------------------------------------------------------------------- /Owl/mplug_owl_video/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available 17 | 18 | 19 | _import_structure = { 20 | "configuration_mplug_owl": ["MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MplugOwlConfig"], 21 | "processing_mplug_owl": ["MplugOwlImageProcessor", "MplugOwlProcessor"], 22 | "tokenization_mplug_owl": ["MplugOwlTokenizer"], 23 | } 24 | 25 | try: 26 | if not is_tokenizers_available(): 27 | raise OptionalDependencyNotAvailable() 28 | except OptionalDependencyNotAvailable: 29 | pass 30 | 31 | 32 | try: 33 | if not is_torch_available(): 34 | raise OptionalDependencyNotAvailable() 35 | except OptionalDependencyNotAvailable: 36 | pass 37 | else: 38 | _import_structure["modeling_mplug_owl"] = [ 39 | "MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST", 40 | "MplugOwlForConditionalGeneration", 41 | "MplugOwlModel", 42 | ] 43 | 44 | 45 | if TYPE_CHECKING: 46 | from .configuration_mplug_owl import MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP, MplugOwlConfig 47 | from .tokenization_mplug_owl import MplugOwlTokenizer 48 | 49 | try: 50 | if not is_tokenizers_available(): 51 | raise OptionalDependencyNotAvailable() 52 | except OptionalDependencyNotAvailable: 53 | pass 54 | 55 | try: 56 | if not is_torch_available(): 57 | raise OptionalDependencyNotAvailable() 58 | except OptionalDependencyNotAvailable: 59 | pass 60 | else: 61 | from .modeling_mplug_owl import ( 62 | MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST, 63 | MplugOwlForConditionalGeneration, 64 | MplugOwlModel, 65 | MplugOwlPreTrainedModel, 66 | ) 67 | 68 | 69 | else: 70 | import sys 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 73 | 74 | from .configuration_mplug_owl import * 75 | from .modeling_mplug_owl import * 76 | from .processing_mplug_owl import * 77 | from .tokenization_mplug_owl import * 78 | -------------------------------------------------------------------------------- /Owl/mplug_owl_video/tokenization_mplug_owl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 x-plug and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for MplugOwl.""" 16 | 17 | from transformers.utils import logging 18 | from transformers.models.llama.tokenization_llama import LlamaTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "MAGAer13/mplug-owl-llama-7b": "https://huggingface.co/MAGAer13/mplug-owl-llama-7b/resolve/main/vocab.txt", 28 | }, 29 | } 30 | 31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 32 | "MAGAer13/mplug-owl-llama-7b": 2048, 33 | } 34 | 35 | 36 | class MplugOwlTokenizer(LlamaTokenizer): 37 | def __init__( 38 | self, 39 | vocab_file, 40 | unk_token="", 41 | bos_token="", 42 | eos_token="", 43 | pad_token="", 44 | sp_model_kwargs=None, 45 | add_bos_token=False, 46 | add_eos_token=False, 47 | clean_up_tokenization_spaces=False, 48 | **kwargs, 49 | ): 50 | super().__init__( 51 | vocab_file, 52 | unk_token, 53 | bos_token, 54 | eos_token, 55 | pad_token, 56 | sp_model_kwargs, 57 | add_bos_token, 58 | add_eos_token, 59 | clean_up_tokenization_spaces, 60 | **kwargs, 61 | ) 62 | self.eod_id = self.eos_token_id 63 | -------------------------------------------------------------------------------- /Owl/pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/pipeline/__init__.py -------------------------------------------------------------------------------- /Owl/pipeline/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .processors.builder import build_processors 2 | from .xgpt3_dataset import MultiModalDataset 3 | 4 | def train_valid_test_datasets_provider(data_path, config, tokenizer, seq_length=1024): 5 | """Build train and valid datasets.""" 6 | print('> building train and validation datasets for mPLUG-Owl ...') 7 | train_ds, valid_ds = build_train_valid_test_datasets( 8 | input_file=data_path, 9 | tokenizer=tokenizer, 10 | max_length=seq_length, 11 | config=config) 12 | print("> finished creating mPLUG-Owl datasets ...") 13 | 14 | return train_ds, valid_ds 15 | 16 | def build_train_valid_test_datasets(input_file, tokenizer, max_length=80, config=None): 17 | train_processors = build_processors(config['train_processors']) 18 | valid_processors = build_processors(config['valid_processors']) 19 | 20 | assert len(input_file) == 2 # If you have files more than 2, modify code at here or merger them into train and dev 21 | train_ds = MultiModalDataset(input_file[0], tokenizer, train_processors, max_length) 22 | valid_ds = MultiModalDataset(input_file[1], tokenizer, valid_processors, max_length) 23 | test_ds = None 24 | return (train_ds, valid_ds) 25 | -------------------------------------------------------------------------------- /Owl/pipeline/data_utils/processors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba. All rights reserved. 2 | from .builder import PROCESSORS, build_processors 3 | from .default_processor import DefaultProcessor 4 | from .caption_processor import CaptionProcessor 5 | 6 | __all__ = [ 7 | 'PROCESSORS', 'build_processors', 8 | 'DefaultProcessor', 'CaptionProcessor' 9 | ] -------------------------------------------------------------------------------- /Owl/pipeline/data_utils/processors/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from data_utils.registry import Registry, build_from_cfg 5 | 6 | PROCESSORS = Registry('processors') 7 | 8 | def build_processors(processors_cfg): 9 | processors = dict() 10 | for task, processor in processors_cfg.items(): 11 | processors[task] = build_from_cfg(processor, PROCESSORS) 12 | return processors 13 | -------------------------------------------------------------------------------- /Owl/pipeline/data_utils/processors/caption_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from PIL import Image 4 | import random 5 | 6 | from data_utils.randaugment import RandomAugment 7 | from .builder import PROCESSORS 8 | 9 | 10 | @PROCESSORS.register_module() 11 | class CaptionProcessor: 12 | def __init__(self, image_size=224, min_scale = 0.5, randaug=False): 13 | self.image_size = image_size 14 | self.min_scale = min_scale 15 | 16 | if randaug: 17 | self.image_transform = transforms.Compose([ 18 | transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC), 19 | transforms.RandomHorizontalFlip(), 20 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 21 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 24 | ]) 25 | else: 26 | self.image_transform = transforms.Compose([ 27 | transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 31 | ]) 32 | self.text_transform = None 33 | 34 | def __call__(self, image, text): 35 | assert image or text 36 | 37 | if image: 38 | image_input = self.image_transform(image) 39 | else: 40 | image_input = None 41 | 42 | if text: 43 | if isinstance(text["prompt"], list): 44 | prompt = random.choice(text["prompt"]) 45 | else: 46 | prompt = text["prompt"] 47 | text_input = dict( 48 | prompt=prompt, 49 | completion=text["text"], 50 | ) 51 | else: 52 | text_input = None 53 | return image_input, text_input -------------------------------------------------------------------------------- /Owl/pipeline/data_utils/processors/default_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from PIL import Image 4 | import random 5 | 6 | from data_utils.randaugment import RandomAugment 7 | from .builder import PROCESSORS 8 | 9 | 10 | @PROCESSORS.register_module() 11 | class DefaultProcessor: 12 | def __init__(self, image_size=224): 13 | self.image_size = image_size 14 | 15 | self.image_transform = transforms.Compose([ 16 | transforms.Resize((image_size, image_size),interpolation=Image.BICUBIC), 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 19 | ]) 20 | 21 | self.text_transform = None 22 | 23 | def __call__(self, image, text): 24 | assert image or text 25 | 26 | if image: 27 | image_input = self.image_transform(image) 28 | else: 29 | image_input = None 30 | 31 | if text: 32 | if isinstance(text["prompt"], list): 33 | prompt = random.choice(text["prompt"]) 34 | else: 35 | prompt = text["prompt"] 36 | text_input = dict( 37 | prompt=prompt, 38 | completion=text["text"], 39 | ) 40 | else: 41 | text_input = None 42 | return image_input, text_input -------------------------------------------------------------------------------- /Owl/pipeline/interface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import requests 4 | from PIL import Image 5 | import os, sys 6 | sys.path.append('/workspace/hal/code/Owl/') 7 | sys.path.append('/workspace/hal/Owl/') 8 | # sys.path.append('/data/ant/code/Owl/mplug_owl/') 9 | # sys.path.append('/data/ant/code/Owl/pipeline/') 10 | # sys.path.append('/data/ant/code/') 11 | from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration 12 | from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor 13 | from transformers import AutoTokenizer 14 | 15 | 16 | def get_model(pretrained_ckpt, use_bf16=False): 17 | """Model Provider with tokenizer and processor. 18 | 19 | Args: 20 | pretrained_ckpt (string): The path to pre-trained checkpoint. 21 | use_bf16 (bool, optional): Whether to use bfloat16 to load the model. Defaults to False. 22 | 23 | Returns: 24 | model: MplugOwl Model 25 | tokenizer: MplugOwl text tokenizer 26 | processor: MplugOwl processor (including text and image) 27 | """ 28 | model = MplugOwlForConditionalGeneration.from_pretrained( 29 | pretrained_ckpt, 30 | torch_dtype=torch.bfloat16 if use_bf16 else torch.half, 31 | ) 32 | image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) 33 | tokenizer = AutoTokenizer.from_pretrained(pretrained_ckpt) 34 | processor = MplugOwlProcessor(image_processor, tokenizer) 35 | return model, tokenizer, processor 36 | 37 | 38 | def do_generate(prompts, image_list, model, tokenizer, processor, use_bf16=False, **generate_kwargs): 39 | """The interface for generation 40 | 41 | Args: 42 | prompts (List[str]): The prompt text 43 | image_list (List[str]): Paths of images 44 | model (MplugOwlForConditionalGeneration): MplugOwlForConditionalGeneration 45 | tokenizer (AutoTokenizer): AutoTokenizer 46 | processor (MplugOwlProcessor): MplugOwlProcessor 47 | use_bf16 (bool, optional): Whether to use bfloat16. Defaults to False. 48 | 49 | Returns: 50 | sentence (str): Generated sentence. 51 | """ 52 | if image_list: 53 | images = [Image.open(_) for _ in image_list] 54 | else: 55 | images = None 56 | inputs = processor(text=prompts, images=images, return_tensors='pt') 57 | print(f"inputs={inputs}, inputs.shape={inputs['pixel_values'].shape}") 58 | inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} 59 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 60 | with torch.no_grad(): 61 | res = model.generate(**inputs, **generate_kwargs) 62 | sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) 63 | return sentence 64 | 65 | 66 | if __name__ == '__main__': 67 | prompts = ['''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. 68 | Human: 69 | Human: Explain why this meme is funny. 70 | AI: '''] 71 | image_list = ['/workspace/hal/Owl/examples/monday.jpg'] 72 | # base_model = 'MAGAer13/mplug-owl-llama-7b' 73 | base_model= '/workspace/hal/checkpoints/mplug_OWL_llama_7b' 74 | # base_model = 'output/sft_v0.1_lora_grad_ckpt/checkpoint-4000/generation_config.json' 75 | model, tokenizer, processor = get_model(base_model, use_bf16=True) 76 | # sentence = do_generate( 77 | # prompts, image_list, model, 78 | # tokenizer, processor, use_bf16=True, 79 | # max_length=512, top_k=5, do_sample=True 80 | # ) 81 | print(tokenizer.eos_token_id) 82 | print(tokenizer.pad_token_id) 83 | print(tokenizer.bos_token_id) 84 | # print(sentence) 85 | -------------------------------------------------------------------------------- /Owl/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.28.1 2 | einops 3 | icecream 4 | flask 5 | ruamel.yaml 6 | uvicorn 7 | fastapi 8 | markdown2 9 | gradio 10 | sconf 11 | tensorboardX 12 | tensorboard 13 | h5py 14 | sentencepiece 15 | peft 16 | opencv-python 17 | decord 18 | chardet 19 | cchardet -------------------------------------------------------------------------------- /Owl/scripts/train_it.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR=`pwd` 3 | DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` 4 | 5 | if [ $MASTER_ADDR ];then 6 | echo $MASTER_ADDR 7 | echo $MASTER_PORT 8 | echo $WORLD_SIZE 9 | echo $RANK 10 | else 11 | MASTER_ADDR=127.0.0.1 12 | MASTER_PORT=2$(($RANDOM % 10))$(($RANDOM % 10))15 13 | WORLD_SIZE=1 14 | RANK=0 15 | fi 16 | 17 | DISTRIBUTED_ARGS="--nproc_per_node 8 \ 18 | --nnodes ${WORLD_SIZE} \ 19 | --node_rank ${RANK} \ 20 | --master_addr ${MASTER_ADDR} \ 21 | --master_port ${MASTER_PORT}" 22 | 23 | EXP_NAME=sft_v0.1 24 | SAVE_NAME=sft_v0.1_ft_grad_ckpt 25 | 26 | SAVE_PATH="./output/${SAVE_NAME}/" 27 | 28 | max_length=2048 29 | micro_batch_size=4 30 | global_batch_size=256 31 | gradient_accumulation_steps=1 32 | 33 | # train_iters = total_data * train_epochs // global_batch_size 34 | # 361481 * 3 / 256 = 4236 35 | train_epochs=3 36 | train_iters=4236 37 | 38 | lr_warmup_iters=50 39 | 40 | eval_iter=50 41 | eval_interval=50 42 | save_interval=500 43 | 44 | mkdir -p ${SAVE_PATH} 45 | 46 | options=" \ 47 | --pretrained-ckpt MAGAer13/mplug-owl-llama-7b-pt \ 48 | --seq-length ${max_length} \ 49 | --micro-batch-size ${micro_batch_size} \ 50 | --num-training-steps ${train_iters} \ 51 | --train-epochs ${train_epochs} \ 52 | --num-warmup-steps ${lr_warmup_iters} \ 53 | --gradient-accumulation-steps ${gradient_accumulation_steps} \ 54 | --lr 2e-5 \ 55 | --min-lr 1e-6 \ 56 | --eval-iters ${eval_iter} \ 57 | --save-interval ${save_interval} \ 58 | --save-path ${SAVE_PATH} \ 59 | --clip-grad 1.0 \ 60 | --weight-decay 0.0001 \ 61 | --adam-beta1 0.9 \ 62 | --adam-beta2 0.999 \ 63 | --num-workers 32 \ 64 | --use-lora \ 65 | --gradient-checkpointing \ 66 | --bf16" 67 | 68 | multimodal_options=" \ 69 | --mm-config configs/v0.yaml 70 | " 71 | 72 | python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pipeline/train.py $@ ${options} ${multimodal_options} 2>&1 | tee ${SAVE_PATH}/train.log -------------------------------------------------------------------------------- /Owl/scripts/train_it_wo_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR=`pwd` 3 | DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` 4 | 5 | if [ $MASTER_ADDR ];then 6 | echo $MASTER_ADDR 7 | echo $MASTER_PORT 8 | echo $WORLD_SIZE 9 | echo $RANK 10 | else 11 | MASTER_ADDR=127.0.0.1 12 | MASTER_PORT=2$(($RANDOM % 10))$(($RANDOM % 10))15 13 | WORLD_SIZE=1 14 | RANK=0 15 | fi 16 | 17 | DISTRIBUTED_ARGS="--nproc_per_node 8 \ 18 | --nnodes ${WORLD_SIZE} \ 19 | --node_rank ${RANK} \ 20 | --master_addr ${MASTER_ADDR} \ 21 | --master_port ${MASTER_PORT}" 22 | 23 | EXP_NAME=sft_v0.1 24 | SAVE_NAME=sft_v0.1_ft_grad_ckpt 25 | 26 | SAVE_PATH="./output/${SAVE_NAME}/" 27 | 28 | max_length=2048 29 | micro_batch_size=1 30 | global_batch_size=256 31 | gradient_accumulation_steps=4 32 | 33 | # train_iters = total_data * train_epochs // global_batch_size 34 | # 361481 * 3 / 256 = 4236 35 | train_epochs=3 36 | train_iters=4236 37 | 38 | lr_warmup_iters=50 39 | lr_decay_iters=`expr $train_iters - $lr_warmup_iters` 40 | 41 | eval_iter=50 42 | eval_interval=50 43 | save_interval=500 44 | 45 | mkdir -p ${SAVE_PATH} 46 | 47 | options=" \ 48 | --pretrained-ckpt MAGAer13/mplug-owl-llama-7b-pt \ 49 | --seq-length ${max_length} \ 50 | --micro-batch-size ${micro_batch_size} \ 51 | --train-epochs ${train_epochs} \ 52 | --num-warmup-steps ${lr_warmup_iters} \ 53 | --num-training-steps ${train_iters} \ 54 | --gradient-accumulation-steps ${gradient_accumulation_steps} \ 55 | --lr 1e-5 \ 56 | --min-lr 1e-6 \ 57 | --eval-iters ${eval_iter} \ 58 | --save-interval ${save_interval} \ 59 | --save-path ${SAVE_PATH} \ 60 | --clip-grad 1.0 \ 61 | --weight-decay 0.0001 \ 62 | --adam-beta1 0.9 \ 63 | --adam-beta2 0.999 \ 64 | --num-workers 32 \ 65 | --gradient-checkpointing \ 66 | --bf16" 67 | 68 | multimodal_options=" \ 69 | --mm-config configs/v0.yaml 70 | " 71 | 72 | python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pipeline/train.py $@ ${options} ${multimodal_options} 2>&1 | tee ${SAVE_PATH}/train.log -------------------------------------------------------------------------------- /Owl/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/serve/__init__.py -------------------------------------------------------------------------------- /Owl/serve/gradio_css.py: -------------------------------------------------------------------------------- 1 | code_highlight_css = ( 2 | """ 3 | #chatbot .hll { background-color: #ffffcc } 4 | #chatbot .c { color: #408080; font-style: italic } 5 | #chatbot .err { border: 1px solid #FF0000 } 6 | #chatbot .k { color: #008000; font-weight: bold } 7 | #chatbot .o { color: #666666 } 8 | #chatbot .ch { color: #408080; font-style: italic } 9 | #chatbot .cm { color: #408080; font-style: italic } 10 | #chatbot .cp { color: #BC7A00 } 11 | #chatbot .cpf { color: #408080; font-style: italic } 12 | #chatbot .c1 { color: #408080; font-style: italic } 13 | #chatbot .cs { color: #408080; font-style: italic } 14 | #chatbot .gd { color: #A00000 } 15 | #chatbot .ge { font-style: italic } 16 | #chatbot .gr { color: #FF0000 } 17 | #chatbot .gh { color: #000080; font-weight: bold } 18 | #chatbot .gi { color: #00A000 } 19 | #chatbot .go { color: #888888 } 20 | #chatbot .gp { color: #000080; font-weight: bold } 21 | #chatbot .gs { font-weight: bold } 22 | #chatbot .gu { color: #800080; font-weight: bold } 23 | #chatbot .gt { color: #0044DD } 24 | #chatbot .kc { color: #008000; font-weight: bold } 25 | #chatbot .kd { color: #008000; font-weight: bold } 26 | #chatbot .kn { color: #008000; font-weight: bold } 27 | #chatbot .kp { color: #008000 } 28 | #chatbot .kr { color: #008000; font-weight: bold } 29 | #chatbot .kt { color: #B00040 } 30 | #chatbot .m { color: #666666 } 31 | #chatbot .s { color: #BA2121 } 32 | #chatbot .na { color: #7D9029 } 33 | #chatbot .nb { color: #008000 } 34 | #chatbot .nc { color: #0000FF; font-weight: bold } 35 | #chatbot .no { color: #880000 } 36 | #chatbot .nd { color: #AA22FF } 37 | #chatbot .ni { color: #999999; font-weight: bold } 38 | #chatbot .ne { color: #D2413A; font-weight: bold } 39 | #chatbot .nf { color: #0000FF } 40 | #chatbot .nl { color: #A0A000 } 41 | #chatbot .nn { color: #0000FF; font-weight: bold } 42 | #chatbot .nt { color: #008000; font-weight: bold } 43 | #chatbot .nv { color: #19177C } 44 | #chatbot .ow { color: #AA22FF; font-weight: bold } 45 | #chatbot .w { color: #bbbbbb } 46 | #chatbot .mb { color: #666666 } 47 | #chatbot .mf { color: #666666 } 48 | #chatbot .mh { color: #666666 } 49 | #chatbot .mi { color: #666666 } 50 | #chatbot .mo { color: #666666 } 51 | #chatbot .sa { color: #BA2121 } 52 | #chatbot .sb { color: #BA2121 } 53 | #chatbot .sc { color: #BA2121 } 54 | #chatbot .dl { color: #BA2121 } 55 | #chatbot .sd { color: #BA2121; font-style: italic } 56 | #chatbot .s2 { color: #BA2121 } 57 | #chatbot .se { color: #BB6622; font-weight: bold } 58 | #chatbot .sh { color: #BA2121 } 59 | #chatbot .si { color: #BB6688; font-weight: bold } 60 | #chatbot .sx { color: #008000 } 61 | #chatbot .sr { color: #BB6688 } 62 | #chatbot .s1 { color: #BA2121 } 63 | #chatbot .ss { color: #19177C } 64 | #chatbot .bp { color: #008000 } 65 | #chatbot .fm { color: #0000FF } 66 | #chatbot .vc { color: #19177C } 67 | #chatbot .vg { color: #19177C } 68 | #chatbot .vi { color: #19177C } 69 | #chatbot .vm { color: #19177C } 70 | #chatbot .il { color: #666666 } 71 | """) 72 | #.highlight { background: #f8f8f8; } 73 | 74 | -------------------------------------------------------------------------------- /Owl/serve/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import torch 5 | import transformers 6 | import traceback 7 | 8 | from queue import Queue 9 | from threading import Thread 10 | 11 | 12 | def post_process_output(text): 13 | text = text.strip() 14 | pattern = re.compile( 15 | r"||||\[PAD\]|<\|endoftext\|>|\[UNK\]|\[CLS\]|\[MASK\]|<\|startofpiece\|>|<\|endofpiece\|>|\[gMASK\]|\[sMASK\]" 16 | ) 17 | text = pattern.sub("", text.strip()).strip() 18 | return text 19 | 20 | 21 | def post_process_code(code): 22 | sep = "\n```" 23 | if sep in code: 24 | blocks = code.split(sep) 25 | if len(blocks) % 2 == 1: 26 | for i in range(1, len(blocks), 2): 27 | blocks[i] = blocks[i].replace("\\_", "_") 28 | code = sep.join(blocks) 29 | return code 30 | 31 | 32 | class Stream(transformers.StoppingCriteria): 33 | def __init__(self, callback_func=None): 34 | self.callback_func = callback_func 35 | 36 | def __call__(self, input_ids, scores) -> bool: 37 | if self.callback_func is not None: 38 | self.callback_func(input_ids[0]) 39 | return False 40 | 41 | 42 | class Iteratorize: 43 | 44 | """ 45 | Transforms a function that takes a callback 46 | into a lazy iterator (generator). 47 | """ 48 | 49 | def __init__(self, func, kwargs={}, callback=None): 50 | self.mfunc = func 51 | self.c_callback = callback 52 | self.q = Queue() 53 | self.sentinel = object() 54 | self.kwargs = kwargs 55 | self.stop_now = False 56 | 57 | def _callback(val): 58 | if self.stop_now: 59 | raise ValueError 60 | self.q.put(val) 61 | 62 | def gentask(): 63 | try: 64 | ret = self.mfunc(callback=_callback, **self.kwargs) 65 | except ValueError: 66 | pass 67 | except: 68 | traceback.print_exc() 69 | pass 70 | 71 | self.q.put(self.sentinel) 72 | if self.c_callback: 73 | self.c_callback(ret) 74 | 75 | self.thread = Thread(target=gentask) 76 | self.thread.start() 77 | 78 | def __iter__(self): 79 | return self 80 | 81 | def __next__(self): 82 | obj = self.q.get(True, None) 83 | if obj is self.sentinel: 84 | raise StopIteration 85 | else: 86 | return obj 87 | 88 | def __enter__(self): 89 | return self 90 | 91 | def __exit__(self, exc_type, exc_val, exc_tb): 92 | self.stop_now = True -------------------------------------------------------------------------------- /Owl/serve/serve_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import gradio as gr 4 | import logging 5 | import sys 6 | import os 7 | import json 8 | import requests 9 | from .conversation import default_conversation 10 | from .gradio_patch import Chatbot as grChatbot 11 | from .gradio_css import code_highlight_css 12 | import datetime 13 | import uuid 14 | import base64 15 | from io import BytesIO 16 | import time 17 | 18 | from .io_utils import IO, DefaultIO, OSS 19 | 20 | 21 | handler = None 22 | 23 | 24 | class _IOWrapper: 25 | def __init__(self): 26 | self._io = DefaultIO() 27 | 28 | def set_io(self, new_io): 29 | self._io = new_io 30 | 31 | def __getattr__(self, name): 32 | if hasattr(self._io, name): 33 | return getattr(self._io, name) 34 | return super().__getattr__(name) 35 | 36 | def __str__(self): 37 | return self._io.__name__ 38 | 39 | def init(): 40 | io = _IOWrapper() 41 | return io 42 | 43 | 44 | def vote_last_response(state, vote_type, model_selector, request: gr.Request): 45 | pass 46 | 47 | def upvote_last_response(state, model_selector, request: gr.Request): 48 | vote_last_response(state, "upvote", model_selector, request) 49 | return ("",) + (disable_btn,) * 3 50 | 51 | def downvote_last_response(state, model_selector, request: gr.Request): 52 | vote_last_response(state, "downvote", model_selector, request) 53 | return ("",) + (disable_btn,) * 3 54 | 55 | def flag_last_response(state, model_selector, request: gr.Request): 56 | vote_last_response(state, "flag", model_selector, request) 57 | return ("",) + (disable_btn,) * 3 58 | 59 | def regenerate(state, request: gr.Request): 60 | state.messages[-1][-1] = None 61 | state.skip_next = False 62 | return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 63 | 64 | def clear_history(request: gr.Request): 65 | state = default_conversation.copy() 66 | return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 67 | 68 | 69 | def add_text(state, text, image, video, request: gr.Request): 70 | if len(text) <= 0 and (image is None or video is None): 71 | state.skip_next = True 72 | return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5 73 | 74 | if image is not None: 75 | if '' not in text: 76 | text = text + '\n' 77 | text = (text, image) 78 | 79 | if video is not None: 80 | num_frames = 4 81 | if '' not in text: 82 | text = text + '\n' * num_frames 83 | text = (text, video) 84 | 85 | state.append_message(state.roles[0], text) 86 | state.append_message(state.roles[1], None) 87 | state.skip_next = False 88 | return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 89 | 90 | def after_process_image(prompt): 91 | prompt = prompt.replace("\n", "") 92 | pro_prompt = "" 93 | prompt = prompt.split("\n") 94 | for p in prompt: 95 | if p.count("") > 0: 96 | pro_prompt += "Human: \n" 97 | if p != "": 98 | pro_prompt += p.replace("", "") + "\n" 99 | else: 100 | pro_prompt += p + "\n" 101 | return (pro_prompt[:-1]+" ").replace("\n Human", "\nHuman").replace("\n AI", "\nAI") 102 | 103 | 104 | headers = {"User-Agent": "mPLUG-Owl Client"} 105 | 106 | no_change_btn = gr.Button.update() 107 | enable_btn = gr.Button.update(interactive=True) 108 | disable_btn = gr.Button.update(interactive=False) 109 | 110 | get_window_url_params = """ 111 | function() { 112 | const params = new URLSearchParams(window.location.search); 113 | url_params = Object.fromEntries(params); 114 | console.log(url_params); 115 | return url_params; 116 | } 117 | """ -------------------------------------------------------------------------------- /Owl/test-ant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | print(isinstance('loss', (list, torch.Tensor))) 3 | -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/assets/method.png -------------------------------------------------------------------------------- /common/gpt.py: -------------------------------------------------------------------------------- 1 | import openai 2 | 3 | openai.api_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" 4 | proxy = "http://127.0.0.1:7890" 5 | 6 | 7 | def gpt_infer(messages: list, model="gpt-3.5-turbo"): 8 | completion = openai.ChatCompletion.create(model=model, messages=messages, proxy=proxy) 9 | return completion.choices[0].message.content # type: ignore 10 | 11 | 12 | if __name__ == "__main__": 13 | messages = [ 14 | {"role": "user", "content": "鲁迅和周树人的关系"}, 15 | ] 16 | print(gpt_infer(messages)) 17 | -------------------------------------------------------------------------------- /configs/minigpt4_infer_fp16-dpo.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: minigpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | has_qformer: False 11 | 12 | model_type: pretrain_llama2 13 | max_txt_len: 160 14 | end_sym: "" 15 | low_resource: False 16 | prompt_template: "[INST] {} [/INST] " 17 | ckpt: "/home/ant/llm-hal/efuf/checkpoints/minigpt4_llama2_7b/pretrained.pth" 18 | 19 | # generation configs 20 | prompt: "" 21 | 22 | # llama_model: "/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf" 23 | llama_model: "/home/ant/llm-hal/HA-DPO/ha_dpo/models/minigpt4/merged_minigpt4_ha_dpo" 24 | 25 | preprocess: 26 | vis_processor: 27 | train: 28 | name: "blip2_image_train" 29 | image_size: 224 30 | eval: 31 | name: "blip2_image_eval" 32 | image_size: 224 33 | text_processor: 34 | train: 35 | name: "blip_caption" 36 | eval: 37 | name: "blip_caption" 38 | 39 | datasets: 40 | cc_sbu_align: 41 | vis_processor: 42 | train: 43 | name: "blip2_image_eval" 44 | image_size: 224 45 | text_processor: 46 | train: 47 | name: "blip_caption" 48 | 49 | run: 50 | task: image_text_pretrain 51 | -------------------------------------------------------------------------------- /configs/minigpt4_infer_fp16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: minigpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | has_qformer: False 11 | 12 | model_type: pretrain_llama2 13 | max_txt_len: 160 14 | end_sym: "" 15 | low_resource: False 16 | prompt_template: "[INST] {} [/INST] " 17 | ckpt: "checkpoints/minigpt4_llama2_7b/pretrained.pth" 18 | 19 | # generation configs 20 | prompt: "" 21 | 22 | llama_model: "/home/nfs02/model/llama2/hf/Llama-2-7b-chat-hf/" 23 | 24 | preprocess: 25 | vis_processor: 26 | train: 27 | name: "blip2_image_train" 28 | image_size: 224 29 | eval: 30 | name: "blip2_image_eval" 31 | image_size: 224 32 | text_processor: 33 | train: 34 | name: "blip_caption" 35 | eval: 36 | name: "blip_caption" 37 | 38 | datasets: 39 | cc_sbu_align: 40 | vis_processor: 41 | train: 42 | name: "blip2_image_eval" 43 | image_size: 224 44 | text_processor: 45 | train: 46 | name: "blip_caption" 47 | 48 | run: 49 | task: image_text_pretrain 50 | -------------------------------------------------------------------------------- /configs/minigpt4_infer_fp16_hadpo.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: minigpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | has_qformer: False 11 | 12 | model_type: pretrain_llama2 13 | max_txt_len: 160 14 | end_sym: "" 15 | low_resource: False 16 | prompt_template: "[INST] {} [/INST] " 17 | ckpt: "/home/ant/llm-hal/efuf/checkpoints/minigpt4_llama2_7b/pretrained.pth" 18 | 19 | # generation configs 20 | prompt: "" 21 | 22 | # llama_model: "/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf" 23 | llama_model: "/home/ant/llm-hal/HA-DPO/ha_dpo/models/minigpt4/merged_minigpt4_ha_dpo" 24 | 25 | preprocess: 26 | vis_processor: 27 | train: 28 | name: "blip2_image_train" 29 | image_size: 224 30 | eval: 31 | name: "blip2_image_eval" 32 | image_size: 224 33 | text_processor: 34 | train: 35 | name: "blip_caption" 36 | eval: 37 | name: "blip_caption" 38 | 39 | datasets: 40 | cc_sbu_align: 41 | vis_processor: 42 | train: 43 | name: "blip2_image_eval" 44 | image_size: 224 45 | text_processor: 46 | train: 47 | name: "blip_caption" 48 | 49 | run: 50 | task: image_text_pretrain 51 | -------------------------------------------------------------------------------- /configs/minigpt4_train_fp16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: minigpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | has_qformer: False 11 | 12 | model_type: pretrain_llama2 13 | max_txt_len: 160 14 | end_sym: "" 15 | low_resource: False 16 | prompt_template: "{}" 17 | ckpt: "checkpoints/minigpt4_llama2_7b/pretrained.pth" 18 | 19 | # generation configs 20 | prompt: "" 21 | 22 | llama_model: "/home/nfs02/model/llama2/hf/Llama-2-7b-chat-hf/" 23 | 24 | preprocess: 25 | vis_processor: 26 | train: 27 | name: "blip2_image_train" 28 | image_size: 224 29 | eval: 30 | name: "blip2_image_eval" 31 | image_size: 224 32 | text_processor: 33 | train: 34 | name: "blip_caption" 35 | eval: 36 | name: "blip_caption" 37 | 38 | datasets: 39 | cc_sbu_align: 40 | vis_processor: 41 | train: 42 | name: "blip2_image_eval" 43 | image_size: 224 44 | text_processor: 45 | train: 46 | name: "blip_caption" 47 | 48 | run: 49 | task: image_text_pretrain 50 | -------------------------------------------------------------------------------- /dataset/caption_prompt.txt: -------------------------------------------------------------------------------- 1 | Please extract all the concrete visual objects in a piece of image caption. If there are brackets around a phrase, it means this is a concrete visual object pre-identified for you in advance, and you must output the phrase. Then you need to examine all noun phrases and infer if they are concrete visual objects in the image. Your response should be part of the input. You **MUST** respond in the format of a non-repeat list "object_1, object_2, ..., object_n". Strictly follow the format and **do not output any other words**, including but not limited to other phrases or additional explanations. 2 | Example: 3 | {input_label} The image shows a train on the tracks, with cars parked on the platform. The train is painted in a blue and white color scheme, with [a large logo] on the front. The cars are painted in a red and white color scheme, with [a small logo] on the front. The train is moving slowly through the station, with [people] standing on the platform and watching. The sky is clear and blue, with a few clouds in the distance. The buildings in the background are tall and made of brick, with windows and doors on the upper floors. There are trees and bushes growing along the sides of the tracks, and [a small bridge] over the tracks in the distance. 4 | {output_label} a train, tracks, cars, platform, a large logo, a small logo, people, sky, clouds, buildings, windows, doors, trees, bushes, a small bridge 5 | -------------------------------------------------------------------------------- /dataset/caption_prompt_finetune.txt: -------------------------------------------------------------------------------- 1 | Please extract all the concrete visual objects in a piece of image caption. You need to examine all noun phrases and infer if they are concrete visual objects in the image. Your response should be part of the input. You **MUST** respond in the format of a non-repeat list "object_1, object_2, ..., object_n". Strictly follow the format and **do not output any other words**, including but not limited to other phrases or additional explanations. 2 | Example: 3 | {input_label} The image shows a train on the tracks, with cars parked on the platform. The train is painted in a blue and white color scheme, with a large logo on the front. The cars are painted in a red and white color scheme, with a small logo on the front. The train is moving slowly through the station, with people standing on the platform and watching. The sky is clear and blue, with a few clouds in the distance. The buildings in the background are tall and made of brick, with windows and doors on the upper floors. There are trees and bushes growing along the sides of the tracks, and a small bridge over the tracks in the distance. 4 | {output_label} a train, tracks, cars, platform, a large logo, a small logo, people, sky, clouds, buildings, windows, doors, trees, bushes, a small bridge 5 | -------------------------------------------------------------------------------- /dataset/synonyms.txt: -------------------------------------------------------------------------------- 1 | person, girl, boy, man, woman, kid, child, chef, baker, people, adult, rider, children, baby, worker, passenger, sister, brother, biker, policeman, cop, officer, lady, cowboy, bride, groom, male, female, guy, traveler, mother, father, gentleman, pitcher, player, skier, snowboarder, skater, skateboarder, guy, foreigner, child, gentleman, caller, offender, coworker, trespasser, patient, politician, soldier, grandchild, serviceman, walker, drinker, doctor, bicyclist, thief, buyer, teenager, student, camper, driver, solider, hunter, shopper, villager, pedestrian 2 | bicycle, bike, unicycle, minibike, trike 3 | car, automobile, van, minivan, sedan, suv, hatchback, cab, jeep, coupe, taxicab, limo, taxi 4 | motorcycle, scooter, motor bike, motor cycle, motorbike, scooter, moped 5 | airplane, jetliner, plane, air plane, monoplane, aircraft, jet, jetliner, airbus, biplane, seaplane 6 | bus, minibus, trolley 7 | train, locomotive, tramway, caboose 8 | truck, pickup, lorry, hauler, firetruck 9 | boat, ship, liner, sailboat, motorboat, dinghy, powerboat, speedboat, canoe, skiff, yacht, kayak, catamaran, pontoon, houseboat, vessel, rowboat, trawler, ferryboat, watercraft, tugboat, schooner, barge, ferry, sailboard, paddleboat, lifeboat, freighter, steamboat, riverboat, battleship, steamship 10 | traffic light, street light, traffic signal, stop light, streetlight, stoplight 11 | fire hydrant, hydrant 12 | stop sign 13 | parking meter 14 | bench, pew 15 | bird, ostrich, owl, seagull, goose, duck, parakeet, falcon, robin, pelican, waterfowl, heron, hummingbird, mallard, finch, pigeon, sparrow, seabird, osprey, blackbird, fowl, shorebird, woodpecker, egret, chickadee, quail, bluebird, kingfisher, buzzard, willet, gull, swan, bluejay, flamingo, cormorant, parrot, loon, gosling, waterbird, pheasant, rooster, sandpiper, crow, raven, turkey, oriole, cowbird, warbler, magpie, peacock, cockatiel, lorikeet, puffin, vulture, condor, macaw, peafowl, cockatoo, songbird 16 | cat, kitten, feline, tabby 17 | dog, puppy, beagle, pup, chihuahua, schnauzer, dachshund, rottweiler, canine, pitbull, collie, pug, terrier, poodle, labrador, doggie, doberman, mutt, doggy, spaniel, bulldog, sheepdog, weimaraner, corgi, cocker, greyhound, retriever, brindle, hound, whippet, husky 18 | horse, colt, pony, racehorse, stallion, equine, mare, foal, palomino, mustang, clydesdale, bronc, bronco 19 | sheep, lamb, ram, lamb, goat, ewe 20 | cow, cattle, oxen, ox, calf, cattle, holstein, heifer, buffalo, bull, zebu, bison 21 | elephant 22 | bear, panda 23 | zebra 24 | giraffe 25 | backpack, knapsack 26 | umbrella 27 | handbag, wallet, purse, briefcase 28 | tie, bow, bow tie 29 | suitcase, suit case, luggage 30 | frisbee 31 | skis, ski 32 | snowboard 33 | sports ball, ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard, longboard, skimboard, shortboard, wakeboard 39 | tennis racket, racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife, pocketknife, knive 45 | spoon 46 | bowl, container 47 | banana 48 | apple 49 | sandwich, burger, sub, cheeseburger, hamburger 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut, doughnut, bagel 56 | cake, cheesecake, cupcake, shortcake, coffeecake, pancake 57 | chair, seat, stool 58 | couch, sofa, recliner, futon, loveseat, settee, chesterfield 59 | potted plant, houseplant 60 | bed 61 | dining table, table, desk, coffee table 62 | toilet, urinal, commode, toilet, lavatory, potty 63 | tv, monitor, televison, television 64 | laptop, computer, notebook, netbook, lenovo, macbook, laptop computer 65 | mouse 66 | remote, remote control 67 | keyboard 68 | cell phone, mobile phone, phone, cellphone, telephone, phon, smartphone, iPhone 69 | microwave 70 | oven, stovetop, stove, stove top oven 71 | toaster 72 | sink 73 | refrigerator, fridge, fridge, freezer 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear, teddybear 79 | hair drier, hairdryer 80 | toothbrush 81 | -------------------------------------------------------------------------------- /dataset/vqa_prompt.txt: -------------------------------------------------------------------------------- 1 | Please extract all the visual objects in a piece of text which is an answer of a Visual Question Answering task. You should carefully infer if the object indeed appears in the image, and avoid those objects that are not depicted in the image. Only extract objects that you are certain to appear in the image according to the text, and do not invent objects that are not mentioned. If you are not sure if an object meets the above condition, do not add it to the result. Respond in the format of a non-repeat list "object_1, object_2, ..., object_n". If there are none, do not respond anything. Strictly follow the format and **do not output any other words**, including but not limited to other non-visual objects or explanations. 2 | 3 | Example: 4 | Input: If someone wants to try a similar water sport, such as surfing, one important factor to consider is the need for proper training and skill development. In the image, the surfer is seen riding a wave on his surfboard, showcasing balance and technique. Learning the fundamentals of surfing, including wave selection, paddling, standing up on the board, and maneuvering, can significantly increase the chances of a successful and enjoyable experience. Additionally, it's essential to use appropriate safety gear such as a leash and a wetsuit, as well as follow local surfing etiquette and regulations to ensure the safety of oneself and others in the water. 5 | Output: surfer, wave, surfboard 6 | 7 | Task: 8 | Input: {} 9 | Output: -------------------------------------------------------------------------------- /deploy/dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-devel 2 | 3 | # Copy all files to the /build directory 4 | COPY . /build 5 | WORKDIR /build 6 | RUN chmod -R u+rwX,go+rX,go-w /build 7 | 8 | # Update sources list and install necessary packages 9 | RUN cp sources.list /etc/apt/sources.list && apt-get update && apt-get install -y python3-pip git libgl1 10 | 11 | # Install Python dependencies 12 | RUN pip install -r requirements.txt -i https://mirrors.nju.edu.cn/pypi/web/simple/ 13 | 14 | # Install flash attention 15 | RUN if [ -e "flash_attn-2.5.5+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl" ]; then \ 16 | pip install flash_attn-2.5.5+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; \ 17 | else \ 18 | pip install flash_attn==2.5.5 -i https://mirrors.nju.edu.cn/pypi/web/simple/; \ 19 | fi 20 | 21 | # Assign workdir 22 | WORKDIR /workspace 23 | 24 | # it should be built: 25 | # cd deploy 26 | # docker build -t efuf:1.0 . 27 | 28 | # then run: 29 | # cd .. 30 | # docker run --gpus all --ipc=host --network=host --rm -it -v .:/workspace efuf:1.0 -------------------------------------------------------------------------------- /deploy/requirements.txt: -------------------------------------------------------------------------------- 1 | openai==0.28.0 2 | tqdm 3 | transformers==4.31.0 4 | git+https://github.com/MaureenZOU/detectron2-xyz.git 5 | einops 6 | einops-exts 7 | wandb 8 | nltk 9 | pandas 10 | spacy 11 | torchmetrics 12 | altair 13 | whisper 14 | gpustat 15 | timm 16 | opencv-python-headless 17 | webdataset 18 | visual_genome 19 | scikit-image 20 | visual_genome 21 | decord 22 | peft 23 | sentence-transformers 24 | gradio 25 | fairscale 26 | easydict 27 | pycocoevalcap 28 | moviepy 29 | pyyaml_env_tag 30 | open3d 31 | h5py 32 | diffusers==0.15.0 33 | seaborn -------------------------------------------------------------------------------- /deploy/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker run --gpus all --network=host -it \ 3 | -v .:/workspace/hal \ 4 | -v /data/cache/huggingface:/root/.cache/huggingface \ 5 | -v /data/cache/torch:/root/.cache/torch \ 6 | -v /root/nltk_data:/root/nltk_data \ 7 | mm_hal:1.2 -------------------------------------------------------------------------------- /deploy/run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker run --gpus all --ipc=host --network=host --rm -it \ 3 | -v .:/workspace/hal \ 4 | -v /data/cache/huggingface:/root/.cache/huggingface \ 5 | -v /data/cache/torch:/root/.cache/torch \ 6 | -v /root/nltk_data:/root/nltk_data \ 7 | -v /data/NJU/lib/zf/LLaVA:/workspace/hal/LLaVA \ 8 | mm_hal:1.3 9 | -------------------------------------------------------------------------------- /deploy/sources.list: -------------------------------------------------------------------------------- 1 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse 2 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse 3 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse 4 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse -------------------------------------------------------------------------------- /evaluate/eval_caption.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2023-11-12 09:22:40 3 | # @Author : Shangyu.Xing (starreeze@foxmail.com) 4 | 5 | from __future__ import annotations 6 | import os, sys, torch 7 | 8 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 10 | from io import TextIOWrapper 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from torch.utils.data import DataLoader, Dataset 14 | from common.models import model_loaders, generators 15 | from common.args import args 16 | 17 | 18 | class CocoImageDataset(Dataset): 19 | def __init__(self, image_names: list[str], processor): 20 | super(CocoImageDataset, self).__init__() 21 | self.image_names = image_names 22 | self.processor = processor 23 | 24 | def __len__(self): 25 | return len(self.image_names) 26 | 27 | def __getitem__(self, index): 28 | image_path = os.path.join(args.image_dir_path, self.image_names[index]) 29 | img = self.processor(Image.open(image_path).convert("RGB")) 30 | if args.device != "cpu": 31 | img = img.to(torch.float16).to(args.device) 32 | return img, self.image_names[index] 33 | 34 | 35 | def process_single(batch, model, prompt: str, output_fd: TextIOWrapper): 36 | images, image_names = batch 37 | texts = [prompt] * args.infer_bs_total 38 | results = [""] * args.infer_bs_total 39 | filtered = [] 40 | for _ in range(args.infer_retry): 41 | with torch.no_grad(): 42 | answers = generators[args.model](model, texts, images) 43 | for i, (name, answer) in enumerate(zip(image_names, answers)): 44 | if answer.replace("\n", "") and not results[i]: 45 | results[i] = name + " ### " + answer.replace("\n", " ") 46 | filtered = [r for r in results if r] 47 | if len(filtered) == len(results): 48 | break 49 | output_fd.write("\n".join(filtered) + "\n") 50 | output_fd.flush() 51 | 52 | 53 | def main(): 54 | model_load_path = getattr(args, f"{args.model}_ckpt_load_path") 55 | model_args = ["--cfg-path", args.minigpt_train_cfg] if args.model == "minigpt" else [] 56 | model, vis_processor = model_loaders[args.model](model_load_path, args.device, False, model_args) 57 | with open(args.object_data_path, "r") as f: 58 | objects = f.read().splitlines() 59 | images_used = {obj.split(args.column_splitter)[0] for obj in objects} 60 | image_names = filter(lambda x: x not in images_used, sorted(os.listdir(args.image_dir_path))) 61 | eval_end_pos = args.default_eval_samples if args.end_pos == int(1e10) else args.end_pos 62 | image_names = list(image_names)[args.start_pos : eval_end_pos] 63 | dataloader = DataLoader( 64 | CocoImageDataset(image_names, vis_processor), 65 | args.infer_bs_total, 66 | False, 67 | num_workers=args.infer_dataloader_worker, 68 | ) 69 | 70 | prompt = getattr(args, f"{args.model}_eval_prompt") 71 | with open(args.caption_eval_path, "w" if args.restart else "a", encoding="utf-8") as f: 72 | for batch in tqdm(dataloader): 73 | process_single(batch, model, prompt, output_fd=f) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /evaluate/eval_gqa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2024-06-11 15:15:36 3 | # @Author : Shangyu.Xing (starreeze@foxmail.com) 4 | 5 | from __future__ import annotations 6 | import os, sys, json, torch 7 | 8 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from torch.utils.data import DataLoader, Dataset 13 | from torch.cuda.amp import autocast # type: ignore 14 | from common.args import args 15 | from common.models import model_loaders, model_forward, generators, data_maps, sample_collators 16 | from common.utils import to_device 17 | 18 | basename = getattr(args, f"{args.model}_ckpt_load_path").split("/")[-1][:10] 19 | pred_path = os.path.join(args.gqa_data_path, "testdev_balanced_predictions.json") 20 | train_question_path = os.path.join(args.gqa_data_path, "questions1.2", "val_balanced_questions.json") 21 | eval_question_path = os.path.join(args.gqa_data_path, "questions1.2", "testdev_balanced_questions.json") 22 | image_dir = os.path.join(args.gqa_data_path, "images") 23 | model_dtype = {"llava": torch.bfloat16, "llavarlhf": torch.bfloat16, "share4v": torch.bfloat16} 24 | 25 | 26 | class GQAData(Dataset): 27 | def __init__(self, processor, start=0, end=int(1e9)): 28 | super().__init__() 29 | self.processor = processor 30 | with open(eval_question_path, "r") as f: 31 | questions: dict = json.load(f) 32 | self.data = list(questions.items())[start:end] 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def __getitem__(self, index: int): 38 | id, question = self.data[index] 39 | image_name = question["imageId"] + ".jpg" 40 | image_path = os.path.join(image_dir, image_name) 41 | image = Image.open(image_path).convert("RGB") 42 | answer = question["answer"] 43 | text = getattr(args, f"{args.model}_eval_vqa_prompt").format(question=question["question"]) 44 | return self.processor(image), id, text, answer 45 | 46 | def save(self): 47 | os.rename(eval_question_path, eval_question_path + ".bak") 48 | with open(eval_question_path, "w") as f: 49 | json.dump({k: v for k, v in self.data}, f) 50 | return self 51 | 52 | @staticmethod 53 | def restore(): 54 | os.rename(eval_question_path + ".bak", eval_question_path) 55 | 56 | 57 | def inference(model, vis_processor, start, end): 58 | eval_data = GQAData(vis_processor, start, end) 59 | eval_loader = DataLoader(eval_data, args.infer_bs_total, False) 60 | model.eval() 61 | model.to(torch.float16) 62 | correct, total = 0, 0 63 | with torch.no_grad(): 64 | results = [] 65 | for batch in tqdm(eval_loader): 66 | image, question_id, question, answer = to_device(batch) # type: ignore 67 | with autocast(dtype=torch.float16): 68 | responses = generators[args.model](model, question, image) 69 | for q, response, ans in zip(question_id, responses, answer): 70 | results.append({"questionId": q, "prediction": response.rstrip(".").lower()}) 71 | if ans.lower() in response.lower(): # type: ignore 72 | correct += 1 73 | total += len(question_id) 74 | with open(pred_path, "w") as f: 75 | json.dump(results, f) 76 | print(f"correct: {correct}, total: {total}, acc: {correct / total}") 77 | 78 | 79 | def eval(): 80 | model, vis_processor = model_loaders[args.model](getattr(args, f"{args.model}_ckpt_load_path")) 81 | try: 82 | model.to(model_dtype[args.model]) 83 | except KeyError: 84 | pass 85 | inference(model, vis_processor, start=0, end=args.default_eval_samples) 86 | 87 | 88 | if __name__ == "__main__": 89 | eval() 90 | -------------------------------------------------------------------------------- /evaluate/eval_mme.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2024-04-02 18:36:40 3 | # @Author : Shangyu.Xing (starreeze@foxmail.com) 4 | 5 | from __future__ import annotations 6 | import os, sys, torch 7 | 8 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 10 | from glob import glob 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from torch.utils.data import DataLoader, Dataset 14 | from common.args import args 15 | from common.models import model_loaders, generators 16 | from common.utils import to_device 17 | 18 | basename = getattr(args, f"{args.model}_ckpt_load_path").split("/")[-1] 19 | pred_name = f"{args.run_name}_" if args.run_name else "" 20 | pred_path = os.path.join(args.mme_result_path, f"{pred_name}{args.model}_{basename}") 21 | os.makedirs(pred_path, exist_ok=True) 22 | # model_dtype = {"llava": torch.bfloat16, "llavarlhf": torch.bfloat16, "share4v": torch.bfloat16} 23 | model_dtype = {"llavarlhf": torch.bfloat16} 24 | 25 | 26 | class MMEData(Dataset): 27 | def __init__(self, processor, category_path: str): 28 | super().__init__() 29 | self.processor = processor 30 | 31 | if os.path.exists(os.path.join(category_path, "images")): 32 | image_path = os.path.join(category_path, "images") 33 | qa_path = os.path.join(category_path, "questions_answers_YN") 34 | else: 35 | image_path = qa_path = category_path 36 | assert os.path.isdir(image_path), image_path 37 | assert os.path.isdir(qa_path), qa_path 38 | 39 | self.data = [] 40 | for file in os.listdir(qa_path): 41 | if not file.endswith(".txt"): 42 | continue 43 | for line in open(os.path.join(qa_path, file), encoding="utf-8"): 44 | question, answer = line.strip().split("\t") 45 | image_globs = glob(os.path.join(image_path, file.split(".")[0] + ".*")) 46 | image = list(filter(lambda x: not x.endswith(".txt"), image_globs)) 47 | if image: 48 | self.data.append((image[0], question, answer)) 49 | else: 50 | tqdm.write("No image found for " + file) 51 | 52 | def __len__(self): 53 | return len(self.data) 54 | 55 | def __getitem__(self, index: int): 56 | sample = self.data[index] 57 | image_path, question, answer = sample 58 | image = Image.open(image_path).convert("RGB") 59 | question = question.replace(" Please answer yes or no.", "") 60 | text = getattr(args, f"{args.model}_eval_vqa_prompt").format(question=question) 61 | return self.processor(image), text, "\t".join(sample) 62 | 63 | 64 | def inference(model, vis_processor, category_dir): 65 | eval_data = MMEData(vis_processor, os.path.join(args.mme_data_path, category_dir)) 66 | eval_loader = DataLoader(eval_data, args.infer_bs_total, False) 67 | model.eval() 68 | with torch.no_grad(): 69 | results = [] 70 | for batch in tqdm(eval_loader, position=1): 71 | images, questions, lines = to_device(batch, model_dtype.get(args.model, torch.float16)) # type: ignore 72 | # with autocast(dtype=torch.float16): 73 | answers = generators[args.model](model, questions, images) 74 | for line, answer in zip(lines, answers): 75 | results.append(line + "\t" + answer.replace("\n", " ")) 76 | with open(os.path.join(pred_path, category_dir + ".txt"), "w", encoding="utf-8") as f: 77 | f.write("\n".join(results)) 78 | 79 | 80 | def eval(): 81 | model, vis_processor = model_loaders[args.model](getattr(args, f"{args.model}_ckpt_load_path")) 82 | try: 83 | model.to(model_dtype[args.model]) 84 | except KeyError: 85 | pass 86 | for category_dir in tqdm(os.listdir(args.mme_data_path), position=0): 87 | if os.path.isdir(os.path.join(args.mme_data_path, category_dir)): 88 | inference(model, vis_processor, category_dir) 89 | eval_cmd = f"python evaluate/mme/calculation.py --results_dir {pred_path}" 90 | print(f"Inference done. running `{eval_cmd}`") 91 | os.system(eval_cmd) 92 | 93 | 94 | if __name__ == "__main__": 95 | eval() 96 | -------------------------------------------------------------------------------- /evaluate/eval_utils.py: -------------------------------------------------------------------------------- 1 | from common.args import args 2 | 3 | 4 | def get_eval_caption(): 5 | with open(args.caption_eval_path, "r") as f: 6 | content: list[str] = f.read().splitlines() 7 | image_ids, captions = [], [] 8 | eval_end_pos = args.default_eval_samples if args.end_pos == int(1e10) else args.end_pos 9 | for line in content[args.start_pos : eval_end_pos]: 10 | try: 11 | image_name, caption = line.replace("### gpt: ", "").split("###")[:2] 12 | except ValueError as e: 13 | print(f"Skipping line {line} due to {e}") 14 | continue 15 | image_ids.append(int(image_name.split("_")[-1].split(".")[0])) 16 | captions.append(caption) 17 | return image_ids, captions 18 | -------------------------------------------------------------------------------- /finetune/format_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2023-11-29 18:07:24 3 | # @Author : Shangyu.Xing (starreeze@foxmail.com) 4 | """ 5 | from captions, objects and scores format a flattened json file, 6 | where each object has an entry containing fields: sentence, sub-sentence/object position, score 7 | 8 | normal text ... hallucination object/subsentence 9 | ^ 10 | | 11 | sub-sentence/object position 12 | 13 | sentence: type=str, stripped sentence until the target 14 | sub-sentence/object mask: type=int, beginning position (char-level) of the unlearn target 15 | score: float, clip score of the object 16 | """ 17 | 18 | from __future__ import annotations 19 | import sys, os, json, re 20 | 21 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 23 | from tqdm import tqdm 24 | import numpy as np 25 | from common.args import args 26 | 27 | 28 | def find_subsentence(sentence: str, target_word: str): 29 | sub_sentences = re.split(f"[{args.subsentence_splitter_set}]+", sentence) 30 | for sub_sentence in sub_sentences: 31 | if target_word in sub_sentence: 32 | start_position = sentence.find(sub_sentence) 33 | end_position = start_position + len(sub_sentence) 34 | return start_position, end_position 35 | tqdm.write(f"'{target_word}' not found in '{sentence}', skipping...") 36 | return None 37 | 38 | 39 | def process_pos_neg(image: str, caption: str, objects: str, scores: np.ndarray) -> list[dict[str, str | int]] | None: 40 | results = [] 41 | for object, score in zip([obj.strip("[]") for obj in objects.split(args.object_splitter)], scores): 42 | # XXX which to unlearn? objects, subsentence or else? 43 | if args.unlearn_target == "objects": 44 | index = caption.find(object) 45 | assert index != -1 46 | # XXX should we keep the former hallucination objects in the training sample? 47 | results.append( 48 | { 49 | "image": image, 50 | "sentence": caption[: index + len(object)], 51 | "position": index, 52 | "score": float(score), 53 | } 54 | ) 55 | elif args.unlearn_target == "subsentence": 56 | result = find_subsentence(caption, object) 57 | if result is not None: 58 | start, end = result 59 | results.append({"image": image, "sentence": caption[:end], "position": start, "score": float(score)}) 60 | else: 61 | raise NotImplementedError() 62 | return results 63 | 64 | 65 | def main(): 66 | with open(args.caption_data_path, "r") as f: 67 | captions = f.read().splitlines() 68 | captions_d = {} 69 | for sample in captions: 70 | name, caption = sample.split(args.column_splitter) 71 | captions_d[name] = caption 72 | with open(args.object_data_path, "r") as f: 73 | objects = f.read().splitlines() 74 | # as new generated data has no [], it is regarded as norm 75 | scores = np.load(args.norm_result_path, allow_pickle=True) 76 | assert len(objects) == len(scores) 77 | pos_neg, sentence = [], [] 78 | for sample in tqdm(zip(objects, scores), total=len(objects)): 79 | object, score = sample 80 | image_name, _, object = object.split(args.column_splitter) 81 | if len(object.split(args.object_splitter)) != score.shape[0] or score.shape[0] == 0: 82 | tqdm.write("objects and scores not match or empty objects! skipping...") 83 | continue 84 | caption = captions_d[image_name] 85 | result = process_pos_neg(image_name, caption, object, score) 86 | if result is not None: 87 | pos_neg.extend(result) 88 | sentence.append( 89 | {"image": image_name, "sentence": caption, "mean": float(score.mean()), "min": float(score.min())} 90 | ) 91 | with open(args.pos_neg_data_path, "w") as f: 92 | json.dump(pos_neg, f, indent=2) 93 | with open(args.sentence_data_path, "w") as f: 94 | json.dump(sentence, f, indent=2) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /insight/plot_time.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2024-02-09 15:49:44 3 | # @Author : Shangyu.Xing (starreeze@foxmail.com) 4 | 5 | from __future__ import annotations 6 | from matplotlib import pyplot as plt 7 | 8 | 9 | def main(): 10 | labels = ["RLHF", "DPO", "CL", "EFUF"] 11 | values = [20, 12, 10, 3] 12 | colors = ["#45B39D"] * (len(values) - 1) + ["#F5B041"] 13 | 14 | # Create the bar chart 15 | plt.figure(figsize=(12, 9)) 16 | plt.subplots_adjust(left=0.15, right=0.95, top=0.92, bottom=0.12) 17 | plt.bar(labels, values, color=colors, width=0.6) # type: ignore 18 | 19 | plt.plot(labels, values, color="darkred", marker="o", linestyle="-", linewidth=3, markersize=12) 20 | for label, value in enumerate(values): 21 | plt.text(label, value + 0.7, str(value), ha="center", fontsize=26) 22 | 23 | # Adding the title and labels 24 | # plt.xlabel("Method", fontsize=24) 25 | plt.ylabel("A100 GPU hours", fontsize=30) 26 | plt.xticks(fontsize=30) 27 | plt.yticks(fontsize=30) 28 | plt.ylim(top=22.5) 29 | 30 | # Show the plot 31 | plt.show() 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /minigpt4/__init__.py: -------------------------------------------------------------------------------- 1 | from .common.registry import Registry 2 | 3 | Registry.register_path("library_root", "minigpt4") 4 | -------------------------------------------------------------------------------- /minigpt4/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/common/__init__.py -------------------------------------------------------------------------------- /minigpt4/common/eval_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from nltk.translate.bleu_score import sentence_bleu 4 | 5 | from minigpt4.common.registry import registry 6 | from minigpt4.common.config import Config 7 | 8 | # imports modules for registration 9 | from minigpt4.datasets.builders import * 10 | from minigpt4.models import * 11 | from minigpt4.processors import * 12 | from minigpt4.runners import * 13 | from minigpt4.tasks import * 14 | 15 | 16 | def eval_parser(): 17 | parser = argparse.ArgumentParser(description="Demo") 18 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.") 19 | parser.add_argument("--name", type=str, default="A2", help="evaluation name") 20 | parser.add_argument("--ckpt", type=str, help="path to configuration file.") 21 | parser.add_argument("--eval_opt", type=str, default="all", help="path to configuration file.") 22 | parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens") 23 | parser.add_argument("--batch_size", type=int, default=32) 24 | parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") 25 | parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") 26 | parser.add_argument( 27 | "--options", 28 | nargs="+", 29 | help="override some settings in the used config, the key-value pair " 30 | "in xxx=yyy format will be merged into config file (deprecate), " 31 | "change to --cfg-options instead.", 32 | ) 33 | return parser 34 | 35 | 36 | def prepare_texts(texts, conv_temp): 37 | convs = [conv_temp.copy() for _ in range(len(texts))] 38 | [conv.append_message(conv.roles[0], " {}".format(text)) for conv, text in zip(convs, texts)] 39 | [conv.append_message(conv.roles[1], None) for conv in convs] 40 | texts = [conv.get_prompt() for conv in convs] 41 | return texts 42 | 43 | 44 | def init_model(args, device="cuda:0"): 45 | print("Initialization Model") 46 | cfg = Config(args) 47 | # cfg.model_cfg.ckpt = args.ckpt 48 | # cfg.model_cfg.lora_r = args.lora_r 49 | # cfg.model_cfg.lora_alpha = args.lora_alpha 50 | 51 | model_config = cfg.model_cfg 52 | model_cls = registry.get_model_class(model_config.arch) 53 | model = model_cls.from_config(model_config).to(device) 54 | 55 | # import pudb; pudb.set_trace() 56 | key = list(cfg.datasets_cfg.keys())[0] 57 | vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train 58 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 59 | print("Initialization Finished") 60 | return model, vis_processor 61 | 62 | 63 | def computeIoU(bbox1, bbox2): 64 | x1, y1, x2, y2 = bbox1 65 | x3, y3, x4, y4 = bbox2 66 | intersection_x1 = max(x1, x3) 67 | intersection_y1 = max(y1, y3) 68 | intersection_x2 = min(x2, x4) 69 | intersection_y2 = min(y2, y4) 70 | intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) 71 | bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) 72 | bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) 73 | union_area = bbox1_area + bbox2_area - intersection_area 74 | iou = intersection_area / union_area 75 | return iou 76 | -------------------------------------------------------------------------------- /minigpt4/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /minigpt4/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from minigpt4.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | **kwargs 25 | ): 26 | self.optimizer = optimizer 27 | 28 | self.max_epoch = max_epoch 29 | self.min_lr = min_lr 30 | 31 | self.decay_rate = decay_rate 32 | 33 | self.init_lr = init_lr 34 | self.warmup_steps = warmup_steps 35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 36 | 37 | def step(self, cur_epoch, cur_step): 38 | if cur_epoch == 0: 39 | warmup_lr_schedule( 40 | step=cur_step, 41 | optimizer=self.optimizer, 42 | max_step=self.warmup_steps, 43 | init_lr=self.warmup_start_lr, 44 | max_lr=self.init_lr, 45 | ) 46 | else: 47 | step_lr_schedule( 48 | epoch=cur_epoch, 49 | optimizer=self.optimizer, 50 | init_lr=self.init_lr, 51 | min_lr=self.min_lr, 52 | decay_rate=self.decay_rate, 53 | ) 54 | 55 | 56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 57 | class LinearWarmupCosineLRScheduler: 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_epoch, 62 | iters_per_epoch, 63 | min_lr, 64 | init_lr, 65 | warmup_steps=0, 66 | warmup_start_lr=-1, 67 | **kwargs 68 | ): 69 | self.optimizer = optimizer 70 | 71 | self.max_epoch = max_epoch 72 | self.iters_per_epoch = iters_per_epoch 73 | self.min_lr = min_lr 74 | 75 | self.init_lr = init_lr 76 | self.warmup_steps = warmup_steps 77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 78 | 79 | def step(self, cur_epoch, cur_step): 80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 81 | if total_cur_step < self.warmup_steps: 82 | warmup_lr_schedule( 83 | step=cur_step, 84 | optimizer=self.optimizer, 85 | max_step=self.warmup_steps, 86 | init_lr=self.warmup_start_lr, 87 | max_lr=self.init_lr, 88 | ) 89 | else: 90 | cosine_lr_schedule( 91 | epoch=total_cur_step, 92 | optimizer=self.optimizer, 93 | max_epoch=self.max_epoch * self.iters_per_epoch, 94 | init_lr=self.init_lr, 95 | min_lr=self.min_lr, 96 | ) 97 | 98 | 99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 100 | """Decay the learning rate""" 101 | lr = (init_lr - min_lr) * 0.5 * ( 102 | 1.0 + math.cos(math.pi * epoch / max_epoch) 103 | ) + min_lr 104 | for param_group in optimizer.param_groups: 105 | param_group["lr"] = lr 106 | 107 | 108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 109 | """Warmup the learning rate""" 110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 111 | for param_group in optimizer.param_groups: 112 | param_group["lr"] = lr 113 | 114 | 115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 116 | """Decay the learning rate""" 117 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 118 | for param_group in optimizer.param_groups: 119 | param_group["lr"] = lr 120 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import sys 4 | dataDir = '../../VQA' 5 | sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir)) 6 | from vqa import VQA 7 | from vqaEvaluation.vqaEval import VQAEval 8 | import matplotlib.pyplot as plt 9 | import skimage.io as io 10 | import json 11 | import random 12 | import os 13 | 14 | # set up file names and paths 15 | versionType ='v2_' # this should be '' when using VQA v2.0 dataset 16 | taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 17 | dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. 18 | dataSubType ='train2014' 19 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) 20 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) 21 | imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) 22 | resultType ='fake' 23 | fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType'] 24 | 25 | # An example result json file has been provided in './Results' folder. 26 | 27 | [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \ 28 | resultType, fileType) for fileType in fileTypes] 29 | 30 | # create vqa object and vqaRes object 31 | vqa = VQA(annFile, quesFile) 32 | vqaRes = vqa.loadRes(resFile, quesFile) 33 | 34 | # create vqaEval object by taking vqa and vqaRes 35 | vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2 36 | 37 | # evaluate results 38 | """ 39 | If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function 40 | By default it uses all the question ids in annotation file 41 | """ 42 | vqaEval.evaluate() 43 | 44 | # print accuracies 45 | print "\n" 46 | print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']) 47 | print "Per Question Type Accuracy is the following:" 48 | for quesType in vqaEval.accuracy['perQuestionType']: 49 | print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType]) 50 | print "\n" 51 | print "Per Answer Type Accuracy is the following:" 52 | for ansType in vqaEval.accuracy['perAnswerType']: 53 | print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType]) 54 | print "\n" 55 | # demo how to use evalQA to retrieve low score result 56 | evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy 57 | if len(evals) > 0: 58 | print 'ground truth answers' 59 | randomEval = random.choice(evals) 60 | randomAnn = vqa.loadQA(randomEval) 61 | vqa.showQA(randomAnn) 62 | 63 | print '\n' 64 | print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval]) 65 | ann = vqaRes.loadQA(randomEval)[0] 66 | print "Answer: %s\n" %(ann['answer']) 67 | 68 | imgId = randomAnn[0]['image_id'] 69 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 70 | if os.path.isfile(imgDir + imgFilename): 71 | I = io.imread(imgDir + imgFilename) 72 | plt.imshow(I) 73 | plt.axis('off') 74 | plt.show() 75 | 76 | # plot accuracy for various question types 77 | plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center') 78 | plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10) 79 | plt.title('Per Question Type Accuracy', fontsize=10) 80 | plt.xlabel('Question Types', fontsize=10) 81 | plt.ylabel('Accuracy', fontsize=10) 82 | plt.show() 83 | 84 | # save evaluation results to ./Results folder 85 | json.dump(vqaEval.accuracy, open(accuracyFile, 'w')) 86 | json.dump(vqaEval.evalQA, open(evalQAFile, 'w')) 87 | json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w')) 88 | json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w')) 89 | 90 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py: -------------------------------------------------------------------------------- 1 | author='aagrawal' 2 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from vqaTools.vqa import VQA 4 | import random 5 | import skimage.io as io 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | dataDir ='../../VQA' 10 | versionType ='v2_' # this should be '' when using VQA v2.0 dataset 11 | taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 12 | dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. 13 | dataSubType ='train2014' 14 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) 15 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) 16 | imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) 17 | 18 | # initialize VQA api for QA annotations 19 | vqa=VQA(annFile, quesFile) 20 | 21 | # load and display QA annotations for given question types 22 | """ 23 | All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder. 24 | """ 25 | annIds = vqa.getQuesIds(quesTypes='how many'); 26 | anns = vqa.loadQA(annIds) 27 | randomAnn = random.choice(anns) 28 | vqa.showQA([randomAnn]) 29 | imgId = randomAnn['image_id'] 30 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 31 | if os.path.isfile(imgDir + imgFilename): 32 | I = io.imread(imgDir + imgFilename) 33 | plt.imshow(I) 34 | plt.axis('off') 35 | plt.show() 36 | 37 | # load and display QA annotations for given answer types 38 | """ 39 | ansTypes can be one of the following 40 | yes/no 41 | number 42 | other 43 | """ 44 | annIds = vqa.getQuesIds(ansTypes='yes/no'); 45 | anns = vqa.loadQA(annIds) 46 | randomAnn = random.choice(anns) 47 | vqa.showQA([randomAnn]) 48 | imgId = randomAnn['image_id'] 49 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 50 | if os.path.isfile(imgDir + imgFilename): 51 | I = io.imread(imgDir + imgFilename) 52 | plt.imshow(I) 53 | plt.axis('off') 54 | plt.show() 55 | 56 | # load and display QA annotations for given images 57 | """ 58 | Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[]) 59 | Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types. 60 | """ 61 | ids = vqa.getImgIds() 62 | annIds = vqa.getQuesIds(imgIds=random.sample(ids,5)); 63 | anns = vqa.loadQA(annIds) 64 | randomAnn = random.choice(anns) 65 | vqa.showQA([randomAnn]) 66 | imgId = randomAnn['image_id'] 67 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 68 | if os.path.isfile(imgDir + imgFilename): 69 | I = io.imread(imgDir + imgFilename) 70 | plt.imshow(I) 71 | plt.axis('off') 72 | plt.show() 73 | 74 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt: -------------------------------------------------------------------------------- 1 | how many 2 | what color is the 3 | is the 4 | where is the 5 | what 6 | what is 7 | are the 8 | what is the 9 | is there a 10 | does the 11 | is the woman 12 | is the man 13 | what is on the 14 | is it 15 | is the girl 16 | is the boy 17 | is the dog 18 | are they 19 | who is 20 | what kind of 21 | what color are the 22 | what is in the 23 | what is the man 24 | is there 25 | what is the woman 26 | what are the 27 | what is the boy 28 | are there 29 | what is the girl 30 | is this 31 | how 32 | which 33 | how many people are 34 | is the cat 35 | why is the 36 | are 37 | will the 38 | what type of 39 | what is the dog 40 | do 41 | is she 42 | does 43 | do the 44 | is 45 | is the baby 46 | are there any 47 | is the lady 48 | can 49 | what animal is 50 | where are the 51 | is the sun 52 | what are they 53 | did the 54 | what is the cat 55 | what is the lady 56 | how many clouds are 57 | is that 58 | is the little girl 59 | is he 60 | are these 61 | how many trees are 62 | how many pillows 63 | are the people 64 | why 65 | is the young 66 | how many windows are 67 | is this a 68 | what is the little 69 | is the tv 70 | how many animals are 71 | who 72 | how many pictures 73 | how many plants are 74 | how many birds are 75 | what color is 76 | what is the baby 77 | is anyone 78 | what color 79 | how many bushes 80 | is the old man 81 | none of the above 82 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt: -------------------------------------------------------------------------------- 1 | how many 2 | is the 3 | what 4 | what color is the 5 | what is the 6 | is this 7 | is this a 8 | what is 9 | are the 10 | what kind of 11 | is there a 12 | what type of 13 | is it 14 | what are the 15 | where is the 16 | is there 17 | does the 18 | what color are the 19 | are these 20 | are there 21 | which 22 | is 23 | what is the man 24 | is the man 25 | are 26 | how 27 | does this 28 | what is on the 29 | what does the 30 | how many people are 31 | what is in the 32 | what is this 33 | do 34 | what are 35 | are they 36 | what time 37 | what sport is 38 | are there any 39 | is he 40 | what color is 41 | why 42 | where are the 43 | what color 44 | who is 45 | what animal is 46 | is the woman 47 | is this an 48 | do you 49 | how many people are in 50 | what room is 51 | has 52 | is this person 53 | what is the woman 54 | can you 55 | why is the 56 | is the person 57 | what is the color of the 58 | what is the person 59 | could 60 | was 61 | is that a 62 | what number is 63 | what is the name 64 | what brand 65 | none of the above 66 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/VQA/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Aishwarya Agrawal 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 18 | FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | 26 | The views and conclusions contained in the software and documentation are 27 | those 28 | of the authors and should not be interpreted as representing official 29 | policies, 30 | either expressed or implied, of the FreeBSD Project. 31 | -------------------------------------------------------------------------------- /minigpt4/common/vqa_tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | __author__ = "aagrawal" 9 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/aokvqa/defaults.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | aok_vqa: 8 | # data_dir: ${env.data_dir}/datasets 9 | data_type: images # [images|videos|features] 10 | 11 | build_info: 12 | # Be careful not to append minus sign (-) before split to avoid itemizing 13 | annotations: 14 | train: 15 | url: 16 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json 17 | storage: 18 | - /path/to/aokvqa_v1p0_train.json 19 | images: 20 | storage: /path/to/coco/images -------------------------------------------------------------------------------- /minigpt4/configs/datasets/cc_sbu/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu_align: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/cc_sbu_align/ 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/cc_sbu/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/cc_sbu_dataset/{00000..01255}.tar 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco/caption.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | coco_caption: # name of the dataset builder 8 | # dataset_card: dataset_card/coco_caption.md 9 | # data_dir: ${env.data_dir}/datasets 10 | data_type: images # [images|videos|features] 11 | 12 | build_info: 13 | # Be careful not to append minus sign (-) before split to avoid itemizing 14 | annotations: 15 | train: 16 | url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json 17 | md5: aa31ac474cf6250ebb81d18348a07ed8 18 | storage: /path/to/coco_caption/coco_karpathy_train.json 19 | images: 20 | storage: /path/to/coco/images 21 | 22 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco/defaults_vqa.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | coco_vqa: 8 | # data_dir: ${env.data_dir}/datasets 9 | data_type: images # [images|videos|features] 10 | 11 | build_info: 12 | 13 | annotations: 14 | train: 15 | url: 16 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json 17 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json 18 | storage: 19 | - /path/to/vqav2/vqa_train.json 20 | - /path/to/vqav2/vqa_val.json 21 | images: 22 | storage: /path/to/coco/images 23 | 24 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | invrefcoco: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/coco/images 6 | ann_path: /path/to/refcoco_annotations 7 | dataset: invrefcoco 8 | splitBy: unc -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | invrefcocog: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/coco/images 6 | ann_path: /path/to/refcoco_annotations 7 | dataset: invrefcocog 8 | splitBy: umd -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | invrefcocop: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/coco/images 6 | ann_path: /path/to/refcoco_annotations 7 | dataset: invrefcoco+ 8 | splitBy: unc -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco_bbox/refcoco.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | refcoco: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/coco/images 6 | ann_path: /path/to/refcoco_annotations 7 | dataset: refcoco 8 | splitBy: unc -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco_bbox/refcocog.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | refcocog: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/coco/images 6 | ann_path: /path/to/refcoco_annotations 7 | dataset: refcocog 8 | splitBy: umd -------------------------------------------------------------------------------- /minigpt4/configs/datasets/coco_bbox/refcocop.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | refcocop: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/coco/images 6 | ann_path: /path/to/refcoco_annotations 7 | dataset: refcoco+ 8 | splitBy: unc -------------------------------------------------------------------------------- /minigpt4/configs/datasets/flickr/caption_to_phrase.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | flickr_CaptionToPhrase: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/filtered_flikcr/images 6 | ann_path: /path/to/filtered_flickr/captiontobbox.json 7 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/flickr/default.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | flickr_grounded_caption: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/filtered_flikcr/images 6 | ann_path: /path/to/filtered_flikcr/groundedcaption.json 7 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/flickr/object_to_phrase.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | flickr_ObjectToPhrase: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/filtered_flikcr/images 6 | ann_path: /path/to/filtered_flikcr/phrasetobbox.json 7 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/gqa/balanced_val.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | gqa: 8 | # data_dir: ${env.data_dir}/datasets 9 | data_type: images # [images|videos|features] 10 | 11 | build_info: 12 | # Be careful not to append minus sign (-) before split to avoid itemizing 13 | annotations: 14 | train: 15 | url: 16 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json 17 | storage: 18 | - /path/to/gqa/train_balanced_questions.json 19 | 20 | images: 21 | storage: /path/to/gqa/images 22 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/laion/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | laion: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/laion_dataset/{00000..10488}.tar 6 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/llava/conversation.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | 3 | llava_conversation: 4 | data_type: images 5 | build_info: 6 | image_path: /path/to/coco/images 7 | ann_path: /path/to/llava/conversation_58k.json -------------------------------------------------------------------------------- /minigpt4/configs/datasets/llava/detail.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | llava_detail: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/coco/images 6 | ann_path: /path/to/llava/detail_23k.json -------------------------------------------------------------------------------- /minigpt4/configs/datasets/llava/reason.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | 3 | llava_reason: 4 | data_type: images 5 | build_info: 6 | image_path: /path/to/coco/images 7 | ann_path: /path/to/llava/complex_reasoning_77k.json -------------------------------------------------------------------------------- /minigpt4/configs/datasets/multitask_conversation/default.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | multitask_conversation: 3 | data_type: images 4 | build_info: 5 | 6 | image_path: /path/to/coco/images 7 | ann_path: /path/to/multitask_conversation/multi_task_conversation.json -------------------------------------------------------------------------------- /minigpt4/configs/datasets/nlp/unnatural_instruction.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | unnatural_instruction: 3 | data_type: text 4 | build_info: 5 | ann_path: /path/to/unnatural_instructions/filtered_unnatural_instruction.json -------------------------------------------------------------------------------- /minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | ocrvqa: 3 | data_type: images 4 | build_info: 5 | image_path: /path/to/ocrvqa/images 6 | ann_path: /path/to/ocrvqa/dataset.json -------------------------------------------------------------------------------- /minigpt4/configs/datasets/okvqa/defaults.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | datasets: 7 | ok_vqa: 8 | # data_dir: ${env.data_dir}/datasets 9 | data_type: images # [images|videos|features] 10 | 11 | build_info: 12 | # Be careful not to append minus sign (-) before split to avoid itemizing 13 | annotations: 14 | train: 15 | url: 16 | # TODO make this order insensitive 17 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json 18 | storage: 19 | - /path/to/okvqa/okvqa_train.json 20 | images: 21 | storage: /path/to/coco/images -------------------------------------------------------------------------------- /minigpt4/configs/datasets/textcaps/caption.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | textcaps_caption: 3 | data_type: images 4 | 5 | build_info: 6 | image_path: /path/to/textcaps/train_images 7 | ann_path: /path/to/textcaps/TextCaps_0.1_train.json 8 | 9 | 10 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/vg/ref.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | refvg: 3 | data_type: images 4 | build_info: 5 | data_dir: /path/to/visual_genome -------------------------------------------------------------------------------- /minigpt4/configs/default.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | # For default users 3 | # cache_root: "cache" 4 | # For internal use with persistent storage 5 | cache_root: "/export/home/.cache/minigpt4" 6 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4_llama2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: minigpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | has_qformer: False 11 | 12 | # generation configs 13 | prompt: "" 14 | 15 | llama_model: "please set this value to the path of llama2-chat-7b" 16 | 17 | preprocess: 18 | vis_processor: 19 | train: 20 | name: "blip2_image_train" 21 | image_size: 224 22 | eval: 23 | name: "blip2_image_eval" 24 | image_size: 224 25 | text_processor: 26 | train: 27 | name: "blip_caption" 28 | eval: 29 | name: "blip_caption" 30 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4_vicuna0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: minigpt4 3 | 4 | # vit encoder 5 | image_size: 224 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | freeze_qformer: True 11 | 12 | # Q-Former 13 | num_query_token: 32 14 | 15 | # generation configs 16 | prompt: "" 17 | 18 | llama_model: "please set this value to the path of vicuna model" 19 | 20 | preprocess: 21 | vis_processor: 22 | train: 23 | name: "blip2_image_train" 24 | image_size: 224 25 | eval: 26 | name: "blip2_image_eval" 27 | image_size: 224 28 | text_processor: 29 | train: 30 | name: "blip_caption" 31 | eval: 32 | name: "blip_caption" 33 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt_v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: minigpt_v2 3 | 4 | # vit encoder 5 | image_size: 448 6 | drop_path_rate: 0 7 | use_grad_checkpoint: False 8 | vit_precision: "fp16" 9 | freeze_vit: True 10 | 11 | # generation configs 12 | prompt: "" 13 | 14 | llama_model: "please set this value to the path of llama2-chat-7b" 15 | lora_r: 64 16 | lora_alpha: 16 17 | 18 | 19 | preprocess: 20 | vis_processor: 21 | train: 22 | name: "blip2_image_train" 23 | image_size: 448 24 | eval: 25 | name: "blip2_image_eval" 26 | image_size: 448 27 | text_processor: 28 | train: 29 | name: "blip_caption" 30 | eval: 31 | name: "blip_caption" 32 | -------------------------------------------------------------------------------- /minigpt4/conversation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/conversation/__init__.py -------------------------------------------------------------------------------- /minigpt4/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/datasets/__init__.py -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config 9 | from minigpt4.datasets.builders.image_text_pair_builder import ( 10 | CCSBUBuilder, 11 | LaionBuilder, 12 | CCSBUAlignBuilder 13 | ) 14 | from minigpt4.common.registry import registry 15 | 16 | __all__ = [ 17 | "CCSBUBuilder", 18 | "LaionBuilder", 19 | "CCSBUAlignBuilder" 20 | ] 21 | 22 | 23 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 24 | """ 25 | Example 26 | 27 | >>> dataset = load_dataset("coco_caption", cfg=None) 28 | >>> splits = dataset.keys() 29 | >>> print([len(dataset[split]) for split in splits]) 30 | 31 | """ 32 | if cfg_path is None: 33 | cfg = None 34 | else: 35 | cfg = load_dataset_config(cfg_path) 36 | 37 | try: 38 | builder = registry.get_builder_class(name)(cfg) 39 | except TypeError: 40 | print( 41 | f"Dataset {name} not found. Available datasets:\n" 42 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 43 | ) 44 | exit(1) 45 | 46 | if vis_path is not None: 47 | if data_type is None: 48 | # use default data type in the config 49 | data_type = builder.config.data_type 50 | 51 | assert ( 52 | data_type in builder.config.build_info 53 | ), f"Invalid data_type {data_type} for {name}." 54 | 55 | builder.config.build_info.get(data_type).storage = vis_path 56 | 57 | dataset = builder.build_datasets() 58 | return dataset 59 | 60 | 61 | class DatasetZoo: 62 | def __init__(self) -> None: 63 | self.dataset_zoo = { 64 | k: list(v.DATASET_CONFIG_DICT.keys()) 65 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 66 | } 67 | 68 | def get_names(self): 69 | return list(self.dataset_zoo.keys()) 70 | 71 | 72 | dataset_zoo = DatasetZoo() 73 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/datasets/datasets/__init__.py -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | 16 | 17 | class BaseDataset(Dataset): 18 | def __init__( 19 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 20 | ): 21 | """ 22 | vis_root (string): Root directory of images (e.g. coco/images/) 23 | ann_root (string): directory to store the annotation file 24 | """ 25 | self.vis_root = vis_root 26 | 27 | self.annotation = [] 28 | # print("ann paths", ann_paths) 29 | for ann_path in ann_paths: 30 | # print("ann_path", ann_path) 31 | ann = json.load(open(ann_path, "r")) 32 | if isinstance(ann, dict): 33 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 34 | # self.annotation.extend(json.load(open(ann_path, "r"))) 35 | else: 36 | self.annotation.extend(json.load(open(ann_path, "r"))) 37 | 38 | self.vis_processor = vis_processor 39 | self.text_processor = text_processor 40 | 41 | self._add_instance_ids() 42 | 43 | def __len__(self): 44 | return len(self.annotation) 45 | 46 | def collater(self, samples): 47 | return default_collate(samples) 48 | 49 | def set_processors(self, vis_processor, text_processor): 50 | self.vis_processor = vis_processor 51 | self.text_processor = text_processor 52 | 53 | def _add_instance_ids(self, key="instance_id"): 54 | for idx, ann in enumerate(self.annotation): 55 | ann[key] = str(idx) 56 | 57 | 58 | 59 | class ConcatDataset(ConcatDataset): 60 | def __init__(self, datasets: Iterable[Dataset]) -> None: 61 | super().__init__(datasets) 62 | 63 | def collater(self, samples): 64 | # TODO For now only supports datasets with same underlying collater implementations 65 | 66 | all_keys = set() 67 | for s in samples: 68 | all_keys.update(s) 69 | 70 | shared_keys = all_keys 71 | for s in samples: 72 | shared_keys = shared_keys & set(s.keys()) 73 | 74 | samples_shared_keys = [] 75 | for s in samples: 76 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 77 | 78 | return self.datasets[0].collater(samples_shared_keys) 79 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/cc_sbu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import webdataset as wds 4 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 5 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 6 | 7 | 8 | class CCSBUDataset(BaseDataset): 9 | def __init__(self, vis_processor, text_processor, location): 10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 11 | 12 | self.inner_dataset = wds.DataPipeline( 13 | wds.ResampledShards(location), 14 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 15 | wds.shuffle(1000, handler=wds.warn_and_continue), 16 | wds.decode("pilrgb", handler=wds.warn_and_continue), 17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 19 | wds.map(self.to_dict, handler=wds.warn_and_continue), 20 | ) 21 | 22 | def to_dict(self, sample): 23 | return { 24 | "image": sample[0], 25 | "answer": self.text_processor(sample[1]["caption"]), 26 | } 27 | 28 | 29 | class CCSBUAlignDataset(CaptionDataset): 30 | 31 | def __getitem__(self, index): 32 | 33 | # TODO this assumes image input, not general enough 34 | ann = self.annotation[index] 35 | 36 | img_file = '{}.jpg'.format(ann["image_id"]) 37 | image_path = os.path.join(self.vis_root, img_file) 38 | image = Image.open(image_path).convert("RGB") 39 | 40 | image = self.vis_processor(image) 41 | caption = ann["caption"] 42 | 43 | return { 44 | "image": image, 45 | "answer": caption, 46 | "image_id": self.img_ids[ann["image_id"]], 47 | } -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/gqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import json 10 | 11 | from PIL import Image 12 | 13 | from minigpt4.datasets.datasets.vqa_datasets import VQADataset 14 | 15 | from collections import OrderedDict 16 | import random 17 | 18 | class __DisplMixin: 19 | def displ_item(self, index): 20 | sample, ann = self.__getitem__(index), self.annotation[index] 21 | 22 | return OrderedDict( 23 | { 24 | "file": ann["image"], 25 | "question": ann["question"], 26 | "question_id": ann["question_id"], 27 | "answers": "; ".join(ann["answer"]), 28 | "image": sample["image"], 29 | } 30 | ) 31 | 32 | 33 | class GQADataset(VQADataset, __DisplMixin): 34 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 35 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 36 | self.instruction_pool =[ 37 | "[vqa] {}", 38 | "[vqa] Based on the image, respond to this question with a short answer: {}" 39 | ] 40 | 41 | def __getitem__(self, index): 42 | ann = self.annotation[index] 43 | 44 | image_path = os.path.join(self.vis_root, ann["image"]) 45 | image = Image.open(image_path).convert("RGB") 46 | 47 | image = self.vis_processor(image) 48 | question = self.text_processor(ann["question"]) 49 | 50 | instruction = random.choice(self.instruction_pool).format(question) 51 | instruction = " {} ".format(instruction) 52 | 53 | answers = self.text_processor(ann["answer"]) 54 | 55 | return { 56 | "image": image, 57 | "instruction_input": instruction, 58 | "answer": answers, 59 | } 60 | 61 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/laion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import webdataset as wds 9 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 10 | 11 | 12 | class LaionDataset(BaseDataset): 13 | def __init__(self, vis_processor, text_processor, location): 14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor) 15 | 16 | self.inner_dataset = wds.DataPipeline( 17 | wds.ResampledShards(location), 18 | wds.tarfile_to_samples(handler=wds.warn_and_continue), 19 | wds.shuffle(1000, handler=wds.warn_and_continue), 20 | wds.decode("pilrgb", handler=wds.warn_and_continue), 21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), 22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), 23 | wds.map(self.to_dict, handler=wds.warn_and_continue), 24 | ) 25 | 26 | def to_dict(self, sample): 27 | return { 28 | "image": sample[0], 29 | "answer": self.text_processor(sample[1]["caption"]), 30 | } 31 | 32 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/multitask_conversation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import random 5 | import time 6 | import itertools 7 | 8 | import numpy as np 9 | from PIL import Image 10 | import skimage.io as io 11 | import matplotlib.pyplot as plt 12 | from matplotlib.collections import PatchCollection 13 | from matplotlib.patches import Polygon, Rectangle 14 | from torch.utils.data import Dataset 15 | import webdataset as wds 16 | 17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 19 | 20 | 21 | 22 | 23 | class MultiTaskConversationDataset(Dataset): 24 | def __init__(self, vis_processor, text_processor, vis_root, ann_path): 25 | """ 26 | vis_root (string): Root directory of images (e.g. coco/images/) 27 | ann_root (string): directory to store the annotation file 28 | """ 29 | self.vis_root = vis_root 30 | 31 | self.vis_processor = vis_processor 32 | self.text_processor = text_processor 33 | 34 | 35 | with open(ann_path, 'r') as f: 36 | self.ann = json.load(f) 37 | 38 | self.connect_sym = "!@#" 39 | 40 | def __len__(self): 41 | return len(self.ann) 42 | 43 | def __getitem__(self, index): 44 | info = self.ann[index] 45 | 46 | image_file = 'COCO_train2014_{}.jpg'.format(info['id']) 47 | image_path = os.path.join(self.vis_root, image_file) 48 | image = Image.open(image_path).convert("RGB") 49 | image = self.vis_processor(image) 50 | 51 | first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() 52 | first_instruction = ' {} '.format(first_instruction) 53 | 54 | questions = [first_instruction] 55 | answers = [] 56 | 57 | for i, item in enumerate(info["conversations"][1:]): 58 | if i % 2 ==0: # assistant 59 | assistant_answer = item["value"] 60 | answers.append(assistant_answer) 61 | else: 62 | human_instruction = item["value"]+" " 63 | questions.append(human_instruction) 64 | 65 | questions = self.connect_sym.join(questions) 66 | answers = self.connect_sym.join(answers) 67 | 68 | 69 | return { 70 | "image": image, 71 | "conv_q": questions, 72 | 'conv_a': answers, 73 | "image_id": info['id'], 74 | "connect_sym": self.connect_sym 75 | } -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/ocrvqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import random 5 | import time 6 | import itertools 7 | 8 | import numpy as np 9 | from PIL import Image 10 | import skimage.io as io 11 | import matplotlib.pyplot as plt 12 | from matplotlib.collections import PatchCollection 13 | from matplotlib.patches import Polygon, Rectangle 14 | from torch.utils.data import Dataset 15 | import webdataset as wds 16 | 17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 19 | 20 | 21 | class OCRVQADataset(Dataset): 22 | def __init__(self, vis_processor, text_processor, vis_root, ann_path): 23 | """ 24 | vis_root (string): Root directory of images (e.g. coco/images/) 25 | ann_root (string): directory to store the annotation file 26 | """ 27 | self.vis_root = vis_root 28 | 29 | self.vis_processor = vis_processor 30 | self.text_processor = text_processor 31 | self.data = self.create_data(ann_path) 32 | 33 | self.instruction_pool =[ 34 | "[vqa] {}", 35 | "[vqa] Based on the image, respond to this question with a short answer: {}" 36 | ] 37 | 38 | def create_data(self, ann_path): 39 | processed_data = [] 40 | with open(ann_path, 'r') as f: 41 | data = json.load(f) 42 | for k in data.keys(): 43 | if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test 44 | ext = os.path.splitext(data[k]['imageURL'])[1] 45 | imageFile = k + ext 46 | assert len(data[k]['questions']) == len(data[k]['answers']) 47 | for q, a in zip(data[k]['questions'], data[k]['answers']): 48 | processed_data.append( 49 | {'question': q, 50 | 'answer': a, 51 | 'image_path': imageFile, 52 | 'image_id': k, 53 | 'title': data[k]['title'], 54 | 'genre': data[k]['genre'], 55 | } 56 | ) 57 | return processed_data 58 | 59 | def __len__(self): 60 | return len(self.data) 61 | 62 | def __getitem__(self, index): 63 | sample = self.data[index] 64 | image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB") 65 | image = self.vis_processor(image) 66 | question = self.text_processor(sample["question"]) 67 | answer = self.text_processor(sample["answer"]) 68 | 69 | instruction = random.choice(self.instruction_pool).format(question) 70 | instruction = " {} ".format(instruction) 71 | return { 72 | "image": image, 73 | "instruction_input": instruction, 74 | "answer": answer, 75 | "image_id": sample['image_id'] 76 | } 77 | 78 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/text_caps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import random 5 | import time 6 | import itertools 7 | 8 | import numpy as np 9 | from PIL import Image 10 | import skimage.io as io 11 | import matplotlib.pyplot as plt 12 | from matplotlib.collections import PatchCollection 13 | from matplotlib.patches import Polygon, Rectangle 14 | from torch.utils.data import Dataset 15 | import webdataset as wds 16 | 17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 19 | 20 | 21 | 22 | class TextCapDataset(Dataset): 23 | def __init__(self, vis_processor, text_processor, vis_root, ann_path): 24 | """ 25 | vis_root (string): Root directory of images (e.g. coco/images/) 26 | ann_root (string): directory to store the annotation file 27 | """ 28 | self.vis_root = vis_root 29 | 30 | self.vis_processor = vis_processor 31 | self.text_processor = text_processor 32 | 33 | self.instruction_pool = [ 34 | 'Briefly describe this image.', 35 | 'Provide a concise depiction of this image.', 36 | 'Present a short description of this image.', 37 | 'Summarize this image in a few words.', 38 | 'A short image caption:', 39 | 'A short image description:', 40 | 'A photo of ', 41 | 'An image that shows ', 42 | 'Write a short description for the image. ', 43 | 'Write a description for the photo.', 44 | 'Provide a description of what is presented in the photo.', 45 | 'Briefly describe the content of the image.', 46 | 'Can you briefly explain what you see in the image?', 47 | 'Could you use a few words to describe what you perceive in the photo?', 48 | 'Please provide a short depiction of the picture.', 49 | 'Using language, provide a short account of the image.', 50 | 'Use a few words to illustrate what is happening in the picture.', 51 | ] 52 | 53 | with open(ann_path, 'r') as f: 54 | self.ann = json.load(f) 55 | 56 | 57 | def __len__(self): 58 | return len(self.ann["data"]) 59 | 60 | 61 | def __getitem__(self, index): 62 | info = self.ann["data"][index] 63 | 64 | image_file = '{}.jpg'.format(info['image_id']) 65 | 66 | image_path = os.path.join(self.vis_root, image_file) 67 | image = Image.open(image_path).convert("RGB") 68 | image = self.vis_processor(image) 69 | 70 | caption = info["caption_str"] 71 | caption = self.text_processor(caption) 72 | instruction = " [caption] {} ".format(random.choice(self.instruction_pool)) 73 | return { 74 | "image": image, 75 | "instruction_input": instruction, 76 | "answer": caption, 77 | } 78 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/unnatural_instruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import random 5 | import time 6 | import itertools 7 | 8 | import numpy as np 9 | from PIL import Image 10 | import skimage.io as io 11 | import matplotlib.pyplot as plt 12 | from matplotlib.collections import PatchCollection 13 | from matplotlib.patches import Polygon, Rectangle 14 | from torch.utils.data import Dataset 15 | import webdataset as wds 16 | 17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset 18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset 19 | 20 | 21 | class UnnaturalDataset(Dataset): 22 | def __init__(self, text_processor, ann_path): 23 | """ 24 | vis_root (string): Root directory of images (e.g. coco/images/) 25 | ann_root (string): directory to store the annotation file 26 | """ 27 | self.text_processor = text_processor 28 | 29 | with open(ann_path, 'r') as f: 30 | self.ann = json.load(f) 31 | 32 | def __len__(self): 33 | return len(self.ann) 34 | 35 | def __getitem__(self, index): 36 | info = self.ann[index]["instances"][0] 37 | instruction = info["instruction_with_input"] 38 | constraints = info["constraints"] 39 | answer = info["output"] 40 | if constraints != None: 41 | instruction = instruction+" "+constraints 42 | 43 | return { 44 | "instruction_input": self.text_processor(instruction), 45 | "answer": self.text_processor(answer), 46 | } 47 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/vg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import random 5 | import time 6 | import itertools 7 | 8 | import numpy as np 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from visual_genome import local 12 | 13 | 14 | 15 | 16 | class ReferVisualGenomeDataset(Dataset): 17 | def __init__(self, vis_processor, text_processor, data_dir): 18 | """ 19 | vis_root (string): Root directory of images (e.g. coco/images/) 20 | ann_root (string): directory to store the annotation file 21 | """ 22 | self.data_dir = data_dir 23 | 24 | self.vis_processor = vis_processor 25 | self.text_processor = text_processor 26 | 27 | all_regions = local.get_all_region_descriptions(self.data_dir) 28 | all_regions = [region for regions in all_regions for region in regions] 29 | 30 | # follow OFA practice, only regions smaller than 16384 pixels are used for refer 31 | self.regions = [region for region in all_regions if region.width * region.height < 16384] 32 | 33 | 34 | self.instruction_pool = [ 35 | "[refer] {}", 36 | "[refer] give me the location of {}", 37 | "[refer] where is {} ?", 38 | "[refer] from this image, tell me the location of {}", 39 | "[refer] the location of {} is", 40 | "[refer] could you tell me the location for {} ?", 41 | "[refer] where can I locate the {} ?", 42 | ] 43 | 44 | 45 | def __len__(self): 46 | return len(self.regions) 47 | 48 | def preprocess(self, index): 49 | region = self.regions[index] 50 | image_file = region.image.url.split('/')[-2:] 51 | image_path = os.path.join(self.data_dir, *image_file) 52 | image = Image.open(image_path).convert("RGB") 53 | image_orig_size = image.size 54 | image = self.vis_processor(image) 55 | image_new_size = [100,100] 56 | 57 | sample_sentence = region.phrase 58 | refer_sentence = self.text_processor(sample_sentence) 59 | 60 | bbox = [region.x, region.y, region.width, region.height] 61 | 62 | bbox = [ 63 | bbox[0] / image_orig_size[0] * image_new_size[0], 64 | bbox[1] / image_orig_size[1] * image_new_size[1], 65 | (bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0], 66 | (bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1] 67 | ] 68 | bbox = [int(x) for x in bbox] 69 | bbox = "{{<{}><{}><{}><{}>}}".format(*bbox) 70 | return { 71 | "image": image, 72 | "refer_sentence": refer_sentence, 73 | "bbox": bbox, 74 | "image_id": region.image.id, 75 | } 76 | 77 | def __getitem__(self, index): 78 | data = self.preprocess(index) 79 | instruction = random.choice(self.instruction_pool).format(data['refer_sentence']) 80 | 81 | instruction = " {} ".format(instruction) 82 | 83 | return { 84 | "image": data['image'], 85 | "instruction_input": instruction, 86 | "answer": data['bbox'], 87 | "image_id": data['image_id'], 88 | } 89 | 90 | 91 | -------------------------------------------------------------------------------- /minigpt4/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.processors.base_processor import BaseProcessor 9 | from minigpt4.processors.blip_processors import ( 10 | Blip2ImageTrainProcessor, 11 | Blip2ImageEvalProcessor, 12 | BlipCaptionProcessor, 13 | ) 14 | 15 | from minigpt4.common.registry import registry 16 | 17 | __all__ = [ 18 | "BaseProcessor", 19 | "Blip2ImageTrainProcessor", 20 | "Blip2ImageEvalProcessor", 21 | "BlipCaptionProcessor", 22 | ] 23 | 24 | 25 | def load_processor(name, cfg=None): 26 | """ 27 | Example 28 | 29 | >>> processor = load_processor("alpro_video_train", cfg=None) 30 | """ 31 | processor = registry.get_processor_class(name).from_config(cfg) 32 | 33 | return processor 34 | -------------------------------------------------------------------------------- /minigpt4/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /minigpt4/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.runners.runner_base import RunnerBase 9 | 10 | __all__ = ["RunnerBase"] 11 | -------------------------------------------------------------------------------- /minigpt4/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask 11 | 12 | 13 | def setup_task(cfg): 14 | assert "task" in cfg.run_cfg, "Task name must be provided." 15 | 16 | task_name = cfg.run_cfg.task 17 | task = registry.get_task_class(task_name).setup_task(cfg=cfg) 18 | assert task is not None, "Task {} not properly registered.".format(task_name) 19 | 20 | return task 21 | 22 | 23 | __all__ = [ 24 | "BaseTask", 25 | "ImageTextPretrainTask", 26 | ] 27 | -------------------------------------------------------------------------------- /minigpt4/tasks/image_text_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | 11 | 12 | @registry.register_task("image_text_pretrain") 13 | class ImageTextPretrainTask(BaseTask): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def evaluation(self, model, data_loader, cuda_enabled=True): 18 | pass 19 | -------------------------------------------------------------------------------- /share4v/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Share4VLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /share4v/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /share4v/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | 6 | import openai 7 | import ray 8 | import tqdm 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | 13 | @ray.remote(num_cpus=4) 14 | def get_eval(content: str, max_tokens: int): 15 | while True: 16 | try: 17 | response = openai.ChatCompletion.create( 18 | model='gpt-4', 19 | messages=[{ 20 | 'role': 'system', 21 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 22 | }, { 23 | 'role': 'user', 24 | 'content': content, 25 | }], 26 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 27 | max_tokens=max_tokens, 28 | ) 29 | break 30 | except openai.error.RateLimitError: 31 | pass 32 | except Exception as e: 33 | print(e) 34 | time.sleep(NUM_SECONDS_TO_SLEEP) 35 | 36 | print('success!') 37 | return response['choices'][0]['message']['content'] 38 | 39 | 40 | def parse_score(review): 41 | try: 42 | score_pair = review.split('\n')[0] 43 | score_pair = score_pair.replace(',', ' ') 44 | sp = score_pair.split(' ') 45 | if len(sp) == 2: 46 | return [float(sp[0]), float(sp[1])] 47 | else: 48 | print('error', review) 49 | return [-1, -1] 50 | except Exception as e: 51 | print(e) 52 | print('error', review) 53 | return [-1, -1] 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser( 58 | description='ChatGPT-based QA evaluation.') 59 | parser.add_argument('-q', '--question') 60 | # parser.add_argument('-a', '--answer') 61 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 62 | parser.add_argument('-r', '--rule') 63 | parser.add_argument('-o', '--output') 64 | parser.add_argument('--max-tokens', type=int, default=1024, 65 | help='maximum number of tokens produced in the output') 66 | args = parser.parse_args() 67 | 68 | ray.init() 69 | 70 | f_q = open(os.path.expanduser(args.question)) 71 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 72 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 73 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 74 | 75 | review_file = open(f'{args.output}', 'w') 76 | 77 | js_list = [] 78 | handles = [] 79 | idx = 0 80 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 81 | # if idx == 1: 82 | # break 83 | 84 | ques = json.loads(ques_js) 85 | ans1 = json.loads(ans1_js) 86 | ans2 = json.loads(ans2_js) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | rule = rule_dict['default'] 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Question]\n{ques["text"]}\n\n' 96 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 97 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 98 | f'[System]\n{prompt}\n\n') 99 | js_list.append({ 100 | 'id': idx+1, 101 | 'question_id': ques['question_id'], 102 | 'answer1_id': ans1['answer_id'], 103 | 'answer2_id': ans2['answer_id'], 104 | 'category': category}) 105 | idx += 1 106 | handles.append(get_eval.remote(content, args.max_tokens)) 107 | # To avoid the rate limit set by OpenAI 108 | time.sleep(NUM_SECONDS_TO_SLEEP) 109 | 110 | reviews = ray.get(handles) 111 | for idx, review in enumerate(reviews): 112 | scores = parse_score(review) 113 | js_list[idx]['content'] = review 114 | js_list[idx]['tuple'] = scores 115 | review_file.write(json.dumps(js_list[idx]) + '\n') 116 | review_file.close() 117 | -------------------------------------------------------------------------------- /share4v/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | 6 | def eval_pope(answers, label_file): 7 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 8 | 9 | for answer in answers: 10 | text = answer['text'] 11 | 12 | # Only keep the first sentence 13 | if text.find('.') != -1: 14 | text = text.split('.')[0] 15 | 16 | text = text.replace(',', '') 17 | words = text.split(' ') 18 | if 'No' in words or 'not' in words or 'no' in words: 19 | answer['text'] = 'no' 20 | else: 21 | answer['text'] = 'yes' 22 | 23 | for i in range(len(label_list)): 24 | if label_list[i] == 'no': 25 | label_list[i] = 0 26 | else: 27 | label_list[i] = 1 28 | 29 | pred_list = [] 30 | for answer in answers: 31 | if answer['text'] == 'no': 32 | pred_list.append(0) 33 | else: 34 | pred_list.append(1) 35 | 36 | pos = 1 37 | neg = 0 38 | yes_ratio = pred_list.count(1) / len(pred_list) 39 | 40 | TP, TN, FP, FN = 0, 0, 0, 0 41 | for pred, label in zip(pred_list, label_list): 42 | if pred == pos and label == pos: 43 | TP += 1 44 | elif pred == pos and label == neg: 45 | FP += 1 46 | elif pred == neg and label == neg: 47 | TN += 1 48 | elif pred == neg and label == pos: 49 | FN += 1 50 | 51 | print('TP\tFP\tTN\tFN\t') 52 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 53 | 54 | precision = float(TP) / float(TP + FP) 55 | recall = float(TP) / float(TP + FN) 56 | f1 = 2*precision*recall / (precision + recall) 57 | acc = (TP + TN) / (TP + TN + FP + FN) 58 | print('Accuracy: {}'.format(acc)) 59 | print('Precision: {}'.format(precision)) 60 | print('Recall: {}'.format(recall)) 61 | print('F1 score: {}'.format(f1)) 62 | print('Yes ratio: {}'.format(yes_ratio)) 63 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % 64 | (f1, acc, precision, recall, yes_ratio)) 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--annotation-dir", type=str) 70 | parser.add_argument("--question-file", type=str) 71 | parser.add_argument("--result-file", type=str) 72 | args = parser.parse_args() 73 | 74 | questions = [json.loads(line) for line in open(args.question_file)] 75 | questions = {question['question_id']: question for question in questions} 76 | answers = [json.loads(q) for q in open(args.result_file)] 77 | for file in os.listdir(args.annotation_dir): 78 | assert file.startswith('coco_pope_') 79 | assert file.endswith('.json') 80 | category = file[10:-5] 81 | cur_answers = [ 82 | x for x in answers if questions[x['question_id']]['category'] == category] 83 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 84 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 85 | print("====================================") 86 | -------------------------------------------------------------------------------- /share4v/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /share4v/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from share4v.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /share4v/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import shortuuid 6 | import torch 7 | from tqdm import tqdm 8 | from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria 9 | 10 | from share4v.conversation import default_conversation 11 | from share4v.utils import disable_torch_init 12 | 13 | 14 | # new stopping implementation 15 | class KeywordsStoppingCriteria(StoppingCriteria): 16 | def __init__(self, keywords, tokenizer, input_ids): 17 | self.keywords = keywords 18 | self.tokenizer = tokenizer 19 | self.start_len = None 20 | self.input_ids = input_ids 21 | 22 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 23 | if self.start_len is None: 24 | self.start_len = self.input_ids.shape[1] 25 | else: 26 | outputs = self.tokenizer.batch_decode( 27 | output_ids[:, self.start_len:], skip_special_tokens=True)[0] 28 | for keyword in self.keywords: 29 | if keyword in outputs: 30 | return True 31 | return False 32 | 33 | 34 | @torch.inference_mode() 35 | def eval_model(model_name, questions_file, answers_file): 36 | # Model 37 | disable_torch_init() 38 | model_name = os.path.expanduser(model_name) 39 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 40 | model = AutoModelForCausalLM.from_pretrained(model_name, 41 | torch_dtype=torch.float16).cuda() 42 | 43 | ques_file = open(os.path.expanduser(questions_file), "r") 44 | ans_file = open(os.path.expanduser(answers_file), "w") 45 | for i, line in enumerate(tqdm(ques_file)): 46 | idx = json.loads(line)["question_id"] 47 | qs = json.loads(line)["text"] 48 | cat = json.loads(line)["category"] 49 | conv = default_conversation.copy() 50 | conv.append_message(conv.roles[0], qs) 51 | prompt = conv.get_prompt() 52 | inputs = tokenizer([prompt]) 53 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 54 | stopping_criteria = KeywordsStoppingCriteria( 55 | [conv.sep], tokenizer, input_ids) 56 | output_ids = model.generate( 57 | input_ids, 58 | do_sample=True, 59 | use_cache=True, 60 | temperature=0.7, 61 | max_new_tokens=1024, 62 | stopping_criteria=[stopping_criteria]) 63 | outputs = tokenizer.batch_decode( 64 | output_ids, skip_special_tokens=True)[0] 65 | try: 66 | index = outputs.index(conv.sep, len(prompt)) 67 | except ValueError: 68 | outputs += conv.sep 69 | index = outputs.index(conv.sep, len(prompt)) 70 | 71 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 72 | ans_id = shortuuid.uuid() 73 | ans_file.write(json.dumps({"question_id": idx, 74 | "text": outputs, 75 | "answer_id": ans_id, 76 | "model_id": model_name, 77 | "metadata": {}}) + "\n") 78 | ans_file.flush() 79 | ans_file.close() 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 85 | parser.add_argument("--question-file", type=str, 86 | default="tables/question.jsonl") 87 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 88 | args = parser.parse_args() 89 | 90 | eval_model(args.model_name, args.question_file, args.answers_file) 91 | -------------------------------------------------------------------------------- /share4v/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import concurrent.futures 5 | import json 6 | import os 7 | import time 8 | 9 | import openai 10 | import shortuuid 11 | import tqdm 12 | 13 | openai.api_key = 'xxx' # replace with your own key 14 | MODEL = 'gpt-3.5-turbo' 15 | MODEL_ID = 'gpt-3.5-turbo:20230327' 16 | 17 | 18 | def get_answer(question_id: int, question: str, max_tokens: int): 19 | ans = { 20 | 'answer_id': shortuuid.uuid(), 21 | 'question_id': question_id, 22 | 'model_id': MODEL_ID, 23 | } 24 | for _ in range(3): 25 | try: 26 | response = openai.ChatCompletion.create( 27 | model=MODEL, 28 | messages=[{ 29 | 'role': 'system', 30 | 'content': 'You are a helpful assistant.' 31 | }, { 32 | 'role': 'user', 33 | 'content': question, 34 | }], 35 | temperature=0, 36 | max_tokens=max_tokens, 37 | ) 38 | ans['text'] = response['choices'][0]['message']['content'] 39 | return ans 40 | except Exception as e: 41 | print('[ERROR]', e) 42 | ans['text'] = '#ERROR#' 43 | time.sleep(1) 44 | return ans 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 49 | parser.add_argument('-q', '--question') 50 | parser.add_argument('-o', '--output') 51 | parser.add_argument('--max-tokens', type=int, default=1024, 52 | help='maximum number of tokens produced in the output') 53 | args = parser.parse_args() 54 | 55 | questions_dict = {} 56 | with open(os.path.expanduser(args.question)) as f: 57 | for line in f: 58 | if not line: 59 | continue 60 | q = json.loads(line) 61 | questions_dict[q['question_id']] = q['text'] 62 | 63 | answers = [] 64 | 65 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 66 | futures = [] 67 | for qid, question in questions_dict.items(): 68 | future = executor.submit( 69 | get_answer, qid, question, args.max_tokens) 70 | futures.append(future) 71 | 72 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 73 | answers.append(future.result()) 74 | 75 | answers.sort(key=lambda x: x['question_id']) 76 | 77 | with open(os.path.expanduser(args.output), 'w') as f: 78 | table = [json.dumps(ans) for ans in answers] 79 | f.write('\n'.join(table)) 80 | -------------------------------------------------------------------------------- /share4v/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='ChatGPT-based QA evaluation.') 12 | parser.add_argument('-d', '--dir', default=None) 13 | parser.add_argument('-v', '--version', default=None) 14 | parser.add_argument('-s', '--select', nargs='*', default=None) 15 | parser.add_argument('-f', '--files', nargs='*', default=[]) 16 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 17 | return parser.parse_args() 18 | 19 | 20 | if __name__ == '__main__': 21 | args = parse_args() 22 | 23 | if args.ignore is not None: 24 | args.ignore = [int(x) for x in args.ignore] 25 | 26 | if len(args.files) > 0: 27 | review_files = args.files 28 | else: 29 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith( 30 | 'gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 31 | 32 | for review_file in sorted(review_files): 33 | config = os.path.basename(review_file).replace( 34 | 'gpt4_text_', '').replace('.jsonl', '') 35 | if args.select is not None and any(x not in config for x in args.select): 36 | continue 37 | if '0613' in config: 38 | version = '0613' 39 | else: 40 | version = '0314' 41 | if args.version is not None and args.version != version: 42 | continue 43 | scores = defaultdict(list) 44 | print(config) 45 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 46 | for review_str in f: 47 | review = json.loads(review_str) 48 | if review['question_id'] in args.ignore: 49 | continue 50 | if 'category' in review: 51 | scores[review['category']].append(review['tuple']) 52 | scores['all'].append(review['tuple']) 53 | else: 54 | if 'tuple' in review: 55 | scores['all'].append(review['tuple']) 56 | else: 57 | scores['all'].append(review['score']) 58 | for k, v in sorted(scores.items()): 59 | stats = np.asarray(v).mean(0).tolist() 60 | stats = [round(x, 3) for x in stats] 61 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 62 | print(k, round(stats[1]/stats[0]*100, 1), 63 | round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 64 | print('=================================') 65 | -------------------------------------------------------------------------------- /share4v/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.share4v_llama import (Share4VConfig, 2 | Share4VLlamaForCausalLM) 3 | from .multimodal_encoder.configuration_evaclip import EvaCLIPVisionConfig 4 | from .multimodal_encoder.modeling_evaclip import EvaCLIPVisionModel 5 | -------------------------------------------------------------------------------- /share4v/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m share4v.model.consolidate --src ~/model_weights/share4v-7b --dst ~/model_weights/share4v-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | from share4v.model import * 11 | from share4v.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained( 18 | src_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) 19 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 20 | src_model.save_pretrained(dst_path) 21 | src_tokenizer.save_pretrained(dst_path) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--src", type=str, required=True) 27 | parser.add_argument("--dst", type=str, required=True) 28 | 29 | args = parser.parse_args() 30 | 31 | consolidate_ckpt(args.src, args.dst) 32 | -------------------------------------------------------------------------------- /share4v/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .clip_encoder import CLIPVisionTower 4 | 5 | 6 | def build_vision_tower(vision_tower_cfg, **kwargs): 7 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 8 | is_absolute_path_exists = os.path.exists(vision_tower) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or vision_tower.startswith("Lin-Chen"): 10 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 11 | 12 | raise ValueError(f'Unknown vision tower: {vision_tower}') 13 | -------------------------------------------------------------------------------- /share4v/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel 4 | 5 | from .configuration_evaclip import EvaCLIPVisionConfig 6 | from .modeling_evaclip import EvaCLIPVisionModel 7 | 8 | 9 | class CLIPVisionTower(nn.Module): 10 | def __init__(self, vision_tower, args, delay_load=False): 11 | super().__init__() 12 | 13 | self.is_loaded = False 14 | 15 | self.vision_tower_name = vision_tower 16 | self.select_layer = args.mm_vision_select_layer 17 | self.select_feature = getattr( 18 | args, 'mm_vision_select_feature', 'patch') 19 | 20 | if not delay_load: 21 | self.load_model() 22 | else: 23 | self.cfg_only = CLIPVisionConfig.from_pretrained( 24 | self.vision_tower_name) 25 | 26 | def load_model(self): 27 | print(f'Load vision tower from {self.vision_tower_name}') 28 | self.image_processor = CLIPImageProcessor.from_pretrained( 29 | self.vision_tower_name) 30 | if 'eva' in self.vision_tower_name.lower(): 31 | vision_cfg = EvaCLIPVisionConfig.from_pretrained( 32 | self.vision_tower_name) 33 | self.vision_tower = EvaCLIPVisionModel.from_pretrained( 34 | self.vision_tower_name, config=vision_cfg) 35 | else: 36 | self.vision_tower = CLIPVisionModel.from_pretrained( 37 | self.vision_tower_name) 38 | self.vision_tower.requires_grad_(False) 39 | 40 | self.is_loaded = True 41 | 42 | def feature_select(self, image_forward_outs): 43 | image_features = image_forward_outs.hidden_states[self.select_layer] 44 | if self.select_feature == 'patch': 45 | image_features = image_features[:, 1:] 46 | elif self.select_feature == 'cls_patch': 47 | image_features = image_features 48 | else: 49 | raise ValueError( 50 | f'Unexpected select feature: {self.select_feature}') 51 | return image_features 52 | 53 | # @torch.no_grad() comment to enable fine-tune vit 54 | def forward(self, images): 55 | if type(images) is list: 56 | image_features = [] 57 | for image in images: 58 | image_forward_out = self.vision_tower(image.to( 59 | device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 60 | image_feature = self.feature_select( 61 | image_forward_out).to(image.dtype) 62 | image_features.append(image_feature) 63 | else: 64 | image_forward_outs = self.vision_tower( 65 | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 66 | image_features = self.feature_select( 67 | image_forward_outs).to(images.dtype) 68 | 69 | return image_features 70 | 71 | @property 72 | def dummy_feature(self): 73 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 74 | 75 | @property 76 | def dtype(self): 77 | return self.vision_tower.dtype 78 | 79 | @property 80 | def device(self): 81 | return self.vision_tower.device 82 | 83 | @property 84 | def config(self): 85 | if self.is_loaded: 86 | return self.vision_tower.config 87 | else: 88 | return self.cfg_only 89 | 90 | @property 91 | def hidden_size(self): 92 | return self.config.hidden_size 93 | 94 | @property 95 | def num_patches(self): 96 | return (self.config.image_size // self.config.patch_size) ** 2 97 | -------------------------------------------------------------------------------- /share4v/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.pre_norm(x) 31 | return x + self.proj(x) 32 | 33 | 34 | def build_vision_projector(config, delay_load=False, **kwargs): 35 | projector_type = getattr(config, 'mm_projector_type', 'linear') 36 | 37 | if projector_type == 'linear': 38 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 39 | 40 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 41 | if mlp_gelu_match: 42 | mlp_depth = int(mlp_gelu_match.group(1)) 43 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 44 | for _ in range(1, mlp_depth): 45 | modules.append(nn.GELU()) 46 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 47 | return nn.Sequential(*modules) 48 | 49 | if projector_type == 'identity': 50 | return IdentityMap() 51 | 52 | raise ValueError(f'Unknown projector type: {projector_type}') 53 | -------------------------------------------------------------------------------- /share4v/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'share4v' in config and 'share4v' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer ShareGPT4V code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "share4v") 15 | cfg.architectures[0] = 'Share4VLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /share4v/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from share4v.train.train import train 7 | from share4v.train.llama_flash_attn_monkey_patch import \ 8 | replace_llama_attn_with_flash_attn 9 | 10 | replace_llama_attn_with_flash_attn() 11 | 12 | 13 | if __name__ == "__main__": 14 | train() 15 | -------------------------------------------------------------------------------- /share4v/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from share4v.train.train import train 5 | from share4v.train.llama_xformers_attn_monkey_patch import \ 6 | replace_llama_attn_with_xformers_attn 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | 11 | if __name__ == "__main__": 12 | train() 13 | -------------------------------------------------------------------------------- /share4v/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def disable_torch_init(): 5 | """ 6 | Disable the redundant torch default initialization to accelerate model creation. 7 | """ 8 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 9 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 10 | --------------------------------------------------------------------------------