This domain is for use in illustrative examples in documents. You may use this
55 | domain in literature without prior coordination or asking for permission.
58 |
59 |
60 | """
61 |
62 | # extract main content of articles HTML
63 | data = extractor.extract(html, base_url=url)
64 |
65 | # extract main content of forum HTML
66 | # data = extractor.extract(html, base_url=url, html_type="forum")
67 |
68 | # extract main content of WeChat official accounts HTML
69 | # data = extractor.extract(html, base_url=url, html_type="weixin")
70 |
71 | print(data)
72 | ```
73 |
74 |
75 |
76 | ## Others
77 |
78 | LICENSE: [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0.html)
79 |
80 | Acknowledgments:
81 |
82 | - [trafilatura](https://github.com/adbar/trafilatura)
83 | - [readability-lxml](https://github.com/buriy/python-readability)
84 |
85 |
--------------------------------------------------------------------------------
/html_extraction/demo_trafilatura_extraction.html:
--------------------------------------------------------------------------------
1 |
2 |
Welcome to Round Table Gauteng
Did you know that the best young men’s organization in the world was founded in February 1927? That’s right, Round Table was created after the Duke of Windsor gave a speech urging young business and professional men to get together around a table and improve themselves and their communities.
Round Table is a club specifically designed for young men between 18 and 40 (45 in some countries) years old, regardless of their beliefs or political opinions. It’s an exclusive club, but anyone can join if they’re invited by a current member who will then become their sponsor.
Round Table is not just any ordinary club, it’s a voluntary non-political, non-denomination non-profit organization. It’s all about meeting new people, exchanging ideas, and developing friendships. With approximately 30,000 members worldwide, Round Table is an international organization that promotes fellowship and community service.
As a member of Round Table, you’ll have the opportunity to give back to your community by raising funds through various events and lending a hand wherever it’s needed. You’ll also have the chance to involve your family, allowing your kids to learn important values and grow alongside you.
Our motto is Adopt, Adapt, Improve, and we strive to foster growth in all areas of our lives. We aim to develop fellowship among young men through their professional and business occupations, while emphasizing that one’s calling offers an excellent medium of service to the community. We cultivate the highest ideals in business, professional, and civic traditions, while recognizing the worthiness of all legitimate occupations.
Round Table fosters responsible citizenship and loyalty to our country, while furthering the establishment of peace and goodwill in international relationships. We achieve these goals through meetings, lectures, discussions, and other activities.
Our logo, the Rondel, is a circular object that symbolizes equality. Each seat around the table is equal, and everyone has an equal voice. The original disc that inspired the logo is over 4 meters in diameter and was likely used as a table. The center of the logo features the Tudor’s rose, which was the emblem of the Tudors from 1450 on, and above it, a king with his scepter, possibly King Arthur.
So, if you’re a young man looking for a way to make a difference in your community while making new friends and developing yourself, consider joining Round Table. Adopt, Adapt, Improve – it’s not just our motto, it’s a way of life.
Round Table raises funds through either organizing various fund-raising events, or by offering time and energy to give a hand wherever it’s needed. As with any hobby, you certainly get out of it what you put in.
This personal commitment certainly also offers a wider experience: Round Table is used to involve families, where our kids can learn to spend time together, sharing our values and growing in similar aspects in life as the Tablers themselves do.
Aims and Objectives
These are the goals that we strive for together – to foster growth across all levels
To develop fellowship among young men through the medium of their professional and business occupations .
To emphasise the fact that one’s calling offers an excellent medium of service to the community.
To cultivate the highest ideals in business, professional and civic traditions.
To recognise the worthiness of all legitimate occupations and to dignify each his own by precept and example.
To foster responsible citizenship and loyalty to our country
To further the establishment of peace and goodwill in international relationships
To further the above objects by meetings, lectures, discussions and other activities.
Our Motto is Adopt, Adapt, Improve.
The Rondel
Where it comes from and what it means
The first Round Table (Norwich, England) adopted the proposal from Neville Headon from RT Manchester: a drawing inspired by a wooden disc which can still be admired in the Great Hall in Winchester (England).
In its centre, we can see the “Tudor’s rose” which was the emblems of the Tudors from 1450 on; just above we can see a king (which could be the King Arthur) with its sceptre.
As this original disc is more than 4 metres as diameter, it seems it was used as a Table. This is why the 2-colors-radius all around remembers that every seat around the Table is equal.
This logo was taken over by RT Britain and Ireland while chartering their national association in 1928. This general theme was then taken over by each new association. King Arthur and / or the central rose was replaced each time by a symbol that refers to the relevant national association.
View our Social media below to follow us.
3 |
--------------------------------------------------------------------------------
/html_extraction/demo_url.txt:
--------------------------------------------------------------------------------
1 | https://roundtablegauteng.co.za/?page_number_0=4
2 |
--------------------------------------------------------------------------------
/human_feedback_textfiltering/README.md:
--------------------------------------------------------------------------------
1 | # Human Feedback Textfiltering
2 |
3 |
4 |
5 | This repository contains 2,000 examples in test.jsonl. To perform text filtering on these examples, simply run the following command:
6 |
7 | ```shell
8 | python run.py
9 | ```
--------------------------------------------------------------------------------
/human_feedback_textfiltering/myxml.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from lxml.etree import Element, RelaxNG, SubElement, XMLParser, fromstring, tostring
3 | from html import unescape
4 |
5 | WITH_ATTRIBUTES = {'cell', 'del', 'graphic', 'head', 'hi', 'item', 'list', 'ref'}
6 |
7 | from trafilatura.xml import replace_element_text, NEWLINE_ELEMS, SPECIAL_FORMATTING, CONTROL_PARSER, sanitize, validate_tei
8 |
9 | LOGGER = logging.getLogger(__name__)
10 |
11 | def clean_attributes(tree):
12 | '''Remove unnecessary attributes.'''
13 | for elem in tree.iter('*'):
14 | if elem.tag not in WITH_ATTRIBUTES:
15 | elem.attrib.clear()
16 | return tree
17 |
18 | def control_xml_output(output_tree, output_format, tei_validation, docmeta, pretty_print=True):
19 | '''Make sure the XML output is conform and valid if required'''
20 | control_string = sanitize(tostring(output_tree, encoding='unicode'))
21 | # necessary for cleaning
22 | output_tree = fromstring(control_string, CONTROL_PARSER)
23 | # validate
24 | if output_format == 'xmltei' and tei_validation is True:
25 | result = validate_tei(output_tree)
26 | LOGGER.debug('TEI validation result: %s %s %s', result, docmeta.id, docmeta.url)
27 | return tostring(output_tree, pretty_print=pretty_print, encoding='unicode').strip()
28 |
29 | def xmltotxt(xmloutput, include_formatting, include_images=True):
30 | '''Convert to plain text format and optionally preserve formatting as markdown.'''
31 | returnlist = []
32 | # strip_tags(xmloutput, 'div', 'main', 'span')
33 | # iterate and convert to list of strings
34 | for element in xmloutput.iter('*'):
35 | if element.text is None and element.tail is None:
36 | if element.tag == 'graphic' and include_images:
37 | # add source, default to ''
38 | text = element.get('title', '')
39 | if element.get('alt') is not None:
40 | text += ' ' + element.get('alt')
41 | url = element.get('src', '')
42 | if not url: url = element.get('data-src', '')
43 | if url: returnlist.extend(['![', text, ']', '(', url, ')'])
44 | # newlines for textless elements
45 | if element.tag in ('graphic', 'row', 'table'):
46 | returnlist.append('\n')
47 | continue
48 | # process text
49 | textelement = replace_element_text(element, include_formatting)
50 | # common elements
51 | if element.tag in NEWLINE_ELEMS:
52 | returnlist.extend(['\n', textelement, '\n'])
53 | # particular cases
54 | elif element.tag == 'item':
55 | returnlist.extend(['\n- ', textelement, '\n'])
56 | elif element.tag == 'cell':
57 | returnlist.extend(['|', textelement, '|'])
58 | elif element.tag == 'comments':
59 | returnlist.append('\n\n')
60 | else:
61 | if element.tag not in SPECIAL_FORMATTING:
62 | LOGGER.debug('unprocessed element in output: %s', element.tag)
63 | returnlist.extend([textelement, ' '])
64 | return unescape(sanitize(''.join(returnlist)))
65 |
66 | def build_xml_output(docmeta):
67 | '''Build XML output tree based on extracted information'''
68 | output = Element('doc')
69 | output = add_xml_meta(output, docmeta)
70 | docmeta.body.tag = 'main'
71 | # clean XML tree
72 | output.append(clean_attributes(docmeta.body))
73 | if docmeta.commentsbody is not None:
74 | docmeta.commentsbody.tag = 'comments'
75 | output.append(clean_attributes(docmeta.commentsbody))
76 | # XML invalid characters
77 | # https://chase-seibert.github.io/blog/2011/05/20/stripping-control-characters-in-python.html
78 | return output
79 |
80 | def add_xml_meta(output, docmeta):
81 | '''Add extracted metadata to the XML output tree'''
82 | # metadata
83 | if docmeta:
84 | if docmeta.sitename is not None:
85 | output.set('sitename', docmeta.sitename)
86 | if docmeta.title is not None:
87 | output.set('title', docmeta.title)
88 | if docmeta.author is not None:
89 | output.set('author', docmeta.author)
90 | if docmeta.date is not None:
91 | output.set('date', docmeta.date)
92 | if docmeta.url is not None:
93 | output.set('source', docmeta.url)
94 | if docmeta.hostname is not None:
95 | output.set('hostname', docmeta.hostname)
96 | if docmeta.description is not None:
97 | output.set('excerpt', docmeta.description)
98 | if docmeta.categories is not None:
99 | try:
100 | output.set('categories', ';'.join(docmeta.categories))
101 | except:
102 | pass
103 | if docmeta.tags is not None:
104 | try:
105 | output.set('tags', ';'.join(docmeta.tags))
106 | except:
107 | pass
108 | if docmeta.license is not None:
109 | output.set('license', docmeta.license)
110 | if docmeta.id is not None:
111 | output.set('id', docmeta.id)
112 | if docmeta.fingerprint is not None:
113 | output.set('fingerprint', docmeta.fingerprint)
114 | if docmeta.language is not None:
115 | output.set('language', docmeta.language)
116 | return output
117 |
--------------------------------------------------------------------------------
/human_feedback_textfiltering/run.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 | from lxml import etree
4 | from meta import xmltodoc
5 | from myxml import xmltotxt
6 | from mytraf import mydetermine_returnstring
7 | from text_filter.filters import (words_discard,
8 | char_discard, stop_word_discard, document_porn_discard,
9 | img_txt_ratio_discard, uppercase_discard, numerical_discard,
10 | social_media_counter_discard, one_word_discard, short_discard,
11 | porn_discard, comments_discard, header_footer_discard,
12 | newlines_discard, heads_discard, underlines_split,
13 | video_field_discard,
14 | readme_discard, fulltext_discard, url_discard, image_caption_discard,
15 | advertisement_discard, re_short_long_paragraphs,
16 | check_paragraph_lengths,
17 | tooshort_discard, aberrant_item_discard, cite_discard,
18 | social_media_discard, phonenum_author_time_discard,
19 | filter_download_links, filter_source_references)
20 |
21 |
22 | def filter_single_xml_en(xml_str):
23 | doc = xmltodoc(xml_str)
24 | xml = etree.fromstring(xml_str)
25 | _, num_images = mydetermine_returnstring(
26 | doc, output_format='txt', include_formatting=False, tei_validation=False)
27 |
28 | xml, _ = newlines_discard(xml)
29 | xml, _ = underlines_split(xml)
30 | xml, _ = video_field_discard(xml)
31 | xml, _ = fulltext_discard(xml)
32 |
33 | # Paragraph Level Filtering
34 | xml, _ = filter_download_links(xml)
35 | xml, _ = filter_source_references(xml)
36 | xml, _ = uppercase_discard(xml, uppercase_threshold=0.8, immutable=True)
37 | xml, _ = numerical_discard(xml, numerical_threshold=0.8, immutable=True)
38 | xml, _ = social_media_counter_discard(xml)
39 | xml, _ = one_word_discard(xml, immutable=True)
40 | xml, _ = short_discard(xml)
41 | xml, _ = porn_discard(xml)
42 | xml, _ = comments_discard(xml)
43 | xml, _ = header_footer_discard(xml)
44 | xml, _ = heads_discard(xml)
45 | xml, _ = readme_discard(xml)
46 | xml, _ = url_discard(xml)
47 | xml, _ = image_caption_discard(xml)
48 | xml, _ = advertisement_discard(xml)
49 | xml, _ = re_short_long_paragraphs(xml)
50 | xml, _ = tooshort_discard(xml)
51 | xml, _ = aberrant_item_discard(xml)
52 | xml, _ = cite_discard(xml)
53 | xml, _ = social_media_discard(xml)
54 | xml, _ = phonenum_author_time_discard(xml)
55 |
56 | pure_txt = xmltotxt(
57 | xml, include_formatting=False, include_images=False)
58 | if pure_txt is None or len(pure_txt) == 0:
59 | return None
60 | words_keep, words = words_discard(pure_txt, words_count_range=(50, 100000), avg_word_len_range=[3, 10],
61 | return_words=True)
62 | if not words_keep:
63 | return None
64 | if not char_discard(pure_txt, char_threshold=0.8, words=words): return None
65 | if not stop_word_discard(pure_txt, stop_word_threshold=2, words=words): return None
66 | if not document_porn_discard(pure_txt, thld=0.02): return None
67 | if not img_txt_ratio_discard(pure_txt, image_count=num_images, min_image_count=2,
68 | min_text_image_ratio=50): return None
69 | if not check_paragraph_lengths(xml): return None
70 | return etree.tostring(xml).decode()
71 |
72 | def read_jsonl_file(file_path):
73 | data = []
74 | with open(file_path, 'r', encoding='utf8') as file:
75 | # print(file)
76 | for line in file:
77 | # print(line)
78 | data.append(json.loads(line))
79 | return data
80 |
81 | def main_function(input_file=None, output_file=None):
82 | datas = read_jsonl_file(file_path=input_file)
83 | with open(output_file, 'w', encoding='utf-8') as f:
84 | for line_dict in datas:
85 | xml_txt = line_dict['content']
86 | res = filter_single_xml_en(xml_txt)
87 | if res is not None:
88 | res = filter_single_xml_en(res)
89 | if res is not None:
90 | res = filter_single_xml_en(res)
91 | if res is None:
92 | continue
93 | line_dict['content'] = res
94 | f.write(json.dumps(line_dict, ensure_ascii=False) + '\n')
95 |
96 |
97 | if __name__ == '__main__':
98 | main_function(input_file="test.jsonl", output_file="output.jsonl")
99 |
--------------------------------------------------------------------------------
/human_feedback_textfiltering/text_filter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/human_feedback_textfiltering/text_filter/__init__.py
--------------------------------------------------------------------------------
/human_feedback_textfiltering/text_filter/spam_word_less.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/human_feedback_textfiltering/text_filter/spam_word_less.txt
--------------------------------------------------------------------------------
/mllm_internvl/DATASET_CARD.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/mllm_internvl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_internvl/__init__.py
--------------------------------------------------------------------------------
/mllm_internvl/coco_metric.py:
--------------------------------------------------------------------------------
1 | from pycocoevalcap.eval import COCOEvalCap
2 | from pycocotools.coco import COCO
3 |
4 |
5 | def compute_cider(
6 | result_path,
7 | annotations_path,
8 | ):
9 | # create coco object and coco_result object
10 | coco = COCO(annotations_path)
11 | coco_result = coco.loadRes(result_path)
12 |
13 | # create coco_eval object by taking coco and coco_result
14 | coco_eval = COCOEvalCap(coco, coco_result)
15 | coco_eval.params["image_id"] = coco_result.getImgIds()
16 | coco_eval.evaluate()
17 |
18 | return coco_eval.eval
19 |
20 |
21 | def postprocess_captioning_generation(predictions):
22 | return predictions.split("Output", 1)[0]
23 |
--------------------------------------------------------------------------------
/mllm_internvl/eval_datasets.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision.datasets import ImageFolder
7 |
8 | from classification_utils import IMAGENET_CLASSNAMES
9 |
10 |
11 | class CaptionDataset(Dataset):
12 | def __init__(
13 | self,
14 | image_train_dir_path,
15 | annotations_path,
16 | is_train,
17 | dataset_name,
18 | image_val_dir_path=None,
19 | ):
20 | self.image_train_dir_path = image_train_dir_path
21 | self.image_val_dir_path = image_val_dir_path
22 | self.annotations = []
23 | self.is_train = is_train
24 | self.dataset_name = dataset_name
25 |
26 | full_annotations = json.load(open(annotations_path))["images"]
27 |
28 | for i in range(len(full_annotations)):
29 | if self.is_train and full_annotations[i]["split"] != "train":
30 | continue
31 | elif not self.is_train and full_annotations[i]["split"] != "test":
32 | continue
33 |
34 | self.annotations.append(full_annotations[i])
35 |
36 | def __len__(self):
37 | return len(self.annotations)
38 |
39 | def __getitem__(self, idx):
40 | if self.dataset_name == "coco":
41 | image = Image.open(
42 | os.path.join(
43 | self.image_train_dir_path, self.annotations[idx]["filename"]
44 | )
45 | if self.annotations[idx]["filepath"] == "train2014"
46 | else os.path.join(
47 | self.image_val_dir_path, self.annotations[idx]["filename"]
48 | )
49 | )
50 | elif self.dataset_name == "flickr":
51 | image = Image.open(
52 | os.path.join(
53 | self.image_train_dir_path, self.annotations[idx]["filename"]
54 | )
55 | )
56 | image.load()
57 | caption = self.annotations[idx]["sentences"][0]["raw"]
58 | return {
59 | "image": image,
60 | "caption": caption,
61 | "image_id": self.annotations[idx]["cocoid"]
62 | if self.dataset_name == "coco"
63 | else self.annotations[idx]["filename"].split(".")[0],
64 | }
65 |
66 |
67 | class VQADataset(Dataset):
68 | def __init__(
69 | self, image_dir_path, question_path, annotations_path, is_train, dataset_name
70 | ):
71 | self.questions = json.load(open(question_path, "r"))["questions"]
72 | if annotations_path is not None:
73 | self.answers = json.load(open(annotations_path, "r"))["annotations"]
74 | else:
75 | self.answers = None
76 | self.image_dir_path = image_dir_path
77 | self.is_train = is_train
78 | self.dataset_name = dataset_name
79 | if self.dataset_name in {"vqav2", "ok_vqa"}:
80 | self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
81 | assert self.img_coco_split in {"train2014", "val2014", "test2015"}
82 |
83 | def __len__(self):
84 | return len(self.questions)
85 |
86 | def get_img_path(self, question):
87 | if self.dataset_name in {"vqav2", "ok_vqa"}:
88 | return os.path.join(
89 | self.image_dir_path,
90 | f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg"
91 | if self.is_train
92 | else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg",
93 | )
94 | elif self.dataset_name == "vizwiz":
95 | return os.path.join(self.image_dir_path, question["image_id"])
96 | elif self.dataset_name == "textvqa":
97 | return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
98 | else:
99 | raise Exception(f"Unknown VQA dataset {self.dataset_name}")
100 |
101 | def __getitem__(self, idx):
102 | question = self.questions[idx]
103 | img_path = self.get_img_path(question)
104 | image = Image.open(img_path)
105 | image.load()
106 | results = {
107 | "image": image,
108 | "question": question["question"],
109 | "question_id": question["question_id"],
110 | }
111 | if self.answers is not None:
112 | answers = self.answers[idx]
113 | results["answers"] = [a["answer"] for a in answers["answers"]]
114 | return results
115 |
116 |
117 | class ImageNetDataset(ImageFolder):
118 | """Class to represent the ImageNet1k dataset."""
119 |
120 | def __init__(self, root, **kwargs):
121 | super().__init__(root=root, **kwargs)
122 | self.class_id_to_name = dict(
123 | zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES)
124 | )
125 |
126 | def __getitem__(self, idx):
127 | sample, target = super().__getitem__(idx)
128 | target_label = self.class_id_to_name[target]
129 | return {
130 | "id": idx,
131 | "image": sample,
132 | "class_id": target, # numeric ID of the ImageNet class
133 | "class_name": target_label, # human-readable name of ImageNet class
134 | }
135 |
136 |
137 | class HatefulMemesDataset(Dataset):
138 | def __init__(self, image_dir_path, annotations_path):
139 | self.image_dir_path = image_dir_path
140 | with open(annotations_path, "r") as f:
141 | self.annotations = [json.loads(line) for line in f]
142 |
143 | def __len__(self):
144 | return len(self.annotations)
145 |
146 | def __getitem__(self, idx):
147 | annotation = self.annotations[idx]
148 | img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1])
149 | image = Image.open(img_path)
150 | image.load()
151 | return {
152 | "id": annotation["id"],
153 | "image": image,
154 | "ocr": annotation["text"],
155 | "class_name": "yes" if annotation["label"] == 1 else "no",
156 | "class_id": annotation["label"],
157 | }
158 |
--------------------------------------------------------------------------------
/mllm_internvl/evaluate_with_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
4 | export MASTER_PORT=12322
5 | echo $MASTER_ADDR
6 | echo $MASTER_PORT
7 | # SLURM_PROCID
8 | HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
9 | echo HOSTNAMES=$HOSTNAMES
10 | H=$(hostname)
11 | THEID=$(echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]")
12 | if [[ -z "${THEID// }" ]]; then
13 | THEID=0
14 | fi
15 | echo SLURM_PROCID=$THEID
16 | NNODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
17 | echo NNODES=$NNODES
18 |
19 | set -x
20 |
21 | torchrun --nnodes=$NNODES --nproc-per-node=8 \
22 | --master_port ${MASTER_PORT} --master_addr ${MASTER_ADDR} --node_rank ${THEID} \
23 | evaluate.py $@
24 |
--------------------------------------------------------------------------------
/mllm_internvl/fill_vqa_testdev_results.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission.
3 | Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set.
4 | Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction.
5 | """
6 | import json
7 | import sys
8 | import os
9 |
10 | sys.path.append(
11 | os.path.join(
12 | os.path.dirname(os.path.abspath(__file__)),
13 | "..",
14 | )
15 | )
16 | from vqa_metric import VQAEval
17 |
18 | postprocessor = VQAEval(None, None)
19 |
20 |
21 | def fill_vizwiz_test_json(
22 | input_path,
23 | output_path,
24 | vqa_test_questions_json_path,
25 | ):
26 | # read the input json and build a set with all question_ids
27 | with open(input_path, "r") as f:
28 | input_json = json.load(f)
29 |
30 | # postprocess answers
31 | question_id_to_answer = {}
32 | for q in input_json:
33 | resAns = q["answer"]
34 | resAns = resAns.replace("\n", " ")
35 | resAns = resAns.replace("\t", " ")
36 | resAns = resAns.strip()
37 | resAns = postprocessor.processPunctuation(resAns)
38 | resAns = postprocessor.processDigitArticle(resAns)
39 | question_id_to_answer[q["question_id"]] = resAns
40 |
41 | # read the vqa test json to get all the qustion_ids that need to be filled
42 | with open(vqa_test_questions_json_path, "r") as f:
43 | vqa_test_json = json.load(f)
44 | vqa_test_json = vqa_test_json["questions"]
45 |
46 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer
47 | output_json = []
48 | for q in vqa_test_json:
49 | output_json.append(
50 | {
51 | "image": q["image_id"],
52 | "answer": question_id_to_answer.get(q["question_id"], ""),
53 | }
54 | )
55 |
56 | # write the json to the output path
57 | with open(output_path, "w") as f:
58 | json.dump(output_json, f)
59 |
60 |
61 | def fill_vqav2_test_json(
62 | input_path,
63 | output_path,
64 | vqa_test_questions_json_path,
65 | ):
66 | # read the input json and build a set with all question_ids
67 | with open(input_path, "r") as f:
68 | input_json = json.load(f)
69 | question_ids = set()
70 | for q in input_json:
71 | question_ids.add(q["question_id"])
72 |
73 | # make a copy of the input json
74 | output_json = []
75 | for q in input_json:
76 | resAns = q["answer"]
77 | resAns = resAns.replace("\n", " ")
78 | resAns = resAns.replace("\t", " ")
79 | resAns = resAns.strip()
80 | resAns = postprocessor.processPunctuation(resAns)
81 | resAns = postprocessor.processDigitArticle(resAns)
82 | q["answer"] = resAns
83 | output_json.append(q)
84 |
85 | # read the vqa test json to get all the qustion_ids that need to be filled
86 | with open(vqa_test_questions_json_path, "r") as f:
87 | vqa_test_json = json.load(f)
88 | vqa_test_json = vqa_test_json["questions"]
89 |
90 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer
91 | for q in vqa_test_json:
92 | if q["question_id"] not in question_ids:
93 | output_json.append(
94 | {
95 | "question_id": q["question_id"],
96 | "answer": "",
97 | }
98 | )
99 |
100 | # write the json to the output path
101 | with open(output_path, "w") as f:
102 | json.dump(output_json, f)
103 |
104 |
105 | if __name__ == "__main__":
106 | import argparse
107 |
108 | parser = argparse.ArgumentParser()
109 | parser.add_argument(
110 | "--dataset",
111 | type=str,
112 | choices=["vqav2", "vizwiz"],
113 | )
114 | parser.add_argument(
115 | "--input_path",
116 | type=str,
117 | help="Path to the json file with the subset of the vqa test-dev questions.",
118 | )
119 | parser.add_argument(
120 | "--vqa_test_questions_json_path",
121 | type=str,
122 | help="Path to the json file with all the vqa test questions.",
123 | )
124 | parser.add_argument(
125 | "--output_path",
126 | type=str,
127 | help="Path to store the filled json.",
128 | )
129 | args = parser.parse_args()
130 |
131 | if args.dataset == "vqav2":
132 | fill_vqav2_test_json(
133 | args.input_path,
134 | args.output_path,
135 | args.vqa_test_questions_json_path,
136 | )
137 | else:
138 | fill_vizwiz_test_json(
139 | args.input_path,
140 | args.output_path,
141 | args.vqa_test_questions_json_path,
142 | )
143 |
--------------------------------------------------------------------------------
/mllm_internvl/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn
2 | pycocoevalcap
3 | einops_exts
4 | opencv-python
5 | imageio
6 | decord
7 | nltk
8 | inflection
9 | termcolor
10 | yacs
11 | pyyaml
12 | scipy
13 | tqdm
--------------------------------------------------------------------------------
/mllm_internvl/rices.py:
--------------------------------------------------------------------------------
1 | import open_clip
2 | import torch
3 | from tqdm import tqdm
4 | import torch
5 | from utils import custom_collate_fn
6 |
7 |
8 | class RICES:
9 | def __init__(
10 | self,
11 | dataset,
12 | device,
13 | batch_size,
14 | vision_encoder_path="ViT-B-32",
15 | vision_encoder_pretrained="openai",
16 | cached_features=None,
17 | ):
18 | self.dataset = dataset
19 | self.device = device
20 | self.batch_size = batch_size
21 |
22 | # Load the model and processor
23 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
24 | vision_encoder_path,
25 | pretrained=vision_encoder_pretrained,
26 | )
27 | self.model = vision_encoder.to(self.device)
28 | self.image_processor = image_processor
29 |
30 | # Precompute features
31 | if cached_features is None:
32 | self.features = self._precompute_features()
33 | else:
34 | self.features = cached_features
35 |
36 | def _precompute_features(self):
37 | features = []
38 |
39 | # Switch to evaluation mode
40 | self.model.eval()
41 |
42 | # Set up loader
43 | loader = torch.utils.data.DataLoader(
44 | self.dataset,
45 | batch_size=self.batch_size,
46 | collate_fn=custom_collate_fn,
47 | )
48 |
49 | with torch.no_grad():
50 | for batch in tqdm(
51 | loader,
52 | desc="Precomputing features for RICES",
53 | ):
54 | batch = batch["image"]
55 | inputs = torch.stack(
56 | [self.image_processor(image) for image in batch]
57 | ).to(self.device)
58 | image_features = self.model.encode_image(inputs)
59 | image_features /= image_features.norm(dim=-1, keepdim=True)
60 | features.append(image_features.detach())
61 |
62 | features = torch.cat(features)
63 | return features
64 |
65 | def find(self, batch, num_examples):
66 | """
67 | Get the top num_examples most similar examples to the images.
68 | """
69 | # Switch to evaluation mode
70 | self.model.eval()
71 |
72 | with torch.no_grad():
73 | inputs = torch.stack([self.image_processor(image) for image in batch]).to(
74 | self.device
75 | )
76 |
77 | # Get the feature of the input image
78 | query_feature = self.model.encode_image(inputs)
79 | query_feature /= query_feature.norm(dim=-1, keepdim=True)
80 | query_feature = query_feature.detach().cpu()
81 |
82 | if query_feature.ndim == 1:
83 | query_feature = query_feature.unsqueeze(0)
84 |
85 | # Compute the similarity of the input image to the precomputed features
86 | similarity = (query_feature @ self.features.T).squeeze()
87 |
88 | if similarity.ndim == 1:
89 | similarity = similarity.unsqueeze(0)
90 |
91 | # Get the indices of the 'num_examples' most similar images
92 | indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples]
93 |
94 | # Return with the most similar images last
95 | return [[self.dataset[i] for i in reversed(row)] for row in indices]
96 |
--------------------------------------------------------------------------------
/mllm_internvl/test.sh:
--------------------------------------------------------------------------------
1 | CKPT_ROOT=".."
2 | RESULT_ROOT="./results"
3 | CKPT_FNAMES=(
4 | "checkpoint-10200"
5 | )
6 | mkdir -p $RESULT_ROOT
7 |
8 | for CKPT_FNAME in "${CKPT_FNAMES[@]}"; do
9 | set -x
10 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \
11 | srun bash run_eval.sh --model internvl_chat \
12 | --batch_size 1 --shots 0 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \
13 | --load_in_8bit False --dynamic False --max_num 6 \
14 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \
15 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_0shots-multi-rounds.result"
16 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \
17 | srun bash run_eval.sh --model internvl_chat \
18 | --batch_size 1 --shots 1 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \
19 | --load_in_8bit False --dynamic False --max_num 6 \
20 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \
21 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_1shots-multi-rounds.result"
22 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \
23 | srun bash run_eval.sh --model internvl_chat \
24 | --batch_size 1 --shots 2 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \
25 | --load_in_8bit False --dynamic False --max_num 6 \
26 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \
27 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_2shots-multi-rounds.result"
28 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \
29 | srun bash run_eval.sh --model internvl_chat \
30 | --batch_size 1 --shots 4 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \
31 | --load_in_8bit False --dynamic False --max_num 6 \
32 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \
33 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_4shots-multi-rounds.result"
34 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \
35 | srun bash run_eval.sh --model internvl_chat \
36 | --batch_size 1 --shots 8 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \
37 | --load_in_8bit False --dynamic False --max_num 6 \
38 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \
39 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_8shots-multi-rounds.result"
40 | PYTHONPATH="/mnt/petrelfs/liqingyun/OmniCorpus_internal/mllm_internvl/InternVL/internvl_chat"$PYTHONPATH \
41 | srun bash run_eval.sh --model internvl_chat \
42 | --batch_size 1 --shots 0 --datasets coco flickr ok_vqa textvqa --chat-few-shot-style multi \
43 | --load_in_8bit False --dynamic False --max_num 6 --zero-shot-add-text-shots 2 \
44 | --checkpoint $CKPT_ROOT/$CKPT_FNAME \
45 | --results_file "$RESULT_ROOT/results_${CKPT_FNAME}_trick0shot-multi-rounds.result"
46 | done
47 |
--------------------------------------------------------------------------------
/mllm_llava/.gitattributes:
--------------------------------------------------------------------------------
1 | # https://git-scm.com/docs/gitattributes
2 |
3 | # Set the default behavior, in case people don't have core.autocrlf set.
4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion
5 | * text=auto
6 |
7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes
8 | # Source files
9 | # ============
10 | *.pxd text diff=python
11 | *.py text diff=python
12 | *.py3 text diff=python
13 | *.pyw text diff=python
14 | *.pyx text diff=python
15 | *.pyz text diff=python
16 | *.pyi text diff=python
17 |
18 | # Binary files
19 | # ============
20 | *.db binary
21 | *.p binary
22 | *.pkl binary
23 | *.pickle binary
24 | *.pyc binary export-ignore
25 | *.pyo binary export-ignore
26 | *.pyd binary
27 |
28 | # Jupyter notebook
29 | *.ipynb text eol=lf
30 |
--------------------------------------------------------------------------------
/mllm_llava/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__
3 | *.pyc
4 | *.egg-info
5 | dist
6 |
7 | # Log
8 | *.log
9 | *.log.*
10 | *.json
11 | *.jsonl
12 |
13 | # Data
14 | !**/alpaca-data-conversation.json
15 |
16 | # Editor
17 | .idea
18 | *.swp
19 |
20 | # Other
21 | .DS_Store
22 | wandb
23 | output
24 |
25 | checkpoints
26 | ckpts*
27 |
28 | .ipynb_checkpoints
29 | *.ipynb
30 |
31 | # DevContainer
32 | !.devcontainer/*
33 |
34 | # Demo
35 | serve_images/
36 |
37 | # data links
38 | playground/data/coco
39 | playground/data/data_path.txt
40 | playground/data/gqa
41 | playground/data/ocr_vqa
42 | playground/data/textvqa
43 | playground/data/vg
44 | playground/data/LLaVA-Pretrain/images
45 | playground/data/eval*
46 |
47 | # lqy custom
48 | runnings_*/*/*
49 | yt-sb-1b/*
50 | batchscript-*
51 | phoenix-slurm-*
52 | tmp*
53 | slurm_out/*
54 | slurm_out_*/*
55 | unit_test/mmc4_img_num/*
56 | cached_features_*
57 | unit_test/test_internlm2_tokenizer/*
58 | playground/data/lmms_eval_logs/*
59 |
--------------------------------------------------------------------------------
/mllm_llava/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM, LlavaInternLM2ForCausalLM
2 |
--------------------------------------------------------------------------------
/mllm_llava/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import tqdm
7 | import ray
8 | import time
9 |
10 | NUM_SECONDS_TO_SLEEP = 3
11 |
12 | @ray.remote(num_cpus=4)
13 | def get_eval(content: str, max_tokens: int):
14 | while True:
15 | try:
16 | response = openai.ChatCompletion.create(
17 | model='gpt-4',
18 | messages=[{
19 | 'role': 'system',
20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21 | }, {
22 | 'role': 'user',
23 | 'content': content,
24 | }],
25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
26 | max_tokens=max_tokens,
27 | )
28 | break
29 | except openai.error.RateLimitError:
30 | pass
31 | except Exception as e:
32 | print(e)
33 | time.sleep(NUM_SECONDS_TO_SLEEP)
34 |
35 | print('success!')
36 | return response['choices'][0]['message']['content']
37 |
38 |
39 | def parse_score(review):
40 | try:
41 | score_pair = review.split('\n')[0]
42 | score_pair = score_pair.replace(',', ' ')
43 | sp = score_pair.split(' ')
44 | if len(sp) == 2:
45 | return [float(sp[0]), float(sp[1])]
46 | else:
47 | print('error', review)
48 | return [-1, -1]
49 | except Exception as e:
50 | print(e)
51 | print('error', review)
52 | return [-1, -1]
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57 | parser.add_argument('-q', '--question')
58 | # parser.add_argument('-a', '--answer')
59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60 | parser.add_argument('-r', '--rule')
61 | parser.add_argument('-o', '--output')
62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63 | args = parser.parse_args()
64 |
65 | ray.init()
66 |
67 | f_q = open(os.path.expanduser(args.question))
68 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71 |
72 | review_file = open(f'{args.output}', 'w')
73 |
74 | js_list = []
75 | handles = []
76 | idx = 0
77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78 | # if idx == 1:
79 | # break
80 |
81 | ques = json.loads(ques_js)
82 | ans1 = json.loads(ans1_js)
83 | ans2 = json.loads(ans2_js)
84 |
85 | category = json.loads(ques_js)['category']
86 | if category in rule_dict:
87 | rule = rule_dict[category]
88 | else:
89 | rule = rule_dict['default']
90 | prompt = rule['prompt']
91 | role = rule['role']
92 | content = (f'[Question]\n{ques["text"]}\n\n'
93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95 | f'[System]\n{prompt}\n\n')
96 | js_list.append({
97 | 'id': idx+1,
98 | 'question_id': ques['question_id'],
99 | 'answer1_id': ans1['answer_id'],
100 | 'answer2_id': ans2['answer_id'],
101 | 'category': category})
102 | idx += 1
103 | handles.append(get_eval.remote(content, args.max_tokens))
104 | # To avoid the rate limit set by OpenAI
105 | time.sleep(NUM_SECONDS_TO_SLEEP)
106 |
107 | reviews = ray.get(handles)
108 | for idx, review in enumerate(reviews):
109 | scores = parse_score(review)
110 | js_list[idx]['content'] = review
111 | js_list[idx]['tuple'] = scores
112 | review_file.write(json.dumps(js_list[idx]) + '\n')
113 | review_file.close()
114 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/eval_gpt_review_bench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | from tqdm import tqdm
9 |
10 | NUM_SECONDS_TO_SLEEP = 0.5
11 |
12 |
13 | def get_eval(content: str, max_tokens: int):
14 | while True:
15 | try:
16 | response = openai.ChatCompletion.create(
17 | model='gpt-4-0314',
18 | messages=[{
19 | 'role': 'system',
20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21 | }, {
22 | 'role': 'user',
23 | 'content': content,
24 | }],
25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
26 | max_tokens=max_tokens,
27 | )
28 | break
29 | except openai.error.RateLimitError:
30 | pass
31 | except Exception as e:
32 | print(e)
33 | time.sleep(NUM_SECONDS_TO_SLEEP)
34 |
35 | return response['choices'][0]['message']['content']
36 |
37 |
38 | def parse_score(review):
39 | try:
40 | score_pair = review.split('\n')[0]
41 | score_pair = score_pair.replace(',', ' ')
42 | sp = score_pair.split(' ')
43 | if len(sp) == 2:
44 | return [float(sp[0]), float(sp[1])]
45 | else:
46 | print('error', review)
47 | return [-1, -1]
48 | except Exception as e:
49 | print(e)
50 | print('error', review)
51 | return [-1, -1]
52 |
53 |
54 | if __name__ == '__main__':
55 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
56 | parser.add_argument('-q', '--question')
57 | parser.add_argument('-c', '--context')
58 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
59 | parser.add_argument('-r', '--rule')
60 | parser.add_argument('-o', '--output')
61 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
62 | args = parser.parse_args()
63 |
64 | f_q = open(os.path.expanduser(args.question))
65 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
66 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
67 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
68 |
69 | if os.path.isfile(os.path.expanduser(args.output)):
70 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
71 | else:
72 | cur_reviews = []
73 |
74 | review_file = open(f'{args.output}', 'a')
75 |
76 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
77 | image_to_context = {context['image']: context for context in context_list}
78 |
79 | handles = []
80 | idx = 0
81 | for ques_js, ans1_js, ans2_js in tqdm(zip(f_q, f_ans1, f_ans2)):
82 | ques = json.loads(ques_js)
83 | ans1 = json.loads(ans1_js)
84 | ans2 = json.loads(ans2_js)
85 |
86 | inst = image_to_context[ques['image']]
87 |
88 | if isinstance(inst['caption'], list):
89 | cap_str = '\n'.join(inst['caption'])
90 | else:
91 | cap_str = inst['caption']
92 |
93 | category = 'llava_bench_' + json.loads(ques_js)['category']
94 | if category in rule_dict:
95 | rule = rule_dict[category]
96 | else:
97 | assert False, f"Visual QA category not found in rule file: {category}."
98 | prompt = rule['prompt']
99 | role = rule['role']
100 | content = (f'[Context]\n{cap_str}\n\n'
101 | f'[Question]\n{ques["text"]}\n\n'
102 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
103 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
104 | f'[System]\n{prompt}\n\n')
105 | cur_js = {
106 | 'id': idx+1,
107 | 'question_id': ques['question_id'],
108 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
109 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
110 | 'category': category
111 | }
112 | if idx >= len(cur_reviews):
113 | review = get_eval(content, args.max_tokens)
114 | scores = parse_score(review)
115 | cur_js['content'] = review
116 | cur_js['tuple'] = scores
117 | review_file.write(json.dumps(cur_js) + '\n')
118 | review_file.flush()
119 | else:
120 | print(f'Skipping {idx} as we already have it.')
121 | idx += 1
122 | print(idx)
123 | review_file.close()
124 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/eval_gpt_review_visual.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 | cap_str = '\n'.join(inst['captions'])
86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87 |
88 | category = json.loads(ques_js)['category']
89 | if category in rule_dict:
90 | rule = rule_dict[category]
91 | else:
92 | assert False, f"Visual QA category not found in rule file: {category}."
93 | prompt = rule['prompt']
94 | role = rule['role']
95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96 | f'[Question]\n{ques["text"]}\n\n'
97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99 | f'[System]\n{prompt}\n\n')
100 | cur_js = {
101 | 'id': idx+1,
102 | 'question_id': ques['question_id'],
103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105 | 'category': category
106 | }
107 | if idx >= len(cur_reviews):
108 | review = get_eval(content, args.max_tokens)
109 | scores = parse_score(review)
110 | cur_js['content'] = review
111 | cur_js['tuple'] = scores
112 | review_file.write(json.dumps(cur_js) + '\n')
113 | review_file.flush()
114 | else:
115 | print(f'Skipping {idx} as we already have it.')
116 | idx += 1
117 | print(idx)
118 | review_file.close()
119 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/eval_pope.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | def eval_pope(answers, label_file):
6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7 |
8 | for answer in answers:
9 | text = answer['text']
10 |
11 | # Only keep the first sentence
12 | if text.find('.') != -1:
13 | text = text.split('.')[0]
14 |
15 | text = text.replace(',', '')
16 | words = text.split(' ')
17 | if 'No' in words or 'not' in words or 'no' in words:
18 | answer['text'] = 'no'
19 | else:
20 | answer['text'] = 'yes'
21 |
22 | for i in range(len(label_list)):
23 | if label_list[i] == 'no':
24 | label_list[i] = 0
25 | else:
26 | label_list[i] = 1
27 |
28 | pred_list = []
29 | for answer in answers:
30 | if answer['text'] == 'no':
31 | pred_list.append(0)
32 | else:
33 | pred_list.append(1)
34 |
35 | pos = 1
36 | neg = 0
37 | yes_ratio = pred_list.count(1) / len(pred_list)
38 |
39 | TP, TN, FP, FN = 0, 0, 0, 0
40 | for pred, label in zip(pred_list, label_list):
41 | if pred == pos and label == pos:
42 | TP += 1
43 | elif pred == pos and label == neg:
44 | FP += 1
45 | elif pred == neg and label == neg:
46 | TN += 1
47 | elif pred == neg and label == pos:
48 | FN += 1
49 |
50 | print('TP\tFP\tTN\tFN\t')
51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52 |
53 | precision = float(TP) / float(TP + FP)
54 | recall = float(TP) / float(TP + FN)
55 | f1 = 2*precision*recall / (precision + recall)
56 | acc = (TP + TN) / (TP + TN + FP + FN)
57 | print('Accuracy: {}'.format(acc))
58 | print('Precision: {}'.format(precision))
59 | print('Recall: {}'.format(recall))
60 | print('F1 score: {}'.format(f1))
61 | print('Yes ratio: {}'.format(yes_ratio))
62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--annotation-dir", type=str, default="./playground/data/eval/pope/coco")
67 | parser.add_argument("--question-file", type=str, default="./playground/data/eval/pope/llava_pope_test.jsonl")
68 | parser.add_argument("--result-file", type=str, required=True)
69 | args = parser.parse_args()
70 |
71 | questions = [json.loads(line) for line in open(args.question_file)]
72 | questions = {question['question_id']: question for question in questions}
73 | answers = [json.loads(q) for q in open(args.result_file)]
74 | for file in os.listdir(args.annotation_dir):
75 | assert file.startswith('coco_pope_')
76 | assert file.endswith('.json')
77 | category = file[10:-5]
78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81 | print("====================================")
82 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/eval_science_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--base-dir', type=str)
11 | parser.add_argument('--result-file', type=str)
12 | parser.add_argument('--output-file', type=str)
13 | parser.add_argument('--output-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return -1
36 | return random.choice(range(len(choices)))
37 |
38 |
39 | if __name__ == "__main__":
40 | args = get_args()
41 |
42 | base_dir = args.base_dir
43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
45 | predictions = [json.loads(line) for line in open(args.result_file)]
46 | predictions = {pred['question_id']: pred for pred in predictions}
47 | split_problems = {idx: problems[idx] for idx in split_indices}
48 |
49 | results = {'correct': [], 'incorrect': []}
50 | sqa_results = {}
51 | sqa_results['acc'] = None
52 | sqa_results['correct'] = None
53 | sqa_results['count'] = None
54 | sqa_results['results'] = {}
55 | sqa_results['outputs'] = {}
56 |
57 | for prob_id, prob in split_problems.items():
58 | if prob_id not in predictions:
59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60 | pred_text = 'FAILED'
61 | else:
62 | pred = predictions[prob_id]
63 | pred_text = pred['text']
64 |
65 | if pred_text in args.options:
66 | answer = pred_text
67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68 | answer = pred_text[0]
69 | else:
70 | pattern = re.compile(r'The answer is ([A-Z]).')
71 | res = pattern.findall(pred_text)
72 | if len(res) == 1:
73 | answer = res[0] # 'A', 'B', ...
74 | else:
75 | answer = "FAILED"
76 |
77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78 |
79 | analysis = {
80 | 'question_id': prob_id,
81 | 'parsed_ans': answer,
82 | 'ground_truth': args.options[prob['answer']],
83 | 'question': pred['prompt'],
84 | 'pred': pred_text,
85 | 'is_multimodal': '' in pred['prompt'],
86 | }
87 |
88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89 | sqa_results['outputs'][prob_id] = pred_text
90 |
91 | if pred_idx == prob['answer']:
92 | results['correct'].append(analysis)
93 | else:
94 | results['incorrect'].append(analysis)
95 |
96 | correct = len(results['correct'])
97 | total = len(results['correct']) + len(results['incorrect'])
98 |
99 | ###### IMG ######
100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102 | multimodal_total = multimodal_correct + multimodal_incorrect
103 | ###### IMG ######
104 |
105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106 |
107 | sqa_results['acc'] = correct / total * 100
108 | sqa_results['correct'] = correct
109 | sqa_results['count'] = total
110 |
111 | with open(args.output_file, 'w') as f:
112 | json.dump(results, f, indent=2)
113 | with open(args.output_result, 'w') as f:
114 | json.dump(sqa_results, f, indent=2)
115 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--our-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | our_predictions = [json.loads(line) for line in open(args.our_result)]
45 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49 |
50 | results = defaultdict(lambda: 0)
51 |
52 | for prob_id, prob in split_problems.items():
53 | if prob_id not in our_predictions:
54 | continue
55 | if prob_id not in gpt4_predictions:
56 | continue
57 | our_pred = our_predictions[prob_id]['text']
58 | gpt4_pred = gpt4_predictions[prob_id]
59 |
60 | pattern = re.compile(r'The answer is ([A-Z]).')
61 | our_res = pattern.findall(our_pred)
62 | if len(our_res) == 1:
63 | our_answer = our_res[0] # 'A', 'B', ...
64 | else:
65 | our_answer = "FAILED"
66 | gpt4_res = pattern.findall(gpt4_pred)
67 | if len(gpt4_res) == 1:
68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69 | else:
70 | gpt4_answer = "FAILED"
71 |
72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74 |
75 | if gpt4_answer == 'FAILED':
76 | results['gpt4_failed'] += 1
77 | # continue
78 | gpt4_pred_idx = our_pred_idx
79 | # if our_pred_idx != prob['answer']:
80 | # print(our_predictions[prob_id]['prompt'])
81 | # print('-----------------')
82 | # print(f'LECTURE: {prob["lecture"]}')
83 | # print(f'SOLUTION: {prob["solution"]}')
84 | # print('=====================')
85 | else:
86 | # continue
87 | pass
88 | # gpt4_pred_idx = our_pred_idx
89 |
90 | if gpt4_pred_idx == prob['answer']:
91 | results['correct'] += 1
92 | else:
93 | results['incorrect'] += 1
94 |
95 |
96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97 | results['correct_upperbound'] += 1
98 |
99 | correct = results['correct']
100 | total = results['correct'] + results['incorrect']
101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 |
105 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/eval_textvqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import re
5 |
6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--annotation-file', type=str)
12 | parser.add_argument('--result-file', type=str)
13 | parser.add_argument('--result-dir', type=str)
14 | return parser.parse_args()
15 |
16 |
17 | def prompt_processor(prompt):
18 | if prompt.startswith('OCR tokens: '):
19 | pattern = r"Question: (.*?) Short answer:"
20 | match = re.search(pattern, prompt, re.DOTALL)
21 | question = match.group(1)
22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23 | if prompt.startswith('Reference OCR token:'):
24 | question = prompt.split('\n')[1]
25 | else:
26 | question = prompt.split('\n')[0]
27 | elif len(prompt.split('\n')) == 2:
28 | question = prompt.split('\n')[0]
29 | else:
30 | assert False
31 |
32 | return question.lower()
33 |
34 |
35 | def eval_single(annotation_file, result_file):
36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37 | print(experiment_name)
38 | annotations = json.load(open(annotation_file))['data']
39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40 | results = [json.loads(line) for line in open(result_file)]
41 |
42 | pred_list = []
43 | for result in results:
44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45 | pred_list.append({
46 | "pred_answer": result['text'],
47 | "gt_answers": annotation['answers'],
48 | })
49 |
50 | evaluator = TextVQAAccuracyEvaluator()
51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52 |
53 |
54 | if __name__ == "__main__":
55 | args = get_args()
56 |
57 | if args.result_file is not None:
58 | eval_single(args.annotation_file, args.result_file)
59 |
60 | if args.result_dir is not None:
61 | for result_file in sorted(os.listdir(args.result_dir)):
62 | if not result_file.endswith('.jsonl'):
63 | print(f'Skipping {result_file}')
64 | continue
65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
66 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/generate_webpage_data_from_table.py:
--------------------------------------------------------------------------------
1 | """Generate json file for webpage."""
2 | import json
3 | import os
4 | import re
5 |
6 | # models = ['llama', 'alpaca', 'gpt35', 'bard']
7 | models = ['vicuna']
8 |
9 |
10 | def read_jsonl(path: str, key: str=None):
11 | data = []
12 | with open(os.path.expanduser(path)) as f:
13 | for line in f:
14 | if not line:
15 | continue
16 | data.append(json.loads(line))
17 | if key is not None:
18 | data.sort(key=lambda x: x[key])
19 | data = {item[key]: item for item in data}
20 | return data
21 |
22 |
23 | def trim_hanging_lines(s: str, n: int) -> str:
24 | s = s.strip()
25 | for _ in range(n):
26 | s = s.split('\n', 1)[1].strip()
27 | return s
28 |
29 |
30 | if __name__ == '__main__':
31 | questions = read_jsonl('table/question.jsonl', key='question_id')
32 |
33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39 |
40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45 |
46 | records = []
47 | for qid in questions.keys():
48 | r = {
49 | 'id': qid,
50 | 'category': questions[qid]['category'],
51 | 'question': questions[qid]['text'],
52 | 'answers': {
53 | # 'alpaca': alpaca_answers[qid]['text'],
54 | # 'llama': llama_answers[qid]['text'],
55 | # 'bard': bard_answers[qid]['text'],
56 | # 'gpt35': gpt35_answers[qid]['text'],
57 | 'vicuna': vicuna_answers[qid]['text'],
58 | 'ours': ours_answers[qid]['text'],
59 | },
60 | 'evaluations': {
61 | # 'alpaca': review_alpaca[qid]['text'],
62 | # 'llama': review_llama[qid]['text'],
63 | # 'bard': review_bard[qid]['text'],
64 | 'vicuna': review_vicuna[qid]['content'],
65 | # 'gpt35': review_gpt35[qid]['text'],
66 | },
67 | 'scores': {
68 | 'vicuna': review_vicuna[qid]['tuple'],
69 | # 'alpaca': review_alpaca[qid]['score'],
70 | # 'llama': review_llama[qid]['score'],
71 | # 'bard': review_bard[qid]['score'],
72 | # 'gpt35': review_gpt35[qid]['score'],
73 | },
74 | }
75 |
76 | # cleanup data
77 | cleaned_evals = {}
78 | for k, v in r['evaluations'].items():
79 | v = v.strip()
80 | lines = v.split('\n')
81 | # trim the first line if it's a pair of numbers
82 | if re.match(r'\d+[, ]+\d+', lines[0]):
83 | lines = lines[1:]
84 | v = '\n'.join(lines)
85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86 |
87 | r['evaluations'] = cleaned_evals
88 | records.append(r)
89 |
90 | # Reorder the records, this is optional
91 | for r in records:
92 | if r['id'] <= 20:
93 | r['id'] += 60
94 | else:
95 | r['id'] -= 20
96 | for r in records:
97 | if r['id'] <= 50:
98 | r['id'] += 10
99 | elif 50 < r['id'] <= 60:
100 | r['id'] -= 50
101 | for r in records:
102 | if r['id'] == 7:
103 | r['id'] = 1
104 | elif r['id'] < 7:
105 | r['id'] += 1
106 |
107 | records.sort(key=lambda x: x['id'])
108 |
109 | # Write to file
110 | with open('webpage/data.json', 'w') as f:
111 | json.dump({'questions': records, 'models': models}, f, indent=2)
112 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/model_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3 | import torch
4 | import os
5 | import json
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.conversation import default_conversation
10 | from llava.utils import disable_torch_init
11 |
12 |
13 | @torch.inference_mode()
14 | def eval_model(model_name, questions_file, answers_file):
15 | # Model
16 | disable_torch_init()
17 | model_name = os.path.expanduser(model_name)
18 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19 | model = AutoModelForCausalLM.from_pretrained(model_name,
20 | torch_dtype=torch.float16).cuda()
21 |
22 |
23 | ques_file = open(os.path.expanduser(questions_file), "r")
24 | ans_file = open(os.path.expanduser(answers_file), "w")
25 | for i, line in enumerate(tqdm(ques_file)):
26 | idx = json.loads(line)["question_id"]
27 | qs = json.loads(line)["text"]
28 | cat = json.loads(line)["category"]
29 | conv = default_conversation.copy()
30 | conv.append_message(conv.roles[0], qs)
31 | prompt = conv.get_prompt()
32 | inputs = tokenizer([prompt])
33 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
34 | output_ids = model.generate(
35 | input_ids,
36 | do_sample=True,
37 | use_cache=True,
38 | temperature=0.7,
39 | max_new_tokens=1024,)
40 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
41 | try:
42 | index = outputs.index(conv.sep, len(prompt))
43 | except ValueError:
44 | outputs += conv.sep
45 | index = outputs.index(conv.sep, len(prompt))
46 |
47 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
48 | ans_id = shortuuid.uuid()
49 | ans_file.write(json.dumps({"question_id": idx,
50 | "text": outputs,
51 | "answer_id": ans_id,
52 | "model_id": model_name,
53 | "metadata": {}}) + "\n")
54 | ans_file.flush()
55 | ans_file.close()
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
60 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
61 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
62 | args = parser.parse_args()
63 |
64 | eval_model(args.model_name, args.question_file, args.answers_file)
65 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/model_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13 |
14 | from PIL import Image
15 | import math
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def split_list_v2(lst, n):
25 | """Split a list into n (roughly) equal-sized chunks"""
26 | base_chunk_size = len(lst) // n # integer division
27 | remainder = len(lst) % n # remaining elements
28 | chunks = []
29 | for i in range(n):
30 | chunk_size = base_chunk_size + (i < remainder) # add one to the chunk size for the first 'remainder' chunks
31 | start = i * base_chunk_size + min(i, remainder) # calculate the start index
32 | chunks.append(lst[start:start+chunk_size])
33 | return chunks
34 |
35 |
36 | def get_chunk(lst, n, k):
37 | chunks = split_list_v2(lst, n)
38 | return chunks[k]
39 |
40 |
41 | def eval_model(args):
42 | # Model
43 | disable_torch_init()
44 | model_path = os.path.expanduser(args.model_path)
45 | model_name = get_model_name_from_path(model_path)
46 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
47 |
48 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
49 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
50 | answers_file = os.path.expanduser(args.answers_file)
51 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
52 | ans_file = open(answers_file, "w")
53 | for line in tqdm(questions):
54 | idx = line["question_id"]
55 | image_file = line["image"]
56 | qs = line["text"]
57 | cur_prompt = qs
58 | if model.config.mm_use_im_start_end:
59 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
60 | else:
61 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
62 |
63 | conv = conv_templates[args.conv_mode].copy()
64 | conv.append_message(conv.roles[0], qs)
65 | conv.append_message(conv.roles[1], None)
66 | prompt = conv.get_prompt()
67 |
68 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
69 |
70 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
71 | image_tensor = process_images([image], image_processor, model.config)[0]
72 |
73 | with torch.inference_mode():
74 | output_ids = model.generate(
75 | input_ids,
76 | images=image_tensor.unsqueeze(0).half().cuda(),
77 | image_sizes=[image.size],
78 | do_sample=True if args.temperature > 0 else False,
79 | temperature=args.temperature,
80 | top_p=args.top_p,
81 | num_beams=args.num_beams,
82 | # no_repeat_ngram_size=3,
83 | max_new_tokens=1024,
84 | use_cache=True)
85 |
86 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
87 |
88 | ans_id = shortuuid.uuid()
89 | ans_file.write(json.dumps({"question_id": idx,
90 | "prompt": cur_prompt,
91 | "text": outputs,
92 | "answer_id": ans_id,
93 | "model_id": model_name,
94 | "metadata": {}}) + "\n")
95 | ans_file.flush()
96 | ans_file.close()
97 |
98 | if __name__ == "__main__":
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
101 | parser.add_argument("--model-base", type=str, default=None)
102 | parser.add_argument("--image-folder", type=str, default="")
103 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
104 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
105 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
106 | parser.add_argument("--num-chunks", type=int, default=1)
107 | parser.add_argument("--chunk-idx", type=int, default=0)
108 | parser.add_argument("--temperature", type=float, default=0.2)
109 | parser.add_argument("--top_p", type=float, default=None)
110 | parser.add_argument("--num_beams", type=int, default=1)
111 | args = parser.parse_args()
112 |
113 | if args.model_base is None and "::" in args.model_path:
114 | model_base_and_path = args.model_path
115 | args.model_base, args.model_path = model_base_and_path.split("::")
116 | print(f"model_base_and_path ({model_base_and_path}) has been split into model_path ({args.model_path}) and model_base ({args.model_base})")
117 |
118 | eval_model(args)
119 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
1 | """Generate answers with GPT-3.5"""
2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3 | import argparse
4 | import json
5 | import os
6 | import time
7 | import concurrent.futures
8 |
9 | import openai
10 | import tqdm
11 | import shortuuid
12 |
13 | MODEL = 'gpt-3.5-turbo'
14 | MODEL_ID = 'gpt-3.5-turbo:20230327'
15 |
16 | def get_answer(question_id: int, question: str, max_tokens: int):
17 | ans = {
18 | 'answer_id': shortuuid.uuid(),
19 | 'question_id': question_id,
20 | 'model_id': MODEL_ID,
21 | }
22 | for _ in range(3):
23 | try:
24 | response = openai.ChatCompletion.create(
25 | model=MODEL,
26 | messages=[{
27 | 'role': 'system',
28 | 'content': 'You are a helpful assistant.'
29 | }, {
30 | 'role': 'user',
31 | 'content': question,
32 | }],
33 | max_tokens=max_tokens,
34 | )
35 | ans['text'] = response['choices'][0]['message']['content']
36 | return ans
37 | except Exception as e:
38 | print('[ERROR]', e)
39 | ans['text'] = '#ERROR#'
40 | time.sleep(1)
41 | return ans
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46 | parser.add_argument('-q', '--question')
47 | parser.add_argument('-o', '--output')
48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49 | args = parser.parse_args()
50 |
51 | questions_dict = {}
52 | with open(os.path.expanduser(args.question)) as f:
53 | for line in f:
54 | if not line:
55 | continue
56 | q = json.loads(line)
57 | questions_dict[q['question_id']] = q['text']
58 |
59 | answers = []
60 |
61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62 | futures = []
63 | for qid, question in questions_dict.items():
64 | future = executor.submit(get_answer, qid, question, args.max_tokens)
65 | futures.append(future)
66 |
67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68 | answers.append(future.result())
69 |
70 | answers.sort(key=lambda x: x['question_id'])
71 |
72 | with open(os.path.expanduser(args.output), 'w') as f:
73 | table = [json.dumps(ans) for ans in answers]
74 | f.write('\n'.join(table))
75 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/run_llava.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import (
5 | IMAGE_TOKEN_INDEX,
6 | DEFAULT_IMAGE_TOKEN,
7 | DEFAULT_IM_START_TOKEN,
8 | DEFAULT_IM_END_TOKEN,
9 | IMAGE_PLACEHOLDER,
10 | )
11 | from llava.conversation import conv_templates, SeparatorStyle
12 | from llava.model.builder import load_pretrained_model
13 | from llava.utils import disable_torch_init
14 | from llava.mm_utils import (
15 | process_images,
16 | tokenizer_image_token,
17 | get_model_name_from_path,
18 | )
19 |
20 | from PIL import Image
21 |
22 | import requests
23 | from PIL import Image
24 | from io import BytesIO
25 | import re
26 |
27 |
28 | def image_parser(args):
29 | out = args.image_file.split(args.sep)
30 | return out
31 |
32 |
33 | def load_image(image_file):
34 | if image_file.startswith("http") or image_file.startswith("https"):
35 | response = requests.get(image_file)
36 | image = Image.open(BytesIO(response.content)).convert("RGB")
37 | else:
38 | image = Image.open(image_file).convert("RGB")
39 | return image
40 |
41 |
42 | def load_images(image_files):
43 | out = []
44 | for image_file in image_files:
45 | image = load_image(image_file)
46 | out.append(image)
47 | return out
48 |
49 |
50 | def eval_model(args):
51 | # Model
52 | disable_torch_init()
53 |
54 | model_name = get_model_name_from_path(args.model_path)
55 | tokenizer, model, image_processor, context_len = load_pretrained_model(
56 | args.model_path, args.model_base, model_name
57 | )
58 |
59 | qs = args.query
60 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
61 | if IMAGE_PLACEHOLDER in qs:
62 | if model.config.mm_use_im_start_end:
63 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
64 | else:
65 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
66 | else:
67 | if model.config.mm_use_im_start_end:
68 | qs = image_token_se + "\n" + qs
69 | else:
70 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
71 |
72 | if "llama-2" in model_name.lower():
73 | conv_mode = "llava_llama_2"
74 | elif "mistral" in model_name.lower():
75 | conv_mode = "mistral_instruct"
76 | elif "v1.6-34b" in model_name.lower():
77 | conv_mode = "chatml_direct"
78 | elif "v1" in model_name.lower():
79 | conv_mode = "llava_v1"
80 | elif "mpt" in model_name.lower():
81 | conv_mode = "mpt"
82 | else:
83 | conv_mode = "llava_v0"
84 |
85 | if args.conv_mode is not None and conv_mode != args.conv_mode:
86 | print(
87 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
88 | conv_mode, args.conv_mode, args.conv_mode
89 | )
90 | )
91 | else:
92 | args.conv_mode = conv_mode
93 |
94 | conv = conv_templates[args.conv_mode].copy()
95 | conv.append_message(conv.roles[0], qs)
96 | conv.append_message(conv.roles[1], None)
97 | prompt = conv.get_prompt()
98 |
99 | image_files = image_parser(args)
100 | images = load_images(image_files)
101 | image_sizes = [x.size for x in images]
102 | images_tensor = process_images(
103 | images,
104 | image_processor,
105 | model.config
106 | ).to(model.device, dtype=torch.float16)
107 |
108 | input_ids = (
109 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
110 | .unsqueeze(0)
111 | .cuda()
112 | )
113 |
114 | with torch.inference_mode():
115 | output_ids = model.generate(
116 | input_ids,
117 | images=images_tensor,
118 | image_sizes=image_sizes,
119 | do_sample=True if args.temperature > 0 else False,
120 | temperature=args.temperature,
121 | top_p=args.top_p,
122 | num_beams=args.num_beams,
123 | max_new_tokens=args.max_new_tokens,
124 | use_cache=True,
125 | )
126 |
127 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
128 | print(outputs)
129 |
130 |
131 | if __name__ == "__main__":
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
134 | parser.add_argument("--model-base", type=str, default=None)
135 | parser.add_argument("--image-file", type=str, required=True)
136 | parser.add_argument("--query", type=str, required=True)
137 | parser.add_argument("--conv-mode", type=str, default=None)
138 | parser.add_argument("--sep", type=str, default=",")
139 | parser.add_argument("--temperature", type=float, default=0.2)
140 | parser.add_argument("--top_p", type=float, default=None)
141 | parser.add_argument("--num_beams", type=int, default=1)
142 | parser.add_argument("--max_new_tokens", type=int, default=512)
143 | args = parser.parse_args()
144 |
145 | if args.model_base is None and "::" in args.model_path:
146 | model_base_and_path = args.model_path
147 | args.model_base, args.model_path = model_base_and_path.split("::")
148 | print(f"model_base_and_path ({model_base_and_path}) has been split into model_path ({args.model_path}) and model_base ({args.model_base})")
149 |
150 | eval_model(args)
151 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 |
7 | import argparse
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11 | parser.add_argument('-d', '--dir', default=None)
12 | parser.add_argument('-v', '--version', default=None)
13 | parser.add_argument('-s', '--select', nargs='*', default=None)
14 | parser.add_argument('-f', '--files', nargs='*', default=[])
15 | parser.add_argument('-i', '--ignore', nargs='*', default=[])
16 | return parser.parse_args()
17 |
18 |
19 | if __name__ == '__main__':
20 | args = parse_args()
21 |
22 | if args.ignore is not None:
23 | args.ignore = [int(x) for x in args.ignore]
24 |
25 | if len(args.files) > 0:
26 | review_files = args.files
27 | else:
28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29 |
30 | for review_file in sorted(review_files):
31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32 | if args.select is not None and any(x not in config for x in args.select):
33 | continue
34 | if '0613' in config:
35 | version = '0613'
36 | else:
37 | version = '0314'
38 | if args.version is not None and args.version != version:
39 | continue
40 | scores = defaultdict(list)
41 | print(config)
42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43 | for review_str in f:
44 | review = json.loads(review_str)
45 | if review['question_id'] in args.ignore:
46 | continue
47 | if 'category' in review:
48 | scores[review['category']].append(review['tuple'])
49 | scores['all'].append(review['tuple'])
50 | else:
51 | if 'tuple' in review:
52 | scores['all'].append(review['tuple'])
53 | else:
54 | scores['all'].append(review['score'])
55 | for k, v in sorted(scores.items()):
56 | stats = np.asarray(v).mean(0).tolist()
57 | stats = [round(x, 3) for x in stats]
58 | # print(k, stats, round(stats[1]/stats[0]*100, 1))
59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60 | print('=================================')
61 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/translation_tool.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import hashlib
3 | import time
4 | import random
5 | import json
6 | from googletrans import Translator
7 | import requests
8 | from requests.adapters import HTTPAdapter
9 | from requests.packages.urllib3.util.retry import Retry
10 |
11 |
12 | class TranslationTool:
13 | def __init__(self):
14 | self.translator = Translator()
15 |
16 | def translate(self, text, src='en', dest='zh-cn'):
17 | """
18 | 翻译指定的文本。
19 | :param text: 要翻译的文本字符串。
20 | :param src: 源语言代码,默认是英文 'en'。
21 | :param dest: 目标语言代码,默认是中文简体 'zh-cn'。
22 | :return: 翻译后的文本字符串。
23 | """
24 | translated = self.translator.translate(text, src=src, dest=dest)
25 | return translated.text
26 |
27 |
28 | class YoudaoTranslationTool:
29 | def __init__(self, app_id, app_secret):
30 | self.app_id = app_id
31 | self.app_secret = app_secret
32 | self.url = "https://openapi.youdao.com/api"
33 | def generate_sign(self, q, salt):
34 | sign_str = self.app_id + q + str(salt) + self.app_secret
35 | sign = hashlib.md5(sign_str.encode('utf-8')).hexdigest()
36 | return sign
37 | def translate(self, text, src='en', dest='zh-CHS'):
38 | salt = random.randint(1, 65536)
39 | sign = self.generate_sign(text, salt)
40 | params = {
41 | 'q': text,
42 | 'from': src,
43 | 'to': dest,
44 | 'appKey': self.app_id,
45 | 'salt': salt,
46 | 'sign': sign
47 | }
48 | try:
49 | response = requests.get(self.url, params=params)
50 | response.raise_for_status()
51 | result = response.json()
52 | if 'translation' in result:
53 | return result['translation'][0]
54 | else:
55 | return "Translation error: " + result.get('errorCode', 'Unknown error')
56 | except requests.exceptions.RequestException as e:
57 | return f"Network error occurred: {e}"
58 | except Exception as e:
59 | return f"An error occurred: {e}"
60 |
61 | # # 使用示例
62 | # app_id = '553c1ee6d3f9f808'
63 | # app_secret = 'iOiNGdr2OtJrspEciUDOirk6wnKzcmP5'
64 | # translation_tool = YoudaoTranslationTool(app_id, app_secret)
65 | # result = translation_tool.translate("Provide a one-sentence caption for the provided image.")
66 | # print(result)
67 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/webpage/figures/alpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/alpaca.png
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/webpage/figures/bard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/bard.jpg
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/webpage/figures/chatgpt.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/webpage/figures/llama.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/llama.jpg
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/webpage/figures/vicuna.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/eval/webpage/figures/vicuna.jpeg
--------------------------------------------------------------------------------
/mllm_llava/llava/eval/webpage/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3 | background-color: #f8f9fa;
4 | }
5 |
6 | .navbar-dark .navbar-nav .nav-link {
7 | color: #f1cf68;
8 | font-size: 1.1rem;
9 | padding: 0.5rem 0.6rem;
10 | }
11 |
12 | .card-header {
13 | font-weight: bold;
14 | }
15 |
16 | .card {
17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18 | transition: 0.3s;
19 | }
20 |
21 | .card:hover {
22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23 | }
24 |
25 | button {
26 | transition: background-color 0.3s;
27 | }
28 |
29 | button:hover {
30 | background-color: #007bff;
31 | }
32 |
33 | @media (max-width: 767px) {
34 | .form-row .form-group {
35 | margin-bottom: 10px;
36 | }
37 | }
38 |
39 | /* Extra styles */
40 |
41 | .expandable-card .card-text-container {
42 | max-height: 200px;
43 | overflow-y: hidden;
44 | position: relative;
45 | }
46 |
47 | .expandable-card.expanded .card-text-container {
48 | max-height: none;
49 | }
50 |
51 | .expand-btn {
52 | position: relative;
53 | display: none;
54 | background-color: rgba(255, 255, 255, 0.8);
55 | color: #510c75;
56 | border-color: transparent;
57 | }
58 |
59 | .expand-btn:hover {
60 | background-color: rgba(200, 200, 200, 0.8);
61 | text-decoration: none;
62 | border-color: transparent;
63 | color: #510c75;
64 | }
65 |
66 | .expand-btn:focus {
67 | outline: none;
68 | text-decoration: none;
69 | }
70 |
71 | .expandable-card:not(.expanded) .card-text-container:after {
72 | content: "";
73 | position: absolute;
74 | bottom: 0;
75 | left: 0;
76 | width: 100%;
77 | height: 90px;
78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79 | }
80 |
81 | .expandable-card:not(.expanded) .expand-btn {
82 | margin-top: -40px;
83 | }
84 |
85 | .card-body {
86 | padding-bottom: 5px;
87 | }
88 |
89 | .vertical-flex-layout {
90 | justify-content: center;
91 | align-items: center;
92 | height: 100%;
93 | display: flex;
94 | flex-direction: column;
95 | gap: 5px;
96 | }
97 |
98 | .figure-img {
99 | max-width: 100%;
100 | height: auto;
101 | }
102 |
103 | .adjustable-font-size {
104 | font-size: calc(0.5rem + 2vw);
105 | }
106 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2 | # NOTE: Solutions may be found at https://github.com/haotian-liu/LLaVA/issues/1101
3 | try:
4 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
5 | except:
6 | pass
7 | try:
8 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
9 | except:
10 | pass
11 | try:
12 | from llava.model.language_model.llava_internlm import LlavaInternLM2ForCausalLM, LlavaInternLM2Config
13 | except:
14 | pass
15 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/language_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/model/language_model/__init__.py
--------------------------------------------------------------------------------
/mllm_llava/llava/model/language_model/internlm_chat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/model/language_model/internlm_chat/__init__.py
--------------------------------------------------------------------------------
/mllm_llava/llava/model/language_model/llava_internlm.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from transformers import AutoConfig, AutoModelForCausalLM
7 |
8 | from transformers.modeling_outputs import CausalLMOutputWithPast
9 | from transformers.generation.utils import GenerateOutput
10 |
11 | from llava.model.language_model.internlm_chat.modeling_internlm2 import InternLM2ForCausalLM, InternLM2Model, InternLM2Config
12 |
13 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
14 |
15 |
16 | class LlavaInternLM2Config(InternLM2Config):
17 | model_type = "llava_internlm2"
18 |
19 |
20 | class LlavaInternLM2Model(LlavaMetaModel, InternLM2Model):
21 | config_class = LlavaInternLM2Config
22 |
23 | def __init__(self, config: InternLM2Config):
24 | super(LlavaInternLM2Model, self).__init__(config)
25 |
26 |
27 | class LlavaInternLM2ForCausalLM(InternLM2ForCausalLM, LlavaMetaForCausalLM):
28 | config_class = LlavaInternLM2Config
29 |
30 | def __init__(self, config):
31 | super(InternLM2ForCausalLM, self).__init__(config)
32 | self.model = LlavaInternLM2Model(config)
33 | self.vocab_size = config.vocab_size
34 | self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
35 |
36 | # Initialize weights and apply final processing
37 | self.post_init()
38 |
39 | def get_model(self):
40 | return self.model
41 |
42 | def forward(
43 | self,
44 | input_ids: torch.LongTensor = None,
45 | attention_mask: Optional[torch.Tensor] = None,
46 | position_ids: Optional[torch.LongTensor] = None,
47 | past_key_values: Optional[List[torch.FloatTensor]] = None,
48 | inputs_embeds: Optional[torch.FloatTensor] = None,
49 | labels: Optional[torch.LongTensor] = None,
50 | use_cache: Optional[bool] = None,
51 | output_attentions: Optional[bool] = None,
52 | output_hidden_states: Optional[bool] = None,
53 | images: Optional[torch.FloatTensor] = None,
54 | image_sizes: Optional[List[List[int]]] = None,
55 | return_dict: Optional[bool] = None,
56 | ) -> Union[Tuple, CausalLMOutputWithPast]:
57 |
58 | if inputs_embeds is None:
59 | (
60 | input_ids,
61 | position_ids,
62 | attention_mask,
63 | past_key_values,
64 | inputs_embeds,
65 | labels
66 | ) = self.prepare_inputs_labels_for_multimodal(
67 | input_ids,
68 | position_ids,
69 | attention_mask,
70 | past_key_values,
71 | labels,
72 | images,
73 | image_sizes
74 | )
75 |
76 | return super().forward(
77 | input_ids=input_ids,
78 | attention_mask=attention_mask,
79 | position_ids=position_ids,
80 | past_key_values=past_key_values,
81 | inputs_embeds=inputs_embeds,
82 | labels=labels,
83 | use_cache=use_cache,
84 | output_attentions=output_attentions,
85 | output_hidden_states=output_hidden_states,
86 | return_dict=return_dict
87 | )
88 |
89 | @torch.no_grad()
90 | def generate(
91 | self,
92 | inputs: Optional[torch.Tensor] = None,
93 | images: Optional[torch.Tensor] = None,
94 | image_sizes: Optional[torch.Tensor] = None,
95 | **kwargs,
96 | ) -> Union[GenerateOutput, torch.LongTensor]:
97 | position_ids = kwargs.pop("position_ids", None)
98 | attention_mask = kwargs.pop("attention_mask", None)
99 | if "inputs_embeds" in kwargs:
100 | raise NotImplementedError("`inputs_embeds` is not supported")
101 |
102 | if images is not None:
103 | (
104 | inputs,
105 | position_ids,
106 | attention_mask,
107 | _,
108 | inputs_embeds,
109 | _
110 | ) = self.prepare_inputs_labels_for_multimodal(
111 | inputs,
112 | position_ids,
113 | attention_mask,
114 | None,
115 | None,
116 | images,
117 | image_sizes=image_sizes
118 | )
119 | else:
120 | inputs_embeds = self.get_model().embed_tokens(inputs)
121 |
122 | return super().generate(
123 | position_ids=position_ids,
124 | attention_mask=attention_mask,
125 | inputs_embeds=inputs_embeds,
126 | **kwargs
127 | )
128 |
129 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
130 | inputs_embeds=None, **kwargs):
131 | images = kwargs.pop("images", None)
132 | image_sizes = kwargs.pop("image_sizes", None)
133 | inputs = super().prepare_inputs_for_generation(
134 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
135 | )
136 | if images is not None:
137 | inputs['images'] = images
138 | if image_sizes is not None:
139 | inputs['image_sizes'] = image_sizes
140 | return inputs
141 |
142 | AutoConfig.register("llava_internlm2", LlavaInternLM2Config)
143 | AutoModelForCausalLM.register(LlavaInternLM2Config, LlavaInternLM2ForCausalLM)
144 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
22 |
23 | from transformers.modeling_outputs import CausalLMOutputWithPast
24 | from transformers.generation.utils import GenerateOutput
25 |
26 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
27 |
28 |
29 | class LlavaConfig(LlamaConfig):
30 | model_type = "llava_llama"
31 |
32 |
33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
34 | config_class = LlavaConfig
35 |
36 | def __init__(self, config: LlamaConfig):
37 | super(LlavaLlamaModel, self).__init__(config)
38 |
39 |
40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
41 | config_class = LlavaConfig
42 |
43 | def __init__(self, config):
44 | super(LlamaForCausalLM, self).__init__(config)
45 | self.model = LlavaLlamaModel(config)
46 | self.pretraining_tp = config.pretraining_tp
47 | self.vocab_size = config.vocab_size
48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.model
55 |
56 | def forward(
57 | self,
58 | input_ids: torch.LongTensor = None,
59 | attention_mask: Optional[torch.Tensor] = None,
60 | position_ids: Optional[torch.LongTensor] = None,
61 | past_key_values: Optional[List[torch.FloatTensor]] = None,
62 | inputs_embeds: Optional[torch.FloatTensor] = None,
63 | labels: Optional[torch.LongTensor] = None,
64 | use_cache: Optional[bool] = None,
65 | output_attentions: Optional[bool] = None,
66 | output_hidden_states: Optional[bool] = None,
67 | images: Optional[torch.FloatTensor] = None,
68 | image_sizes: Optional[List[List[int]]] = None,
69 | return_dict: Optional[bool] = None,
70 | ) -> Union[Tuple, CausalLMOutputWithPast]:
71 |
72 | if inputs_embeds is None:
73 | (
74 | input_ids,
75 | position_ids,
76 | attention_mask,
77 | past_key_values,
78 | inputs_embeds,
79 | labels
80 | ) = self.prepare_inputs_labels_for_multimodal(
81 | input_ids,
82 | position_ids,
83 | attention_mask,
84 | past_key_values,
85 | labels,
86 | images,
87 | image_sizes
88 | )
89 |
90 | return super().forward(
91 | input_ids=input_ids,
92 | attention_mask=attention_mask,
93 | position_ids=position_ids,
94 | past_key_values=past_key_values,
95 | inputs_embeds=inputs_embeds,
96 | labels=labels,
97 | use_cache=use_cache,
98 | output_attentions=output_attentions,
99 | output_hidden_states=output_hidden_states,
100 | return_dict=return_dict
101 | )
102 |
103 | @torch.no_grad()
104 | def generate(
105 | self,
106 | inputs: Optional[torch.Tensor] = None,
107 | images: Optional[torch.Tensor] = None,
108 | image_sizes: Optional[torch.Tensor] = None,
109 | **kwargs,
110 | ) -> Union[GenerateOutput, torch.LongTensor]:
111 | position_ids = kwargs.pop("position_ids", None)
112 | attention_mask = kwargs.pop("attention_mask", None)
113 | if "inputs_embeds" in kwargs:
114 | raise NotImplementedError("`inputs_embeds` is not supported")
115 |
116 | if images is not None:
117 | (
118 | inputs,
119 | position_ids,
120 | attention_mask,
121 | _,
122 | inputs_embeds,
123 | _
124 | ) = self.prepare_inputs_labels_for_multimodal(
125 | inputs,
126 | position_ids,
127 | attention_mask,
128 | None,
129 | None,
130 | images,
131 | image_sizes=image_sizes
132 | )
133 | else:
134 | input_ids = kwargs.pop("input_ids", inputs)
135 | inputs_embeds = self.get_model().embed_tokens(input_ids)
136 |
137 | return super().generate(
138 | position_ids=position_ids,
139 | attention_mask=attention_mask,
140 | inputs_embeds=inputs_embeds,
141 | **kwargs
142 | )
143 |
144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145 | inputs_embeds=None, **kwargs):
146 | images = kwargs.pop("images", None)
147 | image_sizes = kwargs.pop("image_sizes", None)
148 | inputs = super().prepare_inputs_for_generation(
149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150 | )
151 | if images is not None:
152 | inputs['images'] = images
153 | if image_sizes is not None:
154 | inputs['image_sizes'] = image_sizes
155 | return inputs
156 |
157 | AutoConfig.register("llava_llama", LlavaConfig)
158 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
159 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/language_model/llava_mistral.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | from torch.nn import CrossEntropyLoss
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | MistralConfig, MistralModel, MistralForCausalLM
24 |
25 | from transformers.modeling_outputs import CausalLMOutputWithPast
26 | from transformers.generation.utils import GenerateOutput
27 |
28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29 |
30 |
31 | class LlavaMistralConfig(MistralConfig):
32 | model_type = "llava_mistral"
33 |
34 |
35 | class LlavaMistralModel(LlavaMetaModel, MistralModel):
36 | config_class = LlavaMistralConfig
37 |
38 | def __init__(self, config: MistralConfig):
39 | super(LlavaMistralModel, self).__init__(config)
40 |
41 |
42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43 | config_class = LlavaMistralConfig
44 |
45 | def __init__(self, config):
46 | super(MistralForCausalLM, self).__init__(config)
47 | self.model = LlavaMistralModel(config)
48 |
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | ) -> Union[Tuple, CausalLMOutputWithPast]:
72 |
73 | if inputs_embeds is None:
74 | (
75 | input_ids,
76 | position_ids,
77 | attention_mask,
78 | past_key_values,
79 | inputs_embeds,
80 | labels
81 | ) = self.prepare_inputs_labels_for_multimodal(
82 | input_ids,
83 | position_ids,
84 | attention_mask,
85 | past_key_values,
86 | labels,
87 | images,
88 | image_sizes
89 | )
90 |
91 | return super().forward(
92 | input_ids=input_ids,
93 | attention_mask=attention_mask,
94 | position_ids=position_ids,
95 | past_key_values=past_key_values,
96 | inputs_embeds=inputs_embeds,
97 | labels=labels,
98 | use_cache=use_cache,
99 | output_attentions=output_attentions,
100 | output_hidden_states=output_hidden_states,
101 | return_dict=return_dict
102 | )
103 |
104 | @torch.no_grad()
105 | def generate(
106 | self,
107 | inputs: Optional[torch.Tensor] = None,
108 | images: Optional[torch.Tensor] = None,
109 | image_sizes: Optional[torch.Tensor] = None,
110 | **kwargs,
111 | ) -> Union[GenerateOutput, torch.LongTensor]:
112 | position_ids = kwargs.pop("position_ids", None)
113 | attention_mask = kwargs.pop("attention_mask", None)
114 | if "inputs_embeds" in kwargs:
115 | raise NotImplementedError("`inputs_embeds` is not supported")
116 |
117 | if images is not None:
118 | (
119 | inputs,
120 | position_ids,
121 | attention_mask,
122 | _,
123 | inputs_embeds,
124 | _
125 | ) = self.prepare_inputs_labels_for_multimodal(
126 | inputs,
127 | position_ids,
128 | attention_mask,
129 | None,
130 | None,
131 | images,
132 | image_sizes=image_sizes
133 | )
134 | else:
135 | inputs_embeds = self.get_model().embed_tokens(inputs)
136 |
137 | return super().generate(
138 | position_ids=position_ids,
139 | attention_mask=attention_mask,
140 | inputs_embeds=inputs_embeds,
141 | **kwargs
142 | )
143 |
144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145 | inputs_embeds=None, **kwargs):
146 | images = kwargs.pop("images", None)
147 | image_sizes = kwargs.pop("image_sizes", None)
148 | inputs = super().prepare_inputs_for_generation(
149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150 | )
151 | if images is not None:
152 | inputs['images'] = images
153 | if image_sizes is not None:
154 | inputs['image_sizes'] = image_sizes
155 | return inputs
156 |
157 | AutoConfig.register("llava_mistral", LlavaMistralConfig)
158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
159 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10 |
11 | raise ValueError(f'Unknown vision tower: {vision_tower}')
12 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16 |
17 | if not delay_load:
18 | self.load_model()
19 | elif getattr(args, 'unfreeze_mm_vision_tower', False):
20 | self.load_model()
21 | else:
22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23 |
24 | def load_model(self, device_map=None):
25 | if self.is_loaded:
26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27 | return
28 |
29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
30 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
31 | self.vision_tower.requires_grad_(False)
32 |
33 | self.is_loaded = True
34 |
35 | def feature_select(self, image_forward_outs):
36 | image_features = image_forward_outs.hidden_states[self.select_layer]
37 | if self.select_feature == 'patch':
38 | image_features = image_features[:, 1:]
39 | elif self.select_feature == 'cls_patch':
40 | image_features = image_features
41 | else:
42 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
43 | return image_features
44 |
45 | @torch.no_grad()
46 | def forward(self, images):
47 | if type(images) is list:
48 | image_features = []
49 | for image in images:
50 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
52 | image_features.append(image_feature)
53 | else:
54 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
56 |
57 | return image_features
58 |
59 | @property
60 | def dummy_feature(self):
61 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62 |
63 | @property
64 | def dtype(self):
65 | return self.vision_tower.dtype
66 |
67 | @property
68 | def device(self):
69 | return self.vision_tower.device
70 |
71 | @property
72 | def config(self):
73 | if self.is_loaded:
74 | return self.vision_tower.config
75 | else:
76 | return self.cfg_only
77 |
78 | @property
79 | def hidden_size(self):
80 | return self.config.hidden_size
81 |
82 | @property
83 | def num_patches_per_side(self):
84 | return self.config.image_size // self.config.patch_size
85 |
86 | @property
87 | def num_patches(self):
88 | return (self.config.image_size // self.config.patch_size) ** 2
89 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | pixel_shuffle_mlp_gelu_match = re.match(r'^pixel_shuffle_ln_mlp(\d+)x_gelu$', projector_type)
49 | if pixel_shuffle_mlp_gelu_match:
50 | pixel_shuffle_ratio = getattr(config, 'pixel_shuffle_ratio', None)
51 | scale_factor = int(1 / pixel_shuffle_ratio ** 2)
52 | mlp_depth = int(pixel_shuffle_mlp_gelu_match.group(1))
53 | modules = [
54 | nn.LayerNorm(config.mm_hidden_size * scale_factor),
55 | nn.Linear(config.mm_hidden_size * scale_factor, config.hidden_size),
56 | ]
57 | for _ in range(1, mlp_depth):
58 | modules.append(nn.GELU())
59 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
60 | return nn.Sequential(*modules)
61 |
62 | if projector_type == 'identity':
63 | return IdentityMap()
64 |
65 | raise ValueError(f'Unknown projector type: {projector_type}')
66 |
--------------------------------------------------------------------------------
/mllm_llava/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava_llama")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/mllm_llava/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/serve/__init__.py
--------------------------------------------------------------------------------
/mllm_llava/llava/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.model.builder import load_pretrained_model
7 | from llava.utils import disable_torch_init
8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9 |
10 | from PIL import Image
11 |
12 | import requests
13 | from PIL import Image
14 | from io import BytesIO
15 | from transformers import TextStreamer
16 |
17 |
18 | def load_image(image_file):
19 | if image_file.startswith('http://') or image_file.startswith('https://'):
20 | response = requests.get(image_file)
21 | image = Image.open(BytesIO(response.content)).convert('RGB')
22 | else:
23 | image = Image.open(image_file).convert('RGB')
24 | return image
25 |
26 |
27 | def main(args):
28 | # Model
29 | disable_torch_init()
30 |
31 | model_name = get_model_name_from_path(args.model_path)
32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33 |
34 | if "llama-2" in model_name.lower():
35 | conv_mode = "llava_llama_2"
36 | elif "mistral" in model_name.lower():
37 | conv_mode = "mistral_instruct"
38 | elif "v1.6-34b" in model_name.lower():
39 | conv_mode = "chatml_direct"
40 | elif "v1" in model_name.lower():
41 | conv_mode = "llava_v1"
42 | elif "mpt" in model_name.lower():
43 | conv_mode = "mpt"
44 | else:
45 | conv_mode = "llava_v0"
46 |
47 | if args.conv_mode is not None and conv_mode != args.conv_mode:
48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
49 | else:
50 | args.conv_mode = conv_mode
51 |
52 | conv = conv_templates[args.conv_mode].copy()
53 | if "mpt" in model_name.lower():
54 | roles = ('user', 'assistant')
55 | else:
56 | roles = conv.roles
57 |
58 | image = load_image(args.image_file)
59 | image_size = image.size
60 | # Similar operation in model_worker.py
61 | image_tensor = process_images([image], image_processor, model.config)
62 | if type(image_tensor) is list:
63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
64 | else:
65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16)
66 |
67 | while True:
68 | try:
69 | inp = input(f"{roles[0]}: ")
70 | except EOFError:
71 | inp = ""
72 | if not inp:
73 | print("exit...")
74 | break
75 |
76 | print(f"{roles[1]}: ", end="")
77 |
78 | if image is not None:
79 | # first message
80 | if model.config.mm_use_im_start_end:
81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
82 | else:
83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
84 | conv.append_message(conv.roles[0], inp)
85 | image = None
86 | else:
87 | # later messages
88 | conv.append_message(conv.roles[0], inp)
89 | conv.append_message(conv.roles[1], None)
90 | prompt = conv.get_prompt()
91 |
92 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
93 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
94 | keywords = [stop_str]
95 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
96 |
97 | with torch.inference_mode():
98 | output_ids = model.generate(
99 | input_ids,
100 | images=image_tensor,
101 | image_sizes=[image_size],
102 | do_sample=True if args.temperature > 0 else False,
103 | temperature=args.temperature,
104 | max_new_tokens=args.max_new_tokens,
105 | streamer=streamer,
106 | use_cache=True)
107 |
108 | outputs = tokenizer.decode(output_ids[0]).strip()
109 | conv.messages[-1][-1] = outputs
110 |
111 | if args.debug:
112 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
118 | parser.add_argument("--model-base", type=str, default=None)
119 | parser.add_argument("--image-file", type=str, required=True)
120 | parser.add_argument("--device", type=str, default="cuda")
121 | parser.add_argument("--conv-mode", type=str, default=None)
122 | parser.add_argument("--temperature", type=float, default=0.2)
123 | parser.add_argument("--max-new-tokens", type=int, default=512)
124 | parser.add_argument("--load-8bit", action="store_true")
125 | parser.add_argument("--load-4bit", action="store_true")
126 | parser.add_argument("--debug", action="store_true")
127 | args = parser.parse_args()
128 |
129 | if args.model_base is None and "::" in args.model_path:
130 | model_base_and_path = args.model_path
131 | args.model_base, args.model_path = model_base_and_path.split("::")
132 | print(f"model_base_and_path ({model_base_and_path}) has been split into model_path ({args.model_path}) and model_base ({args.model_base})")
133 |
134 | main(args)
135 |
--------------------------------------------------------------------------------
/mllm_llava/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/mllm_llava/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_llava/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/mllm_llava/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/mllm_llava/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default=
59 | "Tell me a story with more than 1000 words.")
60 | args = parser.parse_args()
61 |
62 | main()
63 |
--------------------------------------------------------------------------------
/mllm_llava/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 |
--------------------------------------------------------------------------------
/mllm_llava/llava/train/llama_xformers_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3 | """
4 |
5 | import logging
6 | import math
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | import transformers.models.llama.modeling_llama
11 | from torch import nn
12 |
13 | try:
14 | import xformers.ops
15 | except ImportError:
16 | logging.error("xformers not found! Please install it before trying to use it.")
17 |
18 |
19 | def replace_llama_attn_with_xformers_attn():
20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21 |
22 |
23 | def xformers_forward(
24 | self,
25 | hidden_states: torch.Tensor,
26 | attention_mask: Optional[torch.Tensor] = None,
27 | position_ids: Optional[torch.LongTensor] = None,
28 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
29 | output_attentions: bool = False,
30 | use_cache: bool = False,
31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32 | # pylint: disable=duplicate-code
33 | bsz, q_len, _ = hidden_states.size()
34 |
35 | query_states = (
36 | self.q_proj(hidden_states)
37 | .view(bsz, q_len, self.num_heads, self.head_dim)
38 | .transpose(1, 2)
39 | )
40 | key_states = (
41 | self.k_proj(hidden_states)
42 | .view(bsz, q_len, self.num_heads, self.head_dim)
43 | .transpose(1, 2)
44 | )
45 | value_states = (
46 | self.v_proj(hidden_states)
47 | .view(bsz, q_len, self.num_heads, self.head_dim)
48 | .transpose(1, 2)
49 | )
50 |
51 | kv_seq_len = key_states.shape[-2]
52 | if past_key_value is not None:
53 | kv_seq_len += past_key_value[0].shape[-2]
54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55 | (
56 | query_states,
57 | key_states,
58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59 | query_states, key_states, cos, sin, position_ids
60 | )
61 | # [bsz, nh, t, hd]
62 |
63 | if past_key_value is not None:
64 | # reuse k, v, self_attention
65 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
66 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
67 |
68 | past_key_value = (key_states, value_states) if use_cache else None
69 |
70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix
71 | if not output_attentions:
72 | query_states = query_states.transpose(1, 2)
73 | key_states = key_states.transpose(1, 2)
74 | value_states = value_states.transpose(1, 2)
75 |
76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
80 | attn_output = xformers.ops.memory_efficient_attention(
81 | query_states, key_states, value_states, attn_bias=None
82 | )
83 | else:
84 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
85 | attn_output = xformers.ops.memory_efficient_attention(
86 | query_states,
87 | key_states,
88 | value_states,
89 | attn_bias=xformers.ops.LowerTriangularMask(),
90 | )
91 | attn_weights = None
92 | else:
93 | attn_weights = torch.matmul(
94 | query_states, key_states.transpose(2, 3)
95 | ) / math.sqrt(self.head_dim)
96 |
97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98 | raise ValueError(
99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100 | f" {attn_weights.size()}"
101 | )
102 |
103 | if attention_mask is not None:
104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105 | raise ValueError(
106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107 | )
108 | attn_weights = attn_weights + attention_mask
109 | attn_weights = torch.max(
110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111 | )
112 |
113 | # upcast attention to fp32
114 | attn_weights = nn.functional.softmax(
115 | attn_weights, dim=-1, dtype=torch.float32
116 | ).to(query_states.dtype)
117 | attn_output = torch.matmul(attn_weights, value_states)
118 |
119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120 | raise ValueError(
121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122 | f" {attn_output.size()}"
123 | )
124 |
125 | attn_output = attn_output.transpose(1, 2)
126 |
127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128 | attn_output = self.o_proj(attn_output)
129 | return attn_output, attn_weights, past_key_value
130 |
--------------------------------------------------------------------------------
/mllm_llava/llava/train/train_interleaved_mem.py:
--------------------------------------------------------------------------------
1 | from llava.train.train_interleaved import train
2 |
3 | if __name__ == "__main__":
4 | # # OSError: IOError: broken data stream when reading image file
5 | # from PIL import Image
6 | # from PIL import ImageFile
7 | # ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 | train(attn_implementation="flash_attention_2")
10 |
--------------------------------------------------------------------------------
/mllm_llava/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | from llava.train.train import train
2 |
3 | if __name__ == "__main__":
4 | train(attn_implementation="flash_attention_2")
5 |
--------------------------------------------------------------------------------
/mllm_llava/llava/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2 |
3 | # Need to call this before importing transformers.
4 | from llava.train.llama_xformers_attn_monkey_patch import (
5 | replace_llama_attn_with_xformers_attn,
6 | )
7 |
8 | replace_llama_attn_with_xformers_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/mllm_llava/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/mllm_llava/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava"
7 | version = "1.2.2.post1"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.1.2+cu118", "torchvision==0.16.2+cu118",
17 | "transformers==4.36.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18 | "accelerate==0.30.1", "peft", "bitsandbytes",
19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20 | "gradio==4.16.0", "gradio_client==0.8.1",
21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23 | "deepspeed==0.12.6"
24 | ]
25 |
26 | [project.optional-dependencies]
27 | train = ["deepspeed==0.12.6", "ninja", "wandb"]
28 | build = ["build", "twine"]
29 |
30 | [project.urls]
31 | "Homepage" = "https://llava-vl.github.io"
32 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
33 |
34 | [tool.setuptools.packages.find]
35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
36 |
37 | [tool.wheel]
38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
39 |
--------------------------------------------------------------------------------
/mllm_openflamingo/.gitignore:
--------------------------------------------------------------------------------
1 | *.pt
2 | *.json
3 |
4 | wandb/
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Pycharm project settings
119 | .idea
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | *.out
137 | src/wandb
138 | wandb
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 | # Cache
144 | cache/
145 |
146 | __*.sh
147 |
148 |
149 | # added by lqy
150 | tmp/*
151 | scripts_sh/*
152 | *.err
153 | .deepspeed_env
154 | */*/results_*
155 |
--------------------------------------------------------------------------------
/mllm_openflamingo/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in all
12 | copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 | SOFTWARE.
21 |
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_ddp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: MULTI_GPU
3 | downcast_bf16: false
4 | machine_rank: 0
5 | main_training_function: main
6 | mixed_precision: bf16
7 | num_machines: 1
8 | num_processes: 2
9 | rdzv_backend: static
10 | same_network: false
11 | tpu_use_cluster: false
12 | tpu_use_sudo: false
13 | use_cpu: false
14 | main_process_port: 20685
15 |
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_fsdp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: no
3 | downcast_bf16: true
4 | machine_rank: 0
5 | main_training_function: main
6 | mixed_precision: bf16
7 | num_machines: 1
8 | num_processes: 1
9 | rdzv_backend: static
10 | same_network: true
11 | tpu_use_cluster: false
12 | tpu_use_sudo: false
13 | use_cpu: false
14 | main_process_port: 20687
15 |
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_standalone.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: 'NO'
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 1
11 | same_network: true
12 |
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_zero1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 1
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 1
9 | distributed_type: DEEPSPEED
10 | fsdp_config: {}
11 | main_training_function: main
12 | mixed_precision: bf16
13 | use_cpu: false
14 | num_machines: 1
15 | num_processes: 8
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_zero1_slurm.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | deepspeed_multinode_launcher: standard
4 | gradient_accumulation_steps: 1
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: cpu
7 | offload_param_device: cpu
8 | zero3_init_flag: false
9 | zero_stage: 1
10 | distributed_type: DEEPSPEED
11 | fsdp_config: {}
12 | main_training_function: main
13 | mixed_precision: bf16
14 | use_cpu: false
15 | num_machines: 1
16 | num_processes: 8
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | deepspeed_multinode_launcher: standard
4 | gradient_accumulation_steps: 2
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: false
9 | zero_stage: 2
10 | distributed_type: DEEPSPEED
11 | fsdp_config: {}
12 | main_training_function: main
13 | mixed_precision: bf16
14 | use_cpu: false
15 | num_machines: 1
16 | num_processes: 8
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_zero2_slurm.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | deepspeed_multinode_launcher: standard
4 | gradient_accumulation_steps: 1
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: cpu
7 | offload_param_device: cpu
8 | zero3_init_flag: false
9 | zero_stage: 2
10 | distributed_type: DEEPSPEED
11 | fsdp_config: {}
12 | main_training_function: main
13 | mixed_precision: bf16
14 | use_cpu: false
15 | num_machines: 2
16 | num_processes: 16
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 1
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | fsdp_config: {}
12 | main_training_function: main
13 | mixed_precision: bf16
14 | use_cpu: false
15 | num_machines: 1
16 | num_processes: 8
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_zero3_offload.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 1
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: cpu
6 | offload_param_device: cpu
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | fsdp_config: {}
12 | main_training_function: main
13 | mixed_precision: bf16
14 | use_cpu: false
15 | num_machines: 1
16 | num_processes: 8
--------------------------------------------------------------------------------
/mllm_openflamingo/accelerate_configs/accelerate_config_zero3_slurm.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | deepspeed_multinode_launcher: standard
4 | gradient_accumulation_steps: 2
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: true
9 | zero3_save_16bit_model: true
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | fsdp_config: {}
13 | main_training_function: main
14 | mixed_precision: bf16
15 | use_cpu: false
16 | num_machines: 1
17 | num_processes: 8
--------------------------------------------------------------------------------
/mllm_openflamingo/environment.yml:
--------------------------------------------------------------------------------
1 | name: lmm_baseline_openflamingo
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.9
6 | - conda-forge::openjdk
7 | - pip
8 | - pip:
9 | - -r requirements.txt
10 | - -r requirements-training.txt
11 | - -r requirements-eval.txt
12 | - -e .
13 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/__init__.py:
--------------------------------------------------------------------------------
1 | from .src.flamingo import Flamingo
2 | from .src.factory import create_model_and_transforms
3 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/eval/README.md:
--------------------------------------------------------------------------------
1 | # OpenFlamingo Evaluation Suite
2 |
3 | This is the evaluation module of OpenFlamingo. It contains a set of utilities for evaluating multimodal models on various benchmarking datasets.
4 |
5 | *This module is a work in progress! We will be updating this README as it develops. In the meantime, if you notice an issue, please file a Bug Report or Feature Request [here](https://github.com/mlfoundations/open_flamingo/issues/new/choose).*
6 |
7 | ## Supported datasets
8 |
9 | |Dataset|Task|Metric|Evaluation method|
10 | |-------|----|------|-----------------|
11 | |[COCO](https://arxiv.org/abs/1405.0312)|Captioning|CIDEr|Generation|
12 | |[Flickr-30K](https://aclanthology.org/Q14-1006/)|Captioning|CIDEr|Generation|
13 | |[VQAv2](https://arxiv.org/abs/1612.00837v3)|VQA|VQA accuracy|Generation|
14 | |[OK-VQA](https://arxiv.org/abs/1906.00067)|VQA|VQA accuracy|Generation|
15 | |[TextVQA](https://arxiv.org/abs/1904.08920)|VQA|VQA accuracy|Generation|
16 | |[VizWiz](https://arxiv.org/abs/1802.08218)|VQA|VQA accuracy|Generation|
17 | |[Hateful Memes](https://arxiv.org/abs/2005.04790)|Classification|ROC AUC|Logprobs|
18 | |[ImageNet](https://arxiv.org/abs/1409.0575)|Classification|Top-1 accuracy|Logprobs|
19 |
20 | When evaluating a model using `num_shots` shots, we sample the exemplars from the training split. Performance is evaluated on a disjoint test split, subsampled to `--num_samples` examples (or using the full test split if `--num_samples=-1`).
21 |
22 | ## Sample scripts
23 | Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun`. We provide a sample Slurm evaluation script in `open_flamingo/open_flamingo/scripts/run_eval.sh`.
24 |
25 | We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16.
26 |
27 | To evaluate one of our pretrained checkpoints, we suggest first downloading a local copy of the weights, as follows:
28 |
29 | ```
30 | # grab model checkpoint from huggingface hub
31 | from huggingface_hub import hf_hub_download
32 | HF_TOKEN=""
33 |
34 | checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
35 | checkpoint_path= hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b",
36 | "checkpoint.pt",
37 | local_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b",
38 | cache_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b",
39 | local_dir_use_symlinks=False,
40 | token=HF_TOKEN)
41 | print(checkpoint_path)
42 | ## openflamingo/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt
43 | ```
44 |
45 | This should place the OpenFlamingo model at the expected location in the evaluation script.
46 |
47 | For TextVQA and VizWiz we expect annotations to be formatted differently than the original datasets. We provide the custom annotations in `open_flamingo/open_flamingo/eval/data/`. We have also uploaded all the annotation files in a [huggingface dataset](https://huggingface.co/datasets/openflamingo/eval_benchmark/tree/main) for easy access.
48 |
49 | # Evaluating using RICES (Retrieval-based In-Context Example Selection)
50 |
51 | We provide the option to evaluate using RICES, which is a method for selecting exemplars from the training set based on image similarity. This method was used in DeepMind's implementation for evaluating on ImageNet, but can be used for any dataset in our evaluation suite.
52 |
53 | To use RICES, you must first create features for a benchmark's training set. We provide a script for doing so in `open_flamingo/open_flamingo/scripts/cache_rices_features.py`. This script will extract image features for a given dataset using a given CLIP model checkpoint. For example, to extract features for the COCO training set, you can run:
54 |
55 | ```bash
56 | python cache_rices_features.py \
57 | --vision_encoder_path ViT-L-14 \
58 | --vision_encoder_pretrained openai \
59 | --batch_size 128 \
60 | --eval_coco \
61 | --coco_train_image_dir_path /path/to/coco/train2014 \
62 | --coco_val_image_dir_path /path/to/coco/val2014 \
63 | --coco_karpathy_json_path /path/to/coco/dataset_coco.json \
64 | --coco_annotations_json_path /path/to/coco/annotations/captions_train2014.json \
65 | --output_dir /path/to/coco/features
66 | ```
67 |
68 | This will create a directory at `/path/to/coco/features` containing a file named `coco.pkl` with the extracted features. You can then use this directory to evaluate using RICES by passing the `--rices` flag to the evaluation script, specifying the path to the features directory using the `--cached_demonstration_features` flag, and specifying the vision encoder to use for RICES using the `--rices_vision_encoder_path` and `--rices_vision_encoder_pretrained` flags.
69 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/eval/coco_metric.py:
--------------------------------------------------------------------------------
1 | from pycocoevalcap.eval import COCOEvalCap
2 | from pycocotools.coco import COCO
3 |
4 |
5 | def compute_cider(
6 | result_path,
7 | annotations_path,
8 | ):
9 | # create coco object and coco_result object
10 | coco = COCO(annotations_path)
11 | coco_result = coco.loadRes(result_path)
12 |
13 | # create coco_eval object by taking coco and coco_result
14 | coco_eval = COCOEvalCap(coco, coco_result)
15 | coco_eval.params["image_id"] = coco_result.getImgIds()
16 | coco_eval.evaluate()
17 |
18 | return coco_eval.eval
19 |
20 |
21 | def postprocess_captioning_generation(predictions):
22 | return predictions.split("Output", 1)[0]
23 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/eval/eval_model.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import argparse
3 | from typing import List
4 | from torch.nn.parallel import DistributedDataParallel as DDP
5 | from PIL import Image
6 |
7 |
8 | class BaseEvalModel(abc.ABC):
9 | """Base class encapsulating functionality needed to evaluate a model."""
10 |
11 | def __init__(self, args: List[str]):
12 | """Initialize model.
13 |
14 | Args:
15 | args: arguments to model. These should be parsed, or if the model
16 | has no applicable arguments, an error should be thrown if `args`
17 | is non-empty.
18 | """
19 |
20 | def init_distributed(self):
21 | """Wrap model as DDP."""
22 | self.model = DDP(self.model, device_ids=[self.device])
23 |
24 | def set_device(self, device):
25 | """Set device for model."""
26 | self.device = device
27 | self.model = self.model.to(device)
28 |
29 | def get_outputs(
30 | self,
31 | batch_text: List[str],
32 | batch_images: List[List[Image.Image]],
33 | min_generation_length: int,
34 | max_generation_length: int,
35 | num_beams: int,
36 | length_penalty: float,
37 | ) -> List[str]:
38 | """Get outputs for a batch of images and text.
39 |
40 | Args:
41 | batch_text: list of text strings, with the text "" in place
42 | of any images to be included.
43 | batch_images: images to provide to model. Should be a list of lists,
44 | where each list contains the images for a single example.
45 | max_generation_length: maximum length of the generated caption.
46 | Defaults to 10.
47 | num_beams: number of beams to use for beam search. Defaults to 3.
48 | length_penalty: length penalty for beam search. Defaults to -2.0.
49 |
50 | Returns:
51 | List of decoded output strings.
52 | """
53 |
54 | def vqa_prompt(self, question, answer=None) -> str:
55 | """Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model.
56 |
57 | Returns:
58 | The prompt to use for VQA.
59 | """
60 |
61 | def caption_prompt(self, caption=None) -> str:
62 | """Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model.
63 |
64 | Returns:
65 | The prompt to use for captioning.
66 | """
67 |
68 | def get_rank_classifications(
69 | self,
70 | batch_text: List[str],
71 | batch_images: List[List[Image.Image]],
72 | all_class_names: List[str],
73 | use_cache: bool,
74 | normalize_length: bool,
75 | ):
76 | """
77 | Returns a (B, |all_class_names|) tensor containing the logprobs for each class name.
78 | Args:
79 | batch_text: list of text strings, with the text "" in place
80 | of any images to be included.
81 | batch_images: images to provide to model. Should be a list of lists,
82 | where each list contains the images for a single example.
83 | all_class_names: list of all class names.
84 | use_cache: whether to cache the context to speed up evaluations.
85 | normalize_length: whether to normalize logprobs by the length of the
86 | class name
87 | Returns:
88 | (B, |all_class_names|) tensor containing the logprobs for each class name.
89 | """
90 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/eval/models/blip.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from PIL import Image
4 | import torch
5 |
6 | from transformers import Blip2Processor, Blip2ForConditionalGeneration
7 | from open_flamingo.eval.eval_model import BaseEvalModel
8 | from open_flamingo.eval.utils import unwrap_model
9 |
10 |
11 | class EvalModel(BaseEvalModel):
12 | """BLIP-2 model evaluation.
13 |
14 | Attributes:
15 | model (nn.Module): Underlying Torch model.
16 | tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model.
17 | device: Index of GPU to use, or the string "cpu"
18 | """
19 |
20 | def __init__(self, model_args):
21 | assert (
22 | "processor_path" in model_args and "lm_path" in model_args
23 | ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified"
24 |
25 | self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
26 | self.model = Blip2ForConditionalGeneration.from_pretrained(
27 | model_args["lm_path"]
28 | )
29 | self.model.eval()
30 | self.processor.tokenizer.padding_side = "left"
31 | self.lm_name = model_args["lm_path"].split("/")[-1]
32 |
33 | def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor:
34 | """Preprocess images and stack them.
35 |
36 | Args:
37 | batch: A list of lists of images.
38 |
39 | Returns:
40 | A Tensor of shape
41 | (batch_size, channels, height, width).
42 | """
43 | batch_images = None
44 | assert all(
45 | len(example) == 1 for example in batch
46 | ), "BLIP-2 only supports one image per example"
47 |
48 | for example in batch:
49 | assert len(example) == 1, "BLIP-2 only supports one image per example"
50 | batch_images = torch.cat(
51 | [
52 | batch_images,
53 | self.processor.image_processor(example, return_tensors="pt")[
54 | "pixel_values"
55 | ],
56 | ]
57 | if batch_images is not None
58 | else [
59 | self.processor.image_processor(example, return_tensors="pt")[
60 | "pixel_values"
61 | ]
62 | ],
63 | dim=0,
64 | )
65 | return batch_images
66 |
67 | def get_outputs(
68 | self,
69 | batch_text: List[str],
70 | batch_images: List[List[Image.Image]],
71 | min_generation_length: int,
72 | max_generation_length: int,
73 | num_beams: int,
74 | length_penalty: float,
75 | ) -> List[str]:
76 | encodings = self.processor.tokenizer(
77 | batch_text,
78 | padding="longest",
79 | truncation=True,
80 | return_tensors="pt",
81 | max_length=2000,
82 | )
83 | input_ids = encodings["input_ids"]
84 | attention_mask = encodings["attention_mask"]
85 |
86 | with torch.inference_mode():
87 | outputs = unwrap_model(self.model).generate(
88 | self._prepare_images(batch_images).to(self.device),
89 | input_ids.to(self.device),
90 | attention_mask=attention_mask.to(self.device),
91 | max_new_tokens=max_generation_length,
92 | min_new_tokens=min_generation_length,
93 | num_beams=num_beams,
94 | length_penalty=length_penalty,
95 | )
96 |
97 | return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
98 |
99 | def get_vqa_prompt(self, question, answer=None) -> str:
100 | return (
101 | f"Question:{question} Short answer:{answer if answer is not None else ''}"
102 | )
103 |
104 | def get_caption_prompt(self, caption=None) -> str:
105 | return f"A photo of {caption if caption is not None else ''}"
106 |
107 | def get_rank_classifications(
108 | self,
109 | batch_text: List[str],
110 | batch_images: List[List[Image.Image]],
111 | all_class_names: List[str],
112 | use_cache: bool,
113 | normalize_length: bool,
114 | ):
115 | raise NotImplementedError(
116 | "BLIP-2 classification-based evaluation not implemented"
117 | )
118 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/eval/rices.py:
--------------------------------------------------------------------------------
1 | import open_clip
2 | import torch
3 | from tqdm import tqdm
4 | import torch
5 | from utils import custom_collate_fn
6 |
7 |
8 | class RICES:
9 | def __init__(
10 | self,
11 | dataset,
12 | device,
13 | batch_size,
14 | vision_encoder_path="ViT-B-32",
15 | vision_encoder_pretrained="openai",
16 | cached_features=None,
17 | ):
18 | self.dataset = dataset
19 | self.device = device
20 | self.batch_size = batch_size
21 |
22 | # Load the model and processor
23 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
24 | vision_encoder_path,
25 | pretrained=vision_encoder_pretrained,
26 | )
27 | self.model = vision_encoder.to(self.device)
28 | self.image_processor = image_processor
29 |
30 | # Precompute features
31 | if cached_features is None:
32 | self.features = self._precompute_features()
33 | else:
34 | self.features = cached_features
35 |
36 | def _precompute_features(self):
37 | features = []
38 |
39 | # Switch to evaluation mode
40 | self.model.eval()
41 |
42 | # Set up loader
43 | loader = torch.utils.data.DataLoader(
44 | self.dataset,
45 | batch_size=self.batch_size,
46 | collate_fn=custom_collate_fn,
47 | )
48 |
49 | with torch.no_grad():
50 | for batch in tqdm(
51 | loader,
52 | desc="Precomputing features for RICES",
53 | ):
54 | batch = batch["image"]
55 | inputs = torch.stack(
56 | [self.image_processor(image) for image in batch]
57 | ).to(self.device)
58 | image_features = self.model.encode_image(inputs)
59 | image_features /= image_features.norm(dim=-1, keepdim=True)
60 | features.append(image_features.detach())
61 |
62 | features = torch.cat(features)
63 | return features
64 |
65 | def find(self, batch, num_examples):
66 | """
67 | Get the top num_examples most similar examples to the images.
68 | """
69 | # Switch to evaluation mode
70 | self.model.eval()
71 |
72 | with torch.no_grad():
73 | inputs = torch.stack([self.image_processor(image) for image in batch]).to(
74 | self.device
75 | )
76 |
77 | # Get the feature of the input image
78 | query_feature = self.model.encode_image(inputs)
79 | query_feature /= query_feature.norm(dim=-1, keepdim=True)
80 | query_feature = query_feature.detach().cpu()
81 |
82 | if query_feature.ndim == 1:
83 | query_feature = query_feature.unsqueeze(0)
84 |
85 | # Compute the similarity of the input image to the precomputed features
86 | similarity = (query_feature @ self.features.T).squeeze()
87 |
88 | if similarity.ndim == 1:
89 | similarity = similarity.unsqueeze(0)
90 |
91 | # Get the indices of the 'num_examples' most similar images
92 | indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples]
93 |
94 | # Return with the most similar images last
95 | return [[self.dataset[i] for i in reversed(row)] for row in indices]
96 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/eval/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | import torch.nn as nn
5 | from contextlib import suppress
6 |
7 |
8 | def random_seed(seed=42, rank=0):
9 | torch.manual_seed(seed + rank)
10 | np.random.seed(seed + rank)
11 | random.seed(seed + rank)
12 |
13 |
14 | def custom_collate_fn(batch):
15 | """
16 | Collate function for DataLoader that collates a list of dicts into a dict of lists.
17 | """
18 | collated_batch = {}
19 | for key in batch[0].keys():
20 | collated_batch[key] = [item[key] for item in batch]
21 | return collated_batch
22 |
23 |
24 | def compute_effective_num_shots(num_shots, model_type):
25 | """
26 | Compute the effective number of shots for a given model type.
27 | For example, following Flamingo, 0-shot OF evaluations use two text-only shots.
28 | """
29 | if model_type == "open_flamingo":
30 | return num_shots if num_shots > 0 else 2
31 | return num_shots
32 |
33 |
34 | def sample_batch_demos_from_query_set(query_set, num_samples, batch_size):
35 | """
36 | Sample random demonstrations from the query set.
37 | """
38 | return [random.sample(query_set, num_samples) for _ in range(batch_size)]
39 |
40 |
41 | def get_query_set(train_dataset, query_set_size):
42 | """
43 | Get a subset of the training dataset to use as the query set.
44 | """
45 | query_set = np.random.choice(len(train_dataset), query_set_size, replace=False)
46 | return [train_dataset[i] for i in query_set]
47 |
48 |
49 | def prepare_eval_samples(test_dataset, num_samples, batch_size):
50 | """
51 | Subset the test dataset and return a DataLoader.
52 | """
53 | random_indices = np.random.choice(len(test_dataset), num_samples, replace=False)
54 | dataset = torch.utils.data.Subset(test_dataset, random_indices)
55 | sampler = torch.utils.data.distributed.DistributedSampler(dataset)
56 | loader = torch.utils.data.DataLoader(
57 | dataset,
58 | batch_size=batch_size,
59 | sampler=sampler,
60 | collate_fn=custom_collate_fn,
61 | )
62 | return loader
63 |
64 |
65 | def get_indices_of_unique(x):
66 | """
67 | Return the indices of x that correspond to unique elements.
68 | If value v is unique and two indices in x have value v, the first index is returned.
69 | """
70 | unique_elements = torch.unique(x)
71 | first_indices = []
72 | for v in unique_elements:
73 | indices = torch.where(x == v)[0]
74 | first_indices.append(indices[0]) # Take the first index for each unique element
75 | return torch.tensor(first_indices)
76 |
77 |
78 | def unwrap_model(model):
79 | """
80 | Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
81 | """
82 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
83 | return model.module
84 | else:
85 | return model
86 |
87 |
88 | def get_predicted_classnames(logprobs, k, class_id_to_name):
89 | """
90 | Args:
91 | - logprobs shape (B, Y) containing logprobs for each classname
92 | - k: number for top-k
93 | - class_id_to_name: dict mapping class index to classname
94 |
95 | Returns:
96 | - top-k predicted classnames shape (B, k) type str
97 | - top-k logprobs shape (B, k) type float
98 | """
99 | # convert indices to classnames
100 | _, predictions = torch.topk(logprobs, k=k, dim=1) # shape (B, k)
101 | predicted_classnames = [
102 | [class_id_to_name[ix] for ix in item] for item in predictions.tolist()
103 | ]
104 | predicted_logprobs = torch.gather(logprobs, 1, predictions)
105 | return predicted_classnames, predicted_logprobs
106 |
107 |
108 | def get_cast_dtype(precision: str):
109 | cast_dtype = None
110 | if precision == "bf16":
111 | cast_dtype = torch.bfloat16
112 | elif precision == "fp16":
113 | cast_dtype = torch.float16
114 | return cast_dtype
115 |
116 |
117 | def get_autocast(precision):
118 | if precision == "amp":
119 | return torch.cuda.amp.autocast
120 | elif precision == "amp_bfloat16" or precision == "amp_bf16":
121 | # amp_bfloat16 is more stable than amp float16 for clip training
122 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
123 | else:
124 | return suppress
125 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/scripts/convert_mmc4_to_wds.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import uuid
5 | import zipfile
6 | from PIL import Image
7 | import base64
8 | from io import BytesIO
9 |
10 | import braceexpand
11 | import webdataset as wds
12 |
13 | arg_parser = argparse.ArgumentParser()
14 | arg_parser.add_argument(
15 | "--output_dir",
16 | type=str,
17 | help="Pass in the directory where the output shards (as tar files) will be written to.",
18 | )
19 | arg_parser.add_argument(
20 | "--zip_files",
21 | type=str,
22 | help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip",
23 | )
24 | arg_parser.add_argument(
25 | "--image_dir",
26 | type=str,
27 | help="Pass in the directory where the images have been downloaded to.",
28 | )
29 | arg_parser.add_argument(
30 | "--num_files_per_shard",
31 | type=int,
32 | default=1000,
33 | )
34 | args = arg_parser.parse_args()
35 |
36 |
37 | def main():
38 | os.makedirs(args.output_dir, exist_ok=True)
39 |
40 | doc_shards = list(braceexpand.braceexpand(args.zip_files))
41 |
42 | with wds.ShardWriter(args.output_dir + "/%09d.tar") as sink:
43 | for idx in range(len(doc_shards)):
44 | # Open the ZIP archive and extract the JSON file
45 | with zipfile.ZipFile(doc_shards[idx], "r") as zip_file:
46 | # Assumes the JSON file is the first file in the archive
47 | json_filename = zip_file.namelist()[0]
48 | with zip_file.open(json_filename, "r") as json_file:
49 | for sample_data in json_file:
50 | # get image names from json
51 | sample_data = json.loads(sample_data)
52 | image_info = sample_data["image_info"]
53 | image_names = [image["image_name"] for image in image_info]
54 |
55 | # Add each image to the tar file
56 | for img_idx, image_name in enumerate(image_names):
57 | try:
58 | # load image
59 | img = Image.open(
60 | os.path.join(args.image_dir, str(idx), image_name)
61 | ).convert("RGB")
62 | buffered = BytesIO()
63 | img.save(buffered, format="JPEG")
64 | img_str = base64.b64encode(buffered.getvalue())
65 |
66 | # convert to base64
67 | sample_data["image_info"][img_idx][
68 | "image_base64"
69 | ] = img_str.decode("utf-8")
70 | except FileNotFoundError:
71 | print(
72 | f"Did not find {image_name} downloaded. This can happen if the url is now 404."
73 | )
74 | except Exception as e:
75 | print(f"Error processing {image_name}: {e}")
76 |
77 | key_str = uuid.uuid4().hex
78 | sink.write({"__key__": key_str, "json": sample_data})
79 |
80 | if (idx + 1) % args.num_files_per_shard == 0:
81 | sink.next_stream()
82 |
83 |
84 | if __name__ == "__main__":
85 | main()
86 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/scripts/fill_vqa_testdev_results.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission.
3 | Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set.
4 | Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction.
5 | """
6 | import json
7 | import sys
8 | import os
9 |
10 | sys.path.append(
11 | os.path.join(
12 | os.path.dirname(os.path.abspath(__file__)),
13 | "..",
14 | )
15 | )
16 | from eval.vqa_metric import VQAEval
17 |
18 | postprocessor = VQAEval(None, None)
19 |
20 |
21 | def fill_vizwiz_test_json(
22 | input_path,
23 | output_path,
24 | vqa_test_questions_json_path,
25 | ):
26 | # read the input json and build a set with all question_ids
27 | with open(input_path, "r") as f:
28 | input_json = json.load(f)
29 |
30 | # postprocess answers
31 | question_id_to_answer = {}
32 | for q in input_json:
33 | resAns = q["answer"]
34 | resAns = resAns.replace("\n", " ")
35 | resAns = resAns.replace("\t", " ")
36 | resAns = resAns.strip()
37 | resAns = postprocessor.processPunctuation(resAns)
38 | resAns = postprocessor.processDigitArticle(resAns)
39 | question_id_to_answer[q["question_id"]] = resAns
40 |
41 | # read the vqa test json to get all the qustion_ids that need to be filled
42 | with open(vqa_test_questions_json_path, "r") as f:
43 | vqa_test_json = json.load(f)
44 | vqa_test_json = vqa_test_json["questions"]
45 |
46 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer
47 | output_json = []
48 | for q in vqa_test_json:
49 | output_json.append(
50 | {
51 | "image": q["image_id"],
52 | "answer": question_id_to_answer.get(q["question_id"], ""),
53 | }
54 | )
55 |
56 | # write the json to the output path
57 | with open(output_path, "w") as f:
58 | json.dump(output_json, f)
59 |
60 |
61 | def fill_vqav2_test_json(
62 | input_path,
63 | output_path,
64 | vqa_test_questions_json_path,
65 | ):
66 | # read the input json and build a set with all question_ids
67 | with open(input_path, "r") as f:
68 | input_json = json.load(f)
69 | question_ids = set()
70 | for q in input_json:
71 | question_ids.add(q["question_id"])
72 |
73 | # make a copy of the input json
74 | output_json = []
75 | for q in input_json:
76 | resAns = q["answer"]
77 | resAns = resAns.replace("\n", " ")
78 | resAns = resAns.replace("\t", " ")
79 | resAns = resAns.strip()
80 | resAns = postprocessor.processPunctuation(resAns)
81 | resAns = postprocessor.processDigitArticle(resAns)
82 | q["answer"] = resAns
83 | output_json.append(q)
84 |
85 | # read the vqa test json to get all the qustion_ids that need to be filled
86 | with open(vqa_test_questions_json_path, "r") as f:
87 | vqa_test_json = json.load(f)
88 | vqa_test_json = vqa_test_json["questions"]
89 |
90 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer
91 | for q in vqa_test_json:
92 | if q["question_id"] not in question_ids:
93 | output_json.append(
94 | {
95 | "question_id": q["question_id"],
96 | "answer": "",
97 | }
98 | )
99 |
100 | # write the json to the output path
101 | with open(output_path, "w") as f:
102 | json.dump(output_json, f)
103 |
104 |
105 | if __name__ == "__main__":
106 | import argparse
107 |
108 | parser = argparse.ArgumentParser()
109 | parser.add_argument(
110 | "--dataset",
111 | type=str,
112 | choices=["vqav2", "vizwiz"],
113 | )
114 | parser.add_argument(
115 | "--input_path",
116 | type=str,
117 | help="Path to the json file with the subset of the vqa test-dev questions.",
118 | )
119 | parser.add_argument(
120 | "--vqa_test_questions_json_path",
121 | type=str,
122 | help="Path to the json file with all the vqa test questions.",
123 | )
124 | parser.add_argument(
125 | "--output_path",
126 | type=str,
127 | help="Path to store the filled json.",
128 | )
129 | args = parser.parse_args()
130 |
131 | if args.dataset == "vqav2":
132 | fill_vqav2_test_json(
133 | args.input_path,
134 | args.output_path,
135 | args.vqa_test_questions_json_path,
136 | )
137 | else:
138 | fill_vizwiz_test_json(
139 | args.input_path,
140 | args.output_path,
141 | args.vqa_test_questions_json_path,
142 | )
143 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/OmniCorpus/ce4a012c6adf7c298cd6943a7ad875ec1af1ab6b/mllm_openflamingo/open_flamingo/src/__init__.py
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/src/factory.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | import open_clip
5 |
6 | from .flamingo import Flamingo
7 | from .flamingo_lm import FlamingoLMMixin
8 | from .utils import extend_instance
9 |
10 |
11 | def create_model_and_transforms(
12 | clip_vision_encoder_path: str,
13 | clip_vision_encoder_pretrained: str,
14 | lang_encoder_path: str,
15 | tokenizer_path: str,
16 | cross_attn_every_n_layers: int = 1,
17 | use_local_files: bool = False,
18 | decoder_layers_attr_name: str = None,
19 | freeze_lm_embeddings: bool = False,
20 | cache_dir: Optional[str] = None,
21 | **flamingo_kwargs,
22 | ):
23 | """
24 | Initialize a Flamingo model from a pretrained vision encoder and language encoder.
25 | Appends special tokens to the tokenizer and freezes backbones.
26 |
27 | Args:
28 | clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
29 | clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
30 | lang_encoder_path (str): path to pretrained language encoder
31 | tokenizer_path (str): path to pretrained tokenizer
32 | cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
33 | use_local_files (bool, optional): whether to use local files. Defaults to False.
34 | decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
35 | freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
36 | cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
37 | Returns:
38 | Flamingo: Flamingo model from pretrained vision and language encoders
39 | Image processor: Pipeline to preprocess input images
40 | Tokenizer: A tokenizer for the language model
41 | """
42 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
43 | clip_vision_encoder_path,
44 | pretrained=clip_vision_encoder_pretrained,
45 | cache_dir=cache_dir,
46 | )
47 | # set the vision encoder to output the visual features
48 | vision_encoder.visual.output_tokens = True
49 |
50 | text_tokenizer = AutoTokenizer.from_pretrained(
51 | tokenizer_path,
52 | local_files_only=use_local_files,
53 | trust_remote_code=True,
54 | cache_dir=cache_dir,
55 | )
56 | # add Flamingo special tokens to the tokenizer
57 | text_tokenizer.add_special_tokens(
58 | {"additional_special_tokens": ["<|endofchunk|>", ""]}
59 | )
60 | if text_tokenizer.pad_token is None:
61 | # Issue: GPT models don't have a pad token, which we use to
62 | # modify labels for the loss.
63 | text_tokenizer.add_special_tokens({"pad_token": ""})
64 |
65 | lang_encoder = AutoModelForCausalLM.from_pretrained(
66 | lang_encoder_path,
67 | local_files_only=use_local_files,
68 | trust_remote_code=True,
69 | cache_dir=cache_dir,
70 | )
71 |
72 | # hacks for MPT-1B, which doesn't have a get_input_embeddings method
73 | if "mpt-1b-redpajama-200b" in lang_encoder_path:
74 |
75 | class EmbeddingFnMixin:
76 | def get_input_embeddings(self):
77 | return self.transformer.wte
78 |
79 | def set_input_embeddings(self, new_embeddings):
80 | self.transformer.wte = new_embeddings
81 |
82 | extend_instance(lang_encoder, EmbeddingFnMixin)
83 |
84 | # convert LM to FlamingoLM
85 | extend_instance(lang_encoder, FlamingoLMMixin)
86 |
87 | if decoder_layers_attr_name is None:
88 | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
89 | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
90 | lang_encoder.resize_token_embeddings(len(text_tokenizer))
91 |
92 | model = Flamingo(
93 | vision_encoder,
94 | lang_encoder,
95 | text_tokenizer.encode("<|endofchunk|>")[-1],
96 | text_tokenizer.encode("")[-1],
97 | vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
98 | "width"
99 | ],
100 | cross_attn_every_n_layers=cross_attn_every_n_layers,
101 | **flamingo_kwargs,
102 | )
103 |
104 | # Freeze all parameters
105 | model.requires_grad_(False)
106 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
107 |
108 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
109 | model.perceiver.requires_grad_(True)
110 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
111 | if not freeze_lm_embeddings:
112 | model.lang_encoder.get_input_embeddings().requires_grad_(True)
113 | # TODO: investigate also training the output embeddings when untied
114 |
115 | print(
116 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
117 | )
118 |
119 | return model, image_processor, text_tokenizer
120 |
121 |
122 | def _infer_decoder_layers_attr_name(model):
123 | for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
124 | if k.lower() in model.__class__.__name__.lower():
125 | return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
126 |
127 | raise ValueError(
128 | f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
129 | )
130 |
131 |
132 | __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
133 | "opt": "model.decoder.layers",
134 | "gptj": "transformer.h",
135 | "gpt-j": "transformer.h",
136 | "pythia": "gpt_neox.layers",
137 | "llama": "model.layers",
138 | "gptneoxforcausallm": "gpt_neox.layers",
139 | "mpt": "transformer.blocks",
140 | "mosaicgpt": "transformer.blocks",
141 | }
142 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/src/utils.py:
--------------------------------------------------------------------------------
1 | def extend_instance(obj, mixin):
2 | """Apply mixins to a class instance after creation"""
3 | base_cls = obj.__class__
4 | base_cls_name = obj.__class__.__name__
5 | obj.__class__ = type(
6 | base_cls_name, (mixin, base_cls), {}
7 | ) # mixin needs to go first for our forward() logic to work
8 |
9 |
10 | def getattr_recursive(obj, att):
11 | """
12 | Return nested attribute of obj
13 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
14 | """
15 | if att == "":
16 | return obj
17 | i = att.find(".")
18 | if i < 0:
19 | return getattr(obj, att)
20 | else:
21 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
22 |
23 |
24 | def setattr_recursive(obj, att, val):
25 | """
26 | Set nested attribute of obj
27 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
28 | """
29 | if "." in att:
30 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
31 | setattr(obj, att.split(".")[-1], val)
32 |
33 |
34 | def apply_with_stopping_condition(
35 | module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
36 | ):
37 | if stopping_condition(module):
38 | return
39 | if apply_condition(module):
40 | apply_fn(module, **other_args)
41 | for child in module.children():
42 | apply_with_stopping_condition(
43 | child,
44 | apply_fn,
45 | apply_condition=apply_condition,
46 | stopping_condition=stopping_condition,
47 | **other_args
48 | )
49 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/train/README.md:
--------------------------------------------------------------------------------
1 | # OpenFlamingo Training
2 | To train OpenFlamingo, please ensure your environment matches that of `environment.yml`.
3 |
4 | ## Data
5 | Our codebase uses [WebDataset](https://github.com/webdataset/webdataset) to efficiently load `.tar` files containing image and text sequences. We recommend resampling shards with replacement during training using the `--dataset_resampled` flag.
6 |
7 | ### LAION-2B Dataset
8 | [LAION-2B](https://arxiv.org/abs/2210.08402) contains 2B web-scraped (image, text) pairs.
9 | We use [img2dataset](https://github.com/rom1504/img2dataset) to download this dataset into tar files.
10 |
11 | ### Multimodal C4 Dataset
12 | We train on the full version of [Multimodal C4 (MMC4)](https://github.com/allenai/mmc4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, we truncate sequences to 256 text tokens and six images per sequence.
13 |
14 | Our codebase expects `.tar` files containing `.json` files, which include raw images encoded in base64.
15 | We provide scripts to convert MMC4 to this format:
16 |
17 | 1. Download the MMC4 shards into `.zip` files using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `fewer_facesv2.sh`).
18 | 2. Download the MMC4 raw images into an image directory using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `download_images.py`).
19 | 2. Run `scripts/convert_interleaved_to_wds.py` to convert the downloaded items into the expected tar files.
20 |
21 | ### ChatGPT-generated sequences
22 | A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at [this CodaLab worksheet](https://worksheets.codalab.org/worksheets/0xdcd888ff7c754ae680c5e038f6ed1d9b). We are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase.
23 |
24 | Models trained with ChatGPT-generated sequences:
25 |
26 | * OpenFlamingo-4B-vitl-rpj3b
27 | * OpenFlamingo-4B-vitl-rpj3b-langinstruct
28 |
29 | ## Example training command
30 | We provide a sample Slurm training script in `scripts/`. You can also modify the following command:
31 |
32 | ```
33 | torchrun --nnodes=1 --nproc_per_node=4 train.py \
34 | --lm_path anas-awadalla/mpt-1b-redpajama-200b \
35 | --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
36 | --cross_attn_every_n_layers 1 \
37 | --dataset_resampled \
38 | --batch_size_interleaved 32 \
39 | --batch_size_laion 64 \
40 | --train_num_samples_interleaved 125000\
41 | --train_num_samples_laion 250000 \
42 | --loss_multiplier_laion 0.2 \
43 | --workers=4 \
44 | --run_name OpenFlamingo-3B-vitl-mpt1b \
45 | --num_epochs 480 \
46 | --warmup_steps 1875 \
47 | --mmc4_textsim_threshold 0.24 \
48 | --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
49 | --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
50 | --report_to_wandb
51 | ```
52 | *Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).*
53 |
54 | ## Distributed training
55 |
56 | By default, `train.py` uses Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html) for training.
57 | To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), use the `--fsdp` flag.
58 |
59 | Some notes on FSDP:
60 |
61 | * We recommend using the `--fsdp_use_orig_params` flag. If `--fsdp` is on without this flag, all language model embeddings will be unfrozen during training. (In contrast, the default behavior is to only train the newly added `` and `<|endofchunk|>` tokens.)
62 | * Note: we've encountered issues using OPT with this flag. Other language models should be compatible.
63 | * Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input / output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the `--freeze_lm_embeddings` flag.
64 |
65 | We also implement gradient checkpointing and mixed precision training. Use the `--gradient_checkpointing` and `--precision` arguments respectively.
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/train/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/mllm_openflamingo/open_flamingo/train/distributed.py:
--------------------------------------------------------------------------------
1 | """
2 | Util functions for setting up distributed training.
3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py
4 | """
5 |
6 | import os
7 | import torch
8 |
9 | try:
10 | import horovod.torch as hvd
11 | except ImportError:
12 | hvd = None
13 |
14 |
15 | def is_global_master(args):
16 | return args.rank == 0
17 |
18 |
19 | def is_local_master(args):
20 | return args.local_rank == 0
21 |
22 |
23 | def is_master(args, local=False):
24 | return is_local_master(args) if local else is_global_master(args)
25 |
26 |
27 | def is_using_horovod():
28 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
29 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
30 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
31 | pmi_vars = ["PMI_RANK", "PMI_SIZE"]
32 | if all([var in os.environ for var in ompi_vars]) or all(
33 | [var in os.environ for var in pmi_vars]
34 | ):
35 | return True
36 | else:
37 | return False
38 |
39 |
40 | def is_using_distributed():
41 | if "WORLD_SIZE" in os.environ:
42 | return int(os.environ["WORLD_SIZE"]) > 1
43 | if "SLURM_NTASKS" in os.environ:
44 | return int(os.environ["SLURM_NTASKS"]) > 1
45 | return False
46 |
47 |
48 | def world_info_from_env():
49 | local_rank = 0
50 | for v in (
51 | "LOCAL_RANK",
52 | "MPI_LOCALRANKID",
53 | "SLURM_LOCALID",
54 | "OMPI_COMM_WORLD_LOCAL_RANK",
55 | ):
56 | if v in os.environ:
57 | local_rank = int(os.environ[v])
58 | break
59 | global_rank = 0
60 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
61 | if v in os.environ:
62 | global_rank = int(os.environ[v])
63 | break
64 | world_size = 1
65 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
66 | if v in os.environ:
67 | world_size = int(os.environ[v])
68 | break
69 |
70 | return local_rank, global_rank, world_size
71 |
72 |
73 | def init_distributed_device(args):
74 | # Distributed training = training on more than one GPU.
75 | # Works in both single and multi-node scenarios.
76 | args.distributed = False
77 | args.world_size = 1
78 | args.rank = 0 # global rank
79 | args.local_rank = 0
80 | if args.horovod:
81 | assert hvd is not None, "Horovod is not installed"
82 | hvd.init()
83 | args.local_rank = int(hvd.local_rank())
84 | args.rank = hvd.rank()
85 | args.world_size = hvd.size()
86 | args.distributed = True
87 | os.environ["LOCAL_RANK"] = str(args.local_rank)
88 | os.environ["RANK"] = str(args.rank)
89 | os.environ["WORLD_SIZE"] = str(args.world_size)
90 | elif is_using_distributed():
91 | if "SLURM_PROCID" in os.environ:
92 | # DDP via SLURM
93 | args.local_rank, args.rank, args.world_size = world_info_from_env()
94 | # SLURM var -> torch.distributed vars in case needed
95 | os.environ["LOCAL_RANK"] = str(args.local_rank)
96 | os.environ["RANK"] = str(args.rank)
97 | os.environ["WORLD_SIZE"] = str(args.world_size)
98 | torch.distributed.init_process_group(
99 | backend=args.dist_backend,
100 | init_method=args.dist_url,
101 | world_size=args.world_size,
102 | rank=args.rank,
103 | )
104 | else:
105 | # DDP via torchrun, torch.distributed.launch
106 | args.local_rank, _, _ = world_info_from_env()
107 | torch.distributed.init_process_group(
108 | backend=args.dist_backend, init_method=args.dist_url
109 | )
110 | args.world_size = torch.distributed.get_world_size()
111 | args.rank = torch.distributed.get_rank()
112 | args.distributed = True
113 | else:
114 | # needed to run on single gpu
115 | torch.distributed.init_process_group(
116 | backend=args.dist_backend,
117 | init_method=args.dist_url,
118 | world_size=1,
119 | rank=0,
120 | )
121 |
122 | if torch.cuda.is_available():
123 | if args.distributed and not args.no_set_device_rank:
124 | device = "cuda:%d" % args.local_rank
125 | else:
126 | device = "cuda:0"
127 | torch.cuda.set_device(device)
128 | else:
129 | device = "cpu"
130 | args.device = device
131 | device = torch.device(device)
132 | return device
133 |
--------------------------------------------------------------------------------
/mllm_openflamingo/requirements-eval.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | torchvision
3 | nltk
4 | inflection
5 | pycocoevalcap
6 | pycocotools
7 | tqdm
8 | scikit-learn
9 | black
10 | mypy
11 | pylint
12 | pytest
13 | requests
14 |
--------------------------------------------------------------------------------
/mllm_openflamingo/requirements-training.txt:
--------------------------------------------------------------------------------
1 | torchvision
2 | braceexpand
3 | webdataset
4 | tqdm
5 | wandb
6 |
--------------------------------------------------------------------------------
/mllm_openflamingo/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | einops-exts
3 | transformers==4.28.1
4 | torch==2.0.1
5 | pillow
6 | open_clip_torch>=2.16.0
7 | sentencepiece
8 |
--------------------------------------------------------------------------------
/mllm_openflamingo/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from setuptools import find_packages, setup
4 |
5 | if __name__ == "__main__":
6 | with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file:
7 | long_description = file.read()
8 |
9 | REQUIREMENTS = [
10 | "einops",
11 | "einops-exts",
12 | "transformers>=4.28.1",
13 | "torch==2.0.1",
14 | "pillow",
15 | "open_clip_torch>=2.16.0",
16 | "sentencepiece",
17 | ]
18 |
19 | EVAL = [
20 | "scipy",
21 | "torchvision",
22 | "nltk",
23 | "inflection",
24 | "pycocoevalcap",
25 | "pycocotools",
26 | "tqdm",
27 | ]
28 |
29 | TRAINING = [
30 | "wandb",
31 | "torchvision",
32 | "braceexpand",
33 | "webdataset",
34 | "tqdm",
35 | "deepspeed==0.14.2",
36 | "accelerate==0.30.0",
37 | ]
38 |
39 | setup(
40 | name="open_flamingo",
41 | packages=find_packages(),
42 | include_package_data=True,
43 | version="2.0.1",
44 | license="MIT",
45 | description="An open-source framework for training large multimodal models",
46 | long_description=long_description,
47 | long_description_content_type="text/markdown",
48 | data_files=[(".", ["README.md"])],
49 | keywords=["machine learning"],
50 | install_requires=REQUIREMENTS,
51 | extras_require={
52 | "eval": EVAL,
53 | "training": TRAINING,
54 | "all": list(set(REQUIREMENTS + EVAL + TRAINING)),
55 | },
56 | classifiers=[
57 | "Development Status :: 4 - Beta",
58 | "Intended Audience :: Developers",
59 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
60 | "License :: OSI Approved :: MIT License",
61 | "Programming Language :: Python :: 3.9",
62 | ],
63 | )
64 |
--------------------------------------------------------------------------------
/mllm_openflamingo/setup_deepspeed_adamcpu.py:
--------------------------------------------------------------------------------
1 | import deepspeed
2 | deepspeed.ops.op_builder.CPUAdamBuilder().load()
--------------------------------------------------------------------------------