├── .gitignore
├── LICENSE
├── Owl
├── .gitattributes
├── .gitignore
├── LICENSE
├── OwlEval
│ ├── OwlEval.md
│ ├── answer
│ │ ├── MMreact_answer.jsonl
│ │ ├── blip2_13b_answer.jsonl
│ │ ├── llava_13b_answer.jsonl
│ │ ├── mPLUG_Owl_7b_answer.jsonl
│ │ ├── minigpt4_13b_answer.jsonl
│ │ └── openflanmingo_answer.jsonl
│ ├── cases
│ │ ├── 1.jpg
│ │ ├── 10.jpg
│ │ ├── 11.jpg
│ │ ├── 12.jpg
│ │ ├── 13.jpg
│ │ ├── 14.jpg
│ │ ├── 15.jpg
│ │ ├── 16.jpg
│ │ ├── 17.jpg
│ │ ├── 18.jpg
│ │ ├── 19.jpg
│ │ ├── 2.jpg
│ │ ├── 20.jpg
│ │ ├── 21.jpg
│ │ ├── 22.jpg
│ │ ├── 23.jpg
│ │ ├── 24.jpg
│ │ ├── 25.jpg
│ │ ├── 26.jpg
│ │ ├── 27.jpg
│ │ ├── 28.jpg
│ │ ├── 29.jpg
│ │ ├── 3.jpg
│ │ ├── 30.jpg
│ │ ├── 31.jpg
│ │ ├── 32.jpg
│ │ ├── 33.jpg
│ │ ├── 34.jpg
│ │ ├── 35.jpg
│ │ ├── 36.jpg
│ │ ├── 37.jpg
│ │ ├── 38.jpg
│ │ ├── 39.jpg
│ │ ├── 4.jpg
│ │ ├── 40.jpg
│ │ ├── 41.jpg
│ │ ├── 42.jpg
│ │ ├── 43.jpg
│ │ ├── 44.jpg
│ │ ├── 45.jpg
│ │ ├── 46.jpg
│ │ ├── 47.jpg
│ │ ├── 48.jpg
│ │ ├── 49.jpg
│ │ ├── 5.jpg
│ │ ├── 50.jpg
│ │ ├── 6.jpg
│ │ ├── 7.jpg
│ │ ├── 8.jpg
│ │ └── 9.jpg
│ └── questions.jsonl
├── README.md
├── README_zh.md
├── assets
│ ├── -twitter-blue.svg
│ ├── Demo-ModelScope-brightgreen.svg
│ ├── LICENSE-Apache License-blue.svg
│ ├── Paper-Arxiv-orange.svg
│ ├── Paper-PDF-orange.svg
│ └── modelscopeIcon.svg
├── configs
│ └── v0.yaml
├── examples
│ ├── Yao_Ming.jpeg
│ ├── ca.jpeg
│ ├── fridge.jpg
│ ├── laundry.jpeg
│ ├── monalisa-fun.jpg
│ ├── monday.jpg
│ ├── mug_ad.jpeg
│ ├── rap.jpeg
│ ├── titanic.jpeg
│ ├── vga.jpeg
│ └── website.jpg
├── mplug_owl
│ ├── __init__.py
│ ├── configuration_mplug_owl.py
│ ├── modeling_mplug_owl.py
│ ├── processing_mplug_owl.py
│ └── tokenization_mplug_owl.py
├── mplug_owl_video
│ ├── __init__.py
│ ├── configuration_mplug_owl.py
│ ├── modeling_mplug_owl.py
│ ├── processing_mplug_owl.py
│ └── tokenization_mplug_owl.py
├── pipeline
│ ├── __init__.py
│ ├── data_utils
│ │ ├── __init__.py
│ │ ├── processors
│ │ │ ├── __init__.py
│ │ │ ├── builder.py
│ │ │ ├── caption_processor.py
│ │ │ └── default_processor.py
│ │ ├── randaugment.py
│ │ ├── registry.py
│ │ └── xgpt3_dataset.py
│ ├── interface.py
│ ├── train.py
│ └── utils.py
├── requirements.txt
├── scripts
│ ├── train_it.sh
│ └── train_it_wo_lora.sh
├── serve
│ ├── __init__.py
│ ├── conversation.py
│ ├── gradio_css.py
│ ├── gradio_patch.py
│ ├── io_utils.py
│ ├── model_utils.py
│ ├── model_worker.py
│ ├── serve_utils.py
│ └── web_server.py
└── test-ant.py
├── assets
└── method.png
├── common
├── args.py
├── gpt.py
├── interrupt_wrapper.py
├── models.py
└── utils.py
├── configs
├── minigpt4_infer_fp16-dpo.yaml
├── minigpt4_infer_fp16.yaml
├── minigpt4_infer_fp16_hadpo.yaml
└── minigpt4_train_fp16.yaml
├── dataset
├── caption_prompt.txt
├── caption_prompt_finetune.txt
├── synonyms.txt
└── vqa_prompt.txt
├── deploy
├── dockerfile
├── requirements.txt
├── run.sh
├── run_eval.sh
└── sources.list
├── evaluate
├── eval_auto.py
├── eval_caption.py
├── eval_gqa.py
├── eval_mme.py
├── eval_qbench.py
├── eval_sqa.py
├── eval_utils.py
├── m4c_evaluator.py
└── mme
│ └── calculation.py
├── finetune
├── data.py
├── format_data.py
├── loss.py
├── train.py
└── train_vqa.py
├── insight
├── clip.py
├── object.py
└── plot_time.py
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_llama.py
│ │ ├── llava_mpt.py
│ │ └── mpt
│ │ │ ├── adapt_tokenizer.py
│ │ │ ├── attention.py
│ │ │ ├── blocks.py
│ │ │ ├── configuration_mpt.py
│ │ │ ├── custom_embedding.py
│ │ │ ├── flash_attn_triton.py
│ │ │ ├── hf_prefixlm_converter.py
│ │ │ ├── meta_init_context.py
│ │ │ ├── modeling_mpt.py
│ │ │ ├── norm.py
│ │ │ └── param_init_fns.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── utils.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llama_xformers_attn_monkey_patch.py
│ ├── llava_trainer.py
│ ├── train.py
│ ├── train_mem.py
│ └── train_xformers.py
└── utils.py
├── minigpt4
├── __init__.py
├── common
│ ├── __init__.py
│ ├── config.py
│ ├── dist_utils.py
│ ├── eval_utils.py
│ ├── gradcam.py
│ ├── logger.py
│ ├── optims.py
│ ├── registry.py
│ ├── utils.py
│ └── vqa_tools
│ │ ├── VQA
│ │ ├── PythonEvaluationTools
│ │ │ ├── vqaEvalDemo.py
│ │ │ └── vqaEvaluation
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqaEval.py
│ │ ├── PythonHelperTools
│ │ │ ├── vqaDemo.py
│ │ │ └── vqaTools
│ │ │ │ ├── __init__.py
│ │ │ │ └── vqa.py
│ │ ├── QuestionTypes
│ │ │ ├── abstract_v002_question_types.txt
│ │ │ └── mscoco_question_types.txt
│ │ ├── README.md
│ │ └── license.txt
│ │ ├── __init__.py
│ │ ├── vqa.py
│ │ └── vqa_eval.py
├── configs
│ ├── datasets
│ │ ├── aokvqa
│ │ │ └── defaults.yaml
│ │ ├── cc_sbu
│ │ │ ├── align.yaml
│ │ │ └── defaults.yaml
│ │ ├── coco
│ │ │ ├── caption.yaml
│ │ │ └── defaults_vqa.yaml
│ │ ├── coco_bbox
│ │ │ ├── invrefcoco.yaml
│ │ │ ├── invrefcocog.yaml
│ │ │ ├── invrefcocop.yaml
│ │ │ ├── refcoco.yaml
│ │ │ ├── refcocog.yaml
│ │ │ └── refcocop.yaml
│ │ ├── flickr
│ │ │ ├── caption_to_phrase.yaml
│ │ │ ├── default.yaml
│ │ │ └── object_to_phrase.yaml
│ │ ├── gqa
│ │ │ └── balanced_val.yaml
│ │ ├── laion
│ │ │ └── defaults.yaml
│ │ ├── llava
│ │ │ ├── conversation.yaml
│ │ │ ├── detail.yaml
│ │ │ └── reason.yaml
│ │ ├── multitask_conversation
│ │ │ └── default.yaml
│ │ ├── nlp
│ │ │ └── unnatural_instruction.yaml
│ │ ├── ocrvqa
│ │ │ └── ocrvqa.yaml
│ │ ├── okvqa
│ │ │ └── defaults.yaml
│ │ ├── textcaps
│ │ │ └── caption.yaml
│ │ └── vg
│ │ │ └── ref.yaml
│ ├── default.yaml
│ └── models
│ │ ├── minigpt4_llama2.yaml
│ │ ├── minigpt4_vicuna0.yaml
│ │ └── minigpt_v2.yaml
├── conversation
│ ├── __init__.py
│ └── conversation.py
├── datasets
│ ├── __init__.py
│ ├── builders
│ │ ├── __init__.py
│ │ ├── base_dataset_builder.py
│ │ └── image_text_pair_builder.py
│ ├── data_utils.py
│ └── datasets
│ │ ├── __init__.py
│ │ ├── aok_vqa_datasets.py
│ │ ├── base_dataset.py
│ │ ├── caption_datasets.py
│ │ ├── cc_sbu_dataset.py
│ │ ├── coco_caption.py
│ │ ├── coco_dataset.py
│ │ ├── coco_vqa_datasets.py
│ │ ├── dataloader_utils.py
│ │ ├── flickr.py
│ │ ├── gqa_datasets.py
│ │ ├── laion_dataset.py
│ │ ├── llava_dataset.py
│ │ ├── multitask_conversation.py
│ │ ├── ocrvqa_dataset.py
│ │ ├── text_caps.py
│ │ ├── unnatural_instruction.py
│ │ ├── vg_dataset.py
│ │ └── vqa_datasets.py
├── models
│ ├── Qformer.py
│ ├── __init__.py
│ ├── base_model.py
│ ├── eva_vit.py
│ ├── minigpt4.py
│ ├── minigpt_base.py
│ ├── minigpt_v2.py
│ └── modeling_llama.py
├── processors
│ ├── __init__.py
│ ├── base_processor.py
│ ├── blip_processors.py
│ └── randaugment.py
├── runners
│ ├── __init__.py
│ └── runner_base.py
└── tasks
│ ├── __init__.py
│ ├── base_task.py
│ └── image_text_pretrain.py
├── readme.md
└── share4v
├── __init__.py
├── constants.py
├── conversation.py
├── eval
├── eval_gpt_review.py
├── eval_gpt_review_bench.py
├── eval_gpt_review_visual.py
├── eval_pope.py
├── eval_science_qa.py
├── eval_science_qa_gpt4.py
├── eval_science_qa_gpt4_requery.py
├── eval_textvqa.py
├── m4c_evaluator.py
├── model_qa.py
├── model_vqa.py
├── model_vqa_loader.py
├── model_vqa_mmbench.py
├── model_vqa_qbench.py
├── model_vqa_science.py
├── qa_baseline_gpt35.py
├── run_share4v.py
├── summarize_gpt_review.py
└── table
│ └── rule.json
├── mm_utils.py
├── model
├── __init__.py
├── builder.py
├── consolidate.py
├── language_model
│ └── share4v_llama.py
├── multimodal_encoder
│ ├── builder.py
│ ├── clip_encoder.py
│ ├── configuration_evaclip.py
│ └── modeling_evaclip.py
├── multimodal_projector
│ └── builder.py
├── share4v_arch.py
└── utils.py
├── train
├── llama_flash_attn_monkey_patch.py
├── llama_xformers_attn_monkey_patch.py
├── share4v_trainer.py
├── train.py
├── train_mem.py
└── train_xformers.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | dataset/data
2 | __pycache__
3 | checkpoint*
4 | tmp.txt
5 | finetune/sent_stat.py
6 | *.bak
7 | insight/search.py
8 | common/gpt_free.py
9 |
--------------------------------------------------------------------------------
/Owl/.gitattributes:
--------------------------------------------------------------------------------
1 | *.py eol=lf
2 | *.rst eol=lf
3 | *.md eol=lf
4 | *.mdx eol=lf
--------------------------------------------------------------------------------
/Owl/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # tests and logs
12 | tests/fixtures/cached_*_text.txt
13 | logs/
14 | lightning_logs/
15 | lang_code_data/
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 | .dmypy.json
119 | dmypy.json
120 |
121 | # Pyre type checker
122 | .pyre/
123 |
124 | # vscode
125 | .vs
126 | .vscode
127 |
128 | # Pycharm
129 | .idea
130 |
131 | # TF code
132 | tensorflow_code
133 |
134 | # Models
135 | proc_data
136 |
137 | # examples
138 | runs
139 | /runs_old
140 | /wandb
141 | /output
142 | /configs_dev
143 | /scripts_dev
144 | # /examples/runs
145 | # /examples/**/*.args
146 | # /examples/rag/sweep
147 |
148 | # data
149 | /data
150 | serialization_dir
151 |
152 | # emacs
153 | *.*~
154 | debug.env
155 |
156 | # vim
157 | .*.swp
158 |
159 | #ctags
160 | tags
161 |
162 | # pre-commit
163 | .pre-commit*
164 |
165 | # .lock
166 | *.lock
167 |
168 | # DS_Store (MacOS)
169 | .DS_Store
170 |
171 | # ruff
172 | .ruff_cache
--------------------------------------------------------------------------------
/Owl/OwlEval/OwlEval.md:
--------------------------------------------------------------------------------
1 | # OwlEval
2 |
3 | We have compiled some examples and their corresponding questions from recent open-source work, and organized them into OwlEval.
4 |
5 | Following we will introduce the OwlEval and the data format in this document.
6 |
7 | ## Data Format
8 |
9 | ### questions
10 |
11 | `questions.jsonl` contains case images and information about their corresponding questions
12 |
13 | Each row contains the following field:
14 |
15 | - `image`: Indicate the name of the picture
16 | - `question_id`: Indicate the question id number, there are 82 questions
17 |
18 | - `question`: Represent specific problem information
19 | - `ability`: Represent the required abilities of the model. Detailed definiton can be found in paper.
20 | - `a` is `Instruction Understanding`
21 | - `b` is `Visual Understanding`
22 | - `c` is `Optical Character Recognition`
23 | - `d` is `Knowledge Transfer Ability`
24 | - `e` is `Reasoning Ability`
25 | - `f` is `Multi-turn Dialogue Ability`
26 | - `category`:Indicate whether the problem is a single-turn problem or a multi-turn problem
27 |
28 | For example:
29 |
30 | ```json
31 | {"image": "1.jpg", "question_id": 1, "question": "What is funny about this image? Describe it panel by panel.", "ability": "a,b,e", "category": ["single_turn"]}
32 | ```
33 |
34 | ### answer
35 |
36 | This contains the responses of each model for each question, integrated into six jsonl:
37 |
38 | `llava_13b_answer.jsonl`
39 |
40 | `minigpt4_13b_answer.jsonl`
41 |
42 | `MMreact_answer.jsonl`
43 |
44 | `mPLUG_Owl_7b_answer.jsonl`
45 |
46 | `BLIP2_13b_answer.jsonl`
47 |
48 | `openflamingo_answer.jsonl`
49 |
50 | For each `answer/xxx.jsonl` it contains the following information:
51 |
52 | - `image`: Indicate the name of the picture
53 | - `question_id`: Indicate the question id number, there are 82 questions
54 |
55 | - `question`: Represent specific problem information
56 | - `answer`: Replie given by the model
57 | - `model_id`: The ID of the model the answer is generated by
58 |
59 | For example:
60 |
61 | ```json
62 | {"image": "10.jpg", "question_id": 15, "question": "How many bedrooms are there in this floor plan?", "answer": "There are three bedrooms in this floor plan.", "model_id": "llava-13b"}
63 | ```
64 |
65 | ### cases
66 |
67 | This folder contains 50 evaluation pictures, where 21 from mini GPT-4, 13 from mm-react, 9 from blip-2, 3 from GPT-4 and 4 collected by us
68 |
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/1.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/10.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/11.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/12.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/13.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/14.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/15.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/15.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/16.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/17.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/17.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/18.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/18.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/19.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/2.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/20.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/20.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/21.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/21.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/22.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/23.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/23.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/24.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/24.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/25.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/25.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/26.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/26.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/27.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/27.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/28.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/28.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/29.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/29.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/3.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/30.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/30.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/31.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/31.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/32.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/32.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/33.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/33.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/34.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/34.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/35.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/35.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/36.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/36.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/37.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/37.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/38.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/38.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/39.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/39.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/4.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/40.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/40.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/41.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/41.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/42.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/42.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/43.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/43.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/44.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/44.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/45.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/45.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/46.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/46.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/47.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/47.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/48.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/48.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/49.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/49.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/5.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/50.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/50.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/6.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/7.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/8.jpg
--------------------------------------------------------------------------------
/Owl/OwlEval/cases/9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/OwlEval/cases/9.jpg
--------------------------------------------------------------------------------
/Owl/assets/-twitter-blue.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Owl/assets/Demo-ModelScope-brightgreen.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Owl/assets/LICENSE-Apache License-blue.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Owl/assets/Paper-Arxiv-orange.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Owl/assets/Paper-PDF-orange.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Owl/assets/modelscopeIcon.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Owl/configs/v0.yaml:
--------------------------------------------------------------------------------
1 | data_files: [
2 | 'sft_v0.1_train.jsonl',
3 | 'sft_v0.1_dev.jsonl'
4 | ]
5 |
6 | train_processors: {
7 | sft: {type: 'CaptionProcessor', image_size: 224, min_scale: 0.5, randaug: False}
8 | }
9 |
10 | valid_processors: {
11 | sft: {type: 'DefaultProcessor', image_size: 224}
12 | }
--------------------------------------------------------------------------------
/Owl/examples/Yao_Ming.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/Yao_Ming.jpeg
--------------------------------------------------------------------------------
/Owl/examples/ca.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/ca.jpeg
--------------------------------------------------------------------------------
/Owl/examples/fridge.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/fridge.jpg
--------------------------------------------------------------------------------
/Owl/examples/laundry.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/laundry.jpeg
--------------------------------------------------------------------------------
/Owl/examples/monalisa-fun.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/monalisa-fun.jpg
--------------------------------------------------------------------------------
/Owl/examples/monday.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/monday.jpg
--------------------------------------------------------------------------------
/Owl/examples/mug_ad.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/mug_ad.jpeg
--------------------------------------------------------------------------------
/Owl/examples/rap.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/rap.jpeg
--------------------------------------------------------------------------------
/Owl/examples/titanic.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/titanic.jpeg
--------------------------------------------------------------------------------
/Owl/examples/vga.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/vga.jpeg
--------------------------------------------------------------------------------
/Owl/examples/website.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/examples/website.jpg
--------------------------------------------------------------------------------
/Owl/mplug_owl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
17 |
18 |
19 | _import_structure = {
20 | "configuration_mplug_owl": ["MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MplugOwlConfig"],
21 | "processing_mplug_owl": ["MplugOwlImageProcessor", "MplugOwlProcessor"],
22 | "tokenization_mplug_owl": ["MplugOwlTokenizer"],
23 | }
24 |
25 | try:
26 | if not is_tokenizers_available():
27 | raise OptionalDependencyNotAvailable()
28 | except OptionalDependencyNotAvailable:
29 | pass
30 |
31 |
32 | try:
33 | if not is_torch_available():
34 | raise OptionalDependencyNotAvailable()
35 | except OptionalDependencyNotAvailable:
36 | pass
37 | else:
38 | _import_structure["modeling_mplug_owl"] = [
39 | "MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST",
40 | "MplugOwlForConditionalGeneration",
41 | "MplugOwlModel",
42 | ]
43 |
44 |
45 | if TYPE_CHECKING:
46 | from .configuration_mplug_owl import MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP, MplugOwlConfig
47 | from .tokenization_mplug_owl import MplugOwlTokenizer
48 |
49 | try:
50 | if not is_tokenizers_available():
51 | raise OptionalDependencyNotAvailable()
52 | except OptionalDependencyNotAvailable:
53 | pass
54 |
55 | try:
56 | if not is_torch_available():
57 | raise OptionalDependencyNotAvailable()
58 | except OptionalDependencyNotAvailable:
59 | pass
60 | else:
61 | from .modeling_mplug_owl import (
62 | MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST,
63 | MplugOwlForConditionalGeneration,
64 | MplugOwlModel,
65 | MplugOwlPreTrainedModel,
66 | )
67 |
68 |
69 | else:
70 | import sys
71 |
72 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
73 |
74 | from .configuration_mplug_owl import *
75 | from .modeling_mplug_owl import *
76 | from .processing_mplug_owl import *
77 | from .tokenization_mplug_owl import *
78 |
--------------------------------------------------------------------------------
/Owl/mplug_owl/tokenization_mplug_owl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 x-plug and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for MplugOwl."""
16 |
17 | from transformers.utils import logging
18 | from transformers.models.llama.tokenization_llama import LlamaTokenizer
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
24 |
25 | PRETRAINED_VOCAB_FILES_MAP = {
26 | "vocab_file": {
27 | "MAGAer13/mplug-owl-llama-7b": "https://huggingface.co/MAGAer13/mplug-owl-llama-7b/resolve/main/vocab.txt",
28 | },
29 | }
30 |
31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
32 | "MAGAer13/mplug-owl-llama-7b": 1024,
33 | }
34 |
35 |
36 | class MplugOwlTokenizer(LlamaTokenizer):
37 | def __init__(
38 | self,
39 | vocab_file,
40 | unk_token="",
41 | bos_token="",
42 | eos_token="",
43 | pad_token="",
44 | sp_model_kwargs=None,
45 | add_bos_token=False,
46 | add_eos_token=False,
47 | clean_up_tokenization_spaces=False,
48 | **kwargs,
49 | ):
50 | super().__init__(
51 | vocab_file,
52 | unk_token,
53 | bos_token,
54 | eos_token,
55 | pad_token,
56 | sp_model_kwargs,
57 | add_bos_token,
58 | add_eos_token,
59 | clean_up_tokenization_spaces,
60 | **kwargs,
61 | )
62 | self.eod_id = self.eos_token_id
63 |
--------------------------------------------------------------------------------
/Owl/mplug_owl_video/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
17 |
18 |
19 | _import_structure = {
20 | "configuration_mplug_owl": ["MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MplugOwlConfig"],
21 | "processing_mplug_owl": ["MplugOwlImageProcessor", "MplugOwlProcessor"],
22 | "tokenization_mplug_owl": ["MplugOwlTokenizer"],
23 | }
24 |
25 | try:
26 | if not is_tokenizers_available():
27 | raise OptionalDependencyNotAvailable()
28 | except OptionalDependencyNotAvailable:
29 | pass
30 |
31 |
32 | try:
33 | if not is_torch_available():
34 | raise OptionalDependencyNotAvailable()
35 | except OptionalDependencyNotAvailable:
36 | pass
37 | else:
38 | _import_structure["modeling_mplug_owl"] = [
39 | "MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST",
40 | "MplugOwlForConditionalGeneration",
41 | "MplugOwlModel",
42 | ]
43 |
44 |
45 | if TYPE_CHECKING:
46 | from .configuration_mplug_owl import MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP, MplugOwlConfig
47 | from .tokenization_mplug_owl import MplugOwlTokenizer
48 |
49 | try:
50 | if not is_tokenizers_available():
51 | raise OptionalDependencyNotAvailable()
52 | except OptionalDependencyNotAvailable:
53 | pass
54 |
55 | try:
56 | if not is_torch_available():
57 | raise OptionalDependencyNotAvailable()
58 | except OptionalDependencyNotAvailable:
59 | pass
60 | else:
61 | from .modeling_mplug_owl import (
62 | MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST,
63 | MplugOwlForConditionalGeneration,
64 | MplugOwlModel,
65 | MplugOwlPreTrainedModel,
66 | )
67 |
68 |
69 | else:
70 | import sys
71 |
72 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
73 |
74 | from .configuration_mplug_owl import *
75 | from .modeling_mplug_owl import *
76 | from .processing_mplug_owl import *
77 | from .tokenization_mplug_owl import *
78 |
--------------------------------------------------------------------------------
/Owl/mplug_owl_video/tokenization_mplug_owl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 x-plug and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for MplugOwl."""
16 |
17 | from transformers.utils import logging
18 | from transformers.models.llama.tokenization_llama import LlamaTokenizer
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
24 |
25 | PRETRAINED_VOCAB_FILES_MAP = {
26 | "vocab_file": {
27 | "MAGAer13/mplug-owl-llama-7b": "https://huggingface.co/MAGAer13/mplug-owl-llama-7b/resolve/main/vocab.txt",
28 | },
29 | }
30 |
31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
32 | "MAGAer13/mplug-owl-llama-7b": 2048,
33 | }
34 |
35 |
36 | class MplugOwlTokenizer(LlamaTokenizer):
37 | def __init__(
38 | self,
39 | vocab_file,
40 | unk_token="",
41 | bos_token="",
42 | eos_token="",
43 | pad_token="",
44 | sp_model_kwargs=None,
45 | add_bos_token=False,
46 | add_eos_token=False,
47 | clean_up_tokenization_spaces=False,
48 | **kwargs,
49 | ):
50 | super().__init__(
51 | vocab_file,
52 | unk_token,
53 | bos_token,
54 | eos_token,
55 | pad_token,
56 | sp_model_kwargs,
57 | add_bos_token,
58 | add_eos_token,
59 | clean_up_tokenization_spaces,
60 | **kwargs,
61 | )
62 | self.eod_id = self.eos_token_id
63 |
--------------------------------------------------------------------------------
/Owl/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/pipeline/__init__.py
--------------------------------------------------------------------------------
/Owl/pipeline/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .processors.builder import build_processors
2 | from .xgpt3_dataset import MultiModalDataset
3 |
4 | def train_valid_test_datasets_provider(data_path, config, tokenizer, seq_length=1024):
5 | """Build train and valid datasets."""
6 | print('> building train and validation datasets for mPLUG-Owl ...')
7 | train_ds, valid_ds = build_train_valid_test_datasets(
8 | input_file=data_path,
9 | tokenizer=tokenizer,
10 | max_length=seq_length,
11 | config=config)
12 | print("> finished creating mPLUG-Owl datasets ...")
13 |
14 | return train_ds, valid_ds
15 |
16 | def build_train_valid_test_datasets(input_file, tokenizer, max_length=80, config=None):
17 | train_processors = build_processors(config['train_processors'])
18 | valid_processors = build_processors(config['valid_processors'])
19 |
20 | assert len(input_file) == 2 # If you have files more than 2, modify code at here or merger them into train and dev
21 | train_ds = MultiModalDataset(input_file[0], tokenizer, train_processors, max_length)
22 | valid_ds = MultiModalDataset(input_file[1], tokenizer, valid_processors, max_length)
23 | test_ds = None
24 | return (train_ds, valid_ds)
25 |
--------------------------------------------------------------------------------
/Owl/pipeline/data_utils/processors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba. All rights reserved.
2 | from .builder import PROCESSORS, build_processors
3 | from .default_processor import DefaultProcessor
4 | from .caption_processor import CaptionProcessor
5 |
6 | __all__ = [
7 | 'PROCESSORS', 'build_processors',
8 | 'DefaultProcessor', 'CaptionProcessor'
9 | ]
--------------------------------------------------------------------------------
/Owl/pipeline/data_utils/processors/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from data_utils.registry import Registry, build_from_cfg
5 |
6 | PROCESSORS = Registry('processors')
7 |
8 | def build_processors(processors_cfg):
9 | processors = dict()
10 | for task, processor in processors_cfg.items():
11 | processors[task] = build_from_cfg(processor, PROCESSORS)
12 | return processors
13 |
--------------------------------------------------------------------------------
/Owl/pipeline/data_utils/processors/caption_processor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | from PIL import Image
4 | import random
5 |
6 | from data_utils.randaugment import RandomAugment
7 | from .builder import PROCESSORS
8 |
9 |
10 | @PROCESSORS.register_module()
11 | class CaptionProcessor:
12 | def __init__(self, image_size=224, min_scale = 0.5, randaug=False):
13 | self.image_size = image_size
14 | self.min_scale = min_scale
15 |
16 | if randaug:
17 | self.image_transform = transforms.Compose([
18 | transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
19 | transforms.RandomHorizontalFlip(),
20 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
21 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
22 | transforms.ToTensor(),
23 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
24 | ])
25 | else:
26 | self.image_transform = transforms.Compose([
27 | transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
28 | transforms.RandomHorizontalFlip(),
29 | transforms.ToTensor(),
30 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
31 | ])
32 | self.text_transform = None
33 |
34 | def __call__(self, image, text):
35 | assert image or text
36 |
37 | if image:
38 | image_input = self.image_transform(image)
39 | else:
40 | image_input = None
41 |
42 | if text:
43 | if isinstance(text["prompt"], list):
44 | prompt = random.choice(text["prompt"])
45 | else:
46 | prompt = text["prompt"]
47 | text_input = dict(
48 | prompt=prompt,
49 | completion=text["text"],
50 | )
51 | else:
52 | text_input = None
53 | return image_input, text_input
--------------------------------------------------------------------------------
/Owl/pipeline/data_utils/processors/default_processor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | from PIL import Image
4 | import random
5 |
6 | from data_utils.randaugment import RandomAugment
7 | from .builder import PROCESSORS
8 |
9 |
10 | @PROCESSORS.register_module()
11 | class DefaultProcessor:
12 | def __init__(self, image_size=224):
13 | self.image_size = image_size
14 |
15 | self.image_transform = transforms.Compose([
16 | transforms.Resize((image_size, image_size),interpolation=Image.BICUBIC),
17 | transforms.ToTensor(),
18 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
19 | ])
20 |
21 | self.text_transform = None
22 |
23 | def __call__(self, image, text):
24 | assert image or text
25 |
26 | if image:
27 | image_input = self.image_transform(image)
28 | else:
29 | image_input = None
30 |
31 | if text:
32 | if isinstance(text["prompt"], list):
33 | prompt = random.choice(text["prompt"])
34 | else:
35 | prompt = text["prompt"]
36 | text_input = dict(
37 | prompt=prompt,
38 | completion=text["text"],
39 | )
40 | else:
41 | text_input = None
42 | return image_input, text_input
--------------------------------------------------------------------------------
/Owl/pipeline/interface.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import requests
4 | from PIL import Image
5 | import os, sys
6 | sys.path.append('/workspace/hal/code/Owl/')
7 | sys.path.append('/workspace/hal/Owl/')
8 | # sys.path.append('/data/ant/code/Owl/mplug_owl/')
9 | # sys.path.append('/data/ant/code/Owl/pipeline/')
10 | # sys.path.append('/data/ant/code/')
11 | from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
12 | from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
13 | from transformers import AutoTokenizer
14 |
15 |
16 | def get_model(pretrained_ckpt, use_bf16=False):
17 | """Model Provider with tokenizer and processor.
18 |
19 | Args:
20 | pretrained_ckpt (string): The path to pre-trained checkpoint.
21 | use_bf16 (bool, optional): Whether to use bfloat16 to load the model. Defaults to False.
22 |
23 | Returns:
24 | model: MplugOwl Model
25 | tokenizer: MplugOwl text tokenizer
26 | processor: MplugOwl processor (including text and image)
27 | """
28 | model = MplugOwlForConditionalGeneration.from_pretrained(
29 | pretrained_ckpt,
30 | torch_dtype=torch.bfloat16 if use_bf16 else torch.half,
31 | )
32 | image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
33 | tokenizer = AutoTokenizer.from_pretrained(pretrained_ckpt)
34 | processor = MplugOwlProcessor(image_processor, tokenizer)
35 | return model, tokenizer, processor
36 |
37 |
38 | def do_generate(prompts, image_list, model, tokenizer, processor, use_bf16=False, **generate_kwargs):
39 | """The interface for generation
40 |
41 | Args:
42 | prompts (List[str]): The prompt text
43 | image_list (List[str]): Paths of images
44 | model (MplugOwlForConditionalGeneration): MplugOwlForConditionalGeneration
45 | tokenizer (AutoTokenizer): AutoTokenizer
46 | processor (MplugOwlProcessor): MplugOwlProcessor
47 | use_bf16 (bool, optional): Whether to use bfloat16. Defaults to False.
48 |
49 | Returns:
50 | sentence (str): Generated sentence.
51 | """
52 | if image_list:
53 | images = [Image.open(_) for _ in image_list]
54 | else:
55 | images = None
56 | inputs = processor(text=prompts, images=images, return_tensors='pt')
57 | print(f"inputs={inputs}, inputs.shape={inputs['pixel_values'].shape}")
58 | inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
59 | inputs = {k: v.to(model.device) for k, v in inputs.items()}
60 | with torch.no_grad():
61 | res = model.generate(**inputs, **generate_kwargs)
62 | sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
63 | return sentence
64 |
65 |
66 | if __name__ == '__main__':
67 | prompts = ['''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
68 | Human:
69 | Human: Explain why this meme is funny.
70 | AI: ''']
71 | image_list = ['/workspace/hal/Owl/examples/monday.jpg']
72 | # base_model = 'MAGAer13/mplug-owl-llama-7b'
73 | base_model= '/workspace/hal/checkpoints/mplug_OWL_llama_7b'
74 | # base_model = 'output/sft_v0.1_lora_grad_ckpt/checkpoint-4000/generation_config.json'
75 | model, tokenizer, processor = get_model(base_model, use_bf16=True)
76 | # sentence = do_generate(
77 | # prompts, image_list, model,
78 | # tokenizer, processor, use_bf16=True,
79 | # max_length=512, top_k=5, do_sample=True
80 | # )
81 | print(tokenizer.eos_token_id)
82 | print(tokenizer.pad_token_id)
83 | print(tokenizer.bos_token_id)
84 | # print(sentence)
85 |
--------------------------------------------------------------------------------
/Owl/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.28.1
2 | einops
3 | icecream
4 | flask
5 | ruamel.yaml
6 | uvicorn
7 | fastapi
8 | markdown2
9 | gradio
10 | sconf
11 | tensorboardX
12 | tensorboard
13 | h5py
14 | sentencepiece
15 | peft
16 | opencv-python
17 | decord
18 | chardet
19 | cchardet
--------------------------------------------------------------------------------
/Owl/scripts/train_it.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DIR=`pwd`
3 | DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
4 |
5 | if [ $MASTER_ADDR ];then
6 | echo $MASTER_ADDR
7 | echo $MASTER_PORT
8 | echo $WORLD_SIZE
9 | echo $RANK
10 | else
11 | MASTER_ADDR=127.0.0.1
12 | MASTER_PORT=2$(($RANDOM % 10))$(($RANDOM % 10))15
13 | WORLD_SIZE=1
14 | RANK=0
15 | fi
16 |
17 | DISTRIBUTED_ARGS="--nproc_per_node 8 \
18 | --nnodes ${WORLD_SIZE} \
19 | --node_rank ${RANK} \
20 | --master_addr ${MASTER_ADDR} \
21 | --master_port ${MASTER_PORT}"
22 |
23 | EXP_NAME=sft_v0.1
24 | SAVE_NAME=sft_v0.1_ft_grad_ckpt
25 |
26 | SAVE_PATH="./output/${SAVE_NAME}/"
27 |
28 | max_length=2048
29 | micro_batch_size=4
30 | global_batch_size=256
31 | gradient_accumulation_steps=1
32 |
33 | # train_iters = total_data * train_epochs // global_batch_size
34 | # 361481 * 3 / 256 = 4236
35 | train_epochs=3
36 | train_iters=4236
37 |
38 | lr_warmup_iters=50
39 |
40 | eval_iter=50
41 | eval_interval=50
42 | save_interval=500
43 |
44 | mkdir -p ${SAVE_PATH}
45 |
46 | options=" \
47 | --pretrained-ckpt MAGAer13/mplug-owl-llama-7b-pt \
48 | --seq-length ${max_length} \
49 | --micro-batch-size ${micro_batch_size} \
50 | --num-training-steps ${train_iters} \
51 | --train-epochs ${train_epochs} \
52 | --num-warmup-steps ${lr_warmup_iters} \
53 | --gradient-accumulation-steps ${gradient_accumulation_steps} \
54 | --lr 2e-5 \
55 | --min-lr 1e-6 \
56 | --eval-iters ${eval_iter} \
57 | --save-interval ${save_interval} \
58 | --save-path ${SAVE_PATH} \
59 | --clip-grad 1.0 \
60 | --weight-decay 0.0001 \
61 | --adam-beta1 0.9 \
62 | --adam-beta2 0.999 \
63 | --num-workers 32 \
64 | --use-lora \
65 | --gradient-checkpointing \
66 | --bf16"
67 |
68 | multimodal_options=" \
69 | --mm-config configs/v0.yaml
70 | "
71 |
72 | python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pipeline/train.py $@ ${options} ${multimodal_options} 2>&1 | tee ${SAVE_PATH}/train.log
--------------------------------------------------------------------------------
/Owl/scripts/train_it_wo_lora.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DIR=`pwd`
3 | DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
4 |
5 | if [ $MASTER_ADDR ];then
6 | echo $MASTER_ADDR
7 | echo $MASTER_PORT
8 | echo $WORLD_SIZE
9 | echo $RANK
10 | else
11 | MASTER_ADDR=127.0.0.1
12 | MASTER_PORT=2$(($RANDOM % 10))$(($RANDOM % 10))15
13 | WORLD_SIZE=1
14 | RANK=0
15 | fi
16 |
17 | DISTRIBUTED_ARGS="--nproc_per_node 8 \
18 | --nnodes ${WORLD_SIZE} \
19 | --node_rank ${RANK} \
20 | --master_addr ${MASTER_ADDR} \
21 | --master_port ${MASTER_PORT}"
22 |
23 | EXP_NAME=sft_v0.1
24 | SAVE_NAME=sft_v0.1_ft_grad_ckpt
25 |
26 | SAVE_PATH="./output/${SAVE_NAME}/"
27 |
28 | max_length=2048
29 | micro_batch_size=1
30 | global_batch_size=256
31 | gradient_accumulation_steps=4
32 |
33 | # train_iters = total_data * train_epochs // global_batch_size
34 | # 361481 * 3 / 256 = 4236
35 | train_epochs=3
36 | train_iters=4236
37 |
38 | lr_warmup_iters=50
39 | lr_decay_iters=`expr $train_iters - $lr_warmup_iters`
40 |
41 | eval_iter=50
42 | eval_interval=50
43 | save_interval=500
44 |
45 | mkdir -p ${SAVE_PATH}
46 |
47 | options=" \
48 | --pretrained-ckpt MAGAer13/mplug-owl-llama-7b-pt \
49 | --seq-length ${max_length} \
50 | --micro-batch-size ${micro_batch_size} \
51 | --train-epochs ${train_epochs} \
52 | --num-warmup-steps ${lr_warmup_iters} \
53 | --num-training-steps ${train_iters} \
54 | --gradient-accumulation-steps ${gradient_accumulation_steps} \
55 | --lr 1e-5 \
56 | --min-lr 1e-6 \
57 | --eval-iters ${eval_iter} \
58 | --save-interval ${save_interval} \
59 | --save-path ${SAVE_PATH} \
60 | --clip-grad 1.0 \
61 | --weight-decay 0.0001 \
62 | --adam-beta1 0.9 \
63 | --adam-beta2 0.999 \
64 | --num-workers 32 \
65 | --gradient-checkpointing \
66 | --bf16"
67 |
68 | multimodal_options=" \
69 | --mm-config configs/v0.yaml
70 | "
71 |
72 | python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pipeline/train.py $@ ${options} ${multimodal_options} 2>&1 | tee ${SAVE_PATH}/train.log
--------------------------------------------------------------------------------
/Owl/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/Owl/serve/__init__.py
--------------------------------------------------------------------------------
/Owl/serve/gradio_css.py:
--------------------------------------------------------------------------------
1 | code_highlight_css = (
2 | """
3 | #chatbot .hll { background-color: #ffffcc }
4 | #chatbot .c { color: #408080; font-style: italic }
5 | #chatbot .err { border: 1px solid #FF0000 }
6 | #chatbot .k { color: #008000; font-weight: bold }
7 | #chatbot .o { color: #666666 }
8 | #chatbot .ch { color: #408080; font-style: italic }
9 | #chatbot .cm { color: #408080; font-style: italic }
10 | #chatbot .cp { color: #BC7A00 }
11 | #chatbot .cpf { color: #408080; font-style: italic }
12 | #chatbot .c1 { color: #408080; font-style: italic }
13 | #chatbot .cs { color: #408080; font-style: italic }
14 | #chatbot .gd { color: #A00000 }
15 | #chatbot .ge { font-style: italic }
16 | #chatbot .gr { color: #FF0000 }
17 | #chatbot .gh { color: #000080; font-weight: bold }
18 | #chatbot .gi { color: #00A000 }
19 | #chatbot .go { color: #888888 }
20 | #chatbot .gp { color: #000080; font-weight: bold }
21 | #chatbot .gs { font-weight: bold }
22 | #chatbot .gu { color: #800080; font-weight: bold }
23 | #chatbot .gt { color: #0044DD }
24 | #chatbot .kc { color: #008000; font-weight: bold }
25 | #chatbot .kd { color: #008000; font-weight: bold }
26 | #chatbot .kn { color: #008000; font-weight: bold }
27 | #chatbot .kp { color: #008000 }
28 | #chatbot .kr { color: #008000; font-weight: bold }
29 | #chatbot .kt { color: #B00040 }
30 | #chatbot .m { color: #666666 }
31 | #chatbot .s { color: #BA2121 }
32 | #chatbot .na { color: #7D9029 }
33 | #chatbot .nb { color: #008000 }
34 | #chatbot .nc { color: #0000FF; font-weight: bold }
35 | #chatbot .no { color: #880000 }
36 | #chatbot .nd { color: #AA22FF }
37 | #chatbot .ni { color: #999999; font-weight: bold }
38 | #chatbot .ne { color: #D2413A; font-weight: bold }
39 | #chatbot .nf { color: #0000FF }
40 | #chatbot .nl { color: #A0A000 }
41 | #chatbot .nn { color: #0000FF; font-weight: bold }
42 | #chatbot .nt { color: #008000; font-weight: bold }
43 | #chatbot .nv { color: #19177C }
44 | #chatbot .ow { color: #AA22FF; font-weight: bold }
45 | #chatbot .w { color: #bbbbbb }
46 | #chatbot .mb { color: #666666 }
47 | #chatbot .mf { color: #666666 }
48 | #chatbot .mh { color: #666666 }
49 | #chatbot .mi { color: #666666 }
50 | #chatbot .mo { color: #666666 }
51 | #chatbot .sa { color: #BA2121 }
52 | #chatbot .sb { color: #BA2121 }
53 | #chatbot .sc { color: #BA2121 }
54 | #chatbot .dl { color: #BA2121 }
55 | #chatbot .sd { color: #BA2121; font-style: italic }
56 | #chatbot .s2 { color: #BA2121 }
57 | #chatbot .se { color: #BB6622; font-weight: bold }
58 | #chatbot .sh { color: #BA2121 }
59 | #chatbot .si { color: #BB6688; font-weight: bold }
60 | #chatbot .sx { color: #008000 }
61 | #chatbot .sr { color: #BB6688 }
62 | #chatbot .s1 { color: #BA2121 }
63 | #chatbot .ss { color: #19177C }
64 | #chatbot .bp { color: #008000 }
65 | #chatbot .fm { color: #0000FF }
66 | #chatbot .vc { color: #19177C }
67 | #chatbot .vg { color: #19177C }
68 | #chatbot .vi { color: #19177C }
69 | #chatbot .vm { color: #19177C }
70 | #chatbot .il { color: #666666 }
71 | """)
72 | #.highlight { background: #f8f8f8; }
73 |
74 |
--------------------------------------------------------------------------------
/Owl/serve/model_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | import torch
5 | import transformers
6 | import traceback
7 |
8 | from queue import Queue
9 | from threading import Thread
10 |
11 |
12 | def post_process_output(text):
13 | text = text.strip()
14 | pattern = re.compile(
15 | r"||||\[PAD\]|<\|endoftext\|>|\[UNK\]|\[CLS\]|\[MASK\]|<\|startofpiece\|>|<\|endofpiece\|>|\[gMASK\]|\[sMASK\]"
16 | )
17 | text = pattern.sub("", text.strip()).strip()
18 | return text
19 |
20 |
21 | def post_process_code(code):
22 | sep = "\n```"
23 | if sep in code:
24 | blocks = code.split(sep)
25 | if len(blocks) % 2 == 1:
26 | for i in range(1, len(blocks), 2):
27 | blocks[i] = blocks[i].replace("\\_", "_")
28 | code = sep.join(blocks)
29 | return code
30 |
31 |
32 | class Stream(transformers.StoppingCriteria):
33 | def __init__(self, callback_func=None):
34 | self.callback_func = callback_func
35 |
36 | def __call__(self, input_ids, scores) -> bool:
37 | if self.callback_func is not None:
38 | self.callback_func(input_ids[0])
39 | return False
40 |
41 |
42 | class Iteratorize:
43 |
44 | """
45 | Transforms a function that takes a callback
46 | into a lazy iterator (generator).
47 | """
48 |
49 | def __init__(self, func, kwargs={}, callback=None):
50 | self.mfunc = func
51 | self.c_callback = callback
52 | self.q = Queue()
53 | self.sentinel = object()
54 | self.kwargs = kwargs
55 | self.stop_now = False
56 |
57 | def _callback(val):
58 | if self.stop_now:
59 | raise ValueError
60 | self.q.put(val)
61 |
62 | def gentask():
63 | try:
64 | ret = self.mfunc(callback=_callback, **self.kwargs)
65 | except ValueError:
66 | pass
67 | except:
68 | traceback.print_exc()
69 | pass
70 |
71 | self.q.put(self.sentinel)
72 | if self.c_callback:
73 | self.c_callback(ret)
74 |
75 | self.thread = Thread(target=gentask)
76 | self.thread.start()
77 |
78 | def __iter__(self):
79 | return self
80 |
81 | def __next__(self):
82 | obj = self.q.get(True, None)
83 | if obj is self.sentinel:
84 | raise StopIteration
85 | else:
86 | return obj
87 |
88 | def __enter__(self):
89 | return self
90 |
91 | def __exit__(self, exc_type, exc_val, exc_tb):
92 | self.stop_now = True
--------------------------------------------------------------------------------
/Owl/serve/serve_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import gradio as gr
4 | import logging
5 | import sys
6 | import os
7 | import json
8 | import requests
9 | from .conversation import default_conversation
10 | from .gradio_patch import Chatbot as grChatbot
11 | from .gradio_css import code_highlight_css
12 | import datetime
13 | import uuid
14 | import base64
15 | from io import BytesIO
16 | import time
17 |
18 | from .io_utils import IO, DefaultIO, OSS
19 |
20 |
21 | handler = None
22 |
23 |
24 | class _IOWrapper:
25 | def __init__(self):
26 | self._io = DefaultIO()
27 |
28 | def set_io(self, new_io):
29 | self._io = new_io
30 |
31 | def __getattr__(self, name):
32 | if hasattr(self._io, name):
33 | return getattr(self._io, name)
34 | return super().__getattr__(name)
35 |
36 | def __str__(self):
37 | return self._io.__name__
38 |
39 | def init():
40 | io = _IOWrapper()
41 | return io
42 |
43 |
44 | def vote_last_response(state, vote_type, model_selector, request: gr.Request):
45 | pass
46 |
47 | def upvote_last_response(state, model_selector, request: gr.Request):
48 | vote_last_response(state, "upvote", model_selector, request)
49 | return ("",) + (disable_btn,) * 3
50 |
51 | def downvote_last_response(state, model_selector, request: gr.Request):
52 | vote_last_response(state, "downvote", model_selector, request)
53 | return ("",) + (disable_btn,) * 3
54 |
55 | def flag_last_response(state, model_selector, request: gr.Request):
56 | vote_last_response(state, "flag", model_selector, request)
57 | return ("",) + (disable_btn,) * 3
58 |
59 | def regenerate(state, request: gr.Request):
60 | state.messages[-1][-1] = None
61 | state.skip_next = False
62 | return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
63 |
64 | def clear_history(request: gr.Request):
65 | state = default_conversation.copy()
66 | return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
67 |
68 |
69 | def add_text(state, text, image, video, request: gr.Request):
70 | if len(text) <= 0 and (image is None or video is None):
71 | state.skip_next = True
72 | return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
73 |
74 | if image is not None:
75 | if '' not in text:
76 | text = text + '\n'
77 | text = (text, image)
78 |
79 | if video is not None:
80 | num_frames = 4
81 | if '' not in text:
82 | text = text + '\n' * num_frames
83 | text = (text, video)
84 |
85 | state.append_message(state.roles[0], text)
86 | state.append_message(state.roles[1], None)
87 | state.skip_next = False
88 | return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
89 |
90 | def after_process_image(prompt):
91 | prompt = prompt.replace("\n", "")
92 | pro_prompt = ""
93 | prompt = prompt.split("\n")
94 | for p in prompt:
95 | if p.count("") > 0:
96 | pro_prompt += "Human: \n"
97 | if p != "":
98 | pro_prompt += p.replace("", "") + "\n"
99 | else:
100 | pro_prompt += p + "\n"
101 | return (pro_prompt[:-1]+" ").replace("\n Human", "\nHuman").replace("\n AI", "\nAI")
102 |
103 |
104 | headers = {"User-Agent": "mPLUG-Owl Client"}
105 |
106 | no_change_btn = gr.Button.update()
107 | enable_btn = gr.Button.update(interactive=True)
108 | disable_btn = gr.Button.update(interactive=False)
109 |
110 | get_window_url_params = """
111 | function() {
112 | const params = new URLSearchParams(window.location.search);
113 | url_params = Object.fromEntries(params);
114 | console.log(url_params);
115 | return url_params;
116 | }
117 | """
--------------------------------------------------------------------------------
/Owl/test-ant.py:
--------------------------------------------------------------------------------
1 | import torch
2 | print(isinstance('loss', (list, torch.Tensor)))
3 |
--------------------------------------------------------------------------------
/assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/assets/method.png
--------------------------------------------------------------------------------
/common/gpt.py:
--------------------------------------------------------------------------------
1 | import openai
2 |
3 | openai.api_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
4 | proxy = "http://127.0.0.1:7890"
5 |
6 |
7 | def gpt_infer(messages: list, model="gpt-3.5-turbo"):
8 | completion = openai.ChatCompletion.create(model=model, messages=messages, proxy=proxy)
9 | return completion.choices[0].message.content # type: ignore
10 |
11 |
12 | if __name__ == "__main__":
13 | messages = [
14 | {"role": "user", "content": "鲁迅和周树人的关系"},
15 | ]
16 | print(gpt_infer(messages))
17 |
--------------------------------------------------------------------------------
/configs/minigpt4_infer_fp16-dpo.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: minigpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | has_qformer: False
11 |
12 | model_type: pretrain_llama2
13 | max_txt_len: 160
14 | end_sym: ""
15 | low_resource: False
16 | prompt_template: "[INST] {} [/INST] "
17 | ckpt: "/home/ant/llm-hal/efuf/checkpoints/minigpt4_llama2_7b/pretrained.pth"
18 |
19 | # generation configs
20 | prompt: ""
21 |
22 | # llama_model: "/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf"
23 | llama_model: "/home/ant/llm-hal/HA-DPO/ha_dpo/models/minigpt4/merged_minigpt4_ha_dpo"
24 |
25 | preprocess:
26 | vis_processor:
27 | train:
28 | name: "blip2_image_train"
29 | image_size: 224
30 | eval:
31 | name: "blip2_image_eval"
32 | image_size: 224
33 | text_processor:
34 | train:
35 | name: "blip_caption"
36 | eval:
37 | name: "blip_caption"
38 |
39 | datasets:
40 | cc_sbu_align:
41 | vis_processor:
42 | train:
43 | name: "blip2_image_eval"
44 | image_size: 224
45 | text_processor:
46 | train:
47 | name: "blip_caption"
48 |
49 | run:
50 | task: image_text_pretrain
51 |
--------------------------------------------------------------------------------
/configs/minigpt4_infer_fp16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: minigpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | has_qformer: False
11 |
12 | model_type: pretrain_llama2
13 | max_txt_len: 160
14 | end_sym: ""
15 | low_resource: False
16 | prompt_template: "[INST] {} [/INST] "
17 | ckpt: "checkpoints/minigpt4_llama2_7b/pretrained.pth"
18 |
19 | # generation configs
20 | prompt: ""
21 |
22 | llama_model: "/home/nfs02/model/llama2/hf/Llama-2-7b-chat-hf/"
23 |
24 | preprocess:
25 | vis_processor:
26 | train:
27 | name: "blip2_image_train"
28 | image_size: 224
29 | eval:
30 | name: "blip2_image_eval"
31 | image_size: 224
32 | text_processor:
33 | train:
34 | name: "blip_caption"
35 | eval:
36 | name: "blip_caption"
37 |
38 | datasets:
39 | cc_sbu_align:
40 | vis_processor:
41 | train:
42 | name: "blip2_image_eval"
43 | image_size: 224
44 | text_processor:
45 | train:
46 | name: "blip_caption"
47 |
48 | run:
49 | task: image_text_pretrain
50 |
--------------------------------------------------------------------------------
/configs/minigpt4_infer_fp16_hadpo.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: minigpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | has_qformer: False
11 |
12 | model_type: pretrain_llama2
13 | max_txt_len: 160
14 | end_sym: ""
15 | low_resource: False
16 | prompt_template: "[INST] {} [/INST] "
17 | ckpt: "/home/ant/llm-hal/efuf/checkpoints/minigpt4_llama2_7b/pretrained.pth"
18 |
19 | # generation configs
20 | prompt: ""
21 |
22 | # llama_model: "/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf"
23 | llama_model: "/home/ant/llm-hal/HA-DPO/ha_dpo/models/minigpt4/merged_minigpt4_ha_dpo"
24 |
25 | preprocess:
26 | vis_processor:
27 | train:
28 | name: "blip2_image_train"
29 | image_size: 224
30 | eval:
31 | name: "blip2_image_eval"
32 | image_size: 224
33 | text_processor:
34 | train:
35 | name: "blip_caption"
36 | eval:
37 | name: "blip_caption"
38 |
39 | datasets:
40 | cc_sbu_align:
41 | vis_processor:
42 | train:
43 | name: "blip2_image_eval"
44 | image_size: 224
45 | text_processor:
46 | train:
47 | name: "blip_caption"
48 |
49 | run:
50 | task: image_text_pretrain
51 |
--------------------------------------------------------------------------------
/configs/minigpt4_train_fp16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: minigpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | has_qformer: False
11 |
12 | model_type: pretrain_llama2
13 | max_txt_len: 160
14 | end_sym: ""
15 | low_resource: False
16 | prompt_template: "{}"
17 | ckpt: "checkpoints/minigpt4_llama2_7b/pretrained.pth"
18 |
19 | # generation configs
20 | prompt: ""
21 |
22 | llama_model: "/home/nfs02/model/llama2/hf/Llama-2-7b-chat-hf/"
23 |
24 | preprocess:
25 | vis_processor:
26 | train:
27 | name: "blip2_image_train"
28 | image_size: 224
29 | eval:
30 | name: "blip2_image_eval"
31 | image_size: 224
32 | text_processor:
33 | train:
34 | name: "blip_caption"
35 | eval:
36 | name: "blip_caption"
37 |
38 | datasets:
39 | cc_sbu_align:
40 | vis_processor:
41 | train:
42 | name: "blip2_image_eval"
43 | image_size: 224
44 | text_processor:
45 | train:
46 | name: "blip_caption"
47 |
48 | run:
49 | task: image_text_pretrain
50 |
--------------------------------------------------------------------------------
/dataset/caption_prompt.txt:
--------------------------------------------------------------------------------
1 | Please extract all the concrete visual objects in a piece of image caption. If there are brackets around a phrase, it means this is a concrete visual object pre-identified for you in advance, and you must output the phrase. Then you need to examine all noun phrases and infer if they are concrete visual objects in the image. Your response should be part of the input. You **MUST** respond in the format of a non-repeat list "object_1, object_2, ..., object_n". Strictly follow the format and **do not output any other words**, including but not limited to other phrases or additional explanations.
2 | Example:
3 | {input_label} The image shows a train on the tracks, with cars parked on the platform. The train is painted in a blue and white color scheme, with [a large logo] on the front. The cars are painted in a red and white color scheme, with [a small logo] on the front. The train is moving slowly through the station, with [people] standing on the platform and watching. The sky is clear and blue, with a few clouds in the distance. The buildings in the background are tall and made of brick, with windows and doors on the upper floors. There are trees and bushes growing along the sides of the tracks, and [a small bridge] over the tracks in the distance.
4 | {output_label} a train, tracks, cars, platform, a large logo, a small logo, people, sky, clouds, buildings, windows, doors, trees, bushes, a small bridge
5 |
--------------------------------------------------------------------------------
/dataset/caption_prompt_finetune.txt:
--------------------------------------------------------------------------------
1 | Please extract all the concrete visual objects in a piece of image caption. You need to examine all noun phrases and infer if they are concrete visual objects in the image. Your response should be part of the input. You **MUST** respond in the format of a non-repeat list "object_1, object_2, ..., object_n". Strictly follow the format and **do not output any other words**, including but not limited to other phrases or additional explanations.
2 | Example:
3 | {input_label} The image shows a train on the tracks, with cars parked on the platform. The train is painted in a blue and white color scheme, with a large logo on the front. The cars are painted in a red and white color scheme, with a small logo on the front. The train is moving slowly through the station, with people standing on the platform and watching. The sky is clear and blue, with a few clouds in the distance. The buildings in the background are tall and made of brick, with windows and doors on the upper floors. There are trees and bushes growing along the sides of the tracks, and a small bridge over the tracks in the distance.
4 | {output_label} a train, tracks, cars, platform, a large logo, a small logo, people, sky, clouds, buildings, windows, doors, trees, bushes, a small bridge
5 |
--------------------------------------------------------------------------------
/dataset/synonyms.txt:
--------------------------------------------------------------------------------
1 | person, girl, boy, man, woman, kid, child, chef, baker, people, adult, rider, children, baby, worker, passenger, sister, brother, biker, policeman, cop, officer, lady, cowboy, bride, groom, male, female, guy, traveler, mother, father, gentleman, pitcher, player, skier, snowboarder, skater, skateboarder, guy, foreigner, child, gentleman, caller, offender, coworker, trespasser, patient, politician, soldier, grandchild, serviceman, walker, drinker, doctor, bicyclist, thief, buyer, teenager, student, camper, driver, solider, hunter, shopper, villager, pedestrian
2 | bicycle, bike, unicycle, minibike, trike
3 | car, automobile, van, minivan, sedan, suv, hatchback, cab, jeep, coupe, taxicab, limo, taxi
4 | motorcycle, scooter, motor bike, motor cycle, motorbike, scooter, moped
5 | airplane, jetliner, plane, air plane, monoplane, aircraft, jet, jetliner, airbus, biplane, seaplane
6 | bus, minibus, trolley
7 | train, locomotive, tramway, caboose
8 | truck, pickup, lorry, hauler, firetruck
9 | boat, ship, liner, sailboat, motorboat, dinghy, powerboat, speedboat, canoe, skiff, yacht, kayak, catamaran, pontoon, houseboat, vessel, rowboat, trawler, ferryboat, watercraft, tugboat, schooner, barge, ferry, sailboard, paddleboat, lifeboat, freighter, steamboat, riverboat, battleship, steamship
10 | traffic light, street light, traffic signal, stop light, streetlight, stoplight
11 | fire hydrant, hydrant
12 | stop sign
13 | parking meter
14 | bench, pew
15 | bird, ostrich, owl, seagull, goose, duck, parakeet, falcon, robin, pelican, waterfowl, heron, hummingbird, mallard, finch, pigeon, sparrow, seabird, osprey, blackbird, fowl, shorebird, woodpecker, egret, chickadee, quail, bluebird, kingfisher, buzzard, willet, gull, swan, bluejay, flamingo, cormorant, parrot, loon, gosling, waterbird, pheasant, rooster, sandpiper, crow, raven, turkey, oriole, cowbird, warbler, magpie, peacock, cockatiel, lorikeet, puffin, vulture, condor, macaw, peafowl, cockatoo, songbird
16 | cat, kitten, feline, tabby
17 | dog, puppy, beagle, pup, chihuahua, schnauzer, dachshund, rottweiler, canine, pitbull, collie, pug, terrier, poodle, labrador, doggie, doberman, mutt, doggy, spaniel, bulldog, sheepdog, weimaraner, corgi, cocker, greyhound, retriever, brindle, hound, whippet, husky
18 | horse, colt, pony, racehorse, stallion, equine, mare, foal, palomino, mustang, clydesdale, bronc, bronco
19 | sheep, lamb, ram, lamb, goat, ewe
20 | cow, cattle, oxen, ox, calf, cattle, holstein, heifer, buffalo, bull, zebu, bison
21 | elephant
22 | bear, panda
23 | zebra
24 | giraffe
25 | backpack, knapsack
26 | umbrella
27 | handbag, wallet, purse, briefcase
28 | tie, bow, bow tie
29 | suitcase, suit case, luggage
30 | frisbee
31 | skis, ski
32 | snowboard
33 | sports ball, ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard, longboard, skimboard, shortboard, wakeboard
39 | tennis racket, racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife, pocketknife, knive
45 | spoon
46 | bowl, container
47 | banana
48 | apple
49 | sandwich, burger, sub, cheeseburger, hamburger
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut, doughnut, bagel
56 | cake, cheesecake, cupcake, shortcake, coffeecake, pancake
57 | chair, seat, stool
58 | couch, sofa, recliner, futon, loveseat, settee, chesterfield
59 | potted plant, houseplant
60 | bed
61 | dining table, table, desk, coffee table
62 | toilet, urinal, commode, toilet, lavatory, potty
63 | tv, monitor, televison, television
64 | laptop, computer, notebook, netbook, lenovo, macbook, laptop computer
65 | mouse
66 | remote, remote control
67 | keyboard
68 | cell phone, mobile phone, phone, cellphone, telephone, phon, smartphone, iPhone
69 | microwave
70 | oven, stovetop, stove, stove top oven
71 | toaster
72 | sink
73 | refrigerator, fridge, fridge, freezer
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear, teddybear
79 | hair drier, hairdryer
80 | toothbrush
81 |
--------------------------------------------------------------------------------
/dataset/vqa_prompt.txt:
--------------------------------------------------------------------------------
1 | Please extract all the visual objects in a piece of text which is an answer of a Visual Question Answering task. You should carefully infer if the object indeed appears in the image, and avoid those objects that are not depicted in the image. Only extract objects that you are certain to appear in the image according to the text, and do not invent objects that are not mentioned. If you are not sure if an object meets the above condition, do not add it to the result. Respond in the format of a non-repeat list "object_1, object_2, ..., object_n". If there are none, do not respond anything. Strictly follow the format and **do not output any other words**, including but not limited to other non-visual objects or explanations.
2 |
3 | Example:
4 | Input: If someone wants to try a similar water sport, such as surfing, one important factor to consider is the need for proper training and skill development. In the image, the surfer is seen riding a wave on his surfboard, showcasing balance and technique. Learning the fundamentals of surfing, including wave selection, paddling, standing up on the board, and maneuvering, can significantly increase the chances of a successful and enjoyable experience. Additionally, it's essential to use appropriate safety gear such as a leash and a wetsuit, as well as follow local surfing etiquette and regulations to ensure the safety of oneself and others in the water.
5 | Output: surfer, wave, surfboard
6 |
7 | Task:
8 | Input: {}
9 | Output:
--------------------------------------------------------------------------------
/deploy/dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-devel
2 |
3 | # Copy all files to the /build directory
4 | COPY . /build
5 | WORKDIR /build
6 | RUN chmod -R u+rwX,go+rX,go-w /build
7 |
8 | # Update sources list and install necessary packages
9 | RUN cp sources.list /etc/apt/sources.list && apt-get update && apt-get install -y python3-pip git libgl1
10 |
11 | # Install Python dependencies
12 | RUN pip install -r requirements.txt -i https://mirrors.nju.edu.cn/pypi/web/simple/
13 |
14 | # Install flash attention
15 | RUN if [ -e "flash_attn-2.5.5+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl" ]; then \
16 | pip install flash_attn-2.5.5+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; \
17 | else \
18 | pip install flash_attn==2.5.5 -i https://mirrors.nju.edu.cn/pypi/web/simple/; \
19 | fi
20 |
21 | # Assign workdir
22 | WORKDIR /workspace
23 |
24 | # it should be built:
25 | # cd deploy
26 | # docker build -t efuf:1.0 .
27 |
28 | # then run:
29 | # cd ..
30 | # docker run --gpus all --ipc=host --network=host --rm -it -v .:/workspace efuf:1.0
--------------------------------------------------------------------------------
/deploy/requirements.txt:
--------------------------------------------------------------------------------
1 | openai==0.28.0
2 | tqdm
3 | transformers==4.31.0
4 | git+https://github.com/MaureenZOU/detectron2-xyz.git
5 | einops
6 | einops-exts
7 | wandb
8 | nltk
9 | pandas
10 | spacy
11 | torchmetrics
12 | altair
13 | whisper
14 | gpustat
15 | timm
16 | opencv-python-headless
17 | webdataset
18 | visual_genome
19 | scikit-image
20 | visual_genome
21 | decord
22 | peft
23 | sentence-transformers
24 | gradio
25 | fairscale
26 | easydict
27 | pycocoevalcap
28 | moviepy
29 | pyyaml_env_tag
30 | open3d
31 | h5py
32 | diffusers==0.15.0
33 | seaborn
--------------------------------------------------------------------------------
/deploy/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | docker run --gpus all --network=host -it \
3 | -v .:/workspace/hal \
4 | -v /data/cache/huggingface:/root/.cache/huggingface \
5 | -v /data/cache/torch:/root/.cache/torch \
6 | -v /root/nltk_data:/root/nltk_data \
7 | mm_hal:1.2
--------------------------------------------------------------------------------
/deploy/run_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | docker run --gpus all --ipc=host --network=host --rm -it \
3 | -v .:/workspace/hal \
4 | -v /data/cache/huggingface:/root/.cache/huggingface \
5 | -v /data/cache/torch:/root/.cache/torch \
6 | -v /root/nltk_data:/root/nltk_data \
7 | -v /data/NJU/lib/zf/LLaVA:/workspace/hal/LLaVA \
8 | mm_hal:1.3
9 |
--------------------------------------------------------------------------------
/deploy/sources.list:
--------------------------------------------------------------------------------
1 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse
2 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse
3 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse
4 | deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse
--------------------------------------------------------------------------------
/evaluate/eval_caption.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2023-11-12 09:22:40
3 | # @Author : Shangyu.Xing (starreeze@foxmail.com)
4 |
5 | from __future__ import annotations
6 | import os, sys, torch
7 |
8 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
9 | sys.path.append(os.path.dirname(SCRIPT_DIR))
10 | from io import TextIOWrapper
11 | from PIL import Image
12 | from tqdm import tqdm
13 | from torch.utils.data import DataLoader, Dataset
14 | from common.models import model_loaders, generators
15 | from common.args import args
16 |
17 |
18 | class CocoImageDataset(Dataset):
19 | def __init__(self, image_names: list[str], processor):
20 | super(CocoImageDataset, self).__init__()
21 | self.image_names = image_names
22 | self.processor = processor
23 |
24 | def __len__(self):
25 | return len(self.image_names)
26 |
27 | def __getitem__(self, index):
28 | image_path = os.path.join(args.image_dir_path, self.image_names[index])
29 | img = self.processor(Image.open(image_path).convert("RGB"))
30 | if args.device != "cpu":
31 | img = img.to(torch.float16).to(args.device)
32 | return img, self.image_names[index]
33 |
34 |
35 | def process_single(batch, model, prompt: str, output_fd: TextIOWrapper):
36 | images, image_names = batch
37 | texts = [prompt] * args.infer_bs_total
38 | results = [""] * args.infer_bs_total
39 | filtered = []
40 | for _ in range(args.infer_retry):
41 | with torch.no_grad():
42 | answers = generators[args.model](model, texts, images)
43 | for i, (name, answer) in enumerate(zip(image_names, answers)):
44 | if answer.replace("\n", "") and not results[i]:
45 | results[i] = name + " ### " + answer.replace("\n", " ")
46 | filtered = [r for r in results if r]
47 | if len(filtered) == len(results):
48 | break
49 | output_fd.write("\n".join(filtered) + "\n")
50 | output_fd.flush()
51 |
52 |
53 | def main():
54 | model_load_path = getattr(args, f"{args.model}_ckpt_load_path")
55 | model_args = ["--cfg-path", args.minigpt_train_cfg] if args.model == "minigpt" else []
56 | model, vis_processor = model_loaders[args.model](model_load_path, args.device, False, model_args)
57 | with open(args.object_data_path, "r") as f:
58 | objects = f.read().splitlines()
59 | images_used = {obj.split(args.column_splitter)[0] for obj in objects}
60 | image_names = filter(lambda x: x not in images_used, sorted(os.listdir(args.image_dir_path)))
61 | eval_end_pos = args.default_eval_samples if args.end_pos == int(1e10) else args.end_pos
62 | image_names = list(image_names)[args.start_pos : eval_end_pos]
63 | dataloader = DataLoader(
64 | CocoImageDataset(image_names, vis_processor),
65 | args.infer_bs_total,
66 | False,
67 | num_workers=args.infer_dataloader_worker,
68 | )
69 |
70 | prompt = getattr(args, f"{args.model}_eval_prompt")
71 | with open(args.caption_eval_path, "w" if args.restart else "a", encoding="utf-8") as f:
72 | for batch in tqdm(dataloader):
73 | process_single(batch, model, prompt, output_fd=f)
74 |
75 |
76 | if __name__ == "__main__":
77 | main()
78 |
--------------------------------------------------------------------------------
/evaluate/eval_gqa.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2024-06-11 15:15:36
3 | # @Author : Shangyu.Xing (starreeze@foxmail.com)
4 |
5 | from __future__ import annotations
6 | import os, sys, json, torch
7 |
8 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
9 | sys.path.append(os.path.dirname(SCRIPT_DIR))
10 | from PIL import Image
11 | from tqdm import tqdm
12 | from torch.utils.data import DataLoader, Dataset
13 | from torch.cuda.amp import autocast # type: ignore
14 | from common.args import args
15 | from common.models import model_loaders, model_forward, generators, data_maps, sample_collators
16 | from common.utils import to_device
17 |
18 | basename = getattr(args, f"{args.model}_ckpt_load_path").split("/")[-1][:10]
19 | pred_path = os.path.join(args.gqa_data_path, "testdev_balanced_predictions.json")
20 | train_question_path = os.path.join(args.gqa_data_path, "questions1.2", "val_balanced_questions.json")
21 | eval_question_path = os.path.join(args.gqa_data_path, "questions1.2", "testdev_balanced_questions.json")
22 | image_dir = os.path.join(args.gqa_data_path, "images")
23 | model_dtype = {"llava": torch.bfloat16, "llavarlhf": torch.bfloat16, "share4v": torch.bfloat16}
24 |
25 |
26 | class GQAData(Dataset):
27 | def __init__(self, processor, start=0, end=int(1e9)):
28 | super().__init__()
29 | self.processor = processor
30 | with open(eval_question_path, "r") as f:
31 | questions: dict = json.load(f)
32 | self.data = list(questions.items())[start:end]
33 |
34 | def __len__(self):
35 | return len(self.data)
36 |
37 | def __getitem__(self, index: int):
38 | id, question = self.data[index]
39 | image_name = question["imageId"] + ".jpg"
40 | image_path = os.path.join(image_dir, image_name)
41 | image = Image.open(image_path).convert("RGB")
42 | answer = question["answer"]
43 | text = getattr(args, f"{args.model}_eval_vqa_prompt").format(question=question["question"])
44 | return self.processor(image), id, text, answer
45 |
46 | def save(self):
47 | os.rename(eval_question_path, eval_question_path + ".bak")
48 | with open(eval_question_path, "w") as f:
49 | json.dump({k: v for k, v in self.data}, f)
50 | return self
51 |
52 | @staticmethod
53 | def restore():
54 | os.rename(eval_question_path + ".bak", eval_question_path)
55 |
56 |
57 | def inference(model, vis_processor, start, end):
58 | eval_data = GQAData(vis_processor, start, end)
59 | eval_loader = DataLoader(eval_data, args.infer_bs_total, False)
60 | model.eval()
61 | model.to(torch.float16)
62 | correct, total = 0, 0
63 | with torch.no_grad():
64 | results = []
65 | for batch in tqdm(eval_loader):
66 | image, question_id, question, answer = to_device(batch) # type: ignore
67 | with autocast(dtype=torch.float16):
68 | responses = generators[args.model](model, question, image)
69 | for q, response, ans in zip(question_id, responses, answer):
70 | results.append({"questionId": q, "prediction": response.rstrip(".").lower()})
71 | if ans.lower() in response.lower(): # type: ignore
72 | correct += 1
73 | total += len(question_id)
74 | with open(pred_path, "w") as f:
75 | json.dump(results, f)
76 | print(f"correct: {correct}, total: {total}, acc: {correct / total}")
77 |
78 |
79 | def eval():
80 | model, vis_processor = model_loaders[args.model](getattr(args, f"{args.model}_ckpt_load_path"))
81 | try:
82 | model.to(model_dtype[args.model])
83 | except KeyError:
84 | pass
85 | inference(model, vis_processor, start=0, end=args.default_eval_samples)
86 |
87 |
88 | if __name__ == "__main__":
89 | eval()
90 |
--------------------------------------------------------------------------------
/evaluate/eval_mme.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2024-04-02 18:36:40
3 | # @Author : Shangyu.Xing (starreeze@foxmail.com)
4 |
5 | from __future__ import annotations
6 | import os, sys, torch
7 |
8 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
9 | sys.path.append(os.path.dirname(SCRIPT_DIR))
10 | from glob import glob
11 | from PIL import Image
12 | from tqdm import tqdm
13 | from torch.utils.data import DataLoader, Dataset
14 | from common.args import args
15 | from common.models import model_loaders, generators
16 | from common.utils import to_device
17 |
18 | basename = getattr(args, f"{args.model}_ckpt_load_path").split("/")[-1]
19 | pred_name = f"{args.run_name}_" if args.run_name else ""
20 | pred_path = os.path.join(args.mme_result_path, f"{pred_name}{args.model}_{basename}")
21 | os.makedirs(pred_path, exist_ok=True)
22 | # model_dtype = {"llava": torch.bfloat16, "llavarlhf": torch.bfloat16, "share4v": torch.bfloat16}
23 | model_dtype = {"llavarlhf": torch.bfloat16}
24 |
25 |
26 | class MMEData(Dataset):
27 | def __init__(self, processor, category_path: str):
28 | super().__init__()
29 | self.processor = processor
30 |
31 | if os.path.exists(os.path.join(category_path, "images")):
32 | image_path = os.path.join(category_path, "images")
33 | qa_path = os.path.join(category_path, "questions_answers_YN")
34 | else:
35 | image_path = qa_path = category_path
36 | assert os.path.isdir(image_path), image_path
37 | assert os.path.isdir(qa_path), qa_path
38 |
39 | self.data = []
40 | for file in os.listdir(qa_path):
41 | if not file.endswith(".txt"):
42 | continue
43 | for line in open(os.path.join(qa_path, file), encoding="utf-8"):
44 | question, answer = line.strip().split("\t")
45 | image_globs = glob(os.path.join(image_path, file.split(".")[0] + ".*"))
46 | image = list(filter(lambda x: not x.endswith(".txt"), image_globs))
47 | if image:
48 | self.data.append((image[0], question, answer))
49 | else:
50 | tqdm.write("No image found for " + file)
51 |
52 | def __len__(self):
53 | return len(self.data)
54 |
55 | def __getitem__(self, index: int):
56 | sample = self.data[index]
57 | image_path, question, answer = sample
58 | image = Image.open(image_path).convert("RGB")
59 | question = question.replace(" Please answer yes or no.", "")
60 | text = getattr(args, f"{args.model}_eval_vqa_prompt").format(question=question)
61 | return self.processor(image), text, "\t".join(sample)
62 |
63 |
64 | def inference(model, vis_processor, category_dir):
65 | eval_data = MMEData(vis_processor, os.path.join(args.mme_data_path, category_dir))
66 | eval_loader = DataLoader(eval_data, args.infer_bs_total, False)
67 | model.eval()
68 | with torch.no_grad():
69 | results = []
70 | for batch in tqdm(eval_loader, position=1):
71 | images, questions, lines = to_device(batch, model_dtype.get(args.model, torch.float16)) # type: ignore
72 | # with autocast(dtype=torch.float16):
73 | answers = generators[args.model](model, questions, images)
74 | for line, answer in zip(lines, answers):
75 | results.append(line + "\t" + answer.replace("\n", " "))
76 | with open(os.path.join(pred_path, category_dir + ".txt"), "w", encoding="utf-8") as f:
77 | f.write("\n".join(results))
78 |
79 |
80 | def eval():
81 | model, vis_processor = model_loaders[args.model](getattr(args, f"{args.model}_ckpt_load_path"))
82 | try:
83 | model.to(model_dtype[args.model])
84 | except KeyError:
85 | pass
86 | for category_dir in tqdm(os.listdir(args.mme_data_path), position=0):
87 | if os.path.isdir(os.path.join(args.mme_data_path, category_dir)):
88 | inference(model, vis_processor, category_dir)
89 | eval_cmd = f"python evaluate/mme/calculation.py --results_dir {pred_path}"
90 | print(f"Inference done. running `{eval_cmd}`")
91 | os.system(eval_cmd)
92 |
93 |
94 | if __name__ == "__main__":
95 | eval()
96 |
--------------------------------------------------------------------------------
/evaluate/eval_utils.py:
--------------------------------------------------------------------------------
1 | from common.args import args
2 |
3 |
4 | def get_eval_caption():
5 | with open(args.caption_eval_path, "r") as f:
6 | content: list[str] = f.read().splitlines()
7 | image_ids, captions = [], []
8 | eval_end_pos = args.default_eval_samples if args.end_pos == int(1e10) else args.end_pos
9 | for line in content[args.start_pos : eval_end_pos]:
10 | try:
11 | image_name, caption = line.replace("### gpt: ", "").split("###")[:2]
12 | except ValueError as e:
13 | print(f"Skipping line {line} due to {e}")
14 | continue
15 | image_ids.append(int(image_name.split("_")[-1].split(".")[0]))
16 | captions.append(caption)
17 | return image_ids, captions
18 |
--------------------------------------------------------------------------------
/finetune/format_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2023-11-29 18:07:24
3 | # @Author : Shangyu.Xing (starreeze@foxmail.com)
4 | """
5 | from captions, objects and scores format a flattened json file,
6 | where each object has an entry containing fields: sentence, sub-sentence/object position, score
7 |
8 | normal text ... hallucination object/subsentence
9 | ^
10 | |
11 | sub-sentence/object position
12 |
13 | sentence: type=str, stripped sentence until the target
14 | sub-sentence/object mask: type=int, beginning position (char-level) of the unlearn target
15 | score: float, clip score of the object
16 | """
17 |
18 | from __future__ import annotations
19 | import sys, os, json, re
20 |
21 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
22 | sys.path.append(os.path.dirname(SCRIPT_DIR))
23 | from tqdm import tqdm
24 | import numpy as np
25 | from common.args import args
26 |
27 |
28 | def find_subsentence(sentence: str, target_word: str):
29 | sub_sentences = re.split(f"[{args.subsentence_splitter_set}]+", sentence)
30 | for sub_sentence in sub_sentences:
31 | if target_word in sub_sentence:
32 | start_position = sentence.find(sub_sentence)
33 | end_position = start_position + len(sub_sentence)
34 | return start_position, end_position
35 | tqdm.write(f"'{target_word}' not found in '{sentence}', skipping...")
36 | return None
37 |
38 |
39 | def process_pos_neg(image: str, caption: str, objects: str, scores: np.ndarray) -> list[dict[str, str | int]] | None:
40 | results = []
41 | for object, score in zip([obj.strip("[]") for obj in objects.split(args.object_splitter)], scores):
42 | # XXX which to unlearn? objects, subsentence or else?
43 | if args.unlearn_target == "objects":
44 | index = caption.find(object)
45 | assert index != -1
46 | # XXX should we keep the former hallucination objects in the training sample?
47 | results.append(
48 | {
49 | "image": image,
50 | "sentence": caption[: index + len(object)],
51 | "position": index,
52 | "score": float(score),
53 | }
54 | )
55 | elif args.unlearn_target == "subsentence":
56 | result = find_subsentence(caption, object)
57 | if result is not None:
58 | start, end = result
59 | results.append({"image": image, "sentence": caption[:end], "position": start, "score": float(score)})
60 | else:
61 | raise NotImplementedError()
62 | return results
63 |
64 |
65 | def main():
66 | with open(args.caption_data_path, "r") as f:
67 | captions = f.read().splitlines()
68 | captions_d = {}
69 | for sample in captions:
70 | name, caption = sample.split(args.column_splitter)
71 | captions_d[name] = caption
72 | with open(args.object_data_path, "r") as f:
73 | objects = f.read().splitlines()
74 | # as new generated data has no [], it is regarded as norm
75 | scores = np.load(args.norm_result_path, allow_pickle=True)
76 | assert len(objects) == len(scores)
77 | pos_neg, sentence = [], []
78 | for sample in tqdm(zip(objects, scores), total=len(objects)):
79 | object, score = sample
80 | image_name, _, object = object.split(args.column_splitter)
81 | if len(object.split(args.object_splitter)) != score.shape[0] or score.shape[0] == 0:
82 | tqdm.write("objects and scores not match or empty objects! skipping...")
83 | continue
84 | caption = captions_d[image_name]
85 | result = process_pos_neg(image_name, caption, object, score)
86 | if result is not None:
87 | pos_neg.extend(result)
88 | sentence.append(
89 | {"image": image_name, "sentence": caption, "mean": float(score.mean()), "min": float(score.min())}
90 | )
91 | with open(args.pos_neg_data_path, "w") as f:
92 | json.dump(pos_neg, f, indent=2)
93 | with open(args.sentence_data_path, "w") as f:
94 | json.dump(sentence, f, indent=2)
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
--------------------------------------------------------------------------------
/insight/plot_time.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2024-02-09 15:49:44
3 | # @Author : Shangyu.Xing (starreeze@foxmail.com)
4 |
5 | from __future__ import annotations
6 | from matplotlib import pyplot as plt
7 |
8 |
9 | def main():
10 | labels = ["RLHF", "DPO", "CL", "EFUF"]
11 | values = [20, 12, 10, 3]
12 | colors = ["#45B39D"] * (len(values) - 1) + ["#F5B041"]
13 |
14 | # Create the bar chart
15 | plt.figure(figsize=(12, 9))
16 | plt.subplots_adjust(left=0.15, right=0.95, top=0.92, bottom=0.12)
17 | plt.bar(labels, values, color=colors, width=0.6) # type: ignore
18 |
19 | plt.plot(labels, values, color="darkred", marker="o", linestyle="-", linewidth=3, markersize=12)
20 | for label, value in enumerate(values):
21 | plt.text(label, value + 0.7, str(value), ha="center", fontsize=26)
22 |
23 | # Adding the title and labels
24 | # plt.xlabel("Method", fontsize=24)
25 | plt.ylabel("A100 GPU hours", fontsize=30)
26 | plt.xticks(fontsize=30)
27 | plt.yticks(fontsize=30)
28 | plt.ylim(top=22.5)
29 |
30 | # Show the plot
31 | plt.show()
32 |
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
37 |
38 | if __name__ == "__main__":
39 | main()
40 |
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
3 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/adapt_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4 | NUM_SENTINEL_TOKENS: int = 100
5 |
6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7 | """Adds sentinel tokens and padding token (if missing).
8 |
9 | Expands the tokenizer vocabulary to include sentinel tokens
10 | used in mixture-of-denoiser tasks as well as a padding token.
11 |
12 | All added tokens are added as special tokens. No tokens are
13 | added if sentinel tokens and padding token already exist.
14 | """
15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)]
16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17 | if tokenizer.pad_token is None:
18 | tokenizer.add_tokens('', special_tokens=True)
19 | tokenizer.pad_token = ''
20 | assert tokenizer.pad_token_id is not None
21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)])
22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23 | tokenizer.sentinel_token_ids = _sentinel_token_ids
24 |
25 | class AutoTokenizerForMOD(AutoTokenizer):
26 | """AutoTokenizer + Adaptation for MOD.
27 |
28 | A simple wrapper around AutoTokenizer to make instantiating
29 | an MOD-adapted tokenizer a bit easier.
30 |
31 | MOD-adapted tokenizers have sentinel tokens (e.g., ),
32 | a padding token, and a property to get the token ids of the
33 | sentinel tokens.
34 | """
35 |
36 | @classmethod
37 | def from_pretrained(cls, *args, **kwargs):
38 | """See `AutoTokenizer.from_pretrained` docstring."""
39 | tokenizer = super().from_pretrained(*args, **kwargs)
40 | adapt_tokenizer_for_denoising(tokenizer)
41 | return tokenizer
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 | from .attention import ATTN_CLASS_REGISTRY
6 | from .norm import NORM_CLASS_REGISTRY
7 |
8 | class MPTMLP(nn.Module):
9 |
10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11 | super().__init__()
12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13 | self.act = nn.GELU(approximate='none')
14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15 | self.down_proj._is_residual = True
16 |
17 | def forward(self, x):
18 | return self.down_proj(self.act(self.up_proj(x)))
19 |
20 | class MPTBlock(nn.Module):
21 |
22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23 | del kwargs
24 | super().__init__()
25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27 | self.norm_1 = norm_class(d_model, device=device)
28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29 | self.norm_2 = norm_class(d_model, device=device)
30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33 |
34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35 | a = self.norm_1(x)
36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37 | x = x + self.resid_attn_dropout(b)
38 | m = self.norm_2(x)
39 | n = self.ffn(m)
40 | x = x + self.resid_ffn_dropout(n)
41 | return (x, attn_weights, past_key_value)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/custom_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 |
6 | class SharedEmbedding(nn.Embedding):
7 |
8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
9 | if unembed:
10 | return F.linear(input, self.weight)
11 | return super().forward(input)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/meta_init_context.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import torch
3 | import torch.nn as nn
4 |
5 | @contextmanager
6 | def init_empty_weights(include_buffers: bool=False):
7 | """Meta initialization context manager.
8 |
9 | A context manager under which models are initialized with all parameters
10 | on the meta device, therefore creating an empty model. Useful when just
11 | initializing the model would blow the available RAM.
12 |
13 | Args:
14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
15 | not to also put all buffers on the meta device while initializing.
16 |
17 | Example:
18 | ```python
19 | import torch.nn as nn
20 |
21 | # Initialize a model with 100 billions parameters in no time and without using any RAM.
22 | with init_empty_weights():
23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
24 | ```
25 |
26 |
27 |
28 | Any model created under this context manager has no weights. As such you can't do something like
29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
30 |
31 |
32 | """
33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34 | yield f
35 |
36 | @contextmanager
37 | def init_on_device(device: torch.device, include_buffers: bool=False):
38 | """Device initialization context manager.
39 |
40 | A context manager under which models are initialized with all parameters
41 | on the specified device.
42 |
43 | Args:
44 | device (`torch.device`): Device to initialize all parameters on.
45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
46 | not to also put all buffers on the meta device while initializing.
47 |
48 | Example:
49 | ```python
50 | import torch.nn as nn
51 |
52 | with init_on_device(device=torch.device("cuda")):
53 | tst = nn.Liner(100, 100) # on `cuda` device
54 | ```
55 | """
56 | old_register_parameter = nn.Module.register_parameter
57 | if include_buffers:
58 | old_register_buffer = nn.Module.register_buffer
59 |
60 | def register_empty_parameter(module, name, param):
61 | old_register_parameter(module, name, param)
62 | if param is not None:
63 | param_cls = type(module._parameters[name])
64 | kwargs = module._parameters[name].__dict__
65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66 |
67 | def register_empty_buffer(module, name, buffer):
68 | old_register_buffer(module, name, buffer)
69 | if buffer is not None:
70 | module._buffers[name] = module._buffers[name].to(device)
71 | if include_buffers:
72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
73 | else:
74 | tensor_constructors_to_patch = {}
75 |
76 | def patch_tensor_constructor(fn):
77 |
78 | def wrapper(*args, **kwargs):
79 | kwargs['device'] = device
80 | return fn(*args, **kwargs)
81 | return wrapper
82 | try:
83 | nn.Module.register_parameter = register_empty_parameter
84 | if include_buffers:
85 | nn.Module.register_buffer = register_empty_buffer
86 | for torch_function_name in tensor_constructors_to_patch.keys():
87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
88 | yield
89 | finally:
90 | nn.Module.register_parameter = old_register_parameter
91 | if include_buffers:
92 | nn.Module.register_buffer = old_register_buffer
93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94 | setattr(torch, torch_function_name, old_torch_function)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _cast_if_autocast_enabled(tensor):
4 | if torch.is_autocast_enabled():
5 | if tensor.device.type == 'cuda':
6 | dtype = torch.get_autocast_gpu_dtype()
7 | elif tensor.device.type == 'cpu':
8 | dtype = torch.get_autocast_cpu_dtype()
9 | else:
10 | raise NotImplementedError()
11 | return tensor.to(dtype=dtype)
12 | return tensor
13 |
14 | class LPLayerNorm(torch.nn.LayerNorm):
15 |
16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18 |
19 | def forward(self, x):
20 | module_device = x.device
21 | downcast_x = _cast_if_autocast_enabled(x)
22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24 | with torch.autocast(enabled=False, device_type=module_device.type):
25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26 |
27 | def rms_norm(x, weight=None, eps=1e-05):
28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29 | if weight is not None:
30 | return output * weight
31 | return output
32 |
33 | class RMSNorm(torch.nn.Module):
34 |
35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36 | super().__init__()
37 | self.eps = eps
38 | if weight:
39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
40 | else:
41 | self.register_parameter('weight', None)
42 |
43 | def forward(self, x):
44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45 |
46 | class LPRMSNorm(RMSNorm):
47 |
48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50 |
51 | def forward(self, x):
52 | downcast_x = _cast_if_autocast_enabled(x)
53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54 | with torch.autocast(enabled=False, device_type=x.device.type):
55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10 |
11 | raise ValueError(f'Unknown vision tower: {vision_tower}')
12 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16 |
17 | if not delay_load:
18 | self.load_model()
19 | else:
20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21 |
22 | def load_model(self):
23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
25 | self.vision_tower.requires_grad_(False)
26 |
27 | self.is_loaded = True
28 |
29 | def feature_select(self, image_forward_outs):
30 | image_features = image_forward_outs.hidden_states[self.select_layer]
31 | if self.select_feature == 'patch':
32 | image_features = image_features[:, 1:]
33 | elif self.select_feature == 'cls_patch':
34 | image_features = image_features
35 | else:
36 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
37 | return image_features
38 |
39 | @torch.no_grad()
40 | def forward(self, images):
41 | if type(images) is list:
42 | image_features = []
43 | for image in images:
44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
46 | image_features.append(image_feature)
47 | else:
48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
50 |
51 | return image_features
52 |
53 | @property
54 | def dummy_feature(self):
55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56 |
57 | @property
58 | def dtype(self):
59 | return self.vision_tower.dtype
60 |
61 | @property
62 | def device(self):
63 | return self.vision_tower.device
64 |
65 | @property
66 | def config(self):
67 | if self.is_loaded:
68 | return self.vision_tower.config
69 | else:
70 | return self.cfg_only
71 |
72 | @property
73 | def hidden_size(self):
74 | return self.config.hidden_size
75 |
76 | @property
77 | def num_patches(self):
78 | return (self.config.image_size // self.config.patch_size) ** 2
79 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2 |
3 | # Need to call this before importing transformers.
4 | from llava.train.llama_xformers_attn_monkey_patch import (
5 | replace_llama_attn_with_xformers_attn,
6 | )
7 |
8 | replace_llama_attn_with_xformers_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/minigpt4/__init__.py:
--------------------------------------------------------------------------------
1 | from .common.registry import Registry
2 |
3 | Registry.register_path("library_root", "minigpt4")
4 |
--------------------------------------------------------------------------------
/minigpt4/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/common/__init__.py
--------------------------------------------------------------------------------
/minigpt4/common/eval_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from nltk.translate.bleu_score import sentence_bleu
4 |
5 | from minigpt4.common.registry import registry
6 | from minigpt4.common.config import Config
7 |
8 | # imports modules for registration
9 | from minigpt4.datasets.builders import *
10 | from minigpt4.models import *
11 | from minigpt4.processors import *
12 | from minigpt4.runners import *
13 | from minigpt4.tasks import *
14 |
15 |
16 | def eval_parser():
17 | parser = argparse.ArgumentParser(description="Demo")
18 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
19 | parser.add_argument("--name", type=str, default="A2", help="evaluation name")
20 | parser.add_argument("--ckpt", type=str, help="path to configuration file.")
21 | parser.add_argument("--eval_opt", type=str, default="all", help="path to configuration file.")
22 | parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens")
23 | parser.add_argument("--batch_size", type=int, default=32)
24 | parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
25 | parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
26 | parser.add_argument(
27 | "--options",
28 | nargs="+",
29 | help="override some settings in the used config, the key-value pair "
30 | "in xxx=yyy format will be merged into config file (deprecate), "
31 | "change to --cfg-options instead.",
32 | )
33 | return parser
34 |
35 |
36 | def prepare_texts(texts, conv_temp):
37 | convs = [conv_temp.copy() for _ in range(len(texts))]
38 | [conv.append_message(conv.roles[0], "
{}".format(text)) for conv, text in zip(convs, texts)]
39 | [conv.append_message(conv.roles[1], None) for conv in convs]
40 | texts = [conv.get_prompt() for conv in convs]
41 | return texts
42 |
43 |
44 | def init_model(args, device="cuda:0"):
45 | print("Initialization Model")
46 | cfg = Config(args)
47 | # cfg.model_cfg.ckpt = args.ckpt
48 | # cfg.model_cfg.lora_r = args.lora_r
49 | # cfg.model_cfg.lora_alpha = args.lora_alpha
50 |
51 | model_config = cfg.model_cfg
52 | model_cls = registry.get_model_class(model_config.arch)
53 | model = model_cls.from_config(model_config).to(device)
54 |
55 | # import pudb; pudb.set_trace()
56 | key = list(cfg.datasets_cfg.keys())[0]
57 | vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
58 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
59 | print("Initialization Finished")
60 | return model, vis_processor
61 |
62 |
63 | def computeIoU(bbox1, bbox2):
64 | x1, y1, x2, y2 = bbox1
65 | x3, y3, x4, y4 = bbox2
66 | intersection_x1 = max(x1, x3)
67 | intersection_y1 = max(y1, y3)
68 | intersection_x2 = min(x2, x4)
69 | intersection_y2 = min(y2, y4)
70 | intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
71 | bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
72 | bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
73 | union_area = bbox1_area + bbox2_area - intersection_area
74 | iou = intersection_area / union_area
75 | return iou
76 |
--------------------------------------------------------------------------------
/minigpt4/common/gradcam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from scipy.ndimage import filters
4 | from skimage import transform as skimage_transform
5 |
6 |
7 | def getAttMap(img, attMap, blur=True, overlap=True):
8 | attMap -= attMap.min()
9 | if attMap.max() > 0:
10 | attMap /= attMap.max()
11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12 | if blur:
13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14 | attMap -= attMap.min()
15 | attMap /= attMap.max()
16 | cmap = plt.get_cmap("jet")
17 | attMapV = cmap(attMap)
18 | attMapV = np.delete(attMapV, 3, 2)
19 | if overlap:
20 | attMap = (
21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23 | )
24 | return attMap
25 |
--------------------------------------------------------------------------------
/minigpt4/common/optims.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import math
9 |
10 | from minigpt4.common.registry import registry
11 |
12 |
13 | @registry.register_lr_scheduler("linear_warmup_step_lr")
14 | class LinearWarmupStepLRScheduler:
15 | def __init__(
16 | self,
17 | optimizer,
18 | max_epoch,
19 | min_lr,
20 | init_lr,
21 | decay_rate=1,
22 | warmup_start_lr=-1,
23 | warmup_steps=0,
24 | **kwargs
25 | ):
26 | self.optimizer = optimizer
27 |
28 | self.max_epoch = max_epoch
29 | self.min_lr = min_lr
30 |
31 | self.decay_rate = decay_rate
32 |
33 | self.init_lr = init_lr
34 | self.warmup_steps = warmup_steps
35 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36 |
37 | def step(self, cur_epoch, cur_step):
38 | if cur_epoch == 0:
39 | warmup_lr_schedule(
40 | step=cur_step,
41 | optimizer=self.optimizer,
42 | max_step=self.warmup_steps,
43 | init_lr=self.warmup_start_lr,
44 | max_lr=self.init_lr,
45 | )
46 | else:
47 | step_lr_schedule(
48 | epoch=cur_epoch,
49 | optimizer=self.optimizer,
50 | init_lr=self.init_lr,
51 | min_lr=self.min_lr,
52 | decay_rate=self.decay_rate,
53 | )
54 |
55 |
56 | @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57 | class LinearWarmupCosineLRScheduler:
58 | def __init__(
59 | self,
60 | optimizer,
61 | max_epoch,
62 | iters_per_epoch,
63 | min_lr,
64 | init_lr,
65 | warmup_steps=0,
66 | warmup_start_lr=-1,
67 | **kwargs
68 | ):
69 | self.optimizer = optimizer
70 |
71 | self.max_epoch = max_epoch
72 | self.iters_per_epoch = iters_per_epoch
73 | self.min_lr = min_lr
74 |
75 | self.init_lr = init_lr
76 | self.warmup_steps = warmup_steps
77 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78 |
79 | def step(self, cur_epoch, cur_step):
80 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81 | if total_cur_step < self.warmup_steps:
82 | warmup_lr_schedule(
83 | step=cur_step,
84 | optimizer=self.optimizer,
85 | max_step=self.warmup_steps,
86 | init_lr=self.warmup_start_lr,
87 | max_lr=self.init_lr,
88 | )
89 | else:
90 | cosine_lr_schedule(
91 | epoch=total_cur_step,
92 | optimizer=self.optimizer,
93 | max_epoch=self.max_epoch * self.iters_per_epoch,
94 | init_lr=self.init_lr,
95 | min_lr=self.min_lr,
96 | )
97 |
98 |
99 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100 | """Decay the learning rate"""
101 | lr = (init_lr - min_lr) * 0.5 * (
102 | 1.0 + math.cos(math.pi * epoch / max_epoch)
103 | ) + min_lr
104 | for param_group in optimizer.param_groups:
105 | param_group["lr"] = lr
106 |
107 |
108 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109 | """Warmup the learning rate"""
110 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111 | for param_group in optimizer.param_groups:
112 | param_group["lr"] = lr
113 |
114 |
115 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116 | """Decay the learning rate"""
117 | lr = max(min_lr, init_lr * (decay_rate**epoch))
118 | for param_group in optimizer.param_groups:
119 | param_group["lr"] = lr
120 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import sys
4 | dataDir = '../../VQA'
5 | sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
6 | from vqa import VQA
7 | from vqaEvaluation.vqaEval import VQAEval
8 | import matplotlib.pyplot as plt
9 | import skimage.io as io
10 | import json
11 | import random
12 | import os
13 |
14 | # set up file names and paths
15 | versionType ='v2_' # this should be '' when using VQA v2.0 dataset
16 | taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
17 | dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
18 | dataSubType ='train2014'
19 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
20 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
21 | imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
22 | resultType ='fake'
23 | fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
24 |
25 | # An example result json file has been provided in './Results' folder.
26 |
27 | [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
28 | resultType, fileType) for fileType in fileTypes]
29 |
30 | # create vqa object and vqaRes object
31 | vqa = VQA(annFile, quesFile)
32 | vqaRes = vqa.loadRes(resFile, quesFile)
33 |
34 | # create vqaEval object by taking vqa and vqaRes
35 | vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
36 |
37 | # evaluate results
38 | """
39 | If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
40 | By default it uses all the question ids in annotation file
41 | """
42 | vqaEval.evaluate()
43 |
44 | # print accuracies
45 | print "\n"
46 | print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
47 | print "Per Question Type Accuracy is the following:"
48 | for quesType in vqaEval.accuracy['perQuestionType']:
49 | print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
50 | print "\n"
51 | print "Per Answer Type Accuracy is the following:"
52 | for ansType in vqaEval.accuracy['perAnswerType']:
53 | print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
54 | print "\n"
55 | # demo how to use evalQA to retrieve low score result
56 | evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
57 | if len(evals) > 0:
58 | print 'ground truth answers'
59 | randomEval = random.choice(evals)
60 | randomAnn = vqa.loadQA(randomEval)
61 | vqa.showQA(randomAnn)
62 |
63 | print '\n'
64 | print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
65 | ann = vqaRes.loadQA(randomEval)[0]
66 | print "Answer: %s\n" %(ann['answer'])
67 |
68 | imgId = randomAnn[0]['image_id']
69 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
70 | if os.path.isfile(imgDir + imgFilename):
71 | I = io.imread(imgDir + imgFilename)
72 | plt.imshow(I)
73 | plt.axis('off')
74 | plt.show()
75 |
76 | # plot accuracy for various question types
77 | plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
78 | plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
79 | plt.title('Per Question Type Accuracy', fontsize=10)
80 | plt.xlabel('Question Types', fontsize=10)
81 | plt.ylabel('Accuracy', fontsize=10)
82 | plt.show()
83 |
84 | # save evaluation results to ./Results folder
85 | json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
86 | json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
87 | json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
88 | json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
89 |
90 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py:
--------------------------------------------------------------------------------
1 | author='aagrawal'
2 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from vqaTools.vqa import VQA
4 | import random
5 | import skimage.io as io
6 | import matplotlib.pyplot as plt
7 | import os
8 |
9 | dataDir ='../../VQA'
10 | versionType ='v2_' # this should be '' when using VQA v2.0 dataset
11 | taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
12 | dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
13 | dataSubType ='train2014'
14 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
15 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
16 | imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
17 |
18 | # initialize VQA api for QA annotations
19 | vqa=VQA(annFile, quesFile)
20 |
21 | # load and display QA annotations for given question types
22 | """
23 | All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
24 | """
25 | annIds = vqa.getQuesIds(quesTypes='how many');
26 | anns = vqa.loadQA(annIds)
27 | randomAnn = random.choice(anns)
28 | vqa.showQA([randomAnn])
29 | imgId = randomAnn['image_id']
30 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
31 | if os.path.isfile(imgDir + imgFilename):
32 | I = io.imread(imgDir + imgFilename)
33 | plt.imshow(I)
34 | plt.axis('off')
35 | plt.show()
36 |
37 | # load and display QA annotations for given answer types
38 | """
39 | ansTypes can be one of the following
40 | yes/no
41 | number
42 | other
43 | """
44 | annIds = vqa.getQuesIds(ansTypes='yes/no');
45 | anns = vqa.loadQA(annIds)
46 | randomAnn = random.choice(anns)
47 | vqa.showQA([randomAnn])
48 | imgId = randomAnn['image_id']
49 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
50 | if os.path.isfile(imgDir + imgFilename):
51 | I = io.imread(imgDir + imgFilename)
52 | plt.imshow(I)
53 | plt.axis('off')
54 | plt.show()
55 |
56 | # load and display QA annotations for given images
57 | """
58 | Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
59 | Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
60 | """
61 | ids = vqa.getImgIds()
62 | annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
63 | anns = vqa.loadQA(annIds)
64 | randomAnn = random.choice(anns)
65 | vqa.showQA([randomAnn])
66 | imgId = randomAnn['image_id']
67 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
68 | if os.path.isfile(imgDir + imgFilename):
69 | I = io.imread(imgDir + imgFilename)
70 | plt.imshow(I)
71 | plt.axis('off')
72 | plt.show()
73 |
74 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt:
--------------------------------------------------------------------------------
1 | how many
2 | what color is the
3 | is the
4 | where is the
5 | what
6 | what is
7 | are the
8 | what is the
9 | is there a
10 | does the
11 | is the woman
12 | is the man
13 | what is on the
14 | is it
15 | is the girl
16 | is the boy
17 | is the dog
18 | are they
19 | who is
20 | what kind of
21 | what color are the
22 | what is in the
23 | what is the man
24 | is there
25 | what is the woman
26 | what are the
27 | what is the boy
28 | are there
29 | what is the girl
30 | is this
31 | how
32 | which
33 | how many people are
34 | is the cat
35 | why is the
36 | are
37 | will the
38 | what type of
39 | what is the dog
40 | do
41 | is she
42 | does
43 | do the
44 | is
45 | is the baby
46 | are there any
47 | is the lady
48 | can
49 | what animal is
50 | where are the
51 | is the sun
52 | what are they
53 | did the
54 | what is the cat
55 | what is the lady
56 | how many clouds are
57 | is that
58 | is the little girl
59 | is he
60 | are these
61 | how many trees are
62 | how many pillows
63 | are the people
64 | why
65 | is the young
66 | how many windows are
67 | is this a
68 | what is the little
69 | is the tv
70 | how many animals are
71 | who
72 | how many pictures
73 | how many plants are
74 | how many birds are
75 | what color is
76 | what is the baby
77 | is anyone
78 | what color
79 | how many bushes
80 | is the old man
81 | none of the above
82 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt:
--------------------------------------------------------------------------------
1 | how many
2 | is the
3 | what
4 | what color is the
5 | what is the
6 | is this
7 | is this a
8 | what is
9 | are the
10 | what kind of
11 | is there a
12 | what type of
13 | is it
14 | what are the
15 | where is the
16 | is there
17 | does the
18 | what color are the
19 | are these
20 | are there
21 | which
22 | is
23 | what is the man
24 | is the man
25 | are
26 | how
27 | does this
28 | what is on the
29 | what does the
30 | how many people are
31 | what is in the
32 | what is this
33 | do
34 | what are
35 | are they
36 | what time
37 | what sport is
38 | are there any
39 | is he
40 | what color is
41 | why
42 | where are the
43 | what color
44 | who is
45 | what animal is
46 | is the woman
47 | is this an
48 | do you
49 | how many people are in
50 | what room is
51 | has
52 | is this person
53 | what is the woman
54 | can you
55 | why is the
56 | is the person
57 | what is the color of the
58 | what is the person
59 | could
60 | was
61 | is that a
62 | what number is
63 | what is the name
64 | what brand
65 | none of the above
66 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/VQA/license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2014, Aishwarya Agrawal
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 | 2. Redistributions in binary form must reproduce the above copyright notice,
10 | this list of conditions and the following disclaimer in the documentation
11 | and/or other materials provided with the distribution.
12 |
13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14 | AND
15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
18 | FOR
19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
26 | The views and conclusions contained in the software and documentation are
27 | those
28 | of the authors and should not be interpreted as representing official
29 | policies,
30 | either expressed or implied, of the FreeBSD Project.
31 |
--------------------------------------------------------------------------------
/minigpt4/common/vqa_tools/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | __author__ = "aagrawal"
9 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/aokvqa/defaults.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | aok_vqa:
8 | # data_dir: ${env.data_dir}/datasets
9 | data_type: images # [images|videos|features]
10 |
11 | build_info:
12 | # Be careful not to append minus sign (-) before split to avoid itemizing
13 | annotations:
14 | train:
15 | url:
16 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json
17 | storage:
18 | - /path/to/aokvqa_v1p0_train.json
19 | images:
20 | storage: /path/to/coco/images
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/cc_sbu/align.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu_align:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/cc_sbu_align/
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/cc_sbu/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | cc_sbu:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/cc_sbu_dataset/{00000..01255}.tar
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco/caption.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | coco_caption: # name of the dataset builder
8 | # dataset_card: dataset_card/coco_caption.md
9 | # data_dir: ${env.data_dir}/datasets
10 | data_type: images # [images|videos|features]
11 |
12 | build_info:
13 | # Be careful not to append minus sign (-) before split to avoid itemizing
14 | annotations:
15 | train:
16 | url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json
17 | md5: aa31ac474cf6250ebb81d18348a07ed8
18 | storage: /path/to/coco_caption/coco_karpathy_train.json
19 | images:
20 | storage: /path/to/coco/images
21 |
22 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco/defaults_vqa.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | coco_vqa:
8 | # data_dir: ${env.data_dir}/datasets
9 | data_type: images # [images|videos|features]
10 |
11 | build_info:
12 |
13 | annotations:
14 | train:
15 | url:
16 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json
17 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json
18 | storage:
19 | - /path/to/vqav2/vqa_train.json
20 | - /path/to/vqav2/vqa_val.json
21 | images:
22 | storage: /path/to/coco/images
23 |
24 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | invrefcoco:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/coco/images
6 | ann_path: /path/to/refcoco_annotations
7 | dataset: invrefcoco
8 | splitBy: unc
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | invrefcocog:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/coco/images
6 | ann_path: /path/to/refcoco_annotations
7 | dataset: invrefcocog
8 | splitBy: umd
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | invrefcocop:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/coco/images
6 | ann_path: /path/to/refcoco_annotations
7 | dataset: invrefcoco+
8 | splitBy: unc
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco_bbox/refcoco.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | refcoco:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/coco/images
6 | ann_path: /path/to/refcoco_annotations
7 | dataset: refcoco
8 | splitBy: unc
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco_bbox/refcocog.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | refcocog:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/coco/images
6 | ann_path: /path/to/refcoco_annotations
7 | dataset: refcocog
8 | splitBy: umd
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/coco_bbox/refcocop.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | refcocop:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/coco/images
6 | ann_path: /path/to/refcoco_annotations
7 | dataset: refcoco+
8 | splitBy: unc
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/flickr/caption_to_phrase.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | flickr_CaptionToPhrase:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/filtered_flikcr/images
6 | ann_path: /path/to/filtered_flickr/captiontobbox.json
7 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/flickr/default.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | flickr_grounded_caption:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/filtered_flikcr/images
6 | ann_path: /path/to/filtered_flikcr/groundedcaption.json
7 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/flickr/object_to_phrase.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | flickr_ObjectToPhrase:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/filtered_flikcr/images
6 | ann_path: /path/to/filtered_flikcr/phrasetobbox.json
7 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/gqa/balanced_val.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | gqa:
8 | # data_dir: ${env.data_dir}/datasets
9 | data_type: images # [images|videos|features]
10 |
11 | build_info:
12 | # Be careful not to append minus sign (-) before split to avoid itemizing
13 | annotations:
14 | train:
15 | url:
16 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json
17 | storage:
18 | - /path/to/gqa/train_balanced_questions.json
19 |
20 | images:
21 | storage: /path/to/gqa/images
22 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/laion/defaults.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | laion:
3 | data_type: images
4 | build_info:
5 | storage: /path/to/laion_dataset/{00000..10488}.tar
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/llava/conversation.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 |
3 | llava_conversation:
4 | data_type: images
5 | build_info:
6 | image_path: /path/to/coco/images
7 | ann_path: /path/to/llava/conversation_58k.json
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/llava/detail.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | llava_detail:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/coco/images
6 | ann_path: /path/to/llava/detail_23k.json
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/llava/reason.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 |
3 | llava_reason:
4 | data_type: images
5 | build_info:
6 | image_path: /path/to/coco/images
7 | ann_path: /path/to/llava/complex_reasoning_77k.json
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/multitask_conversation/default.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | multitask_conversation:
3 | data_type: images
4 | build_info:
5 |
6 | image_path: /path/to/coco/images
7 | ann_path: /path/to/multitask_conversation/multi_task_conversation.json
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | unnatural_instruction:
3 | data_type: text
4 | build_info:
5 | ann_path: /path/to/unnatural_instructions/filtered_unnatural_instruction.json
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | ocrvqa:
3 | data_type: images
4 | build_info:
5 | image_path: /path/to/ocrvqa/images
6 | ann_path: /path/to/ocrvqa/dataset.json
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/okvqa/defaults.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5 |
6 | datasets:
7 | ok_vqa:
8 | # data_dir: ${env.data_dir}/datasets
9 | data_type: images # [images|videos|features]
10 |
11 | build_info:
12 | # Be careful not to append minus sign (-) before split to avoid itemizing
13 | annotations:
14 | train:
15 | url:
16 | # TODO make this order insensitive
17 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json
18 | storage:
19 | - /path/to/okvqa/okvqa_train.json
20 | images:
21 | storage: /path/to/coco/images
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/textcaps/caption.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | textcaps_caption:
3 | data_type: images
4 |
5 | build_info:
6 | image_path: /path/to/textcaps/train_images
7 | ann_path: /path/to/textcaps/TextCaps_0.1_train.json
8 |
9 |
10 |
--------------------------------------------------------------------------------
/minigpt4/configs/datasets/vg/ref.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | refvg:
3 | data_type: images
4 | build_info:
5 | data_dir: /path/to/visual_genome
--------------------------------------------------------------------------------
/minigpt4/configs/default.yaml:
--------------------------------------------------------------------------------
1 | env:
2 | # For default users
3 | # cache_root: "cache"
4 | # For internal use with persistent storage
5 | cache_root: "/export/home/.cache/minigpt4"
6 |
--------------------------------------------------------------------------------
/minigpt4/configs/models/minigpt4_llama2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: minigpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | has_qformer: False
11 |
12 | # generation configs
13 | prompt: ""
14 |
15 | llama_model: "please set this value to the path of llama2-chat-7b"
16 |
17 | preprocess:
18 | vis_processor:
19 | train:
20 | name: "blip2_image_train"
21 | image_size: 224
22 | eval:
23 | name: "blip2_image_eval"
24 | image_size: 224
25 | text_processor:
26 | train:
27 | name: "blip_caption"
28 | eval:
29 | name: "blip_caption"
30 |
--------------------------------------------------------------------------------
/minigpt4/configs/models/minigpt4_vicuna0.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: minigpt4
3 |
4 | # vit encoder
5 | image_size: 224
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 | freeze_qformer: True
11 |
12 | # Q-Former
13 | num_query_token: 32
14 |
15 | # generation configs
16 | prompt: ""
17 |
18 | llama_model: "please set this value to the path of vicuna model"
19 |
20 | preprocess:
21 | vis_processor:
22 | train:
23 | name: "blip2_image_train"
24 | image_size: 224
25 | eval:
26 | name: "blip2_image_eval"
27 | image_size: 224
28 | text_processor:
29 | train:
30 | name: "blip_caption"
31 | eval:
32 | name: "blip_caption"
33 |
--------------------------------------------------------------------------------
/minigpt4/configs/models/minigpt_v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: minigpt_v2
3 |
4 | # vit encoder
5 | image_size: 448
6 | drop_path_rate: 0
7 | use_grad_checkpoint: False
8 | vit_precision: "fp16"
9 | freeze_vit: True
10 |
11 | # generation configs
12 | prompt: ""
13 |
14 | llama_model: "please set this value to the path of llama2-chat-7b"
15 | lora_r: 64
16 | lora_alpha: 16
17 |
18 |
19 | preprocess:
20 | vis_processor:
21 | train:
22 | name: "blip2_image_train"
23 | image_size: 448
24 | eval:
25 | name: "blip2_image_eval"
26 | image_size: 448
27 | text_processor:
28 | train:
29 | name: "blip_caption"
30 | eval:
31 | name: "blip_caption"
32 |
--------------------------------------------------------------------------------
/minigpt4/conversation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/conversation/__init__.py
--------------------------------------------------------------------------------
/minigpt4/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/datasets/__init__.py
--------------------------------------------------------------------------------
/minigpt4/datasets/builders/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
9 | from minigpt4.datasets.builders.image_text_pair_builder import (
10 | CCSBUBuilder,
11 | LaionBuilder,
12 | CCSBUAlignBuilder
13 | )
14 | from minigpt4.common.registry import registry
15 |
16 | __all__ = [
17 | "CCSBUBuilder",
18 | "LaionBuilder",
19 | "CCSBUAlignBuilder"
20 | ]
21 |
22 |
23 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
24 | """
25 | Example
26 |
27 | >>> dataset = load_dataset("coco_caption", cfg=None)
28 | >>> splits = dataset.keys()
29 | >>> print([len(dataset[split]) for split in splits])
30 |
31 | """
32 | if cfg_path is None:
33 | cfg = None
34 | else:
35 | cfg = load_dataset_config(cfg_path)
36 |
37 | try:
38 | builder = registry.get_builder_class(name)(cfg)
39 | except TypeError:
40 | print(
41 | f"Dataset {name} not found. Available datasets:\n"
42 | + ", ".join([str(k) for k in dataset_zoo.get_names()])
43 | )
44 | exit(1)
45 |
46 | if vis_path is not None:
47 | if data_type is None:
48 | # use default data type in the config
49 | data_type = builder.config.data_type
50 |
51 | assert (
52 | data_type in builder.config.build_info
53 | ), f"Invalid data_type {data_type} for {name}."
54 |
55 | builder.config.build_info.get(data_type).storage = vis_path
56 |
57 | dataset = builder.build_datasets()
58 | return dataset
59 |
60 |
61 | class DatasetZoo:
62 | def __init__(self) -> None:
63 | self.dataset_zoo = {
64 | k: list(v.DATASET_CONFIG_DICT.keys())
65 | for k, v in sorted(registry.mapping["builder_name_mapping"].items())
66 | }
67 |
68 | def get_names(self):
69 | return list(self.dataset_zoo.keys())
70 |
71 |
72 | dataset_zoo = DatasetZoo()
73 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/starreeze/efuf/e0798b6918ca94c5eb6c7f8405868b1e8b2a692a/minigpt4/datasets/datasets/__init__.py
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import json
9 | from typing import Iterable
10 |
11 | from torch.utils.data import Dataset, ConcatDataset
12 | from torch.utils.data.dataloader import default_collate
13 |
14 |
15 |
16 |
17 | class BaseDataset(Dataset):
18 | def __init__(
19 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
20 | ):
21 | """
22 | vis_root (string): Root directory of images (e.g. coco/images/)
23 | ann_root (string): directory to store the annotation file
24 | """
25 | self.vis_root = vis_root
26 |
27 | self.annotation = []
28 | # print("ann paths", ann_paths)
29 | for ann_path in ann_paths:
30 | # print("ann_path", ann_path)
31 | ann = json.load(open(ann_path, "r"))
32 | if isinstance(ann, dict):
33 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
34 | # self.annotation.extend(json.load(open(ann_path, "r")))
35 | else:
36 | self.annotation.extend(json.load(open(ann_path, "r")))
37 |
38 | self.vis_processor = vis_processor
39 | self.text_processor = text_processor
40 |
41 | self._add_instance_ids()
42 |
43 | def __len__(self):
44 | return len(self.annotation)
45 |
46 | def collater(self, samples):
47 | return default_collate(samples)
48 |
49 | def set_processors(self, vis_processor, text_processor):
50 | self.vis_processor = vis_processor
51 | self.text_processor = text_processor
52 |
53 | def _add_instance_ids(self, key="instance_id"):
54 | for idx, ann in enumerate(self.annotation):
55 | ann[key] = str(idx)
56 |
57 |
58 |
59 | class ConcatDataset(ConcatDataset):
60 | def __init__(self, datasets: Iterable[Dataset]) -> None:
61 | super().__init__(datasets)
62 |
63 | def collater(self, samples):
64 | # TODO For now only supports datasets with same underlying collater implementations
65 |
66 | all_keys = set()
67 | for s in samples:
68 | all_keys.update(s)
69 |
70 | shared_keys = all_keys
71 | for s in samples:
72 | shared_keys = shared_keys & set(s.keys())
73 |
74 | samples_shared_keys = []
75 | for s in samples:
76 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
77 |
78 | return self.datasets[0].collater(samples_shared_keys)
79 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/cc_sbu_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import webdataset as wds
4 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
5 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
6 |
7 |
8 | class CCSBUDataset(BaseDataset):
9 | def __init__(self, vis_processor, text_processor, location):
10 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
11 |
12 | self.inner_dataset = wds.DataPipeline(
13 | wds.ResampledShards(location),
14 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
15 | wds.shuffle(1000, handler=wds.warn_and_continue),
16 | wds.decode("pilrgb", handler=wds.warn_and_continue),
17 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
18 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
19 | wds.map(self.to_dict, handler=wds.warn_and_continue),
20 | )
21 |
22 | def to_dict(self, sample):
23 | return {
24 | "image": sample[0],
25 | "answer": self.text_processor(sample[1]["caption"]),
26 | }
27 |
28 |
29 | class CCSBUAlignDataset(CaptionDataset):
30 |
31 | def __getitem__(self, index):
32 |
33 | # TODO this assumes image input, not general enough
34 | ann = self.annotation[index]
35 |
36 | img_file = '{}.jpg'.format(ann["image_id"])
37 | image_path = os.path.join(self.vis_root, img_file)
38 | image = Image.open(image_path).convert("RGB")
39 |
40 | image = self.vis_processor(image)
41 | caption = ann["caption"]
42 |
43 | return {
44 | "image": image,
45 | "answer": caption,
46 | "image_id": self.img_ids[ann["image_id"]],
47 | }
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/gqa_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import os
9 | import json
10 |
11 | from PIL import Image
12 |
13 | from minigpt4.datasets.datasets.vqa_datasets import VQADataset
14 |
15 | from collections import OrderedDict
16 | import random
17 |
18 | class __DisplMixin:
19 | def displ_item(self, index):
20 | sample, ann = self.__getitem__(index), self.annotation[index]
21 |
22 | return OrderedDict(
23 | {
24 | "file": ann["image"],
25 | "question": ann["question"],
26 | "question_id": ann["question_id"],
27 | "answers": "; ".join(ann["answer"]),
28 | "image": sample["image"],
29 | }
30 | )
31 |
32 |
33 | class GQADataset(VQADataset, __DisplMixin):
34 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
35 | super().__init__(vis_processor, text_processor, vis_root, ann_paths)
36 | self.instruction_pool =[
37 | "[vqa] {}",
38 | "[vqa] Based on the image, respond to this question with a short answer: {}"
39 | ]
40 |
41 | def __getitem__(self, index):
42 | ann = self.annotation[index]
43 |
44 | image_path = os.path.join(self.vis_root, ann["image"])
45 | image = Image.open(image_path).convert("RGB")
46 |
47 | image = self.vis_processor(image)
48 | question = self.text_processor(ann["question"])
49 |
50 | instruction = random.choice(self.instruction_pool).format(question)
51 | instruction = "
{} ".format(instruction)
52 |
53 | answers = self.text_processor(ann["answer"])
54 |
55 | return {
56 | "image": image,
57 | "instruction_input": instruction,
58 | "answer": answers,
59 | }
60 |
61 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/laion_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import webdataset as wds
9 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
10 |
11 |
12 | class LaionDataset(BaseDataset):
13 | def __init__(self, vis_processor, text_processor, location):
14 | super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15 |
16 | self.inner_dataset = wds.DataPipeline(
17 | wds.ResampledShards(location),
18 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
19 | wds.shuffle(1000, handler=wds.warn_and_continue),
20 | wds.decode("pilrgb", handler=wds.warn_and_continue),
21 | wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22 | wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23 | wds.map(self.to_dict, handler=wds.warn_and_continue),
24 | )
25 |
26 | def to_dict(self, sample):
27 | return {
28 | "image": sample[0],
29 | "answer": self.text_processor(sample[1]["caption"]),
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/multitask_conversation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import random
5 | import time
6 | import itertools
7 |
8 | import numpy as np
9 | from PIL import Image
10 | import skimage.io as io
11 | import matplotlib.pyplot as plt
12 | from matplotlib.collections import PatchCollection
13 | from matplotlib.patches import Polygon, Rectangle
14 | from torch.utils.data import Dataset
15 | import webdataset as wds
16 |
17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19 |
20 |
21 |
22 |
23 | class MultiTaskConversationDataset(Dataset):
24 | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
25 | """
26 | vis_root (string): Root directory of images (e.g. coco/images/)
27 | ann_root (string): directory to store the annotation file
28 | """
29 | self.vis_root = vis_root
30 |
31 | self.vis_processor = vis_processor
32 | self.text_processor = text_processor
33 |
34 |
35 | with open(ann_path, 'r') as f:
36 | self.ann = json.load(f)
37 |
38 | self.connect_sym = "!@#"
39 |
40 | def __len__(self):
41 | return len(self.ann)
42 |
43 | def __getitem__(self, index):
44 | info = self.ann[index]
45 |
46 | image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
47 | image_path = os.path.join(self.vis_root, image_file)
48 | image = Image.open(image_path).convert("RGB")
49 | image = self.vis_processor(image)
50 |
51 | first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
52 | first_instruction = '
{} '.format(first_instruction)
53 |
54 | questions = [first_instruction]
55 | answers = []
56 |
57 | for i, item in enumerate(info["conversations"][1:]):
58 | if i % 2 ==0: # assistant
59 | assistant_answer = item["value"]
60 | answers.append(assistant_answer)
61 | else:
62 | human_instruction = item["value"]+" "
63 | questions.append(human_instruction)
64 |
65 | questions = self.connect_sym.join(questions)
66 | answers = self.connect_sym.join(answers)
67 |
68 |
69 | return {
70 | "image": image,
71 | "conv_q": questions,
72 | 'conv_a': answers,
73 | "image_id": info['id'],
74 | "connect_sym": self.connect_sym
75 | }
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/ocrvqa_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import random
5 | import time
6 | import itertools
7 |
8 | import numpy as np
9 | from PIL import Image
10 | import skimage.io as io
11 | import matplotlib.pyplot as plt
12 | from matplotlib.collections import PatchCollection
13 | from matplotlib.patches import Polygon, Rectangle
14 | from torch.utils.data import Dataset
15 | import webdataset as wds
16 |
17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19 |
20 |
21 | class OCRVQADataset(Dataset):
22 | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
23 | """
24 | vis_root (string): Root directory of images (e.g. coco/images/)
25 | ann_root (string): directory to store the annotation file
26 | """
27 | self.vis_root = vis_root
28 |
29 | self.vis_processor = vis_processor
30 | self.text_processor = text_processor
31 | self.data = self.create_data(ann_path)
32 |
33 | self.instruction_pool =[
34 | "[vqa] {}",
35 | "[vqa] Based on the image, respond to this question with a short answer: {}"
36 | ]
37 |
38 | def create_data(self, ann_path):
39 | processed_data = []
40 | with open(ann_path, 'r') as f:
41 | data = json.load(f)
42 | for k in data.keys():
43 | if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test
44 | ext = os.path.splitext(data[k]['imageURL'])[1]
45 | imageFile = k + ext
46 | assert len(data[k]['questions']) == len(data[k]['answers'])
47 | for q, a in zip(data[k]['questions'], data[k]['answers']):
48 | processed_data.append(
49 | {'question': q,
50 | 'answer': a,
51 | 'image_path': imageFile,
52 | 'image_id': k,
53 | 'title': data[k]['title'],
54 | 'genre': data[k]['genre'],
55 | }
56 | )
57 | return processed_data
58 |
59 | def __len__(self):
60 | return len(self.data)
61 |
62 | def __getitem__(self, index):
63 | sample = self.data[index]
64 | image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
65 | image = self.vis_processor(image)
66 | question = self.text_processor(sample["question"])
67 | answer = self.text_processor(sample["answer"])
68 |
69 | instruction = random.choice(self.instruction_pool).format(question)
70 | instruction = "
{} ".format(instruction)
71 | return {
72 | "image": image,
73 | "instruction_input": instruction,
74 | "answer": answer,
75 | "image_id": sample['image_id']
76 | }
77 |
78 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/text_caps.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import random
5 | import time
6 | import itertools
7 |
8 | import numpy as np
9 | from PIL import Image
10 | import skimage.io as io
11 | import matplotlib.pyplot as plt
12 | from matplotlib.collections import PatchCollection
13 | from matplotlib.patches import Polygon, Rectangle
14 | from torch.utils.data import Dataset
15 | import webdataset as wds
16 |
17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19 |
20 |
21 |
22 | class TextCapDataset(Dataset):
23 | def __init__(self, vis_processor, text_processor, vis_root, ann_path):
24 | """
25 | vis_root (string): Root directory of images (e.g. coco/images/)
26 | ann_root (string): directory to store the annotation file
27 | """
28 | self.vis_root = vis_root
29 |
30 | self.vis_processor = vis_processor
31 | self.text_processor = text_processor
32 |
33 | self.instruction_pool = [
34 | 'Briefly describe this image.',
35 | 'Provide a concise depiction of this image.',
36 | 'Present a short description of this image.',
37 | 'Summarize this image in a few words.',
38 | 'A short image caption:',
39 | 'A short image description:',
40 | 'A photo of ',
41 | 'An image that shows ',
42 | 'Write a short description for the image. ',
43 | 'Write a description for the photo.',
44 | 'Provide a description of what is presented in the photo.',
45 | 'Briefly describe the content of the image.',
46 | 'Can you briefly explain what you see in the image?',
47 | 'Could you use a few words to describe what you perceive in the photo?',
48 | 'Please provide a short depiction of the picture.',
49 | 'Using language, provide a short account of the image.',
50 | 'Use a few words to illustrate what is happening in the picture.',
51 | ]
52 |
53 | with open(ann_path, 'r') as f:
54 | self.ann = json.load(f)
55 |
56 |
57 | def __len__(self):
58 | return len(self.ann["data"])
59 |
60 |
61 | def __getitem__(self, index):
62 | info = self.ann["data"][index]
63 |
64 | image_file = '{}.jpg'.format(info['image_id'])
65 |
66 | image_path = os.path.join(self.vis_root, image_file)
67 | image = Image.open(image_path).convert("RGB")
68 | image = self.vis_processor(image)
69 |
70 | caption = info["caption_str"]
71 | caption = self.text_processor(caption)
72 | instruction = "
[caption] {} ".format(random.choice(self.instruction_pool))
73 | return {
74 | "image": image,
75 | "instruction_input": instruction,
76 | "answer": caption,
77 | }
78 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/unnatural_instruction.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import random
5 | import time
6 | import itertools
7 |
8 | import numpy as np
9 | from PIL import Image
10 | import skimage.io as io
11 | import matplotlib.pyplot as plt
12 | from matplotlib.collections import PatchCollection
13 | from matplotlib.patches import Polygon, Rectangle
14 | from torch.utils.data import Dataset
15 | import webdataset as wds
16 |
17 | from minigpt4.datasets.datasets.base_dataset import BaseDataset
18 | from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
19 |
20 |
21 | class UnnaturalDataset(Dataset):
22 | def __init__(self, text_processor, ann_path):
23 | """
24 | vis_root (string): Root directory of images (e.g. coco/images/)
25 | ann_root (string): directory to store the annotation file
26 | """
27 | self.text_processor = text_processor
28 |
29 | with open(ann_path, 'r') as f:
30 | self.ann = json.load(f)
31 |
32 | def __len__(self):
33 | return len(self.ann)
34 |
35 | def __getitem__(self, index):
36 | info = self.ann[index]["instances"][0]
37 | instruction = info["instruction_with_input"]
38 | constraints = info["constraints"]
39 | answer = info["output"]
40 | if constraints != None:
41 | instruction = instruction+" "+constraints
42 |
43 | return {
44 | "instruction_input": self.text_processor(instruction),
45 | "answer": self.text_processor(answer),
46 | }
47 |
--------------------------------------------------------------------------------
/minigpt4/datasets/datasets/vg_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import random
5 | import time
6 | import itertools
7 |
8 | import numpy as np
9 | from PIL import Image
10 | from torch.utils.data import Dataset
11 | from visual_genome import local
12 |
13 |
14 |
15 |
16 | class ReferVisualGenomeDataset(Dataset):
17 | def __init__(self, vis_processor, text_processor, data_dir):
18 | """
19 | vis_root (string): Root directory of images (e.g. coco/images/)
20 | ann_root (string): directory to store the annotation file
21 | """
22 | self.data_dir = data_dir
23 |
24 | self.vis_processor = vis_processor
25 | self.text_processor = text_processor
26 |
27 | all_regions = local.get_all_region_descriptions(self.data_dir)
28 | all_regions = [region for regions in all_regions for region in regions]
29 |
30 | # follow OFA practice, only regions smaller than 16384 pixels are used for refer
31 | self.regions = [region for region in all_regions if region.width * region.height < 16384]
32 |
33 |
34 | self.instruction_pool = [
35 | "[refer] {}",
36 | "[refer] give me the location of {}",
37 | "[refer] where is {} ?",
38 | "[refer] from this image, tell me the location of {}",
39 | "[refer] the location of {} is",
40 | "[refer] could you tell me the location for {} ?",
41 | "[refer] where can I locate the {} ?",
42 | ]
43 |
44 |
45 | def __len__(self):
46 | return len(self.regions)
47 |
48 | def preprocess(self, index):
49 | region = self.regions[index]
50 | image_file = region.image.url.split('/')[-2:]
51 | image_path = os.path.join(self.data_dir, *image_file)
52 | image = Image.open(image_path).convert("RGB")
53 | image_orig_size = image.size
54 | image = self.vis_processor(image)
55 | image_new_size = [100,100]
56 |
57 | sample_sentence = region.phrase
58 | refer_sentence = self.text_processor(sample_sentence)
59 |
60 | bbox = [region.x, region.y, region.width, region.height]
61 |
62 | bbox = [
63 | bbox[0] / image_orig_size[0] * image_new_size[0],
64 | bbox[1] / image_orig_size[1] * image_new_size[1],
65 | (bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0],
66 | (bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1]
67 | ]
68 | bbox = [int(x) for x in bbox]
69 | bbox = "{{<{}><{}><{}><{}>}}".format(*bbox)
70 | return {
71 | "image": image,
72 | "refer_sentence": refer_sentence,
73 | "bbox": bbox,
74 | "image_id": region.image.id,
75 | }
76 |
77 | def __getitem__(self, index):
78 | data = self.preprocess(index)
79 | instruction = random.choice(self.instruction_pool).format(data['refer_sentence'])
80 |
81 | instruction = "
{} ".format(instruction)
82 |
83 | return {
84 | "image": data['image'],
85 | "instruction_input": instruction,
86 | "answer": data['bbox'],
87 | "image_id": data['image_id'],
88 | }
89 |
90 |
91 |
--------------------------------------------------------------------------------
/minigpt4/processors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.processors.base_processor import BaseProcessor
9 | from minigpt4.processors.blip_processors import (
10 | Blip2ImageTrainProcessor,
11 | Blip2ImageEvalProcessor,
12 | BlipCaptionProcessor,
13 | )
14 |
15 | from minigpt4.common.registry import registry
16 |
17 | __all__ = [
18 | "BaseProcessor",
19 | "Blip2ImageTrainProcessor",
20 | "Blip2ImageEvalProcessor",
21 | "BlipCaptionProcessor",
22 | ]
23 |
24 |
25 | def load_processor(name, cfg=None):
26 | """
27 | Example
28 |
29 | >>> processor = load_processor("alpro_video_train", cfg=None)
30 | """
31 | processor = registry.get_processor_class(name).from_config(cfg)
32 |
33 | return processor
34 |
--------------------------------------------------------------------------------
/minigpt4/processors/base_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from omegaconf import OmegaConf
9 |
10 |
11 | class BaseProcessor:
12 | def __init__(self):
13 | self.transform = lambda x: x
14 | return
15 |
16 | def __call__(self, item):
17 | return self.transform(item)
18 |
19 | @classmethod
20 | def from_config(cls, cfg=None):
21 | return cls()
22 |
23 | def build(self, **kwargs):
24 | cfg = OmegaConf.create(kwargs)
25 |
26 | return self.from_config(cfg)
27 |
--------------------------------------------------------------------------------
/minigpt4/runners/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.runners.runner_base import RunnerBase
9 |
10 | __all__ = ["RunnerBase"]
11 |
--------------------------------------------------------------------------------
/minigpt4/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.common.registry import registry
9 | from minigpt4.tasks.base_task import BaseTask
10 | from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask
11 |
12 |
13 | def setup_task(cfg):
14 | assert "task" in cfg.run_cfg, "Task name must be provided."
15 |
16 | task_name = cfg.run_cfg.task
17 | task = registry.get_task_class(task_name).setup_task(cfg=cfg)
18 | assert task is not None, "Task {} not properly registered.".format(task_name)
19 |
20 | return task
21 |
22 |
23 | __all__ = [
24 | "BaseTask",
25 | "ImageTextPretrainTask",
26 | ]
27 |
--------------------------------------------------------------------------------
/minigpt4/tasks/image_text_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | from minigpt4.common.registry import registry
9 | from minigpt4.tasks.base_task import BaseTask
10 |
11 |
12 | @registry.register_task("image_text_pretrain")
13 | class ImageTextPretrainTask(BaseTask):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def evaluation(self, model, data_loader, cuda_enabled=True):
18 | pass
19 |
--------------------------------------------------------------------------------
/share4v/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Share4VLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/share4v/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/share4v/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import time
5 |
6 | import openai
7 | import ray
8 | import tqdm
9 |
10 | NUM_SECONDS_TO_SLEEP = 3
11 |
12 |
13 | @ray.remote(num_cpus=4)
14 | def get_eval(content: str, max_tokens: int):
15 | while True:
16 | try:
17 | response = openai.ChatCompletion.create(
18 | model='gpt-4',
19 | messages=[{
20 | 'role': 'system',
21 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
22 | }, {
23 | 'role': 'user',
24 | 'content': content,
25 | }],
26 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
27 | max_tokens=max_tokens,
28 | )
29 | break
30 | except openai.error.RateLimitError:
31 | pass
32 | except Exception as e:
33 | print(e)
34 | time.sleep(NUM_SECONDS_TO_SLEEP)
35 |
36 | print('success!')
37 | return response['choices'][0]['message']['content']
38 |
39 |
40 | def parse_score(review):
41 | try:
42 | score_pair = review.split('\n')[0]
43 | score_pair = score_pair.replace(',', ' ')
44 | sp = score_pair.split(' ')
45 | if len(sp) == 2:
46 | return [float(sp[0]), float(sp[1])]
47 | else:
48 | print('error', review)
49 | return [-1, -1]
50 | except Exception as e:
51 | print(e)
52 | print('error', review)
53 | return [-1, -1]
54 |
55 |
56 | if __name__ == '__main__':
57 | parser = argparse.ArgumentParser(
58 | description='ChatGPT-based QA evaluation.')
59 | parser.add_argument('-q', '--question')
60 | # parser.add_argument('-a', '--answer')
61 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
62 | parser.add_argument('-r', '--rule')
63 | parser.add_argument('-o', '--output')
64 | parser.add_argument('--max-tokens', type=int, default=1024,
65 | help='maximum number of tokens produced in the output')
66 | args = parser.parse_args()
67 |
68 | ray.init()
69 |
70 | f_q = open(os.path.expanduser(args.question))
71 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
72 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
73 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
74 |
75 | review_file = open(f'{args.output}', 'w')
76 |
77 | js_list = []
78 | handles = []
79 | idx = 0
80 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
81 | # if idx == 1:
82 | # break
83 |
84 | ques = json.loads(ques_js)
85 | ans1 = json.loads(ans1_js)
86 | ans2 = json.loads(ans2_js)
87 |
88 | category = json.loads(ques_js)['category']
89 | if category in rule_dict:
90 | rule = rule_dict[category]
91 | else:
92 | rule = rule_dict['default']
93 | prompt = rule['prompt']
94 | role = rule['role']
95 | content = (f'[Question]\n{ques["text"]}\n\n'
96 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
97 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
98 | f'[System]\n{prompt}\n\n')
99 | js_list.append({
100 | 'id': idx+1,
101 | 'question_id': ques['question_id'],
102 | 'answer1_id': ans1['answer_id'],
103 | 'answer2_id': ans2['answer_id'],
104 | 'category': category})
105 | idx += 1
106 | handles.append(get_eval.remote(content, args.max_tokens))
107 | # To avoid the rate limit set by OpenAI
108 | time.sleep(NUM_SECONDS_TO_SLEEP)
109 |
110 | reviews = ray.get(handles)
111 | for idx, review in enumerate(reviews):
112 | scores = parse_score(review)
113 | js_list[idx]['content'] = review
114 | js_list[idx]['tuple'] = scores
115 | review_file.write(json.dumps(js_list[idx]) + '\n')
116 | review_file.close()
117 |
--------------------------------------------------------------------------------
/share4v/eval/eval_pope.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 |
6 | def eval_pope(answers, label_file):
7 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
8 |
9 | for answer in answers:
10 | text = answer['text']
11 |
12 | # Only keep the first sentence
13 | if text.find('.') != -1:
14 | text = text.split('.')[0]
15 |
16 | text = text.replace(',', '')
17 | words = text.split(' ')
18 | if 'No' in words or 'not' in words or 'no' in words:
19 | answer['text'] = 'no'
20 | else:
21 | answer['text'] = 'yes'
22 |
23 | for i in range(len(label_list)):
24 | if label_list[i] == 'no':
25 | label_list[i] = 0
26 | else:
27 | label_list[i] = 1
28 |
29 | pred_list = []
30 | for answer in answers:
31 | if answer['text'] == 'no':
32 | pred_list.append(0)
33 | else:
34 | pred_list.append(1)
35 |
36 | pos = 1
37 | neg = 0
38 | yes_ratio = pred_list.count(1) / len(pred_list)
39 |
40 | TP, TN, FP, FN = 0, 0, 0, 0
41 | for pred, label in zip(pred_list, label_list):
42 | if pred == pos and label == pos:
43 | TP += 1
44 | elif pred == pos and label == neg:
45 | FP += 1
46 | elif pred == neg and label == neg:
47 | TN += 1
48 | elif pred == neg and label == pos:
49 | FN += 1
50 |
51 | print('TP\tFP\tTN\tFN\t')
52 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
53 |
54 | precision = float(TP) / float(TP + FP)
55 | recall = float(TP) / float(TP + FN)
56 | f1 = 2*precision*recall / (precision + recall)
57 | acc = (TP + TN) / (TP + TN + FP + FN)
58 | print('Accuracy: {}'.format(acc))
59 | print('Precision: {}'.format(precision))
60 | print('Recall: {}'.format(recall))
61 | print('F1 score: {}'.format(f1))
62 | print('Yes ratio: {}'.format(yes_ratio))
63 | print('%.3f, %.3f, %.3f, %.3f, %.3f' %
64 | (f1, acc, precision, recall, yes_ratio))
65 |
66 |
67 | if __name__ == "__main__":
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument("--annotation-dir", type=str)
70 | parser.add_argument("--question-file", type=str)
71 | parser.add_argument("--result-file", type=str)
72 | args = parser.parse_args()
73 |
74 | questions = [json.loads(line) for line in open(args.question_file)]
75 | questions = {question['question_id']: question for question in questions}
76 | answers = [json.loads(q) for q in open(args.result_file)]
77 | for file in os.listdir(args.annotation_dir):
78 | assert file.startswith('coco_pope_')
79 | assert file.endswith('.json')
80 | category = file[10:-5]
81 | cur_answers = [
82 | x for x in answers if questions[x['question_id']]['category'] == category]
83 | print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
84 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
85 | print("====================================")
86 |
--------------------------------------------------------------------------------
/share4v/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--our-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | our_predictions = [json.loads(line) for line in open(args.our_result)]
45 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49 |
50 | results = defaultdict(lambda: 0)
51 |
52 | for prob_id, prob in split_problems.items():
53 | if prob_id not in our_predictions:
54 | continue
55 | if prob_id not in gpt4_predictions:
56 | continue
57 | our_pred = our_predictions[prob_id]['text']
58 | gpt4_pred = gpt4_predictions[prob_id]
59 |
60 | pattern = re.compile(r'The answer is ([A-Z]).')
61 | our_res = pattern.findall(our_pred)
62 | if len(our_res) == 1:
63 | our_answer = our_res[0] # 'A', 'B', ...
64 | else:
65 | our_answer = "FAILED"
66 | gpt4_res = pattern.findall(gpt4_pred)
67 | if len(gpt4_res) == 1:
68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69 | else:
70 | gpt4_answer = "FAILED"
71 |
72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74 |
75 | if gpt4_answer == 'FAILED':
76 | results['gpt4_failed'] += 1
77 | # continue
78 | gpt4_pred_idx = our_pred_idx
79 | # if our_pred_idx != prob['answer']:
80 | # print(our_predictions[prob_id]['prompt'])
81 | # print('-----------------')
82 | # print(f'LECTURE: {prob["lecture"]}')
83 | # print(f'SOLUTION: {prob["solution"]}')
84 | # print('=====================')
85 | else:
86 | # continue
87 | pass
88 | # gpt4_pred_idx = our_pred_idx
89 |
90 | if gpt4_pred_idx == prob['answer']:
91 | results['correct'] += 1
92 | else:
93 | results['incorrect'] += 1
94 |
95 |
96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97 | results['correct_upperbound'] += 1
98 |
99 | correct = results['correct']
100 | total = results['correct'] + results['incorrect']
101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 |
105 |
--------------------------------------------------------------------------------
/share4v/eval/eval_textvqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import re
5 |
6 | from share4v.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--annotation-file', type=str)
12 | parser.add_argument('--result-file', type=str)
13 | parser.add_argument('--result-dir', type=str)
14 | return parser.parse_args()
15 |
16 |
17 | def prompt_processor(prompt):
18 | if prompt.startswith('OCR tokens: '):
19 | pattern = r"Question: (.*?) Short answer:"
20 | match = re.search(pattern, prompt, re.DOTALL)
21 | question = match.group(1)
22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23 | if prompt.startswith('Reference OCR token:'):
24 | question = prompt.split('\n')[1]
25 | else:
26 | question = prompt.split('\n')[0]
27 | elif len(prompt.split('\n')) == 2:
28 | question = prompt.split('\n')[0]
29 | else:
30 | assert False
31 |
32 | return question.lower()
33 |
34 |
35 | def eval_single(annotation_file, result_file):
36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37 | print(experiment_name)
38 | annotations = json.load(open(annotation_file))['data']
39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40 | results = [json.loads(line) for line in open(result_file)]
41 |
42 | pred_list = []
43 | for result in results:
44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45 | pred_list.append({
46 | "pred_answer": result['text'],
47 | "gt_answers": annotation['answers'],
48 | })
49 |
50 | evaluator = TextVQAAccuracyEvaluator()
51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52 |
53 |
54 | if __name__ == "__main__":
55 | args = get_args()
56 |
57 | if args.result_file is not None:
58 | eval_single(args.annotation_file, args.result_file)
59 |
60 | if args.result_dir is not None:
61 | for result_file in sorted(os.listdir(args.result_dir)):
62 | if not result_file.endswith('.jsonl'):
63 | print(f'Skipping {result_file}')
64 | continue
65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
66 |
--------------------------------------------------------------------------------
/share4v/eval/model_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import shortuuid
6 | import torch
7 | from tqdm import tqdm
8 | from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
9 |
10 | from share4v.conversation import default_conversation
11 | from share4v.utils import disable_torch_init
12 |
13 |
14 | # new stopping implementation
15 | class KeywordsStoppingCriteria(StoppingCriteria):
16 | def __init__(self, keywords, tokenizer, input_ids):
17 | self.keywords = keywords
18 | self.tokenizer = tokenizer
19 | self.start_len = None
20 | self.input_ids = input_ids
21 |
22 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
23 | if self.start_len is None:
24 | self.start_len = self.input_ids.shape[1]
25 | else:
26 | outputs = self.tokenizer.batch_decode(
27 | output_ids[:, self.start_len:], skip_special_tokens=True)[0]
28 | for keyword in self.keywords:
29 | if keyword in outputs:
30 | return True
31 | return False
32 |
33 |
34 | @torch.inference_mode()
35 | def eval_model(model_name, questions_file, answers_file):
36 | # Model
37 | disable_torch_init()
38 | model_name = os.path.expanduser(model_name)
39 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
40 | model = AutoModelForCausalLM.from_pretrained(model_name,
41 | torch_dtype=torch.float16).cuda()
42 |
43 | ques_file = open(os.path.expanduser(questions_file), "r")
44 | ans_file = open(os.path.expanduser(answers_file), "w")
45 | for i, line in enumerate(tqdm(ques_file)):
46 | idx = json.loads(line)["question_id"]
47 | qs = json.loads(line)["text"]
48 | cat = json.loads(line)["category"]
49 | conv = default_conversation.copy()
50 | conv.append_message(conv.roles[0], qs)
51 | prompt = conv.get_prompt()
52 | inputs = tokenizer([prompt])
53 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
54 | stopping_criteria = KeywordsStoppingCriteria(
55 | [conv.sep], tokenizer, input_ids)
56 | output_ids = model.generate(
57 | input_ids,
58 | do_sample=True,
59 | use_cache=True,
60 | temperature=0.7,
61 | max_new_tokens=1024,
62 | stopping_criteria=[stopping_criteria])
63 | outputs = tokenizer.batch_decode(
64 | output_ids, skip_special_tokens=True)[0]
65 | try:
66 | index = outputs.index(conv.sep, len(prompt))
67 | except ValueError:
68 | outputs += conv.sep
69 | index = outputs.index(conv.sep, len(prompt))
70 |
71 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
72 | ans_id = shortuuid.uuid()
73 | ans_file.write(json.dumps({"question_id": idx,
74 | "text": outputs,
75 | "answer_id": ans_id,
76 | "model_id": model_name,
77 | "metadata": {}}) + "\n")
78 | ans_file.flush()
79 | ans_file.close()
80 |
81 |
82 | if __name__ == "__main__":
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
85 | parser.add_argument("--question-file", type=str,
86 | default="tables/question.jsonl")
87 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
88 | args = parser.parse_args()
89 |
90 | eval_model(args.model_name, args.question_file, args.answers_file)
91 |
--------------------------------------------------------------------------------
/share4v/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
1 | """Generate answers with GPT-3.5"""
2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3 | import argparse
4 | import concurrent.futures
5 | import json
6 | import os
7 | import time
8 |
9 | import openai
10 | import shortuuid
11 | import tqdm
12 |
13 | openai.api_key = 'xxx' # replace with your own key
14 | MODEL = 'gpt-3.5-turbo'
15 | MODEL_ID = 'gpt-3.5-turbo:20230327'
16 |
17 |
18 | def get_answer(question_id: int, question: str, max_tokens: int):
19 | ans = {
20 | 'answer_id': shortuuid.uuid(),
21 | 'question_id': question_id,
22 | 'model_id': MODEL_ID,
23 | }
24 | for _ in range(3):
25 | try:
26 | response = openai.ChatCompletion.create(
27 | model=MODEL,
28 | messages=[{
29 | 'role': 'system',
30 | 'content': 'You are a helpful assistant.'
31 | }, {
32 | 'role': 'user',
33 | 'content': question,
34 | }],
35 | temperature=0,
36 | max_tokens=max_tokens,
37 | )
38 | ans['text'] = response['choices'][0]['message']['content']
39 | return ans
40 | except Exception as e:
41 | print('[ERROR]', e)
42 | ans['text'] = '#ERROR#'
43 | time.sleep(1)
44 | return ans
45 |
46 |
47 | if __name__ == '__main__':
48 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
49 | parser.add_argument('-q', '--question')
50 | parser.add_argument('-o', '--output')
51 | parser.add_argument('--max-tokens', type=int, default=1024,
52 | help='maximum number of tokens produced in the output')
53 | args = parser.parse_args()
54 |
55 | questions_dict = {}
56 | with open(os.path.expanduser(args.question)) as f:
57 | for line in f:
58 | if not line:
59 | continue
60 | q = json.loads(line)
61 | questions_dict[q['question_id']] = q['text']
62 |
63 | answers = []
64 |
65 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
66 | futures = []
67 | for qid, question in questions_dict.items():
68 | future = executor.submit(
69 | get_answer, qid, question, args.max_tokens)
70 | futures.append(future)
71 |
72 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
73 | answers.append(future.result())
74 |
75 | answers.sort(key=lambda x: x['question_id'])
76 |
77 | with open(os.path.expanduser(args.output), 'w') as f:
78 | table = [json.dumps(ans) for ans in answers]
79 | f.write('\n'.join(table))
80 |
--------------------------------------------------------------------------------
/share4v/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(
11 | description='ChatGPT-based QA evaluation.')
12 | parser.add_argument('-d', '--dir', default=None)
13 | parser.add_argument('-v', '--version', default=None)
14 | parser.add_argument('-s', '--select', nargs='*', default=None)
15 | parser.add_argument('-f', '--files', nargs='*', default=[])
16 | parser.add_argument('-i', '--ignore', nargs='*', default=[])
17 | return parser.parse_args()
18 |
19 |
20 | if __name__ == '__main__':
21 | args = parse_args()
22 |
23 | if args.ignore is not None:
24 | args.ignore = [int(x) for x in args.ignore]
25 |
26 | if len(args.files) > 0:
27 | review_files = args.files
28 | else:
29 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith(
30 | 'gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
31 |
32 | for review_file in sorted(review_files):
33 | config = os.path.basename(review_file).replace(
34 | 'gpt4_text_', '').replace('.jsonl', '')
35 | if args.select is not None and any(x not in config for x in args.select):
36 | continue
37 | if '0613' in config:
38 | version = '0613'
39 | else:
40 | version = '0314'
41 | if args.version is not None and args.version != version:
42 | continue
43 | scores = defaultdict(list)
44 | print(config)
45 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
46 | for review_str in f:
47 | review = json.loads(review_str)
48 | if review['question_id'] in args.ignore:
49 | continue
50 | if 'category' in review:
51 | scores[review['category']].append(review['tuple'])
52 | scores['all'].append(review['tuple'])
53 | else:
54 | if 'tuple' in review:
55 | scores['all'].append(review['tuple'])
56 | else:
57 | scores['all'].append(review['score'])
58 | for k, v in sorted(scores.items()):
59 | stats = np.asarray(v).mean(0).tolist()
60 | stats = [round(x, 3) for x in stats]
61 | # print(k, stats, round(stats[1]/stats[0]*100, 1))
62 | print(k, round(stats[1]/stats[0]*100, 1),
63 | round(stats[0] * 10, 1), round(stats[1] * 10, 1))
64 | print('=================================')
65 |
--------------------------------------------------------------------------------
/share4v/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.share4v_llama import (Share4VConfig,
2 | Share4VLlamaForCausalLM)
3 | from .multimodal_encoder.configuration_evaclip import EvaCLIPVisionConfig
4 | from .multimodal_encoder.modeling_evaclip import EvaCLIPVisionModel
5 |
--------------------------------------------------------------------------------
/share4v/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m share4v.model.consolidate --src ~/model_weights/share4v-7b --dst ~/model_weights/share4v-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoModelForCausalLM, AutoTokenizer
9 |
10 | from share4v.model import *
11 | from share4v.model.utils import auto_upgrade
12 |
13 |
14 | def consolidate_ckpt(src_path, dst_path):
15 | print("Loading model")
16 | auto_upgrade(src_path)
17 | src_model = AutoModelForCausalLM.from_pretrained(
18 | src_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
19 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
20 | src_model.save_pretrained(dst_path)
21 | src_tokenizer.save_pretrained(dst_path)
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--src", type=str, required=True)
27 | parser.add_argument("--dst", type=str, required=True)
28 |
29 | args = parser.parse_args()
30 |
31 | consolidate_ckpt(args.src, args.dst)
32 |
--------------------------------------------------------------------------------
/share4v/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .clip_encoder import CLIPVisionTower
4 |
5 |
6 | def build_vision_tower(vision_tower_cfg, **kwargs):
7 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
8 | is_absolute_path_exists = os.path.exists(vision_tower)
9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or vision_tower.startswith("Lin-Chen"):
10 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
11 |
12 | raise ValueError(f'Unknown vision tower: {vision_tower}')
13 |
--------------------------------------------------------------------------------
/share4v/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
4 |
5 | from .configuration_evaclip import EvaCLIPVisionConfig
6 | from .modeling_evaclip import EvaCLIPVisionModel
7 |
8 |
9 | class CLIPVisionTower(nn.Module):
10 | def __init__(self, vision_tower, args, delay_load=False):
11 | super().__init__()
12 |
13 | self.is_loaded = False
14 |
15 | self.vision_tower_name = vision_tower
16 | self.select_layer = args.mm_vision_select_layer
17 | self.select_feature = getattr(
18 | args, 'mm_vision_select_feature', 'patch')
19 |
20 | if not delay_load:
21 | self.load_model()
22 | else:
23 | self.cfg_only = CLIPVisionConfig.from_pretrained(
24 | self.vision_tower_name)
25 |
26 | def load_model(self):
27 | print(f'Load vision tower from {self.vision_tower_name}')
28 | self.image_processor = CLIPImageProcessor.from_pretrained(
29 | self.vision_tower_name)
30 | if 'eva' in self.vision_tower_name.lower():
31 | vision_cfg = EvaCLIPVisionConfig.from_pretrained(
32 | self.vision_tower_name)
33 | self.vision_tower = EvaCLIPVisionModel.from_pretrained(
34 | self.vision_tower_name, config=vision_cfg)
35 | else:
36 | self.vision_tower = CLIPVisionModel.from_pretrained(
37 | self.vision_tower_name)
38 | self.vision_tower.requires_grad_(False)
39 |
40 | self.is_loaded = True
41 |
42 | def feature_select(self, image_forward_outs):
43 | image_features = image_forward_outs.hidden_states[self.select_layer]
44 | if self.select_feature == 'patch':
45 | image_features = image_features[:, 1:]
46 | elif self.select_feature == 'cls_patch':
47 | image_features = image_features
48 | else:
49 | raise ValueError(
50 | f'Unexpected select feature: {self.select_feature}')
51 | return image_features
52 |
53 | # @torch.no_grad() comment to enable fine-tune vit
54 | def forward(self, images):
55 | if type(images) is list:
56 | image_features = []
57 | for image in images:
58 | image_forward_out = self.vision_tower(image.to(
59 | device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
60 | image_feature = self.feature_select(
61 | image_forward_out).to(image.dtype)
62 | image_features.append(image_feature)
63 | else:
64 | image_forward_outs = self.vision_tower(
65 | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
66 | image_features = self.feature_select(
67 | image_forward_outs).to(images.dtype)
68 |
69 | return image_features
70 |
71 | @property
72 | def dummy_feature(self):
73 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
74 |
75 | @property
76 | def dtype(self):
77 | return self.vision_tower.dtype
78 |
79 | @property
80 | def device(self):
81 | return self.vision_tower.device
82 |
83 | @property
84 | def config(self):
85 | if self.is_loaded:
86 | return self.vision_tower.config
87 | else:
88 | return self.cfg_only
89 |
90 | @property
91 | def hidden_size(self):
92 | return self.config.hidden_size
93 |
94 | @property
95 | def num_patches(self):
96 | return (self.config.image_size // self.config.patch_size) ** 2
97 |
--------------------------------------------------------------------------------
/share4v/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 |
29 | def forward(self, x):
30 | x = self.pre_norm(x)
31 | return x + self.proj(x)
32 |
33 |
34 | def build_vision_projector(config, delay_load=False, **kwargs):
35 | projector_type = getattr(config, 'mm_projector_type', 'linear')
36 |
37 | if projector_type == 'linear':
38 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
39 |
40 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
41 | if mlp_gelu_match:
42 | mlp_depth = int(mlp_gelu_match.group(1))
43 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
44 | for _ in range(1, mlp_depth):
45 | modules.append(nn.GELU())
46 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47 | return nn.Sequential(*modules)
48 |
49 | if projector_type == 'identity':
50 | return IdentityMap()
51 |
52 | raise ValueError(f'Unknown projector type: {projector_type}')
53 |
--------------------------------------------------------------------------------
/share4v/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'share4v' in config and 'share4v' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer ShareGPT4V code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "share4v")
15 | cfg.architectures[0] = 'Share4VLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/share4v/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from share4v.train.train import train
7 | from share4v.train.llama_flash_attn_monkey_patch import \
8 | replace_llama_attn_with_flash_attn
9 |
10 | replace_llama_attn_with_flash_attn()
11 |
12 |
13 | if __name__ == "__main__":
14 | train()
15 |
--------------------------------------------------------------------------------
/share4v/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2 |
3 | # Need to call this before importing transformers.
4 | from share4v.train.train import train
5 | from share4v.train.llama_xformers_attn_monkey_patch import \
6 | replace_llama_attn_with_xformers_attn
7 |
8 | replace_llama_attn_with_xformers_attn()
9 |
10 |
11 | if __name__ == "__main__":
12 | train()
13 |
--------------------------------------------------------------------------------
/share4v/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def disable_torch_init():
5 | """
6 | Disable the redundant torch default initialization to accelerate model creation.
7 | """
8 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
9 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
10 |
--------------------------------------------------------------------------------