├── DATASET_CARD.md ├── README.md ├── html_extraction ├── README.md ├── demo_magic-html_extraction.html ├── demo_origin_page.html ├── demo_trafilatura_extraction.html └── demo_url.txt ├── human_feedback_textfiltering ├── README.md ├── meta.py ├── mytraf.py ├── myxml.py ├── run.py ├── test.jsonl └── text_filter │ ├── __init__.py │ ├── filtering_utils.py │ ├── filters.py │ └── spam_word_less.txt ├── mllm_internvl ├── DATASET_CARD.md ├── README.md ├── __init__.py ├── classification_utils.py ├── coco_metric.py ├── eval_datasets.py ├── eval_models.py ├── evaluate.py ├── evaluate_with_slurm.sh ├── fill_vqa_testdev_results.py ├── ok_vqa_utils.py ├── requirements.txt ├── rices.py ├── test.sh ├── utils.py └── vqa_metric.py ├── mllm_llava ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── llava │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── conversation_2.py │ ├── eval │ │ ├── eval_gpt_review.py │ │ ├── eval_gpt_review_bench.py │ │ ├── eval_gpt_review_visual.py │ │ ├── eval_mm-vet.py │ │ ├── eval_mmlu.py │ │ ├── eval_mmlu_ppl.py │ │ ├── eval_pope.py │ │ ├── eval_science_qa.py │ │ ├── eval_science_qa_gpt4.py │ │ ├── eval_science_qa_gpt4_requery.py │ │ ├── eval_textvqa.py │ │ ├── eval_vqa.py │ │ ├── generate_webpage_data_from_table.py │ │ ├── m4c_evaluator.py │ │ ├── model_captioning_loader_few_shot.py │ │ ├── model_qa.py │ │ ├── model_vqa.py │ │ ├── model_vqa_loader.py │ │ ├── model_vqa_loader_few_shot.py │ │ ├── model_vqa_mmbench.py │ │ ├── model_vqa_science.py │ │ ├── qa_baseline_gpt35.py │ │ ├── rices.py │ │ ├── run_llava.py │ │ ├── summarize_gpt_review.py │ │ ├── translation_husky.py │ │ ├── translation_tool.py │ │ └── webpage │ │ │ ├── figures │ │ │ ├── alpaca.png │ │ │ ├── bard.jpg │ │ │ ├── chatgpt.svg │ │ │ ├── llama.jpg │ │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ │ └── vicuna.jpeg │ │ │ ├── index.html │ │ │ ├── script.js │ │ │ └── styles.css │ ├── mm_utils.py │ ├── model │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model │ │ │ ├── __init__.py │ │ │ ├── internlm_chat │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_internlm2.py │ │ │ │ ├── modeling_internlm2.py │ │ │ │ └── tokenization_internlm2.py │ │ │ ├── llava_internlm.py │ │ │ ├── llava_llama.py │ │ │ ├── llava_mistral.py │ │ │ └── llava_mpt.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ ├── multimodal_projector │ │ │ └── builder.py │ │ └── utils.py │ ├── serve │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── controller.py │ │ ├── examples │ │ │ ├── extreme_ironing.jpg │ │ │ └── waterview.jpg │ │ ├── gradio_web_server.py │ │ ├── model_worker.py │ │ ├── register_worker.py │ │ ├── sglang_worker.py │ │ └── test_message.py │ ├── train │ │ ├── datasets.py │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llama_xformers_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── train.py │ │ ├── train_interleaved.py │ │ ├── train_interleaved_mem.py │ │ ├── train_mem.py │ │ └── train_xformers.py │ └── utils.py ├── pyproject.toml └── scripts │ ├── convert_fewshot_sft.py │ └── vis_fewshot_sft.py └── mllm_openflamingo ├── .gitignore ├── LICENSE ├── README.md ├── _optim_utils.py ├── accelerate_configs ├── accelerate_config_ddp.yaml ├── accelerate_config_fsdp.yaml ├── accelerate_config_standalone.yaml ├── accelerate_config_zero1.yaml ├── accelerate_config_zero1_slurm.yaml ├── accelerate_config_zero2.yaml ├── accelerate_config_zero2_slurm.yaml ├── accelerate_config_zero3.yaml ├── accelerate_config_zero3_offload.yaml └── accelerate_config_zero3_slurm.yaml ├── environment.yml ├── open_flamingo ├── __init__.py ├── eval │ ├── README.md │ ├── __init__.py │ ├── classification_utils.py │ ├── coco_metric.py │ ├── eval_datasets.py │ ├── eval_model.py │ ├── evaluate.py │ ├── models │ │ ├── blip.py │ │ └── open_flamingo.py │ ├── ok_vqa_utils.py │ ├── rices.py │ ├── utils.py │ └── vqa_metric.py ├── scripts │ ├── cache_rices_features.py │ ├── convert_mmc4_to_wds.py │ └── fill_vqa_testdev_results.py ├── src │ ├── __init__.py │ ├── factory.py │ ├── flamingo.py │ ├── flamingo_lm.py │ ├── helpers.py │ └── utils.py └── train │ ├── README.md │ ├── __init__.py │ ├── data.py │ ├── data_utils.py │ ├── distributed.py │ ├── train.py │ ├── train_omnicorpus.py │ └── train_utils.py ├── requirements-eval.txt ├── requirements-training.txt ├── requirements.txt ├── setup.py └── setup_deepspeed_adamcpu.py /DATASET_CARD.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /html_extraction/README.md: -------------------------------------------------------------------------------- 1 | # HTML Extraction Tools in OmniCorpus —— magic-html 2 | 3 | The main body extraction toolkit of the data engine for OmniCorpus has been merged into [magic-html](https://github.com/opendatalab/magic-html), which has been significantly improved compared to the commonly used [trafilatura](https://github.com/adbar/trafilatura). 4 | 5 | In terms of accuracy, we have addressed the issue where trafilatura would overlook the main content of an HTML document when extracting images, and enhanced its capability to handle Chinese, Japanese, and Arabic documents. Additionally, we have incorporated techniques to trim web noise regions based on HTML structure (such as clusters of lists and navigation bars) and style (targeting elements like advertisements, comments, JavaScript, and CSS). 6 | 7 | In terms of efficiency, we optimized the process based on HTML nodes and streamlined the processing pipeline by eliminating the fallback process in challenging cases. With these two improvements, we can not only extract more informative content from the main body but also double the speed of the extraction process. 8 | 9 | We present a demo to compare [extraction results of trafilatura](./demo_trafilatura_extraction.html) with [ours](./demo_magic-html_extraction.html). (You can download and then browse them with webpage browser.) 10 | 11 | 12 | 13 | ## Features of magic-html 14 | 15 | - Flexible export construction. (HTML or customizable TXT/MarkDown) 16 | - Supports extraction of both pure textual and multimodal corpora. 17 | - Robust for various layout. (Such as articles/forums) 18 | - Support Latex formula extraction and transforming. 19 | 20 | 21 | 22 | ## Installation and Usage 23 | 24 | Install with pip wheel: 25 | 26 | ```shell 27 | pip install https://github.com/opendatalab/magic-html/releases/download/magic_html-0.1.2-released/magic_html-0.1.2-py3-none-any.whl 28 | ``` 29 | 30 | Extract the main body of a demo HTML: 31 | 32 | ```python 33 | from magic_html import GeneralExtractor 34 | 35 | # initialize the extractor 36 | extractor = GeneralExtractor() 37 | 38 | url = "http://example.com/" 39 | html = """ 40 | 41 | 42 | 43 | 44 | Example Domain 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 |
53 |

Example Domain

54 |

This domain is for use in illustrative examples in documents. You may use this 55 | domain in literature without prior coordination or asking for permission.

56 |

More information...

57 |
58 | 59 | 60 | """ 61 | 62 | # extract main content of articles HTML 63 | data = extractor.extract(html, base_url=url) 64 | 65 | # extract main content of forum HTML 66 | # data = extractor.extract(html, base_url=url, html_type="forum") 67 | 68 | # extract main content of WeChat official accounts HTML 69 | # data = extractor.extract(html, base_url=url, html_type="weixin") 70 | 71 | print(data) 72 | ``` 73 | 74 | 75 | 76 | ## Others 77 | 78 | LICENSE: [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0.html) 79 | 80 | Acknowledgments: 81 | 82 | - [trafilatura](https://github.com/adbar/trafilatura) 83 | - [readability-lxml](https://github.com/buriy/python-readability) 84 | 85 | -------------------------------------------------------------------------------- /html_extraction/demo_trafilatura_extraction.html: -------------------------------------------------------------------------------- 1 | 2 |

Welcome to Round Table Gauteng

Did you know that the best young men’s organization in the world was founded in February 1927? That’s right, Round Table was created after the Duke of Windsor gave a speech urging young business and professional men to get together around a table and improve themselves and their communities.

Round Table is a club specifically designed for young men between 18 and 40 (45 in some countries) years old, regardless of their beliefs or political opinions. It’s an exclusive club, but anyone can join if they’re invited by a current member who will then become their sponsor.

Round Table is not just any ordinary club, it’s a voluntary non-political, non-denomination non-profit organization. It’s all about meeting new people, exchanging ideas, and developing friendships. With approximately 30,000 members worldwide, Round Table is an international organization that promotes fellowship and community service.

As a member of Round Table, you’ll have the opportunity to give back to your community by raising funds through various events and lending a hand wherever it’s needed. You’ll also have the chance to involve your family, allowing your kids to learn important values and grow alongside you.

Our motto is Adopt, Adapt, Improve, and we strive to foster growth in all areas of our lives. We aim to develop fellowship among young men through their professional and business occupations, while emphasizing that one’s calling offers an excellent medium of service to the community. We cultivate the highest ideals in business, professional, and civic traditions, while recognizing the worthiness of all legitimate occupations.

Round Table fosters responsible citizenship and loyalty to our country, while furthering the establishment of peace and goodwill in international relationships. We achieve these goals through meetings, lectures, discussions, and other activities.

Our logo, the Rondel, is a circular object that symbolizes equality. Each seat around the table is equal, and everyone has an equal voice. The original disc that inspired the logo is over 4 meters in diameter and was likely used as a table. The center of the logo features the Tudor’s rose, which was the emblem of the Tudors from 1450 on, and above it, a king with his scepter, possibly King Arthur.

So, if you’re a young man looking for a way to make a difference in your community while making new friends and developing yourself, consider joining Round Table. Adopt, Adapt, Improve – it’s not just our motto, it’s a way of life.

Round Table raises funds through either organizing various fund-raising events, or by offering time and energy to give a hand wherever it’s needed. As with any hobby, you certainly get out of it what you put in.

This personal commitment certainly also offers a wider experience: Round Table is used to involve families, where our kids can learn to spend time together, sharing our values and growing in similar aspects in life as the Tablers themselves do.

Aims and Objectives

These are the goals that we strive for together – to foster growth across all levels

Our Motto is Adopt, Adapt, Improve.

The Rondel

Where it comes from and what it means

The first Round Table (Norwich, England) adopted the proposal from Neville Headon from RT Manchester: a drawing inspired by a wooden disc which can still be admired in the Great Hall in Winchester (England).

In its centre, we can see the “Tudor’s rose” which was the emblems of the Tudors from 1450 on; just above we can see a king (which could be the King Arthur) with its sceptre.

As this original disc is more than 4 metres as diameter, it seems it was used as a Table. This is why the 2-colors-radius all around remembers that every seat around the Table is equal.

This logo was taken over by RT Britain and Ireland while chartering their national association in 1928. This general theme was then taken over by each new association. King Arthur and / or the central rose was replaced each time by a symbol that refers to the relevant national association.

View our Social media below to follow us.

3 | -------------------------------------------------------------------------------- /html_extraction/demo_url.txt: -------------------------------------------------------------------------------- 1 | https://roundtablegauteng.co.za/?page_number_0=4 2 | -------------------------------------------------------------------------------- /human_feedback_textfiltering/README.md: -------------------------------------------------------------------------------- 1 | # Human Feedback Textfiltering 2 | 3 | 4 | 5 | This repository contains 2,000 examples in test.jsonl. To perform text filtering on these examples, simply run the following command: 6 | 7 | ```shell 8 | python run.py 9 | ``` -------------------------------------------------------------------------------- /human_feedback_textfiltering/myxml.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from lxml.etree import Element, RelaxNG, SubElement, XMLParser, fromstring, tostring 3 | from html import unescape 4 | 5 | WITH_ATTRIBUTES = {'cell', 'del', 'graphic', 'head', 'hi', 'item', 'list', 'ref'} 6 | 7 | from trafilatura.xml import replace_element_text, NEWLINE_ELEMS, SPECIAL_FORMATTING, CONTROL_PARSER, sanitize, validate_tei 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | def clean_attributes(tree): 12 | '''Remove unnecessary attributes.''' 13 | for elem in tree.iter('*'): 14 | if elem.tag not in WITH_ATTRIBUTES: 15 | elem.attrib.clear() 16 | return tree 17 | 18 | def control_xml_output(output_tree, output_format, tei_validation, docmeta, pretty_print=True): 19 | '''Make sure the XML output is conform and valid if required''' 20 | control_string = sanitize(tostring(output_tree, encoding='unicode')) 21 | # necessary for cleaning 22 | output_tree = fromstring(control_string, CONTROL_PARSER) 23 | # validate 24 | if output_format == 'xmltei' and tei_validation is True: 25 | result = validate_tei(output_tree) 26 | LOGGER.debug('TEI validation result: %s %s %s', result, docmeta.id, docmeta.url) 27 | return tostring(output_tree, pretty_print=pretty_print, encoding='unicode').strip() 28 | 29 | def xmltotxt(xmloutput, include_formatting, include_images=True): 30 | '''Convert to plain text format and optionally preserve formatting as markdown.''' 31 | returnlist = [] 32 | # strip_tags(xmloutput, 'div', 'main', 'span') 33 | # iterate and convert to list of strings 34 | for element in xmloutput.iter('*'): 35 | if element.text is None and element.tail is None: 36 | if element.tag == 'graphic' and include_images: 37 | # add source, default to '' 38 | text = element.get('title', '') 39 | if element.get('alt') is not None: 40 | text += ' ' + element.get('alt') 41 | url = element.get('src', '') 42 | if not url: url = element.get('data-src', '') 43 | if url: returnlist.extend(['![', text, ']', '(', url, ')']) 44 | # newlines for textless elements 45 | if element.tag in ('graphic', 'row', 'table'): 46 | returnlist.append('\n') 47 | continue 48 | # process text 49 | textelement = replace_element_text(element, include_formatting) 50 | # common elements 51 | if element.tag in NEWLINE_ELEMS: 52 | returnlist.extend(['\n', textelement, '\n']) 53 | # particular cases 54 | elif element.tag == 'item': 55 | returnlist.extend(['\n- ', textelement, '\n']) 56 | elif element.tag == 'cell': 57 | returnlist.extend(['|', textelement, '|']) 58 | elif element.tag == 'comments': 59 | returnlist.append('\n\n') 60 | else: 61 | if element.tag not in SPECIAL_FORMATTING: 62 | LOGGER.debug('unprocessed element in output: %s', element.tag) 63 | returnlist.extend([textelement, ' ']) 64 | return unescape(sanitize(''.join(returnlist))) 65 | 66 | def build_xml_output(docmeta): 67 | '''Build XML output tree based on extracted information''' 68 | output = Element('doc') 69 | output = add_xml_meta(output, docmeta) 70 | docmeta.body.tag = 'main' 71 | # clean XML tree 72 | output.append(clean_attributes(docmeta.body)) 73 | if docmeta.commentsbody is not None: 74 | docmeta.commentsbody.tag = 'comments' 75 | output.append(clean_attributes(docmeta.commentsbody)) 76 | # XML invalid characters 77 | # https://chase-seibert.github.io/blog/2011/05/20/stripping-control-characters-in-python.html 78 | return output 79 | 80 | def add_xml_meta(output, docmeta): 81 | '''Add extracted metadata to the XML output tree''' 82 | # metadata 83 | if docmeta: 84 | if docmeta.sitename is not None: 85 | output.set('sitename', docmeta.sitename) 86 | if docmeta.title is not None: 87 | output.set('title', docmeta.title) 88 | if docmeta.author is not None: 89 | output.set('author', docmeta.author) 90 | if docmeta.date is not None: 91 | output.set('date', docmeta.date) 92 | if docmeta.url is not None: 93 | output.set('source', docmeta.url) 94 | if docmeta.hostname is not None: 95 | output.set('hostname', docmeta.hostname) 96 | if docmeta.description is not None: 97 | output.set('excerpt', docmeta.description) 98 | if docmeta.categories is not None: 99 | try: 100 | output.set('categories', ';'.join(docmeta.categories)) 101 | except: 102 | pass 103 | if docmeta.tags is not None: 104 | try: 105 | output.set('tags', ';'.join(docmeta.tags)) 106 | except: 107 | pass 108 | if docmeta.license is not None: 109 | output.set('license', docmeta.license) 110 | if docmeta.id is not None: 111 | output.set('id', docmeta.id) 112 | if docmeta.fingerprint is not None: 113 | output.set('fingerprint', docmeta.fingerprint) 114 | if docmeta.language is not None: 115 | output.set('language', docmeta.language) 116 | return output 117 | -------------------------------------------------------------------------------- /human_feedback_textfiltering/run.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | from lxml import etree 4 | from meta import xmltodoc 5 | from myxml import xmltotxt 6 | from mytraf import mydetermine_returnstring 7 | from text_filter.filters import (words_discard, 8 | char_discard, stop_word_discard, document_porn_discard, 9 | img_txt_ratio_discard, uppercase_discard, numerical_discard, 10 | social_media_counter_discard, one_word_discard, short_discard, 11 | porn_discard, comments_discard, header_footer_discard, 12 | newlines_discard, heads_discard, underlines_split, 13 | video_field_discard, 14 | readme_discard, fulltext_discard, url_discard, image_caption_discard, 15 | advertisement_discard, re_short_long_paragraphs, 16 | check_paragraph_lengths, 17 | tooshort_discard, aberrant_item_discard, cite_discard, 18 | social_media_discard, phonenum_author_time_discard, 19 | filter_download_links, filter_source_references) 20 | 21 | 22 | def filter_single_xml_en(xml_str): 23 | doc = xmltodoc(xml_str) 24 | xml = etree.fromstring(xml_str) 25 | _, num_images = mydetermine_returnstring( 26 | doc, output_format='txt', include_formatting=False, tei_validation=False) 27 | 28 | xml, _ = newlines_discard(xml) 29 | xml, _ = underlines_split(xml) 30 | xml, _ = video_field_discard(xml) 31 | xml, _ = fulltext_discard(xml) 32 | 33 | # Paragraph Level Filtering 34 | xml, _ = filter_download_links(xml) 35 | xml, _ = filter_source_references(xml) 36 | xml, _ = uppercase_discard(xml, uppercase_threshold=0.8, immutable=True) 37 | xml, _ = numerical_discard(xml, numerical_threshold=0.8, immutable=True) 38 | xml, _ = social_media_counter_discard(xml) 39 | xml, _ = one_word_discard(xml, immutable=True) 40 | xml, _ = short_discard(xml) 41 | xml, _ = porn_discard(xml) 42 | xml, _ = comments_discard(xml) 43 | xml, _ = header_footer_discard(xml) 44 | xml, _ = heads_discard(xml) 45 | xml, _ = readme_discard(xml) 46 | xml, _ = url_discard(xml) 47 | xml, _ = image_caption_discard(xml) 48 | xml, _ = advertisement_discard(xml) 49 | xml, _ = re_short_long_paragraphs(xml) 50 | xml, _ = tooshort_discard(xml) 51 | xml, _ = aberrant_item_discard(xml) 52 | xml, _ = cite_discard(xml) 53 | xml, _ = social_media_discard(xml) 54 | xml, _ = phonenum_author_time_discard(xml) 55 | 56 | pure_txt = xmltotxt( 57 | xml, include_formatting=False, include_images=False) 58 | if pure_txt is None or len(pure_txt) == 0: 59 | return None 60 | words_keep, words = words_discard(pure_txt, words_count_range=(50, 100000), avg_word_len_range=[3, 10], 61 | return_words=True) 62 | if not words_keep: 63 | return None 64 | if not char_discard(pure_txt, char_threshold=0.8, words=words): return None 65 | if not stop_word_discard(pure_txt, stop_word_threshold=2, words=words): return None 66 | if not document_porn_discard(pure_txt, thld=0.02): return None 67 | if not img_txt_ratio_discard(pure_txt, image_count=num_images, min_image_count=2, 68 | min_text_image_ratio=50): return None 69 | if not check_paragraph_lengths(xml): return None 70 | return etree.tostring(xml).decode() 71 | 72 | def read_jsonl_file(file_path): 73 | data = [] 74 | with open(file_path, 'r', encoding='utf8') as file: 75 | # print(file) 76 | for line in file: 77 | # print(line) 78 | data.append(json.loads(line)) 79 | return data 80 | 81 | def main_function(input_file=None, output_file=None): 82 | datas = read_jsonl_file(file_path=input_file) 83 | with open(output_file, 'w', encoding='utf-8') as f: 84 | for line_dict in datas: 85 | xml_txt = line_dict['content'] 86 | res = filter_single_xml_en(xml_txt) 87 | if res is not None: 88 | res = filter_single_xml_en(res) 89 | if res is not None: 90 | res = filter_single_xml_en(res) 91 | if res is None: 92 | continue 93 | line_dict['content'] = res 94 | f.write(json.dumps(line_dict, ensure_ascii=False) + '\n') 95 | 96 | 97 | if __name__ == '__main__': 98 | main_function(input_file="test.jsonl", output_file="output.jsonl") 99 | -------------------------------------------------------------------------------- /human_feedback_textfiltering/text_filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/human_feedback_textfiltering/text_filter/__init__.py -------------------------------------------------------------------------------- /human_feedback_textfiltering/text_filter/spam_word_less.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/human_feedback_textfiltering/text_filter/spam_word_less.txt -------------------------------------------------------------------------------- /mllm_internvl/DATASET_CARD.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mllm_internvl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_internvl/__init__.py -------------------------------------------------------------------------------- /mllm_internvl/coco_metric.py: -------------------------------------------------------------------------------- 1 | from pycocoevalcap.eval import COCOEvalCap 2 | from pycocotools.coco import COCO 3 | 4 | 5 | def compute_cider( 6 | result_path, 7 | annotations_path, 8 | ): 9 | # create coco object and coco_result object 10 | coco = COCO(annotations_path) 11 | coco_result = coco.loadRes(result_path) 12 | 13 | # create coco_eval object by taking coco and coco_result 14 | coco_eval = COCOEvalCap(coco, coco_result) 15 | coco_eval.params["image_id"] = coco_result.getImgIds() 16 | coco_eval.evaluate() 17 | 18 | return coco_eval.eval 19 | 20 | 21 | def postprocess_captioning_generation(predictions): 22 | return predictions.split("Output", 1)[0] 23 | -------------------------------------------------------------------------------- /mllm_internvl/eval_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets import ImageFolder 7 | 8 | from classification_utils import IMAGENET_CLASSNAMES 9 | 10 | 11 | class CaptionDataset(Dataset): 12 | def __init__( 13 | self, 14 | image_train_dir_path, 15 | annotations_path, 16 | is_train, 17 | dataset_name, 18 | image_val_dir_path=None, 19 | ): 20 | self.image_train_dir_path = image_train_dir_path 21 | self.image_val_dir_path = image_val_dir_path 22 | self.annotations = [] 23 | self.is_train = is_train 24 | self.dataset_name = dataset_name 25 | 26 | full_annotations = json.load(open(annotations_path))["images"] 27 | 28 | for i in range(len(full_annotations)): 29 | if self.is_train and full_annotations[i]["split"] != "train": 30 | continue 31 | elif not self.is_train and full_annotations[i]["split"] != "test": 32 | continue 33 | 34 | self.annotations.append(full_annotations[i]) 35 | 36 | def __len__(self): 37 | return len(self.annotations) 38 | 39 | def __getitem__(self, idx): 40 | if self.dataset_name == "coco": 41 | image = Image.open( 42 | os.path.join( 43 | self.image_train_dir_path, self.annotations[idx]["filename"] 44 | ) 45 | if self.annotations[idx]["filepath"] == "train2014" 46 | else os.path.join( 47 | self.image_val_dir_path, self.annotations[idx]["filename"] 48 | ) 49 | ) 50 | elif self.dataset_name == "flickr": 51 | image = Image.open( 52 | os.path.join( 53 | self.image_train_dir_path, self.annotations[idx]["filename"] 54 | ) 55 | ) 56 | image.load() 57 | caption = self.annotations[idx]["sentences"][0]["raw"] 58 | return { 59 | "image": image, 60 | "caption": caption, 61 | "image_id": self.annotations[idx]["cocoid"] 62 | if self.dataset_name == "coco" 63 | else self.annotations[idx]["filename"].split(".")[0], 64 | } 65 | 66 | 67 | class VQADataset(Dataset): 68 | def __init__( 69 | self, image_dir_path, question_path, annotations_path, is_train, dataset_name 70 | ): 71 | self.questions = json.load(open(question_path, "r"))["questions"] 72 | if annotations_path is not None: 73 | self.answers = json.load(open(annotations_path, "r"))["annotations"] 74 | else: 75 | self.answers = None 76 | self.image_dir_path = image_dir_path 77 | self.is_train = is_train 78 | self.dataset_name = dataset_name 79 | if self.dataset_name in {"vqav2", "ok_vqa"}: 80 | self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1] 81 | assert self.img_coco_split in {"train2014", "val2014", "test2015"} 82 | 83 | def __len__(self): 84 | return len(self.questions) 85 | 86 | def get_img_path(self, question): 87 | if self.dataset_name in {"vqav2", "ok_vqa"}: 88 | return os.path.join( 89 | self.image_dir_path, 90 | f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg" 91 | if self.is_train 92 | else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg", 93 | ) 94 | elif self.dataset_name == "vizwiz": 95 | return os.path.join(self.image_dir_path, question["image_id"]) 96 | elif self.dataset_name == "textvqa": 97 | return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") 98 | else: 99 | raise Exception(f"Unknown VQA dataset {self.dataset_name}") 100 | 101 | def __getitem__(self, idx): 102 | question = self.questions[idx] 103 | img_path = self.get_img_path(question) 104 | image = Image.open(img_path) 105 | image.load() 106 | results = { 107 | "image": image, 108 | "question": question["question"], 109 | "question_id": question["question_id"], 110 | } 111 | if self.answers is not None: 112 | answers = self.answers[idx] 113 | results["answers"] = [a["answer"] for a in answers["answers"]] 114 | return results 115 | 116 | 117 | class ImageNetDataset(ImageFolder): 118 | """Class to represent the ImageNet1k dataset.""" 119 | 120 | def __init__(self, root, **kwargs): 121 | super().__init__(root=root, **kwargs) 122 | self.class_id_to_name = dict( 123 | zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES) 124 | ) 125 | 126 | def __getitem__(self, idx): 127 | sample, target = super().__getitem__(idx) 128 | target_label = self.class_id_to_name[target] 129 | return { 130 | "id": idx, 131 | "image": sample, 132 | "class_id": target, # numeric ID of the ImageNet class 133 | "class_name": target_label, # human-readable name of ImageNet class 134 | } 135 | 136 | 137 | class HatefulMemesDataset(Dataset): 138 | def __init__(self, image_dir_path, annotations_path): 139 | self.image_dir_path = image_dir_path 140 | with open(annotations_path, "r") as f: 141 | self.annotations = [json.loads(line) for line in f] 142 | 143 | def __len__(self): 144 | return len(self.annotations) 145 | 146 | def __getitem__(self, idx): 147 | annotation = self.annotations[idx] 148 | img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1]) 149 | image = Image.open(img_path) 150 | image.load() 151 | return { 152 | "id": annotation["id"], 153 | "image": image, 154 | "ocr": annotation["text"], 155 | "class_name": "yes" if annotation["label"] == 1 else "no", 156 | "class_id": annotation["label"], 157 | } 158 | -------------------------------------------------------------------------------- /mllm_internvl/evaluate_with_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 4 | export MASTER_PORT=12322 5 | echo $MASTER_ADDR 6 | echo $MASTER_PORT 7 | # SLURM_PROCID 8 | HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 9 | echo HOSTNAMES=$HOSTNAMES 10 | H=$(hostname) 11 | THEID=$(echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]") 12 | if [[ -z "${THEID// }" ]]; then 13 | THEID=0 14 | fi 15 | echo SLURM_PROCID=$THEID 16 | NNODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l) 17 | echo NNODES=$NNODES 18 | 19 | set -x 20 | 21 | torchrun --nnodes=$NNODES --nproc-per-node=8 \ 22 | --master_port ${MASTER_PORT} --master_addr ${MASTER_ADDR} --node_rank ${THEID} \ 23 | evaluate.py $@ 24 | -------------------------------------------------------------------------------- /mllm_internvl/fill_vqa_testdev_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission. 3 | Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set. 4 | Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction. 5 | """ 6 | import json 7 | import sys 8 | import os 9 | 10 | sys.path.append( 11 | os.path.join( 12 | os.path.dirname(os.path.abspath(__file__)), 13 | "..", 14 | ) 15 | ) 16 | from vqa_metric import VQAEval 17 | 18 | postprocessor = VQAEval(None, None) 19 | 20 | 21 | def fill_vizwiz_test_json( 22 | input_path, 23 | output_path, 24 | vqa_test_questions_json_path, 25 | ): 26 | # read the input json and build a set with all question_ids 27 | with open(input_path, "r") as f: 28 | input_json = json.load(f) 29 | 30 | # postprocess answers 31 | question_id_to_answer = {} 32 | for q in input_json: 33 | resAns = q["answer"] 34 | resAns = resAns.replace("\n", " ") 35 | resAns = resAns.replace("\t", " ") 36 | resAns = resAns.strip() 37 | resAns = postprocessor.processPunctuation(resAns) 38 | resAns = postprocessor.processDigitArticle(resAns) 39 | question_id_to_answer[q["question_id"]] = resAns 40 | 41 | # read the vqa test json to get all the qustion_ids that need to be filled 42 | with open(vqa_test_questions_json_path, "r") as f: 43 | vqa_test_json = json.load(f) 44 | vqa_test_json = vqa_test_json["questions"] 45 | 46 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 47 | output_json = [] 48 | for q in vqa_test_json: 49 | output_json.append( 50 | { 51 | "image": q["image_id"], 52 | "answer": question_id_to_answer.get(q["question_id"], ""), 53 | } 54 | ) 55 | 56 | # write the json to the output path 57 | with open(output_path, "w") as f: 58 | json.dump(output_json, f) 59 | 60 | 61 | def fill_vqav2_test_json( 62 | input_path, 63 | output_path, 64 | vqa_test_questions_json_path, 65 | ): 66 | # read the input json and build a set with all question_ids 67 | with open(input_path, "r") as f: 68 | input_json = json.load(f) 69 | question_ids = set() 70 | for q in input_json: 71 | question_ids.add(q["question_id"]) 72 | 73 | # make a copy of the input json 74 | output_json = [] 75 | for q in input_json: 76 | resAns = q["answer"] 77 | resAns = resAns.replace("\n", " ") 78 | resAns = resAns.replace("\t", " ") 79 | resAns = resAns.strip() 80 | resAns = postprocessor.processPunctuation(resAns) 81 | resAns = postprocessor.processDigitArticle(resAns) 82 | q["answer"] = resAns 83 | output_json.append(q) 84 | 85 | # read the vqa test json to get all the qustion_ids that need to be filled 86 | with open(vqa_test_questions_json_path, "r") as f: 87 | vqa_test_json = json.load(f) 88 | vqa_test_json = vqa_test_json["questions"] 89 | 90 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 91 | for q in vqa_test_json: 92 | if q["question_id"] not in question_ids: 93 | output_json.append( 94 | { 95 | "question_id": q["question_id"], 96 | "answer": "", 97 | } 98 | ) 99 | 100 | # write the json to the output path 101 | with open(output_path, "w") as f: 102 | json.dump(output_json, f) 103 | 104 | 105 | if __name__ == "__main__": 106 | import argparse 107 | 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument( 110 | "--dataset", 111 | type=str, 112 | choices=["vqav2", "vizwiz"], 113 | ) 114 | parser.add_argument( 115 | "--input_path", 116 | type=str, 117 | help="Path to the json file with the subset of the vqa test-dev questions.", 118 | ) 119 | parser.add_argument( 120 | "--vqa_test_questions_json_path", 121 | type=str, 122 | help="Path to the json file with all the vqa test questions.", 123 | ) 124 | parser.add_argument( 125 | "--output_path", 126 | type=str, 127 | help="Path to store the filled json.", 128 | ) 129 | args = parser.parse_args() 130 | 131 | if args.dataset == "vqav2": 132 | fill_vqav2_test_json( 133 | args.input_path, 134 | args.output_path, 135 | args.vqa_test_questions_json_path, 136 | ) 137 | else: 138 | fill_vizwiz_test_json( 139 | args.input_path, 140 | args.output_path, 141 | args.vqa_test_questions_json_path, 142 | ) 143 | -------------------------------------------------------------------------------- /mllm_internvl/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | pycocoevalcap 3 | einops_exts 4 | opencv-python 5 | imageio 6 | decord 7 | nltk 8 | inflection 9 | termcolor 10 | yacs 11 | pyyaml 12 | scipy 13 | tqdm -------------------------------------------------------------------------------- /mllm_internvl/rices.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | from tqdm import tqdm 4 | import torch 5 | from utils import custom_collate_fn 6 | 7 | 8 | class RICES: 9 | def __init__( 10 | self, 11 | dataset, 12 | device, 13 | batch_size, 14 | vision_encoder_path="ViT-B-32", 15 | vision_encoder_pretrained="openai", 16 | cached_features=None, 17 | ): 18 | self.dataset = dataset 19 | self.device = device 20 | self.batch_size = batch_size 21 | 22 | # Load the model and processor 23 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 24 | vision_encoder_path, 25 | pretrained=vision_encoder_pretrained, 26 | ) 27 | self.model = vision_encoder.to(self.device) 28 | self.image_processor = image_processor 29 | 30 | # Precompute features 31 | if cached_features is None: 32 | self.features = self._precompute_features() 33 | else: 34 | self.features = cached_features 35 | 36 | def _precompute_features(self): 37 | features = [] 38 | 39 | # Switch to evaluation mode 40 | self.model.eval() 41 | 42 | # Set up loader 43 | loader = torch.utils.data.DataLoader( 44 | self.dataset, 45 | batch_size=self.batch_size, 46 | collate_fn=custom_collate_fn, 47 | ) 48 | 49 | with torch.no_grad(): 50 | for batch in tqdm( 51 | loader, 52 | desc="Precomputing features for RICES", 53 | ): 54 | batch = batch["image"] 55 | inputs = torch.stack( 56 | [self.image_processor(image) for image in batch] 57 | ).to(self.device) 58 | image_features = self.model.encode_image(inputs) 59 | image_features /= image_features.norm(dim=-1, keepdim=True) 60 | features.append(image_features.detach()) 61 | 62 | features = torch.cat(features) 63 | return features 64 | 65 | def find(self, batch, num_examples): 66 | """ 67 | Get the top num_examples most similar examples to the images. 68 | """ 69 | # Switch to evaluation mode 70 | self.model.eval() 71 | 72 | with torch.no_grad(): 73 | inputs = torch.stack([self.image_processor(image) for image in batch]).to( 74 | self.device 75 | ) 76 | 77 | # Get the feature of the input image 78 | query_feature = self.model.encode_image(inputs) 79 | query_feature /= query_feature.norm(dim=-1, keepdim=True) 80 | query_feature = query_feature.detach().cpu() 81 | 82 | if query_feature.ndim == 1: 83 | query_feature = query_feature.unsqueeze(0) 84 | 85 | # Compute the similarity of the input image to the precomputed features 86 | similarity = (query_feature @ self.features.T).squeeze() 87 | 88 | if similarity.ndim == 1: 89 | similarity = similarity.unsqueeze(0) 90 | 91 | # Get the indices of the 'num_examples' most similar images 92 | indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples] 93 | 94 | # Return with the most similar images last 95 | return [[self.dataset[i] for i in reversed(row)] for row in indices] 96 | -------------------------------------------------------------------------------- /mllm_internvl/test.sh: -------------------------------------------------------------------------------- 1 | CKPT_ROOT=".." 2 | RESULT_ROOT="./results" 3 | CKPT_FNAMES=( 4 | "checkpoint-10200" 5 | ) 6 | mkdir -p $RESULT_ROOT 7 | 8 | for CKPT_FNAME in "${CKPT_FNAMES[@]}"; do 9 | set -x 10 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \ 11 | srun bash run_eval.sh --model internvl_chat \ 12 | --batch_size 1 --shots 0 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \ 13 | --load_in_8bit False --dynamic False --max_num 6 \ 14 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \ 15 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_0shots-multi-rounds.result" 16 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \ 17 | srun bash run_eval.sh --model internvl_chat \ 18 | --batch_size 1 --shots 1 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \ 19 | --load_in_8bit False --dynamic False --max_num 6 \ 20 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \ 21 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_1shots-multi-rounds.result" 22 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \ 23 | srun bash run_eval.sh --model internvl_chat \ 24 | --batch_size 1 --shots 2 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \ 25 | --load_in_8bit False --dynamic False --max_num 6 \ 26 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \ 27 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_2shots-multi-rounds.result" 28 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \ 29 | srun bash run_eval.sh --model internvl_chat \ 30 | --batch_size 1 --shots 4 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \ 31 | --load_in_8bit False --dynamic False --max_num 6 \ 32 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \ 33 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_4shots-multi-rounds.result" 34 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \ 35 | srun bash run_eval.sh --model internvl_chat \ 36 | --batch_size 1 --shots 8 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \ 37 | --load_in_8bit False --dynamic False --max_num 6 \ 38 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \ 39 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_8shots-multi-rounds.result" 40 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \ 41 | srun bash run_eval.sh --model internvl_chat \ 42 | --batch_size 1 --shots 0 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \ 43 | --load_in_8bit False --dynamic False --max_num 6 --zero-shot-add-text-shots 2 \ 44 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \ 45 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_trick0shot-multi-rounds.result" 46 | done 47 | -------------------------------------------------------------------------------- /mllm_llava/.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /mllm_llava/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | *.json 11 | *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | 16 | # Editor 17 | .idea 18 | *.swp 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | 25 | checkpoints 26 | ckpts* 27 | 28 | .ipynb_checkpoints 29 | *.ipynb 30 | 31 | # DevContainer 32 | !.devcontainer/* 33 | 34 | # Demo 35 | serve_images/ 36 | 37 | # data links 38 | playground/data/coco 39 | playground/data/data_path.txt 40 | playground/data/gqa 41 | playground/data/ocr_vqa 42 | playground/data/textvqa 43 | playground/data/vg 44 | playground/data/LLaVA-Pretrain/images 45 | playground/data/eval* 46 | 47 | # lqy custom 48 | runnings_*/*/* 49 | yt-sb-1b/* 50 | batchscript-* 51 | phoenix-slurm-* 52 | tmp* 53 | slurm_out/* 54 | slurm_out_*/* 55 | unit_test/mmc4_img_num/* 56 | cached_features_* 57 | unit_test/test_internlm2_tokenizer/* 58 | playground/data/lmms_eval_logs/* 59 | -------------------------------------------------------------------------------- /mllm_llava/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM, LlavaInternLM2ForCausalLM 2 | -------------------------------------------------------------------------------- /mllm_llava/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | from tqdm import tqdm 9 | 10 | NUM_SECONDS_TO_SLEEP = 0.5 11 | 12 | 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4-0314', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | return response['choices'][0]['message']['content'] 36 | 37 | 38 | def parse_score(review): 39 | try: 40 | score_pair = review.split('\n')[0] 41 | score_pair = score_pair.replace(',', ' ') 42 | sp = score_pair.split(' ') 43 | if len(sp) == 2: 44 | return [float(sp[0]), float(sp[1])] 45 | else: 46 | print('error', review) 47 | return [-1, -1] 48 | except Exception as e: 49 | print(e) 50 | print('error', review) 51 | return [-1, -1] 52 | 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 56 | parser.add_argument('-q', '--question') 57 | parser.add_argument('-c', '--context') 58 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 59 | parser.add_argument('-r', '--rule') 60 | parser.add_argument('-o', '--output') 61 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 62 | args = parser.parse_args() 63 | 64 | f_q = open(os.path.expanduser(args.question)) 65 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 66 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 67 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 68 | 69 | if os.path.isfile(os.path.expanduser(args.output)): 70 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 71 | else: 72 | cur_reviews = [] 73 | 74 | review_file = open(f'{args.output}', 'a') 75 | 76 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 77 | image_to_context = {context['image']: context for context in context_list} 78 | 79 | handles = [] 80 | idx = 0 81 | for ques_js, ans1_js, ans2_js in tqdm(zip(f_q, f_ans1, f_ans2)): 82 | ques = json.loads(ques_js) 83 | ans1 = json.loads(ans1_js) 84 | ans2 = json.loads(ans2_js) 85 | 86 | inst = image_to_context[ques['image']] 87 | 88 | if isinstance(inst['caption'], list): 89 | cap_str = '\n'.join(inst['caption']) 90 | else: 91 | cap_str = inst['caption'] 92 | 93 | category = 'llava_bench_' + json.loads(ques_js)['category'] 94 | if category in rule_dict: 95 | rule = rule_dict[category] 96 | else: 97 | assert False, f"Visual QA category not found in rule file: {category}." 98 | prompt = rule['prompt'] 99 | role = rule['role'] 100 | content = (f'[Context]\n{cap_str}\n\n' 101 | f'[Question]\n{ques["text"]}\n\n' 102 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 103 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 104 | f'[System]\n{prompt}\n\n') 105 | cur_js = { 106 | 'id': idx+1, 107 | 'question_id': ques['question_id'], 108 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 109 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 110 | 'category': category 111 | } 112 | if idx >= len(cur_reviews): 113 | review = get_eval(content, args.max_tokens) 114 | scores = parse_score(review) 115 | cur_js['content'] = review 116 | cur_js['tuple'] = scores 117 | review_file.write(json.dumps(cur_js) + '\n') 118 | review_file.flush() 119 | else: 120 | print(f'Skipping {idx} as we already have it.') 121 | idx += 1 122 | print(idx) 123 | review_file.close() 124 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str, default="./playground/data/eval/pope/coco") 67 | parser.add_argument("--question-file", type=str, default="./playground/data/eval/pope/llava_pope_test.jsonl") 68 | parser.add_argument("--result-file", type=str, required=True) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | @torch.inference_mode() 14 | def eval_model(model_name, questions_file, answers_file): 15 | # Model 16 | disable_torch_init() 17 | model_name = os.path.expanduser(model_name) 18 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 19 | model = AutoModelForCausalLM.from_pretrained(model_name, 20 | torch_dtype=torch.float16).cuda() 21 | 22 | 23 | ques_file = open(os.path.expanduser(questions_file), "r") 24 | ans_file = open(os.path.expanduser(answers_file), "w") 25 | for i, line in enumerate(tqdm(ques_file)): 26 | idx = json.loads(line)["question_id"] 27 | qs = json.loads(line)["text"] 28 | cat = json.loads(line)["category"] 29 | conv = default_conversation.copy() 30 | conv.append_message(conv.roles[0], qs) 31 | prompt = conv.get_prompt() 32 | inputs = tokenizer([prompt]) 33 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 34 | output_ids = model.generate( 35 | input_ids, 36 | do_sample=True, 37 | use_cache=True, 38 | temperature=0.7, 39 | max_new_tokens=1024,) 40 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 41 | try: 42 | index = outputs.index(conv.sep, len(prompt)) 43 | except ValueError: 44 | outputs += conv.sep 45 | index = outputs.index(conv.sep, len(prompt)) 46 | 47 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 48 | ans_id = shortuuid.uuid() 49 | ans_file.write(json.dumps({"question_id": idx, 50 | "text": outputs, 51 | "answer_id": ans_id, 52 | "model_id": model_name, 53 | "metadata": {}}) + "\n") 54 | ans_file.flush() 55 | ans_file.close() 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 60 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 61 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 62 | args = parser.parse_args() 63 | 64 | eval_model(args.model_name, args.question_file, args.answers_file) 65 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def split_list_v2(lst, n): 25 | """Split a list into n (roughly) equal-sized chunks""" 26 | base_chunk_size = len(lst) // n # integer division 27 | remainder = len(lst) % n # remaining elements 28 | chunks = [] 29 | for i in range(n): 30 | chunk_size = base_chunk_size + (i < remainder) # add one to the chunk size for the first 'remainder' chunks 31 | start = i * base_chunk_size + min(i, remainder) # calculate the start index 32 | chunks.append(lst[start:start+chunk_size]) 33 | return chunks 34 | 35 | 36 | def get_chunk(lst, n, k): 37 | chunks = split_list_v2(lst, n) 38 | return chunks[k] 39 | 40 | 41 | def eval_model(args): 42 | # Model 43 | disable_torch_init() 44 | model_path = os.path.expanduser(args.model_path) 45 | model_name = get_model_name_from_path(model_path) 46 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 47 | 48 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 49 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 50 | answers_file = os.path.expanduser(args.answers_file) 51 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 52 | ans_file = open(answers_file, "w") 53 | for line in tqdm(questions): 54 | idx = line["question_id"] 55 | image_file = line["image"] 56 | qs = line["text"] 57 | cur_prompt = qs 58 | if model.config.mm_use_im_start_end: 59 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 60 | else: 61 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 62 | 63 | conv = conv_templates[args.conv_mode].copy() 64 | conv.append_message(conv.roles[0], qs) 65 | conv.append_message(conv.roles[1], None) 66 | prompt = conv.get_prompt() 67 | 68 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 69 | 70 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') 71 | image_tensor = process_images([image], image_processor, model.config)[0] 72 | 73 | with torch.inference_mode(): 74 | output_ids = model.generate( 75 | input_ids, 76 | images=image_tensor.unsqueeze(0).half().cuda(), 77 | image_sizes=[image.size], 78 | do_sample=True if args.temperature > 0 else False, 79 | temperature=args.temperature, 80 | top_p=args.top_p, 81 | num_beams=args.num_beams, 82 | # no_repeat_ngram_size=3, 83 | max_new_tokens=1024, 84 | use_cache=True) 85 | 86 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 87 | 88 | ans_id = shortuuid.uuid() 89 | ans_file.write(json.dumps({"question_id": idx, 90 | "prompt": cur_prompt, 91 | "text": outputs, 92 | "answer_id": ans_id, 93 | "model_id": model_name, 94 | "metadata": {}}) + "\n") 95 | ans_file.flush() 96 | ans_file.close() 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 101 | parser.add_argument("--model-base", type=str, default=None) 102 | parser.add_argument("--image-folder", type=str, default="") 103 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 104 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 105 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 106 | parser.add_argument("--num-chunks", type=int, default=1) 107 | parser.add_argument("--chunk-idx", type=int, default=0) 108 | parser.add_argument("--temperature", type=float, default=0.2) 109 | parser.add_argument("--top_p", type=float, default=None) 110 | parser.add_argument("--num_beams", type=int, default=1) 111 | args = parser.parse_args() 112 | 113 | if args.model_base is None and "::" in args.model_path: 114 | model_base_and_path = args.model_path 115 | args.model_base, args.model_path = model_base_and_path.split("::") 116 | print(f"model_base_and_path ({model_base_and_path}) has been split into model_path ({args.model_path}) and model_base ({args.model_base})") 117 | 118 | eval_model(args) 119 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import ( 5 | IMAGE_TOKEN_INDEX, 6 | DEFAULT_IMAGE_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | IMAGE_PLACEHOLDER, 10 | ) 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import ( 15 | process_images, 16 | tokenizer_image_token, 17 | get_model_name_from_path, 18 | ) 19 | 20 | from PIL import Image 21 | 22 | import requests 23 | from PIL import Image 24 | from io import BytesIO 25 | import re 26 | 27 | 28 | def image_parser(args): 29 | out = args.image_file.split(args.sep) 30 | return out 31 | 32 | 33 | def load_image(image_file): 34 | if image_file.startswith("http") or image_file.startswith("https"): 35 | response = requests.get(image_file) 36 | image = Image.open(BytesIO(response.content)).convert("RGB") 37 | else: 38 | image = Image.open(image_file).convert("RGB") 39 | return image 40 | 41 | 42 | def load_images(image_files): 43 | out = [] 44 | for image_file in image_files: 45 | image = load_image(image_file) 46 | out.append(image) 47 | return out 48 | 49 | 50 | def eval_model(args): 51 | # Model 52 | disable_torch_init() 53 | 54 | model_name = get_model_name_from_path(args.model_path) 55 | tokenizer, model, image_processor, context_len = load_pretrained_model( 56 | args.model_path, args.model_base, model_name 57 | ) 58 | 59 | qs = args.query 60 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 61 | if IMAGE_PLACEHOLDER in qs: 62 | if model.config.mm_use_im_start_end: 63 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 64 | else: 65 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 66 | else: 67 | if model.config.mm_use_im_start_end: 68 | qs = image_token_se + "\n" + qs 69 | else: 70 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 71 | 72 | if "llama-2" in model_name.lower(): 73 | conv_mode = "llava_llama_2" 74 | elif "mistral" in model_name.lower(): 75 | conv_mode = "mistral_instruct" 76 | elif "v1.6-34b" in model_name.lower(): 77 | conv_mode = "chatml_direct" 78 | elif "v1" in model_name.lower(): 79 | conv_mode = "llava_v1" 80 | elif "mpt" in model_name.lower(): 81 | conv_mode = "mpt" 82 | else: 83 | conv_mode = "llava_v0" 84 | 85 | if args.conv_mode is not None and conv_mode != args.conv_mode: 86 | print( 87 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 88 | conv_mode, args.conv_mode, args.conv_mode 89 | ) 90 | ) 91 | else: 92 | args.conv_mode = conv_mode 93 | 94 | conv = conv_templates[args.conv_mode].copy() 95 | conv.append_message(conv.roles[0], qs) 96 | conv.append_message(conv.roles[1], None) 97 | prompt = conv.get_prompt() 98 | 99 | image_files = image_parser(args) 100 | images = load_images(image_files) 101 | image_sizes = [x.size for x in images] 102 | images_tensor = process_images( 103 | images, 104 | image_processor, 105 | model.config 106 | ).to(model.device, dtype=torch.float16) 107 | 108 | input_ids = ( 109 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 110 | .unsqueeze(0) 111 | .cuda() 112 | ) 113 | 114 | with torch.inference_mode(): 115 | output_ids = model.generate( 116 | input_ids, 117 | images=images_tensor, 118 | image_sizes=image_sizes, 119 | do_sample=True if args.temperature > 0 else False, 120 | temperature=args.temperature, 121 | top_p=args.top_p, 122 | num_beams=args.num_beams, 123 | max_new_tokens=args.max_new_tokens, 124 | use_cache=True, 125 | ) 126 | 127 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 128 | print(outputs) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 134 | parser.add_argument("--model-base", type=str, default=None) 135 | parser.add_argument("--image-file", type=str, required=True) 136 | parser.add_argument("--query", type=str, required=True) 137 | parser.add_argument("--conv-mode", type=str, default=None) 138 | parser.add_argument("--sep", type=str, default=",") 139 | parser.add_argument("--temperature", type=float, default=0.2) 140 | parser.add_argument("--top_p", type=float, default=None) 141 | parser.add_argument("--num_beams", type=int, default=1) 142 | parser.add_argument("--max_new_tokens", type=int, default=512) 143 | args = parser.parse_args() 144 | 145 | if args.model_base is None and "::" in args.model_path: 146 | model_base_and_path = args.model_path 147 | args.model_base, args.model_path = model_base_and_path.split("::") 148 | print(f"model_base_and_path ({model_base_and_path}) has been split into model_path ({args.model_path}) and model_base ({args.model_base})") 149 | 150 | eval_model(args) 151 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/translation_tool.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import hashlib 3 | import time 4 | import random 5 | import json 6 | from googletrans import Translator 7 | import requests 8 | from requests.adapters import HTTPAdapter 9 | from requests.packages.urllib3.util.retry import Retry 10 | 11 | 12 | class TranslationTool: 13 | def __init__(self): 14 | self.translator = Translator() 15 | 16 | def translate(self, text, src='en', dest='zh-cn'): 17 | """ 18 | 翻译指定的文本。 19 | :param text: 要翻译的文本字符串。 20 | :param src: 源语言代码,默认是英文 'en'。 21 | :param dest: 目标语言代码,默认是中文简体 'zh-cn'。 22 | :return: 翻译后的文本字符串。 23 | """ 24 | translated = self.translator.translate(text, src=src, dest=dest) 25 | return translated.text 26 | 27 | 28 | class YoudaoTranslationTool: 29 | def __init__(self, app_id, app_secret): 30 | self.app_id = app_id 31 | self.app_secret = app_secret 32 | self.url = "https://openapi.youdao.com/api" 33 | def generate_sign(self, q, salt): 34 | sign_str = self.app_id + q + str(salt) + self.app_secret 35 | sign = hashlib.md5(sign_str.encode('utf-8')).hexdigest() 36 | return sign 37 | def translate(self, text, src='en', dest='zh-CHS'): 38 | salt = random.randint(1, 65536) 39 | sign = self.generate_sign(text, salt) 40 | params = { 41 | 'q': text, 42 | 'from': src, 43 | 'to': dest, 44 | 'appKey': self.app_id, 45 | 'salt': salt, 46 | 'sign': sign 47 | } 48 | try: 49 | response = requests.get(self.url, params=params) 50 | response.raise_for_status() 51 | result = response.json() 52 | if 'translation' in result: 53 | return result['translation'][0] 54 | else: 55 | return "Translation error: " + result.get('errorCode', 'Unknown error') 56 | except requests.exceptions.RequestException as e: 57 | return f"Network error occurred: {e}" 58 | except Exception as e: 59 | return f"An error occurred: {e}" 60 | 61 | # # 使用示例 62 | # app_id = '553c1ee6d3f9f808' 63 | # app_secret = 'iOiNGdr2OtJrspEciUDOirk6wnKzcmP5' 64 | # translation_tool = YoudaoTranslationTool(app_id, app_secret) 65 | # result = translation_tool.translate("Provide a one-sentence caption for the provided image.") 66 | # print(result) 67 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /mllm_llava/llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /mllm_llava/llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /mllm_llava/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mllm_llava/llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /mllm_llava/llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | # NOTE: Solutions may be found at https://github.com/haotian-liu/LLaVA/issues/1101 3 | try: 4 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 5 | except: 6 | pass 7 | try: 8 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 9 | except: 10 | pass 11 | try: 12 | from llava.model.language_model.llava_internlm import LlavaInternLM2ForCausalLM, LlavaInternLM2Config 13 | except: 14 | pass 15 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/language_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/model/language_model/__init__.py -------------------------------------------------------------------------------- /mllm_llava/llava/model/language_model/internlm_chat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/model/language_model/internlm_chat/__init__.py -------------------------------------------------------------------------------- /mllm_llava/llava/model/language_model/llava_internlm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import AutoConfig, AutoModelForCausalLM 7 | 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | from transformers.generation.utils import GenerateOutput 10 | 11 | from llava.model.language_model.internlm_chat.modeling_internlm2 import InternLM2ForCausalLM, InternLM2Model, InternLM2Config 12 | 13 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 14 | 15 | 16 | class LlavaInternLM2Config(InternLM2Config): 17 | model_type = "llava_internlm2" 18 | 19 | 20 | class LlavaInternLM2Model(LlavaMetaModel, InternLM2Model): 21 | config_class = LlavaInternLM2Config 22 | 23 | def __init__(self, config: InternLM2Config): 24 | super(LlavaInternLM2Model, self).__init__(config) 25 | 26 | 27 | class LlavaInternLM2ForCausalLM(InternLM2ForCausalLM, LlavaMetaForCausalLM): 28 | config_class = LlavaInternLM2Config 29 | 30 | def __init__(self, config): 31 | super(InternLM2ForCausalLM, self).__init__(config) 32 | self.model = LlavaInternLM2Model(config) 33 | self.vocab_size = config.vocab_size 34 | self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 35 | 36 | # Initialize weights and apply final processing 37 | self.post_init() 38 | 39 | def get_model(self): 40 | return self.model 41 | 42 | def forward( 43 | self, 44 | input_ids: torch.LongTensor = None, 45 | attention_mask: Optional[torch.Tensor] = None, 46 | position_ids: Optional[torch.LongTensor] = None, 47 | past_key_values: Optional[List[torch.FloatTensor]] = None, 48 | inputs_embeds: Optional[torch.FloatTensor] = None, 49 | labels: Optional[torch.LongTensor] = None, 50 | use_cache: Optional[bool] = None, 51 | output_attentions: Optional[bool] = None, 52 | output_hidden_states: Optional[bool] = None, 53 | images: Optional[torch.FloatTensor] = None, 54 | image_sizes: Optional[List[List[int]]] = None, 55 | return_dict: Optional[bool] = None, 56 | ) -> Union[Tuple, CausalLMOutputWithPast]: 57 | 58 | if inputs_embeds is None: 59 | ( 60 | input_ids, 61 | position_ids, 62 | attention_mask, 63 | past_key_values, 64 | inputs_embeds, 65 | labels 66 | ) = self.prepare_inputs_labels_for_multimodal( 67 | input_ids, 68 | position_ids, 69 | attention_mask, 70 | past_key_values, 71 | labels, 72 | images, 73 | image_sizes 74 | ) 75 | 76 | return super().forward( 77 | input_ids=input_ids, 78 | attention_mask=attention_mask, 79 | position_ids=position_ids, 80 | past_key_values=past_key_values, 81 | inputs_embeds=inputs_embeds, 82 | labels=labels, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | @torch.no_grad() 90 | def generate( 91 | self, 92 | inputs: Optional[torch.Tensor] = None, 93 | images: Optional[torch.Tensor] = None, 94 | image_sizes: Optional[torch.Tensor] = None, 95 | **kwargs, 96 | ) -> Union[GenerateOutput, torch.LongTensor]: 97 | position_ids = kwargs.pop("position_ids", None) 98 | attention_mask = kwargs.pop("attention_mask", None) 99 | if "inputs_embeds" in kwargs: 100 | raise NotImplementedError("`inputs_embeds` is not supported") 101 | 102 | if images is not None: 103 | ( 104 | inputs, 105 | position_ids, 106 | attention_mask, 107 | _, 108 | inputs_embeds, 109 | _ 110 | ) = self.prepare_inputs_labels_for_multimodal( 111 | inputs, 112 | position_ids, 113 | attention_mask, 114 | None, 115 | None, 116 | images, 117 | image_sizes=image_sizes 118 | ) 119 | else: 120 | inputs_embeds = self.get_model().embed_tokens(inputs) 121 | 122 | return super().generate( 123 | position_ids=position_ids, 124 | attention_mask=attention_mask, 125 | inputs_embeds=inputs_embeds, 126 | **kwargs 127 | ) 128 | 129 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 130 | inputs_embeds=None, **kwargs): 131 | images = kwargs.pop("images", None) 132 | image_sizes = kwargs.pop("image_sizes", None) 133 | inputs = super().prepare_inputs_for_generation( 134 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 135 | ) 136 | if images is not None: 137 | inputs['images'] = images 138 | if image_sizes is not None: 139 | inputs['image_sizes'] = image_sizes 140 | return inputs 141 | 142 | AutoConfig.register("llava_internlm2", LlavaInternLM2Config) 143 | AutoModelForCausalLM.register(LlavaInternLM2Config, LlavaInternLM2ForCausalLM) 144 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM 22 | 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | from transformers.generation.utils import GenerateOutput 25 | 26 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 27 | 28 | 29 | class LlavaConfig(LlamaConfig): 30 | model_type = "llava_llama" 31 | 32 | 33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 34 | config_class = LlavaConfig 35 | 36 | def __init__(self, config: LlamaConfig): 37 | super(LlavaLlamaModel, self).__init__(config) 38 | 39 | 40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaConfig 42 | 43 | def __init__(self, config): 44 | super(LlamaForCausalLM, self).__init__(config) 45 | self.model = LlavaLlamaModel(config) 46 | self.pretraining_tp = config.pretraining_tp 47 | self.vocab_size = config.vocab_size 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | image_sizes: Optional[List[List[int]]] = None, 69 | return_dict: Optional[bool] = None, 70 | ) -> Union[Tuple, CausalLMOutputWithPast]: 71 | 72 | if inputs_embeds is None: 73 | ( 74 | input_ids, 75 | position_ids, 76 | attention_mask, 77 | past_key_values, 78 | inputs_embeds, 79 | labels 80 | ) = self.prepare_inputs_labels_for_multimodal( 81 | input_ids, 82 | position_ids, 83 | attention_mask, 84 | past_key_values, 85 | labels, 86 | images, 87 | image_sizes 88 | ) 89 | 90 | return super().forward( 91 | input_ids=input_ids, 92 | attention_mask=attention_mask, 93 | position_ids=position_ids, 94 | past_key_values=past_key_values, 95 | inputs_embeds=inputs_embeds, 96 | labels=labels, 97 | use_cache=use_cache, 98 | output_attentions=output_attentions, 99 | output_hidden_states=output_hidden_states, 100 | return_dict=return_dict 101 | ) 102 | 103 | @torch.no_grad() 104 | def generate( 105 | self, 106 | inputs: Optional[torch.Tensor] = None, 107 | images: Optional[torch.Tensor] = None, 108 | image_sizes: Optional[torch.Tensor] = None, 109 | **kwargs, 110 | ) -> Union[GenerateOutput, torch.LongTensor]: 111 | position_ids = kwargs.pop("position_ids", None) 112 | attention_mask = kwargs.pop("attention_mask", None) 113 | if "inputs_embeds" in kwargs: 114 | raise NotImplementedError("`inputs_embeds` is not supported") 115 | 116 | if images is not None: 117 | ( 118 | inputs, 119 | position_ids, 120 | attention_mask, 121 | _, 122 | inputs_embeds, 123 | _ 124 | ) = self.prepare_inputs_labels_for_multimodal( 125 | inputs, 126 | position_ids, 127 | attention_mask, 128 | None, 129 | None, 130 | images, 131 | image_sizes=image_sizes 132 | ) 133 | else: 134 | input_ids = kwargs.pop("input_ids", inputs) 135 | inputs_embeds = self.get_model().embed_tokens(input_ids) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_llama", LlavaConfig) 158 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 159 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | 24 | def load_model(self, device_map=None): 25 | if self.is_loaded: 26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 27 | return 28 | 29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 30 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 31 | self.vision_tower.requires_grad_(False) 32 | 33 | self.is_loaded = True 34 | 35 | def feature_select(self, image_forward_outs): 36 | image_features = image_forward_outs.hidden_states[self.select_layer] 37 | if self.select_feature == 'patch': 38 | image_features = image_features[:, 1:] 39 | elif self.select_feature == 'cls_patch': 40 | image_features = image_features 41 | else: 42 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 43 | return image_features 44 | 45 | @torch.no_grad() 46 | def forward(self, images): 47 | if type(images) is list: 48 | image_features = [] 49 | for image in images: 50 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 51 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 52 | image_features.append(image_feature) 53 | else: 54 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 55 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 56 | 57 | return image_features 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.device 70 | 71 | @property 72 | def config(self): 73 | if self.is_loaded: 74 | return self.vision_tower.config 75 | else: 76 | return self.cfg_only 77 | 78 | @property 79 | def hidden_size(self): 80 | return self.config.hidden_size 81 | 82 | @property 83 | def num_patches_per_side(self): 84 | return self.config.image_size // self.config.patch_size 85 | 86 | @property 87 | def num_patches(self): 88 | return (self.config.image_size // self.config.patch_size) ** 2 89 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | pixel_shuffle_mlp_gelu_match = re.match(r'^pixel_shuffle_ln_mlp(\d+)x_gelu$', projector_type) 49 | if pixel_shuffle_mlp_gelu_match: 50 | pixel_shuffle_ratio = getattr(config, 'pixel_shuffle_ratio', None) 51 | scale_factor = int(1 / pixel_shuffle_ratio ** 2) 52 | mlp_depth = int(pixel_shuffle_mlp_gelu_match.group(1)) 53 | modules = [ 54 | nn.LayerNorm(config.mm_hidden_size * scale_factor), 55 | nn.Linear(config.mm_hidden_size * scale_factor, config.hidden_size), 56 | ] 57 | for _ in range(1, mlp_depth): 58 | modules.append(nn.GELU()) 59 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == 'identity': 63 | return IdentityMap() 64 | 65 | raise ValueError(f'Unknown projector type: {projector_type}') 66 | -------------------------------------------------------------------------------- /mllm_llava/llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava_llama") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /mllm_llava/llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/serve/__init__.py -------------------------------------------------------------------------------- /mllm_llava/llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "mistral" in model_name.lower(): 37 | conv_mode = "mistral_instruct" 38 | elif "v1.6-34b" in model_name.lower(): 39 | conv_mode = "chatml_direct" 40 | elif "v1" in model_name.lower(): 41 | conv_mode = "llava_v1" 42 | elif "mpt" in model_name.lower(): 43 | conv_mode = "mpt" 44 | else: 45 | conv_mode = "llava_v0" 46 | 47 | if args.conv_mode is not None and conv_mode != args.conv_mode: 48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 49 | else: 50 | args.conv_mode = conv_mode 51 | 52 | conv = conv_templates[args.conv_mode].copy() 53 | if "mpt" in model_name.lower(): 54 | roles = ('user', 'assistant') 55 | else: 56 | roles = conv.roles 57 | 58 | image = load_image(args.image_file) 59 | image_size = image.size 60 | # Similar operation in model_worker.py 61 | image_tensor = process_images([image], image_processor, model.config) 62 | if type(image_tensor) is list: 63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 64 | else: 65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 66 | 67 | while True: 68 | try: 69 | inp = input(f"{roles[0]}: ") 70 | except EOFError: 71 | inp = "" 72 | if not inp: 73 | print("exit...") 74 | break 75 | 76 | print(f"{roles[1]}: ", end="") 77 | 78 | if image is not None: 79 | # first message 80 | if model.config.mm_use_im_start_end: 81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 82 | else: 83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 84 | conv.append_message(conv.roles[0], inp) 85 | image = None 86 | else: 87 | # later messages 88 | conv.append_message(conv.roles[0], inp) 89 | conv.append_message(conv.roles[1], None) 90 | prompt = conv.get_prompt() 91 | 92 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 93 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 94 | keywords = [stop_str] 95 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 96 | 97 | with torch.inference_mode(): 98 | output_ids = model.generate( 99 | input_ids, 100 | images=image_tensor, 101 | image_sizes=[image_size], 102 | do_sample=True if args.temperature > 0 else False, 103 | temperature=args.temperature, 104 | max_new_tokens=args.max_new_tokens, 105 | streamer=streamer, 106 | use_cache=True) 107 | 108 | outputs = tokenizer.decode(output_ids[0]).strip() 109 | conv.messages[-1][-1] = outputs 110 | 111 | if args.debug: 112 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 118 | parser.add_argument("--model-base", type=str, default=None) 119 | parser.add_argument("--image-file", type=str, required=True) 120 | parser.add_argument("--device", type=str, default="cuda") 121 | parser.add_argument("--conv-mode", type=str, default=None) 122 | parser.add_argument("--temperature", type=float, default=0.2) 123 | parser.add_argument("--max-new-tokens", type=int, default=512) 124 | parser.add_argument("--load-8bit", action="store_true") 125 | parser.add_argument("--load-4bit", action="store_true") 126 | parser.add_argument("--debug", action="store_true") 127 | args = parser.parse_args() 128 | 129 | if args.model_base is None and "::" in args.model_path: 130 | model_base_and_path = args.model_path 131 | args.model_base, args.model_path = model_base_and_path.split("::") 132 | print(f"model_base_and_path ({model_base_and_path}) has been split into model_path ({args.model_path}) and model_base ({args.model_base})") 133 | 134 | main(args) 135 | -------------------------------------------------------------------------------- /mllm_llava/llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /mllm_llava/llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /mllm_llava/llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /mllm_llava/llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /mllm_llava/llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /mllm_llava/llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /mllm_llava/llava/train/train_interleaved_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train_interleaved import train 2 | 3 | if __name__ == "__main__": 4 | # # OSError: IOError: broken data stream when reading image file 5 | # from PIL import Image 6 | # from PIL import ImageFile 7 | # ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | train(attn_implementation="flash_attention_2") 10 | -------------------------------------------------------------------------------- /mllm_llava/llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train(attn_implementation="flash_attention_2") 5 | -------------------------------------------------------------------------------- /mllm_llava/llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /mllm_llava/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /mllm_llava/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.2.2.post1" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.1.2+cu118", "torchvision==0.16.2+cu118", 17 | "transformers==4.36.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.30.1", "peft", "bitsandbytes", 19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==4.16.0", "gradio_client==0.8.1", 21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 23 | "deepspeed==0.12.6" 24 | ] 25 | 26 | [project.optional-dependencies] 27 | train = ["deepspeed==0.12.6", "ninja", "wandb"] 28 | build = ["build", "twine"] 29 | 30 | [project.urls] 31 | "Homepage" = "https://llava-vl.github.io" 32 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" 33 | 34 | [tool.setuptools.packages.find] 35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 36 | 37 | [tool.wheel] 38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 39 | -------------------------------------------------------------------------------- /mllm_openflamingo/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.json 3 | 4 | wandb/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Pycharm project settings 119 | .idea 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | *.out 137 | src/wandb 138 | wandb 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # Cache 144 | cache/ 145 | 146 | __*.sh 147 | 148 | 149 | # added by lqy 150 | tmp/* 151 | scripts_sh/* 152 | *.err 153 | .deepspeed_env 154 | */*/results_* 155 | -------------------------------------------------------------------------------- /mllm_openflamingo/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | SOFTWARE. 21 | -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | downcast_bf16: false 4 | machine_rank: 0 5 | main_training_function: main 6 | mixed_precision: bf16 7 | num_machines: 1 8 | num_processes: 2 9 | rdzv_backend: static 10 | same_network: false 11 | tpu_use_cluster: false 12 | tpu_use_sudo: false 13 | use_cpu: false 14 | main_process_port: 20685 15 | -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_fsdp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: no 3 | downcast_bf16: true 4 | machine_rank: 0 5 | main_training_function: main 6 | mixed_precision: bf16 7 | num_machines: 1 8 | num_processes: 1 9 | rdzv_backend: static 10 | same_network: true 11 | tpu_use_cluster: false 12 | tpu_use_sudo: false 13 | use_cpu: false 14 | main_process_port: 20687 15 | -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_standalone.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 1 11 | same_network: true 12 | -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_zero1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 1 9 | distributed_type: DEEPSPEED 10 | fsdp_config: {} 11 | main_training_function: main 12 | mixed_precision: bf16 13 | use_cpu: false 14 | num_machines: 1 15 | num_processes: 8 -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_zero1_slurm.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: false 9 | zero_stage: 1 10 | distributed_type: DEEPSPEED 11 | fsdp_config: {} 12 | main_training_function: main 13 | mixed_precision: bf16 14 | use_cpu: false 15 | num_machines: 1 16 | num_processes: 8 -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 2 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | fsdp_config: {} 12 | main_training_function: main 13 | mixed_precision: bf16 14 | use_cpu: false 15 | num_machines: 1 16 | num_processes: 8 -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_zero2_slurm.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | fsdp_config: {} 12 | main_training_function: main 13 | mixed_precision: bf16 14 | use_cpu: false 15 | num_machines: 2 16 | num_processes: 16 -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | fsdp_config: {} 12 | main_training_function: main 13 | mixed_precision: bf16 14 | use_cpu: false 15 | num_machines: 1 16 | num_processes: 8 -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_zero3_offload.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: cpu 6 | offload_param_device: cpu 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | fsdp_config: {} 12 | main_training_function: main 13 | mixed_precision: bf16 14 | use_cpu: false 15 | num_machines: 1 16 | num_processes: 8 -------------------------------------------------------------------------------- /mllm_openflamingo/accelerate_configs/accelerate_config_zero3_slurm.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 2 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | fsdp_config: {} 13 | main_training_function: main 14 | mixed_precision: bf16 15 | use_cpu: false 16 | num_machines: 1 17 | num_processes: 8 -------------------------------------------------------------------------------- /mllm_openflamingo/environment.yml: -------------------------------------------------------------------------------- 1 | name: lmm_baseline_openflamingo 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9 6 | - conda-forge::openjdk 7 | - pip 8 | - pip: 9 | - -r requirements.txt 10 | - -r requirements-training.txt 11 | - -r requirements-eval.txt 12 | - -e . 13 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/__init__.py: -------------------------------------------------------------------------------- 1 | from .src.flamingo import Flamingo 2 | from .src.factory import create_model_and_transforms 3 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/eval/README.md: -------------------------------------------------------------------------------- 1 | # OpenFlamingo Evaluation Suite 2 | 3 | This is the evaluation module of OpenFlamingo. It contains a set of utilities for evaluating multimodal models on various benchmarking datasets. 4 | 5 | *This module is a work in progress! We will be updating this README as it develops. In the meantime, if you notice an issue, please file a Bug Report or Feature Request [here](https://github.com/mlfoundations/open_flamingo/issues/new/choose).* 6 | 7 | ## Supported datasets 8 | 9 | |Dataset|Task|Metric|Evaluation method| 10 | |-------|----|------|-----------------| 11 | |[COCO](https://arxiv.org/abs/1405.0312)|Captioning|CIDEr|Generation| 12 | |[Flickr-30K](https://aclanthology.org/Q14-1006/)|Captioning|CIDEr|Generation| 13 | |[VQAv2](https://arxiv.org/abs/1612.00837v3)|VQA|VQA accuracy|Generation| 14 | |[OK-VQA](https://arxiv.org/abs/1906.00067)|VQA|VQA accuracy|Generation| 15 | |[TextVQA](https://arxiv.org/abs/1904.08920)|VQA|VQA accuracy|Generation| 16 | |[VizWiz](https://arxiv.org/abs/1802.08218)|VQA|VQA accuracy|Generation| 17 | |[Hateful Memes](https://arxiv.org/abs/2005.04790)|Classification|ROC AUC|Logprobs| 18 | |[ImageNet](https://arxiv.org/abs/1409.0575)|Classification|Top-1 accuracy|Logprobs| 19 | 20 | When evaluating a model using `num_shots` shots, we sample the exemplars from the training split. Performance is evaluated on a disjoint test split, subsampled to `--num_samples` examples (or using the full test split if `--num_samples=-1`). 21 | 22 | ## Sample scripts 23 | Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun`. We provide a sample Slurm evaluation script in `open_flamingo/open_flamingo/scripts/run_eval.sh`. 24 | 25 | We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16. 26 | 27 | To evaluate one of our pretrained checkpoints, we suggest first downloading a local copy of the weights, as follows: 28 | 29 | ``` 30 | # grab model checkpoint from huggingface hub 31 | from huggingface_hub import hf_hub_download 32 | HF_TOKEN="" 33 | 34 | checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt") 35 | checkpoint_path= hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", 36 | "checkpoint.pt", 37 | local_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b", 38 | cache_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b", 39 | local_dir_use_symlinks=False, 40 | token=HF_TOKEN) 41 | print(checkpoint_path) 42 | ## openflamingo/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt 43 | ``` 44 | 45 | This should place the OpenFlamingo model at the expected location in the evaluation script. 46 | 47 | For TextVQA and VizWiz we expect annotations to be formatted differently than the original datasets. We provide the custom annotations in `open_flamingo/open_flamingo/eval/data/`. We have also uploaded all the annotation files in a [huggingface dataset](https://huggingface.co/datasets/openflamingo/eval_benchmark/tree/main) for easy access. 48 | 49 | # Evaluating using RICES (Retrieval-based In-Context Example Selection) 50 | 51 | We provide the option to evaluate using RICES, which is a method for selecting exemplars from the training set based on image similarity. This method was used in DeepMind's implementation for evaluating on ImageNet, but can be used for any dataset in our evaluation suite. 52 | 53 | To use RICES, you must first create features for a benchmark's training set. We provide a script for doing so in `open_flamingo/open_flamingo/scripts/cache_rices_features.py`. This script will extract image features for a given dataset using a given CLIP model checkpoint. For example, to extract features for the COCO training set, you can run: 54 | 55 | ```bash 56 | python cache_rices_features.py \ 57 | --vision_encoder_path ViT-L-14 \ 58 | --vision_encoder_pretrained openai \ 59 | --batch_size 128 \ 60 | --eval_coco \ 61 | --coco_train_image_dir_path /path/to/coco/train2014 \ 62 | --coco_val_image_dir_path /path/to/coco/val2014 \ 63 | --coco_karpathy_json_path /path/to/coco/dataset_coco.json \ 64 | --coco_annotations_json_path /path/to/coco/annotations/captions_train2014.json \ 65 | --output_dir /path/to/coco/features 66 | ``` 67 | 68 | This will create a directory at `/path/to/coco/features` containing a file named `coco.pkl` with the extracted features. You can then use this directory to evaluate using RICES by passing the `--rices` flag to the evaluation script, specifying the path to the features directory using the `--cached_demonstration_features` flag, and specifying the vision encoder to use for RICES using the `--rices_vision_encoder_path` and `--rices_vision_encoder_pretrained` flags. 69 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/eval/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/eval/coco_metric.py: -------------------------------------------------------------------------------- 1 | from pycocoevalcap.eval import COCOEvalCap 2 | from pycocotools.coco import COCO 3 | 4 | 5 | def compute_cider( 6 | result_path, 7 | annotations_path, 8 | ): 9 | # create coco object and coco_result object 10 | coco = COCO(annotations_path) 11 | coco_result = coco.loadRes(result_path) 12 | 13 | # create coco_eval object by taking coco and coco_result 14 | coco_eval = COCOEvalCap(coco, coco_result) 15 | coco_eval.params["image_id"] = coco_result.getImgIds() 16 | coco_eval.evaluate() 17 | 18 | return coco_eval.eval 19 | 20 | 21 | def postprocess_captioning_generation(predictions): 22 | return predictions.split("Output", 1)[0] 23 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/eval/eval_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import argparse 3 | from typing import List 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | from PIL import Image 6 | 7 | 8 | class BaseEvalModel(abc.ABC): 9 | """Base class encapsulating functionality needed to evaluate a model.""" 10 | 11 | def __init__(self, args: List[str]): 12 | """Initialize model. 13 | 14 | Args: 15 | args: arguments to model. These should be parsed, or if the model 16 | has no applicable arguments, an error should be thrown if `args` 17 | is non-empty. 18 | """ 19 | 20 | def init_distributed(self): 21 | """Wrap model as DDP.""" 22 | self.model = DDP(self.model, device_ids=[self.device]) 23 | 24 | def set_device(self, device): 25 | """Set device for model.""" 26 | self.device = device 27 | self.model = self.model.to(device) 28 | 29 | def get_outputs( 30 | self, 31 | batch_text: List[str], 32 | batch_images: List[List[Image.Image]], 33 | min_generation_length: int, 34 | max_generation_length: int, 35 | num_beams: int, 36 | length_penalty: float, 37 | ) -> List[str]: 38 | """Get outputs for a batch of images and text. 39 | 40 | Args: 41 | batch_text: list of text strings, with the text "" in place 42 | of any images to be included. 43 | batch_images: images to provide to model. Should be a list of lists, 44 | where each list contains the images for a single example. 45 | max_generation_length: maximum length of the generated caption. 46 | Defaults to 10. 47 | num_beams: number of beams to use for beam search. Defaults to 3. 48 | length_penalty: length penalty for beam search. Defaults to -2.0. 49 | 50 | Returns: 51 | List of decoded output strings. 52 | """ 53 | 54 | def vqa_prompt(self, question, answer=None) -> str: 55 | """Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model. 56 | 57 | Returns: 58 | The prompt to use for VQA. 59 | """ 60 | 61 | def caption_prompt(self, caption=None) -> str: 62 | """Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model. 63 | 64 | Returns: 65 | The prompt to use for captioning. 66 | """ 67 | 68 | def get_rank_classifications( 69 | self, 70 | batch_text: List[str], 71 | batch_images: List[List[Image.Image]], 72 | all_class_names: List[str], 73 | use_cache: bool, 74 | normalize_length: bool, 75 | ): 76 | """ 77 | Returns a (B, |all_class_names|) tensor containing the logprobs for each class name. 78 | Args: 79 | batch_text: list of text strings, with the text "" in place 80 | of any images to be included. 81 | batch_images: images to provide to model. Should be a list of lists, 82 | where each list contains the images for a single example. 83 | all_class_names: list of all class names. 84 | use_cache: whether to cache the context to speed up evaluations. 85 | normalize_length: whether to normalize logprobs by the length of the 86 | class name 87 | Returns: 88 | (B, |all_class_names|) tensor containing the logprobs for each class name. 89 | """ 90 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/eval/models/blip.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from PIL import Image 4 | import torch 5 | 6 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 7 | from open_flamingo.eval.eval_model import BaseEvalModel 8 | from open_flamingo.eval.utils import unwrap_model 9 | 10 | 11 | class EvalModel(BaseEvalModel): 12 | """BLIP-2 model evaluation. 13 | 14 | Attributes: 15 | model (nn.Module): Underlying Torch model. 16 | tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. 17 | device: Index of GPU to use, or the string "cpu" 18 | """ 19 | 20 | def __init__(self, model_args): 21 | assert ( 22 | "processor_path" in model_args and "lm_path" in model_args 23 | ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified" 24 | 25 | self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) 26 | self.model = Blip2ForConditionalGeneration.from_pretrained( 27 | model_args["lm_path"] 28 | ) 29 | self.model.eval() 30 | self.processor.tokenizer.padding_side = "left" 31 | self.lm_name = model_args["lm_path"].split("/")[-1] 32 | 33 | def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: 34 | """Preprocess images and stack them. 35 | 36 | Args: 37 | batch: A list of lists of images. 38 | 39 | Returns: 40 | A Tensor of shape 41 | (batch_size, channels, height, width). 42 | """ 43 | batch_images = None 44 | assert all( 45 | len(example) == 1 for example in batch 46 | ), "BLIP-2 only supports one image per example" 47 | 48 | for example in batch: 49 | assert len(example) == 1, "BLIP-2 only supports one image per example" 50 | batch_images = torch.cat( 51 | [ 52 | batch_images, 53 | self.processor.image_processor(example, return_tensors="pt")[ 54 | "pixel_values" 55 | ], 56 | ] 57 | if batch_images is not None 58 | else [ 59 | self.processor.image_processor(example, return_tensors="pt")[ 60 | "pixel_values" 61 | ] 62 | ], 63 | dim=0, 64 | ) 65 | return batch_images 66 | 67 | def get_outputs( 68 | self, 69 | batch_text: List[str], 70 | batch_images: List[List[Image.Image]], 71 | min_generation_length: int, 72 | max_generation_length: int, 73 | num_beams: int, 74 | length_penalty: float, 75 | ) -> List[str]: 76 | encodings = self.processor.tokenizer( 77 | batch_text, 78 | padding="longest", 79 | truncation=True, 80 | return_tensors="pt", 81 | max_length=2000, 82 | ) 83 | input_ids = encodings["input_ids"] 84 | attention_mask = encodings["attention_mask"] 85 | 86 | with torch.inference_mode(): 87 | outputs = unwrap_model(self.model).generate( 88 | self._prepare_images(batch_images).to(self.device), 89 | input_ids.to(self.device), 90 | attention_mask=attention_mask.to(self.device), 91 | max_new_tokens=max_generation_length, 92 | min_new_tokens=min_generation_length, 93 | num_beams=num_beams, 94 | length_penalty=length_penalty, 95 | ) 96 | 97 | return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) 98 | 99 | def get_vqa_prompt(self, question, answer=None) -> str: 100 | return ( 101 | f"Question:{question} Short answer:{answer if answer is not None else ''}" 102 | ) 103 | 104 | def get_caption_prompt(self, caption=None) -> str: 105 | return f"A photo of {caption if caption is not None else ''}" 106 | 107 | def get_rank_classifications( 108 | self, 109 | batch_text: List[str], 110 | batch_images: List[List[Image.Image]], 111 | all_class_names: List[str], 112 | use_cache: bool, 113 | normalize_length: bool, 114 | ): 115 | raise NotImplementedError( 116 | "BLIP-2 classification-based evaluation not implemented" 117 | ) 118 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/eval/rices.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | from tqdm import tqdm 4 | import torch 5 | from utils import custom_collate_fn 6 | 7 | 8 | class RICES: 9 | def __init__( 10 | self, 11 | dataset, 12 | device, 13 | batch_size, 14 | vision_encoder_path="ViT-B-32", 15 | vision_encoder_pretrained="openai", 16 | cached_features=None, 17 | ): 18 | self.dataset = dataset 19 | self.device = device 20 | self.batch_size = batch_size 21 | 22 | # Load the model and processor 23 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 24 | vision_encoder_path, 25 | pretrained=vision_encoder_pretrained, 26 | ) 27 | self.model = vision_encoder.to(self.device) 28 | self.image_processor = image_processor 29 | 30 | # Precompute features 31 | if cached_features is None: 32 | self.features = self._precompute_features() 33 | else: 34 | self.features = cached_features 35 | 36 | def _precompute_features(self): 37 | features = [] 38 | 39 | # Switch to evaluation mode 40 | self.model.eval() 41 | 42 | # Set up loader 43 | loader = torch.utils.data.DataLoader( 44 | self.dataset, 45 | batch_size=self.batch_size, 46 | collate_fn=custom_collate_fn, 47 | ) 48 | 49 | with torch.no_grad(): 50 | for batch in tqdm( 51 | loader, 52 | desc="Precomputing features for RICES", 53 | ): 54 | batch = batch["image"] 55 | inputs = torch.stack( 56 | [self.image_processor(image) for image in batch] 57 | ).to(self.device) 58 | image_features = self.model.encode_image(inputs) 59 | image_features /= image_features.norm(dim=-1, keepdim=True) 60 | features.append(image_features.detach()) 61 | 62 | features = torch.cat(features) 63 | return features 64 | 65 | def find(self, batch, num_examples): 66 | """ 67 | Get the top num_examples most similar examples to the images. 68 | """ 69 | # Switch to evaluation mode 70 | self.model.eval() 71 | 72 | with torch.no_grad(): 73 | inputs = torch.stack([self.image_processor(image) for image in batch]).to( 74 | self.device 75 | ) 76 | 77 | # Get the feature of the input image 78 | query_feature = self.model.encode_image(inputs) 79 | query_feature /= query_feature.norm(dim=-1, keepdim=True) 80 | query_feature = query_feature.detach().cpu() 81 | 82 | if query_feature.ndim == 1: 83 | query_feature = query_feature.unsqueeze(0) 84 | 85 | # Compute the similarity of the input image to the precomputed features 86 | similarity = (query_feature @ self.features.T).squeeze() 87 | 88 | if similarity.ndim == 1: 89 | similarity = similarity.unsqueeze(0) 90 | 91 | # Get the indices of the 'num_examples' most similar images 92 | indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples] 93 | 94 | # Return with the most similar images last 95 | return [[self.dataset[i] for i in reversed(row)] for row in indices] 96 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/eval/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import torch.nn as nn 5 | from contextlib import suppress 6 | 7 | 8 | def random_seed(seed=42, rank=0): 9 | torch.manual_seed(seed + rank) 10 | np.random.seed(seed + rank) 11 | random.seed(seed + rank) 12 | 13 | 14 | def custom_collate_fn(batch): 15 | """ 16 | Collate function for DataLoader that collates a list of dicts into a dict of lists. 17 | """ 18 | collated_batch = {} 19 | for key in batch[0].keys(): 20 | collated_batch[key] = [item[key] for item in batch] 21 | return collated_batch 22 | 23 | 24 | def compute_effective_num_shots(num_shots, model_type): 25 | """ 26 | Compute the effective number of shots for a given model type. 27 | For example, following Flamingo, 0-shot OF evaluations use two text-only shots. 28 | """ 29 | if model_type == "open_flamingo": 30 | return num_shots if num_shots > 0 else 2 31 | return num_shots 32 | 33 | 34 | def sample_batch_demos_from_query_set(query_set, num_samples, batch_size): 35 | """ 36 | Sample random demonstrations from the query set. 37 | """ 38 | return [random.sample(query_set, num_samples) for _ in range(batch_size)] 39 | 40 | 41 | def get_query_set(train_dataset, query_set_size): 42 | """ 43 | Get a subset of the training dataset to use as the query set. 44 | """ 45 | query_set = np.random.choice(len(train_dataset), query_set_size, replace=False) 46 | return [train_dataset[i] for i in query_set] 47 | 48 | 49 | def prepare_eval_samples(test_dataset, num_samples, batch_size): 50 | """ 51 | Subset the test dataset and return a DataLoader. 52 | """ 53 | random_indices = np.random.choice(len(test_dataset), num_samples, replace=False) 54 | dataset = torch.utils.data.Subset(test_dataset, random_indices) 55 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 56 | loader = torch.utils.data.DataLoader( 57 | dataset, 58 | batch_size=batch_size, 59 | sampler=sampler, 60 | collate_fn=custom_collate_fn, 61 | ) 62 | return loader 63 | 64 | 65 | def get_indices_of_unique(x): 66 | """ 67 | Return the indices of x that correspond to unique elements. 68 | If value v is unique and two indices in x have value v, the first index is returned. 69 | """ 70 | unique_elements = torch.unique(x) 71 | first_indices = [] 72 | for v in unique_elements: 73 | indices = torch.where(x == v)[0] 74 | first_indices.append(indices[0]) # Take the first index for each unique element 75 | return torch.tensor(first_indices) 76 | 77 | 78 | def unwrap_model(model): 79 | """ 80 | Unwrap a model from a DataParallel or DistributedDataParallel wrapper. 81 | """ 82 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): 83 | return model.module 84 | else: 85 | return model 86 | 87 | 88 | def get_predicted_classnames(logprobs, k, class_id_to_name): 89 | """ 90 | Args: 91 | - logprobs shape (B, Y) containing logprobs for each classname 92 | - k: number for top-k 93 | - class_id_to_name: dict mapping class index to classname 94 | 95 | Returns: 96 | - top-k predicted classnames shape (B, k) type str 97 | - top-k logprobs shape (B, k) type float 98 | """ 99 | # convert indices to classnames 100 | _, predictions = torch.topk(logprobs, k=k, dim=1) # shape (B, k) 101 | predicted_classnames = [ 102 | [class_id_to_name[ix] for ix in item] for item in predictions.tolist() 103 | ] 104 | predicted_logprobs = torch.gather(logprobs, 1, predictions) 105 | return predicted_classnames, predicted_logprobs 106 | 107 | 108 | def get_cast_dtype(precision: str): 109 | cast_dtype = None 110 | if precision == "bf16": 111 | cast_dtype = torch.bfloat16 112 | elif precision == "fp16": 113 | cast_dtype = torch.float16 114 | return cast_dtype 115 | 116 | 117 | def get_autocast(precision): 118 | if precision == "amp": 119 | return torch.cuda.amp.autocast 120 | elif precision == "amp_bfloat16" or precision == "amp_bf16": 121 | # amp_bfloat16 is more stable than amp float16 for clip training 122 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 123 | else: 124 | return suppress 125 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/scripts/convert_mmc4_to_wds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import uuid 5 | import zipfile 6 | from PIL import Image 7 | import base64 8 | from io import BytesIO 9 | 10 | import braceexpand 11 | import webdataset as wds 12 | 13 | arg_parser = argparse.ArgumentParser() 14 | arg_parser.add_argument( 15 | "--output_dir", 16 | type=str, 17 | help="Pass in the directory where the output shards (as tar files) will be written to.", 18 | ) 19 | arg_parser.add_argument( 20 | "--zip_files", 21 | type=str, 22 | help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip", 23 | ) 24 | arg_parser.add_argument( 25 | "--image_dir", 26 | type=str, 27 | help="Pass in the directory where the images have been downloaded to.", 28 | ) 29 | arg_parser.add_argument( 30 | "--num_files_per_shard", 31 | type=int, 32 | default=1000, 33 | ) 34 | args = arg_parser.parse_args() 35 | 36 | 37 | def main(): 38 | os.makedirs(args.output_dir, exist_ok=True) 39 | 40 | doc_shards = list(braceexpand.braceexpand(args.zip_files)) 41 | 42 | with wds.ShardWriter(args.output_dir + "/%09d.tar") as sink: 43 | for idx in range(len(doc_shards)): 44 | # Open the ZIP archive and extract the JSON file 45 | with zipfile.ZipFile(doc_shards[idx], "r") as zip_file: 46 | # Assumes the JSON file is the first file in the archive 47 | json_filename = zip_file.namelist()[0] 48 | with zip_file.open(json_filename, "r") as json_file: 49 | for sample_data in json_file: 50 | # get image names from json 51 | sample_data = json.loads(sample_data) 52 | image_info = sample_data["image_info"] 53 | image_names = [image["image_name"] for image in image_info] 54 | 55 | # Add each image to the tar file 56 | for img_idx, image_name in enumerate(image_names): 57 | try: 58 | # load image 59 | img = Image.open( 60 | os.path.join(args.image_dir, str(idx), image_name) 61 | ).convert("RGB") 62 | buffered = BytesIO() 63 | img.save(buffered, format="JPEG") 64 | img_str = base64.b64encode(buffered.getvalue()) 65 | 66 | # convert to base64 67 | sample_data["image_info"][img_idx][ 68 | "image_base64" 69 | ] = img_str.decode("utf-8") 70 | except FileNotFoundError: 71 | print( 72 | f"Did not find {image_name} downloaded. This can happen if the url is now 404." 73 | ) 74 | except Exception as e: 75 | print(f"Error processing {image_name}: {e}") 76 | 77 | key_str = uuid.uuid4().hex 78 | sink.write({"__key__": key_str, "json": sample_data}) 79 | 80 | if (idx + 1) % args.num_files_per_shard == 0: 81 | sink.next_stream() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/scripts/fill_vqa_testdev_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission. 3 | Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set. 4 | Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction. 5 | """ 6 | import json 7 | import sys 8 | import os 9 | 10 | sys.path.append( 11 | os.path.join( 12 | os.path.dirname(os.path.abspath(__file__)), 13 | "..", 14 | ) 15 | ) 16 | from eval.vqa_metric import VQAEval 17 | 18 | postprocessor = VQAEval(None, None) 19 | 20 | 21 | def fill_vizwiz_test_json( 22 | input_path, 23 | output_path, 24 | vqa_test_questions_json_path, 25 | ): 26 | # read the input json and build a set with all question_ids 27 | with open(input_path, "r") as f: 28 | input_json = json.load(f) 29 | 30 | # postprocess answers 31 | question_id_to_answer = {} 32 | for q in input_json: 33 | resAns = q["answer"] 34 | resAns = resAns.replace("\n", " ") 35 | resAns = resAns.replace("\t", " ") 36 | resAns = resAns.strip() 37 | resAns = postprocessor.processPunctuation(resAns) 38 | resAns = postprocessor.processDigitArticle(resAns) 39 | question_id_to_answer[q["question_id"]] = resAns 40 | 41 | # read the vqa test json to get all the qustion_ids that need to be filled 42 | with open(vqa_test_questions_json_path, "r") as f: 43 | vqa_test_json = json.load(f) 44 | vqa_test_json = vqa_test_json["questions"] 45 | 46 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 47 | output_json = [] 48 | for q in vqa_test_json: 49 | output_json.append( 50 | { 51 | "image": q["image_id"], 52 | "answer": question_id_to_answer.get(q["question_id"], ""), 53 | } 54 | ) 55 | 56 | # write the json to the output path 57 | with open(output_path, "w") as f: 58 | json.dump(output_json, f) 59 | 60 | 61 | def fill_vqav2_test_json( 62 | input_path, 63 | output_path, 64 | vqa_test_questions_json_path, 65 | ): 66 | # read the input json and build a set with all question_ids 67 | with open(input_path, "r") as f: 68 | input_json = json.load(f) 69 | question_ids = set() 70 | for q in input_json: 71 | question_ids.add(q["question_id"]) 72 | 73 | # make a copy of the input json 74 | output_json = [] 75 | for q in input_json: 76 | resAns = q["answer"] 77 | resAns = resAns.replace("\n", " ") 78 | resAns = resAns.replace("\t", " ") 79 | resAns = resAns.strip() 80 | resAns = postprocessor.processPunctuation(resAns) 81 | resAns = postprocessor.processDigitArticle(resAns) 82 | q["answer"] = resAns 83 | output_json.append(q) 84 | 85 | # read the vqa test json to get all the qustion_ids that need to be filled 86 | with open(vqa_test_questions_json_path, "r") as f: 87 | vqa_test_json = json.load(f) 88 | vqa_test_json = vqa_test_json["questions"] 89 | 90 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 91 | for q in vqa_test_json: 92 | if q["question_id"] not in question_ids: 93 | output_json.append( 94 | { 95 | "question_id": q["question_id"], 96 | "answer": "", 97 | } 98 | ) 99 | 100 | # write the json to the output path 101 | with open(output_path, "w") as f: 102 | json.dump(output_json, f) 103 | 104 | 105 | if __name__ == "__main__": 106 | import argparse 107 | 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument( 110 | "--dataset", 111 | type=str, 112 | choices=["vqav2", "vizwiz"], 113 | ) 114 | parser.add_argument( 115 | "--input_path", 116 | type=str, 117 | help="Path to the json file with the subset of the vqa test-dev questions.", 118 | ) 119 | parser.add_argument( 120 | "--vqa_test_questions_json_path", 121 | type=str, 122 | help="Path to the json file with all the vqa test questions.", 123 | ) 124 | parser.add_argument( 125 | "--output_path", 126 | type=str, 127 | help="Path to store the filled json.", 128 | ) 129 | args = parser.parse_args() 130 | 131 | if args.dataset == "vqav2": 132 | fill_vqav2_test_json( 133 | args.input_path, 134 | args.output_path, 135 | args.vqa_test_questions_json_path, 136 | ) 137 | else: 138 | fill_vizwiz_test_json( 139 | args.input_path, 140 | args.output_path, 141 | args.vqa_test_questions_json_path, 142 | ) 143 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_openflamingo/open_flamingo/src/__init__.py -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/src/factory.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | import open_clip 5 | 6 | from .flamingo import Flamingo 7 | from .flamingo_lm import FlamingoLMMixin 8 | from .utils import extend_instance 9 | 10 | 11 | def create_model_and_transforms( 12 | clip_vision_encoder_path: str, 13 | clip_vision_encoder_pretrained: str, 14 | lang_encoder_path: str, 15 | tokenizer_path: str, 16 | cross_attn_every_n_layers: int = 1, 17 | use_local_files: bool = False, 18 | decoder_layers_attr_name: str = None, 19 | freeze_lm_embeddings: bool = False, 20 | cache_dir: Optional[str] = None, 21 | **flamingo_kwargs, 22 | ): 23 | """ 24 | Initialize a Flamingo model from a pretrained vision encoder and language encoder. 25 | Appends special tokens to the tokenizer and freezes backbones. 26 | 27 | Args: 28 | clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") 29 | clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") 30 | lang_encoder_path (str): path to pretrained language encoder 31 | tokenizer_path (str): path to pretrained tokenizer 32 | cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. 33 | use_local_files (bool, optional): whether to use local files. Defaults to False. 34 | decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. 35 | freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. 36 | cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. 37 | Returns: 38 | Flamingo: Flamingo model from pretrained vision and language encoders 39 | Image processor: Pipeline to preprocess input images 40 | Tokenizer: A tokenizer for the language model 41 | """ 42 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 43 | clip_vision_encoder_path, 44 | pretrained=clip_vision_encoder_pretrained, 45 | cache_dir=cache_dir, 46 | ) 47 | # set the vision encoder to output the visual features 48 | vision_encoder.visual.output_tokens = True 49 | 50 | text_tokenizer = AutoTokenizer.from_pretrained( 51 | tokenizer_path, 52 | local_files_only=use_local_files, 53 | trust_remote_code=True, 54 | cache_dir=cache_dir, 55 | ) 56 | # add Flamingo special tokens to the tokenizer 57 | text_tokenizer.add_special_tokens( 58 | {"additional_special_tokens": ["<|endofchunk|>", ""]} 59 | ) 60 | if text_tokenizer.pad_token is None: 61 | # Issue: GPT models don't have a pad token, which we use to 62 | # modify labels for the loss. 63 | text_tokenizer.add_special_tokens({"pad_token": ""}) 64 | 65 | lang_encoder = AutoModelForCausalLM.from_pretrained( 66 | lang_encoder_path, 67 | local_files_only=use_local_files, 68 | trust_remote_code=True, 69 | cache_dir=cache_dir, 70 | ) 71 | 72 | # hacks for MPT-1B, which doesn't have a get_input_embeddings method 73 | if "mpt-1b-redpajama-200b" in lang_encoder_path: 74 | 75 | class EmbeddingFnMixin: 76 | def get_input_embeddings(self): 77 | return self.transformer.wte 78 | 79 | def set_input_embeddings(self, new_embeddings): 80 | self.transformer.wte = new_embeddings 81 | 82 | extend_instance(lang_encoder, EmbeddingFnMixin) 83 | 84 | # convert LM to FlamingoLM 85 | extend_instance(lang_encoder, FlamingoLMMixin) 86 | 87 | if decoder_layers_attr_name is None: 88 | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) 89 | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) 90 | lang_encoder.resize_token_embeddings(len(text_tokenizer)) 91 | 92 | model = Flamingo( 93 | vision_encoder, 94 | lang_encoder, 95 | text_tokenizer.encode("<|endofchunk|>")[-1], 96 | text_tokenizer.encode("")[-1], 97 | vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ 98 | "width" 99 | ], 100 | cross_attn_every_n_layers=cross_attn_every_n_layers, 101 | **flamingo_kwargs, 102 | ) 103 | 104 | # Freeze all parameters 105 | model.requires_grad_(False) 106 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 107 | 108 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings 109 | model.perceiver.requires_grad_(True) 110 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) 111 | if not freeze_lm_embeddings: 112 | model.lang_encoder.get_input_embeddings().requires_grad_(True) 113 | # TODO: investigate also training the output embeddings when untied 114 | 115 | print( 116 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" 117 | ) 118 | 119 | return model, image_processor, text_tokenizer 120 | 121 | 122 | def _infer_decoder_layers_attr_name(model): 123 | for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: 124 | if k.lower() in model.__class__.__name__.lower(): 125 | return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] 126 | 127 | raise ValueError( 128 | f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." 129 | ) 130 | 131 | 132 | __KNOWN_DECODER_LAYERS_ATTR_NAMES = { 133 | "opt": "model.decoder.layers", 134 | "gptj": "transformer.h", 135 | "gpt-j": "transformer.h", 136 | "pythia": "gpt_neox.layers", 137 | "llama": "model.layers", 138 | "gptneoxforcausallm": "gpt_neox.layers", 139 | "mpt": "transformer.blocks", 140 | "mosaicgpt": "transformer.blocks", 141 | } 142 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/src/utils.py: -------------------------------------------------------------------------------- 1 | def extend_instance(obj, mixin): 2 | """Apply mixins to a class instance after creation""" 3 | base_cls = obj.__class__ 4 | base_cls_name = obj.__class__.__name__ 5 | obj.__class__ = type( 6 | base_cls_name, (mixin, base_cls), {} 7 | ) # mixin needs to go first for our forward() logic to work 8 | 9 | 10 | def getattr_recursive(obj, att): 11 | """ 12 | Return nested attribute of obj 13 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 14 | """ 15 | if att == "": 16 | return obj 17 | i = att.find(".") 18 | if i < 0: 19 | return getattr(obj, att) 20 | else: 21 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 22 | 23 | 24 | def setattr_recursive(obj, att, val): 25 | """ 26 | Set nested attribute of obj 27 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 28 | """ 29 | if "." in att: 30 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 31 | setattr(obj, att.split(".")[-1], val) 32 | 33 | 34 | def apply_with_stopping_condition( 35 | module, apply_fn, apply_condition=None, stopping_condition=None, **other_args 36 | ): 37 | if stopping_condition(module): 38 | return 39 | if apply_condition(module): 40 | apply_fn(module, **other_args) 41 | for child in module.children(): 42 | apply_with_stopping_condition( 43 | child, 44 | apply_fn, 45 | apply_condition=apply_condition, 46 | stopping_condition=stopping_condition, 47 | **other_args 48 | ) 49 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/train/README.md: -------------------------------------------------------------------------------- 1 | # OpenFlamingo Training 2 | To train OpenFlamingo, please ensure your environment matches that of `environment.yml`. 3 | 4 | ## Data 5 | Our codebase uses [WebDataset](https://github.com/webdataset/webdataset) to efficiently load `.tar` files containing image and text sequences. We recommend resampling shards with replacement during training using the `--dataset_resampled` flag. 6 | 7 | ### LAION-2B Dataset 8 | [LAION-2B](https://arxiv.org/abs/2210.08402) contains 2B web-scraped (image, text) pairs. 9 | We use [img2dataset](https://github.com/rom1504/img2dataset) to download this dataset into tar files. 10 | 11 | ### Multimodal C4 Dataset 12 | We train on the full version of [Multimodal C4 (MMC4)](https://github.com/allenai/mmc4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, we truncate sequences to 256 text tokens and six images per sequence. 13 | 14 | Our codebase expects `.tar` files containing `.json` files, which include raw images encoded in base64. 15 | We provide scripts to convert MMC4 to this format: 16 | 17 | 1. Download the MMC4 shards into `.zip` files using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `fewer_facesv2.sh`). 18 | 2. Download the MMC4 raw images into an image directory using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `download_images.py`). 19 | 2. Run `scripts/convert_interleaved_to_wds.py` to convert the downloaded items into the expected tar files. 20 | 21 | ### ChatGPT-generated sequences 22 | A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at [this CodaLab worksheet](https://worksheets.codalab.org/worksheets/0xdcd888ff7c754ae680c5e038f6ed1d9b). We are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase. 23 | 24 | Models trained with ChatGPT-generated sequences: 25 | 26 | * OpenFlamingo-4B-vitl-rpj3b 27 | * OpenFlamingo-4B-vitl-rpj3b-langinstruct 28 | 29 | ## Example training command 30 | We provide a sample Slurm training script in `scripts/`. You can also modify the following command: 31 | 32 | ``` 33 | torchrun --nnodes=1 --nproc_per_node=4 train.py \ 34 | --lm_path anas-awadalla/mpt-1b-redpajama-200b \ 35 | --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ 36 | --cross_attn_every_n_layers 1 \ 37 | --dataset_resampled \ 38 | --batch_size_interleaved 32 \ 39 | --batch_size_laion 64 \ 40 | --train_num_samples_interleaved 125000\ 41 | --train_num_samples_laion 250000 \ 42 | --loss_multiplier_laion 0.2 \ 43 | --workers=4 \ 44 | --run_name OpenFlamingo-3B-vitl-mpt1b \ 45 | --num_epochs 480 \ 46 | --warmup_steps 1875 \ 47 | --mmc4_textsim_threshold 0.24 \ 48 | --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ 49 | --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ 50 | --report_to_wandb 51 | ``` 52 | *Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).* 53 | 54 | ## Distributed training 55 | 56 | By default, `train.py` uses Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html) for training. 57 | To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), use the `--fsdp` flag. 58 | 59 | Some notes on FSDP: 60 | 61 | * We recommend using the `--fsdp_use_orig_params` flag. If `--fsdp` is on without this flag, all language model embeddings will be unfrozen during training. (In contrast, the default behavior is to only train the newly added `` and `<|endofchunk|>` tokens.) 62 | * Note: we've encountered issues using OPT with this flag. Other language models should be compatible. 63 | * Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input / output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the `--freeze_lm_embeddings` flag. 64 | 65 | We also implement gradient checkpointing and mixed precision training. Use the `--gradient_checkpointing` and `--precision` arguments respectively. -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/train/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mllm_openflamingo/open_flamingo/train/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for setting up distributed training. 3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py 4 | """ 5 | 6 | import os 7 | import torch 8 | 9 | try: 10 | import horovod.torch as hvd 11 | except ImportError: 12 | hvd = None 13 | 14 | 15 | def is_global_master(args): 16 | return args.rank == 0 17 | 18 | 19 | def is_local_master(args): 20 | return args.local_rank == 0 21 | 22 | 23 | def is_master(args, local=False): 24 | return is_local_master(args) if local else is_global_master(args) 25 | 26 | 27 | def is_using_horovod(): 28 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 29 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 30 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 31 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 32 | if all([var in os.environ for var in ompi_vars]) or all( 33 | [var in os.environ for var in pmi_vars] 34 | ): 35 | return True 36 | else: 37 | return False 38 | 39 | 40 | def is_using_distributed(): 41 | if "WORLD_SIZE" in os.environ: 42 | return int(os.environ["WORLD_SIZE"]) > 1 43 | if "SLURM_NTASKS" in os.environ: 44 | return int(os.environ["SLURM_NTASKS"]) > 1 45 | return False 46 | 47 | 48 | def world_info_from_env(): 49 | local_rank = 0 50 | for v in ( 51 | "LOCAL_RANK", 52 | "MPI_LOCALRANKID", 53 | "SLURM_LOCALID", 54 | "OMPI_COMM_WORLD_LOCAL_RANK", 55 | ): 56 | if v in os.environ: 57 | local_rank = int(os.environ[v]) 58 | break 59 | global_rank = 0 60 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 61 | if v in os.environ: 62 | global_rank = int(os.environ[v]) 63 | break 64 | world_size = 1 65 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 66 | if v in os.environ: 67 | world_size = int(os.environ[v]) 68 | break 69 | 70 | return local_rank, global_rank, world_size 71 | 72 | 73 | def init_distributed_device(args): 74 | # Distributed training = training on more than one GPU. 75 | # Works in both single and multi-node scenarios. 76 | args.distributed = False 77 | args.world_size = 1 78 | args.rank = 0 # global rank 79 | args.local_rank = 0 80 | if args.horovod: 81 | assert hvd is not None, "Horovod is not installed" 82 | hvd.init() 83 | args.local_rank = int(hvd.local_rank()) 84 | args.rank = hvd.rank() 85 | args.world_size = hvd.size() 86 | args.distributed = True 87 | os.environ["LOCAL_RANK"] = str(args.local_rank) 88 | os.environ["RANK"] = str(args.rank) 89 | os.environ["WORLD_SIZE"] = str(args.world_size) 90 | elif is_using_distributed(): 91 | if "SLURM_PROCID" in os.environ: 92 | # DDP via SLURM 93 | args.local_rank, args.rank, args.world_size = world_info_from_env() 94 | # SLURM var -> torch.distributed vars in case needed 95 | os.environ["LOCAL_RANK"] = str(args.local_rank) 96 | os.environ["RANK"] = str(args.rank) 97 | os.environ["WORLD_SIZE"] = str(args.world_size) 98 | torch.distributed.init_process_group( 99 | backend=args.dist_backend, 100 | init_method=args.dist_url, 101 | world_size=args.world_size, 102 | rank=args.rank, 103 | ) 104 | else: 105 | # DDP via torchrun, torch.distributed.launch 106 | args.local_rank, _, _ = world_info_from_env() 107 | torch.distributed.init_process_group( 108 | backend=args.dist_backend, init_method=args.dist_url 109 | ) 110 | args.world_size = torch.distributed.get_world_size() 111 | args.rank = torch.distributed.get_rank() 112 | args.distributed = True 113 | else: 114 | # needed to run on single gpu 115 | torch.distributed.init_process_group( 116 | backend=args.dist_backend, 117 | init_method=args.dist_url, 118 | world_size=1, 119 | rank=0, 120 | ) 121 | 122 | if torch.cuda.is_available(): 123 | if args.distributed and not args.no_set_device_rank: 124 | device = "cuda:%d" % args.local_rank 125 | else: 126 | device = "cuda:0" 127 | torch.cuda.set_device(device) 128 | else: 129 | device = "cpu" 130 | args.device = device 131 | device = torch.device(device) 132 | return device 133 | -------------------------------------------------------------------------------- /mllm_openflamingo/requirements-eval.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | torchvision 3 | nltk 4 | inflection 5 | pycocoevalcap 6 | pycocotools 7 | tqdm 8 | scikit-learn 9 | black 10 | mypy 11 | pylint 12 | pytest 13 | requests 14 | -------------------------------------------------------------------------------- /mllm_openflamingo/requirements-training.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | braceexpand 3 | webdataset 4 | tqdm 5 | wandb 6 | -------------------------------------------------------------------------------- /mllm_openflamingo/requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | einops-exts 3 | transformers==4.28.1 4 | torch==2.0.1 5 | pillow 6 | open_clip_torch>=2.16.0 7 | sentencepiece 8 | -------------------------------------------------------------------------------- /mllm_openflamingo/setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | if __name__ == "__main__": 6 | with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: 7 | long_description = file.read() 8 | 9 | REQUIREMENTS = [ 10 | "einops", 11 | "einops-exts", 12 | "transformers>=4.28.1", 13 | "torch==2.0.1", 14 | "pillow", 15 | "open_clip_torch>=2.16.0", 16 | "sentencepiece", 17 | ] 18 | 19 | EVAL = [ 20 | "scipy", 21 | "torchvision", 22 | "nltk", 23 | "inflection", 24 | "pycocoevalcap", 25 | "pycocotools", 26 | "tqdm", 27 | ] 28 | 29 | TRAINING = [ 30 | "wandb", 31 | "torchvision", 32 | "braceexpand", 33 | "webdataset", 34 | "tqdm", 35 | "deepspeed==0.14.2", 36 | "accelerate==0.30.0", 37 | ] 38 | 39 | setup( 40 | name="open_flamingo", 41 | packages=find_packages(), 42 | include_package_data=True, 43 | version="2.0.1", 44 | license="MIT", 45 | description="An open-source framework for training large multimodal models", 46 | long_description=long_description, 47 | long_description_content_type="text/markdown", 48 | data_files=[(".", ["README.md"])], 49 | keywords=["machine learning"], 50 | install_requires=REQUIREMENTS, 51 | extras_require={ 52 | "eval": EVAL, 53 | "training": TRAINING, 54 | "all": list(set(REQUIREMENTS + EVAL + TRAINING)), 55 | }, 56 | classifiers=[ 57 | "Development Status :: 4 - Beta", 58 | "Intended Audience :: Developers", 59 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 60 | "License :: OSI Approved :: MIT License", 61 | "Programming Language :: Python :: 3.9", 62 | ], 63 | ) 64 | -------------------------------------------------------------------------------- /mllm_openflamingo/setup_deepspeed_adamcpu.py: -------------------------------------------------------------------------------- 1 | import deepspeed 2 | deepspeed.ops.op_builder.CPUAdamBuilder().load() --------------------------------------------------------------------------------